Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/builtin_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ class ImageGenerationTool(AbstractBuiltinTool):
Supported by:
* OpenAI Responses. Only supported for 'png' and 'webp' output formats.
* Google (Vertex AI only). Only supported for 'jpeg' output format.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the docs as well

"""

output_format: Literal['png', 'webp', 'jpeg'] | None = None
Expand All @@ -309,6 +310,7 @@ class ImageGenerationTool(AbstractBuiltinTool):
Supported by:
* OpenAI Responses. Default: 'png'.
* Google (Vertex AI only).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add in the default, which I'm assuming is also PNG?

"""

partial_images: int = 0
Expand Down
53 changes: 42 additions & 11 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@
_GOOGLE_IMAGE_SIZE = Literal['1K', '2K', '4K']
_GOOGLE_IMAGE_SIZES: tuple[_GOOGLE_IMAGE_SIZE, ...] = _utils.get_args(_GOOGLE_IMAGE_SIZE)

_GOOGLE_IMAGE_OUTPUT_FORMAT = Literal['png', 'jpeg', 'webp']
_GOOGLE_IMAGE_OUTPUT_FORMATS: tuple[_GOOGLE_IMAGE_OUTPUT_FORMAT, ...] = _utils.get_args(_GOOGLE_IMAGE_OUTPUT_FORMAT)


