Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ Also see [architecture](docs/ARCHITECTURE.md).
| Provider | Status | Provider | Status |
|----------|--------|----------|--------|
| OpenAI | ✅ | Azure OpenAI | ✅ |
| Anthropic Claude | ✅ | Google Gemini | ✅ |
| AWS Bedrock | ✅ | Mistral AI | ✅ |
| Ollama (local) | ✅ | Anyscale | ✅ |
| OpenAI Compatible | ✅ | Anthropic Claude | ✅ |
| AWS Bedrock | ✅ | Google Gemini | ✅ |
| Ollama (local) | ✅ | Mistral AI | ✅ |
| Anyscale | ✅ | | |

### Vector Databases

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions unstract/sdk1/src/unstract/sdk1/adapters/base1.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,30 @@ def validate_model(adapter_metadata: dict[str, "Any"]) -> str:
return f"openai/{model}"


class OpenAICompatibleLLMParameters(BaseChatCompletionParameters):
"""See https://docs.litellm.ai/docs/providers/openai_compatible/."""

api_key: str | None = None
api_base: str

@staticmethod
def validate(adapter_metadata: dict[str, "Any"]) -> dict[str, "Any"]:
adapter_metadata["model"] = OpenAICompatibleLLMParameters.validate_model(
adapter_metadata
)
api_key = adapter_metadata.get("api_key")
if isinstance(api_key, str) and not api_key.strip():
adapter_metadata["api_key"] = None
return OpenAICompatibleLLMParameters(**adapter_metadata).model_dump()

@staticmethod
def validate_model(adapter_metadata: dict[str, "Any"]) -> str:
model = adapter_metadata.get("model", "")
if model.startswith("custom_openai/"):
return model
return f"custom_openai/{model}"


class AzureOpenAILLMParameters(BaseChatCompletionParameters):
"""See https://docs.litellm.ai/docs/providers/azure/#completion---using-azure_ad_token-api_base-api_version."""

Expand Down
2 changes: 2 additions & 0 deletions unstract/sdk1/src/unstract/sdk1/adapters/llm1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from unstract.sdk1.adapters.llm1.bedrock import AWSBedrockLLMAdapter
from unstract.sdk1.adapters.llm1.ollama import OllamaLLMAdapter
from unstract.sdk1.adapters.llm1.openai import OpenAILLMAdapter
from unstract.sdk1.adapters.llm1.openai_compatible import OpenAICompatibleLLMAdapter
from unstract.sdk1.adapters.llm1.vertexai import VertexAILLMAdapter

adapters: dict[str, dict[str, Any]] = {}
Expand All @@ -22,5 +23,6 @@
"AzureOpenAILLMAdapter",
"OllamaLLMAdapter",
"OpenAILLMAdapter",
"OpenAICompatibleLLMAdapter",
"VertexAILLMAdapter",
]
46 changes: 46 additions & 0 deletions unstract/sdk1/src/unstract/sdk1/adapters/llm1/openai_compatible.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Any

from unstract.sdk1.adapters.base1 import BaseAdapter, OpenAICompatibleLLMParameters
from unstract.sdk1.adapters.enums import AdapterTypes

DESCRIPTION = (
"Adapter for servers that implement the OpenAI Chat Completions API "
"(vLLM, LM Studio, self-hosted gateways, and third-party providers). "
"Use OpenAI for the official OpenAI service."
)


class OpenAICompatibleLLMAdapter(OpenAICompatibleLLMParameters, BaseAdapter):
@staticmethod
def get_id() -> str:
return "openaicompatible|b6d10f33-2c41-49fc-a8c2-58d2b247fc09"

@staticmethod
def get_metadata() -> dict[str, Any]:
return {
"name": "OpenAI Compatible",
"version": "1.0.0",
"adapter": OpenAICompatibleLLMAdapter,
"description": DESCRIPTION,
"is_active": True,
}

@staticmethod
def get_name() -> str:
return "OpenAI Compatible"

@staticmethod
def get_description() -> str:
return DESCRIPTION

@staticmethod
def get_provider() -> str:
return "custom_openai"

