diff --git a/pyproject.toml b/pyproject.toml index f11c5905..73e17a65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,9 @@ Documentation = "https://github.com/microsoft/typeagent-py/tree/main/docs/README [tool.uv.build-backend] module-root = "src" +[tool.uv.sources] +pytest-async-benchmark = { git = "https://github.com/KRRT7/pytest-async-benchmark.git", rev = "feat/pedantic-mode" } + [tool.pytest.ini_options] asyncio_default_fixture_loop_scope = "function" testpaths = ["tests"] @@ -91,6 +94,7 @@ dev = [ "opentelemetry-instrumentation-httpx>=0.57b0", "pyright>=1.1.408", # 407 has a regression "pytest>=8.3.5", + "pytest-async-benchmark", "pytest-asyncio>=0.26.0", "pytest-benchmark>=5.1.0", "pytest-mock>=3.14.0", diff --git a/src/typeagent/knowpro/answers.py b/src/typeagent/knowpro/answers.py index eb77e12f..9e984e18 100644 --- a/src/typeagent/knowpro/answers.py +++ b/src/typeagent/knowpro/answers.py @@ -452,19 +452,22 @@ async def get_scored_semantic_refs_from_ordinals_iter( semantic_ref_matches: list[ScoredSemanticRefOrdinal], knowledge_type: KnowledgeType, ) -> list[Scored[SemanticRef]]: - result = [] - for semantic_ref_match in semantic_ref_matches: - semantic_ref = await semantic_refs.get_item( - semantic_ref_match.semantic_ref_ordinal - ) - if semantic_ref.knowledge.knowledge_type == knowledge_type: - result.append( - Scored( - item=semantic_ref, - score=semantic_ref_match.score, - ) - ) - return result + if not semantic_ref_matches: + return [] + ordinals = [m.semantic_ref_ordinal for m in semantic_ref_matches] + metadata = await semantic_refs.get_metadata_multiple(ordinals) + matching = [ + (sr_match, m.ordinal) + for sr_match, m in zip(semantic_ref_matches, metadata) + if m.knowledge_type == knowledge_type + ] + if not matching: + return [] + full_refs = await semantic_refs.get_multiple([o for _, o in matching]) + return [ + Scored(item=ref, score=sr_match.score) + for (sr_match, _), ref in zip(matching, full_refs) + ] def merge_scored_concrete_entities( diff --git a/src/typeagent/knowpro/collections.py b/src/typeagent/knowpro/collections.py index a2716577..d7c07b19 100644 --- a/src/typeagent/knowpro/collections.py +++ b/src/typeagent/knowpro/collections.py @@ -331,13 +331,17 @@ async def group_matches_by_type( self, semantic_refs: ISemanticRefCollection, ) -> dict[KnowledgeType, "SemanticRefAccumulator"]: + matches = list(self) + if not matches: + return {} + ordinals = [match.value for match in matches] + metadata = await semantic_refs.get_metadata_multiple(ordinals) groups: dict[KnowledgeType, SemanticRefAccumulator] = {} - for match in self: - semantic_ref = await semantic_refs.get_item(match.value) - group = groups.get(semantic_ref.knowledge.knowledge_type) + for match, m in zip(matches, metadata): + group = groups.get(m.knowledge_type) if group is None: group = SemanticRefAccumulator(self.search_term_matches) - groups[semantic_ref.knowledge.knowledge_type] = group + groups[m.knowledge_type] = group group.set_match(match) return groups @@ -346,11 +350,14 @@ async def get_matches_in_scope( semantic_refs: ISemanticRefCollection, ranges_in_scope: "TextRangesInScope", ) -> "SemanticRefAccumulator": + matches = list(self) + if not matches: + return SemanticRefAccumulator(self.search_term_matches) + ordinals = [match.value for match in matches] + metadata = await semantic_refs.get_metadata_multiple(ordinals) accumulator = SemanticRefAccumulator(self.search_term_matches) - for match in self: - if ranges_in_scope.is_range_in_scope( - (await semantic_refs.get_item(match.value)).range - ): + for match, m in zip(matches, metadata): + if ranges_in_scope.is_range_in_scope(m.range): accumulator.set_match(match) return accumulator @@ -519,12 +526,16 @@ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> No self.add_range(text_range) def contains_range(self, inner_range: TextRange) -> bool: - # Since ranges are sorted by start, once we pass inner_range's start - # no further range can contain it. - for outer_range in self._ranges: - if outer_range.start > inner_range.start: - break - if inner_range in outer_range: + if not self._ranges: + return False + # Bisect on start only to find all ranges with start <= inner.start, + # then scan backwards — the most likely containing range has the + # largest start still <= inner's. + hi = bisect.bisect_right( + self._ranges, inner_range.start, key=lambda r: r.start + ) + for i in range(hi - 1, -1, -1): + if inner_range in self._ranges[i]: return True return False diff --git a/src/typeagent/knowpro/interfaces_core.py b/src/typeagent/knowpro/interfaces_core.py index 105e45b6..4dc8fc8e 100644 --- a/src/typeagent/knowpro/interfaces_core.py +++ b/src/typeagent/knowpro/interfaces_core.py @@ -4,6 +4,7 @@ from __future__ import annotations +from collections.abc import Sequence from datetime import datetime as Datetime from typing import ( Any, @@ -168,6 +169,11 @@ async def add_term( semantic_ref_ordinal: SemanticRefOrdinal | ScoredSemanticRefOrdinal, ) -> str: ... + async def add_terms_batch( + self, + terms: Sequence[tuple[str, SemanticRefOrdinal | ScoredSemanticRefOrdinal]], + ) -> None: ... + async def remove_term( self, term: str, semantic_ref_ordinal: SemanticRefOrdinal ) -> None: ... @@ -249,32 +255,24 @@ def __repr__(self) -> str: else: return f"{self.__class__.__name__}({self.start}, {self.end})" + @staticmethod + def _effective_end(tr: "TextRange") -> tuple[int, int]: + """Return (message_ordinal, chunk_ordinal) for the effective end.""" + if tr.end is not None: + return (tr.end.message_ordinal, tr.end.chunk_ordinal) + return (tr.start.message_ordinal, tr.start.chunk_ordinal + 1) + def __eq__(self, other: object) -> bool: if not isinstance(other, TextRange): return NotImplemented - if self.start != other.start: return False - - # Get the effective end for both ranges - self_end = self.end or TextLocation( - self.start.message_ordinal, self.start.chunk_ordinal + 1 - ) - other_end = other.end or TextLocation( - other.start.message_ordinal, other.start.chunk_ordinal + 1 - ) - return self_end == other_end + return TextRange._effective_end(self) == TextRange._effective_end(other) def __lt__(self, other: Self) -> bool: if self.start != other.start: return self.start < other.start - self_end = self.end or TextLocation( - self.start.message_ordinal, self.start.chunk_ordinal + 1 - ) - other_end = other.end or TextLocation( - other.start.message_ordinal, other.start.chunk_ordinal + 1 - ) - return self_end < other_end + return TextRange._effective_end(self) < TextRange._effective_end(other) def __gt__(self, other: Self) -> bool: return other.__lt__(self) @@ -286,13 +284,9 @@ def __le__(self, other: Self) -> bool: return not other.__lt__(self) def __contains__(self, other: Self) -> bool: - other_end = other.end or TextLocation( - other.start.message_ordinal, other.start.chunk_ordinal + 1 - ) - self_end = self.end or TextLocation( - self.start.message_ordinal, self.start.chunk_ordinal + 1 - ) - return self.start <= other.start and other_end <= self_end + if not (self.start <= other.start): + return False + return TextRange._effective_end(other) <= TextRange._effective_end(self) def serialize(self) -> TextRangeData: return self.__pydantic_serializer__.to_python( # type: ignore diff --git a/src/typeagent/knowpro/interfaces_indexes.py b/src/typeagent/knowpro/interfaces_indexes.py index a894ab88..6c348a01 100644 --- a/src/typeagent/knowpro/interfaces_indexes.py +++ b/src/typeagent/knowpro/interfaces_indexes.py @@ -59,6 +59,11 @@ async def add_property( semantic_ref_ordinal: SemanticRefOrdinal | ScoredSemanticRefOrdinal, ) -> None: ... + async def add_properties_batch( + self, + properties: Sequence[tuple[str, str, SemanticRefOrdinal | ScoredSemanticRefOrdinal]], + ) -> None: ... + async def lookup_property( self, property_name: str, value: str ) -> list[ScoredSemanticRefOrdinal] | None: ... diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index a82fe7ad..97f7b600 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -6,16 +6,18 @@ from collections.abc import AsyncIterable, Iterable from datetime import datetime as Datetime -from typing import Any, Protocol, Self +from typing import Any, NamedTuple, Protocol, Self from pydantic.dataclasses import dataclass from .interfaces_core import ( IMessage, ITermToSemanticRefIndex, + KnowledgeType, MessageOrdinal, SemanticRef, SemanticRefOrdinal, + TextRange, ) from .interfaces_indexes import ( IConversationSecondaryIndexes, @@ -57,6 +59,14 @@ class ConversationMetadata: extra: dict[str, str] | None = None +class SemanticRefMetadata(NamedTuple): + """Lightweight metadata for filtering without full knowledge deserialization.""" + + ordinal: SemanticRefOrdinal + range: TextRange + knowledge_type: KnowledgeType + + class IReadonlyCollection[T, TOrdinal](AsyncIterable[T], Protocol): async def size(self) -> int: ... @@ -91,6 +101,12 @@ class IMessageCollection[TMessage: IMessage]( class ISemanticRefCollection(ICollection[SemanticRef, SemanticRefOrdinal], Protocol): """A collection of SemanticRefs.""" + async def get_metadata_multiple( + self, ordinals: list[SemanticRefOrdinal] + ) -> list[SemanticRefMetadata]: + """Batch-fetch lightweight metadata without deserializing knowledge.""" + ... + class IStorageProvider[TMessage: IMessage](Protocol): """API spec for storage providers -- maybe in-memory or persistent.""" @@ -190,4 +206,5 @@ class IConversation[ "ISemanticRefCollection", "IStorageProvider", "STATUS_INGESTED", + "SemanticRefMetadata", ] diff --git a/src/typeagent/knowpro/query.py b/src/typeagent/knowpro/query.py index 44fa06ec..5859e3bc 100644 --- a/src/typeagent/knowpro/query.py +++ b/src/typeagent/knowpro/query.py @@ -37,6 +37,7 @@ ScoredSemanticRefOrdinal, SearchTerm, SemanticRef, + SemanticRefMetadata, SemanticRefOrdinal, SemanticRefSearchResult, Term, @@ -174,17 +175,14 @@ async def lookup_term_filtered( semantic_ref_index: ITermToSemanticRefIndex, term: Term, semantic_refs: ISemanticRefCollection, - filter: Callable[[SemanticRef, ScoredSemanticRefOrdinal], bool], + filter: Callable[[SemanticRefMetadata, ScoredSemanticRefOrdinal], bool], ) -> list[ScoredSemanticRefOrdinal] | None: """Look up a term in the semantic reference index and filter the results.""" scored_refs = await semantic_ref_index.lookup_term(term.text) if scored_refs: - filtered = [] - for sr in scored_refs: - semantic_ref = await semantic_refs.get_item(sr.semantic_ref_ordinal) - if filter(semantic_ref, sr): - filtered.append(sr) - return filtered + ordinals = [sr.semantic_ref_ordinal for sr in scored_refs] + metadata = await semantic_refs.get_metadata_multiple(ordinals) + return [sr for sr, m in zip(scored_refs, metadata) if filter(m, sr)] return None @@ -202,10 +200,10 @@ async def lookup_term( semantic_ref_index, term, semantic_refs, - lambda sr, _: ( - not knowledge_type or sr.knowledge.knowledge_type == knowledge_type + lambda m, _: ( + not knowledge_type or m.knowledge_type == knowledge_type ) - and ranges_in_scope.is_range_in_scope(sr.range), + and ranges_in_scope.is_range_in_scope(m.range), ) return await semantic_ref_index.lookup_term(term.text) diff --git a/src/typeagent/storage/memory/collections.py b/src/typeagent/storage/memory/collections.py index 9973a290..8a5b14eb 100644 --- a/src/typeagent/storage/memory/collections.py +++ b/src/typeagent/storage/memory/collections.py @@ -10,6 +10,7 @@ IMessage, MessageOrdinal, SemanticRef, + SemanticRefMetadata, SemanticRefOrdinal, ) @@ -63,6 +64,18 @@ async def extend(self, items: Iterable[T]) -> None: class MemorySemanticRefCollection(MemoryCollection[SemanticRef, SemanticRefOrdinal]): """A collection of semantic references.""" + async def get_metadata_multiple( + self, ordinals: list[SemanticRefOrdinal] + ) -> list[SemanticRefMetadata]: + return [ + SemanticRefMetadata( + ordinal=o, + range=self.items[o].range, + knowledge_type=self.items[o].knowledge.knowledge_type, + ) + for o in ordinals + ] + class MemoryMessageCollection[TMessage: IMessage]( MemoryCollection[TMessage, MessageOrdinal] diff --git a/src/typeagent/storage/memory/propindex.py b/src/typeagent/storage/memory/propindex.py index acc7b89a..ecb3e85d 100644 --- a/src/typeagent/storage/memory/propindex.py +++ b/src/typeagent/storage/memory/propindex.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from collections.abc import Sequence import enum from typing import assert_never @@ -109,6 +110,63 @@ async def build_property_index(conversation: IConversation) -> None: await add_to_property_index(conversation, 0) +def collect_facet_properties( + facet: kplib.Facet | None, + ordinal: SemanticRefOrdinal, +) -> list[tuple[str, str, SemanticRefOrdinal]]: + """Collect property tuples from a facet without touching any index.""" + if facet is None: + return [] + props: list[tuple[str, str, SemanticRefOrdinal]] = [ + (PropertyNames.FacetName.value, facet.name, ordinal) + ] + value = facet.value + if value is not None: + if isinstance(value, float) and value: + value = f"{value:g}" + props.append((PropertyNames.FacetValue.value, str(value), ordinal)) + return props + + +def collect_entity_properties( + entity: kplib.ConcreteEntity, + ordinal: SemanticRefOrdinal, +) -> list[tuple[str, str, SemanticRefOrdinal]]: + """Collect all property tuples for an entity.""" + props: list[tuple[str, str, SemanticRefOrdinal]] = [ + (PropertyNames.EntityName.value, entity.name, ordinal) + ] + for t in entity.type: + props.append((PropertyNames.EntityType.value, t, ordinal)) + if entity.facets: + for facet in entity.facets: + props.extend(collect_facet_properties(facet, ordinal)) + return props + + +def collect_action_properties( + action: kplib.Action, + ordinal: SemanticRefOrdinal, +) -> list[tuple[str, str, SemanticRefOrdinal]]: + """Collect all property tuples for an action.""" + props: list[tuple[str, str, SemanticRefOrdinal]] = [ + (PropertyNames.Verb.value, " ".join(action.verbs), ordinal) + ] + if action.subject_entity_name != "none": + props.append((PropertyNames.Subject.value, action.subject_entity_name, ordinal)) + if action.object_entity_name != "none": + props.append((PropertyNames.Object.value, action.object_entity_name, ordinal)) + if action.indirect_object_entity_name != "none": + props.append( + ( + PropertyNames.IndirectObject.value, + action.indirect_object_entity_name, + ordinal, + ) + ) + return props + + async def add_to_property_index( conversation: IConversation, start_at_ordinal: SemanticRefOrdinal, @@ -127,29 +185,40 @@ async def add_to_property_index( semantic_refs = conversation.semantic_refs size = await semantic_refs.size() + collected: list[tuple[str, str, SemanticRefOrdinal]] = [] for semantic_ref_ordinal, semantic_ref in enumerate( await semantic_refs.get_slice(start_at_ordinal, size), start_at_ordinal, ): assert semantic_ref.semantic_ref_ordinal == semantic_ref_ordinal if isinstance(semantic_ref.knowledge, kplib.Action): - await add_action_properties_to_index( - semantic_ref.knowledge, property_index, semantic_ref_ordinal + collected.extend( + collect_action_properties( + semantic_ref.knowledge, semantic_ref_ordinal + ) ) elif isinstance(semantic_ref.knowledge, kplib.ConcreteEntity): - await add_entity_properties_to_index( - semantic_ref.knowledge, property_index, semantic_ref_ordinal + collected.extend( + collect_entity_properties( + semantic_ref.knowledge, semantic_ref_ordinal + ) ) elif isinstance(semantic_ref.knowledge, Tag): - tag = semantic_ref.knowledge - await property_index.add_property( - PropertyNames.Tag.value, tag.text, semantic_ref_ordinal + collected.append( + ( + PropertyNames.Tag.value, + semantic_ref.knowledge.text, + semantic_ref_ordinal, + ) ) elif isinstance(semantic_ref.knowledge, Topic): pass else: assert_never(semantic_ref.knowledge) + if collected: + await property_index.add_properties_batch(collected) + class PropertyIndex(IPropertyToSemanticRefIndex): def __init__(self): @@ -183,6 +252,15 @@ async def add_property( else: self._map[term_text] = [semantic_ref_ordinal] + async def add_properties_batch( + self, + properties: Sequence[ + tuple[str, str, SemanticRefOrdinal | ScoredSemanticRefOrdinal] + ], + ) -> None: + for name, value, ordinal in properties: + await self.add_property(name, value, ordinal) + async def clear(self) -> None: self._map = {} @@ -252,12 +330,13 @@ async def lookup_property_in_property_index( property_value, ) if ranges_in_scope is not None and scored_refs: - filtered_refs = [] - for sr in scored_refs: - semantic_ref = await semantic_refs.get_item(sr.semantic_ref_ordinal) - if ranges_in_scope.is_range_in_scope(semantic_ref.range): - filtered_refs.append(sr) - scored_refs = filtered_refs + ordinals = [sr.semantic_ref_ordinal for sr in scored_refs] + metadata = await semantic_refs.get_metadata_multiple(ordinals) + scored_refs = [ + sr + for sr, m in zip(scored_refs, metadata) + if ranges_in_scope.is_range_in_scope(m.range) + ] return scored_refs or None # Return None if no results diff --git a/src/typeagent/storage/memory/semrefindex.py b/src/typeagent/storage/memory/semrefindex.py index 6c42022d..8437bacd 100644 --- a/src/typeagent/storage/memory/semrefindex.py +++ b/src/typeagent/storage/memory/semrefindex.py @@ -3,11 +3,13 @@ from __future__ import annotations # TODO: Avoid -from collections.abc import AsyncIterable, Callable +from collections.abc import AsyncIterable, Callable, Sequence from typechat import Failure -from ...knowpro import convknowledge, knowledge_schema as kplib, secindex +from ...knowpro import convknowledge +from ...knowpro import knowledge_schema as kplib +from ...knowpro import secindex from ...knowpro.convsettings import ConversationSettings, SemanticRefIndexSettings from ...knowpro.interfaces import ( # Interfaces.; Other imports. IConversation, @@ -577,6 +579,48 @@ async def add_metadata_to_index[TMessage: IMessage]( i += 1 +def collect_facet_terms(facet: kplib.Facet | None) -> list[str]: + """Collect terms from a facet without touching any index.""" + if facet is None: + return [] + terms = [facet.name] + if facet.value is not None: + terms.append(str(facet.value)) + return terms + + +def collect_entity_terms(entity: kplib.ConcreteEntity) -> list[str]: + """Collect all terms an entity would add to the semantic ref index.""" + terms = [entity.name] + for t in entity.type: + terms.append(t) + if entity.facets: + for facet in entity.facets: + terms.extend(collect_facet_terms(facet)) + return terms + + +def collect_action_terms(action: kplib.Action) -> list[str]: + """Collect all terms an action would add to the semantic ref index.""" + terms = [" ".join(action.verbs)] + if action.subject_entity_name != "none": + terms.append(action.subject_entity_name) + if action.object_entity_name != "none": + terms.append(action.object_entity_name) + if action.indirect_object_entity_name != "none": + terms.append(action.indirect_object_entity_name) + if action.params: + for param in action.params: + if isinstance(param, str): + terms.append(param) + else: + terms.append(param.name) + if isinstance(param.value, str): + terms.append(param.value) + terms.extend(collect_facet_terms(action.subject_entity_facet)) + return terms + + async def add_metadata_to_index_from_list[TMessage: IMessage]( messages: list[TMessage], semantic_refs: ISemanticRefCollection, @@ -585,18 +629,50 @@ async def add_metadata_to_index_from_list[TMessage: IMessage]( knowledge_validator: KnowledgeValidator | None = None, ) -> None: """Extract metadata knowledge from a list of messages starting at ordinal.""" + next_ordinal = await semantic_refs.size() + collected_refs: list[SemanticRef] = [] + collected_terms: list[tuple[str, SemanticRefOrdinal]] = [] + for i, msg in enumerate(messages, start_from_ordinal): knowledge_response = msg.get_knowledge() for entity in knowledge_response.entities: if knowledge_validator is None or knowledge_validator("entity", entity): - await add_entity_to_index(entity, semantic_refs, semantic_ref_index, i) + ref = SemanticRef( + semantic_ref_ordinal=next_ordinal, + range=text_range_from_location(i), + knowledge=entity, + ) + collected_refs.append(ref) + for term in collect_entity_terms(entity): + collected_terms.append((term, next_ordinal)) + next_ordinal += 1 for action in knowledge_response.actions: if knowledge_validator is None or knowledge_validator("action", action): - await add_action_to_index(action, semantic_refs, semantic_ref_index, i) + ref = SemanticRef( + semantic_ref_ordinal=next_ordinal, + range=text_range_from_location(i), + knowledge=action, + ) + collected_refs.append(ref) + for term in collect_action_terms(action): + collected_terms.append((term, next_ordinal)) + next_ordinal += 1 for topic_response in knowledge_response.topics: topic = Topic(text=topic_response) if knowledge_validator is None or knowledge_validator("topic", topic): - await add_topic_to_index(topic, semantic_refs, semantic_ref_index, i) + ref = SemanticRef( + semantic_ref_ordinal=next_ordinal, + range=text_range_from_location(i), + knowledge=topic, + ) + collected_refs.append(ref) + collected_terms.append((topic.text, next_ordinal)) + next_ordinal += 1 + + if collected_refs: + await semantic_refs.extend(collected_refs) + if collected_terms: + await semantic_ref_index.add_terms_batch(collected_terms) class TermToSemanticRefIndex(ITermToSemanticRefIndex): @@ -635,6 +711,13 @@ async def add_term( self._map[term] = [semantic_ref_ordinal] return term + async def add_terms_batch( + self, + terms: Sequence[tuple[str, SemanticRefOrdinal | ScoredSemanticRefOrdinal]], + ) -> None: + for term, ordinal in terms: + await self.add_term(term, ordinal) + async def lookup_term(self, term: str) -> list[ScoredSemanticRefOrdinal] | None: return self._map.get(self._prepare_term(term)) or [] diff --git a/src/typeagent/storage/sqlite/collections.py b/src/typeagent/storage/sqlite/collections.py index 9730f6d1..fe394dcb 100644 --- a/src/typeagent/storage/sqlite/collections.py +++ b/src/typeagent/storage/sqlite/collections.py @@ -340,6 +340,50 @@ async def get_multiple(self, arg: list[int]) -> list[interfaces.SemanticRef]: assert set(rowdict) == set(arg) return [self._deserialize_semantic_ref_from_row(rowdict[ordl]) for ordl in arg] + async def get_metadata_multiple( + self, ordinals: list[int] + ) -> list[interfaces.SemanticRefMetadata]: + if not ordinals: + return [] + cursor = self.db.cursor() + placeholders = ",".join("?" * len(ordinals)) + cursor.execute( + f""" + SELECT semref_id, range_json, knowledge_type + FROM SemanticRefs WHERE semref_id IN ({placeholders}) + """, + ordinals, + ) + rows = cursor.fetchall() + rowdict = {r[0]: r for r in rows} + result = [] + for o in ordinals: + row = rowdict[o] + range_data = json.loads(row[1]) + start = range_data["start"] + end_data = range_data.get("end") + result.append( + interfaces.SemanticRefMetadata( + ordinal=row[0], + range=interfaces.TextRange( + start=interfaces.TextLocation( + start["messageOrdinal"], + start.get("chunkOrdinal", 0), + ), + end=( + interfaces.TextLocation( + end_data["messageOrdinal"], + end_data.get("chunkOrdinal", 0), + ) + if end_data + else None + ), + ), + knowledge_type=row[2], + ) + ) + return result + async def append(self, item: interfaces.SemanticRef) -> None: cursor = self.db.cursor() semref_id, range_json, knowledge_type, knowledge_json = ( diff --git a/src/typeagent/storage/sqlite/propindex.py b/src/typeagent/storage/sqlite/propindex.py index 5a0fa63a..f9704b45 100644 --- a/src/typeagent/storage/sqlite/propindex.py +++ b/src/typeagent/storage/sqlite/propindex.py @@ -3,6 +3,7 @@ """SQLite-based property index implementation.""" +from collections.abc import Sequence import sqlite3 from ...knowpro import interfaces @@ -67,6 +68,43 @@ async def add_property( (property_name, value, score, semref_id), ) + async def add_properties_batch( + self, + properties: Sequence[ + tuple[ + str, + str, + interfaces.SemanticRefOrdinal | interfaces.ScoredSemanticRefOrdinal, + ] + ], + ) -> None: + if not properties: + return + from ...storage.memory.propindex import ( + make_property_term_text, + split_property_term_text, + ) + + rows = [] + for property_name, value, ordinal in properties: + if isinstance(ordinal, interfaces.ScoredSemanticRefOrdinal): + semref_id = ordinal.semantic_ref_ordinal + score = ordinal.score + else: + semref_id = ordinal + score = 1.0 + term_text = make_property_term_text(property_name, value) + term_text = term_text.lower() + property_name, value = split_property_term_text(term_text) + if property_name.startswith("prop."): + property_name = property_name[5:] + rows.append((property_name, value, score, semref_id)) + cursor = self.db.cursor() + cursor.executemany( + "INSERT INTO PropertyIndex (prop_name, value_str, score, semref_id) VALUES (?, ?, ?, ?)", + rows, + ) + async def clear(self) -> None: cursor = self.db.cursor() cursor.execute("DELETE FROM PropertyIndex") diff --git a/src/typeagent/storage/sqlite/semrefindex.py b/src/typeagent/storage/sqlite/semrefindex.py index 682b8e7d..ac68a1e0 100644 --- a/src/typeagent/storage/sqlite/semrefindex.py +++ b/src/typeagent/storage/sqlite/semrefindex.py @@ -3,6 +3,7 @@ """SQLite-based semantic reference index implementation.""" +from collections.abc import Sequence import re import sqlite3 import unicodedata @@ -56,6 +57,33 @@ async def add_term( return term + async def add_terms_batch( + self, + terms: Sequence[ + tuple[ + str, interfaces.SemanticRefOrdinal | interfaces.ScoredSemanticRefOrdinal + ] + ], + ) -> None: + if not terms: + return + rows = [] + for term, ordinal in terms: + if not term: + continue + term = self._prepare_term(term) + if isinstance(ordinal, interfaces.ScoredSemanticRefOrdinal): + semref_id = ordinal.semantic_ref_ordinal + else: + semref_id = ordinal + rows.append((term, semref_id)) + if rows: + cursor = self.db.cursor() + cursor.executemany( + "INSERT OR IGNORE INTO SemanticRefIndex (term, semref_id) VALUES (?, ?)", + rows, + ) + async def remove_term( self, term: str, semantic_ref_ordinal: interfaces.SemanticRefOrdinal ) -> None: diff --git a/tests/benchmarks/test_benchmark_query.py b/tests/benchmarks/test_benchmark_query.py new file mode 100644 index 00000000..8c4dd137 --- /dev/null +++ b/tests/benchmarks/test_benchmark_query.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Benchmark for lookup_term_filtered — measures the N+1 query pattern. + +After indexing 200 synthetic messages, looks up a high-frequency term +and filters results via lookup_term_filtered. Each call triggers +one get_item() SELECT per matching semantic ref (N+1 pattern). + +Run: + uv run python -m pytest tests/benchmarks/test_benchmark_query.py -v -s +""" + +import os +import tempfile + +import pytest + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.knowpro.interfaces_core import Term +from typeagent.knowpro.query import lookup_term_filtered +from typeagent.storage.sqlite.provider import SqliteStorageProvider +from typeagent.transcripts.transcript import ( + Transcript, + TranscriptMessage, + TranscriptMessageMeta, +) + + +def make_settings() -> ConversationSettings: + model = create_test_embedding_model() + settings = ConversationSettings(model=model) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + return settings + + +def synthetic_messages(n: int) -> list[TranscriptMessage]: + return [ + TranscriptMessage( + text_chunks=[f"Message {i} about topic {i % 10}"], + metadata=TranscriptMessageMeta(speaker=f"Speaker{i % 3}"), + tags=[f"tag{i % 5}"], + ) + for i in range(n) + ] + + +async def create_indexed_transcript( + db_path: str, settings: ConversationSettings, n_messages: int +) -> Transcript: + """Create and index a transcript, returning it ready for queries.""" + storage = SqliteStorageProvider( + db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings.message_text_index_settings, + related_term_index_settings=settings.related_term_index_settings, + ) + settings.storage_provider = storage + transcript = await Transcript.create(settings, name="bench") + messages = synthetic_messages(n_messages) + await transcript.add_messages_with_indexing(messages) + return transcript + + +@pytest.mark.asyncio +async def test_benchmark_lookup_term_filtered(async_benchmark): + """Benchmark lookup_term_filtered with N+1 get_item pattern.""" + settings = make_settings() + tmpdir = tempfile.mkdtemp() + db_path = os.path.join(tmpdir, "query_bench.db") + + transcript = await create_indexed_transcript(db_path, settings, 200) + + # Find a high-frequency term to look up. + semref_index = transcript.semantic_ref_index + terms = await semref_index.get_terms() + # Pick the term with the most matches. + best_term = None + best_count = 0 + for t in terms: + refs = await semref_index.lookup_term(t) + if refs and len(refs) > best_count: + best_count = len(refs) + best_term = t + + assert best_term is not None, "No terms found after indexing" + print(f"\nBenchmarking term '{best_term}' with {best_count} matches") + + term = Term(text=best_term) + semantic_refs = transcript.semantic_refs + # Filter that accepts all — isolates the get_item overhead. + accept_all = lambda sr, scored: True + + async def target(): + await lookup_term_filtered(semref_index, term, semantic_refs, accept_all) + + try: + await async_benchmark.pedantic(target, rounds=200, warmup_rounds=20) + finally: + await settings.storage_provider.close() + import shutil + + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/tests/conftest.py b/tests/conftest.py index dae619c1..7f0f11f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from collections.abc import AsyncGenerator, Callable, Iterator +from collections.abc import AsyncGenerator, Callable, Iterator, Sequence import os from pathlib import Path import tempfile @@ -236,6 +236,13 @@ async def add_term( self.term_to_refs[term].append(scored_ref) return term + async def add_terms_batch( + self, + terms: Sequence[tuple[str, int | ScoredSemanticRefOrdinal]], + ) -> None: + for term, ordinal in terms: + await self.add_term(term, ordinal) + async def remove_term(self, term: str, semantic_ref_ordinal: int) -> None: if term in self.term_to_refs: self.term_to_refs[term] = [ diff --git a/uv.lock b/uv.lock index 5701a7c7..a0c0a6c5 100644 --- a/uv.lock +++ b/uv.lock @@ -1931,6 +1931,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-async-benchmark" +version = "0.2.0" +source = { git = "https://github.com/KRRT7/pytest-async-benchmark.git?rev=feat%2Fpedantic-mode#029d03634d140789baebc6c3c8f72d5c81a67f9a" } +dependencies = [ + { name = "pytest" }, + { name = "rich" }, +] + [[package]] name = "pytest-asyncio" version = "1.3.0" @@ -2420,6 +2429,7 @@ dev = [ { name = "opentelemetry-instrumentation-httpx" }, { name = "pyright" }, { name = "pytest" }, + { name = "pytest-async-benchmark" }, { name = "pytest-asyncio" }, { name = "pytest-benchmark" }, { name = "pytest-mock" }, @@ -2459,6 +2469,7 @@ dev = [ { name = "opentelemetry-instrumentation-httpx", specifier = ">=0.57b0" }, { name = "pyright", specifier = ">=1.1.408" }, { name = "pytest", specifier = ">=8.3.5" }, + { name = "pytest-async-benchmark", git = "https://github.com/KRRT7/pytest-async-benchmark.git?rev=feat%2Fpedantic-mode" }, { name = "pytest-asyncio", specifier = ">=0.26.0" }, { name = "pytest-benchmark", specifier = ">=5.1.0" }, { name = "pytest-mock", specifier = ">=3.14.0" },