Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions unstract/sdk1/src/unstract/sdk1/adapters/base1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,3 +1005,30 @@ def validate_model(adapter_metadata: dict[str, "Any"]) -> str:
return model
else:
return f"ollama/{model}"


class GeminiEmbeddingParameters(BaseEmbeddingParameters):
"""See https://docs.litellm.ai/docs/providers/gemini."""

api_key: str
embed_batch_size: int | None = None

@staticmethod
def validate(adapter_metadata: dict[str, "Any"]) -> dict[str, "Any"]:
metadata_copy = {**adapter_metadata}
metadata_copy["model"] = GeminiEmbeddingParameters.validate_model(metadata_copy)

return GeminiEmbeddingParameters(**metadata_copy).model_dump()

@staticmethod
def validate_model(adapter_metadata: dict[str, "Any"]) -> str:
raw_model = adapter_metadata.get("model")
model = raw_model.strip() if isinstance(raw_model, str) else ""
if not model:
raise ValueError(
"The 'model' field is required for the Gemini embedding adapter. "
"Example: 'gemini/text-embedding-004'"
)
if not model.startswith("gemini/"):
model = f"gemini/{model}"
return model
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unstract.sdk1.adapters.base1 import register_adapters
from unstract.sdk1.adapters.embedding1.azure_openai import AzureOpenAIEmbeddingAdapter
from unstract.sdk1.adapters.embedding1.bedrock import AWSBedrockEmbeddingAdapter
from unstract.sdk1.adapters.embedding1.gemini import GeminiEmbeddingAdapter
from unstract.sdk1.adapters.embedding1.ollama import OllamaEmbeddingAdapter
from unstract.sdk1.adapters.embedding1.openai import OpenAIEmbeddingAdapter
from unstract.sdk1.adapters.embedding1.vertexai import VertexAIEmbeddingAdapter
Expand All @@ -16,6 +17,7 @@
"adapters",
"AzureOpenAIEmbeddingAdapter",
"AWSBedrockEmbeddingAdapter",
"GeminiEmbeddingAdapter",
"OpenAIEmbeddingAdapter",
"VertexAIEmbeddingAdapter",
"OllamaEmbeddingAdapter",
Expand Down
40 changes: 40 additions & 0 deletions unstract/sdk1/src/unstract/sdk1/adapters/embedding1/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any

from unstract.sdk1.adapters.base1 import BaseAdapter, GeminiEmbeddingParameters
from unstract.sdk1.adapters.enums import AdapterTypes


class GeminiEmbeddingAdapter(GeminiEmbeddingParameters, BaseAdapter):
@staticmethod
def get_id() -> str:
return "gemini|5c2a36b8-0b8e-4f26-82c0-9f3b564cb066"

@staticmethod
def get_metadata() -> dict[str, Any]:
return {
"name": "Gemini",
"version": "1.0.0",
"adapter": GeminiEmbeddingAdapter,
"description": "Gemini embedding adapter",
"is_active": True,
}

@staticmethod
def get_name() -> str:
return "Gemini"

@staticmethod
def get_description() -> str:
return "Gemini embedding adapter"

@staticmethod
def get_provider() -> str:
return "gemini"

@staticmethod
def get_icon() -> str:
return "/icons/adapter-icons/Gemini.png"

