Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/lint_custom_code.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ echo "-> running on examples"
uv run mypy examples/ \
--exclude 'audio/' || ERRORS=1
echo "-> running on extra"
uv run mypy src/mistralai/extra/ || ERRORS=1
uv run --all-extras mypy src/mistralai/extra/ || ERRORS=1
echo "-> running on hooks"
uv run mypy src/mistralai/client/_hooks/ \
--exclude __init__.py --exclude sdkhooks.py --exclude types.py || ERRORS=1
Expand All @@ -48,7 +48,7 @@ echo "Running pyright..."
# TODO: Uncomment once the examples are fixed
# uv run pyright examples/ || ERRORS=1
echo "-> running on extra"
uv run pyright src/mistralai/extra/ || ERRORS=1
uv run --all-extras pyright src/mistralai/extra/ || ERRORS=1
echo "-> running on hooks"
uv run pyright src/mistralai/client/_hooks/ || ERRORS=1
echo "-> running on azure hooks"
Expand Down
1 change: 1 addition & 0 deletions src/mistralai/extra/py.typed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Marker file for PEP 561. The package enables type hints.
134 changes: 134 additions & 0 deletions src/mistralai/extra/tests/test_workflow_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Tests for workflow encoding configuration lifecycle."""

import gc

import pytest
from pydantic import SecretStr

from mistralai.client import Mistral
from mistralai.client._hooks.workflow_encoding_hook import (
_workflow_configs,
_ENCODING_CONFIG_ID_ATTR,
configure_workflow_encoding,
)
from mistralai.extra.workflows import (
WorkflowEncodingConfig,
PayloadEncryptionConfig,
PayloadEncryptionMode,
)


@pytest.fixture
def encryption_config() -> WorkflowEncodingConfig:
"""Create a test encryption config."""
return WorkflowEncodingConfig(
payload_encryption=PayloadEncryptionConfig(
mode=PayloadEncryptionMode.FULL,
main_key=SecretStr("0" * 64), # 256-bit key in hex
)
)


def test_payload_encoder_cleanup_on_client_gc(encryption_config: WorkflowEncodingConfig):
"""Test that PayloadEncoder is cleaned up when client is garbage collected."""
initial_config_count = len(_workflow_configs)

# Create client and configure encoding
client = Mistral(api_key="test-key")
configure_workflow_encoding(
encryption_config,
namespace="test-namespace",
sdk_config=client.sdk_configuration,
)

# Verify config was added
config_id = getattr(client.sdk_configuration, _ENCODING_CONFIG_ID_ATTR)
assert config_id is not None
assert config_id in _workflow_configs
assert len(_workflow_configs) == initial_config_count + 1

# Delete client and force garbage collection
del client
gc.collect()

# Verify config was cleaned up
assert config_id not in _workflow_configs
assert len(_workflow_configs) == initial_config_count


def test_multiple_clients_independent_configs(encryption_config: WorkflowEncodingConfig):
"""Test that multiple clients have independent configs."""
initial_config_count = len(_workflow_configs)

# Create two clients with different namespaces
client1 = Mistral(api_key="test-key-1")
client2 = Mistral(api_key="test-key-2")

configure_workflow_encoding(
encryption_config,
namespace="namespace-1",
sdk_config=client1.sdk_configuration,
)
configure_workflow_encoding(
encryption_config,
namespace="namespace-2",
sdk_config=client2.sdk_configuration,
)

# Verify both configs exist
config_id1 = getattr(client1.sdk_configuration, _ENCODING_CONFIG_ID_ATTR)
config_id2 = getattr(client2.sdk_configuration, _ENCODING_CONFIG_ID_ATTR)
assert config_id1 != config_id2
assert len(_workflow_configs) == initial_config_count + 2

# Verify namespaces are independent
assert _workflow_configs[config_id1].namespace == "namespace-1"
assert _workflow_configs[config_id2].namespace == "namespace-2"

# Delete first client
del client1
gc.collect()

# First config should be cleaned up, second should remain
assert config_id1 not in _workflow_configs
assert config_id2 in _workflow_configs
assert len(_workflow_configs) == initial_config_count + 1

# Delete second client
del client2
gc.collect()

# Both configs should be cleaned up
assert config_id2 not in _workflow_configs
assert len(_workflow_configs) == initial_config_count


def test_reconfigure_same_client(encryption_config: WorkflowEncodingConfig):
"""Test that reconfiguring the same client updates the config."""
client = Mistral(api_key="test-key")

# Initial configuration
configure_workflow_encoding(
encryption_config,
namespace="namespace-v1",
sdk_config=client.sdk_configuration,
)

config_id = getattr(client.sdk_configuration, _ENCODING_CONFIG_ID_ATTR)
assert _workflow_configs[config_id].namespace == "namespace-v1"

# Reconfigure with different namespace
configure_workflow_encoding(
encryption_config,
namespace="namespace-v2",
sdk_config=client.sdk_configuration,
)

# Should use same config_id but updated namespace
assert getattr(client.sdk_configuration, _ENCODING_CONFIG_ID_ATTR) == config_id
assert _workflow_configs[config_id].namespace == "namespace-v2"

# Cleanup
del client
gc.collect()
assert config_id not in _workflow_configs
76 changes: 76 additions & 0 deletions src/mistralai/extra/workflows/encoding/storage/_azure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import annotations

from typing import Any, cast

from azure.core.exceptions import ResourceNotFoundError
from azure.storage.blob.aio import BlobServiceClient
from .blob_storage import BlobNotFoundError, BlobStorage


class AzureBlobStorage(BlobStorage):
def __init__(
self,
container_name: str,
azure_connection_string: str,
prefix: str | None = None,
):
self.container_name = container_name
self.connection_string = azure_connection_string
self.prefix = prefix or ""
self._service_client: BlobServiceClient | None = None
self._container_client: Any = None

def _get_full_key(self, key: str) -> str:
if not self.prefix:
return key
if key.startswith(self.prefix):
return key
return f"{self.prefix}/{key}"

async def __aenter__(self) -> "AzureBlobStorage":
self._service_client = BlobServiceClient.from_connection_string(
self.connection_string
)
assert self._service_client is not None
self._container_client = self._service_client.get_container_client(
self.container_name
)
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._service_client:
await self._service_client.close()

async def upload_blob(self, key: str, content: bytes) -> str:
full_key = self._get_full_key(key)
blob_client = self._container_client.get_blob_client(full_key)
await blob_client.upload_blob(content, overwrite=True)
return cast(str, blob_client.url)

async def get_blob(self, key: str) -> bytes:
full_key = self._get_full_key(key)
blob_client = self._container_client.get_blob_client(full_key)
try:
stream = await blob_client.download_blob()
return cast(bytes, await stream.readall())
except ResourceNotFoundError as e:
raise BlobNotFoundError(f"Blob not found: {key}") from e

async def get_blob_properties(self, key: str) -> dict[str, Any] | None:
full_key = self._get_full_key(key)
blob_client = self._container_client.get_blob_client(full_key)
try:
props = await blob_client.get_blob_properties()
return {"size": props.size, "last_modified": props.last_modified}
except ResourceNotFoundError:
return None

async def delete_blob(self, key: str) -> None:
full_key = self._get_full_key(key)
blob_client = self._container_client.get_blob_client(full_key)
await blob_client.delete_blob()

async def blob_exists(self, key: str) -> bool:
full_key = self._get_full_key(key)
blob_client = self._container_client.get_blob_client(full_key)
return cast(bool, await blob_client.exists())
81 changes: 81 additions & 0 deletions src/mistralai/extra/workflows/encoding/storage/_gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

from typing import Any, cast

import aiohttp
from gcloud.aio.storage import Storage

from .blob_storage import BlobNotFoundError, BlobStorage


class GCSBlobStorage(BlobStorage):
def __init__(self, bucket_id: str, prefix: str | None = None):
self.bucket_id = bucket_id
self.prefix = prefix or ""
self._storage: Storage | None = None
self._session: aiohttp.ClientSession | None = None

def _get_full_key(self, key: str) -> str:
if not self.prefix:
return key
if key.startswith(self.prefix):
return key
return f"{self.prefix}/{key}"

async def __aenter__(self) -> "GCSBlobStorage":
self._session = aiohttp.ClientSession()
self._storage = Storage(session=self._session)
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._storage:
await self._storage.close()
if self._session:
await self._session.close()

async def upload_blob(self, key: str, content: bytes) -> str:
full_key = self._get_full_key(key)
assert self._storage is not None
response = await self._storage.upload(self.bucket_id, full_key, content)
return str(response.get("selfLink"))

async def get_blob(self, key: str) -> bytes:
full_key = self._get_full_key(key)
assert self._storage is not None
try:
content = await self._storage.download(self.bucket_id, full_key)
return cast(bytes, content)
except Exception as e:
if "404" in str(e) or "Not Found" in str(e):
raise BlobNotFoundError(f"Blob not found: {key}") from e
raise

async def get_blob_properties(self, key: str) -> dict[str, Any] | None:
full_key = self._get_full_key(key)
assert self._storage is not None
try:
metadata = await self._storage.download_metadata(self.bucket_id, full_key)
return {
"size": int(metadata.get("size", 0)),
"last_modified": metadata.get("updated"),
}
except Exception as e:
if "404" in str(e) or "Not Found" in str(e):
return None
raise

async def delete_blob(self, key: str) -> None:
full_key = self._get_full_key(key)
assert self._storage is not None
await self._storage.delete(self.bucket_id, full_key)

async def blob_exists(self, key: str) -> bool:
full_key = self._get_full_key(key)
assert self._storage is not None
try:
await self._storage.download_metadata(self.bucket_id, full_key)
return True
except Exception as e:
if "404" in str(e) or "Not Found" in str(e):
return False
raise
Loading
Loading