diff --git a/airbyte/_util/api_imports.py b/airbyte/_util/api_imports.py index 19935b6f4..df367256e 100644 --- a/airbyte/_util/api_imports.py +++ b/airbyte/_util/api_imports.py @@ -25,6 +25,7 @@ ConnectionResponse, DestinationResponse, JobResponse, + WorkspaceResponse, ) # Public-Use Classes @@ -39,4 +40,5 @@ "DestinationResponse", "JobResponse", "JobStatusEnum", + "WorkspaceResponse", ] diff --git a/airbyte/cloud/__init__.py b/airbyte/cloud/__init__.py index 617913f96..394726d35 100644 --- a/airbyte/cloud/__init__.py +++ b/airbyte/cloud/__init__.py @@ -84,9 +84,11 @@ from typing import TYPE_CHECKING +from airbyte.cloud.client import CloudClient from airbyte.cloud.client_config import CloudClientConfig from airbyte.cloud.connections import CloudConnection from airbyte.cloud.constants import JobStatusEnum +from airbyte.cloud.organizations import CloudOrganization from airbyte.cloud.sync_results import SyncResult from airbyte.cloud.workspaces import CloudWorkspace @@ -95,9 +97,11 @@ if TYPE_CHECKING: # ruff: noqa: TC004 from airbyte.cloud import ( + client, client_config, connections, constants, + organizations, sync_results, workspaces, ) @@ -106,11 +110,15 @@ __all__ = [ # Submodules "workspaces", + "client", + "organizations", "connections", "constants", "client_config", "sync_results", # Classes + "CloudClient", + "CloudOrganization", "CloudWorkspace", "CloudConnection", "CloudClientConfig", diff --git a/airbyte/cloud/_credentials.py b/airbyte/cloud/_credentials.py new file mode 100644 index 000000000..2031c6363 --- /dev/null +++ b/airbyte/cloud/_credentials.py @@ -0,0 +1,348 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +"""Credential file helpers for Airbyte Cloud authentication.""" + +from __future__ import annotations + +from dataclasses import dataclass, replace +from pathlib import Path + +import yaml + +from airbyte._util.api_util import get_bearer_token +from airbyte.constants import ( + CLOUD_API_ROOT, + CLOUD_API_ROOT_ENV_VAR, + CLOUD_BEARER_TOKEN_ENV_VAR, + CLOUD_CLIENT_ID_ENV_VAR, + CLOUD_CLIENT_SECRET_ENV_VAR, + CLOUD_CONFIG_API_ROOT, + CLOUD_CONFIG_API_ROOT_ENV_VAR, + CLOUD_ORGANIZATION_ID_ENV_VAR, + CLOUD_WORKSPACE_ID_ENV_VAR, +) +from airbyte.exceptions import PyAirbyteInputError +from airbyte.secrets.base import SecretString +from airbyte.secrets.util import try_get_secret + + +CREDENTIALS_FILE_PATH = Path("~/.airbyte/credentials").expanduser() +CLIENT_ID_ENV_VAR = "AIRBYTE_CLIENT_ID" +CLIENT_SECRET_ENV_VAR = "AIRBYTE_CLIENT_SECRET" +WORKSPACE_ID_ENV_VAR = "AIRBYTE_WORKSPACE_ID" +ORGANIZATION_ID_ENV_VAR = "AIRBYTE_ORGANIZATION_ID" +PUBLIC_API_ROOT_ENV_VAR = "AIRBYTE_API_ROOT" +BEARER_TOKEN_ENV_VAR = "AIRBYTE_BEARER_TOKEN" +CONFIG_API_ROOT_ENV_VAR = "AIRBYTE_CONFIG_API_ROOT" + + +@dataclass(frozen=True) +class CloudLoginResult: + """Result of a successful non-interactive Cloud login.""" + + credentials_file_path: Path + airbyte_api_root: str + config_api_root: str + + +@dataclass(frozen=True) +class _AirbyteCredentials: + """Resolved credentials and API roots for Airbyte control-plane APIs.""" + + client_id: SecretString | None + client_secret: SecretString | None + bearer_token: SecretString | None + public_api_root: str + config_api_root: str | None + workspace_id: str | None = None + organization_id: str | None = None + + @classmethod + def from_auth( + cls, + *, + workspace_id: str | None = None, + organization_id: str | None = None, + client_id: str | SecretString | None = None, + client_secret: str | SecretString | None = None, + bearer_token: str | SecretString | None = None, + public_api_root: str | None = None, + config_api_root: str | None = None, + credentials_file_path: Path = CREDENTIALS_FILE_PATH, + ) -> _AirbyteCredentials: + """Resolve Airbyte Cloud credentials from inputs, env vars, and credentials file.""" + file_credentials = cls.from_file(credentials_file_path) + resolved_bearer_token = _first_value( + str(bearer_token) if bearer_token is not None else None, + _env_value(BEARER_TOKEN_ENV_VAR, CLOUD_BEARER_TOKEN_ENV_VAR), + str(file_credentials.bearer_token) + if file_credentials.bearer_token is not None + else None, + ) + resolved_client_id = _first_value( + str(client_id) if client_id is not None else None, + _env_value(CLIENT_ID_ENV_VAR, CLOUD_CLIENT_ID_ENV_VAR), + str(file_credentials.client_id) if file_credentials.client_id is not None else None, + ) + resolved_client_secret = _first_value( + str(client_secret) if client_secret is not None else None, + _env_value(CLIENT_SECRET_ENV_VAR, CLOUD_CLIENT_SECRET_ENV_VAR), + str(file_credentials.client_secret) + if file_credentials.client_secret is not None + else None, + ) + + if resolved_bearer_token and (resolved_client_id or resolved_client_secret): + resolved_client_id = None + resolved_client_secret = None + elif bool(resolved_client_id) != bool(resolved_client_secret): + raise PyAirbyteInputError( + message="Client ID and client secret are both required.", + guidance="Provide both client ID and client secret, or use a bearer token.", + ) + elif not resolved_bearer_token and not resolved_client_id: + raise PyAirbyteInputError( + message="No Airbyte credentials found.", + guidance=( + "Set Airbyte Cloud credentials in environment variables or " + f"create a credentials file at {credentials_file_path}." + ), + ) + + return cls( + client_id=SecretString(resolved_client_id) if resolved_client_id else None, + client_secret=SecretString(resolved_client_secret) if resolved_client_secret else None, + bearer_token=SecretString(resolved_bearer_token) if resolved_bearer_token else None, + public_api_root=_first_value( + public_api_root, + _env_value(PUBLIC_API_ROOT_ENV_VAR, CLOUD_API_ROOT_ENV_VAR), + file_credentials.public_api_root, + ) + or CLOUD_API_ROOT, + config_api_root=_first_value( + config_api_root, + _env_value(CONFIG_API_ROOT_ENV_VAR, CLOUD_CONFIG_API_ROOT_ENV_VAR), + file_credentials.config_api_root, + ), + workspace_id=_first_value( + workspace_id, + _env_value(WORKSPACE_ID_ENV_VAR, CLOUD_WORKSPACE_ID_ENV_VAR), + file_credentials.workspace_id, + ), + organization_id=_first_value( + organization_id, + _env_value(ORGANIZATION_ID_ENV_VAR, CLOUD_ORGANIZATION_ID_ENV_VAR), + file_credentials.organization_id, + ), + ) + + @classmethod + def from_file( + cls, + credentials_file_path: Path = CREDENTIALS_FILE_PATH, + ) -> _AirbyteCredentials: + """Read Airbyte credentials from a YAML credentials file.""" + credentials: dict[str, str] = {} + if not credentials_file_path.exists(): + return cls( + client_id=None, + client_secret=None, + bearer_token=None, + public_api_root=CLOUD_API_ROOT, + config_api_root=None, + ) + + try: + content = credentials_file_path.read_text(encoding="utf-8").strip() + parsed = yaml.safe_load(content) if content else {} + credentials = _as_string_mapping(parsed) + except (OSError, yaml.YAMLError): + credentials = {} + + return cls( + client_id=SecretString(credentials["client_id"]) + if credentials.get("client_id") + else None, + client_secret=SecretString(credentials["client_secret"]) + if credentials.get("client_secret") + else None, + bearer_token=SecretString(credentials["bearer_token"]) + if credentials.get("bearer_token") + else None, + public_api_root=_first_value( + credentials.get("airbyte_api_root"), + credentials.get("public_api_root"), + credentials.get("api_url"), + ) + or CLOUD_API_ROOT, + config_api_root=credentials.get("config_api_root"), + workspace_id=credentials.get("workspace_id"), + organization_id=credentials.get("organization_id"), + ) + + def to_file( + self, + credentials_file_path: Path = CREDENTIALS_FILE_PATH, + ) -> None: + """Write Airbyte credentials to a YAML credentials file.""" + credentials = { + key: value + for key, value in { + "airbyte_api_root": self.public_api_root, + "bearer_token": str(self.bearer_token) if self.bearer_token is not None else None, + "client_id": str(self.client_id) if self.client_id is not None else None, + "client_secret": str(self.client_secret) + if self.client_secret is not None + else None, + "config_api_root": self.config_api_root, + "organization_id": self.organization_id, + "workspace_id": self.workspace_id, + }.items() + if value + } + credentials_file_path.parent.mkdir(parents=True, exist_ok=True) + credentials_file_path.write_text( + yaml.safe_dump(credentials, sort_keys=True), + encoding="utf-8", + ) + credentials_file_path.chmod(0o600) + + def login( + self, + *, + credentials_file_path: Path = CREDENTIALS_FILE_PATH, + ) -> CloudLoginResult: + """Log in using client credentials and persist a bearer token.""" + resolved_client_id, resolved_client_secret = _validate_client_credentials( + client_id=str(self.client_id) if self.client_id is not None else None, + client_secret=str(self.client_secret) if self.client_secret is not None else None, + ) + resolved_airbyte_api_root, resolved_config_api_root = _resolve_login_roots( + airbyte_api_root=self.public_api_root, + config_api_root=self.config_api_root, + ) + bearer_token = get_bearer_token( + client_id=SecretString(resolved_client_id), + client_secret=SecretString(resolved_client_secret), + api_root=resolved_airbyte_api_root, + ) + + existing_credentials = type(self).from_file(credentials_file_path) + replace( + existing_credentials, + bearer_token=SecretString(bearer_token), + client_id=None, + client_secret=None, + public_api_root=resolved_airbyte_api_root, + config_api_root=resolved_config_api_root, + workspace_id=self.workspace_id or existing_credentials.workspace_id, + organization_id=self.organization_id or existing_credentials.organization_id, + ).to_file(credentials_file_path) + + return CloudLoginResult( + credentials_file_path=credentials_file_path, + airbyte_api_root=resolved_airbyte_api_root, + config_api_root=resolved_config_api_root, + ) + + def with_workspace_id(self, workspace_id: str | None) -> _AirbyteCredentials: + """Return credentials scoped to a workspace.""" + return replace(self, workspace_id=workspace_id) + + def with_organization_id(self, organization_id: str | None) -> _AirbyteCredentials: + """Return credentials scoped to an organization.""" + return replace(self, organization_id=organization_id) + + +def _as_string_mapping(parsed: object) -> dict[str, str]: + """Return a string-only mapping from parsed YAML content.""" + if not isinstance(parsed, dict): + return {} + + result: dict[str, str] = {} + for key, value in parsed.items(): + if isinstance(key, str) and value is not None: + result[key] = str(value) + + return result + + +def _first_value(*values: str | None) -> str | None: + """Return the first non-empty string value.""" + for value in values: + if value: + return value + return None + + +def _env_value(*names: str) -> str | None: + """Return the first available environment variable value.""" + for name in names: + value = try_get_secret(name, default=None) + if value: + return str(value) + return None + + +def _raise_interactive_login_unavailable() -> None: + """Raise an error for the unsupported browser login flow.""" + raise PyAirbyteInputError( + message="Interactive Airbyte Cloud login is not implemented.", + guidance=( + "Provide `--client-id` and `--client-secret` for non-interactive login. " + "The browser login protocol has not been published in repo docs." + ), + ) + + +def _validate_client_credentials( + *, + client_id: str | None, + client_secret: str | None, +) -> tuple[str, str]: + """Validate and return client credentials for non-interactive login.""" + if not client_id and not client_secret: + _raise_interactive_login_unavailable() + + if not client_id or not client_secret: + raise PyAirbyteInputError( + message="Client ID and client secret are both required.", + guidance="Provide both `--client-id` and `--client-secret`.", + ) + + return client_id, client_secret + + +def _resolve_login_roots( + *, + airbyte_api_root: str | None, + config_api_root: str | None, +) -> tuple[str, str]: + """Resolve Cloud or self-managed API roots for login.""" + if airbyte_api_root in {None, CLOUD_API_ROOT} and config_api_root is None: + return CLOUD_API_ROOT, CLOUD_CONFIG_API_ROOT + + if airbyte_api_root is not None or config_api_root is not None: + if airbyte_api_root is not None and config_api_root is not None: + return airbyte_api_root, config_api_root + + missing_roots: list[str] = [] + if not airbyte_api_root: + missing_roots.append("airbyte_api_root") + if not config_api_root: + missing_roots.append("config_api_root") + raise PyAirbyteInputError( + message="Self-managed login requires both API roots.", + context={"missing": ", ".join(missing_roots)}, + guidance="Provide both `--public-api-root` and `--config-api-root`.", + ) + + return CLOUD_API_ROOT, CLOUD_CONFIG_API_ROOT + + +def logout( + *, + credentials_file_path: Path = CREDENTIALS_FILE_PATH, +) -> None: + """Remove locally stored Airbyte credentials.""" + if credentials_file_path.exists(): + credentials_file_path.unlink() diff --git a/airbyte/cloud/client.py b/airbyte/cloud/client.py new file mode 100644 index 000000000..fbe68420a --- /dev/null +++ b/airbyte/cloud/client.py @@ -0,0 +1,322 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +"""PyAirbyte Cloud client.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, overload + +from airbyte import exceptions as exc +from airbyte._util import api_util +from airbyte.cloud._credentials import ( + CREDENTIALS_FILE_PATH, + CloudLoginResult, + _AirbyteCredentials, +) +from airbyte.cloud._credentials import logout as remove_credentials_file +from airbyte.cloud.organizations import CloudOrganization +from airbyte.cloud.workspaces import CloudWorkspace +from airbyte.exceptions import AirbyteMissingResourceError +from airbyte.secrets.base import SecretString + + +if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + + from airbyte._util import api_imports + + +@dataclass(init=False, kw_only=True) +class CloudClient: + """Authenticated client for Airbyte Cloud and self-managed Airbyte APIs.""" + + _credentials: _AirbyteCredentials + + def __init__( + self, + *, + client_id: str | SecretString | None = None, + client_secret: str | SecretString | None = None, + bearer_token: str | SecretString | None = None, + public_api_root: str | None = None, + config_api_root: str | None = None, + workspace_id: str | None = None, + organization_id: str | None = None, + ) -> None: + """Initialize a `CloudClient` from explicit auth values.""" + self._credentials = _AirbyteCredentials( + client_id=SecretString(client_id) if client_id else None, + client_secret=SecretString(client_secret) if client_secret else None, + bearer_token=SecretString(bearer_token) if bearer_token else None, + public_api_root=public_api_root or api_util.CLOUD_API_ROOT, + config_api_root=config_api_root, + workspace_id=workspace_id, + organization_id=organization_id, + ) + + @property + def client_id(self) -> SecretString | None: + """OAuth client ID used for authentication.""" + return self._credentials.client_id + + @property + def client_secret(self) -> SecretString | None: + """OAuth client secret used for authentication.""" + return self._credentials.client_secret + + @property + def bearer_token(self) -> SecretString | None: + """Bearer token used for authentication.""" + return self._credentials.bearer_token + + @property + def public_api_root(self) -> str: + """Airbyte Public API root.""" + return self._credentials.public_api_root + + @property + def config_api_root(self) -> str | None: + """Airbyte Config API root.""" + return self._credentials.config_api_root + + @property + def organization_id(self) -> str | None: + """Default organization ID for organization-scoped operations.""" + return self._credentials.organization_id + + @classmethod + def from_env( + cls, + *, + client_id: str | SecretString | None = None, + client_secret: str | SecretString | None = None, + bearer_token: str | SecretString | None = None, + organization_id: str | None = None, + public_api_root: str | None = None, + config_api_root: str | None = None, + ) -> CloudClient: + """Create a client from shared environment and credentials-file resolution.""" + credentials = _AirbyteCredentials.from_auth( + client_id=client_id, + client_secret=client_secret, + bearer_token=bearer_token, + organization_id=organization_id, + public_api_root=public_api_root, + config_api_root=config_api_root, + ) + return cls._from_credentials(credentials) + + @classmethod + def from_auth( + cls, + *, + organization_id: str | None = None, + client_id: str | SecretString | None = None, + client_secret: str | SecretString | None = None, + bearer_token: str | SecretString | None = None, + public_api_root: str | None = None, + config_api_root: str | None = None, + credentials_file_path: Path = CREDENTIALS_FILE_PATH, + ) -> CloudClient: + """Create a client from explicit inputs, env vars, and credentials file.""" + credentials = _AirbyteCredentials.from_auth( + organization_id=organization_id, + client_id=client_id, + client_secret=client_secret, + bearer_token=bearer_token, + public_api_root=public_api_root, + config_api_root=config_api_root, + credentials_file_path=credentials_file_path, + ) + return cls._from_credentials(credentials) + + @classmethod + def _from_credentials(cls, credentials: _AirbyteCredentials) -> CloudClient: + """Create a client from resolved Cloud credentials.""" + return cls( + client_id=credentials.client_id, + client_secret=credentials.client_secret, + bearer_token=credentials.bearer_token, + public_api_root=credentials.public_api_root, + config_api_root=credentials.config_api_root, + workspace_id=credentials.workspace_id, + organization_id=credentials.organization_id, + ) + + def login( + self, + *, + interactive: bool | None = None, + credentials_file_path: Path = CREDENTIALS_FILE_PATH, + ) -> CloudLoginResult: + """Log in to Airbyte and persist local credentials.""" + if interactive is True: + raise NotImplementedError("Interactive Airbyte Cloud login is not implemented.") + if self.client_id is not None and self.client_secret is not None: + return self._credentials.login(credentials_file_path=credentials_file_path) + if interactive is False: + raise exc.PyAirbyteInputError( + message="Client ID and client secret are both required.", + guidance="Provide both client ID and client secret for non-interactive login.", + ) + + raise NotImplementedError("Interactive Airbyte Cloud login is not implemented.") + + def logout( + self, + *, + credentials_file_path: Path = CREDENTIALS_FILE_PATH, + ) -> None: + """Log out by removing locally stored credentials.""" + remove_credentials_file(credentials_file_path=credentials_file_path) + + def get_workspace(self, workspace_id: str | None = None) -> CloudWorkspace: + """Create a `CloudWorkspace` using this client's credentials.""" + resolved_workspace_id = workspace_id or self._credentials.workspace_id + if not resolved_workspace_id: + raise exc.PyAirbyteInputError( + message="Workspace ID is required.", + guidance="Provide a workspace ID.", + ) + + credentials = self._credentials.with_workspace_id(resolved_workspace_id) + return CloudWorkspace( + workspace_id=credentials.workspace_id, + client_id=credentials.client_id, + client_secret=credentials.client_secret, + bearer_token=credentials.bearer_token, + api_root=credentials.public_api_root, + config_api_root=credentials.config_api_root, + ) + + @overload + def list_workspaces( + self, + name: str | None = None, + *, + organization_id: None = None, + name_contains: str | None = None, + name_filter: Callable[[str], bool] | None = None, + limit: int | None = None, + ) -> list[api_imports.WorkspaceResponse]: + raise NotImplementedError + + @overload + def list_workspaces( + self, + name: str | None = None, + *, + organization_id: str, + name_contains: str | None = None, + name_filter: Callable[[str], bool] | None = None, + limit: int | None = None, + ) -> list[dict[str, object]]: + raise NotImplementedError + + def list_workspaces( + self, + name: str | None = None, + *, + organization_id: str | None = None, + name_contains: str | None = None, + name_filter: Callable[[str], bool] | None = None, + limit: int | None = None, + ) -> list[api_imports.WorkspaceResponse] | list[dict[str, object]]: + """List workspaces available to this client.""" + if organization_id is not None or self.organization_id is not None: + resolved_organization_id = organization_id or self.organization_id + if not resolved_organization_id: + raise exc.PyAirbyteInputError( + message="Organization ID is required.", + guidance="Provide an organization ID.", + ) + workspaces = api_util.list_workspaces_in_organization( + organization_id=resolved_organization_id, + api_root=self.public_api_root, + config_api_root=self.config_api_root, + client_id=self.client_id, + client_secret=self.client_secret, + bearer_token=self.bearer_token, + name_contains=name_contains or name, + limit=None if name_filter is not None else limit, + ) + if name_filter is not None: + workspaces = [ + workspace + for workspace in workspaces + if name_filter(str(workspace.get("name", ""))) + ] + if limit is not None: + workspaces = workspaces[:limit] + return workspaces + if name_contains is not None: + name = name_contains + return api_util.list_workspaces( + workspace_id="", + api_root=self.public_api_root, + name=name, + name_filter=name_filter, + client_id=self.client_id, + client_secret=self.client_secret, + bearer_token=self.bearer_token, + limit=limit, + ) + + def get_organization( + self, + organization_id: str | None = None, + *, + organization_name: str | None = None, + ) -> CloudOrganization: + """Resolve an organization by ID or exact name.""" + if organization_id and organization_name: + raise exc.PyAirbyteInputError( + message="Provide either organization ID or organization name." + ) + if not organization_id and not organization_name: + raise exc.PyAirbyteInputError( + message="Organization ID or organization name is required." + ) + + organizations = api_util.list_organizations_for_user( + api_root=self.public_api_root, + client_id=self.client_id, + client_secret=self.client_secret, + bearer_token=self.bearer_token, + ) + if organization_id: + matching_organizations = [ + organization + for organization in organizations + if organization.organization_id == organization_id + ] + else: + matching_organizations = [ + organization + for organization in organizations + if organization.organization_name == organization_name + ] + + if not matching_organizations: + raise AirbyteMissingResourceError(resource_type="organization") + if len(matching_organizations) > 1: + raise exc.PyAirbyteInputError( + message="Organization name matches multiple organizations." + ) + + organization = matching_organizations[0] + + organization_credentials = self._credentials.with_organization_id( + organization.organization_id + ) + return CloudOrganization( + organization_id=organization.organization_id, + organization_name=organization.organization_name, + email=organization.email, + client_id=organization_credentials.client_id, + client_secret=organization_credentials.client_secret, + bearer_token=organization_credentials.bearer_token, + public_api_root=organization_credentials.public_api_root, + config_api_root=organization_credentials.config_api_root, + ) diff --git a/airbyte/cloud/organizations.py b/airbyte/cloud/organizations.py new file mode 100644 index 000000000..a6325c372 --- /dev/null +++ b/airbyte/cloud/organizations.py @@ -0,0 +1,111 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +"""PyAirbyte classes and methods for Airbyte Cloud organizations.""" + +from __future__ import annotations + +from typing import Any + +from airbyte._util import api_util +from airbyte.cloud._credentials import _AirbyteCredentials +from airbyte.secrets.base import SecretString + + +class CloudOrganization: + """Information about an organization in Airbyte Cloud. + + This class provides lazy loading of organization attributes including billing status. + It is typically created via `CloudWorkspace.get_organization()`. + """ + + def __init__( + self, + organization_id: str, + organization_name: str | None = None, + email: str | None = None, + *, + client_id: str | SecretString | None = None, + client_secret: str | SecretString | None = None, + bearer_token: str | SecretString | None = None, + public_api_root: str | None = None, + config_api_root: str | None = None, + ) -> None: + """Initialize a `CloudOrganization`.""" + self.organization_id = organization_id + """The organization ID.""" + + self._organization_name = organization_name + """Display name of the organization.""" + + self._email = email + """Email associated with the organization.""" + + self._credentials = _AirbyteCredentials( + client_id=SecretString(client_id) if client_id else None, + client_secret=SecretString(client_secret) if client_secret else None, + bearer_token=SecretString(bearer_token) if bearer_token else None, + public_api_root=public_api_root or api_util.CLOUD_API_ROOT, + config_api_root=config_api_root, + organization_id=organization_id, + ) + self._organization_info: dict[str, Any] | None = None + self._organization_info_fetch_failed: bool = False + + def _fetch_organization_info(self, *, force_refresh: bool = False) -> dict[str, Any]: + """Fetch and cache organization info including billing status.""" + if force_refresh: + self._organization_info_fetch_failed = False + + if self._organization_info_fetch_failed and self._organization_info is None: + return {} + + if not force_refresh and self._organization_info is not None: + return self._organization_info + + try: + self._organization_info = api_util.get_organization_info( + organization_id=self.organization_id, + api_root=self._credentials.public_api_root, + config_api_root=self._credentials.config_api_root, + client_id=self._credentials.client_id, + client_secret=self._credentials.client_secret, + bearer_token=self._credentials.bearer_token, + ) + except Exception: + if self._organization_info is None: + self._organization_info_fetch_failed = True + return self._organization_info or {} + else: + return self._organization_info + + @property + def organization_name(self) -> str | None: + """Display name of the organization.""" + if self._organization_name is not None: + return self._organization_name + info = self._fetch_organization_info() + return info.get("organizationName") + + @property + def email(self) -> str | None: + """Email associated with the organization.""" + if self._email is not None: + return self._email + info = self._fetch_organization_info() + return info.get("email") + + @property + def payment_status(self) -> str | None: + """Payment status of the organization.""" + info = self._fetch_organization_info() + return (info.get("billing") or {}).get("paymentStatus") + + @property + def subscription_status(self) -> str | None: + """Subscription status of the organization.""" + info = self._fetch_organization_info() + return (info.get("billing") or {}).get("subscriptionStatus") + + @property + def is_account_locked(self) -> bool: + """Whether the account is locked due to billing issues.""" + return api_util.is_account_locked(self.payment_status, self.subscription_status) diff --git a/airbyte/cloud/workspaces.py b/airbyte/cloud/workspaces.py index 1c33f3f42..6358830c8 100644 --- a/airbyte/cloud/workspaces.py +++ b/airbyte/cloud/workspaces.py @@ -45,14 +45,7 @@ from airbyte import exceptions as exc from airbyte._util import api_util, text_util from airbyte._util.api_util import get_web_url_root -from airbyte.cloud.auth import ( - resolve_cloud_api_url, - resolve_cloud_bearer_token, - resolve_cloud_client_id, - resolve_cloud_client_secret, - resolve_cloud_config_api_url, - resolve_cloud_workspace_id, -) +from airbyte.cloud._credentials import _AirbyteCredentials from airbyte.cloud.client_config import CloudClientConfig from airbyte.cloud.connections import CloudConnection from airbyte.cloud.connectors import ( @@ -60,158 +53,20 @@ CloudSource, CustomCloudSourceDefinition, ) +from airbyte.cloud.organizations import CloudOrganization from airbyte.destinations.base import Destination from airbyte.exceptions import AirbyteError -from airbyte.secrets.base import SecretString if TYPE_CHECKING: from collections.abc import Callable + from airbyte._util.api_imports import WorkspaceResponse + from airbyte.secrets.base import SecretString from airbyte.sources.base import Source -class CloudOrganization: - """Information about an organization in Airbyte Cloud. - - This class provides lazy loading of organization attributes including billing status. - It is typically created via CloudWorkspace.get_organization(). - """ - - def __init__( - self, - organization_id: str, - organization_name: str | None = None, - email: str | None = None, - *, - api_root: str = api_util.CLOUD_API_ROOT, - client_id: SecretString | None = None, - client_secret: SecretString | None = None, - bearer_token: SecretString | None = None, - config_api_root: str | None = None, - ) -> None: - """Initialize a CloudOrganization. - - Args: - organization_id: The organization ID. - organization_name: Display name of the organization. - email: Email associated with the organization. - api_root: The API root URL. - client_id: OAuth client ID for authentication. - client_secret: OAuth client secret for authentication. - bearer_token: Bearer token for authentication (alternative to client credentials). - config_api_root: Optional Config API root URL. - """ - self.organization_id = organization_id - """The organization ID.""" - - self._organization_name = organization_name - """Display name of the organization.""" - - self._email = email - """Email associated with the organization.""" - - self._api_root = api_root - self._config_api_root = config_api_root - self._client_id = client_id - self._client_secret = client_secret - self._bearer_token = bearer_token - - # Cached organization info (billing, etc.) - self._organization_info: dict[str, Any] | None = None - # Flag to remember if fetching organization info failed (e.g., permission issues) - self._organization_info_fetch_failed: bool = False - - def _fetch_organization_info(self, *, force_refresh: bool = False) -> dict[str, Any]: - """Fetch and cache organization info including billing status. - - If fetching fails (e.g., due to permission issues), the failure is cached and - subsequent calls will return an empty dict without retrying. - - Args: - force_refresh: If True, always fetch from the API even if cached. - - Returns: - Dictionary containing organization info including billing data. - Returns empty dict if fetching failed or is not permitted. - """ - # Reset failure flag if force_refresh is requested - if force_refresh: - self._organization_info_fetch_failed = False - - # If we already know fetching failed, return empty dict without retrying - if self._organization_info_fetch_failed: - return {} - - if not force_refresh and self._organization_info is not None: - return self._organization_info - - try: - self._organization_info = api_util.get_organization_info( - organization_id=self.organization_id, - api_root=self._api_root, - config_api_root=self._config_api_root, - client_id=self._client_id, - client_secret=self._client_secret, - bearer_token=self._bearer_token, - ) - except Exception: - # Cache the failure so we don't retry on subsequent property accesses - self._organization_info_fetch_failed = True - return {} - else: - return self._organization_info - - @property - def organization_name(self) -> str | None: - """Display name of the organization.""" - if self._organization_name is not None: - return self._organization_name - # Try to fetch from API if not set (returns empty dict on failure) - info = self._fetch_organization_info() - return info.get("organizationName") - - @property - def email(self) -> str | None: - """Email associated with the organization.""" - if self._email is not None: - return self._email - # Try to fetch from API if not set (returns empty dict on failure) - info = self._fetch_organization_info() - return info.get("email") - - @property - def payment_status(self) -> str | None: - """Payment status of the organization. - - Possible values: 'uninitialized', 'okay', 'grace_period', 'disabled', 'locked', 'manual'. - When 'disabled', syncs are blocked due to unpaid invoices. - Returns None if billing info is not available (e.g., due to permission issues). - """ - info = self._fetch_organization_info() - return (info.get("billing") or {}).get("paymentStatus") - - @property - def subscription_status(self) -> str | None: - """Subscription status of the organization. - - Possible values: 'pre_subscription', 'subscribed', 'unsubscribed'. - Returns None if billing info is not available (e.g., due to permission issues). - """ - info = self._fetch_organization_info() - return (info.get("billing") or {}).get("subscriptionStatus") - - @property - def is_account_locked(self) -> bool: - """Whether the account is locked due to billing issues. - - Returns True if payment_status is 'disabled'/'locked' or subscription_status is - 'unsubscribed'. Defaults to False unless we have affirmative evidence of a locked state. - """ - return api_util.is_account_locked(self.payment_status, self.subscription_status) - - -@dataclass(kw_only=True) +@dataclass(init=False, kw_only=True) # noqa: PLR0904 # Core cloud API facade. class CloudWorkspace: """A remote workspace on the Airbyte Cloud. @@ -241,28 +96,52 @@ class CloudWorkspace: """ workspace_id: str - client_id: SecretString | None = None - client_secret: SecretString | None = None - api_root: str = api_util.CLOUD_API_ROOT - config_api_root: str | None = None + client_id: SecretString | None + client_secret: SecretString | None + api_root: str + config_api_root: str | None """The Config API root URL.""" - bearer_token: SecretString | None = None + bearer_token: SecretString | None - # Internal credentials object (set in __post_init__, excluded from __init__) - _credentials: CloudClientConfig | None = field(default=None, init=False, repr=False) + # Internal credentials objects (set in __init__, excluded from repr) + _credentials: _AirbyteCredentials = field(init=False, repr=False) + _client_config: CloudClientConfig = field(init=False, repr=False) - def __post_init__(self) -> None: + def __init__( + self, + *, + workspace_id: str | None = None, + client_id: str | SecretString | None = None, + client_secret: str | SecretString | None = None, + api_root: str | None = None, + config_api_root: str | None = None, + bearer_token: str | SecretString | None = None, + ) -> None: """Validate and initialize credentials.""" - # Wrap secrets in SecretString if provided - if self.client_id is not None: - self.client_id = SecretString(self.client_id) - if self.client_secret is not None: - self.client_secret = SecretString(self.client_secret) - if self.bearer_token is not None: - self.bearer_token = SecretString(self.bearer_token) + credentials = _AirbyteCredentials.from_auth( + workspace_id=workspace_id, + client_id=client_id, + client_secret=client_secret, + bearer_token=bearer_token, + public_api_root=api_root, + config_api_root=config_api_root, + ) + if not credentials.workspace_id: + raise exc.PyAirbyteInputError( + message="Workspace ID is required.", + guidance="Provide a workspace ID.", + ) + + self._credentials = credentials + self.workspace_id = credentials.workspace_id or "" + self.client_id = credentials.client_id + self.client_secret = credentials.client_secret + self.bearer_token = credentials.bearer_token + self.api_root = credentials.public_api_root + self.config_api_root = credentials.config_api_root # Create internal CloudClientConfig object (validates mutual exclusivity) - self._credentials = CloudClientConfig( + self._client_config = CloudClientConfig( client_id=self.client_id, client_secret=self.client_secret, bearer_token=self.bearer_token, @@ -321,26 +200,10 @@ def from_env( workspace = CloudWorkspace.from_env(workspace_id="your-workspace-id") ``` """ - resolved_api_root = resolve_cloud_api_url(api_root) - resolved_config_api_root = resolve_cloud_config_api_url(config_api_root) - - # Try bearer token first - bearer_token = resolve_cloud_bearer_token() - if bearer_token: - return cls( - workspace_id=resolve_cloud_workspace_id(workspace_id), - bearer_token=bearer_token, - api_root=resolved_api_root, - config_api_root=resolved_config_api_root, - ) - - # Fall back to client credentials return cls( - workspace_id=resolve_cloud_workspace_id(workspace_id), - client_id=resolve_cloud_client_id(), - client_secret=resolve_cloud_client_secret(), - api_root=resolved_api_root, - config_api_root=resolved_config_api_root, + workspace_id=workspace_id, + api_root=api_root, + config_api_root=config_api_root, ) @property @@ -425,14 +288,15 @@ def get_organization( ) return None + organization_credentials = self._credentials.with_organization_id(organization_id) return CloudOrganization( organization_id=organization_id, organization_name=organization_name, - api_root=self.api_root, - config_api_root=self.config_api_root, - client_id=self.client_id, - client_secret=self.client_secret, - bearer_token=self.bearer_token, + client_id=organization_credentials.client_id, + client_secret=organization_credentials.client_secret, + bearer_token=organization_credentials.bearer_token, + public_api_root=organization_credentials.public_api_root, + config_api_root=organization_credentials.config_api_root, ) # Test connection and creds @@ -773,7 +637,24 @@ def permanently_delete_connection( safe_mode=safe_mode, ) - # List sources, destinations, and connections + # List workspaces, sources, destinations, and connections + + def list_workspaces( + self, + name: str | None = None, + *, + name_filter: Callable | None = None, + ) -> list[WorkspaceResponse]: + """List workspaces available to the current credentials.""" + return api_util.list_workspaces( + workspace_id="", + api_root=self.api_root, + name=name, + name_filter=name_filter, + client_id=self.client_id, + client_secret=self.client_secret, + bearer_token=self.bearer_token, + ) def list_connections( self, diff --git a/airbyte/constants.py b/airbyte/constants.py index aa7632de6..a2716f219 100644 --- a/airbyte/constants.py +++ b/airbyte/constants.py @@ -232,6 +232,9 @@ def _str_to_bool(value: str) -> bool: CLOUD_WORKSPACE_ID_ENV_VAR: str = "AIRBYTE_CLOUD_WORKSPACE_ID" """The environment variable name for the Airbyte Cloud workspace ID.""" +CLOUD_ORGANIZATION_ID_ENV_VAR: str = "AIRBYTE_CLOUD_ORGANIZATION_ID" +"""The environment variable name for the Airbyte Cloud organization ID.""" + CLOUD_BEARER_TOKEN_ENV_VAR: str = "AIRBYTE_CLOUD_BEARER_TOKEN" """The environment variable name for the Airbyte Cloud bearer token. diff --git a/airbyte/mcp/cloud.py b/airbyte/mcp/cloud.py index d625f975e..0759af551 100644 --- a/airbyte/mcp/cloud.py +++ b/airbyte/mcp/cloud.py @@ -20,9 +20,11 @@ from airbyte import cloud, get_destination, get_source from airbyte._util import api_util +from airbyte.cloud.client import CloudClient from airbyte.cloud.connectors import CustomCloudSourceDefinition from airbyte.cloud.constants import FAILED_STATUSES -from airbyte.cloud.workspaces import CloudOrganization, CloudWorkspace +from airbyte.cloud.organizations import CloudOrganization +from airbyte.cloud.workspaces import CloudWorkspace from airbyte.constants import ( MCP_CONFIG_API_URL, MCP_CONFIG_BEARER_TOKEN, @@ -264,19 +266,28 @@ def _get_cloud_workspace( guidance="Set AIRBYTE_CLOUD_WORKSPACE_ID env var or pass workspace_id parameter.", ) + return _get_cloud_client(ctx).get_workspace(resolved_workspace_id) + + +def _get_cloud_client( + ctx: Context, + *, + organization_id: str | None = None, +) -> CloudClient: + """Get an authenticated `CloudClient` from MCP config.""" bearer_token = get_mcp_config(ctx, MCP_CONFIG_BEARER_TOKEN) client_id = get_mcp_config(ctx, MCP_CONFIG_CLIENT_ID) client_secret = get_mcp_config(ctx, MCP_CONFIG_CLIENT_SECRET) - api_url = get_mcp_config(ctx, MCP_CONFIG_API_URL) or api_util.CLOUD_API_ROOT + api_url = get_mcp_config(ctx, MCP_CONFIG_API_URL) config_api_url = get_mcp_config(ctx, MCP_CONFIG_CONFIG_API_URL) - return CloudWorkspace( - workspace_id=resolved_workspace_id, - client_id=SecretString(client_id) if client_id else None, - client_secret=SecretString(client_secret) if client_secret else None, - bearer_token=SecretString(bearer_token) if bearer_token else None, - api_root=api_url, + return CloudClient( + client_id=client_id, + client_secret=client_secret, + bearer_token=bearer_token, + public_api_root=api_url, config_api_root=config_api_url, + organization_id=organization_id, ) @@ -1314,62 +1325,13 @@ def _resolve_organization( message="Either 'organization_id' or 'organization_name' must be provided." ) - # Get all organizations for the user - orgs = api_util.list_organizations_for_user( - api_root=api_root, + return CloudClient( client_id=client_id, client_secret=client_secret, bearer_token=bearer_token, - ) - - org_response: api_util.models.OrganizationResponse | None = None - - if organization_id: - # Find by ID - matching_orgs = [org for org in orgs if org.organization_id == organization_id] - if not matching_orgs: - raise AirbyteMissingResourceError( - resource_type="organization", - context={ - "organization_id": organization_id, - "message": f"No organization found with ID '{organization_id}' " - "for the current user.", - }, - ) - org_response = matching_orgs[0] - else: - # Find by exact name match (case-sensitive) - matching_orgs = [org for org in orgs if org.organization_name == organization_name] - - if not matching_orgs: - raise AirbyteMissingResourceError( - resource_type="organization", - context={ - "organization_name": organization_name, - "message": f"No organization found with exact name '{organization_name}' " - "for the current user.", - }, - ) - - if len(matching_orgs) > 1: - raise PyAirbyteInputError( - message=f"Multiple organizations found with name '{organization_name}'. " - "Please use 'organization_id' instead to specify the exact organization." - ) - - org_response = matching_orgs[0] - - # Return a CloudOrganization with credentials for lazy loading of billing info - return CloudOrganization( - organization_id=org_response.organization_id, - organization_name=org_response.organization_name, - email=org_response.email, - api_root=api_root, + public_api_root=api_root, config_api_root=config_api_root, - client_id=client_id, - client_secret=client_secret, - bearer_token=bearer_token, - ) + ).get_organization(organization_id=organization_id, organization_name=organization_name) def _resolve_organization_id( @@ -1444,29 +1406,20 @@ def list_cloud_workspaces( This tool will NOT list workspaces across all organizations - you must specify which organization to list workspaces from. """ - bearer_token = get_mcp_config(ctx, MCP_CONFIG_BEARER_TOKEN) - client_id = get_mcp_config(ctx, MCP_CONFIG_CLIENT_ID) - client_secret = get_mcp_config(ctx, MCP_CONFIG_CLIENT_SECRET) - api_url = get_mcp_config(ctx, MCP_CONFIG_API_URL) or api_util.CLOUD_API_ROOT - config_api_url = get_mcp_config(ctx, MCP_CONFIG_CONFIG_API_URL) + client = _get_cloud_client(ctx) resolved_org_id = _resolve_organization_id( organization_id=organization_id, organization_name=organization_name, - api_root=api_url, - client_id=SecretString(client_id) if client_id else None, - client_secret=SecretString(client_secret) if client_secret else None, - bearer_token=SecretString(bearer_token) if bearer_token else None, - config_api_root=config_api_url, + api_root=client.public_api_root, + client_id=client.client_id, + client_secret=client.client_secret, + bearer_token=client.bearer_token, + config_api_root=client.config_api_root, ) - workspaces = api_util.list_workspaces_in_organization( + workspaces = client.list_workspaces( organization_id=resolved_org_id, - api_root=api_url, - client_id=SecretString(client_id) if client_id else None, - client_secret=SecretString(client_secret) if client_secret else None, - bearer_token=SecretString(bearer_token) if bearer_token else None, - config_api_root=config_api_url, name_contains=name_contains, limit=limit, ) @@ -1512,20 +1465,9 @@ def describe_cloud_organization( Requires either organization_id OR organization_name (exact match) to be provided. This tool is useful for looking up an organization's ID from its name, or vice versa. """ - bearer_token = get_mcp_config(ctx, MCP_CONFIG_BEARER_TOKEN) - client_id = get_mcp_config(ctx, MCP_CONFIG_CLIENT_ID) - client_secret = get_mcp_config(ctx, MCP_CONFIG_CLIENT_SECRET) - api_url = get_mcp_config(ctx, MCP_CONFIG_API_URL) or api_util.CLOUD_API_ROOT - config_api_url = get_mcp_config(ctx, MCP_CONFIG_CONFIG_API_URL) - - org = _resolve_organization( + org = _get_cloud_client(ctx).get_organization( organization_id=organization_id, organization_name=organization_name, - api_root=api_url, - client_id=SecretString(client_id) if client_id else None, - client_secret=SecretString(client_secret) if client_secret else None, - bearer_token=SecretString(bearer_token) if bearer_token else None, - config_api_root=config_api_url, ) # CloudOrganization has lazy loading of billing properties diff --git a/tests/unit_tests/test_cloud_credentials.py b/tests/unit_tests/test_cloud_credentials.py new file mode 100644 index 000000000..cf760f140 --- /dev/null +++ b/tests/unit_tests/test_cloud_credentials.py @@ -0,0 +1,257 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +from __future__ import annotations + +from pathlib import Path +import sys + +import pytest +import yaml + +from airbyte import constants +from airbyte._util import api_util +from airbyte.cloud import _credentials as cloud_credentials +from airbyte.cloud.client import CloudClient +from airbyte.cloud.organizations import CloudOrganization +from airbyte.exceptions import PyAirbyteInputError +from airbyte.secrets.base import SecretString + + +def test_login_with_client_credentials_writes_bearer_token( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + credentials_file_path = tmp_path / "credentials" + + def fake_get_bearer_token( + *, + client_id: SecretString, + client_secret: SecretString, + api_root: str, + ) -> str: + assert str(client_id) == "test-client-id" + assert str(client_secret) == "test-client-secret" + assert api_root == "https://api.example.com/v1" + return "test-bearer-token" + + monkeypatch.setattr(cloud_credentials, "get_bearer_token", fake_get_bearer_token) + + result = cloud_credentials._AirbyteCredentials( + client_id=SecretString("test-client-id"), + client_secret=SecretString("test-client-secret"), + bearer_token=None, + public_api_root="https://api.example.com/v1", + config_api_root="https://config.example.com/api/v1", + ).login(credentials_file_path=credentials_file_path) + + saved_credentials = yaml.safe_load( + credentials_file_path.read_text(encoding="utf-8") + ) + assert result.credentials_file_path == credentials_file_path + assert result.airbyte_api_root == "https://api.example.com/v1" + assert result.config_api_root == "https://config.example.com/api/v1" + assert saved_credentials == { + "airbyte_api_root": "https://api.example.com/v1", + "bearer_token": "test-bearer-token", + "config_api_root": "https://config.example.com/api/v1", + } + if sys.platform != "win32": + assert credentials_file_path.stat().st_mode & 0o777 == 0o600 + + +def test_login_without_client_credentials_raises_interactive_flow_error() -> None: + with pytest.raises(PyAirbyteInputError, match="Interactive Airbyte Cloud login"): + cloud_credentials._AirbyteCredentials( + client_id=None, + client_secret=None, + bearer_token=None, + public_api_root=constants.CLOUD_API_ROOT, + config_api_root=constants.CLOUD_CONFIG_API_ROOT, + ).login() + + +def test_login_with_partial_client_credentials_raises() -> None: + with pytest.raises(PyAirbyteInputError, match="Client ID and client secret"): + cloud_credentials._AirbyteCredentials( + client_id=SecretString("test-client-id"), + client_secret=None, + bearer_token=None, + public_api_root=constants.CLOUD_API_ROOT, + config_api_root=constants.CLOUD_CONFIG_API_ROOT, + ).login() + + +def test_self_managed_login_requires_both_api_roots() -> None: + with pytest.raises( + PyAirbyteInputError, match="Self-managed login requires both API roots" + ): + cloud_credentials._AirbyteCredentials( + client_id=SecretString("test-client-id"), + client_secret=SecretString("test-client-secret"), + bearer_token=None, + public_api_root="https://api.example.com/v1", + config_api_root=None, + ).login() + + +def test_login_with_client_credentials_uses_cloud_default_roots( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + credentials_file_path = tmp_path / "credentials" + + def fake_get_bearer_token( + *, + client_id: SecretString, + client_secret: SecretString, + api_root: str, + ) -> str: + assert str(client_id) == "test-client-id" + assert str(client_secret) == "test-client-secret" + assert api_root == constants.CLOUD_API_ROOT + return "test-bearer-token" + + monkeypatch.setattr(cloud_credentials, "get_bearer_token", fake_get_bearer_token) + + result = cloud_credentials._AirbyteCredentials( + client_id=SecretString("test-client-id"), + client_secret=SecretString("test-client-secret"), + bearer_token=None, + public_api_root=constants.CLOUD_API_ROOT, + config_api_root=None, + ).login(credentials_file_path=credentials_file_path) + + assert result.airbyte_api_root == constants.CLOUD_API_ROOT + assert result.config_api_root == constants.CLOUD_CONFIG_API_ROOT + + +def test_cloud_client_login_uses_cloud_default_roots( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + credentials_file_path = tmp_path / "credentials" + + def fake_get_bearer_token( + *, + client_id: SecretString, + client_secret: SecretString, + api_root: str, + ) -> str: + assert str(client_id) == "test-client-id" + assert str(client_secret) == "test-client-secret" + assert api_root == constants.CLOUD_API_ROOT + return "test-bearer-token" + + monkeypatch.setattr(cloud_credentials, "get_bearer_token", fake_get_bearer_token) + + result = CloudClient( + client_id=SecretString("test-client-id"), + client_secret=SecretString("test-client-secret"), + ).login(credentials_file_path=credentials_file_path) + + assert result.airbyte_api_root == constants.CLOUD_API_ROOT + assert result.config_api_root == constants.CLOUD_CONFIG_API_ROOT + + +def test_airbyte_credentials_from_auth_uses_pyairbyte_secret_lookup( + monkeypatch: pytest.MonkeyPatch, +) -> None: + secrets = { + constants.CLOUD_BEARER_TOKEN_ENV_VAR: SecretString("test-bearer-token"), + constants.CLOUD_WORKSPACE_ID_ENV_VAR: SecretString("test-workspace-id"), + } + + def fake_try_get_secret( + secret_name: str, + /, + *, + default: str | SecretString | None = None, + **_: object, + ) -> SecretString | str | None: + return secrets.get(secret_name, default) + + monkeypatch.setattr(cloud_credentials, "try_get_secret", fake_try_get_secret) + + credentials = cloud_credentials._AirbyteCredentials.from_auth() + + assert credentials.bearer_token == "test-bearer-token" + assert credentials.workspace_id == "test-workspace-id" + + +def test_cloud_client_list_workspaces_forwards_limit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured_limit = None + + def fake_list_workspaces( + *, + limit: int | None = None, + **_: object, + ) -> list[object]: + nonlocal captured_limit + captured_limit = limit + return [] + + monkeypatch.setattr(api_util, "list_workspaces", fake_list_workspaces) + + CloudClient(bearer_token="token").list_workspaces(limit=3) + + assert captured_limit == 3 + + +def test_cloud_client_list_workspaces_in_organization_applies_name_filter_before_limit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured_limit = None + + def fake_list_workspaces_in_organization( + *, + limit: int | None = None, + **_: object, + ) -> list[dict[str, object]]: + nonlocal captured_limit + captured_limit = limit + return [ + {"name": "miss"}, + {"name": "target-one"}, + {"name": "target-two"}, + ] + + monkeypatch.setattr( + api_util, + "list_workspaces_in_organization", + fake_list_workspaces_in_organization, + ) + + result = CloudClient( + bearer_token="token", + organization_id="organization-id", + ).list_workspaces( + name_filter=lambda name: name.startswith("target"), + limit=1, + ) + + assert captured_limit is None + assert result == [{"name": "target-one"}] + + +def test_cloud_organization_fetch_returns_cached_info_after_refresh_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + responses: list[dict[str, object] | Exception] = [ + {"organizationName": "cached"}, + RuntimeError("temporary error"), + ] + + def fake_get_organization_info(**_: object) -> dict[str, object]: + response = responses.pop(0) + if isinstance(response, Exception): + raise response + return response + + monkeypatch.setattr(api_util, "get_organization_info", fake_get_organization_info) + organization = CloudOrganization("organization-id", bearer_token="token") + + assert organization._fetch_organization_info() == {"organizationName": "cached"} # noqa: SLF001 + assert organization._fetch_organization_info(force_refresh=True) == { # noqa: SLF001 + "organizationName": "cached" + }