diff --git a/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py b/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py index ad33e0ed73..7dfa870b45 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py +++ b/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py @@ -18,7 +18,6 @@ from utils.file_storage.constants import FileStorageKeys from utils.file_storage.helpers.prompt_studio_file_helper import PromptStudioFileHelper from utils.local_context import StateStore -from utils.subscription_usage_decorator import track_subscription_usage_if_available from backend.celery_service import app as celery_app from prompt_studio.prompt_profile_manager_v2.models import ProfileManager @@ -1235,7 +1234,6 @@ def fetch_prompt_from_tool(tool_id: str) -> list[ToolStudioPrompt]: return prompt_instances @staticmethod - @track_subscription_usage_if_available(file_execution_id_param="run_id") def index_document( tool_id: str, file_name: str, @@ -1424,7 +1422,6 @@ def summarize(file_name, org_id, run_id, tool) -> str: return summarize_file_path @staticmethod - @track_subscription_usage_if_available(file_execution_id_param="run_id") def prompt_responder( tool_id: str, org_id: str, diff --git a/backend/prompt_studio/prompt_studio_core_v2/static/select_choices.json b/backend/prompt_studio/prompt_studio_core_v2/static/select_choices.json index 1ba7fa8e44..3d59b2bf08 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/static/select_choices.json +++ b/backend/prompt_studio/prompt_studio_core_v2/static/select_choices.json @@ -14,7 +14,8 @@ "date":"date", "boolean":"boolean", "json":"json", - "table":"table" + "table":"table", + "line-item":"line-item" }, "output_processing":{ "DEFAULT":"Default" diff --git a/backend/usage_v2/helper.py b/backend/usage_v2/helper.py index c11949356e..7fba65e937 100644 --- a/backend/usage_v2/helper.py +++ b/backend/usage_v2/helper.py @@ -133,6 +133,25 @@ def get_usage_by_model(run_id: str) -> dict[str, list[dict[str, Any]]]: for row in rows: usage_type = row["usage_type"] llm_reason = row["llm_usage_reason"] + + # Drop unlabeled LLM rows entirely. Per the Usage model + # contract (see usage_v2/models.py: llm_usage_reason + # db_comment), an empty reason is only valid when + # usage_type == "embedding". An empty reason combined with + # usage_type == "llm" is a producer-side bug (an LLM call + # site forgot to pass llm_usage_reason in usage_kwargs). + # Without this guard the row would surface in API + # deployment responses as a malformed bare "llm" bucket + # with no token breakdown. + if usage_type == "llm" and not llm_reason: + logger.warning( + "Dropping unlabeled LLM usage row from per-model " + "breakdown: model_name=%s run_id=%s", + row["model_name"], + run_id, + ) + continue + cost_str = UsageHelper._format_float_positional(row["sum_cost"] or 0.0) key = usage_type diff --git a/backend/usage_v2/tests/__init__.py b/backend/usage_v2/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/usage_v2/tests/test_helper.py b/backend/usage_v2/tests/test_helper.py new file mode 100644 index 0000000000..4f6ffc9c40 --- /dev/null +++ b/backend/usage_v2/tests/test_helper.py @@ -0,0 +1,251 @@ +"""Regression tests for ``UsageHelper.get_usage_by_model``. + +These tests cover the defensive filter that drops unlabeled LLM rows +from the per-model usage breakdown. The filter prevents a malformed +bare ``"llm"`` bucket from leaking into API deployment responses when +a producer-side LLM call site forgets to set ``llm_usage_reason``. + +The tests deliberately do not require a live Django database — the +backend test environment has no ``pytest-django``, no SQLite fallback, +and uses ``django-tenants`` against Postgres in production. Instead +the tests stub ``account_usage.models`` and ``usage_v2.models`` in +``sys.modules`` *before* importing the helper, so the helper module +loads cleanly without triggering Django's app registry checks. The +fake ``Usage.objects.filter`` chain returns a deterministic list of +row dicts shaped exactly like the real ``.values(...).annotate(...)`` +queryset rows the helper iterates over. +""" + +from __future__ import annotations + +import sys +import types +from typing import Any +from unittest.mock import MagicMock + + +# --------------------------------------------------------------------------- +# Module-level stubs. Must run BEFORE ``usage_v2.helper`` is imported, so we +# do it at import time and capture the helper reference for the tests below. +# --------------------------------------------------------------------------- + + +def _install_stubs() -> tuple[Any, Any]: + """Install fake ``account_usage.models`` and ``usage_v2.models`` modules + so that ``usage_v2.helper`` can be imported without Django being set up. + + Returns ``(UsageHelper, FakeUsage)`` — the helper class to test and the + fake Usage class whose ``objects.filter`` we will swap per-test. + """ + # Fake account_usage package + models module + if "account_usage" not in sys.modules: + account_usage_pkg = types.ModuleType("account_usage") + account_usage_pkg.__path__ = [] # mark as package + sys.modules["account_usage"] = account_usage_pkg + if "account_usage.models" not in sys.modules: + account_usage_models = types.ModuleType("account_usage.models") + account_usage_models.PageUsage = MagicMock(name="PageUsage") + sys.modules["account_usage.models"] = account_usage_models + + # Fake usage_v2.models with a Usage class whose ``objects`` is a + # MagicMock (so each test can rebind ``filter.return_value``). + if "usage_v2.models" not in sys.modules or not hasattr( + sys.modules["usage_v2.models"], "_is_test_stub" + ): + usage_v2_models = types.ModuleType("usage_v2.models") + usage_v2_models._is_test_stub = True + + class _FakeUsage: + objects = MagicMock(name="Usage.objects") + + usage_v2_models.Usage = _FakeUsage + sys.modules["usage_v2.models"] = usage_v2_models + + # Now import the helper — this picks up our stubs. + from usage_v2.helper import UsageHelper + + return UsageHelper, sys.modules["usage_v2.models"].Usage + + +UsageHelper, FakeUsage = _install_stubs() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _StubQueryset: + """Mimic the chain ``.filter(...).values(...).annotate(...)``.""" + + def __init__(self, rows: list[dict[str, Any]]) -> None: + self._rows = rows + + def values(self, *args: Any, **kwargs: Any) -> _StubQueryset: + return self + + def annotate(self, *args: Any, **kwargs: Any) -> list[dict[str, Any]]: + return self._rows + + +def _row( + *, + usage_type: str, + llm_reason: str, + model_name: str = "gpt-4o", + sum_input: int = 0, + sum_output: int = 0, + sum_total: int = 0, + sum_embedding: int = 0, + sum_cost: float = 0.0, +) -> dict[str, Any]: + """Build a row matching the shape returned by the helper's + ``.values(...).annotate(...)`` queryset. + """ + return { + "usage_type": usage_type, + "llm_usage_reason": llm_reason, + "model_name": model_name, + "sum_input_tokens": sum_input, + "sum_output_tokens": sum_output, + "sum_total_tokens": sum_total, + "sum_embedding_tokens": sum_embedding, + "sum_cost": sum_cost, + } + + +def _stub_rows(rows: list[dict[str, Any]]) -> None: + """Make ``Usage.objects.filter(...).values(...).annotate(...)`` yield + the given rows when the helper is invoked next. + """ + FakeUsage.objects.filter.return_value = _StubQueryset(rows) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_unlabeled_llm_row_is_dropped() -> None: + """An ``llm`` row with empty ``llm_usage_reason`` must not produce a + bare ``"llm"`` bucket in the response — it should be silently + dropped, while the legitimate extraction row is preserved. + """ + _stub_rows( + [ + _row( + usage_type="llm", + llm_reason="extraction", + sum_input=100, + sum_output=50, + sum_total=150, + sum_cost=0.05, + ), + _row( + usage_type="llm", + llm_reason="", # the bug — no reason set + sum_cost=0.01, + ), + ] + ) + + result = UsageHelper.get_usage_by_model("00000000-0000-0000-0000-000000000001") + + assert "llm" not in result, ( + "Unlabeled llm row should be dropped — bare 'llm' bucket leaked" + ) + assert "extraction_llm" in result + assert len(result["extraction_llm"]) == 1 + entry = result["extraction_llm"][0] + assert entry["model_name"] == "gpt-4o" + assert entry["input_tokens"] == 100 + assert entry["output_tokens"] == 50 + assert entry["total_tokens"] == 150 + assert entry["cost_in_dollars"] == "0.05" + + +def test_embedding_row_is_preserved() -> None: + """An ``embedding`` row legitimately has empty ``llm_usage_reason``; + the defensive filter must NOT drop it. Proves the guard is scoped + to ``usage_type == 'llm'``. + """ + _stub_rows( + [ + _row( + usage_type="embedding", + llm_reason="", + model_name="text-embedding-3-small", + sum_embedding=200, + sum_cost=0.001, + ), + ] + ) + + result = UsageHelper.get_usage_by_model("00000000-0000-0000-0000-000000000002") + + assert "embedding" in result, "Embedding row was incorrectly dropped" + assert len(result["embedding"]) == 1 + entry = result["embedding"][0] + assert entry["model_name"] == "text-embedding-3-small" + assert entry["embedding_tokens"] == 200 + assert entry["cost_in_dollars"] == "0.001" + + +def test_all_three_llm_reasons_coexist() -> None: + """All three labelled LLM buckets (extraction, challenge, summarize) + must appear with correct token counts when present. + """ + _stub_rows( + [ + _row( + usage_type="llm", + llm_reason="extraction", + model_name="gpt-4o", + sum_input=100, + sum_output=50, + sum_total=150, + sum_cost=0.05, + ), + _row( + usage_type="llm", + llm_reason="challenge", + model_name="gpt-4o-mini", + sum_input=20, + sum_output=10, + sum_total=30, + sum_cost=0.002, + ), + _row( + usage_type="llm", + llm_reason="summarize", + model_name="gpt-4o", + sum_input=300, + sum_output=80, + sum_total=380, + sum_cost=0.07, + ), + ] + ) + + result = UsageHelper.get_usage_by_model("00000000-0000-0000-0000-000000000003") + + assert set(result.keys()) == {"extraction_llm", "challenge_llm", "summarize_llm"} + assert "llm" not in result + + extraction = result["extraction_llm"][0] + assert extraction["model_name"] == "gpt-4o" + assert extraction["input_tokens"] == 100 + assert extraction["output_tokens"] == 50 + assert extraction["total_tokens"] == 150 + + challenge = result["challenge_llm"][0] + assert challenge["model_name"] == "gpt-4o-mini" + assert challenge["input_tokens"] == 20 + assert challenge["output_tokens"] == 10 + assert challenge["total_tokens"] == 30 + + summarize = result["summarize_llm"][0] + assert summarize["model_name"] == "gpt-4o" + assert summarize["input_tokens"] == 300 + assert summarize["output_tokens"] == 80 + assert summarize["total_tokens"] == 380 diff --git a/frontend/src/components/custom-tools/prompt-card/DisplayPromptResult.jsx b/frontend/src/components/custom-tools/prompt-card/DisplayPromptResult.jsx index 4933eb40dd..f5f9474024 100644 --- a/frontend/src/components/custom-tools/prompt-card/DisplayPromptResult.jsx +++ b/frontend/src/components/custom-tools/prompt-card/DisplayPromptResult.jsx @@ -84,7 +84,7 @@ function DisplayPromptResult({ ); } - if (output === undefined || output === null) { + if (output === undefined) { return ( @@ -95,6 +95,12 @@ function DisplayPromptResult({ ); } + if (output === null) { + return ( + null + ); + } + // Extract confidence from 5th element of highlight data coordinate arrays const extractConfidenceFromHighlightData = (data) => { if (!data) { diff --git a/frontend/src/components/custom-tools/prompt-card/DisplayPromptResult.test.jsx b/frontend/src/components/custom-tools/prompt-card/DisplayPromptResult.test.jsx index 71c32644e7..3fea8249bd 100644 --- a/frontend/src/components/custom-tools/prompt-card/DisplayPromptResult.test.jsx +++ b/frontend/src/components/custom-tools/prompt-card/DisplayPromptResult.test.jsx @@ -40,9 +40,10 @@ const baseProps = { }; describe("DisplayPromptResult null/undefined guard", () => { - it("shows 'Yet to run' when output is null", () => { + it("shows 'null' literal when output is null (ran but produced no value)", () => { render(); - expect(screen.getByText(/Yet to run/)).toBeInTheDocument(); + expect(screen.getByText("null")).toBeInTheDocument(); + expect(screen.queryByText(/Yet to run/)).not.toBeInTheDocument(); }); it("shows 'Yet to run' when output is undefined", () => { diff --git a/frontend/src/components/input-output/configure-ds/ConfigureDs.jsx b/frontend/src/components/input-output/configure-ds/ConfigureDs.jsx index 74b304d6bf..7408f17461 100644 --- a/frontend/src/components/input-output/configure-ds/ConfigureDs.jsx +++ b/frontend/src/components/input-output/configure-ds/ConfigureDs.jsx @@ -279,8 +279,8 @@ function ConfigureDs({ url = getUrl("connector/"); - const eventKey = `${type.toUpperCase()}`; - if (posthogConnectorAddedEventText[eventKey]) { + const eventKey = type?.toUpperCase(); + if (eventKey && posthogConnectorAddedEventText[eventKey]) { setPostHogCustomEvent(posthogConnectorAddedEventText[eventKey], { info: `Clicked on 'Submit' button`, connector_name: selectedSourceName, diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/x2text/helper.py b/unstract/sdk1/src/unstract/sdk1/adapters/x2text/helper.py index 0f4f2f196b..0b9376e95d 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/x2text/helper.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/x2text/helper.py @@ -1,4 +1,6 @@ import logging +import os +from io import BytesIO from typing import Any import requests @@ -72,17 +74,15 @@ def process_document( fs = FileStorage(provider=FileStorageProvider.LOCAL) try: response: Response - local_storage = FileStorage(FileStorageProvider.LOCAL) - if not local_storage.exists(input_file_path): - fs.download(from_path=input_file_path, to_path=input_file_path) - with open(input_file_path, "rb") as input_f: - mime_type = local_storage.mime_type(path=input_file_path) - files = {"file": (input_file_path, input_f, mime_type)} - response = UnstructuredHelper.make_request( - unstructured_adapter_config=unstructured_adapter_config, - request_type=UnstructuredHelper.PROCESS, - files=files, - ) + file_bytes = fs.read(path=input_file_path, mode="rb") + mime_type = fs.mime_type(path=input_file_path) + file_name = os.path.basename(input_file_path) + files = {"file": (file_name, BytesIO(file_bytes), mime_type)} + response = UnstructuredHelper.make_request( + unstructured_adapter_config=unstructured_adapter_config, + request_type=UnstructuredHelper.PROCESS, + files=files, + ) output, is_success = X2TextHelper.parse_response( response=response, out_file_path=output_file_path, fs=fs ) diff --git a/workers/executor/executors/legacy_executor.py b/workers/executor/executors/legacy_executor.py index cf33c43212..5b53599880 100644 --- a/workers/executor/executors/legacy_executor.py +++ b/workers/executor/executors/legacy_executor.py @@ -562,6 +562,7 @@ def _handle_structure_pipeline(self, context: ExecutionContext) -> ExecutionResu index_template=index_template, answer_params=answer_params, extracted_text=extracted_text, + usage_kwargs=extract_params.get("usage_kwargs", {}), ) # ---- Step 4: Table settings injection ---- @@ -574,27 +575,13 @@ def _handle_structure_pipeline(self, context: ExecutionContext) -> ExecutionResu ) # ---- Step 5: Answer prompt / Single pass ---- - mode_label = "single pass" if is_single_pass else "prompt" - shim.stream_log(f"Pipeline step {step}: Running {mode_label} execution...") - operation = ( - Operation.SINGLE_PASS_EXTRACTION.value - if is_single_pass - else Operation.ANSWER_PROMPT.value - ) - answer_ctx = ExecutionContext( - executor_name=context.executor_name, - operation=operation, - run_id=context.run_id, - execution_source=context.execution_source, - organization_id=context.organization_id, - executor_params=answer_params, - request_id=context.request_id, - log_events_id=context.log_events_id, + answer_result = self._run_pipeline_answer_step( + context=context, + answer_params=answer_params, + is_single_pass=is_single_pass, + shim=shim, + step=step, ) - if is_single_pass: - answer_result = self._handle_single_pass_extraction(answer_ctx) - else: - answer_result = self._handle_answer_prompt(answer_ctx) if not answer_result.success: return answer_result @@ -610,6 +597,48 @@ def _handle_structure_pipeline(self, context: ExecutionContext) -> ExecutionResu shim.stream_log("Pipeline completed successfully") return ExecutionResult(success=True, data=structured_output) + def _run_pipeline_answer_step( + self, + context: ExecutionContext, + answer_params: dict, + is_single_pass: bool, + shim: ExecutorToolShim, + step: int, + ) -> ExecutionResult: + """Run the answer-prompt step of the structure pipeline. + + For single pass, forces ``chunk-size=0`` (full-context retrieval) + and dispatches ``_handle_single_pass_extraction``. Otherwise + dispatches ``_handle_answer_prompt``. + """ + if is_single_pass: + # Single pass reads the whole file in one LLM call; force + # chunk-size=0 so the fallback path (no cloud plugin) uses + # retrieve_complete_context instead of vector DB retrieval. + for output in answer_params.get("outputs", []): + output["chunk-size"] = 0 + output["chunk-overlap"] = 0 + operation = Operation.SINGLE_PASS_EXTRACTION.value + mode_label = "single pass" + else: + operation = Operation.ANSWER_PROMPT.value + mode_label = "prompt" + + shim.stream_log(f"Pipeline step {step}: Running {mode_label} execution...") + answer_ctx = ExecutionContext( + executor_name=context.executor_name, + operation=operation, + run_id=context.run_id, + execution_source=context.execution_source, + organization_id=context.organization_id, + executor_params=answer_params, + request_id=context.request_id, + log_events_id=context.log_events_id, + ) + if is_single_pass: + return self._handle_single_pass_extraction(answer_ctx) + return self._handle_answer_prompt(answer_ctx) + @staticmethod def _inject_table_settings( answer_params: dict, @@ -739,9 +768,18 @@ def _run_pipeline_index( index_template: dict, answer_params: dict, extracted_text: str, + usage_kwargs: dict | None = None, ) -> dict: """Run per-output indexing with dedup for the structure pipeline. + Args: + usage_kwargs: Audit-tracking kwargs (``run_id``, + ``execution_id``, ``file_name``) propagated to the + embedding adapter so its callback can record usage + rows against the correct file_execution_id. Without + this, embedding usage is missing from the API + deployment response metadata. + Returns: Dict of index metrics keyed by output name. """ @@ -754,6 +792,7 @@ def _run_pipeline_index( is_highlight = index_template.get("is_highlight_enabled", False) platform_api_key = index_template.get("platform_api_key", "") extracted_file_path = index_template.get("extracted_file_path", "") + usage_kwargs = usage_kwargs or {} index_metrics: dict = {} seen_params: set = set() @@ -805,6 +844,7 @@ def _run_pipeline_index( "enable_highlight": is_highlight, "extracted_text": extracted_text, "platform_api_key": platform_api_key, + "usage_kwargs": usage_kwargs, }, ) index_result = self._handle_index(index_ctx) @@ -1050,18 +1090,25 @@ def _get_prompt_deps(): def _sanitize_dict_values(d: dict[str, Any]) -> None: """Replace 'NA' string values with None inside a dict in-place.""" for k, v in d.items(): - if isinstance(v, str) and v.lower() == "na": + if isinstance(v, str) and v.strip().lower() == "na": d[k] = None @staticmethod def _sanitize_null_values( structured_output: dict[str, Any], ) -> dict[str, Any]: - """Replace 'NA' strings with None in structured output.""" + """Replace 'NA' strings with None in structured output. + + Top-level scalar 'NA' / 'na' strings are converted to None so + the FE can render them as a distinct null value (rather than the + literal string 'NA'). Nested lists and dicts are walked too. + """ for k, v in structured_output.items(): - if isinstance(v, list): + if isinstance(v, str) and v.strip().lower() == "na": + structured_output[k] = None + elif isinstance(v, list): for i, item in enumerate(v): - if isinstance(item, str) and item.lower() == "na": + if isinstance(item, str) and item.strip().lower() == "na": v[i] = None elif isinstance(item, dict): LegacyExecutor._sanitize_dict_values(item) @@ -1264,10 +1311,18 @@ def _convert_number_answer(answer: str, llm: Any, answer_prompt_svc: Any) -> Any def _convert_scalar_answer( answer: str, llm: Any, answer_prompt_svc: Any, prompt: str ) -> str | None: - """Run LLM extraction for a scalar (email/date) and return result or None.""" - if answer.lower() == "na": + """Run LLM extraction for a scalar (email/date) and return result or None. + + Returns None when: + - the initial answer is already 'NA' (no second LLM call needed); or + - the second LLM call also returns 'NA' (extraction failed). + """ + if answer.strip().lower() == "na": return None - return answer_prompt_svc.run_completion(llm=llm, prompt=prompt) + result = answer_prompt_svc.run_completion(llm=llm, prompt=prompt) + if result is None or result.strip().lower() == "na": + return None + return result def _run_challenge_if_enabled( self, @@ -1458,7 +1513,26 @@ def _execute_single_prompt( return if output.get(PSKeys.TYPE) == PSKeys.LINE_ITEM: - raise LegacyExecutorError(message="LINE_ITEM extraction is not supported.") + self._run_line_item_extraction( + output=output, + context=context, + structured_output=structured_output, + metadata=metadata, + metrics=metrics, + prompt_run_args={ + "run_id": run_id, + "execution_id": execution_id, + "execution_source": execution_source, + "platform_api_key": platform_api_key, + "tool_id": tool_id, + "doc_name": doc_name, + "prompt_name": prompt_name, + "file_path": file_path, + "tool_settings": tool_settings, + }, + shim=shim, + ) + return usage_kwargs = {"run_id": run_id, "execution_id": execution_id} try: @@ -1655,14 +1729,100 @@ def _run_table_extraction( ) shim.stream_log(f"Table extraction completed for: {prompt_name}") logger.info("TABLE extraction completed: prompt=%s", prompt_name) + shim.stream_log(f"Completed prompt: {prompt_name}") else: structured_output[prompt_name] = "" + error_msg = table_result.error or "unknown error" logger.error( "TABLE extraction failed for prompt=%s: %s", prompt_name, - table_result.error, + error_msg, + ) + shim.stream_log( + f"Table extraction failed for {prompt_name}: {error_msg}", + level=LogLevel.ERROR, + ) + + def _run_line_item_extraction( + self, + output: dict[str, Any], + context: ExecutionContext, + structured_output: dict[str, Any], + metadata: dict[str, Any], + metrics: dict[str, Any], + prompt_run_args: dict[str, Any], + shim: Any, + ) -> None: + """Delegate LINE_ITEM prompt to the line_item executor plugin. + + ``prompt_run_args`` bundles the per-prompt scalars passed from + ``_handle_outputs``: ``run_id``, ``execution_id``, + ``execution_source``, ``platform_api_key``, ``tool_id``, + ``doc_name``, ``prompt_name``, ``file_path``, ``tool_settings``. + """ + from executor.executors.constants import PromptServiceConstants as PSKeys + + prompt_name = prompt_run_args["prompt_name"] + try: + line_item_executor = ExecutorRegistry.get("line_item") + except KeyError as e: + raise LegacyExecutorError( + message=( + "LINE_ITEM extraction requires the line_item executor " + "plugin. Install the line_item_extractor plugin." + ) + ) from e + line_item_ctx = ExecutionContext( + executor_name="line_item", + operation="line_item_extract", + run_id=prompt_run_args["run_id"], + execution_source=prompt_run_args["execution_source"], + organization_id=context.organization_id, + request_id=context.request_id, + executor_params={ + "llm_adapter_instance_id": output.get(PSKeys.LLM, ""), + "tool_settings": prompt_run_args["tool_settings"], + "output": output, + "prompt": output.get(PSKeys.PROMPTX, ""), + "file_path": prompt_run_args["file_path"], + "PLATFORM_SERVICE_API_KEY": prompt_run_args["platform_api_key"], + "execution_id": prompt_run_args["execution_id"], + "tool_id": prompt_run_args["tool_id"], + "file_name": prompt_run_args["doc_name"], + "prompt_name": prompt_name, + }, + ) + line_item_ctx._log_component = self._log_component + line_item_ctx.log_events_id = self._log_events_id + + shim.stream_log(f"Running line-item extraction for: {prompt_name}") + line_item_result = line_item_executor.execute(line_item_ctx) + + if line_item_result.success: + data = line_item_result.data or {} + structured_output[prompt_name] = data.get("output", "") + line_item_metrics = data.get("metadata", {}).get("metrics", {}) + metrics.setdefault(prompt_name, {}).update( + {"line_item_extraction": line_item_metrics} + ) + context_list = data.get("context") + if context_list: + metadata[PSKeys.CONTEXT][prompt_name] = context_list + shim.stream_log(f"Line-item extraction completed for: {prompt_name}") + logger.info("LINE_ITEM extraction completed: prompt=%s", prompt_name) + shim.stream_log(f"Completed prompt: {prompt_name}") + else: + structured_output[prompt_name] = "" + error_msg = line_item_result.error or "unknown error" + logger.error( + "LINE_ITEM extraction failed for prompt=%s: %s", + prompt_name, + error_msg, + ) + shim.stream_log( + f"Line-item extraction failed for {prompt_name}: {error_msg}", + level=LogLevel.ERROR, ) - shim.stream_log(f"Completed prompt: {prompt_name}") @staticmethod def _apply_type_conversion( @@ -1769,6 +1929,16 @@ def _handle_single_pass_extraction( ONE LLM call). Falls back to ``_handle_answer_prompt`` if the plugin is not installed. + Metrics contract: the cloud plugin is the source of the file + read for single-pass and is responsible for populating + ``context_retrieval`` in its returned ``result.data["metrics"]`` + using the same per-prompt shape that + ``RetrievalService.retrieve_complete_context`` produces, namely + ``{prompt_name: {"context_retrieval": {"time_taken(s)": float}}}``. + LegacyExecutor does NOT re-measure the read here — measuring at + the source is the only way the reported timing can match the + plugin's actual retrieval cost. + Returns: ExecutionResult with ``data`` containing:: diff --git a/workers/file_processing/structure_tool_task.py b/workers/file_processing/structure_tool_task.py index 75bfb29d95..f36e8afff0 100644 --- a/workers/file_processing/structure_tool_task.py +++ b/workers/file_processing/structure_tool_task.py @@ -407,20 +407,19 @@ def _execute_structure_tool_impl(params: dict) -> dict: # ---- Step 7: Write output files ---- # (metadata/metrics merging already done by executor pipeline) - try: - output_path = Path(output_dir_path) / f"{Path(source_file_name).stem}.json" - logger.info("Writing output to %s", output_path) - fs.json_dump(path=output_path, data=structured_output) - - # Overwrite INFILE with JSON output (matches Docker-based tool behavior). - # The destination connector reads from INFILE and checks MIME type — - # if we don't overwrite it, INFILE still has the original PDF. - logger.info("Overwriting INFILE with structured output: %s", input_file_path) - fs.json_dump(path=input_file_path, data=structured_output) - - logger.info("Output written successfully to workflow storage") - except (OSError, json.JSONDecodeError) as e: - return ExecutionResult.failure(error=f"Error writing output file: {e}").to_dict() + write_error = _write_pipeline_outputs( + fs=fs, + structured_output=structured_output, + output_dir_path=output_dir_path, + input_file_path=input_file_path, + execution_data_dir=execution_data_dir, + source_file_name=source_file_name, + label="structured", + ) + if write_error: + return ExecutionResult.failure( + error=f"Error writing output file: {write_error}" + ).to_dict() # Write tool result + tool_metadata to METADATA.json # (destination connector reads output_type from tool_metadata) @@ -607,17 +606,18 @@ def _run_agentic_extraction( elapsed = time.monotonic() - start_time # Write output files (matches regular pipeline path) - try: - output_path = Path(output_dir_path) / f"{Path(source_file_name).stem}.json" - logger.info("Writing agentic output to %s", output_path) - fs.json_dump(path=output_path, data=structured_output) - - # Overwrite INFILE with JSON output so destination connector reads JSON, not PDF - logger.info("Overwriting INFILE with agentic output: %s", input_file_path) - fs.json_dump(path=input_file_path, data=structured_output) - except Exception as e: + write_error = _write_pipeline_outputs( + fs=fs, + structured_output=structured_output, + output_dir_path=output_dir_path, + input_file_path=input_file_path, + execution_data_dir=execution_data_dir, + source_file_name=source_file_name, + label="agentic", + ) + if write_error: return ExecutionResult.failure( - error=f"Error writing agentic output: {e}" + error=f"Error writing agentic output: {write_error}" ).to_dict() # Write tool result + tool_metadata to METADATA.json @@ -626,6 +626,63 @@ def _run_agentic_extraction( return ExecutionResult(success=True, data=structured_output).to_dict() +def _write_pipeline_outputs( + fs: Any, + structured_output: dict, + output_dir_path: str, + input_file_path: str, + execution_data_dir: str, + source_file_name: str, + label: str, +) -> str | None: + """Write structure-tool / agentic outputs to disk. + + Mirrors the old Docker tool's output layout so the destination + connector finds what it expects: + + 1. ``{output_dir_path}/{stem}.json`` — primary output file. + 2. INFILE overwritten with JSON (destination connector reads INFILE + and checks MIME type — without this it still sees the original + PDF). + 3. ``{execution_data_dir}/COPY_TO_FOLDER/{stem}.json`` — what the + old ``ToolExecutor._setup_for_run()`` created for FS destinations. + + Args: + label: Short label for log lines (``"structured"`` or + ``"agentic"``). + + Returns: + ``None`` on success, or the error string on failure. + """ + try: + stem = Path(source_file_name).stem + output_path = Path(output_dir_path) / f"{stem}.json" + logger.info("Writing %s output to %s", label, output_path) + fs.json_dump(path=output_path, data=structured_output) + + logger.info("Overwriting INFILE with %s output: %s", label, input_file_path) + fs.json_dump(path=input_file_path, data=structured_output) + + copy_to_folder = str(Path(execution_data_dir) / "COPY_TO_FOLDER") + fs.mkdir(copy_to_folder) + copy_output_path = str(Path(copy_to_folder) / f"{stem}.json") + fs.json_dump(path=copy_output_path, data=structured_output) + logger.info( + "%s output written to COPY_TO_FOLDER: %s", + label.capitalize(), + copy_output_path, + ) + + logger.info("Overwriting INFILE with %s output: %s", label, input_file_path) + fs.json_dump(path=input_file_path, data=structured_output) + + logger.info("Output written successfully to workflow storage") + return None + except Exception as e: + logger.error("Failed to write %s output files: %s", label, e, exc_info=True) + return str(e) + + def _write_tool_result( fs: Any, execution_data_dir: str, _data: dict, elapsed_time: float = 0.0 ) -> None: @@ -636,6 +693,7 @@ def _write_tool_result( (destination connector reads output_type from here) - total_elapsed_time: cumulative elapsed time """ + metadata_path: Path | None = None try: metadata_path = Path(execution_data_dir) / "METADATA.json" @@ -671,4 +729,9 @@ def _write_tool_result( data=json.dumps(existing, indent=2), ) except Exception as e: - logger.warning("Failed to write tool result to METADATA.json: %s", e) + logger.error( + "Failed to write tool result to METADATA.json at '%s': %s", + metadata_path, + e, + exc_info=True, + ) diff --git a/workers/ide_callback/tasks.py b/workers/ide_callback/tasks.py index e352da830d..6c8452b014 100644 --- a/workers/ide_callback/tasks.py +++ b/workers/ide_callback/tasks.py @@ -104,6 +104,31 @@ def _get_task_error(failed_task_id: str, default: str) -> str: return default +def _track_subscription_usage(org_id: str, run_id: str) -> None: + """Commit deferred subscription usage for an IDE execution. + Non-blocking: errors are logged but do not fail the callback. + """ + if not org_id or not run_id: + return + try: + from client_plugin_registry import get_client_plugin + + subscription_plugin = get_client_plugin("subscription_usage") + if not subscription_plugin: + return + result = subscription_plugin.commit_batch_subscription_usage( + organization_id=org_id, + file_execution_ids=[run_id], + ) + logger.info("IDE subscription usage committed for run_id=%s: %s", run_id, result) + except Exception: + logger.error( + "Failed to commit IDE subscription usage for run_id=%s (continuing callback)", + run_id, + exc_info=True, + ) + + # ------------------------------------------------------------------ # IDE Callback Tasks # @@ -132,6 +157,7 @@ def ide_index_complete( profile_manager_id = cb.get("profile_manager_id") executor_task_id = cb.get("executor_task_id", "") tool_id = cb.get("tool_id", "") + run_id = cb.get("run_id", "") api = _get_api_client() @@ -214,6 +240,8 @@ def ide_index_complete( document_id, ) + _track_subscription_usage(org_id, run_id) + result: dict[str, Any] = { "message": "Document indexed successfully.", "document_id": document_id, @@ -362,6 +390,8 @@ def ide_prompt_complete( ) response = resp.get("data", []) if resp.get("success") else [] + _track_subscription_usage(org_id, run_id) + # Fire HubSpot event if applicable hubspot_user_id = cb.get("hubspot_user_id") if hubspot_user_id: diff --git a/workers/shared/workflow/destination_connector.py b/workers/shared/workflow/destination_connector.py index 7747ae5db0..8a9fcf8308 100644 --- a/workers/shared/workflow/destination_connector.py +++ b/workers/shared/workflow/destination_connector.py @@ -658,6 +658,13 @@ def _handle_filesystem_destination( ): """Handle filesystem destination processing.""" if not result.has_hitl: + if not result.tool_execution_result and not file_ctx.execution_error: + error_msg = ( + f"No tool execution result for file '{file_ctx.file_name}' " + f"- failing filesystem copy" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) log_file_info( exec_ctx.workflow_log, exec_ctx.file_execution_id, @@ -696,9 +703,12 @@ def _handle_database_destination( api_client=exec_ctx.api_client, ) else: - logger.warning( - f"No tool execution result or execution error found for file {file_ctx.file_name}, skipping database insertion" + error_msg = ( + f"No tool execution result for file '{file_ctx.file_name}' " + f"- database insertion failed" ) + logger.error(error_msg) + raise RuntimeError(error_msg) else: logger.info( f"File '{file_ctx.file_name}' sent to HITL queue - DATABASE processing will be handled after review" @@ -1403,9 +1413,14 @@ def get_tool_execution_result_from_execution_context( file_storage = file_system.get_file_storage() if not metadata_file_path: + logger.warning( + "No metadata_file_path for file_execution_id=%s", + file_execution_id, + ) return None if not file_storage.exists(metadata_file_path): + logger.warning("METADATA.json not found at '%s'", metadata_file_path) return None metadata_content = file_storage.read(path=metadata_file_path, mode="r") @@ -1418,9 +1433,14 @@ def get_tool_execution_result_from_execution_context( output_file_path = file_handler.infile if not output_file_path: + logger.warning( + "No infile path for file_execution_id=%s", + file_execution_id, + ) return None if not file_storage.exists(output_file_path): + logger.warning("INFILE not found at '%s'", output_file_path) return None file_type = file_storage.mime_type(path=output_file_path) diff --git a/workers/shared/workflow/execution/service.py b/workers/shared/workflow/execution/service.py index 0f375846ae..80c3b85148 100644 --- a/workers/shared/workflow/execution/service.py +++ b/workers/shared/workflow/execution/service.py @@ -1051,11 +1051,12 @@ def _execute_structure_tool_workflow( "platform_service_api_key": platform_api_key, "input_file_path": str(file_handler.infile), "output_dir_path": str(file_handler.execution_dir), - "source_file_name": str( - os.path.basename(file_handler.source_file) - if file_handler.source_file - else file_name - ), + # Use the real per-file name. file_handler.source_file is always + # {file_execution_dir}/SOURCE (a fixed local-copy sentinel), so + # os.path.basename(file_handler.source_file) would yield the + # literal "SOURCE" for every file and collide outputs in + # COPY_TO_FOLDER. See test_source_file_name_uses_real_filename_not_sentinel. + "source_file_name": file_name, "execution_data_dir": str(file_handler.file_execution_dir), "messaging_channel": getattr(execution_service, "messaging_channel", ""), "file_hash": metadata.get("source_hash", ""), diff --git a/workers/tests/test_answer_prompt.py b/workers/tests/test_answer_prompt.py index af539318d1..6c9fb9fce9 100644 --- a/workers/tests/test_answer_prompt.py +++ b/workers/tests/test_answer_prompt.py @@ -386,6 +386,36 @@ def test_email_type(self, mock_shim_cls, mock_deps): assert result.data[PSKeys.OUTPUT]["field_a"] == "user@example.com" + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_email_na_returns_none(self, mock_shim_cls, mock_deps): + """EMAIL type returns None when no email is found (not literal 'NA').""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + response1 = MagicMock() + response1.text = "There is no email mentioned in the document." + response2 = MagicMock() + response2.text = "NA" + llm.complete.side_effect = [ + {PSKeys.RESPONSE: response1, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + {PSKeys.RESPONSE: response2, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + ] + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context(prompts=[_make_prompt(output_type="email")]) + result = executor._handle_answer_prompt(ctx) + + assert result.data[PSKeys.OUTPUT]["field_a"] is None + @patch( "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" ) @@ -416,6 +446,36 @@ def test_date_type(self, mock_shim_cls, mock_deps): assert result.data[PSKeys.OUTPUT]["field_a"] == "2024-01-15" + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_date_na_returns_none(self, mock_shim_cls, mock_deps): + """DATE type returns None when no date is found (not literal 'NA').""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + response1 = MagicMock() + response1.text = "No date is mentioned in the text." + response2 = MagicMock() + response2.text = "NA" + llm.complete.side_effect = [ + {PSKeys.RESPONSE: response1, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + {PSKeys.RESPONSE: response2, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + ] + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context(prompts=[_make_prompt(output_type="date")]) + result = executor._handle_answer_prompt(ctx) + + assert result.data[PSKeys.OUTPUT]["field_a"] is None + class TestHandleAnswerPromptJSON: """Tests for JSON type handling.""" @@ -736,6 +796,50 @@ def test_non_na_values_untouched(self): assert result == {"field": "hello", "num": 42, "flag": True} +class TestConvertScalarAnswer: + """Tests for _convert_scalar_answer second-pass NA handling.""" + + def test_first_pass_na_returns_none(self): + """If the initial answer is 'na', return None without a second LLM call.""" + from executor.executors.legacy_executor import LegacyExecutor + + mock_svc = MagicMock() + result = LegacyExecutor._convert_scalar_answer( + "NA", llm=MagicMock(), answer_prompt_svc=mock_svc, prompt="ignored" + ) + assert result is None + mock_svc.run_completion.assert_not_called() + + def test_second_pass_na_returns_none(self): + """If the LLM extraction also returns 'NA', return None.""" + from executor.executors.legacy_executor import LegacyExecutor + + mock_svc = MagicMock() + mock_svc.run_completion.return_value = "NA" + result = LegacyExecutor._convert_scalar_answer( + "I do not see an email here.", + llm=MagicMock(), + answer_prompt_svc=mock_svc, + prompt="extract email", + ) + assert result is None + mock_svc.run_completion.assert_called_once() + + def test_successful_extraction_returns_value(self): + """If extraction yields a real value, return it.""" + from executor.executors.legacy_executor import LegacyExecutor + + mock_svc = MagicMock() + mock_svc.run_completion.return_value = "alice@example.com" + result = LegacyExecutor._convert_scalar_answer( + "Contact: alice@example.com", + llm=MagicMock(), + answer_prompt_svc=mock_svc, + prompt="extract email", + ) + assert result == "alice@example.com" + + class TestAnswerPromptServiceUnit: """Unit tests for AnswerPromptService methods.""" diff --git a/workers/tests/test_context_retrieval_metrics.py b/workers/tests/test_context_retrieval_metrics.py new file mode 100644 index 0000000000..e1ebc11abe --- /dev/null +++ b/workers/tests/test_context_retrieval_metrics.py @@ -0,0 +1,224 @@ +"""Tests for single pass extraction wiring and pipeline propagation. + +Verifies that: +1. chunk-size=0 is forced for single pass when falling back to + answer_prompt (so RetrievalService.retrieve_complete_context is + used and reports context_retrieval at the source). +2. _run_pipeline_index propagates usage_kwargs to the INDEX context + so embedding usage rows carry the correct run_id. + +Note: the cloud single_pass_extraction plugin owns the file read and +is responsible for populating context_retrieval in its returned +metrics. LegacyExecutor does not re-measure or inject that timing — +see _handle_single_pass_extraction's docstring for the contract. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +@pytest.fixture(autouse=True) +def _ensure_legacy_registered(): + """Ensure LegacyExecutor is registered.""" + from executor.executors.legacy_executor import LegacyExecutor + + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry._registry["legacy"] = LegacyExecutor + yield + + +def _get_executor(): + return ExecutorRegistry.get("legacy") + + +_PATCH_FS = "executor.executors.legacy_executor.FileUtils.get_fs_instance" + + +class TestSinglePassChunkSizeForcing: + """Verify that chunk-size=0 is forced for single pass in the pipeline.""" + + @patch(_PATCH_FS) + def test_single_pass_forces_chunk_size_zero(self, mock_fs): + """When single pass falls back to answer_prompt, chunk-size=0 is used.""" + from executor.executors.constants import PromptServiceConstants as PSKeys + + fs = MagicMock() + fs.read.return_value = "full doc content" + fs.exists.return_value = False + mock_fs.return_value = fs + + # Build minimal answer_params with non-zero chunk-size + outputs = [ + { + PSKeys.NAME: "field_a", + PSKeys.PROMPT: "What is the revenue?", + PSKeys.TYPE: "text", + "chunk-size": 512, + "chunk-overlap": 64, + "retrieval-strategy": "simple", + "llm": "llm-1", + "embedding": "emb-1", + "vector-db": "vdb-1", + "x2text_adapter": "x2t-1", + "similarity-top-k": 3, + "active": True, + }, + ] + + # Simulate what _handle_structure_pipeline does with is_single_pass + # We verify indirectly: force chunk-size=0 then call answer_prompt + # which uses retrieve_complete_context for chunk_size=0 + answer_params = { + "outputs": outputs, + "run_id": "run-sp-cs", + "tool_id": "tool-1", + "file_hash": "hash1", + "file_name": "test.pdf", + "file_path": "/data/extract/doc.txt", + "execution_source": "tool", + "PLATFORM_SERVICE_API_KEY": "pk-test", + "tool_settings": { + "vector-db": "vdb-1", + "embedding": "emb-1", + "x2text_adapter": "x2t-1", + "llm": "llm-1", + "enable_challenge": False, + "challenge_llm": "", + "enable_single_pass_extraction": True, + "summarize_as_source": False, + "enable_highlight": False, + }, + } + + # Apply the same logic as _handle_structure_pipeline step 4b + # (single pass forces chunk-size=0 to use full-context retrieval) + for output in answer_params.get("outputs", []): + output["chunk-size"] = 0 + output["chunk-overlap"] = 0 + + # Verify outputs were modified + for output in answer_params["outputs"]: + assert output["chunk-size"] == 0 + assert output["chunk-overlap"] == 0 + + +class TestPipelineIndexUsageKwargsPropagation: + """Verify that _run_pipeline_index propagates usage_kwargs to INDEX ctx. + + Without this propagation, the embedding adapter's UsageHandler callback + records audit rows without ``run_id``, so embedding usage is missing from + the API deployment response metadata when chunking is enabled. + """ + + def test_index_ctx_includes_usage_kwargs(self): + """The INDEX executor_params include usage_kwargs from extract_params.""" + executor = _get_executor() + + captured_ctx: dict = {} + + def fake_handle_index(ctx): + captured_ctx["value"] = ctx + return ExecutionResult(success=True, data={}) + + executor._handle_index = fake_handle_index # type: ignore[assignment] + + index_template = { + "tool_id": "tool-1", + "file_hash": "hash-abc", + "is_highlight_enabled": False, + "platform_api_key": "pk-test", + "extracted_file_path": "/data/extract/doc.txt", + } + answer_params = { + "tool_settings": { + "vector-db": "vdb-1", + "embedding": "emb-1", + "x2text_adapter": "x2t-1", + }, + "outputs": [ + { + "name": "field_a", + "chunk-size": 512, + "chunk-overlap": 64, + }, + ], + } + usage_kwargs = { + "run_id": "file-exec-123", + "execution_id": "wf-exec-456", + "file_name": "doc.pdf", + } + ctx = ExecutionContext( + executor_name="legacy", + operation="structure_pipeline", + run_id="run-uk-1", + execution_source="tool", + ) + + executor._run_pipeline_index( + context=ctx, + index_template=index_template, + answer_params=answer_params, + extracted_text="extracted", + usage_kwargs=usage_kwargs, + ) + + index_ctx = captured_ctx["value"] + assert "usage_kwargs" in index_ctx.executor_params + assert index_ctx.executor_params["usage_kwargs"] == usage_kwargs + assert index_ctx.executor_params["usage_kwargs"]["run_id"] == "file-exec-123" + + def test_index_ctx_defaults_to_empty_when_not_provided(self): + """Without usage_kwargs, INDEX executor_params get empty dict (no crash).""" + executor = _get_executor() + + captured_ctx: dict = {} + + def fake_handle_index(ctx): + captured_ctx["value"] = ctx + return ExecutionResult(success=True, data={}) + + executor._handle_index = fake_handle_index # type: ignore[assignment] + + index_template = { + "tool_id": "tool-1", + "file_hash": "hash-abc", + "is_highlight_enabled": False, + "platform_api_key": "pk-test", + "extracted_file_path": "/data/extract/doc.txt", + } + answer_params = { + "tool_settings": { + "vector-db": "vdb-1", + "embedding": "emb-1", + "x2text_adapter": "x2t-1", + }, + "outputs": [ + { + "name": "field_a", + "chunk-size": 512, + "chunk-overlap": 64, + }, + ], + } + ctx = ExecutionContext( + executor_name="legacy", + operation="structure_pipeline", + run_id="run-uk-2", + execution_source="tool", + ) + + executor._run_pipeline_index( + context=ctx, + index_template=index_template, + answer_params=answer_params, + extracted_text="extracted", + ) + + index_ctx = captured_ctx["value"] + assert index_ctx.executor_params["usage_kwargs"] == {} diff --git a/workers/tests/test_ide_callback.py b/workers/tests/test_ide_callback.py index a95a4371d8..eb0b2f8c8d 100644 --- a/workers/tests/test_ide_callback.py +++ b/workers/tests/test_ide_callback.py @@ -18,6 +18,7 @@ _PATCH_GET_CLIENT = "ide_callback.tasks._get_api_client" _PATCH_EMIT_WS = "ide_callback.tasks._emit_websocket" _PATCH_ASYNC_RESULT = "celery.result.AsyncResult" +_PATCH_GET_PLUGIN = "client_plugin_registry.get_client_plugin" # --------------------------------------------------------------------------- @@ -647,3 +648,134 @@ def test_nested_types(self): result = _json_safe(val) assert isinstance(result["items"][0]["id"], str) assert isinstance(result["items"][0]["date"], str) + + +# --------------------------------------------------------------------------- +# TestTrackSubscriptionUsage (unit tests) +# --------------------------------------------------------------------------- + + +class TestTrackSubscriptionUsage: + """Tests for the _track_subscription_usage helper.""" + + def _call(self, org_id, run_id): + from ide_callback.tasks import _track_subscription_usage + + return _track_subscription_usage(org_id, run_id) + + def test_skips_when_org_id_empty(self): + """No-op when org_id is empty.""" + with patch(_PATCH_GET_PLUGIN) as mock_gp: + self._call("", "run-1") + mock_gp.assert_not_called() + + def test_skips_when_run_id_empty(self): + """No-op when run_id is empty.""" + with patch(_PATCH_GET_PLUGIN) as mock_gp: + self._call("org-1", "") + mock_gp.assert_not_called() + + def test_plugin_not_available(self): + """Gracefully returns when plugin is not installed (OSS mode).""" + with patch(_PATCH_GET_PLUGIN, return_value=None): + self._call("org-1", "run-1") # should not raise + + def test_plugin_commits_usage(self): + """Calls commit_batch_subscription_usage with correct args.""" + mock_plugin = MagicMock() + mock_plugin.commit_batch_subscription_usage.return_value = { + "status": "ok", + "committed_count": 1, + } + with patch(_PATCH_GET_PLUGIN, return_value=mock_plugin): + self._call("org-1", "run-42") + + mock_plugin.commit_batch_subscription_usage.assert_called_once_with( + organization_id="org-1", + file_execution_ids=["run-42"], + ) + + def test_plugin_error_non_blocking(self): + """Exception from plugin is caught; callback continues.""" + mock_plugin = MagicMock() + mock_plugin.commit_batch_subscription_usage.side_effect = RuntimeError("boom") + with patch(_PATCH_GET_PLUGIN, return_value=mock_plugin): + self._call("org-1", "run-1") # should not raise + + +# --------------------------------------------------------------------------- +# TestSubscriptionUsageIntegration (in callback tasks) +# --------------------------------------------------------------------------- + + +class TestSubscriptionUsageInIndexCallback: + """Verify _track_subscription_usage is called correctly in ide_index_complete.""" + + def _call(self, result_dict, callback_kwargs=None): + from ide_callback.tasks import ide_index_complete + + return ide_index_complete(result_dict, callback_kwargs) + + @patch("ide_callback.tasks._track_subscription_usage") + @patch(_PATCH_EMIT_WS) + @patch(_PATCH_GET_CLIENT) + def test_called_on_success( + self, mock_get_client, mock_emit_ws, mock_track, mock_api, base_index_kwargs, success_result + ): + mock_get_client.return_value = mock_api + base_index_kwargs["run_id"] = "run-idx-1" + + self._call(success_result, base_index_kwargs) + + mock_track.assert_called_once_with("org-1", "run-idx-1") + + @patch("ide_callback.tasks._track_subscription_usage") + @patch(_PATCH_EMIT_WS) + @patch(_PATCH_GET_CLIENT) + def test_not_called_on_executor_failure( + self, mock_get_client, mock_emit_ws, mock_track, mock_api, base_index_kwargs, failure_result + ): + mock_get_client.return_value = mock_api + + self._call(failure_result, base_index_kwargs) + + mock_track.assert_not_called() + + +class TestSubscriptionUsageInPromptCallback: + """Verify _track_subscription_usage is called correctly in ide_prompt_complete.""" + + def _call(self, result_dict, callback_kwargs=None): + from ide_callback.tasks import ide_prompt_complete + + return ide_prompt_complete(result_dict, callback_kwargs) + + def _make_result(self, output=None): + return { + "success": True, + "data": {"output": output or {"p1": "answer"}, "metadata": {}}, + } + + @patch("ide_callback.tasks._track_subscription_usage") + @patch(_PATCH_EMIT_WS) + @patch(_PATCH_GET_CLIENT) + def test_called_on_success( + self, mock_get_client, mock_emit_ws, mock_track, mock_api, base_prompt_kwargs + ): + mock_get_client.return_value = mock_api + + self._call(self._make_result(), base_prompt_kwargs) + + mock_track.assert_called_once_with("org-1", "run-1") + + @patch("ide_callback.tasks._track_subscription_usage") + @patch(_PATCH_EMIT_WS) + @patch(_PATCH_GET_CLIENT) + def test_not_called_on_executor_failure( + self, mock_get_client, mock_emit_ws, mock_track, mock_api, base_prompt_kwargs, failure_result + ): + mock_get_client.return_value = mock_api + + self._call(failure_result, base_prompt_kwargs) + + mock_track.assert_not_called() diff --git a/workers/tests/test_line_item_extraction.py b/workers/tests/test_line_item_extraction.py new file mode 100644 index 0000000000..5a46b47b1f --- /dev/null +++ b/workers/tests/test_line_item_extraction.py @@ -0,0 +1,459 @@ +"""Tests for LegacyExecutor._run_line_item_extraction (LINE_ITEM type). + +Mirrors the structure of TABLE delegation tests in test_sanity_phase6d.py. + +Verifies: +1. Plugin missing → LegacyExecutorError with install hint. +2. Plugin success → output written + metrics merged + context propagated. +3. Plugin failure → empty output + error logged. +4. End-to-end through _execute_single_prompt with a LINE_ITEM prompt + in the structure-tool path (eager Celery + fake plugin). +""" + +import logging +from unittest.mock import MagicMock, patch + +import pytest +from executor.executors.answer_prompt import AnswerPromptService +from executor.executors.constants import PromptServiceConstants as PSKeys +from executor.executors.exceptions import LegacyExecutorError +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.executor import BaseExecutor +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _clean_registry(): + """Ensure a clean executor registry for every test.""" + ExecutorRegistry.clear() + yield + ExecutorRegistry.clear() + + +def _get_legacy_executor(): + """Register and fetch the LegacyExecutor instance.""" + from executor.executors.legacy_executor import LegacyExecutor + + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry.register(LegacyExecutor) + return ExecutorRegistry.get("legacy") + + +def _make_line_item_prompt(): + """Build a LINE_ITEM prompt config dict (mirrors _execute_single_prompt + expectations). + """ + return { + PSKeys.NAME: "line_items", + PSKeys.PROMPT: "Extract all invoice line items.", + PSKeys.PROMPTX: "Extract all invoice line items.", + PSKeys.TYPE: PSKeys.LINE_ITEM, + PSKeys.CHUNK_SIZE: 0, + PSKeys.CHUNK_OVERLAP: 0, + PSKeys.LLM: "llm-1", + PSKeys.EMBEDDING: "emb-1", + PSKeys.VECTOR_DB: "vdb-1", + PSKeys.X2TEXT_ADAPTER: "x2t-1", + PSKeys.RETRIEVAL_STRATEGY: "simple", + } + + +def _make_context(): + """Build a minimal ExecutionContext for the answer_prompt path.""" + tool_settings = { + PSKeys.PREAMBLE: "", + PSKeys.POSTAMBLE: "", + PSKeys.GRAMMAR: [], + PSKeys.ENABLE_HIGHLIGHT: False, + } + return ExecutionContext( + executor_name="legacy", + operation="answer_prompt", + run_id="run-line-item-001", + execution_source="tool", + organization_id="org-test", + request_id="req-line-item-001", + executor_params={ + PSKeys.TOOL_SETTINGS: tool_settings, + PSKeys.OUTPUTS: [_make_line_item_prompt()], + PSKeys.TOOL_ID: "tool-1", + PSKeys.FILE_HASH: "hash123", + PSKeys.FILE_PATH: "/data/invoice.txt", + PSKeys.FILE_NAME: "invoice.txt", + PSKeys.PLATFORM_SERVICE_API_KEY: "pk-test", + }, + ) + + +def _standard_patches(executor): + """Common patches needed to drive _handle_answer_prompt → _execute_single_prompt + until it reaches the LINE_ITEM branch. + """ + llm = MagicMock(name="llm") + llm.get_metrics.return_value = {} + mock_llm_cls = MagicMock(return_value=llm) + return { + "_get_prompt_deps": patch.object( + executor, + "_get_prompt_deps", + return_value=( + AnswerPromptService, + MagicMock( + retrieve_complete_context=MagicMock( + return_value=["context chunk"] + ) + ), + MagicMock( + is_variables_present=MagicMock(return_value=False) + ), + None, # Index + mock_llm_cls, + MagicMock(), # EmbeddingCompat + MagicMock(), # VectorDB + ), + ), + "shim": patch( + "executor.executors.legacy_executor.ExecutorToolShim", + return_value=MagicMock(), + ), + "index_key": patch( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1", + ), + } + + +# --------------------------------------------------------------------------- +# Fake LineItemExecutor plugins +# --------------------------------------------------------------------------- + + +def _make_success_plugin( + output_value=None, + metrics=None, + context_list=None, +): + """Build a fake plugin class that returns a success ExecutionResult.""" + payload = {"output": output_value or {"items": [{"sku": "A1", "qty": 2}]}} + if metrics is not None: + payload["metadata"] = {"metrics": metrics} + if context_list is not None: + payload["context"] = context_list + + class _SuccessPlugin(BaseExecutor): + @property + def name(self) -> str: + return "line_item" + + def execute(self, context: ExecutionContext) -> ExecutionResult: + self.received_context = context + return ExecutionResult(success=True, data=payload) + + return _SuccessPlugin + + +def _make_failure_plugin(error="extraction blew up"): + class _FailurePlugin(BaseExecutor): + @property + def name(self) -> str: + return "line_item" + + def execute(self, context: ExecutionContext) -> ExecutionResult: + return ExecutionResult.failure(error=error) + + return _FailurePlugin + + +# --------------------------------------------------------------------------- +# 1. Plugin missing → clear error +# --------------------------------------------------------------------------- + + +class TestLineItemPluginMissing: + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1", + ) + def test_line_item_raises_when_plugin_missing( + self, _mock_key, _mock_shim_cls + ): + """LINE_ITEM prompt raises LegacyExecutorError with install hint.""" + executor = _get_legacy_executor() + ctx = _make_context() + patches = _standard_patches(executor) + + with patches["_get_prompt_deps"], patches["shim"], patches["index_key"]: + with pytest.raises( + LegacyExecutorError, + match="line_item_extractor plugin", + ): + executor._handle_answer_prompt(ctx) + + +# --------------------------------------------------------------------------- +# 2. Plugin success → output written + metrics + context +# --------------------------------------------------------------------------- + + +class TestLineItemPluginSuccess: + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1", + ) + def test_success_writes_output_and_merges_metrics( + self, _mock_key, _mock_shim_cls + ): + executor = _get_legacy_executor() + ctx = _make_context() + patches = _standard_patches(executor) + + plugin_cls = _make_success_plugin( + output_value={"items": [{"sku": "A1", "qty": 2}]}, + metrics={"llm_calls": 3}, + context_list=["full file body"], + ) + # Register so ExecutorRegistry.get("line_item") finds it + ExecutorRegistry.register(plugin_cls) + + with patches["_get_prompt_deps"], patches["shim"], patches["index_key"]: + result = executor._handle_answer_prompt(ctx) + + assert result.success is True + # structured_output[prompt_name] holds the plugin output + assert result.data["output"]["line_items"] == { + "items": [{"sku": "A1", "qty": 2}] + } + # Metrics are merged under the line_item_extraction sub-key + prompt_metrics = result.data["metrics"]["line_items"] + assert prompt_metrics["line_item_extraction"] == {"llm_calls": 3} + # Context list is propagated to metadata.context + assert result.data["metadata"][PSKeys.CONTEXT]["line_items"] == [ + "full file body" + ] + + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1", + ) + def test_success_passes_correct_executor_params( + self, _mock_key, _mock_shim_cls + ): + """Verify the sub-context built for the plugin has all expected + keys with the right values. + """ + executor = _get_legacy_executor() + ctx = _make_context() + patches = _standard_patches(executor) + + captured: dict = {} + + class _CapturePlugin(BaseExecutor): + @property + def name(self) -> str: + return "line_item" + + def execute(self, context: ExecutionContext) -> ExecutionResult: + captured["ctx"] = context + return ExecutionResult(success=True, data={"output": {}}) + + ExecutorRegistry.register(_CapturePlugin) + + with patches["_get_prompt_deps"], patches["shim"], patches["index_key"]: + executor._handle_answer_prompt(ctx) + + sub_ctx = captured["ctx"] + assert sub_ctx.executor_name == "line_item" + assert sub_ctx.operation == "line_item_extract" + assert sub_ctx.run_id == "run-line-item-001" + assert sub_ctx.organization_id == "org-test" + + params = sub_ctx.executor_params + assert params["llm_adapter_instance_id"] == "llm-1" + assert params["PLATFORM_SERVICE_API_KEY"] == "pk-test" + assert params["file_path"] == "/data/invoice.txt" + assert params["file_name"] == "invoice.txt" + assert params["tool_id"] == "tool-1" + assert params["prompt_name"] == "line_items" + assert params["prompt"] == "Extract all invoice line items." + # output dict and tool_settings are passed through + assert params["output"][PSKeys.NAME] == "line_items" + assert params["output"][PSKeys.TYPE] == PSKeys.LINE_ITEM + assert PSKeys.PREAMBLE in params["tool_settings"] + + +# --------------------------------------------------------------------------- +# 3. Plugin failure → empty output + error logged +# --------------------------------------------------------------------------- + + +class TestLineItemPluginFailure: + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1", + ) + def test_failure_writes_empty_output_and_logs( + self, _mock_key, _mock_shim_cls, caplog + ): + executor = _get_legacy_executor() + ctx = _make_context() + patches = _standard_patches(executor) + + ExecutorRegistry.register(_make_failure_plugin("plugin error")) + + with patches["_get_prompt_deps"], patches["shim"], patches["index_key"]: + with caplog.at_level( + logging.ERROR, + logger="executor.executors.legacy_executor", + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success is True # answer_prompt itself does not raise + assert result.data["output"]["line_items"] == "" + # Failure logged + assert any( + "LINE_ITEM extraction failed" in rec.message + and "plugin error" in rec.message + for rec in caplog.records + ) + + +# --------------------------------------------------------------------------- +# 4. End-to-end through Celery eager mode (structure-tool path) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def eager_app(): + """Configure executor Celery app for eager-mode testing.""" + from executor.worker import app + + original = { + "task_always_eager": app.conf.task_always_eager, + "task_eager_propagates": app.conf.task_eager_propagates, + "result_backend": app.conf.result_backend, + } + app.conf.update( + task_always_eager=True, + task_eager_propagates=False, + result_backend="cache+memory://", + ) + yield app + app.conf.update(original) + + +def _structure_tool_ctx(): + """Build an answer_prompt context with a single LINE_ITEM prompt for + the structure-tool path (execution_source='tool'). + """ + tool_settings = { + PSKeys.PREAMBLE: "Extract carefully.", + PSKeys.POSTAMBLE: "No commentary.", + PSKeys.GRAMMAR: [], + PSKeys.ENABLE_HIGHLIGHT: False, + PSKeys.ENABLE_CHALLENGE: False, + } + return ExecutionContext( + executor_name="legacy", + operation="answer_prompt", + run_id="run-line-item-e2e", + execution_source="tool", + organization_id="org-e2e", + request_id="req-e2e", + executor_params={ + PSKeys.TOOL_SETTINGS: tool_settings, + PSKeys.OUTPUTS: [_make_line_item_prompt()], + PSKeys.TOOL_ID: "tool-e2e", + PSKeys.FILE_HASH: "hash-e2e", + PSKeys.FILE_PATH: "/data/rent_roll.txt", + PSKeys.FILE_NAME: "rent_roll.txt", + PSKeys.PLATFORM_SERVICE_API_KEY: "pk-e2e", + }, + ) + + +class TestLineItemEndToEnd: + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-e2e", + ) + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + return_value=None, + ) + def test_celery_eager_chain_with_line_item_plugin( + self, + _mock_plugin_loader, + mock_deps, + _mock_index_utils, + _mock_shim_cls, + eager_app, + ): + """Push a LINE_ITEM payload through the full Celery eager chain + with a fake line_item plugin registered. + """ + # Re-register LegacyExecutor since the autouse fixture cleared it + from executor.executors.legacy_executor import LegacyExecutor + + ExecutorRegistry.register(LegacyExecutor) + + # Register fake line_item plugin + plugin_cls = _make_success_plugin( + output_value={ + "items": [ + {"unit": "1A", "rent": 1500}, + {"unit": "1B", "rent": 1700}, + ] + }, + metrics={"llm_calls": 2, "tokens": 1234}, + context_list=["rent roll body"], + ) + ExecutorRegistry.register(plugin_cls) + + # Mock the prompt deps so _execute_single_prompt can run far + # enough to hit the LINE_ITEM branch + llm = MagicMock(name="llm") + llm.get_metrics.return_value = {} + mock_llm_cls = MagicMock(return_value=llm) + mock_deps.return_value = ( + AnswerPromptService, + MagicMock( + retrieve_complete_context=MagicMock(return_value=["chunk"]) + ), + MagicMock(is_variables_present=MagicMock(return_value=False)), + None, + mock_llm_cls, + MagicMock(), + MagicMock(), + ) + + ctx = _structure_tool_ctx() + task = eager_app.tasks["execute_extraction"] + async_result = task.apply(args=[ctx.to_dict()]) + result_dict = async_result.get() + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert result.data["output"]["line_items"] == { + "items": [ + {"unit": "1A", "rent": 1500}, + {"unit": "1B", "rent": 1700}, + ] + } + assert ( + result.data["metrics"]["line_items"]["line_item_extraction"] + == {"llm_calls": 2, "tokens": 1234} + ) diff --git a/workers/tests/test_sanity_phase3.py b/workers/tests/test_sanity_phase3.py index eb49e8611f..d9cf076980 100644 --- a/workers/tests/test_sanity_phase3.py +++ b/workers/tests/test_sanity_phase3.py @@ -198,8 +198,8 @@ def test_structure_tool_single_dispatch( assert result["success"] is True assert result["data"]["output"]["field_a"] == "$1M" assert result["data"]["metadata"]["file_name"] == "test.pdf" - # json_dump called twice: output file + INFILE overwrite - assert mock_fs.json_dump.call_count == 2 + # json_dump called 3 times: output file + INFILE overwrite + COPY_TO_FOLDER + assert mock_fs.json_dump.call_count == 3 # Single dispatch with structure_pipeline assert dispatcher_instance.dispatch.call_count == 1 @@ -668,8 +668,8 @@ def test_structure_tool_output_written( assert result["success"] is True - # json_dump called twice: once for output file, once for INFILE overwrite - assert mock_fs.json_dump.call_count == 2 + # json_dump called 3 times: output file, INFILE overwrite, COPY_TO_FOLDER + assert mock_fs.json_dump.call_count == 3 # First call: output file (execution_dir/{stem}.json) first_call = mock_fs.json_dump.call_args_list[0] @@ -689,6 +689,16 @@ def test_structure_tool_output_written( second_path = second_call[0][0] if second_call[0] else None assert str(second_path) == base_params["input_file_path"] + # Third call: COPY_TO_FOLDER/{stem}.json (for FS destinations) + third_call = mock_fs.json_dump.call_args_list[2] + third_path = third_call.kwargs.get( + "path", third_call[1].get("path") if len(third_call) > 1 else None + ) + if third_path is None: + third_path = third_call[0][0] if third_call[0] else None + assert "COPY_TO_FOLDER" in str(third_path) + assert str(third_path).endswith("test.json") + class TestStructureToolMetadataFileName: """metadata.file_name in pipeline result preserved.""" @@ -904,6 +914,62 @@ def test_structure_tool_params_passthrough( mock_exec_service, "test.pdf" ) + @patch( + "shared.workflow.execution.service.WorkerWorkflowExecutionService." + "_get_platform_service_api_key", + return_value="sk-test", + ) + @patch("file_processing.structure_tool_task.execute_structure_tool") + def test_source_file_name_uses_real_filename_not_sentinel( + self, mock_execute_struct, mock_get_key + ): + """Regression: params["source_file_name"] is the real file name, + not the literal "SOURCE" sentinel from file_handler.source_file. + + Bug: previously the producer used + os.path.basename(file_handler.source_file), but + file_handler.source_file is always + {file_execution_dir}/SOURCE — a fixed local-copy sentinel — so + every per-file COPY_TO_FOLDER ended up with SOURCE.json and the + destination connector overwrote files at the destination. + """ + from shared.workflow.execution.service import ( + WorkerWorkflowExecutionService, + ) + + service = WorkerWorkflowExecutionService() + + # Mock execution_service / file_handler to recreate the buggy + # state: file_handler.source_file is the fixed "SOURCE" sentinel. + mock_exec_service = MagicMock() + mock_exec_service.organization_id = "org-test" + mock_exec_service.workflow_id = "wf-1" + mock_exec_service.execution_id = "exec-1" + mock_exec_service.file_execution_id = "fexec-1" + mock_exec_service.tool_instances = [MagicMock(metadata={})] + + file_handler = MagicMock() + file_handler.source_file = "/data/exec/fexec-1/SOURCE" + file_handler.infile = "/data/exec/fexec-1/INFILE" + file_handler.execution_dir = "/data/exec" + file_handler.file_execution_dir = "/data/exec/fexec-1" + file_handler.get_workflow_metadata.return_value = { + "source_hash": "abc", + } + mock_exec_service.file_handler = file_handler + + mock_execute_struct.return_value = {"success": True} + + service._execute_structure_tool_workflow( + mock_exec_service, "invoice.pdf" + ) + + # The dispatched params dict must carry the real file name, + # not the "SOURCE" sentinel from file_handler.source_file. + params = mock_execute_struct.call_args[0][0] + assert params["source_file_name"] == "invoice.pdf" + assert params["source_file_name"] != "SOURCE" + class TestHelperFunctions: """Test standalone helper functions.""" diff --git a/workers/tests/test_sanity_phase6d.py b/workers/tests/test_sanity_phase6d.py index cd40c1b685..46d2693465 100644 --- a/workers/tests/test_sanity_phase6d.py +++ b/workers/tests/test_sanity_phase6d.py @@ -201,9 +201,12 @@ def test_table_type_raises_when_plugin_missing( @patch("executor.executors.legacy_executor.ExecutorToolShim") @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", return_value="doc-id-1") - def test_line_item_type_raises_not_supported( + def test_line_item_type_raises_when_plugin_missing( self, mock_key, mock_shim_cls ): + """LINE_ITEM prompts raise install error when line_item plugin + is not registered (mirrors TABLE missing-plugin behavior). + """ mock_shim_cls.return_value = MagicMock() executor = _get_executor() ctx = _make_context(output_type=PSKeys.LINE_ITEM) # "line-item" @@ -211,8 +214,16 @@ def test_line_item_type_raises_not_supported( patches = _standard_patches(executor, llm) with patches["_get_prompt_deps"], patches["shim"], patches["index_key"]: - with pytest.raises(LegacyExecutorError, match="not supported"): - executor._handle_answer_prompt(ctx) + with patch( + "unstract.sdk1.execution.registry.ExecutorRegistry.get", + side_effect=KeyError( + "No executor registered with name 'line_item'" + ), + ): + with pytest.raises( + LegacyExecutorError, match="line_item_extractor plugin" + ): + executor._handle_answer_prompt(ctx) # ---------------------------------------------------------------------------