diff --git a/src/mistralai/extra/observability/serialization.py b/src/mistralai/extra/observability/formatting.py similarity index 84% rename from src/mistralai/extra/observability/serialization.py rename to src/mistralai/extra/observability/formatting.py index de3bfce2..34dc9aed 100644 --- a/src/mistralai/extra/observability/serialization.py +++ b/src/mistralai/extra/observability/formatting.py @@ -1,7 +1,9 @@ -"""Serialization helpers for converting Mistral API payloads to OTEL GenAI convention formats. +"""Formatting helpers for converting Mistral API payloads to OTEL GenAI convention formats. -These are pure functions with no OTEL dependencies — they transform dicts to JSON strings +These are pure functions with no OTEL dependencies — they transform dicts to dicts matching the GenAI semantic convention schemas for input/output messages and tool definitions. +The caller is responsible for the final JSON serialization (single json.dumps on the whole +collection) before setting span attributes. Schemas: - Input messages: https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-input-messages.json @@ -9,7 +11,6 @@ - Tool definitions: https://github.com/Cirilla-zmh/semantic-conventions/blob/cc4d07e7e56b80e9aa5904a3d524c134699da37f/docs/gen-ai/gen-ai-tool-definitions.json """ -import json from typing import Any @@ -72,8 +73,8 @@ def _tool_calls_to_parts(tool_calls: list[dict] | None) -> list[dict]: return parts -def serialize_input_message(message: dict[str, Any]) -> str: - """Serialize a single input message per the OTEL GenAI convention. +def format_input_message(message: dict[str, Any]) -> dict[str, Any]: + """Format a single input message per the OTEL GenAI convention. Schema: https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-input-messages.json ChatMessage: {role (required), parts (required), name?} @@ -89,7 +90,7 @@ def serialize_input_message(message: dict[str, Any]) -> str: part: dict = {"type": "tool_call_response", "response": message.get("result")} if (tool_call_id := message.get("tool_call_id")) is not None: part["id"] = tool_call_id - return json.dumps({"role": "tool", "parts": [part]}) + return {"role": "tool", "parts": [part]} # TODO: may need to handle other types for conversations (e.g. agent handoff) @@ -109,11 +110,11 @@ def serialize_input_message(message: dict[str, Any]) -> str: parts.extend(_content_to_parts(message.get("content"))) parts.extend(_tool_calls_to_parts(message.get("tool_calls"))) - return json.dumps({"role": role, "parts": parts}) + return {"role": role, "parts": parts} -def serialize_output_message(choice: dict[str, Any]) -> str: - """Serialize a single output choice/message per the OTEL GenAI convention. +def format_output_message(choice: dict[str, Any]) -> dict[str, Any]: + """Format a single output choice/message per the OTEL GenAI convention. Schema: https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-output-messages.json OutputMessage: {role (required), parts (required), finish_reason (required), name?} @@ -123,16 +124,14 @@ def serialize_output_message(choice: dict[str, Any]) -> str: parts.extend(_content_to_parts(message.get("content"))) parts.extend(_tool_calls_to_parts(message.get("tool_calls"))) - return json.dumps( - { - "role": message.get("role", "assistant"), - "parts": parts, - "finish_reason": choice.get("finish_reason", ""), - } - ) + return { + "role": message.get("role", "assistant"), + "parts": parts, + "finish_reason": choice.get("finish_reason", ""), + } -def serialize_tool_definition(tool: dict[str, Any]) -> str | None: +def format_tool_definition(tool: dict[str, Any]) -> dict[str, Any] | None: """Flatten a Mistral tool definition to the OTEL GenAI convention schema. Mistral format: {"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}} @@ -148,9 +147,9 @@ def serialize_tool_definition(tool: dict[str, Any]) -> str | None: name = func.get("name") if not name: return None - serialized: dict = {"type": type, "name": name} + formatted: dict = {"type": type, "name": name} if (description := func.get("description")) is not None: - serialized["description"] = description + formatted["description"] = description if (parameters := func.get("parameters")) is not None: - serialized["parameters"] = parameters - return json.dumps(serialized) + formatted["parameters"] = parameters + return formatted diff --git a/src/mistralai/extra/observability/otel.py b/src/mistralai/extra/observability/otel.py index 6ea37389..7c75271e 100644 --- a/src/mistralai/extra/observability/otel.py +++ b/src/mistralai/extra/observability/otel.py @@ -23,10 +23,10 @@ from opentelemetry.baggage import get_baggage from opentelemetry.trace import Span, Status, StatusCode, Tracer, set_span_in_context -from .serialization import ( - serialize_input_message, - serialize_output_message, - serialize_tool_definition, +from .formatting import ( + format_input_message, + format_output_message, + format_tool_definition, ) from .streaming import accumulate_chunks_to_response_dict, parse_sse_chunks @@ -185,18 +185,20 @@ def _enrich_request_genai_attrs( # Chat/agent completion API uses messages in request body; conversation API uses inputs input_messages = request_body.get("messages") or request_body.get("inputs") if isinstance(input_messages, str): - attributes[gen_ai_attributes.GEN_AI_INPUT_MESSAGES] = [ - serialize_input_message({"role": "user", "content": input_messages}) - ] + attributes[gen_ai_attributes.GEN_AI_INPUT_MESSAGES] = json.dumps( + [format_input_message({"role": "user", "content": input_messages})] + ) elif isinstance(input_messages, list): - attributes[gen_ai_attributes.GEN_AI_INPUT_MESSAGES] = list( - map(serialize_input_message, input_messages) + attributes[gen_ai_attributes.GEN_AI_INPUT_MESSAGES] = json.dumps( + list(map(format_input_message, input_messages)) ) # Tool definitions if tools := request_body.get("tools"): - attributes[gen_ai_attributes.GEN_AI_TOOL_DEFINITIONS] = list( - filter(None, map(serialize_tool_definition, tools)) - ) + formatted_tools = list(filter(None, map(format_tool_definition, tools))) + if formatted_tools: + attributes[gen_ai_attributes.GEN_AI_TOOL_DEFINITIONS] = json.dumps( + formatted_tools + ) # TODO: For agent start conversation, add agent id and version attributes here ? set_available_attributes(span, attributes) @@ -244,8 +246,8 @@ def _enrich_response_genai_attrs( if finish_reasons: attributes[gen_ai_attributes.GEN_AI_RESPONSE_FINISH_REASONS] = finish_reasons if choices: - attributes[gen_ai_attributes.GEN_AI_OUTPUT_MESSAGES] = list( - map(serialize_output_message, choices) + attributes[gen_ai_attributes.GEN_AI_OUTPUT_MESSAGES] = json.dumps( + list(map(format_output_message, choices)) ) # Usage @@ -305,7 +307,8 @@ def _create_tool_execution_child_span( if isinstance(tool_arguments, str) else (json.dumps(tool_arguments) if tool_arguments else None), gen_ai_attributes.GEN_AI_TOOL_CALL_RESULT: tool_result - and json.dumps(tool_result), + if isinstance(tool_result, str) + else (json.dumps(tool_result) if tool_result else None), gen_ai_attributes.GEN_AI_TOOL_NAME: output.get("name"), gen_ai_attributes.GEN_AI_TOOL_TYPE: "extension", } @@ -338,9 +341,9 @@ def _create_message_output_child_span( gen_ai_attributes.GEN_AI_RESPONSE_ID: output.get("id"), gen_ai_attributes.GEN_AI_AGENT_ID: output.get("agent_id"), gen_ai_attributes.GEN_AI_RESPONSE_MODEL: output.get("model"), - gen_ai_attributes.GEN_AI_OUTPUT_MESSAGES: [ - serialize_output_message(choice_wrapper) - ], + gen_ai_attributes.GEN_AI_OUTPUT_MESSAGES: json.dumps( + [format_output_message(choice_wrapper)] + ), } set_available_attributes(child_span, message_attributes) child_span.end(end_time=end_ns) diff --git a/src/mistralai/extra/tests/test_serialization.py b/src/mistralai/extra/tests/test_formatting.py similarity index 83% rename from src/mistralai/extra/tests/test_serialization.py rename to src/mistralai/extra/tests/test_formatting.py index 3c88aa71..3dd1dee4 100644 --- a/src/mistralai/extra/tests/test_serialization.py +++ b/src/mistralai/extra/tests/test_formatting.py @@ -1,26 +1,20 @@ -"""Unit tests for the OTEL serialization helpers. +"""Unit tests for the OTEL formatting helpers. Each test covers a single function with both happy-path and edge-case inputs. -The functions are pure (dict → str/list), so no OTEL setup is needed. +The functions are pure (dict -> dict/list), so no OTEL setup is needed. """ -import json import unittest -from mistralai.extra.observability.serialization import ( +from mistralai.extra.observability.formatting import ( _content_to_parts, _tool_calls_to_parts, - serialize_input_message, - serialize_output_message, - serialize_tool_definition, + format_input_message, + format_output_message, + format_tool_definition, ) -def _parse(json_str: str): - """Shorthand: parse a JSON string returned by a serialize_* function.""" - return json.loads(json_str) - - class TestContentToParts(unittest.TestCase): def test_none(self): self.assertEqual(_content_to_parts(None), []) @@ -98,7 +92,7 @@ def test_thinking_chunk_fallback_plain_string(self): ) def test_thinking_chunk_missing_thinking_field(self): - """Empty string default → str("") fallback.""" + """Empty string default -> str("") fallback.""" chunk = {"type": "thinking"} self.assertEqual( _content_to_parts([chunk]), @@ -195,7 +189,7 @@ def test_missing_arguments(self): ) def test_missing_function(self): - """No function key → empty name.""" + """No function key -> empty name.""" tc = {"id": "1"} self.assertListEqual( _tool_calls_to_parts([tc]), @@ -210,11 +204,11 @@ def test_function_is_none(self): ) -class TestSerializeInputMessage(unittest.TestCase): +class TestFormatInputMessage(unittest.TestCase): # -- Happy paths (role-based messages) ------------------------------------ def test_user_message(self): - result = _parse(serialize_input_message({"role": "user", "content": "hi"})) + result = format_input_message({"role": "user", "content": "hi"}) self.assertDictEqual( result, { @@ -224,9 +218,7 @@ def test_user_message(self): ) def test_system_message(self): - result = _parse( - serialize_input_message({"role": "system", "content": "be helpful"}) - ) + result = format_input_message({"role": "system", "content": "be helpful"}) self.assertDictEqual( result, { @@ -241,7 +233,7 @@ def test_assistant_message_with_tool_calls(self): "content": "", "tool_calls": [{"id": "tc1", "function": {"name": "f", "arguments": "{}"}}], } - result = _parse(serialize_input_message(msg)) + result = format_input_message(msg) self.assertEqual(result["role"], "assistant") # text part from content + tool_call part self.assertListEqual( @@ -251,7 +243,7 @@ def test_assistant_message_with_tool_calls(self): def test_tool_message(self): msg = {"role": "tool", "content": "22C sunny", "tool_call_id": "tc1"} - result = _parse(serialize_input_message(msg)) + result = format_input_message(msg) self.assertDictEqual( result, { @@ -264,7 +256,7 @@ def test_tool_message(self): def test_tool_message_without_tool_call_id(self): msg = {"role": "tool", "content": "result"} - result = _parse(serialize_input_message(msg)) + result = format_input_message(msg) self.assertNotIn("id", result["parts"][0]) # -- Conversation entry: function.result ---------------------------------- @@ -275,7 +267,7 @@ def test_function_result_entry(self): "result": '{"status": "ok"}', "tool_call_id": "tc1", } - result = _parse(serialize_input_message(msg)) + result = format_input_message(msg) self.assertDictEqual( result, { @@ -292,13 +284,13 @@ def test_function_result_entry(self): def test_function_result_entry_without_tool_call_id(self): msg = {"type": "function.result", "result": "data"} - result = _parse(serialize_input_message(msg)) + result = format_input_message(msg) self.assertNotIn("id", result["parts"][0]) # -- Edge cases ----------------------------------------------------------- def test_missing_role_defaults_to_unknown(self): - result = _parse(serialize_input_message({"content": "orphan"})) + result = format_input_message({"content": "orphan"}) self.assertDictEqual( result, { @@ -308,17 +300,17 @@ def test_missing_role_defaults_to_unknown(self): ) def test_no_content_no_tool_calls(self): - result = _parse(serialize_input_message({"role": "user"})) + result = format_input_message({"role": "user"}) self.assertDictEqual(result, {"role": "user", "parts": []}) -class TestSerializeOutputMessage(unittest.TestCase): +class TestFormatOutputMessage(unittest.TestCase): def test_simple_assistant_response(self): choice = { "message": {"role": "assistant", "content": "hello"}, "finish_reason": "stop", } - result = _parse(serialize_output_message(choice)) + result = format_output_message(choice) self.assertDictEqual( result, { @@ -339,7 +331,7 @@ def test_tool_calls_response(self): }, "finish_reason": "tool_calls", } - result = _parse(serialize_output_message(choice)) + result = format_output_message(choice) self.assertEqual(result["finish_reason"], "tool_calls") self.assertListEqual( [p["type"] for p in result["parts"]], @@ -347,7 +339,7 @@ def test_tool_calls_response(self): ) def test_missing_message(self): - result = _parse(serialize_output_message({})) + result = format_output_message({}) self.assertDictEqual( result, { @@ -358,7 +350,7 @@ def test_missing_message(self): ) def test_message_is_none(self): - result = _parse(serialize_output_message({"message": None})) + result = format_output_message({"message": None}) self.assertDictEqual( result, { @@ -370,7 +362,7 @@ def test_message_is_none(self): def test_defaults_role_to_assistant(self): choice = {"message": {"content": "hi"}, "finish_reason": "stop"} - result = _parse(serialize_output_message(choice)) + result = format_output_message(choice) self.assertDictEqual( result, { @@ -381,7 +373,7 @@ def test_defaults_role_to_assistant(self): ) -class TestSerializeToolDefinition(unittest.TestCase): +class TestFormatToolDefinition(unittest.TestCase): def test_full_definition(self): tool = { "type": "function", @@ -391,11 +383,11 @@ def test_full_definition(self): "parameters": {"type": "object", "properties": {}}, }, } - serialized = serialize_tool_definition(tool) - self.assertIsNotNone(serialized) - assert serialized is not None + result = format_tool_definition(tool) + self.assertIsNotNone(result) + assert result is not None self.assertDictEqual( - _parse(serialized), + result, { "type": "function", "name": "get_weather", @@ -407,11 +399,11 @@ def test_full_definition(self): def test_minimal_definition(self): """Only name, no description or parameters.""" tool = {"function": {"name": "f"}} - serialized = serialize_tool_definition(tool) - self.assertIsNotNone(serialized) - assert serialized is not None + result = format_tool_definition(tool) + self.assertIsNotNone(result) + assert result is not None self.assertDictEqual( - _parse(serialized), + result, { "type": "function", "name": "f", @@ -419,23 +411,23 @@ def test_minimal_definition(self): ) def test_missing_function_returns_none(self): - self.assertIsNone(serialize_tool_definition({"type": "function"})) + self.assertIsNone(format_tool_definition({"type": "function"})) def test_empty_function_returns_none(self): - self.assertIsNone(serialize_tool_definition({"function": {}})) + self.assertIsNone(format_tool_definition({"function": {}})) def test_missing_name_returns_none(self): self.assertIsNone( - serialize_tool_definition({"function": {"description": "no name"}}) + format_tool_definition({"function": {"description": "no name"}}) ) def test_custom_type_preserved(self): tool = {"type": "custom_tool", "function": {"name": "f"}} - serialized = serialize_tool_definition(tool) - self.assertIsNotNone(serialized) - assert serialized is not None + result = format_tool_definition(tool) + self.assertIsNotNone(result) + assert result is not None self.assertDictEqual( - _parse(serialized), + result, { "type": "custom_tool", "name": "f", diff --git a/src/mistralai/extra/tests/test_otel_tracing.py b/src/mistralai/extra/tests/test_otel_tracing.py index ff30ba0c..f4673c3f 100644 --- a/src/mistralai/extra/tests/test_otel_tracing.py +++ b/src/mistralai/extra/tests/test_otel_tracing.py @@ -142,8 +142,8 @@ def _make_streaming_httpx_response(sse_body: bytes) -> httpx.Response: def _parse_json_list(span_attr): - """Parse a span attribute containing a list of JSON-encoded strings.""" - return [json.loads(m) for m in span_attr] + """Parse a span attribute containing a JSON-encoded array string.""" + return json.loads(span_attr) # -- Tests ---------------------------------------------------------------------