@staticmethod
def get_icon() -> str:
return "/icons/adapter-icons/OpenAICompatible.png"

@staticmethod
def get_adapter_type() -> AdapterTypes:
return AdapterTypes.LLM
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{
"title": "OpenAI Compatible LLM",
"type": "object",
"required": [
"adapter_name",
"api_base"
],
Comment thread
jimmyzhuu marked this conversation as resolved.
"properties": {
"adapter_name": {
"type": "string",
"title": "Name",
"default": "",
"description": "Provide a unique name for this adapter instance. Example: compatible-gateway-1"
},
"api_key": {
"type": [
"string",
"null"
],
"title": "API Key",
"format": "password",
"description": "API key for your OpenAI-compatible endpoint. Leave empty if the endpoint does not require one."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Medium] api_key misconfiguration is indistinguishable from intentional no-auth.

"type": ["string", "null"] with no minLength and no explicit toggle means:

  • A user who forgets to paste their key gets a 401 from the provider at request time, not a validation error.
  • null, "", and " " are all accepted and functionally identical.

Consider either:

  • Adding an explicit boolean flag (e.g. requires_auth, default true) so the schema can reject empty api_key when auth is expected; or
  • Tightening validation in OpenAICompatibleLLMParameters to coerce ""None and treat None as "keyless endpoint" explicitly.

As-is the description ("Leave empty if the endpoint does not require one") silently papers over a common misconfiguration.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the smaller validation step here: blank or whitespace-only api_key values are now normalized to None, so the keyless-endpoint case stays explicit without expanding the schema.

},
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"model": {
"type": "string",
"title": "Model",
"default": "gpt-4o-mini",
"description": "The model name expected by your OpenAI-compatible endpoint. Examples: gateway-model, gpt-4o-mini, openai/gpt-4o"
},
"api_base": {
"type": "string",
"format": "url",
"title": "API Base",
"description": "Base URL for the OpenAI-compatible endpoint. Examples: https://gateway.example.com/v1, https://llm.example.net/openai/v1"
},
"max_tokens": {
"type": "number",
"minimum": 0,
"multipleOf": 1,
"title": "Maximum Output Tokens",
"default": 4096,
"description": "Maximum number of output tokens to limit LLM replies. Leave it empty to use the provider default."
},
"max_retries": {
"type": "number",
"minimum": 0,
"multipleOf": 1,
"title": "Max Retries",
"default": 5,
"description": "The maximum number of times to retry a request if it fails."
},
"timeout": {
"type": "number",
"minimum": 0,
"multipleOf": 1,
"title": "Timeout",
"default": 900,
"description": "Timeout in seconds."
}
}
}
16 changes: 14 additions & 2 deletions unstract/sdk1/src/unstract/sdk1/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,10 +539,22 @@ def _record_usage(
usage: Mapping[str, int] | None,
llm_api: str,
) -> None:
prompt_tokens = token_counter(model=model, messages=messages)
usage_data: Mapping[str, int] = usage or {}
prompt_tokens = usage_data.get("prompt_tokens")
if prompt_tokens is None:
try:
prompt_tokens = token_counter(model=model, messages=messages)
except Exception as e:
prompt_tokens = 0
logger.warning(
"[sdk1][LLM][%s][%s] Failed to estimate prompt tokens; "
"recording 0 prompt tokens for usage audit: %s",
model,
llm_api,
e,
)
Comment on lines +543 to +555
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pk-zipstack @johnyrahul is this a safe change?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept this scoped to usage accounting only. It still uses provider-reported prompt tokens when available, only estimates when they are missing, and the fallback paths are covered by tests now.

