diff --git a/README.md b/README.md index a3ab108c6b..9e282b3e10 100644 --- a/README.md +++ b/README.md @@ -176,9 +176,10 @@ Also see [architecture](docs/ARCHITECTURE.md). | Provider | Status | Provider | Status | |----------|--------|----------|--------| | OpenAI | ✅ | Azure OpenAI | ✅ | -| Anthropic Claude | ✅ | Google Gemini | ✅ | -| AWS Bedrock | ✅ | Mistral AI | ✅ | -| Ollama (local) | ✅ | Anyscale | ✅ | +| OpenAI Compatible | ✅ | Anthropic Claude | ✅ | +| AWS Bedrock | ✅ | Google Gemini | ✅ | +| Ollama (local) | ✅ | Mistral AI | ✅ | +| Anyscale | ✅ | | | ### Vector Databases diff --git a/frontend/public/icons/adapter-icons/OpenAICompatible.png b/frontend/public/icons/adapter-icons/OpenAICompatible.png new file mode 100644 index 0000000000..ec23189d9a Binary files /dev/null and b/frontend/public/icons/adapter-icons/OpenAICompatible.png differ diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/base1.py b/unstract/sdk1/src/unstract/sdk1/adapters/base1.py index 8ad721c3d4..427e1d361b 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/base1.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/base1.py @@ -225,6 +225,30 @@ def validate_model(adapter_metadata: dict[str, "Any"]) -> str: return f"openai/{model}" +class OpenAICompatibleLLMParameters(BaseChatCompletionParameters): + """See https://docs.litellm.ai/docs/providers/openai_compatible/.""" + + api_key: str | None = None + api_base: str + + @staticmethod + def validate(adapter_metadata: dict[str, "Any"]) -> dict[str, "Any"]: + adapter_metadata["model"] = OpenAICompatibleLLMParameters.validate_model( + adapter_metadata + ) + api_key = adapter_metadata.get("api_key") + if isinstance(api_key, str) and not api_key.strip(): + adapter_metadata["api_key"] = None + return OpenAICompatibleLLMParameters(**adapter_metadata).model_dump() + + @staticmethod + def validate_model(adapter_metadata: dict[str, "Any"]) -> str: + model = adapter_metadata.get("model", "") + if model.startswith("custom_openai/"): + return model + return f"custom_openai/{model}" + + class AzureOpenAILLMParameters(BaseChatCompletionParameters): """See https://docs.litellm.ai/docs/providers/azure/#completion---using-azure_ad_token-api_base-api_version.""" diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py index c23a33390a..1da3590f51 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py @@ -8,6 +8,7 @@ from unstract.sdk1.adapters.llm1.bedrock import AWSBedrockLLMAdapter from unstract.sdk1.adapters.llm1.ollama import OllamaLLMAdapter from unstract.sdk1.adapters.llm1.openai import OpenAILLMAdapter +from unstract.sdk1.adapters.llm1.openai_compatible import OpenAICompatibleLLMAdapter from unstract.sdk1.adapters.llm1.vertexai import VertexAILLMAdapter adapters: dict[str, dict[str, Any]] = {} @@ -22,5 +23,6 @@ "AzureOpenAILLMAdapter", "OllamaLLMAdapter", "OpenAILLMAdapter", + "OpenAICompatibleLLMAdapter", "VertexAILLMAdapter", ] diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/openai_compatible.py b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/openai_compatible.py new file mode 100644 index 0000000000..3cb3ceafc4 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/openai_compatible.py @@ -0,0 +1,46 @@ +from typing import Any + +from unstract.sdk1.adapters.base1 import BaseAdapter, OpenAICompatibleLLMParameters +from unstract.sdk1.adapters.enums import AdapterTypes + +DESCRIPTION = ( + "Adapter for servers that implement the OpenAI Chat Completions API " + "(vLLM, LM Studio, self-hosted gateways, and third-party providers). " + "Use OpenAI for the official OpenAI service." +) + + +class OpenAICompatibleLLMAdapter(OpenAICompatibleLLMParameters, BaseAdapter): + @staticmethod + def get_id() -> str: + return "openaicompatible|b6d10f33-2c41-49fc-a8c2-58d2b247fc09" + + @staticmethod + def get_metadata() -> dict[str, Any]: + return { + "name": "OpenAI Compatible", + "version": "1.0.0", + "adapter": OpenAICompatibleLLMAdapter, + "description": DESCRIPTION, + "is_active": True, + } + + @staticmethod + def get_name() -> str: + return "OpenAI Compatible" + + @staticmethod + def get_description() -> str: + return DESCRIPTION + + @staticmethod + def get_provider() -> str: + return "custom_openai" + + @staticmethod + def get_icon() -> str: + return "/icons/adapter-icons/OpenAICompatible.png" + + @staticmethod + def get_adapter_type() -> AdapterTypes: + return AdapterTypes.LLM diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/custom_openai.json b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/custom_openai.json new file mode 100644 index 0000000000..f4720c9e92 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/adapters/llm1/static/custom_openai.json @@ -0,0 +1,61 @@ +{ + "title": "OpenAI Compatible LLM", + "type": "object", + "required": [ + "adapter_name", + "api_base" + ], + "properties": { + "adapter_name": { + "type": "string", + "title": "Name", + "default": "", + "description": "Provide a unique name for this adapter instance. Example: compatible-gateway-1" + }, + "api_key": { + "type": [ + "string", + "null" + ], + "title": "API Key", + "format": "password", + "description": "API key for your OpenAI-compatible endpoint. Leave empty if the endpoint does not require one." + }, + "model": { + "type": "string", + "title": "Model", + "default": "gpt-4o-mini", + "description": "The model name expected by your OpenAI-compatible endpoint. Examples: gateway-model, gpt-4o-mini, openai/gpt-4o" + }, + "api_base": { + "type": "string", + "format": "url", + "title": "API Base", + "description": "Base URL for the OpenAI-compatible endpoint. Examples: https://gateway.example.com/v1, https://llm.example.net/openai/v1" + }, + "max_tokens": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Maximum Output Tokens", + "default": 4096, + "description": "Maximum number of output tokens to limit LLM replies. Leave it empty to use the provider default." + }, + "max_retries": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Max Retries", + "default": 5, + "description": "The maximum number of times to retry a request if it fails." + }, + "timeout": { + "type": "number", + "minimum": 0, + "multipleOf": 1, + "title": "Timeout", + "default": 900, + "description": "Timeout in seconds." + } + } +} diff --git a/unstract/sdk1/src/unstract/sdk1/llm.py b/unstract/sdk1/src/unstract/sdk1/llm.py index 8ff29a89d5..ce5819c9ec 100644 --- a/unstract/sdk1/src/unstract/sdk1/llm.py +++ b/unstract/sdk1/src/unstract/sdk1/llm.py @@ -539,10 +539,22 @@ def _record_usage( usage: Mapping[str, int] | None, llm_api: str, ) -> None: - prompt_tokens = token_counter(model=model, messages=messages) usage_data: Mapping[str, int] = usage or {} + prompt_tokens = usage_data.get("prompt_tokens") + if prompt_tokens is None: + try: + prompt_tokens = token_counter(model=model, messages=messages) + except Exception as e: + prompt_tokens = 0 + logger.warning( + "[sdk1][LLM][%s][%s] Failed to estimate prompt tokens; " + "recording 0 prompt tokens for usage audit: %s", + model, + llm_api, + e, + ) all_tokens = TokenCounterCompat( - prompt_tokens=usage_data.get("prompt_tokens", 0), + prompt_tokens=prompt_tokens or 0, completion_tokens=usage_data.get("completion_tokens", 0), total_tokens=usage_data.get("total_tokens", 0), ) diff --git a/unstract/sdk1/tests/test_openai_compatible_adapter.py b/unstract/sdk1/tests/test_openai_compatible_adapter.py new file mode 100644 index 0000000000..53eade4601 --- /dev/null +++ b/unstract/sdk1/tests/test_openai_compatible_adapter.py @@ -0,0 +1,192 @@ +import json +from functools import lru_cache +from importlib import import_module +from unittest.mock import MagicMock, patch + +from unstract.sdk1.adapters.base1 import OpenAICompatibleLLMParameters +from unstract.sdk1.adapters.constants import Common +from unstract.sdk1.adapters.llm1 import adapters +from unstract.sdk1.adapters.llm1.openai_compatible import OpenAICompatibleLLMAdapter + +OPENAI_COMPATIBLE_DESCRIPTION = ( + "Adapter for servers that implement the OpenAI Chat Completions API " + "(vLLM, LM Studio, self-hosted gateways, and third-party providers). " + "Use OpenAI for the official OpenAI service." +) + + +@lru_cache(maxsize=1) +def _load_llm_module() -> object: + import sys + from types import ModuleType + + with patch.dict( + sys.modules, + { + # Stub python-magic so importing LLM does not depend on libmagic + # being available in the test environment. + "magic": ModuleType("magic") + }, + ): + return import_module("unstract.sdk1.llm") + + +def _load_llm_class() -> type: + return _load_llm_module().LLM + + +def test_openai_compatible_adapter_is_registered() -> None: + adapter_id = OpenAICompatibleLLMAdapter.get_id() + + assert adapter_id in adapters + assert adapters[adapter_id][Common.MODULE] is OpenAICompatibleLLMAdapter + + +def test_openai_compatible_validate_prefixes_model() -> None: + validated = OpenAICompatibleLLMParameters.validate( + { + "api_base": "https://gateway.example.com/v1", + "api_key": "test-key", + "model": "gateway-model", + } + ) + + assert validated["model"] == "custom_openai/gateway-model" + + +def test_openai_compatible_validate_preserves_prefixed_model() -> None: + validated = OpenAICompatibleLLMParameters.validate( + { + "api_base": "https://gateway.example.com/v1", + "model": "custom_openai/openai/gpt-4o", + } + ) + + assert validated["model"] == "custom_openai/openai/gpt-4o" + assert validated["api_key"] is None + + +def test_openai_compatible_validate_normalizes_blank_api_key_to_none() -> None: + validated = OpenAICompatibleLLMParameters.validate( + { + "api_base": "https://gateway.example.com/v1", + "api_key": " ", + "model": "gateway-model", + } + ) + + assert validated["api_key"] is None + + +def test_openai_compatible_schema_is_loadable() -> None: + schema = json.loads(OpenAICompatibleLLMAdapter.get_json_schema()) + + assert schema["title"] == "OpenAI Compatible LLM" + assert schema["properties"]["api_key"]["type"] == ["string", "null"] + assert "gateway-model" in schema["properties"]["model"]["description"] + assert "ERNIE" not in schema["properties"]["model"]["description"] + assert "qianfan" not in schema["properties"]["api_base"]["description"].lower() + assert "default" not in schema["properties"]["api_base"] + + +def test_openai_compatible_adapter_uses_distinct_description_and_icon() -> None: + metadata = OpenAICompatibleLLMAdapter.get_metadata() + + assert OpenAICompatibleLLMAdapter.get_description() == OPENAI_COMPATIBLE_DESCRIPTION + assert metadata["description"] == OPENAI_COMPATIBLE_DESCRIPTION + assert OpenAICompatibleLLMAdapter.get_icon() == ( + "/icons/adapter-icons/OpenAICompatible.png" + ) + + +def test_record_usage_uses_reported_prompt_tokens_without_estimating() -> None: + llm_module = _load_llm_module() + llm_cls = llm_module.LLM + + llm = llm_cls.__new__(llm_cls) + llm._platform_api_key = "platform-key" + llm.platform_kwargs = {"run_id": "run-1"} + llm.adapter = MagicMock() + llm.adapter.get_provider.return_value = "custom_openai" + + with ( + patch.object(llm_module, "token_counter") as mock_token_counter, + patch.object(llm_module, "Audit") as mock_audit, + ): + llm._record_usage( + model="custom_openai/gateway-model", + messages=[{"role": "user", "content": "hello"}], + usage={"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7}, + llm_api="complete", + ) + + mock_token_counter.assert_not_called() + mock_audit.return_value.push_usage_data.assert_called_once() + assert ( + mock_audit.return_value.push_usage_data.call_args.kwargs[ + "token_counter" + ].prompt_llm_token_count + == 3 + ) + + +def test_record_usage_tolerates_unmapped_models_without_prompt_tokens() -> None: + llm_module = _load_llm_module() + llm_cls = llm_module.LLM + + llm = llm_cls.__new__(llm_cls) + llm._platform_api_key = "platform-key" + llm.platform_kwargs = {"run_id": "run-1"} + llm.adapter = MagicMock() + llm.adapter.get_provider.return_value = "custom_openai" + + with ( + patch.object(llm_module, "token_counter", side_effect=Exception("unmapped")), + patch.object(llm_module, "Audit") as mock_audit, + patch.object(llm_module.logger, "warning") as mock_warning, + ): + llm._record_usage( + model="custom_openai/gateway-model", + messages=[{"role": "user", "content": "hello"}], + usage={"completion_tokens": 4, "total_tokens": 7}, + llm_api="complete", + ) + + mock_audit.return_value.push_usage_data.assert_called_once() + assert ( + mock_audit.return_value.push_usage_data.call_args.kwargs[ + "token_counter" + ].prompt_llm_token_count + == 0 + ) + assert "recording 0 prompt tokens for usage audit" in mock_warning.call_args.args[0] + + +def test_record_usage_uses_estimated_prompt_tokens_when_usage_has_none() -> None: + llm_module = _load_llm_module() + llm_cls = llm_module.LLM + + llm = llm_cls.__new__(llm_cls) + llm._platform_api_key = "platform-key" + llm.platform_kwargs = {"run_id": "run-1"} + llm.adapter = MagicMock() + llm.adapter.get_provider.return_value = "custom_openai" + + with ( + patch.object(llm_module, "token_counter", return_value=9) as mock_token_counter, + patch.object(llm_module, "Audit") as mock_audit, + ): + llm._record_usage( + model="custom_openai/gateway-model", + messages=[{"role": "user", "content": "hello"}], + usage={"prompt_tokens": None, "completion_tokens": 4, "total_tokens": 13}, + llm_api="complete", + ) + + mock_token_counter.assert_called_once() + assert ( + mock_audit.return_value.push_usage_data.call_args.kwargs[ + "token_counter" + ].prompt_llm_token_count + == 9 + )