diff --git a/instrumentation/opentelemetry-instrumentation-google-genai/.changelog/167.changed b/instrumentation/opentelemetry-instrumentation-google-genai/.changelog/167.changed new file mode 100644 index 00000000..f19a3d00 --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-google-genai/.changelog/167.changed @@ -0,0 +1,2 @@ +Update the `generate_content` streaming method variants to return streaming wrapper classes to enable users to iterate over the stream of responses from the model. + diff --git a/instrumentation/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/generate_content.py b/instrumentation/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/generate_content.py index d7ceb328..e99e6707 100644 --- a/instrumentation/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/generate_content.py +++ b/instrumentation/opentelemetry-instrumentation-google-genai/src/opentelemetry/instrumentation/google_genai/generate_content.py @@ -3,7 +3,7 @@ import json import os -from collections.abc import Callable +from collections.abc import AsyncIterable, Callable, Iterable from typing import Any, Optional, Union from google.genai.models import AsyncModels, Models @@ -27,6 +27,10 @@ from opentelemetry.util.genai.invocation import ( InferenceInvocation, ) +from opentelemetry.util.genai.stream import ( + AsyncStreamWrapper, + SyncStreamWrapper, +) from opentelemetry.util.genai.types import ( FunctionToolDefinition, GenericToolDefinition, @@ -484,6 +488,82 @@ def _execute( return instrumented_generate_content +class GenerateContentStreamWrapper(SyncStreamWrapper[GenerateContentResponse]): + def __init__( + self, + stream: Iterable[GenerateContentResponse], + invocation: InferenceInvocation, + telemetry_handler: TelemetryHandler, + ) -> None: + super().__init__(stream) + self._self_invocation = invocation + self._self_telemetry_handler = telemetry_handler + self._self_finish_reasons: list[str] = [] + self._self_candidates = [] + + def _process_chunk(self, chunk: GenerateContentResponse) -> None: + _apply_response_attributes( + chunk, + self._self_finish_reasons, + self._self_invocation, + ) + if chunk.candidates: + self._self_candidates.extend(chunk.candidates) + + def _on_stream_end(self) -> None: + if ( + self._self_telemetry_handler.should_capture_content() + and self._self_candidates + ): + self._self_invocation.output_messages = to_output_messages( + candidates=self._self_candidates + ) + self._self_invocation.stop() + + def _on_stream_error(self, error: BaseException) -> None: + self._self_invocation.fail(error) + + +class AsyncGenerateContentStreamWrapper( + AsyncStreamWrapper[GenerateContentResponse] +): + def __init__( + self, + stream: AsyncIterable[GenerateContentResponse], + invocation: InferenceInvocation, + telemetry_handler: TelemetryHandler, + ) -> None: + super().__init__(stream) + # _self_ is a naming convention used by the wrapt library to differentiate + # between attributes on the wrapped function and the original function. + self._self_invocation = invocation + self._self_telemetry_handler = telemetry_handler + self._self_finish_reasons: list[str] = [] + self._self_candidates = [] + + def _process_chunk(self, chunk: GenerateContentResponse) -> None: + _apply_response_attributes( + chunk, + self._self_finish_reasons, + self._self_invocation, + ) + if chunk.candidates: + self._self_candidates.extend(chunk.candidates) + + def _on_stream_end(self) -> None: + if ( + self._self_telemetry_handler.should_capture_content() + and self._self_candidates + ): + self._self_invocation.output_messages = to_output_messages( + candidates=self._self_candidates + ) + self._self_invocation.stop() + + def _on_stream_error(self, error: BaseException) -> None: + self._self_invocation.fail(error) + + def _create_instrumented_generate_content_stream( telemetry_handler: TelemetryHandler, generate_content_config_key_allowlist: AllowList, @@ -508,57 +588,44 @@ def _execute( telemetry_handler, config, ) - finish_reasons = [] - with telemetry_handler.inference( + invocation = telemetry_handler.inference( provider=_determine_genai_system(instance), request_model=model, operation_name="generate_content", - ) as invocation: - _apply_request_attributes( - wrapped_config, - generate_content_config_key_allowlist, - invocation, - ) - invocation.attributes.update( - _get_extra_generate_content_attributes() - ) - invocation.tool_definitions = _maybe_get_tool_definitions( - wrapped_config - ) + ) + _apply_request_attributes( + wrapped_config, + generate_content_config_key_allowlist, + invocation, + ) + invocation.attributes.update( + _get_extra_generate_content_attributes() + ) + invocation.tool_definitions = _maybe_get_tool_definitions( + wrapped_config + ) - if telemetry_handler.should_capture_content(): - invocation.input_messages = to_input_messages( - contents=transformers.t_contents(contents) + if telemetry_handler.should_capture_content(): + invocation.input_messages = to_input_messages( + contents=transformers.t_contents(contents) + ) + if wrapped_config.system_instruction: + invocation.system_instruction = to_system_instructions( + content=transformers.t_contents( + wrapped_config.system_instruction + )[0] ) - if wrapped_config.system_instruction: - invocation.system_instruction = to_system_instructions( - content=transformers.t_contents( - wrapped_config.system_instruction - )[0] - ) - candidates = [] - try: - for resp in wrapped( - model=model, - contents=contents, - config=wrapped_config if has_wrapped_tools else config, - *_args, - **_kwargs, - ): - _apply_response_attributes( - resp, finish_reasons, invocation - ) - if resp.candidates: - candidates.extend(resp.candidates) - yield resp - finally: - if ( - telemetry_handler.should_capture_content() - and candidates - ): - invocation.output_messages = to_output_messages( - candidates=candidates - ) + return GenerateContentStreamWrapper( + wrapped( + model=model, + contents=contents, + config=wrapped_config if has_wrapped_tools else config, + *_args, + **_kwargs, + ), + invocation, + telemetry_handler, + ) return _execute(*args, **kwargs) @@ -670,7 +737,6 @@ async def _execute( telemetry_handler, config, ) - finish_reasons = [] invocation = telemetry_handler.inference( provider=_determine_genai_system(instance), request_model=model, @@ -698,43 +764,17 @@ async def _execute( wrapped_config.system_instruction )[0] ) - - async def _response_async_generator_wrapper(): - candidates = [] - try: - async for resp in await wrapped( - model=model, - contents=contents, - config=wrapped_config if has_wrapped_tools else config, - *_args, - **_kwargs, - ): - _apply_response_attributes( - resp, finish_reasons, invocation - ) - if resp.candidates: - candidates.extend(resp.candidates) - yield resp - if ( - telemetry_handler.should_capture_content() - and candidates - ): - invocation.output_messages = to_output_messages( - candidates=candidates - ) - invocation.stop() - except Exception as exc: - if ( - telemetry_handler.should_capture_content() - and candidates - ): - invocation.output_messages = to_output_messages( - candidates=candidates - ) - invocation.fail(exc) - raise - - return _response_async_generator_wrapper() + return AsyncGenerateContentStreamWrapper( + await wrapped( + model=model, + contents=contents, + config=wrapped_config if has_wrapped_tools else config, + *_args, + **_kwargs, + ), + invocation, + telemetry_handler, + ) return await _execute(*args, **kwargs)