all_tokens = TokenCounterCompat(
prompt_tokens=usage_data.get("prompt_tokens", 0),
prompt_tokens=prompt_tokens or 0,
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
Comment on lines +543 to 560
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Silent zero-token recording risks corrupting billing/usage audit data.

When token_counter raises (e.g., unmapped custom models in LiteLLM's metadata), the code records prompt_tokens=0 into Audit().push_usage_data. Per unstract/sdk1/src/unstract/sdk1/utils/common.py:114-145 and unstract/sdk1/src/unstract/sdk1/audit.py:85-98, that zero flows directly to the platform's usage record with no sentinel/flag distinguishing "unknown" from "actually zero." For long-running workloads against an OpenAI-compatible endpoint that doesn't return usage.prompt_tokens, this could silently understate prompt-token consumption in cost attribution and analytics.

Consider one of:

  • Tagging the audit payload with an estimation_failed / prompt_tokens_source flag so downstream consumers can distinguish missing data from genuinely zero usage.
  • Narrowing the except (e.g., except (KeyError, ValueError, litellm.exceptions.*)) so truly unexpected errors still propagate instead of being swallowed.
  • Emitting a metric/counter when this fallback triggers so ops can detect silent drift.

A warning log alone is easy to miss in aggregated usage reports. This answers the question raised in the prior review thread on this range.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@unstract/sdk1/src/unstract/sdk1/llm.py` around lines 543 - 560, The current
catch in the prompt token estimation around token_counter (used when building
TokenCounterCompat) silently sets prompt_tokens=0; update this to (1) narrow the
except to only expected errors from the estimator (e.g., KeyError/ValueError and
litellm-specific exceptions raised by token_counter) so unexpected errors still
propagate, and (2) add a sentinel field to the usage payload (e.g.,
prompt_tokens_source or estimation_failed) before calling
Audit().push_usage_data to mark that prompt tokens were estimated/failed, and/or
increment an ops metric/counter when the fallback path occurs; reference the
token_counter call, TokenCounterCompat construction, Audit().push_usage_data,
and the existing logger to emit a clear warning and metric.

Expand Down
192 changes: 192 additions & 0 deletions unstract/sdk1/tests/test_openai_compatible_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import json
from functools import lru_cache
from importlib import import_module
from unittest.mock import MagicMock, patch

from unstract.sdk1.adapters.base1 import OpenAICompatibleLLMParameters
from unstract.sdk1.adapters.constants import Common
from unstract.sdk1.adapters.llm1 import adapters
from unstract.sdk1.adapters.llm1.openai_compatible import OpenAICompatibleLLMAdapter

OPENAI_COMPATIBLE_DESCRIPTION = (
"Adapter for servers that implement the OpenAI Chat Completions API "
"(vLLM, LM Studio, self-hosted gateways, and third-party providers). "
"Use OpenAI for the official OpenAI service."
)


@lru_cache(maxsize=1)
def _load_llm_module() -> object:
import sys
from types import ModuleType

with patch.dict(
sys.modules,
{
# Stub python-magic so importing LLM does not depend on libmagic
# being available in the test environment.
"magic": ModuleType("magic")
},
):
return import_module("unstract.sdk1.llm")


def _load_llm_class() -> type:
return _load_llm_module().LLM


def test_openai_compatible_adapter_is_registered() -> None:
adapter_id = OpenAICompatibleLLMAdapter.get_id()

assert adapter_id in adapters
assert adapters[adapter_id][Common.MODULE] is OpenAICompatibleLLMAdapter


def test_openai_compatible_validate_prefixes_model() -> None:
validated = OpenAICompatibleLLMParameters.validate(
{
"api_base": "https://gateway.example.com/v1",
"api_key": "test-key",
"model": "gateway-model",
}
)

assert validated["model"] == "custom_openai/gateway-model"


def test_openai_compatible_validate_preserves_prefixed_model() -> None:
validated = OpenAICompatibleLLMParameters.validate(
{
"api_base": "https://gateway.example.com/v1",
"model": "custom_openai/openai/gpt-4o",
}
)

assert validated["model"] == "custom_openai/openai/gpt-4o"
assert validated["api_key"] is None


def test_openai_compatible_validate_normalizes_blank_api_key_to_none() -> None:
validated = OpenAICompatibleLLMParameters.validate(
{
"api_base": "https://gateway.example.com/v1",
"api_key": " ",
"model": "gateway-model",
}
)

assert validated["api_key"] is None


def test_openai_compatible_schema_is_loadable() -> None:
schema = json.loads(OpenAICompatibleLLMAdapter.get_json_schema())

assert schema["title"] == "OpenAI Compatible LLM"
assert schema["properties"]["api_key"]["type"] == ["string", "null"]
assert "gateway-model" in schema["properties"]["model"]["description"]
assert "ERNIE" not in schema["properties"]["model"]["description"]
assert "qianfan" not in schema["properties"]["api_base"]["description"].lower()
assert "default" not in schema["properties"]["api_base"]


def test_openai_compatible_adapter_uses_distinct_description_and_icon() -> None:
metadata = OpenAICompatibleLLMAdapter.get_metadata()

assert OpenAICompatibleLLMAdapter.get_description() == OPENAI_COMPATIBLE_DESCRIPTION
assert metadata["description"] == OPENAI_COMPATIBLE_DESCRIPTION
assert OpenAICompatibleLLMAdapter.get_icon() == (
"/icons/adapter-icons/OpenAICompatible.png"
)


Comment thread
greptile-apps[bot] marked this conversation as resolved.
def test_record_usage_uses_reported_prompt_tokens_without_estimating() -> None:
llm_module = _load_llm_module()
llm_cls = llm_module.LLM

llm = llm_cls.__new__(llm_cls)
llm._platform_api_key = "platform-key"
llm.platform_kwargs = {"run_id": "run-1"}
llm.adapter = MagicMock()
llm.adapter.get_provider.return_value = "custom_openai"

with (
patch.object(llm_module, "token_counter") as mock_token_counter,
patch.object(llm_module, "Audit") as mock_audit,
):
llm._record_usage(
model="custom_openai/gateway-model",
messages=[{"role": "user", "content": "hello"}],
usage={"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7},
llm_api="complete",
)

mock_token_counter.assert_not_called()
mock_audit.return_value.push_usage_data.assert_called_once()
assert (
mock_audit.return_value.push_usage_data.call_args.kwargs[
"token_counter"
].prompt_llm_token_count
== 3
)


def test_record_usage_tolerates_unmapped_models_without_prompt_tokens() -> None:
llm_module = _load_llm_module()
llm_cls = llm_module.LLM

llm = llm_cls.__new__(llm_cls)
llm._platform_api_key = "platform-key"
llm.platform_kwargs = {"run_id": "run-1"}
llm.adapter = MagicMock()
llm.adapter.get_provider.return_value = "custom_openai"

with (
patch.object(llm_module, "token_counter", side_effect=Exception("unmapped")),
patch.object(llm_module, "Audit") as mock_audit,
patch.object(llm_module.logger, "warning") as mock_warning,
):
llm._record_usage(
model="custom_openai/gateway-model",
messages=[{"role": "user", "content": "hello"}],
usage={"completion_tokens": 4, "total_tokens": 7},
llm_api="complete",
)

mock_audit.return_value.push_usage_data.assert_called_once()
assert (
mock_audit.return_value.push_usage_data.call_args.kwargs[
"token_counter"
].prompt_llm_token_count
== 0
)
assert "recording 0 prompt tokens for usage audit" in mock_warning.call_args.args[0]


def test_record_usage_uses_estimated_prompt_tokens_when_usage_has_none() -> None:
llm_module = _load_llm_module()
llm_cls = llm_module.LLM

llm = llm_cls.__new__(llm_cls)
llm._platform_api_key = "platform-key"
llm.platform_kwargs = {"run_id": "run-1"}
llm.adapter = MagicMock()
llm.adapter.get_provider.return_value = "custom_openai"

with (
patch.object(llm_module, "token_counter", return_value=9) as mock_token_counter,
patch.object(llm_module, "Audit") as mock_audit,
):
llm._record_usage(
model="custom_openai/gateway-model",
messages=[{"role": "user", "content": "hello"}],
usage={"prompt_tokens": None, "completion_tokens": 4, "total_tokens": 13},
llm_api="complete",
)

mock_token_counter.assert_called_once()
assert (
mock_audit.return_value.push_usage_data.call_args.kwargs[
"token_counter"
].prompt_llm_token_count
== 9
)