Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Update the `generate_content` streaming method variants to return streaming wrapper classes to enable users to iterate over the stream of responses from the model.

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import json
import os
from collections.abc import Callable
from collections.abc import AsyncIterable, Callable, Iterable
from typing import Any, Optional, Union

from google.genai.models import AsyncModels, Models
Expand All @@ -27,6 +27,10 @@
from opentelemetry.util.genai.invocation import (
InferenceInvocation,
)
from opentelemetry.util.genai.stream import (
AsyncStreamWrapper,
SyncStreamWrapper,
)
from opentelemetry.util.genai.types import (
FunctionToolDefinition,
GenericToolDefinition,
Expand Down Expand Up @@ -484,6 +488,82 @@ def _execute(
return instrumented_generate_content


class GenerateContentStreamWrapper(SyncStreamWrapper[GenerateContentResponse]):
def __init__(
self,
stream: Iterable[GenerateContentResponse],
invocation: InferenceInvocation,
telemetry_handler: TelemetryHandler,
) -> None:
super().__init__(stream)
self._self_invocation = invocation
self._self_telemetry_handler = telemetry_handler
self._self_finish_reasons: list[str] = []
self._self_candidates = []

def _process_chunk(self, chunk: GenerateContentResponse) -> None:
Comment thread
DylanRussell marked this conversation as resolved.
_apply_response_attributes(
chunk,
self._self_finish_reasons,
self._self_invocation,
)
if chunk.candidates:
self._self_candidates.extend(chunk.candidates)

def _on_stream_end(self) -> None:
if (
self._self_telemetry_handler.should_capture_content()
and self._self_candidates
):
self._self_invocation.output_messages = to_output_messages(
candidates=self._self_candidates
)
self._self_invocation.stop()

def _on_stream_error(self, error: BaseException) -> None:
self._self_invocation.fail(error)


class AsyncGenerateContentStreamWrapper(
AsyncStreamWrapper[GenerateContentResponse]
):
def __init__(
self,
stream: AsyncIterable[GenerateContentResponse],
invocation: InferenceInvocation,
telemetry_handler: TelemetryHandler,
) -> None:
super().__init__(stream)
# _self_ is a naming convention used by the wrapt library to differentiate
# between attributes on the wrapped function and the original function.
self._self_invocation = invocation
self._self_telemetry_handler = telemetry_handler
self._self_finish_reasons: list[str] = []
self._self_candidates = []

def _process_chunk(self, chunk: GenerateContentResponse) -> None:
_apply_response_attributes(
chunk,
self._self_finish_reasons,
self._self_invocation,
)
if chunk.candidates:
self._self_candidates.extend(chunk.candidates)

def _on_stream_end(self) -> None:
if (
self._self_telemetry_handler.should_capture_content()
and self._self_candidates
):
self._self_invocation.output_messages = to_output_messages(
candidates=self._self_candidates
)
self._self_invocation.stop()

def _on_stream_error(self, error: BaseException) -> None:
self._self_invocation.fail(error)


def _create_instrumented_generate_content_stream(
telemetry_handler: TelemetryHandler,
generate_content_config_key_allowlist: AllowList,
Expand All @@ -508,57 +588,44 @@ def _execute(
telemetry_handler,
config,
)
finish_reasons = []
with telemetry_handler.inference(
invocation = telemetry_handler.inference(
provider=_determine_genai_system(instance),
request_model=model,
operation_name="generate_content",
) as invocation:
_apply_request_attributes(
wrapped_config,
generate_content_config_key_allowlist,
invocation,
)
invocation.attributes.update(
_get_extra_generate_content_attributes()
)
invocation.tool_definitions = _maybe_get_tool_definitions(
wrapped_config
)
)
_apply_request_attributes(
wrapped_config,
generate_content_config_key_allowlist,
invocation,
)
invocation.attributes.update(
_get_extra_generate_content_attributes()
)
invocation.tool_definitions = _maybe_get_tool_definitions(
wrapped_config
)

if telemetry_handler.should_capture_content():
invocation.input_messages = to_input_messages(
contents=transformers.t_contents(contents)
if telemetry_handler.should_capture_content():
invocation.input_messages = to_input_messages(
contents=transformers.t_contents(contents)
)
if wrapped_config.system_instruction:
Comment thread
DylanRussell marked this conversation as resolved.
invocation.system_instruction = to_system_instructions(
content=transformers.t_contents(
wrapped_config.system_instruction
)[0]
)
if wrapped_config.system_instruction:
invocation.system_instruction = to_system_instructions(
content=transformers.t_contents(
wrapped_config.system_instruction
)[0]
)
candidates = []
try:
for resp in wrapped(
model=model,
contents=contents,
config=wrapped_config if has_wrapped_tools else config,
*_args,
**_kwargs,
):
_apply_response_attributes(
resp, finish_reasons, invocation
)
if resp.candidates:
candidates.extend(resp.candidates)
yield resp
finally:
if (
telemetry_handler.should_capture_content()
and candidates
):
invocation.output_messages = to_output_messages(
candidates=candidates
)
return GenerateContentStreamWrapper(
wrapped(
model=model,
contents=contents,
config=wrapped_config if has_wrapped_tools else config,
*_args,
**_kwargs,
),
invocation,
telemetry_handler,
)

return _execute(*args, **kwargs)

Expand Down Expand Up @@ -670,7 +737,6 @@ async def _execute(
telemetry_handler,
config,
)
finish_reasons = []
invocation = telemetry_handler.inference(
provider=_determine_genai_system(instance),
request_model=model,
Expand Down Expand Up @@ -698,43 +764,17 @@ async def _execute(
wrapped_config.system_instruction
)[0]
)

async def _response_async_generator_wrapper():
candidates = []
try:
async for resp in await wrapped(
model=model,
contents=contents,
config=wrapped_config if has_wrapped_tools else config,
*_args,
**_kwargs,
):
_apply_response_attributes(
resp, finish_reasons, invocation
)
if resp.candidates:
candidates.extend(resp.candidates)
yield resp
if (
telemetry_handler.should_capture_content()
and candidates
):
invocation.output_messages = to_output_messages(
candidates=candidates
)
invocation.stop()
except Exception as exc:
if (
telemetry_handler.should_capture_content()
and candidates
):
invocation.output_messages = to_output_messages(
candidates=candidates
)
invocation.fail(exc)
raise

return _response_async_generator_wrapper()
return AsyncGenerateContentStreamWrapper(
await wrapped(
model=model,
contents=contents,
config=wrapped_config if has_wrapped_tools else config,
*_args,
**_kwargs,
),
invocation,
telemetry_handler,
)

return await _execute(*args, **kwargs)

Expand Down