diff --git a/src/mistralai/client/_hooks/tracing.py b/src/mistralai/client/_hooks/tracing.py index b353d9bd..632320ce 100644 --- a/src/mistralai/client/_hooks/tracing.py +++ b/src/mistralai/client/_hooks/tracing.py @@ -22,10 +22,12 @@ logger = logging.getLogger(__name__) +_SPAN_EXT_KEY = "_tracing_span" + + class TracingHook(BeforeRequestHook, AfterSuccessHook, AfterErrorHook): def __init__(self) -> None: self.tracing_enabled, self.tracer = get_or_create_otel_tracer() - self.request_span: Optional[Span] = None def before_request( self, hook_ctx: BeforeRequestContext, request: httpx.Request @@ -33,27 +35,34 @@ def before_request( # Refresh tracer/provider per request so tracing can be enabled if the # application configures OpenTelemetry after the client is instantiated. self.tracing_enabled, self.tracer = get_or_create_otel_tracer() - self.request_span = None - request, self.request_span = get_traced_request_and_span( + request, span = get_traced_request_and_span( tracing_enabled=self.tracing_enabled, tracer=self.tracer, - span=self.request_span, + span=None, operation_id=hook_ctx.operation_id, request=request, ) + request.extensions[_SPAN_EXT_KEY] = span return request + @staticmethod + def _get_span(response: Optional[httpx.Response]) -> Optional[Span]: + try: + return response.request.extensions.get(_SPAN_EXT_KEY) if response is not None else None + except RuntimeError: + return None + def after_success( self, hook_ctx: AfterSuccessContext, response: httpx.Response ) -> Union[httpx.Response, Exception]: + span = self._get_span(response) response = get_traced_response( tracing_enabled=self.tracing_enabled, tracer=self.tracer, - span=self.request_span, + span=span, operation_id=hook_ctx.operation_id, response=response, ) - self.request_span = None return response def after_error( @@ -63,13 +72,13 @@ def after_error( error: Optional[Exception], ) -> Union[Tuple[Optional[httpx.Response], Optional[Exception]], Exception]: if response: + span = self._get_span(response) response, error = get_response_and_error( tracing_enabled=self.tracing_enabled, tracer=self.tracer, - span=self.request_span, + span=span, operation_id=hook_ctx.operation_id, response=response, error=error, ) - self.request_span = None return response, error diff --git a/src/mistralai/extra/tests/test_otel_tracing.py b/src/mistralai/extra/tests/test_otel_tracing.py index ff30ba0c..24785801 100644 --- a/src/mistralai/extra/tests/test_otel_tracing.py +++ b/src/mistralai/extra/tests/test_otel_tracing.py @@ -70,6 +70,7 @@ UsageInfo, UserMessage, ) +from mistralai.client.sdk import Mistral from mistralai.extra.observability.otel import TracedResponse from mistralai.extra.run.tools import ( RunFunction, @@ -196,6 +197,9 @@ def _run_hook_lifecycle( self.assertNotIsInstance(hooked_request, Exception) assert isinstance(hooked_request, httpx.Request) + # Link response to request, as httpx.Client.send() does in real usage. + response.request = hooked_request + result = hook.after_success(AfterSuccessContext(hook_ctx), response) self.assertNotIsInstance(result, Exception) @@ -228,6 +232,9 @@ def _run_hook_error_lifecycle( self.assertNotIsInstance(hooked_request, Exception) assert isinstance(hooked_request, httpx.Request) + # Link response to request, as httpx.Client.send() does in real usage. + response.request = hooked_request + result = hook.after_error(AfterErrorContext(hook_ctx), response, error) self.assertNotIsInstance(result, Exception) @@ -1526,7 +1533,6 @@ def failing_tool(x: int) -> str: "Expected an exception event on the span", ) - # -- Baggage propagation: gen_ai.conversation.id --------------------------- def test_conversation_id_from_baggage(self): @@ -1597,6 +1603,113 @@ def test_no_conversation_id_without_baggage(self): span = self._get_single_span() self.assertNotIn("gen_ai.conversation.id", span.attributes) + # -- Concurrency: interleaved requests on shared hook ---------------------- + + def test_concurrent_async_requests_get_correct_spans(self): + """Two concurrent async chat completions through a real Mistral client. + + Uses asyncio.gather to fire two requests simultaneously through the + SDK. A mock transport with an asyncio.Event gate guarantees both + before_request hooks run before either after_success, reproducing + the interleaving that corrupts self.request_span. + + Expected: each span carries its own request model AND response id. + """ + + # Gate ensures both requests have entered the transport (i.e. both + # before_request hooks have already run) before either returns. + gate = asyncio.Event() + arrived = 0 + + async def _mock_handler(request: httpx.Request) -> httpx.Response: + nonlocal arrived + arrived += 1 + if arrived < 2: + # First request: wait for the second to arrive. + await gate.wait() + else: + # Second request: both hooks have fired, unblock the first. + gate.set() + + body = json.loads(request.content) + model = body["model"] + resp = _dump( + ChatCompletionResponse( + id=f"cmpl-{model}", + object="chat.completion", + model=model, + created=1700000000, + choices=[ + ChatCompletionChoice( + index=0, + message=AssistantMessage(content=f"Reply from {model}"), + finish_reason="stop", + ) + ], + usage=UsageInfo( + prompt_tokens=10, + completion_tokens=5, + total_tokens=15, + ), + ) + ) + return httpx.Response(200, json=resp) + + transport = httpx.MockTransport(_mock_handler) + async_client = httpx.AsyncClient(transport=transport) + + client = Mistral( + api_key="test-key", + async_client=async_client, + ) + + async def _run(): + return await asyncio.gather( + client.chat.complete_async( + model="mistral-large-latest", + messages=[{"role": "user", "content": "A"}], + ), + client.chat.complete_async( + model="mistral-small-latest", + messages=[{"role": "user", "content": "B"}], + ), + ) + + results = asyncio.get_event_loop().run_until_complete(_run()) + + # Both calls must succeed + self.assertEqual(len(results), 2) + + # --- Verify spans --- + spans = self._get_finished_spans() + spans_by_resp = {s.attributes.get("gen_ai.response.id"): s for s in spans} + + # Both spans must have a response.id + self.assertIn( + "cmpl-mistral-large-latest", + spans_by_resp, + "Span for large model must exist", + ) + self.assertIn( + "cmpl-mistral-small-latest", + spans_by_resp, + "Span for small model must exist", + ) + + # Each span's request model must match its response + self.assertEqual( + spans_by_resp["cmpl-mistral-large-latest"].attributes.get( + "gen_ai.request.model" + ), + "mistral-large-latest", + ) + self.assertEqual( + spans_by_resp["cmpl-mistral-small-latest"].attributes.get( + "gen_ai.request.model" + ), + "mistral-small-latest", + ) + if __name__ == "__main__": unittest.main()