diff --git a/ravendb/documents/session/query.py b/ravendb/documents/session/query.py index 26e3fc99..fdb5b4a7 100644 --- a/ravendb/documents/session/query.py +++ b/ravendb/documents/session/query.py @@ -2,6 +2,7 @@ import datetime import enum import os +import warnings from copy import copy from typing import ( Generic, @@ -1041,29 +1042,45 @@ def _add_root_type(self, object_type: Type[_T]): def _vector_search_internal( self, wrapped_embedding_field: str, - vector: Union[List[float], List[int], str], + term: Union[List[float], List[int], str] = None, source_quantization_type: VectorEmbeddingType = VectorSearch.DEFAULT_EMBEDDING_TYPE, target_quantization_type: VectorEmbeddingType = VectorSearch.DEFAULT_EMBEDDING_TYPE, minimum_similarity: float = None, number_of_candidates: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, task_name: str = None, + is_index_field: bool = False, + document_id: str = None, ): - is_source_base64_encoded = False - is_vector_base64_encoded = False + if ( + source_quantization_type == VectorEmbeddingType.INT8 + or source_quantization_type == VectorEmbeddingType.BINARY + ) and target_quantization_type != source_quantization_type: + raise ValueError( + f"Cannot quantize already quantized embeddings. Source quantization type: {source_quantization_type.value}; however the target is: {target_quantization_type.value}." + ) + + if target_quantization_type == VectorEmbeddingType.TEXT: + raise ValueError( + f"Cannot set target quantization type to be {str(target_quantization_type.value)}. This option is only availabe for source_quantization_type." + ) - query_parameter_name = self.__add_query_parameter(vector) + if target_quantization_type != VectorSearch.DEFAULT_EMBEDDING_TYPE and is_index_field: + raise ValueError( + f"Cannot set target quantization when querying an index, since quantization is already done on the index side." + ) + + query_parameter_name = self.__add_query_parameter(term) vector_search_token = VectorSearchToken( wrapped_embedding_field, query_parameter_name, source_quantization_type, target_quantization_type, - is_source_base64_encoded, - is_vector_base64_encoded, minimum_similarity, number_of_candidates, is_exact, task_name, + document_id, ) self._where_tokens.append(vector_search_token) @@ -1884,58 +1901,224 @@ def search(self, field_name: str, search_terms: str, operator: SearchOperator = def vector_search( self, embedding_field: str, - vector: Union[List[float], str], # todo: docs about base 64 (|str) + vector: Union[List[float], str], minimum_similarity: float = None, number_of_candidates: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, + target_quantization: VectorEmbeddingType = VectorEmbeddingType.SINGLE, ) -> DocumentQuery[_T]: - """Perform vector search using embedding field (float32)""" + """Perform vector search using embedding field (float32) + The vector parameter can be either a list of float or a base64 string.""" self._vector_search_internal( - embedding_field, - vector, - VectorEmbeddingType.SINGLE, - VectorEmbeddingType.SINGLE, - minimum_similarity, - number_of_candidates, - is_exact, + wrapped_embedding_field=embedding_field, + term=vector, + source_quantization_type=VectorEmbeddingType.SINGLE, + target_quantization_type=target_quantization, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, + ) + return self + + def vector_search_with_field( + self, + index_embedding_field: str, + vector: list[float] | str, + minimum_similarity: float = None, + number_of_candidates: int = None, + is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, + ): + """Perform vector search using float32 embedding field + The vector parameter can be either a list of float or a base64 string.""" + self._vector_search_internal( + wrapped_embedding_field=index_embedding_field, + term=vector, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, + is_index_field=True, + ) + return self + + def vector_search_with_i8_field( + self, + index_embedding_field: str, + vector: list[int] | str, + minimum_similarity: float = None, + number_of_candidates: int = None, + is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, + ): + """Perform vector search using int8 embedding field + The vector parameter can be either a list of int or a base64 string.""" + self._vector_search_internal( + wrapped_embedding_field=index_embedding_field, + term=vector, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, + is_index_field=True, + ) + return self + + def vector_search_with_i1_field( + self, + index_embedding_field: str, + vector: list[int], + minimum_similarity: float = None, + number_of_candidates: int = None, + is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, + ): + """Perform vector search using int1 embedding field + The vector parameter can be either a list of int or a base64 string.""" + self._vector_search_internal( + wrapped_embedding_field=index_embedding_field, + term=vector, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, + is_index_field=True, + ) + return self + + def vector_search_with_text_field( + self, + index_embedding_field: str, + search_term: str, + minimum_similarity: float = None, + number_of_candidates: int = None, + is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, + ): + self._vector_search_internal( + wrapped_embedding_field=index_embedding_field, + term=search_term, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, + is_index_field=True, + ) + return self + + def vector_search_with_field_for_document( + self, + index_embedding_field: str, + document_id: str, + minimum_similarity: float = None, + number_of_candidates: int = None, + is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, + ): + self._vector_search_internal( + wrapped_embedding_field=index_embedding_field, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, + is_index_field=True, + document_id=document_id, + ) + return self + + def vector_search_with_base64( + self, + embedding_field: str, + vector: list[float] | str, + target_quantization: VectorEmbeddingType = VectorSearch.DEFAULT_EMBEDDING_TYPE, + minimum_similarity: float = None, + number_of_candidates: int = None, + is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, + ): + """Perform vector search over a base64-encoded f32 vector field. + The vector parameter can be either a list of float or a base64 string.""" + self._vector_search_internal( + wrapped_embedding_field=embedding_field, + term=vector, + source_quantization_type=VectorEmbeddingType.SINGLE, + target_quantization_type=target_quantization, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, + ) + return self + + def vector_search_with_base64_i8( + self, + embedding_field: str, + vector: list[float] | str, + minimum_similarity: float = None, + number_of_candidates: int = None, + is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, + ): + """Perform vector search over a base64-encoded int8 vector field. + The vector parameter can be either a list of int or a base64 string.""" + self._vector_search_internal( + wrapped_embedding_field=embedding_field, + term=vector, + source_quantization_type=VectorEmbeddingType.INT8, + target_quantization_type=VectorEmbeddingType.INT8, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, + ) + return self + + def vector_search_with_base64_i1( + self, + embedding_field: str, + vector: list[int] | str, + minimum_similarity: float = None, + number_of_candidates: int = None, + is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, + ): + """Perform vector search over a base64-encoded int1 vector field. + The vector parameter can be either a list of int or a base64 string.""" + self._vector_search_internal( + wrapped_embedding_field=embedding_field, + term=vector, + source_quantization_type=VectorEmbeddingType.BINARY, + target_quantization_type=VectorEmbeddingType.BINARY, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, ) return self def vector_search_i8( self, embedding_field: str, - vector: List[int], + vector: List[int] | str, minimum_similarity: float = None, number_of_candidates: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, ) -> DocumentQuery[_T]: + """Perform vector search using int8 field. + The vector parameter can be either a list of int or a base64 string.""" self._vector_search_internal( - embedding_field, - vector, - VectorEmbeddingType.INT8, - VectorEmbeddingType.INT8, - minimum_similarity, - number_of_candidates, - is_exact, + wrapped_embedding_field=embedding_field, + term=vector, + source_quantization_type=VectorEmbeddingType.INT8, + target_quantization_type=VectorEmbeddingType.INT8, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, ) return self def vector_search_i1( self, embedding_field: str, - vector: List[int], + vector: List[int] | str, minimum_similarity: float = None, number_of_candidates: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, ) -> DocumentQuery[_T]: + """Perform vector search using int1 field. + The vector parameter can be either a list of int or a base64 string.""" self._vector_search_internal( - embedding_field, - vector, - VectorEmbeddingType.BINARY, - VectorEmbeddingType.BINARY, - minimum_similarity, - number_of_candidates, - is_exact, + wrapped_embedding_field=embedding_field, + term=vector, + source_quantization_type=VectorEmbeddingType.BINARY, + target_quantization_type=VectorEmbeddingType.BINARY, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, ) return self @@ -1946,16 +2129,41 @@ def vector_search_text( minimum_similarity: float = None, number_of_candidates: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, + target_quantization: VectorEmbeddingType = VectorSearch.DEFAULT_EMBEDDING_TYPE, + embedding_generation_task_identifier: str = None, ) -> DocumentQuery[_T]: """Perform vector search using text field""" self._vector_search_internal( - embedding_field, - vector, - VectorEmbeddingType.TEXT, - VectorEmbeddingType.SINGLE, - minimum_similarity, - number_of_candidates, - is_exact, + wrapped_embedding_field=embedding_field, + term=vector, + source_quantization_type=VectorEmbeddingType.TEXT, + target_quantization_type=target_quantization, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, + task_name=embedding_generation_task_identifier, + ) + return self + + def vector_search_text_for_document( + self, + embedding_field: str, + document_id: str, + target_quantization: VectorEmbeddingType = VectorSearch.DEFAULT_EMBEDDING_TYPE, + minimum_similarity: float = None, + number_of_candidates: int = None, + is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, + embedding_generation_task_identifier: str = None, + ): + self._vector_search_internal( + wrapped_embedding_field=embedding_field, + source_quantization_type=VectorEmbeddingType.TEXT, + target_quantization_type=target_quantization, + minimum_similarity=minimum_similarity, + number_of_candidates=number_of_candidates, + is_exact=is_exact, + task_name=embedding_generation_task_identifier, + document_id=document_id, ) return self @@ -1968,7 +2176,13 @@ def vector_search_text_using_task( number_of_candidates: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, ) -> DocumentQuery[_T]: - """Perform vector search using text field""" + """Deprecated: Use vector_search_text() with the embedding_generation_task_identifier parameter instead.""" + warnings.warn( + "vector_search_text_using_task is deprecated; use vector_search_text with " + "embedding_generation_task_identifier parameter instead.", + DeprecationWarning, + stacklevel=2, + ) self._vector_search_internal( embedding_field, vector, @@ -1989,6 +2203,13 @@ def vector_search_f32_i8( number_of_candidates: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, ) -> DocumentQuery[_T]: + """Deprecated: Use vector_search() with the target_quantization = VectorEmbeddingType.INT8 parameter instead.""" + warnings.warn( + "vector_search_f32_i8 is deprecated; use vector_search with " + "target_quantization = VectorEmbeddingType.INT8 parameter instead.", + DeprecationWarning, + stacklevel=2, + ) self._vector_search_internal( embedding_field, vector, @@ -2008,6 +2229,13 @@ def vector_search_f32_i1( number_of_candidates: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, ) -> DocumentQuery[_T]: + """Deprecated: Use vector_search() with the target_quantization = VectorEmbeddingType.BINARY parameter instead.""" + warnings.warn( + "vector_search_f32_i1 is deprecated; use vector_search with " + "target_quantization = VectorEmbeddingType.BINARY parameter instead.", + DeprecationWarning, + stacklevel=2, + ) self._vector_search_internal( embedding_field, vector, @@ -2027,6 +2255,13 @@ def vector_search_text_i8( number_of_candidates: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, ) -> DocumentQuery[_T]: + """Deprecated: Use vector_search_text() with the target_quantization = VectorEmbeddingType.INT8 parameter instead.""" + warnings.warn( + "vector_search_text_i8 is deprecated; use vector_search_text with " + "target_quantization = VectorEmbeddingType.INT8 parameter instead.", + DeprecationWarning, + stacklevel=2, + ) self._vector_search_internal( embedding_field, vector, @@ -2046,6 +2281,13 @@ def vector_search_text_i1( number_of_candidates: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, ) -> DocumentQuery[_T]: + """Deprecated: Use vector_search_text() with the target_quantization = VectorEmbeddingType.BINARY parameter instead.""" + warnings.warn( + "vector_search_text_i1 is deprecated; use vector_search_text with " + "target_quantization = VectorEmbeddingType.BINARY parameter instead.", + DeprecationWarning, + stacklevel=2, + ) self._vector_search_internal( embedding_field, vector, @@ -2066,6 +2308,13 @@ def vector_search_text_i8_using_task( number_of_candidates: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, ) -> DocumentQuery[_T]: + """Deprecated: Use vector_search_text() with the target_quantization = VectorEmbeddingType.INT8 and embedding_generation_task_identifier parameters instead.""" + warnings.warn( + "vector_search_text_i8_using_task is deprecated; use vector_search_text with " + "target_quantization = VectorEmbeddingType.INT8 and embedding_generation_task_identifier parameters instead.", + DeprecationWarning, + stacklevel=2, + ) self._vector_search_internal( embedding_field, vector, @@ -2087,6 +2336,13 @@ def vector_search_text_i1_using_task( number_of_candidates: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, ) -> DocumentQuery[_T]: + """Deprecated: Use vector_search_text() with the target_quantization = VectorEmbeddingType.BINARY and embedding_generation_task_identifier parameters instead.""" + warnings.warn( + "vector_search_text_i1_using_task is deprecated; use vector_search_text with " + "target_quantization = VectorEmbeddingType.BINARY and embedding_generation_task_identifier parameters instead.", + DeprecationWarning, + stacklevel=2, + ) self._vector_search_internal( embedding_field, vector, diff --git a/ravendb/documents/session/tokens/query_tokens/definitions.py b/ravendb/documents/session/tokens/query_tokens/definitions.py index 2ff98f68..8642e46b 100644 --- a/ravendb/documents/session/tokens/query_tokens/definitions.py +++ b/ravendb/documents/session/tokens/query_tokens/definitions.py @@ -1002,12 +1002,11 @@ def __init__( parameter_name: str, source_quantization_type: VectorEmbeddingType, target_quantization_type: VectorEmbeddingType, - is_source_base64_encoded: bool, - is_vector_base64_encoded: bool, similarity_threshold: float = None, number_of_candidates_for_querying: int = None, is_exact: bool = VectorSearch.DEFAULT_IS_EXACT, task_name: str = None, + document_id: str = None, ): where_options = WhereToken.WhereOptions() where_options.exact = is_exact @@ -1019,14 +1018,12 @@ def __init__( self._source_quantization_type = source_quantization_type self._target_quantization_type = target_quantization_type - self._is_source_base64_encoded = is_source_base64_encoded - self._is_vector_base64_encoded = is_vector_base64_encoded - self._similarity_threshold = similarity_threshold self._number_of_candidates_for_querying = number_of_candidates_for_querying self._is_exact = is_exact self._task_name = task_name + self._document_id = document_id def write_to(self, writer: List[str]) -> None: """ @@ -1035,7 +1032,6 @@ def write_to(self, writer: List[str]) -> None: """ if self._is_exact: writer.append("exact(") - writer.append("vector.search(") if ( @@ -1047,13 +1043,18 @@ def write_to(self, writer: List[str]) -> None: method_name = VectorSearch.configuration_to_method_name( self._source_quantization_type, self._target_quantization_type ) - writer.append(f"{method_name}({self.field_name}") - if self._task_name: - writer.append(f", ai.task('{self._task_name}')") - writer.append(")") + if self._source_quantization_type == VectorEmbeddingType.TEXT and self._task_name is not None: + writer.append( + f"{method_name}({self.field_name}, {VectorSearch.AI_TASK_METHOD_NAME}('{self._task_name}'))" + ) + else: + writer.append(f"{method_name}({self.field_name})") + writer.append(", ") - # Add main parameter - writer.append(f", ${self._parameter_name}") + if self._document_id: + writer.append(f"{VectorSearch.EMBEDDING_FOR_DOCUMENT}(${self._parameter_name})") + else: + writer.append(f"${self._parameter_name}") # Handle optional parameters parameters_are_default = self._similarity_threshold is None and self._number_of_candidates_for_querying is None diff --git a/ravendb/tests/documents_tests/query_tests/test_vector_search.py b/ravendb/tests/documents_tests/query_tests/test_vector_search.py index e643a8f7..1a95211e 100644 --- a/ravendb/tests/documents_tests/query_tests/test_vector_search.py +++ b/ravendb/tests/documents_tests/query_tests/test_vector_search.py @@ -1,4 +1,6 @@ import unittest + +from ravendb.documents.indexes.vector.embedding import VectorEmbeddingType from ravendb.documents.queries.vector import VectorQuantizer from ravendb.tests.dotnet_migrated_tests.test_ravenDB_22076 import Dto from ravendb.tests.test_base import TestBase @@ -126,46 +128,62 @@ def test_to_int1_padding_trimmed(self): class TestVectorSearch(TestBase): def test_should_generate_rql_with_text_field_using_named_ai_task(self): with self.store.open_session() as session: - q = session.query(object_type=Dto).vector_search_text_using_task("EmbeddingField", "fishing", "my-ai-task") + q = session.query(object_type=Dto).vector_search_text( + "EmbeddingField", "fishing", embedding_generation_task_identifier="my-ai-task" + ) self.assertEqual( "from 'Dtoes' where vector.search(embedding.text(EmbeddingField, ai.task('my-ai-task')), $p0)", q._to_string(), ) - q_exact = session.query(object_type=Dto).vector_search_text_using_task( - "EmbeddingField", "fishing", "my-ai-task", is_exact=True + q_exact = session.query(object_type=Dto).vector_search_text( + "EmbeddingField", "fishing", embedding_generation_task_identifier="my-ai-task", is_exact=True ) self.assertEqual( "from 'Dtoes' where exact(vector.search(embedding.text(EmbeddingField, ai.task('my-ai-task')), $p0))", q_exact._to_string(), ) - q2 = session.query(object_type=Dto).vector_search_text_i1_using_task( - "EmbeddingField", "fishing", "my-ai-task" + q2 = session.query(object_type=Dto).vector_search_text( + "EmbeddingField", + "fishing", + target_quantization=VectorEmbeddingType.BINARY, + embedding_generation_task_identifier="my-ai-task", ) self.assertEqual( "from 'Dtoes' where vector.search(embedding.text_i1(EmbeddingField, ai.task('my-ai-task')), $p0)", q2._to_string(), ) - q2_exact = session.query(object_type=Dto).vector_search_text_i1_using_task( - "EmbeddingField", "fishing", "my-ai-task", is_exact=True + q2_exact = session.query(object_type=Dto).vector_search_text( + "EmbeddingField", + "fishing", + target_quantization=VectorEmbeddingType.BINARY, + embedding_generation_task_identifier="my-ai-task", + is_exact=True, ) self.assertEqual( "from 'Dtoes' where exact(vector.search(embedding.text_i1(EmbeddingField, ai.task('my-ai-task')), $p0))", q2_exact._to_string(), ) - q3 = session.query(object_type=Dto).vector_search_text_i8_using_task( - "EmbeddingField", "fishing", "my-ai-task" + q3 = session.query(object_type=Dto).vector_search_text( + "EmbeddingField", + "fishing", + target_quantization=VectorEmbeddingType.INT8, + embedding_generation_task_identifier="my-ai-task", ) self.assertEqual( "from 'Dtoes' where vector.search(embedding.text_i8(EmbeddingField, ai.task('my-ai-task')), $p0)", q3._to_string(), ) - q3_exact = session.query(object_type=Dto).vector_search_text_i8_using_task( - "EmbeddingField", "fishing", "my-ai-task", is_exact=True + q3_exact = session.query(object_type=Dto).vector_search_text( + "EmbeddingField", + "fishing", + embedding_generation_task_identifier="my-ai-task", + is_exact=True, + target_quantization=VectorEmbeddingType.INT8, ) self.assertEqual( "from 'Dtoes' where exact(vector.search(embedding.text_i8(EmbeddingField, ai.task('my-ai-task')), $p0))", diff --git a/ravendb/tests/dotnet_migrated_tests/test_ravenDB_22076.py b/ravendb/tests/dotnet_migrated_tests/test_ravenDB_22076.py index 5e53ee2c..30fcbd66 100644 --- a/ravendb/tests/dotnet_migrated_tests/test_ravenDB_22076.py +++ b/ravendb/tests/dotnet_migrated_tests/test_ravenDB_22076.py @@ -86,7 +86,9 @@ def test_rql_generation(self): q4 = session.query(object_type=Dto).vector_search("VectorField", "aaaa==") self.assertEqual("from 'Dtoes' where vector.search(VectorField, $p0)", q4._to_string()) - q5 = session.query(object_type=Dto).vector_search_text_i8("TextField", "aaaa") + q5 = session.query(object_type=Dto).vector_search_text( + "TextField", "aaaa", target_quantization=VectorEmbeddingType.INT8 + ) self.assertEqual("from 'Dtoes' where vector.search(embedding.text_i8(TextField), $p0)", q5._to_string()) q6 = session.query(object_type=Dto).vector_search_i8("EmbeddingField", [2, 3], 0.65) @@ -94,9 +96,13 @@ def test_rql_generation(self): "from 'Dtoes' where vector.search(embedding.i8(EmbeddingField), $p0, 0.65, null)", q6._to_string() ) - q7 = session.query(object_type=Dto).vector_search_text_i8("TextField", "aaaa") + q7 = session.query(object_type=Dto).vector_search_text( + "TextField", "aaaa", target_quantization=VectorEmbeddingType.INT8 + ) self.assertEqual("from 'Dtoes' where vector.search(embedding.text_i8(TextField), $p0)", q7._to_string()) + # q8 = session.query(object_type=Dto).vector_search_with_field() + def test_rql_generation_2(self): with self.store.open_session() as session: @@ -115,12 +121,16 @@ def test_rql_generation_2(self): "from 'Dtoes' where vector.search(embedding.i8(EmbeddingField), $p0, 0.65, null)", q1._to_string() ) - q2 = session.query(object_type=Dto).vector_search_f32_i8("EmbeddingField", [2.5, 3.3], 0.65) + q2 = session.query(object_type=Dto).vector_search( + "EmbeddingField", [2.5, 3.3], 0.65, target_quantization=VectorEmbeddingType.INT8 + ) self.assertEqual( "from 'Dtoes' where vector.search(embedding.f32_i8(EmbeddingField), $p0, 0.65, null)", q2._to_string() ) - q3 = session.query(object_type=Dto).vector_search_f32_i8("EmbeddingField", "abcd==", 0.75) + q3 = session.query(object_type=Dto).vector_search( + "EmbeddingField", "abcd==", 0.75, target_quantization=VectorEmbeddingType.INT8 + ) self.assertEqual( "from 'Dtoes' where vector.search(embedding.f32_i8(EmbeddingField), $p0, 0.75, null)", q3._to_string() ) @@ -144,6 +154,79 @@ def test_rql_generation_2(self): ) self.assertEqual("from 'Dtoes' where exact(vector.search(EmbeddingBase64, $p0, null, 25))", q8._to_string()) + def test_rql_generation_3(self): + with self.store.open_session() as session: + # forDocument - text/field + q1 = session.query(object_type=Dto).vector_search_with_field_for_document("VectorField", "docs/1-A") + self.assertEqual("from 'Dtoes' where vector.search(VectorField, embedding.forDoc($p0))", q1._to_string()) + + q2 = session.query(object_type=Dto).vector_search_text_for_document( + "VectorField", "docs/1-A", target_quantization=VectorEmbeddingType.INT8 + ) + self.assertEqual( + "from 'Dtoes' where vector.search(embedding.text_i8(VectorField), embedding.forDoc($p0))", + q2._to_string(), + ) + + # withField + q3 = session.query(object_type=Dto).vector_search_with_field("VectorField", [0.1, 0.2, 0.3]) + self.assertEqual("from 'Dtoes' where vector.search(VectorField, $p0)", q3._to_string()) + + q4 = session.query(object_type=Dto).vector_search_with_text_field("VectorField", "hello") + self.assertEqual("from 'Dtoes' where vector.search(VectorField, $p0)", q4._to_string()) + + q5 = session.query(object_type=Dto).vector_search_with_i8_field("VectorField", [1, 2, 3]) + self.assertEqual("from 'Dtoes' where vector.search(VectorField, $p0)", q5._to_string()) + + q6 = session.query(object_type=Dto).vector_search_with_i1_field("VectorField", [0, 1, 0]) + self.assertEqual("from 'Dtoes' where vector.search(VectorField, $p0)", q6._to_string()) + + # with base64 + q7 = session.query(object_type=Dto).vector_search_with_base64("VectorField", "abcd==") + self.assertEqual("from 'Dtoes' where vector.search(VectorField, $p0)", q7._to_string()) + + q8 = session.query(object_type=Dto).vector_search_with_base64_i8("VectorField", "abcd==") + self.assertEqual("from 'Dtoes' where vector.search(embedding.i8(VectorField), $p0)", q8._to_string()) + + q9 = session.query(object_type=Dto).vector_search_with_base64_i1("VectorField", "abcd==") + self.assertEqual("from 'Dtoes' where vector.search(embedding.i1(VectorField), $p0)", q9._to_string()) + + # ability to search in base64 + q10 = session.query(object_type=Dto).vector_search("VectorField", "abcd==") + self.assertEqual("from 'Dtoes' where vector.search(VectorField, $p0)", q10._to_string()) + + q11 = session.query(object_type=Dto).vector_search_i8("VectorField", "abcd==") + self.assertEqual("from 'Dtoes' where vector.search(embedding.i8(VectorField), $p0)", q11._to_string()) + + q12 = session.query(object_type=Dto).vector_search_i1("VectorField", "abcd==") + self.assertEqual("from 'Dtoes' where vector.search(embedding.i1(VectorField), $p0)", q12._to_string()) + + q13 = session.query(object_type=Dto).vector_search_with_field("VectorField", "abcd==") + self.assertEqual("from 'Dtoes' where vector.search(VectorField, $p0)", q13._to_string()) + + q14 = session.query(object_type=Dto).vector_search_with_i8_field("VectorField", "abcd==") + self.assertEqual("from 'Dtoes' where vector.search(VectorField, $p0)", q14._to_string()) + + q15 = session.query(object_type=Dto).vector_search_with_i1_field("VectorField", "abcd==") + self.assertEqual("from 'Dtoes' where vector.search(VectorField, $p0)", q15._to_string()) + + # embeddingTaskIdentifier + q16 = session.query(object_type=Dto).vector_search_text( + "VectorField", "hello", embedding_generation_task_identifier="my-ai-task" + ) + self.assertEqual( + "from 'Dtoes' where vector.search(embedding.text(VectorField, ai.task('my-ai-task')), $p0)", + q16._to_string(), + ) + + q17 = session.query(object_type=Dto).vector_search_text_for_document( + "VectorField", "hello", embedding_generation_task_identifier="my-ai-task" + ) + self.assertEqual( + "from 'Dtoes' where vector.search(embedding.text(VectorField, ai.task('my-ai-task')), embedding.forDoc($p0))", + q17._to_string(), + ) + def test_embedding_dimensions_check(self): with self.store.open_session() as session: dto1 = Dto(embedding_singles=[0.5, -1.0])