@staticmethod
def get_adapter_type() -> AdapterTypes:
return AdapterTypes.EMBEDDING
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"title": "Gemini Embedding",
"type": "object",
"required": [
"adapter_name",
"api_key",
"model"
],
"properties": {
"adapter_name": {
"type": "string",
"title": "Name",
"default": "",
"description": "Provide a unique name for this adapter instance. Example: gemini-emb-1"
},
"model": {
"type": "string",
"title": "Model",
"default": "gemini/text-embedding-004",
"description": "Provide the name of the model. The gemini/ prefix will be added automatically if omitted. Recommended: gemini/text-embedding-004"
},
"api_key": {
"type": "string",
"title": "API Key",
"default": "",
"format": "password"
},
"embed_batch_size": {
"type": "number",
"minimum": 1,
"multipleOf": 1,
"title": "Embed Batch Size",
"description": "Number of texts to embed in a single batch. Leave empty to use the system default."
},
"timeout": {
"type": "number",
"minimum": 0,
"multipleOf": 1,
"title": "Timeout",
"default": 240,
"description": "Timeout in seconds"
}
}
}
2 changes: 1 addition & 1 deletion unstract/sdk1/src/unstract/sdk1/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
self.platform_kwargs: dict[str, object] = kwargs
self.kwargs: dict[str, object] = self.adapter.validate(self._adapter_metadata)
self._cost_model: str | None = self.kwargs.pop("cost_model", None)
except ValidationError as e:
except (ValidationError, ValueError) as e:

Check warning on line 97 in unstract/sdk1/src/unstract/sdk1/embedding.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove this redundant Exception class; it derives from another which is already caught.

See more on https://sonarcloud.io/project/issues?id=Zipstack_unstract&issues=AZ16tN47yCqiNXJ_l20C&open=AZ16tN47yCqiNXJ_l20C&pullRequest=1891
raise SdkError("Invalid embedding adapter metadata: " + str(e)) from e

# Test connection - wrap in error handling
Expand Down
152 changes: 152 additions & 0 deletions unstract/sdk1/tests/test_gemini_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import json

import pytest
from unstract.sdk1.adapters.embedding1.gemini import GeminiEmbeddingAdapter
from unstract.sdk1.adapters.enums import AdapterTypes


class TestGeminiEmbeddingAdapter:
def test_adapter_registration(self) -> None:
from unstract.sdk1.adapters.embedding1 import adapters

gemini_ids = [k for k in adapters if "gemini" in k.lower()]
assert len(gemini_ids) == 1

def test_get_id_format(self) -> None:
adapter_id = GeminiEmbeddingAdapter.get_id()
assert adapter_id.startswith("gemini|")
# Standard UUID-4 with hyphens is 36 characters
uuid_part = adapter_id.split("|")[1]
assert len(uuid_part) == 36

def test_get_adapter_type(self) -> None:
assert GeminiEmbeddingAdapter.get_adapter_type() == AdapterTypes.EMBEDDING

def test_get_name(self) -> None:
assert GeminiEmbeddingAdapter.get_name() == "Gemini"

def test_get_provider(self) -> None:
assert GeminiEmbeddingAdapter.get_provider() == "gemini"

def test_json_schema_loads(self) -> None:
schema = json.loads(GeminiEmbeddingAdapter.get_json_schema())
assert isinstance(schema, dict)
assert "title" in schema
assert "properties" in schema
assert schema["title"] == "Gemini Embedding"

def test_json_schema_required_fields(self) -> None:
schema = json.loads(GeminiEmbeddingAdapter.get_json_schema())
assert set(schema["required"]) == {"adapter_name", "api_key", "model"}

def test_json_schema_no_batch_size_default(self) -> None:
schema = json.loads(GeminiEmbeddingAdapter.get_json_schema())
assert "default" not in schema["properties"]["embed_batch_size"]

def test_json_schema_api_key_password_format(self) -> None:
schema = json.loads(GeminiEmbeddingAdapter.get_json_schema())
assert schema["properties"]["api_key"]["format"] == "password"

def test_json_schema_model_default(self) -> None:
schema = json.loads(GeminiEmbeddingAdapter.get_json_schema())
assert schema["properties"]["model"]["default"] == "gemini/text-embedding-004"

def test_validate_model_adds_prefix(self) -> None:
meta = {"model": "text-embedding-004", "api_key": "test"}
result = GeminiEmbeddingAdapter.validate_model(meta)
assert result == "gemini/text-embedding-004"

