diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/base1.py b/unstract/sdk1/src/unstract/sdk1/adapters/base1.py index f7d98f0fa1..920a561633 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/base1.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/base1.py @@ -1005,3 +1005,30 @@ def validate_model(adapter_metadata: dict[str, "Any"]) -> str: return model else: return f"ollama/{model}" + + +class GeminiEmbeddingParameters(BaseEmbeddingParameters): + """See https://docs.litellm.ai/docs/providers/gemini.""" + + api_key: str + embed_batch_size: int | None = None + + @staticmethod + def validate(adapter_metadata: dict[str, "Any"]) -> dict[str, "Any"]: + metadata_copy = {**adapter_metadata} + metadata_copy["model"] = GeminiEmbeddingParameters.validate_model(metadata_copy) + + return GeminiEmbeddingParameters(**metadata_copy).model_dump() + + @staticmethod + def validate_model(adapter_metadata: dict[str, "Any"]) -> str: + raw_model = adapter_metadata.get("model") + model = raw_model.strip() if isinstance(raw_model, str) else "" + if not model: + raise ValueError( + "The 'model' field is required for the Gemini embedding adapter. " + "Example: 'gemini/text-embedding-004'" + ) + if not model.startswith("gemini/"): + model = f"gemini/{model}" + return model diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/__init__.py b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/__init__.py index d1c5e4935a..3f7de6e916 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/__init__.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/__init__.py @@ -3,6 +3,7 @@ from unstract.sdk1.adapters.base1 import register_adapters from unstract.sdk1.adapters.embedding1.azure_openai import AzureOpenAIEmbeddingAdapter from unstract.sdk1.adapters.embedding1.bedrock import AWSBedrockEmbeddingAdapter +from unstract.sdk1.adapters.embedding1.gemini import GeminiEmbeddingAdapter from unstract.sdk1.adapters.embedding1.ollama import OllamaEmbeddingAdapter from unstract.sdk1.adapters.embedding1.openai import OpenAIEmbeddingAdapter from unstract.sdk1.adapters.embedding1.vertexai import VertexAIEmbeddingAdapter @@ -16,6 +17,7 @@ "adapters", "AzureOpenAIEmbeddingAdapter", "AWSBedrockEmbeddingAdapter", + "GeminiEmbeddingAdapter", "OpenAIEmbeddingAdapter", "VertexAIEmbeddingAdapter", "OllamaEmbeddingAdapter", diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/gemini.py b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/gemini.py new file mode 100644 index 0000000000..ee6d1db87e --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/gemini.py @@ -0,0 +1,40 @@ +from typing import Any + +from unstract.sdk1.adapters.base1 import BaseAdapter, GeminiEmbeddingParameters +from unstract.sdk1.adapters.enums import AdapterTypes + + +class GeminiEmbeddingAdapter(GeminiEmbeddingParameters, BaseAdapter): + @staticmethod + def get_id() -> str: + return "gemini|5c2a36b8-0b8e-4f26-82c0-9f3b564cb066" + + @staticmethod + def get_metadata() -> dict[str, Any]: + return { + "name": "Gemini", + "version": "1.0.0", + "adapter": GeminiEmbeddingAdapter, + "description": "Gemini embedding adapter", + "is_active": True, + } + + @staticmethod + def get_name() -> str: + return "Gemini" + + @staticmethod + def get_description() -> str: + return "Gemini embedding adapter" + + @staticmethod + def get_provider() -> str: + return "gemini" + + @staticmethod + def get_icon() -> str: + return "/icons/adapter-icons/Gemini.png" + + @staticmethod + def get_adapter_type() -> AdapterTypes: + return AdapterTypes.EMBEDDING diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/gemini.json b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/gemini.json new file mode 100644 index 0000000000..976c18438b --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/adapters/embedding1/static/gemini.json @@ -0,0 +1,44 @@ +{ + "title": "Gemini Embedding", + "type": "object", + "required": [ + "adapter_name", + "api_key", + "model" + ], + "properties": { + "adapter_name": { + "type": "string", + "title": "Name", + "default": "", + "description": "Provide a unique name for this adapter instance. Example: gemini-emb-1" + }, + "model": { + "type": "string", + "title": "Model", + "default": "gemini/text-embedding-004", + "description": "Provide the name of the model. The gemini/ prefix will be added automatically if omitted. Recommended: gemini/text-embedding-004" + }, + "api_key": { + "type": "string", + "title": "API Key", + "default": "", + "format": "password" + }, + "embed_batch_size": { + "type": "number", + "minimum": 1, + "multipleOf": 1, + "title": "Embed Batch Size", + "description": "Number of texts to embed in a single batch. Leave empty to use the system default." + }, + "timeout": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Timeout", + "default": 240, + "description": "Timeout in seconds" + } + } +} diff --git a/unstract/sdk1/src/unstract/sdk1/embedding.py b/unstract/sdk1/src/unstract/sdk1/embedding.py index e54a093393..0677473852 100644 --- a/unstract/sdk1/src/unstract/sdk1/embedding.py +++ b/unstract/sdk1/src/unstract/sdk1/embedding.py @@ -94,7 +94,7 @@ def __init__( self.platform_kwargs: dict[str, object] = kwargs self.kwargs: dict[str, object] = self.adapter.validate(self._adapter_metadata) self._cost_model: str | None = self.kwargs.pop("cost_model", None) - except ValidationError as e: + except (ValidationError, ValueError) as e: raise SdkError("Invalid embedding adapter metadata: " + str(e)) from e # Test connection - wrap in error handling diff --git a/unstract/sdk1/tests/test_gemini_embedding.py b/unstract/sdk1/tests/test_gemini_embedding.py new file mode 100644 index 0000000000..b43a16262a --- /dev/null +++ b/unstract/sdk1/tests/test_gemini_embedding.py @@ -0,0 +1,152 @@ +import json + +import pytest +from unstract.sdk1.adapters.embedding1.gemini import GeminiEmbeddingAdapter +from unstract.sdk1.adapters.enums import AdapterTypes + + +class TestGeminiEmbeddingAdapter: + def test_adapter_registration(self) -> None: + from unstract.sdk1.adapters.embedding1 import adapters + + gemini_ids = [k for k in adapters if "gemini" in k.lower()] + assert len(gemini_ids) == 1 + + def test_get_id_format(self) -> None: + adapter_id = GeminiEmbeddingAdapter.get_id() + assert adapter_id.startswith("gemini|") + # Standard UUID-4 with hyphens is 36 characters + uuid_part = adapter_id.split("|")[1] + assert len(uuid_part) == 36 + + def test_get_adapter_type(self) -> None: + assert GeminiEmbeddingAdapter.get_adapter_type() == AdapterTypes.EMBEDDING + + def test_get_name(self) -> None: + assert GeminiEmbeddingAdapter.get_name() == "Gemini" + + def test_get_provider(self) -> None: + assert GeminiEmbeddingAdapter.get_provider() == "gemini" + + def test_json_schema_loads(self) -> None: + schema = json.loads(GeminiEmbeddingAdapter.get_json_schema()) + assert isinstance(schema, dict) + assert "title" in schema + assert "properties" in schema + assert schema["title"] == "Gemini Embedding" + + def test_json_schema_required_fields(self) -> None: + schema = json.loads(GeminiEmbeddingAdapter.get_json_schema()) + assert set(schema["required"]) == {"adapter_name", "api_key", "model"} + + def test_json_schema_no_batch_size_default(self) -> None: + schema = json.loads(GeminiEmbeddingAdapter.get_json_schema()) + assert "default" not in schema["properties"]["embed_batch_size"] + + def test_json_schema_api_key_password_format(self) -> None: + schema = json.loads(GeminiEmbeddingAdapter.get_json_schema()) + assert schema["properties"]["api_key"]["format"] == "password" + + def test_json_schema_model_default(self) -> None: + schema = json.loads(GeminiEmbeddingAdapter.get_json_schema()) + assert schema["properties"]["model"]["default"] == "gemini/text-embedding-004" + + def test_validate_model_adds_prefix(self) -> None: + meta = {"model": "text-embedding-004", "api_key": "test"} + result = GeminiEmbeddingAdapter.validate_model(meta) + assert result == "gemini/text-embedding-004" + + def test_validate_model_idempotent(self) -> None: + meta = {"model": "gemini/text-embedding-004", "api_key": "test"} + result = GeminiEmbeddingAdapter.validate_model(meta) + assert result == "gemini/text-embedding-004" + + def test_validate_model_does_not_mutate_input(self) -> None: + meta = {"model": "text-embedding-004", "api_key": "test"} + GeminiEmbeddingAdapter.validate_model(meta) + assert meta["model"] == "text-embedding-004" + + def test_validate_does_not_mutate_input(self) -> None: + meta = {"model": "text-embedding-004", "api_key": "test-key"} + original_model = meta["model"] + GeminiEmbeddingAdapter.validate(meta) + assert meta["model"] == original_model + + def test_validate_model_empty_string_raises(self) -> None: + meta = {"model": "", "api_key": "test"} + with pytest.raises(ValueError, match="model.*required"): + GeminiEmbeddingAdapter.validate_model(meta) + + def test_validate_model_whitespace_only_raises(self) -> None: + meta = {"model": " ", "api_key": "test"} + with pytest.raises(ValueError, match="model.*required"): + GeminiEmbeddingAdapter.validate_model(meta) + + def test_validate_model_none_raises(self) -> None: + meta = {"model": None, "api_key": "test"} + with pytest.raises(ValueError, match="model.*required"): + GeminiEmbeddingAdapter.validate_model(meta) + + def test_validate_model_missing_key_raises(self) -> None: + meta = {"api_key": "test"} + with pytest.raises(ValueError, match="model.*required"): + GeminiEmbeddingAdapter.validate_model(meta) + + def test_validate_empty_model_raises(self) -> None: + meta = {"model": "", "api_key": "test-key"} + with pytest.raises(ValueError, match="model.*required"): + GeminiEmbeddingAdapter.validate(meta) + + def test_validate_none_model_raises(self) -> None: + meta = {"model": None, "api_key": "test-key"} + with pytest.raises(ValueError, match="model.*required"): + GeminiEmbeddingAdapter.validate(meta) + + def test_validate_missing_api_key_raises(self) -> None: + from pydantic import ValidationError + + meta = {"model": "gemini/text-embedding-004"} + with pytest.raises(ValidationError): + GeminiEmbeddingAdapter.validate(meta) + + def test_validate_calls_validate_model(self) -> None: + meta = {"model": "text-embedding-004", "api_key": "test-key"} + validated = GeminiEmbeddingAdapter.validate(meta) + assert validated["model"] == "gemini/text-embedding-004" + + def test_validate_embed_batch_size_none_by_default(self) -> None: + meta = {"model": "gemini/text-embedding-004", "api_key": "test-key"} + validated = GeminiEmbeddingAdapter.validate(meta) + assert validated["embed_batch_size"] is None + + def test_validate_embed_batch_size_preserved(self) -> None: + meta = { + "model": "gemini/text-embedding-004", + "api_key": "test-key", + "embed_batch_size": 50, + } + validated = GeminiEmbeddingAdapter.validate(meta) + assert validated["embed_batch_size"] == 50 + + def test_validate_strips_extra_fields(self) -> None: + meta = { + "model": "gemini/text-embedding-004", + "api_key": "test-key", + "adapter_name": "my-adapter", + "unknown_field": "should_be_dropped", + } + validated = GeminiEmbeddingAdapter.validate(meta) + assert "adapter_name" not in validated + assert "unknown_field" not in validated + + def test_validate_includes_base_fields(self) -> None: + meta = {"model": "gemini/text-embedding-004", "api_key": "test-key"} + validated = GeminiEmbeddingAdapter.validate(meta) + assert "timeout" in validated + assert "max_retries" in validated + + def test_metadata(self) -> None: + metadata = GeminiEmbeddingAdapter.get_metadata() + assert metadata["name"] == "Gemini" + assert metadata["is_active"] is True + assert metadata["adapter"] is GeminiEmbeddingAdapter