Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ Documentation = "https://github.com/microsoft/typeagent-py/tree/main/docs/README
[tool.uv.build-backend]
module-root = "src"

[tool.uv.sources]
pytest-async-benchmark = { git = "https://github.com/KRRT7/pytest-async-benchmark.git", rev = "feat/pedantic-mode" }

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"
testpaths = ["tests"]
Expand Down Expand Up @@ -91,6 +94,7 @@ dev = [
"opentelemetry-instrumentation-httpx>=0.57b0",
"pyright>=1.1.408", # 407 has a regression
"pytest>=8.3.5",
"pytest-async-benchmark",
"pytest-asyncio>=0.26.0",
"pytest-benchmark>=5.1.0",
"pytest-mock>=3.14.0",
Expand Down
29 changes: 16 additions & 13 deletions src/typeagent/knowpro/answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,19 +452,22 @@ async def get_scored_semantic_refs_from_ordinals_iter(
semantic_ref_matches: list[ScoredSemanticRefOrdinal],
knowledge_type: KnowledgeType,
) -> list[Scored[SemanticRef]]:
result = []
for semantic_ref_match in semantic_ref_matches:
semantic_ref = await semantic_refs.get_item(
semantic_ref_match.semantic_ref_ordinal
)
if semantic_ref.knowledge.knowledge_type == knowledge_type:
result.append(
Scored(
item=semantic_ref,
score=semantic_ref_match.score,
)
)
return result
if not semantic_ref_matches:
return []
ordinals = [m.semantic_ref_ordinal for m in semantic_ref_matches]
metadata = await semantic_refs.get_metadata_multiple(ordinals)
matching = [
(sr_match, m.ordinal)
for sr_match, m in zip(semantic_ref_matches, metadata)
if m.knowledge_type == knowledge_type
]
if not matching:
return []
full_refs = await semantic_refs.get_multiple([o for _, o in matching])
return [
Scored(item=ref, score=sr_match.score)
for (sr_match, _), ref in zip(matching, full_refs)
]


def merge_scored_concrete_entities(
Expand Down
39 changes: 25 additions & 14 deletions src/typeagent/knowpro/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,17 @@ async def group_matches_by_type(
self,
semantic_refs: ISemanticRefCollection,
) -> dict[KnowledgeType, "SemanticRefAccumulator"]:
matches = list(self)
if not matches:
return {}
ordinals = [match.value for match in matches]
metadata = await semantic_refs.get_metadata_multiple(ordinals)
groups: dict[KnowledgeType, SemanticRefAccumulator] = {}
for match in self:
semantic_ref = await semantic_refs.get_item(match.value)
group = groups.get(semantic_ref.knowledge.knowledge_type)
for match, m in zip(matches, metadata):
group = groups.get(m.knowledge_type)
if group is None:
group = SemanticRefAccumulator(self.search_term_matches)
groups[semantic_ref.knowledge.knowledge_type] = group
groups[m.knowledge_type] = group
group.set_match(match)
return groups

Expand All @@ -346,11 +350,14 @@ async def get_matches_in_scope(
semantic_refs: ISemanticRefCollection,
ranges_in_scope: "TextRangesInScope",
) -> "SemanticRefAccumulator":
matches = list(self)
if not matches:
return SemanticRefAccumulator(self.search_term_matches)
ordinals = [match.value for match in matches]
metadata = await semantic_refs.get_metadata_multiple(ordinals)
accumulator = SemanticRefAccumulator(self.search_term_matches)
for match in self:
if ranges_in_scope.is_range_in_scope(
(await semantic_refs.get_item(match.value)).range
):
for match, m in zip(matches, metadata):
if ranges_in_scope.is_range_in_scope(m.range):
accumulator.set_match(match)
return accumulator

Expand Down Expand Up @@ -519,12 +526,16 @@ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> No
self.add_range(text_range)

def contains_range(self, inner_range: TextRange) -> bool:
# Since ranges are sorted by start, once we pass inner_range's start
# no further range can contain it.
for outer_range in self._ranges:
if outer_range.start > inner_range.start:
break
if inner_range in outer_range:
if not self._ranges:
return False
# Bisect on start only to find all ranges with start <= inner.start,
# then scan backwards — the most likely containing range has the
# largest start still <= inner's.
hi = bisect.bisect_right(
self._ranges, inner_range.start, key=lambda r: r.start
)
for i in range(hi - 1, -1, -1):
if inner_range in self._ranges[i]:
return True
return False

Expand Down
42 changes: 18 additions & 24 deletions src/typeagent/knowpro/interfaces_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

from collections.abc import Sequence
from datetime import datetime as Datetime
from typing import (
Any,
Expand Down Expand Up @@ -168,6 +169,11 @@ async def add_term(
semantic_ref_ordinal: SemanticRefOrdinal | ScoredSemanticRefOrdinal,
) -> str: ...

async def add_terms_batch(
self,
terms: Sequence[tuple[str, SemanticRefOrdinal | ScoredSemanticRefOrdinal]],
) -> None: ...

async def remove_term(
self, term: str, semantic_ref_ordinal: SemanticRefOrdinal
) -> None: ...
Expand Down Expand Up @@ -249,32 +255,24 @@ def __repr__(self) -> str:
else:
return f"{self.__class__.__name__}({self.start}, {self.end})"

@staticmethod
def _effective_end(tr: "TextRange") -> tuple[int, int]:
"""Return (message_ordinal, chunk_ordinal) for the effective end."""
if tr.end is not None:
return (tr.end.message_ordinal, tr.end.chunk_ordinal)
return (tr.start.message_ordinal, tr.start.chunk_ordinal + 1)

def __eq__(self, other: object) -> bool:
if not isinstance(other, TextRange):
return NotImplemented

if self.start != other.start:
return False

# Get the effective end for both ranges
self_end = self.end or TextLocation(
self.start.message_ordinal, self.start.chunk_ordinal + 1
)
other_end = other.end or TextLocation(
other.start.message_ordinal, other.start.chunk_ordinal + 1
)
return self_end == other_end
return TextRange._effective_end(self) == TextRange._effective_end(other)

def __lt__(self, other: Self) -> bool:
if self.start != other.start:
return self.start < other.start
self_end = self.end or TextLocation(
self.start.message_ordinal, self.start.chunk_ordinal + 1
)
other_end = other.end or TextLocation(
other.start.message_ordinal, other.start.chunk_ordinal + 1
)
return self_end < other_end
return TextRange._effective_end(self) < TextRange._effective_end(other)

def __gt__(self, other: Self) -> bool:
return other.__lt__(self)
Expand All @@ -286,13 +284,9 @@ def __le__(self, other: Self) -> bool:
return not other.__lt__(self)

def __contains__(self, other: Self) -> bool:
other_end = other.end or TextLocation(
other.start.message_ordinal, other.start.chunk_ordinal + 1
)
self_end = self.end or TextLocation(
self.start.message_ordinal, self.start.chunk_ordinal + 1
)
return self.start <= other.start and other_end <= self_end
if not (self.start <= other.start):
return False
return TextRange._effective_end(other) <= TextRange._effective_end(self)

def serialize(self) -> TextRangeData:
return self.__pydantic_serializer__.to_python( # type: ignore
Expand Down
5 changes: 5 additions & 0 deletions src/typeagent/knowpro/interfaces_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ async def add_property(
semantic_ref_ordinal: SemanticRefOrdinal | ScoredSemanticRefOrdinal,
) -> None: ...

async def add_properties_batch(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In propindex.py:85 the SQLite implementation imports make_property_term_text / split_property_term_text from storage.memory.propindex. This creates a cross-layer dependency (sqlite → memory). These helpers should be extracted to a shared location (e.g., the base knowpro/propindex or a utils module).

self,
properties: Sequence[tuple[str, str, SemanticRefOrdinal | ScoredSemanticRefOrdinal]],
) -> None: ...

async def lookup_property(
self, property_name: str, value: str
) -> list[ScoredSemanticRefOrdinal] | None: ...
Expand Down
19 changes: 18 additions & 1 deletion src/typeagent/knowpro/interfaces_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@

from collections.abc import AsyncIterable, Iterable
from datetime import datetime as Datetime
from typing import Any, Protocol, Self
from typing import Any, NamedTuple, Protocol, Self

from pydantic.dataclasses import dataclass

from .interfaces_core import (
IMessage,
ITermToSemanticRefIndex,
KnowledgeType,
MessageOrdinal,
SemanticRef,
SemanticRefOrdinal,
TextRange,
)
from .interfaces_indexes import (
IConversationSecondaryIndexes,
Expand Down Expand Up @@ -57,6 +59,14 @@ class ConversationMetadata:
extra: dict[str, str] | None = None


class SemanticRefMetadata(NamedTuple):
"""Lightweight metadata for filtering without full knowledge deserialization."""

ordinal: SemanticRefOrdinal
range: TextRange
knowledge_type: KnowledgeType


class IReadonlyCollection[T, TOrdinal](AsyncIterable[T], Protocol):
async def size(self) -> int: ...

Expand Down Expand Up @@ -91,6 +101,12 @@ class IMessageCollection[TMessage: IMessage](
class ISemanticRefCollection(ICollection[SemanticRef, SemanticRefOrdinal], Protocol):
"""A collection of SemanticRefs."""

async def get_metadata_multiple(
self, ordinals: list[SemanticRefOrdinal]
) -> list[SemanticRefMetadata]:
"""Batch-fetch lightweight metadata without deserializing knowledge."""
...


class IStorageProvider[TMessage: IMessage](Protocol):
"""API spec for storage providers -- maybe in-memory or persistent."""
Expand Down Expand Up @@ -190,4 +206,5 @@ class IConversation[
"ISemanticRefCollection",
"IStorageProvider",
"STATUS_INGESTED",
"SemanticRefMetadata",
]
18 changes: 8 additions & 10 deletions src/typeagent/knowpro/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ScoredSemanticRefOrdinal,
SearchTerm,
SemanticRef,
SemanticRefMetadata,
SemanticRefOrdinal,
SemanticRefSearchResult,
Term,
Expand Down Expand Up @@ -174,17 +175,14 @@ async def lookup_term_filtered(
semantic_ref_index: ITermToSemanticRefIndex,
term: Term,
semantic_refs: ISemanticRefCollection,
filter: Callable[[SemanticRef, ScoredSemanticRefOrdinal], bool],
filter: Callable[[SemanticRefMetadata, ScoredSemanticRefOrdinal], bool],
) -> list[ScoredSemanticRefOrdinal] | None:
"""Look up a term in the semantic reference index and filter the results."""
scored_refs = await semantic_ref_index.lookup_term(term.text)
if scored_refs:
filtered = []
for sr in scored_refs:
semantic_ref = await semantic_refs.get_item(sr.semantic_ref_ordinal)
if filter(semantic_ref, sr):
filtered.append(sr)
return filtered
ordinals = [sr.semantic_ref_ordinal for sr in scored_refs]
metadata = await semantic_refs.get_metadata_multiple(ordinals)
return [sr for sr, m in zip(scored_refs, metadata) if filter(m, sr)]
return None


Expand All @@ -202,10 +200,10 @@ async def lookup_term(
semantic_ref_index,
term,
semantic_refs,
lambda sr, _: (
not knowledge_type or sr.knowledge.knowledge_type == knowledge_type
lambda m, _: (
not knowledge_type or m.knowledge_type == knowledge_type
)
and ranges_in_scope.is_range_in_scope(sr.range),
and ranges_in_scope.is_range_in_scope(m.range),
)
return await semantic_ref_index.lookup_term(term.text)

Expand Down
13 changes: 13 additions & 0 deletions src/typeagent/storage/memory/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
IMessage,
MessageOrdinal,
SemanticRef,
SemanticRefMetadata,
SemanticRefOrdinal,
)

Expand Down Expand Up @@ -63,6 +64,18 @@ async def extend(self, items: Iterable[T]) -> None:
class MemorySemanticRefCollection(MemoryCollection[SemanticRef, SemanticRefOrdinal]):
"""A collection of semantic references."""

async def get_metadata_multiple(
self, ordinals: list[SemanticRefOrdinal]
) -> list[SemanticRefMetadata]:
return [
SemanticRefMetadata(
ordinal=o,
range=self.items[o].range,
knowledge_type=self.items[o].knowledge.knowledge_type,
)
for o in ordinals
]


class MemoryMessageCollection[TMessage: IMessage](
MemoryCollection[TMessage, MessageOrdinal]
Expand Down
Loading
Loading