def test_validate_model_idempotent(self) -> None:
meta = {"model": "gemini/text-embedding-004", "api_key": "test"}
result = GeminiEmbeddingAdapter.validate_model(meta)
assert result == "gemini/text-embedding-004"

def test_validate_model_does_not_mutate_input(self) -> None:
meta = {"model": "text-embedding-004", "api_key": "test"}
GeminiEmbeddingAdapter.validate_model(meta)
assert meta["model"] == "text-embedding-004"

def test_validate_does_not_mutate_input(self) -> None:
meta = {"model": "text-embedding-004", "api_key": "test-key"}
original_model = meta["model"]
GeminiEmbeddingAdapter.validate(meta)
assert meta["model"] == original_model

def test_validate_model_empty_string_raises(self) -> None:
meta = {"model": "", "api_key": "test"}
with pytest.raises(ValueError, match="model.*required"):
GeminiEmbeddingAdapter.validate_model(meta)

def test_validate_model_whitespace_only_raises(self) -> None:
meta = {"model": " ", "api_key": "test"}
with pytest.raises(ValueError, match="model.*required"):
GeminiEmbeddingAdapter.validate_model(meta)

def test_validate_model_none_raises(self) -> None:
meta = {"model": None, "api_key": "test"}
with pytest.raises(ValueError, match="model.*required"):
GeminiEmbeddingAdapter.validate_model(meta)

def test_validate_model_missing_key_raises(self) -> None:
meta = {"api_key": "test"}
with pytest.raises(ValueError, match="model.*required"):
GeminiEmbeddingAdapter.validate_model(meta)

def test_validate_empty_model_raises(self) -> None:
meta = {"model": "", "api_key": "test-key"}
with pytest.raises(ValueError, match="model.*required"):
GeminiEmbeddingAdapter.validate(meta)

def test_validate_none_model_raises(self) -> None:
meta = {"model": None, "api_key": "test-key"}
with pytest.raises(ValueError, match="model.*required"):
GeminiEmbeddingAdapter.validate(meta)

def test_validate_missing_api_key_raises(self) -> None:
from pydantic import ValidationError

meta = {"model": "gemini/text-embedding-004"}
with pytest.raises(ValidationError):
GeminiEmbeddingAdapter.validate(meta)

def test_validate_calls_validate_model(self) -> None:
meta = {"model": "text-embedding-004", "api_key": "test-key"}
validated = GeminiEmbeddingAdapter.validate(meta)
assert validated["model"] == "gemini/text-embedding-004"

def test_validate_embed_batch_size_none_by_default(self) -> None:
meta = {"model": "gemini/text-embedding-004", "api_key": "test-key"}
validated = GeminiEmbeddingAdapter.validate(meta)
assert validated["embed_batch_size"] is None

def test_validate_embed_batch_size_preserved(self) -> None:
meta = {
"model": "gemini/text-embedding-004",
"api_key": "test-key",
"embed_batch_size": 50,
}
validated = GeminiEmbeddingAdapter.validate(meta)
assert validated["embed_batch_size"] == 50

def test_validate_strips_extra_fields(self) -> None:
meta = {
"model": "gemini/text-embedding-004",
"api_key": "test-key",
"adapter_name": "my-adapter",
"unknown_field": "should_be_dropped",
}
validated = GeminiEmbeddingAdapter.validate(meta)
assert "adapter_name" not in validated
assert "unknown_field" not in validated

def test_validate_includes_base_fields(self) -> None:
meta = {"model": "gemini/text-embedding-004", "api_key": "test-key"}
validated = GeminiEmbeddingAdapter.validate(meta)
assert "timeout" in validated
assert "max_retries" in validated

def test_metadata(self) -> None:
metadata = GeminiEmbeddingAdapter.get_metadata()
assert metadata["name"] == "Gemini"
assert metadata["is_active"] is True
assert metadata["adapter"] is GeminiEmbeddingAdapter