class GoogleModelSettings(ModelSettings, total=False):
"""Settings used for a Gemini model request."""
Expand Down Expand Up @@ -358,6 +361,44 @@ async def request_stream(
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore

def _build_image_config(self, tool: ImageGenerationTool) -> ImageConfigDict:
"""Build ImageConfigDict from ImageGenerationTool with validation."""
image_config = ImageConfigDict()

if tool.aspect_ratio is not None:
image_config['aspect_ratio'] = tool.aspect_ratio

if tool.size is not None:
if tool.size not in _GOOGLE_IMAGE_SIZES:
raise UserError(
f'Google image generation only supports `size` values: {_GOOGLE_IMAGE_SIZES}. '
f'Got: {tool.size!r}. Omit `size` to use the default (1K).'
)
image_config['image_size'] = tool.size

if tool.output_format is not None:
if tool.output_format not in _GOOGLE_IMAGE_OUTPUT_FORMATS:
raise UserError(
f'Google image generation only supports `output_format` values: {_GOOGLE_IMAGE_OUTPUT_FORMATS}. '
f'Got: {tool.output_format!r}.'
)
image_config['output_mime_type'] = f'image/{tool.output_format}'

if tool.output_compression != 100:
if not (0 <= tool.output_compression <= 100):
raise UserError(
f'Google image generation `output_compression` must be between 0 and 100. '
f'Got: {tool.output_compression}.'
)
if tool.output_format not in (None, 'jpeg'):
raise UserError(
f'Google image generation `output_compression` is only supported for JPEG format. '
f'Got format: {tool.output_format!r}. Either set `output_format="jpeg"` or remove `output_compression`.'
)
image_config['output_compression_quality'] = tool.output_compression

return image_config

def _get_tools(
self, model_request_parameters: ModelRequestParameters
) -> tuple[list[ToolDict] | None, ImageConfigDict | None]:
Expand Down Expand Up @@ -387,17 +428,7 @@ def _get_tools(
raise UserError(
"`ImageGenerationTool` is not supported by this model. Use a model with 'image' in the name instead."
)

image_config = ImageConfigDict()
if tool.aspect_ratio is not None:
image_config['aspect_ratio'] = tool.aspect_ratio
if tool.size is not None:
if tool.size not in _GOOGLE_IMAGE_SIZES:
raise UserError(
f'Google image generation only supports `size` values: {_GOOGLE_IMAGE_SIZES}. '
f'Got: {tool.size!r}. Omit `size` to use the default (1K).'
)
image_config['image_size'] = tool.size
image_config = self._build_image_config(tool)
else: # pragma: no cover
raise UserError(
f'`{tool.__class__.__name__}` is not supported by `GoogleModel`. If it should be, please file an issue.'
Expand Down

Large diffs are not rendered by default.

104 changes: 104 additions & 0 deletions tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -3694,6 +3694,110 @@ async def test_google_image_generation_auto_size_raises_error(google_provider: G
model._get_tools(params) # pyright: ignore[reportPrivateUsage]


async def test_google_image_generation_tool_output_format(google_provider: GoogleProvider) -> None:
"""Test that ImageGenerationTool.output_format is mapped to ImageConfigDict.output_mime_type."""
model = GoogleModel('gemini-3-pro-image-preview', provider=google_provider)
params = ModelRequestParameters(builtin_tools=[ImageGenerationTool(output_format='png')])

tools, image_config = model._get_tools(params) # pyright: ignore[reportPrivateUsage]
assert tools is None
assert image_config == {'output_mime_type': 'image/png'}


async def test_google_image_generation_tool_unsupported_format_raises_error(google_provider: GoogleProvider) -> None:
"""Test that unsupported output_format values raise an error."""
model = GoogleModel('gemini-3-pro-image-preview', provider=google_provider)
# 'gif' is not supported by Google
params = ModelRequestParameters(builtin_tools=[ImageGenerationTool(output_format='gif')]) # type: ignore

with pytest.raises(UserError, match='Google image generation only supports `output_format` values'):
model._get_tools(params) # pyright: ignore[reportPrivateUsage]


async def test_google_image_generation_tool_output_compression(google_provider: GoogleProvider) -> None:
"""Test that ImageGenerationTool.output_compression is mapped to ImageConfigDict.output_compression_quality."""
model = GoogleModel('gemini-3-pro-image-preview', provider=google_provider)
params = ModelRequestParameters(builtin_tools=[ImageGenerationTool(output_compression=85)])

tools, image_config = model._get_tools(params) # pyright: ignore[reportPrivateUsage]
assert tools is None
assert image_config == {'output_compression_quality': 85}


async def test_google_image_generation_tool_compression_validation(google_provider: GoogleProvider) -> None:
"""Test compression validation: range and JPEG-only."""
model = GoogleModel('gemini-3-pro-image-preview', provider=google_provider)

# Invalid range: > 100
with pytest.raises(UserError, match='`output_compression` must be between 0 and 100'):
model._get_tools(ModelRequestParameters(builtin_tools=[ImageGenerationTool(output_compression=101)])) # pyright: ignore[reportPrivateUsage]

# Invalid range: < 0
with pytest.raises(UserError, match='`output_compression` must be between 0 and 100'):
model._get_tools(ModelRequestParameters(builtin_tools=[ImageGenerationTool(output_compression=-1)])) # pyright: ignore[reportPrivateUsage]

# Non-JPEG format (PNG)
with pytest.raises(UserError, match='`output_compression` is only supported for JPEG format'):
model._get_tools( # pyright: ignore[reportPrivateUsage]
ModelRequestParameters(builtin_tools=[ImageGenerationTool(output_format='png', output_compression=90)])
)

# Non-JPEG format (WebP)
with pytest.raises(UserError, match='`output_compression` is only supported for JPEG format'):
model._get_tools( # pyright: ignore[reportPrivateUsage]
ModelRequestParameters(builtin_tools=[ImageGenerationTool(output_format='webp', output_compression=90)])
)


async def test_google_image_generation_rejected_by_gemini_api(
allow_model_requests: None, google_provider: GoogleProvider
) -> None:
"""Test that output_format and compression are rejected by Gemini API (google-gla)."""
model = GoogleModel('gemini-2.5-flash-image', provider=google_provider)

# Test output_format rejection
agent = Agent(model, builtin_tools=[ImageGenerationTool(output_format='png')], output_type=BinaryImage)
with pytest.raises(ValueError, match='output_mime_type parameter is not supported in Gemini API'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of hitting errors, let's ignore the field silently as we would've done before. We've already documented it only works with Vertex, so we can check self.system == 'google-vertex' inside GoogleModel

await agent.run('Generate an image.')

# Test output_compression rejection (compression only, no format)
agent = Agent(model, builtin_tools=[ImageGenerationTool(output_compression=90)], output_type=BinaryImage)
with pytest.raises(ValueError, match='output_compression_quality parameter is not supported in Gemini API'):
await agent.run('Generate an image.')


async def test_google_vertexai_image_generation_with_output_format(
allow_model_requests: None, vertex_provider: GoogleProvider
): # pragma: lax no cover
"""Test that output_format works with Vertex AI."""
model = GoogleModel('gemini-2.5-flash-image', provider=vertex_provider)
agent = Agent(
model,
builtin_tools=[ImageGenerationTool(output_format='jpeg', output_compression=85)],
output_type=BinaryImage,
)

result = await agent.run('Generate an image of an axolotl.')
assert result.output.media_type == 'image/jpeg'


async def test_google_image_generation_tool_all_fields(google_provider: GoogleProvider) -> None:
"""Test that all ImageGenerationTool fields are mapped correctly."""
model = GoogleModel('gemini-3-pro-image-preview', provider=google_provider)
params = ModelRequestParameters(
builtin_tools=[ImageGenerationTool(aspect_ratio='16:9', size='2K', output_format='jpeg', output_compression=90)]
)

tools, image_config = model._get_tools(params) # pyright: ignore[reportPrivateUsage]
assert tools is None
assert image_config == {
'aspect_ratio': '16:9',
'image_size': '2K',
'output_mime_type': 'image/jpeg',
'output_compression_quality': 90,
}


async def test_google_vertexai_image_generation(allow_model_requests: None, vertex_provider: GoogleProvider):
model = GoogleModel('gemini-2.5-flash-image', provider=vertex_provider)

Expand Down