diff --git a/scripts/lint_custom_code.sh b/scripts/lint_custom_code.sh index 4baa3d88..872e2551 100755 --- a/scripts/lint_custom_code.sh +++ b/scripts/lint_custom_code.sh @@ -27,7 +27,7 @@ echo "-> running on examples" uv run mypy examples/ \ --exclude 'audio/' || ERRORS=1 echo "-> running on extra" -uv run mypy src/mistralai/extra/ || ERRORS=1 +uv run --all-extras mypy src/mistralai/extra/ || ERRORS=1 echo "-> running on hooks" uv run mypy src/mistralai/client/_hooks/ \ --exclude __init__.py --exclude sdkhooks.py --exclude types.py || ERRORS=1 @@ -48,7 +48,7 @@ echo "Running pyright..." # TODO: Uncomment once the examples are fixed # uv run pyright examples/ || ERRORS=1 echo "-> running on extra" -uv run pyright src/mistralai/extra/ || ERRORS=1 +uv run --all-extras pyright src/mistralai/extra/ || ERRORS=1 echo "-> running on hooks" uv run pyright src/mistralai/client/_hooks/ || ERRORS=1 echo "-> running on azure hooks" diff --git a/src/mistralai/extra/py.typed b/src/mistralai/extra/py.typed new file mode 100644 index 00000000..9df62a6e --- /dev/null +++ b/src/mistralai/extra/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561. The package enables type hints. \ No newline at end of file diff --git a/src/mistralai/extra/tests/test_workflow_encoding.py b/src/mistralai/extra/tests/test_workflow_encoding.py new file mode 100644 index 00000000..3703012b --- /dev/null +++ b/src/mistralai/extra/tests/test_workflow_encoding.py @@ -0,0 +1,134 @@ +"""Tests for workflow encoding configuration lifecycle.""" + +import gc + +import pytest +from pydantic import SecretStr + +from mistralai.client import Mistral +from mistralai.client._hooks.workflow_encoding_hook import ( + _workflow_configs, + _ENCODING_CONFIG_ID_ATTR, + configure_workflow_encoding, +) +from mistralai.extra.workflows import ( + WorkflowEncodingConfig, + PayloadEncryptionConfig, + PayloadEncryptionMode, +) + + +@pytest.fixture +def encryption_config() -> WorkflowEncodingConfig: + """Create a test encryption config.""" + return WorkflowEncodingConfig( + payload_encryption=PayloadEncryptionConfig( + mode=PayloadEncryptionMode.FULL, + main_key=SecretStr("0" * 64), # 256-bit key in hex + ) + ) + + +def test_payload_encoder_cleanup_on_client_gc(encryption_config: WorkflowEncodingConfig): + """Test that PayloadEncoder is cleaned up when client is garbage collected.""" + initial_config_count = len(_workflow_configs) + + # Create client and configure encoding + client = Mistral(api_key="test-key") + configure_workflow_encoding( + encryption_config, + namespace="test-namespace", + sdk_config=client.sdk_configuration, + ) + + # Verify config was added + config_id = getattr(client.sdk_configuration, _ENCODING_CONFIG_ID_ATTR) + assert config_id is not None + assert config_id in _workflow_configs + assert len(_workflow_configs) == initial_config_count + 1 + + # Delete client and force garbage collection + del client + gc.collect() + + # Verify config was cleaned up + assert config_id not in _workflow_configs + assert len(_workflow_configs) == initial_config_count + + +def test_multiple_clients_independent_configs(encryption_config: WorkflowEncodingConfig): + """Test that multiple clients have independent configs.""" + initial_config_count = len(_workflow_configs) + + # Create two clients with different namespaces + client1 = Mistral(api_key="test-key-1") + client2 = Mistral(api_key="test-key-2") + + configure_workflow_encoding( + encryption_config, + namespace="namespace-1", + sdk_config=client1.sdk_configuration, + ) + configure_workflow_encoding( + encryption_config, + namespace="namespace-2", + sdk_config=client2.sdk_configuration, + ) + + # Verify both configs exist + config_id1 = getattr(client1.sdk_configuration, _ENCODING_CONFIG_ID_ATTR) + config_id2 = getattr(client2.sdk_configuration, _ENCODING_CONFIG_ID_ATTR) + assert config_id1 != config_id2 + assert len(_workflow_configs) == initial_config_count + 2 + + # Verify namespaces are independent + assert _workflow_configs[config_id1].namespace == "namespace-1" + assert _workflow_configs[config_id2].namespace == "namespace-2" + + # Delete first client + del client1 + gc.collect() + + # First config should be cleaned up, second should remain + assert config_id1 not in _workflow_configs + assert config_id2 in _workflow_configs + assert len(_workflow_configs) == initial_config_count + 1 + + # Delete second client + del client2 + gc.collect() + + # Both configs should be cleaned up + assert config_id2 not in _workflow_configs + assert len(_workflow_configs) == initial_config_count + + +def test_reconfigure_same_client(encryption_config: WorkflowEncodingConfig): + """Test that reconfiguring the same client updates the config.""" + client = Mistral(api_key="test-key") + + # Initial configuration + configure_workflow_encoding( + encryption_config, + namespace="namespace-v1", + sdk_config=client.sdk_configuration, + ) + + config_id = getattr(client.sdk_configuration, _ENCODING_CONFIG_ID_ATTR) + assert _workflow_configs[config_id].namespace == "namespace-v1" + + # Reconfigure with different namespace + configure_workflow_encoding( + encryption_config, + namespace="namespace-v2", + sdk_config=client.sdk_configuration, + ) + + # Should use same config_id but updated namespace + assert getattr(client.sdk_configuration, _ENCODING_CONFIG_ID_ATTR) == config_id + assert _workflow_configs[config_id].namespace == "namespace-v2" + + # Cleanup + del client + gc.collect() + assert config_id not in _workflow_configs diff --git a/src/mistralai/extra/workflows/encoding/storage/_azure.py b/src/mistralai/extra/workflows/encoding/storage/_azure.py new file mode 100644 index 00000000..e62d9926 --- /dev/null +++ b/src/mistralai/extra/workflows/encoding/storage/_azure.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import Any, cast + +from azure.core.exceptions import ResourceNotFoundError +from azure.storage.blob.aio import BlobServiceClient +from .blob_storage import BlobNotFoundError, BlobStorage + + +class AzureBlobStorage(BlobStorage): + def __init__( + self, + container_name: str, + azure_connection_string: str, + prefix: str | None = None, + ): + self.container_name = container_name + self.connection_string = azure_connection_string + self.prefix = prefix or "" + self._service_client: BlobServiceClient | None = None + self._container_client: Any = None + + def _get_full_key(self, key: str) -> str: + if not self.prefix: + return key + if key.startswith(self.prefix): + return key + return f"{self.prefix}/{key}" + + async def __aenter__(self) -> "AzureBlobStorage": + self._service_client = BlobServiceClient.from_connection_string( + self.connection_string + ) + assert self._service_client is not None + self._container_client = self._service_client.get_container_client( + self.container_name + ) + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._service_client: + await self._service_client.close() + + async def upload_blob(self, key: str, content: bytes) -> str: + full_key = self._get_full_key(key) + blob_client = self._container_client.get_blob_client(full_key) + await blob_client.upload_blob(content, overwrite=True) + return cast(str, blob_client.url) + + async def get_blob(self, key: str) -> bytes: + full_key = self._get_full_key(key) + blob_client = self._container_client.get_blob_client(full_key) + try: + stream = await blob_client.download_blob() + return cast(bytes, await stream.readall()) + except ResourceNotFoundError as e: + raise BlobNotFoundError(f"Blob not found: {key}") from e + + async def get_blob_properties(self, key: str) -> dict[str, Any] | None: + full_key = self._get_full_key(key) + blob_client = self._container_client.get_blob_client(full_key) + try: + props = await blob_client.get_blob_properties() + return {"size": props.size, "last_modified": props.last_modified} + except ResourceNotFoundError: + return None + + async def delete_blob(self, key: str) -> None: + full_key = self._get_full_key(key) + blob_client = self._container_client.get_blob_client(full_key) + await blob_client.delete_blob() + + async def blob_exists(self, key: str) -> bool: + full_key = self._get_full_key(key) + blob_client = self._container_client.get_blob_client(full_key) + return cast(bool, await blob_client.exists()) diff --git a/src/mistralai/extra/workflows/encoding/storage/_gcs.py b/src/mistralai/extra/workflows/encoding/storage/_gcs.py new file mode 100644 index 00000000..c5e37d60 --- /dev/null +++ b/src/mistralai/extra/workflows/encoding/storage/_gcs.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Any, cast + +import aiohttp +from gcloud.aio.storage import Storage + +from .blob_storage import BlobNotFoundError, BlobStorage + + +class GCSBlobStorage(BlobStorage): + def __init__(self, bucket_id: str, prefix: str | None = None): + self.bucket_id = bucket_id + self.prefix = prefix or "" + self._storage: Storage | None = None + self._session: aiohttp.ClientSession | None = None + + def _get_full_key(self, key: str) -> str: + if not self.prefix: + return key + if key.startswith(self.prefix): + return key + return f"{self.prefix}/{key}" + + async def __aenter__(self) -> "GCSBlobStorage": + self._session = aiohttp.ClientSession() + self._storage = Storage(session=self._session) + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._storage: + await self._storage.close() + if self._session: + await self._session.close() + + async def upload_blob(self, key: str, content: bytes) -> str: + full_key = self._get_full_key(key) + assert self._storage is not None + response = await self._storage.upload(self.bucket_id, full_key, content) + return str(response.get("selfLink")) + + async def get_blob(self, key: str) -> bytes: + full_key = self._get_full_key(key) + assert self._storage is not None + try: + content = await self._storage.download(self.bucket_id, full_key) + return cast(bytes, content) + except Exception as e: + if "404" in str(e) or "Not Found" in str(e): + raise BlobNotFoundError(f"Blob not found: {key}") from e + raise + + async def get_blob_properties(self, key: str) -> dict[str, Any] | None: + full_key = self._get_full_key(key) + assert self._storage is not None + try: + metadata = await self._storage.download_metadata(self.bucket_id, full_key) + return { + "size": int(metadata.get("size", 0)), + "last_modified": metadata.get("updated"), + } + except Exception as e: + if "404" in str(e) or "Not Found" in str(e): + return None + raise + + async def delete_blob(self, key: str) -> None: + full_key = self._get_full_key(key) + assert self._storage is not None + await self._storage.delete(self.bucket_id, full_key) + + async def blob_exists(self, key: str) -> bool: + full_key = self._get_full_key(key) + assert self._storage is not None + try: + await self._storage.download_metadata(self.bucket_id, full_key) + return True + except Exception as e: + if "404" in str(e) or "Not Found" in str(e): + return False + raise diff --git a/src/mistralai/extra/workflows/encoding/storage/_s3.py b/src/mistralai/extra/workflows/encoding/storage/_s3.py new file mode 100644 index 00000000..4a0ce063 --- /dev/null +++ b/src/mistralai/extra/workflows/encoding/storage/_s3.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from typing import Any, cast + +import aioboto3 # type: ignore[import-untyped] +from botocore.exceptions import ClientError # type: ignore[import-untyped] + +from .blob_storage import BlobNotFoundError, BlobStorage + + +class S3BlobStorage(BlobStorage): + def __init__( + self, + bucket_name: str, + prefix: str | None = None, + region_name: str | None = None, + endpoint_url: str | None = None, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + ): + self.bucket_name = bucket_name + self.prefix = prefix or "" + self.region_name = region_name + self.endpoint_url = endpoint_url + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self._session: aioboto3.Session | None = None + self._client: Any = None + + def _get_full_key(self, key: str) -> str: + if not self.prefix: + return key + if key.startswith(self.prefix): + return key + return f"{self.prefix}/{key}" + + async def __aenter__(self) -> "S3BlobStorage": + self._session = aioboto3.Session() + assert self._session is not None + kwargs: dict[str, Any] = {} + if self.region_name: + kwargs["region_name"] = self.region_name + if self.endpoint_url: + kwargs["endpoint_url"] = self.endpoint_url + if self.aws_access_key_id: + kwargs["aws_access_key_id"] = self.aws_access_key_id + if self.aws_secret_access_key: + kwargs["aws_secret_access_key"] = self.aws_secret_access_key + + self._client = self._session.client("s3", **kwargs) + self._client = await self._client.__aenter__() + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._client: + await self._client.__aexit__(exc_type, exc_val, exc_tb) + self._session = None + self._client = None + + async def upload_blob(self, key: str, content: bytes) -> str: + full_key = self._get_full_key(key) + await self._client.put_object( + Bucket=self.bucket_name, Key=full_key, Body=content + ) + endpoint = ( + self.endpoint_url + or f"https://s3.{self.region_name or 'us-east-1'}.amazonaws.com" + ) + return f"{endpoint}/{self.bucket_name}/{full_key}" + + async def get_blob(self, key: str) -> bytes: + full_key = self._get_full_key(key) + try: + response = await self._client.get_object( + Bucket=self.bucket_name, Key=full_key + ) + async with response["Body"] as stream: + return cast(bytes, await stream.read()) + except ClientError as e: + if e.response.get("Error", {}).get("Code") == "NoSuchKey": + raise BlobNotFoundError(f"Blob not found: {key}") from e + raise + + async def get_blob_properties(self, key: str) -> dict[str, Any] | None: + full_key = self._get_full_key(key) + try: + response = await self._client.head_object( + Bucket=self.bucket_name, Key=full_key + ) + return { + "size": response["ContentLength"], + "last_modified": response["LastModified"], + } + except ClientError as e: + if e.response.get("Error", {}).get("Code") in ("NoSuchKey", "404"): + return None + raise + + async def delete_blob(self, key: str) -> None: + full_key = self._get_full_key(key) + await self._client.delete_object(Bucket=self.bucket_name, Key=full_key) + + async def blob_exists(self, key: str) -> bool: + full_key = self._get_full_key(key) + try: + await self._client.head_object(Bucket=self.bucket_name, Key=full_key) + return True + except ClientError as e: + if e.response.get("Error", {}).get("Code") in ("NoSuchKey", "404"): + return False + raise