diff --git a/src/drs/auth.py b/src/drs/auth.py index 051f365..9e1b13a 100644 --- a/src/drs/auth.py +++ b/src/drs/auth.py @@ -34,32 +34,110 @@ class DrsConfig(BaseModel): project_id: str +def _extract_flat_values(raw: dict[str, Any]) -> dict[str, Any]: + """Extract config values from a flat (legacy) config dict.""" + values = { + "uri": raw.get("uri", raw.get("endpoint")), + "pat": raw.get("pat", raw.get("token")), + "project_id": raw.get("project_id", raw.get("projectId")), + } + return {k: v for k, v in values.items() if v is not None} + + +def _resolve_profile(raw: dict[str, Any], profile: str | None) -> dict[str, Any]: + """Resolve a profile from the config file. + + If the file uses the new ``profiles`` format, look up the requested + profile (falling back to ``default_profile``). If the file is the + legacy flat format (no ``profiles`` key), treat it as a single + implicit profile. + """ + profiles = raw.get("profiles") + if profiles is None: + # Legacy flat format — treat as single implicit profile + return _extract_flat_values(raw) + + # New profiles format + name = profile or raw.get("default_profile") or next(iter(profiles), None) + if not name or name not in profiles: + return {} + return _extract_flat_values(profiles[name]) + + +def read_config_file(config_path: Path | None = None) -> dict[str, Any]: + """Read and return the raw YAML config dict. Returns {} if the file doesn't exist.""" + path = config_path or DEFAULT_CONFIG_PATH + if not path.exists(): + return {} + with path.open() as f: + return yaml.safe_load(f) or {} + + +def list_profiles(config_path: Path | None = None) -> dict[str, dict[str, Any]]: + """Return all profiles from the config file. + + For legacy flat configs, returns a single ``default`` profile. + """ + raw = read_config_file(config_path) + profiles = raw.get("profiles") + if profiles is None: + flat = _extract_flat_values(raw) + if flat: + return {"default": flat} + return {} + return {name: _extract_flat_values(vals) for name, vals in profiles.items()} + + +def get_default_profile_name(config_path: Path | None = None) -> str | None: + """Return the name of the default profile, or None.""" + raw = read_config_file(config_path) + profiles = raw.get("profiles") + if profiles is None: + # Legacy flat format — the implicit profile name is "default" + flat = _extract_flat_values(raw) + return "default" if flat else None + return raw.get("default_profile") or next(iter(profiles), None) + + +def set_default_profile(profile_name: str, config_path: Path | None = None) -> None: + """Set the default profile in the config file.""" + path = config_path or DEFAULT_CONFIG_PATH + raw = read_config_file(path) + profiles = raw.get("profiles", {}) + if profile_name not in profiles: + raise ValueError(f"Profile '{profile_name}' not found in config") + raw["default_profile"] = profile_name + _write_raw_config(raw, path) + + +def _write_raw_config(raw: dict[str, Any], config_path: Path) -> None: + """Write a raw config dict to YAML.""" + config_path.parent.mkdir(parents=True, exist_ok=True) + header = "# Dremio CLI config — generated by 'dremio setup'\n# PAT is stored in plaintext. Keep this file private (mode 600).\n" + config_path.write_text(header + yaml.dump(raw, default_flow_style=False, sort_keys=False)) + config_path.chmod(0o600) + + def load_config( config_path: Path | None = None, *, + profile: str | None = None, cli_token: str | None = None, cli_project_id: str | None = None, cli_uri: str | None = None, ) -> DrsConfig: - """Load config with resolution order: CLI args > env vars > config file > defaults. + """Load config with resolution order: CLI args > env vars > config file profile > defaults. Authentication priority: 1. --token CLI arg 2. DREMIO_TOKEN / DREMIO_PAT env var - 3. Config file pat/token field + 3. Config file profile (selected via --profile / DREMIO_PROFILE / default_profile) """ # -- Config file (lowest priority) -- - file_values: dict[str, Any] = {} - path = config_path or DEFAULT_CONFIG_PATH - if path.exists(): - with path.open() as f: - raw = yaml.safe_load(f) or {} - file_values = { - "uri": raw.get("uri", raw.get("endpoint")), - "pat": raw.get("pat", raw.get("token")), - "project_id": raw.get("project_id", raw.get("projectId")), - } - file_values = {k: v for k, v in file_values.items() if v is not None} + raw = read_config_file(config_path) + # Profile selection: CLI --profile > DREMIO_PROFILE env var > default_profile in file + effective_profile = profile or os.environ.get("DREMIO_PROFILE") + file_values = _resolve_profile(raw, effective_profile) # -- Env vars (override file) -- env_values: dict[str, Any] = {} diff --git a/src/drs/cli.py b/src/drs/cli.py index 8a86e4d..98932b4 100644 --- a/src/drs/cli.py +++ b/src/drs/cli.py @@ -30,6 +30,7 @@ from drs.client import DremioClient from drs.commands import ( chat, + context, engine, folder, grant, @@ -68,6 +69,7 @@ app.add_typer(grant.app, name="grant") app.add_typer(project.app, name="project") app.add_typer(chat.app, name="chat") +app.add_typer(context.app, name="context") app.command("setup")(setup.setup_command) # Global state for config @@ -79,6 +81,7 @@ def main( ctx: typer.Context, config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"), + profile: str | None = typer.Option(None, "--profile", "-p", help="Named profile to use from config file"), token: str | None = typer.Option(None, "--token", help="Dremio personal access token (PAT)"), project_id: str | None = typer.Option(None, "--project-id", help="Dremio Cloud project ID"), uri: str | None = typer.Option( @@ -110,6 +113,7 @@ def main( global _cli_opts _cli_opts = { "config_path": Path(config) if config else None, + "profile": profile, "cli_token": token, "cli_project_id": project_id, "cli_uri": uri, @@ -126,6 +130,7 @@ def get_config() -> DrsConfig: try: _config = load_config( _cli_opts.get("config_path"), + profile=_cli_opts.get("profile"), cli_token=_cli_opts.get("cli_token"), cli_project_id=_cli_opts.get("cli_project_id"), cli_uri=_cli_opts.get("cli_uri"), diff --git a/src/drs/commands/context.py b/src/drs/commands/context.py new file mode 100644 index 0000000..bf1b892 --- /dev/null +++ b/src/drs/commands/context.py @@ -0,0 +1,113 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""dremio context — switch between named profiles.""" + +from __future__ import annotations + +from pathlib import Path + +import typer +from rich.console import Console +from rich.table import Table + +from drs.auth import ( + DEFAULT_CONFIG_PATH, + DEFAULT_URI, + get_default_profile_name, + list_profiles, + set_default_profile, +) + +app = typer.Typer( + help="Manage named configuration profiles.", + context_settings={"help_option_names": ["-h", "--help"]}, +) + +console = Console() +err_console = Console(stderr=True) + + +def _get_config_path(ctx: typer.Context) -> Path: + """Resolve config path from the global --config flag.""" + if ctx.obj and ctx.obj.get("config_path"): + return ctx.obj["config_path"] + return DEFAULT_CONFIG_PATH + + +@app.command("list") +def cli_list(ctx: typer.Context) -> None: + """List all configured profiles.""" + config_path = _get_config_path(ctx) + profiles = list_profiles(config_path) + + if not profiles: + err_console.print( + "[yellow]No profiles configured.[/yellow]\nRun [bold cyan]dremio setup[/bold cyan] to create one." + ) + raise typer.Exit(1) + + default_name = get_default_profile_name(config_path) + + table = Table(show_header=True, header_style="bold") + table.add_column("") + table.add_column("Profile") + table.add_column("Region") + table.add_column("Project ID") + + for name, values in profiles.items(): + is_active = name == default_name + marker = "*" if is_active else " " + uri = values.get("uri", DEFAULT_URI) + region = "EU" if "eu.dremio" in uri else "US" + project_id = values.get("project_id", "—") + style = "bold" if is_active else "" + table.add_row(marker, name, region, project_id, style=style) + + console.print(table) + + +@app.command("use") +def cli_use( + ctx: typer.Context, + name: str = typer.Argument(help="Profile name to set as default"), +) -> None: + """Switch the default profile.""" + config_path = _get_config_path(ctx) + try: + set_default_profile(name, config_path) + except ValueError as exc: + err_console.print(f"[red]{exc}[/red]") + profiles = list_profiles(config_path) + if profiles: + err_console.print(f"Available profiles: {', '.join(profiles)}") + raise typer.Exit(1) + + console.print(f"Switched to profile [bold]{name}[/bold].") + + +@app.command("current") +def cli_current(ctx: typer.Context) -> None: + """Show the active (default) profile name.""" + config_path = _get_config_path(ctx) + name = get_default_profile_name(config_path) + + if not name: + err_console.print( + "[yellow]No profiles configured.[/yellow]\nRun [bold cyan]dremio setup[/bold cyan] to create one." + ) + raise typer.Exit(1) + + console.print(name) diff --git a/src/drs/commands/reflection.py b/src/drs/commands/reflection.py index 2171ea2..79b0ead 100644 --- a/src/drs/commands/reflection.py +++ b/src/drs/commands/reflection.py @@ -60,15 +60,42 @@ async def create(client: DremioClient, path: str, rtype: str, display_fields: li raise handle_api_error(exc) from exc -async def list_reflections(client: DremioClient, path: str) -> dict: - """List reflections on a dataset via sys.project.reflections.""" - parts = parse_path(path) - try: - entity = await client.get_catalog_by_path(parts) - except httpx.HTTPStatusError as exc: - raise handle_api_error(exc) from exc - dataset_id = entity["id"] - sql = f"SELECT * FROM sys.project.reflections WHERE dataset_id = '{dataset_id}'" +async def list_reflections( + client: DremioClient, + path: str | None = None, + *, + rtype: str | None = None, + status: str | None = None, + dataset_name: str | None = None, + limit: int | None = None, +) -> dict: + """List reflections via sys.project.reflections. + + When *path* is given, only reflections for that dataset are returned. + When omitted, all reflections in the project are returned. + Optional filters narrow results by *rtype*, *status*, or *dataset_name*. + An optional *limit* caps the number of rows returned. + """ + sql = "SELECT * FROM sys.project.reflections" + conditions: list[str] = [] + if path is not None: + parts = parse_path(path) + try: + entity = await client.get_catalog_by_path(parts) + except httpx.HTTPStatusError as exc: + raise handle_api_error(exc) from exc + dataset_id = entity["id"] + conditions.append(f"dataset_id = '{dataset_id}'") + if rtype is not None: + conditions.append(f"type = '{rtype.upper()}'") + if status is not None: + conditions.append(f"status = '{status.upper()}'") + if dataset_name is not None: + conditions.append(f"dataset_name ILIKE '%{dataset_name}%'") + if conditions: + sql += " WHERE " + " AND ".join(conditions) + if limit is not None: + sql += f" LIMIT {limit}" return await run_query(client, sql) @@ -142,12 +169,20 @@ def cli_create( @app.command("list") def cli_list( - path: str = typer.Argument(help="Dot-separated dataset path"), + path: str = typer.Argument(None, help="Dot-separated dataset path (omit to list all reflections)"), + rtype: str = typer.Option(None, "--type", "-t", help="Filter by reflection type: raw or aggregation"), + status: str = typer.Option(None, "--status", "-s", help="Filter by status (e.g. CAN_ACCELERATE, FAILED, EXPIRED)"), + dataset_name: str = typer.Option(None, "--dataset-name", "-d", help="Filter by dataset name (substring match)"), + limit: int = typer.Option(None, "--limit", "-l", help="Maximum number of reflections to return"), fmt: OutputFormat = typer.Option(OutputFormat.json, "--output", "-o", help="Output format"), ) -> None: - """List all reflections defined on a dataset.""" + """List reflections. Shows all project reflections, or those for a specific dataset.""" client = _get_client() - _run_command(list_reflections(client, path), client, fmt) + _run_command( + list_reflections(client, path, rtype=rtype, status=status, dataset_name=dataset_name, limit=limit), + client, + fmt, + ) @app.command("get") diff --git a/src/drs/commands/setup.py b/src/drs/commands/setup.py index fb77ea2..8af791a 100644 --- a/src/drs/commands/setup.py +++ b/src/drs/commands/setup.py @@ -18,18 +18,18 @@ from __future__ import annotations import asyncio +import re import sys from pathlib import Path from typing import Any import httpx import typer -import yaml from rich.console import Console from rich.panel import Panel from rich.text import Text -from drs.auth import DEFAULT_CONFIG_PATH, DEFAULT_URI, DrsConfig +from drs.auth import DEFAULT_CONFIG_PATH, DEFAULT_URI, DrsConfig, _write_raw_config, read_config_file from drs.client import DremioClient REGIONS = { @@ -41,21 +41,23 @@ err_console = Console(stderr=True) -async def validate_credentials(uri: str, pat: str, project_id: str) -> tuple[bool, str, dict[str, Any] | None]: - """Test credentials by calling get_project(). Returns (success, message, project_data).""" - config = DrsConfig(uri=uri, pat=pat, project_id=project_id) +async def validate_credentials(uri: str, pat: str) -> tuple[bool, str, list[dict[str, Any]] | None]: + """Test credentials by calling list_projects(). Returns (success, message, projects_list).""" + # We need a minimal config to create a client — project_id is not needed for listing projects. + config = DrsConfig(uri=uri, pat=pat, project_id="__discovery__") client = DremioClient(config) try: - project = await client.get_project(project_id) - return True, f"Connected to project: {project.get('name', project_id)}", project + result = await client.list_projects() + projects = result.get("data", []) if isinstance(result, dict) else result + if not projects: + return False, "No projects found — your account may not have any projects in this region.", None + return True, f"Authenticated — found {len(projects)} project(s).", projects except httpx.HTTPStatusError as exc: code = exc.response.status_code if code == 401: return False, "Authentication failed — your PAT is invalid or expired.", None if code == 403: - return False, "Access denied — your PAT may lack permissions, or the project is inaccessible.", None - if code == 404: - return False, "Project not found — check the project ID.", None + return False, "Access denied — your PAT may lack permissions.", None return False, f"API error (HTTP {code}): {exc.response.text[:200]}", None except httpx.ConnectError: return False, f"Cannot reach {uri} — check your region selection and network.", None @@ -65,18 +67,59 @@ async def validate_credentials(uri: str, pat: str, project_id: str) -> tuple[boo await client.close() -def write_config(uri: str, pat: str, project_id: str, config_path: Path) -> None: - """Write YAML config file, creating parent directories as needed.""" +def _slugify(name: str) -> str: + """Convert a project/org name to a URL-friendly profile name.""" + slug = re.sub(r"[^a-z0-9]+", "-", name.lower()).strip("-") + return slug or "default" + + +def write_profile( + uri: str, + pat: str, + project_id: str, + profile_name: str, + set_default: bool, + config_path: Path, +) -> None: + """Write a profile to the config file, preserving existing profiles.""" + raw = read_config_file(config_path) + + # Migrate legacy flat config to profiles format if needed + if "profiles" not in raw: + old_values: dict[str, Any] = {} + for key in ("uri", "endpoint", "pat", "token", "project_id", "projectId"): + if key in raw: + old_values[key] = raw.pop(key) + if old_values: + migrated = {} + if v := old_values.get("uri", old_values.get("endpoint")): + migrated["uri"] = v + if v := old_values.get("pat", old_values.get("token")): + migrated["pat"] = v + if v := old_values.get("project_id", old_values.get("projectId")): + migrated["project_id"] = v + if migrated: + raw.setdefault("profiles", {})["default"] = migrated + if "default_profile" not in raw: + raw["default_profile"] = "default" + + # Add/update the profile + raw.setdefault("profiles", {})[profile_name] = _build_profile_dict(uri, pat, project_id) + + if set_default or "default_profile" not in raw: + raw["default_profile"] = profile_name + + _write_raw_config(raw, config_path) + + +def _build_profile_dict(uri: str, pat: str, project_id: str) -> dict[str, str]: + """Build a profile dict, omitting uri if it's the default.""" data: dict[str, str] = {} if uri != DEFAULT_URI: data["uri"] = uri data["pat"] = pat data["project_id"] = project_id - - config_path.parent.mkdir(parents=True, exist_ok=True) - header = "# Dremio CLI config — generated by 'dremio setup'\n# PAT is stored in plaintext. Keep this file private (mode 600).\n" - config_path.write_text(header + yaml.dump(data, default_flow_style=False, sort_keys=False)) - config_path.chmod(0o600) + return data def _prompt_region() -> tuple[str, str]: @@ -122,26 +165,56 @@ def _prompt_pat(app_url: str) -> str: console.print("[red]PAT cannot be empty.[/red]") -def _prompt_project_id(app_url: str) -> str: - """Prompt for Project ID with step-by-step instructions.""" +def _prompt_project(projects: list[dict[str, Any]]) -> dict[str, Any]: + """Show discovered projects and let the user pick one.""" + console.print() + lines = ["[bold]Step 3: Choose a project[/bold]\n"] + for i, proj in enumerate(projects, 1): + name = proj.get("name", "Unnamed") + pid = proj.get("id", "???") + lines.append(f" [cyan]{i}[/cyan]) {name} [dim]({pid})[/dim]") + + console.print(Panel("\n".join(lines), title="Projects", border_style="blue")) + + if len(projects) == 1: + console.print(f" → Auto-selected: [bold]{projects[0].get('name', 'Unnamed')}[/bold]") + return projects[0] + + while True: + choice = typer.prompt(f"Enter 1-{len(projects)}", default="1").strip() + try: + idx = int(choice) - 1 + if 0 <= idx < len(projects): + selected = projects[idx] + console.print(f" → Project: [bold]{selected.get('name', 'Unnamed')}[/bold]") + return selected + except ValueError: + pass + console.print(f"[red]Please enter a number between 1 and {len(projects)}.[/red]") + + +def _prompt_profile_name(suggested: str, existing_profiles: dict[str, Any]) -> str: + """Prompt for a profile name with a suggested default.""" console.print() console.print( Panel( - "[bold]Step 3: Find your Project ID[/bold]\n\n" - f" 1. Open [link={app_url}]{app_url}[/link]\n" - " 2. Select your project from the top-left dropdown\n" - " 3. Go to [bold]Project Settings[/bold] → [bold]General[/bold]\n" - " 4. Copy the [bold]Project ID[/bold] (a UUID like [dim]a1b2c3d4-...[/dim])\n\n" - "[dim]Tip: The project ID is also visible in the URL bar.[/dim]", - title="Project ID", + "[bold]Step 4: Name this profile[/bold]\n\n" + " A short name to identify this configuration.\n" + f" [dim]Existing profiles: {', '.join(existing_profiles) if existing_profiles else '(none)'}[/dim]", + title="Profile Name", border_style="blue", ) ) while True: - project_id = typer.prompt("Paste your Project ID").strip() - if project_id: - return project_id - console.print("[red]Project ID cannot be empty.[/red]") + name = typer.prompt("Profile name", default=suggested).strip() + if not name: + console.print("[red]Profile name cannot be empty.[/red]") + continue + if name in existing_profiles: + if typer.confirm(f"Profile '{name}' already exists. Overwrite it?", default=False): + return name + continue + return name def setup_command( @@ -149,8 +222,10 @@ def setup_command( ) -> None: """Interactive setup wizard — configure credentials for Dremio Cloud. - Writes configuration to ~/.config/dremioai/config.yaml (or the path - specified with --config). Prompts for region (US/EU), PAT, and project ID. + Walks you through connecting to your Dremio Cloud account. Discovers + your projects automatically and saves the configuration as a named + profile in ~/.config/dremioai/config.yaml (or the path specified + with --config). """ if not sys.stdin.isatty(): err_console.print( @@ -167,6 +242,10 @@ def setup_command( global_config = ctx.obj.get("config_path") if ctx.obj else None config_path = global_config if global_config else DEFAULT_CONFIG_PATH + # Read existing config for profile awareness + existing_raw = read_config_file(config_path) + existing_profiles = existing_raw.get("profiles", {}) + # Welcome console.print() console.print( @@ -174,33 +253,23 @@ def setup_command( "This wizard will help you connect the Dremio CLI to your Dremio Cloud account.\n\n" "You'll need:\n" " • A [bold]Dremio Cloud account[/bold] (sign up at [link=https://app.dremio.cloud]app.dremio.cloud[/link])\n" - " • A [bold]Personal Access Token[/bold] (we'll walk you through creating one)\n" - " • A [bold]Project ID[/bold] (we'll show you where to find it)", + " • A [bold]Personal Access Token[/bold] (we'll walk you through creating one)\n\n" + "We'll discover your projects automatically after you authenticate.", title="[bold]Dremio CLI Setup[/bold]", border_style="cyan", ) ) - # Check existing config - if config_path.exists(): - console.print(f"\n[yellow]A config file already exists at {config_path}[/yellow]") - if not typer.confirm("Overwrite it?", default=False): - console.print("Setup cancelled.") - raise typer.Exit(0) - # Step 1: Region api_uri, app_url = _prompt_region() # Step 2: PAT (with retry loop) pat = _prompt_pat(app_url) - # Step 3: Project ID (with retry loop) - project_id = _prompt_project_id(app_url) - - # Validate + # Validate and discover projects console.print() - with console.status("[bold]Validating credentials...[/bold]", spinner="dots"): - ok, message, project_data = asyncio.run(validate_credentials(api_uri, pat, project_id)) + with console.status("[bold]Authenticating and discovering projects...[/bold]", spinner="dots"): + ok, message, projects = asyncio.run(validate_credentials(api_uri, pat)) while not ok: console.print(f"\n[red]✗ {message}[/red]") @@ -211,39 +280,47 @@ def setup_command( if "Authentication" in message: console.print("[dim]Let's try the PAT again.[/dim]") pat = _prompt_pat(app_url) - elif "Access denied" in message: - console.print( - "\n [cyan]1[/cyan]) Re-enter PAT (token may lack permissions)\n [cyan]2[/cyan]) Re-enter Project ID" - ) - choice = typer.prompt("Which would you like to fix?", default="1").strip() - if choice == "2": - project_id = _prompt_project_id(app_url) - else: - pat = _prompt_pat(app_url) - elif "Project" in message: - console.print("[dim]Let's try the Project ID again.[/dim]") - project_id = _prompt_project_id(app_url) - else: + elif "Cannot reach" in message: console.print("[dim]Let's try the region again.[/dim]") api_uri, app_url = _prompt_region() pat = _prompt_pat(app_url) - project_id = _prompt_project_id(app_url) + else: + pat = _prompt_pat(app_url) console.print() - with console.status("[bold]Validating credentials...[/bold]", spinner="dots"): - ok, message, project_data = asyncio.run(validate_credentials(api_uri, pat, project_id)) + with console.status("[bold]Authenticating and discovering projects...[/bold]", spinner="dots"): + ok, message, projects = asyncio.run(validate_credentials(api_uri, pat)) - # Success — write config - project_name = project_data.get("name", project_id) if project_data else project_id + assert projects is not None console.print(f"\n[green]✓ {message}[/green]") - write_config(api_uri, pat, project_id, config_path) + # Step 3: Choose a project + selected_project = _prompt_project(projects) + project_id = selected_project["id"] + project_name = selected_project.get("name", project_id) + + # Step 4: Name this profile + suggested_name = _slugify(project_name) + profile_name = _prompt_profile_name(suggested_name, existing_profiles) + + # Step 5: Set as default? + set_default = True + if existing_profiles and existing_raw.get("default_profile"): + set_default = typer.confirm("Set as default profile?", default=True) + + # Write config + write_profile(api_uri, pat, project_id, profile_name, set_default, config_path) console.print() success = Text() success.append("Config saved to ", style="bold") success.append(str(config_path), style="cyan") + success.append(f"\nProfile: {profile_name}") success.append(f"\nProject: {project_name}") + if set_default: + success.append(" (default)", style="dim") success.append("\n\nTry it out:\n ") success.append('dremio query run "SELECT 1 AS hello"', style="bold cyan") + if not set_default: + success.append(f'\n dremio --profile {profile_name} query run "SELECT 1"', style="bold cyan") console.print(Panel(success, title="[bold green]Setup complete[/bold green]", border_style="green")) diff --git a/src/drs/introspect.py b/src/drs/introspect.py index d68bc3c..b317e0a 100644 --- a/src/drs/introspect.py +++ b/src/drs/introspect.py @@ -302,11 +302,15 @@ "reflection.list": { "group": "reflection", "command": "list", - "description": "List all reflections defined on a dataset.", + "description": "List reflections. Shows all project reflections, or those for a specific dataset.", "mechanism": "SQL", - "sql_template": "SELECT * FROM sys.project.reflections WHERE dataset_id = '{dataset_id}'", + "sql_template": "SELECT * FROM sys.project.reflections [WHERE ...] [LIMIT {limit}]", "parameters": [ - {"name": "path", "type": "string", "required": True, "positional": True}, + {"name": "path", "type": "string", "required": False, "positional": True}, + {"name": "type", "type": "string", "required": False, "flag": "--type/-t"}, + {"name": "status", "type": "string", "required": False, "flag": "--status/-s"}, + {"name": "dataset_name", "type": "string", "required": False, "flag": "--dataset-name/-d"}, + {"name": "limit", "type": "integer", "required": False, "flag": "--limit/-l"}, {"name": "output", "type": "enum", "required": False, "default": "json", "enum": ["json", "csv", "pretty"]}, ], }, diff --git a/tests/test_auth.py b/tests/test_auth.py index 9187f81..c034657 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -25,7 +25,16 @@ import yaml from pydantic import ValidationError -from drs.auth import load_config +from drs.auth import ( + get_default_profile_name, + list_profiles, + load_config, + set_default_profile, +) + +# --------------------------------------------------------------------------- +# Legacy flat config tests (backwards compatibility) +# --------------------------------------------------------------------------- def test_config_from_env_vars(tmp_path: Path) -> None: @@ -56,7 +65,7 @@ def test_config_from_file(tmp_path: Path) -> None: ) with patch.dict(os.environ, {}, clear=False): - for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]: + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI", "DREMIO_PROFILE"]: os.environ.pop(k, None) config = load_config(config_file) @@ -81,6 +90,7 @@ def test_config_env_overrides_file(tmp_path: Path) -> None: os.environ.pop("DREMIO_PAT", None) os.environ.pop("DREMIO_PROJECT_ID", None) os.environ.pop("DREMIO_URI", None) + os.environ.pop("DREMIO_PROFILE", None) config = load_config(config_file) assert config.pat == "env-token" # env wins @@ -90,7 +100,7 @@ def test_config_env_overrides_file(tmp_path: Path) -> None: def test_config_missing_required_field(tmp_path: Path) -> None: """Missing pat or project_id should raise ValidationError.""" with patch.dict(os.environ, {}, clear=False): - for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]: + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI", "DREMIO_PROFILE"]: os.environ.pop(k, None) with pytest.raises(ValidationError): load_config(tmp_path / "nonexistent.yaml") @@ -110,7 +120,7 @@ def test_config_dremio_mcp_compat(tmp_path: Path) -> None: ) with patch.dict(os.environ, {}, clear=False): - for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]: + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI", "DREMIO_PROFILE"]: os.environ.pop(k, None) config = load_config(config_file) @@ -162,3 +172,200 @@ def test_cli_args_override_env(tmp_path: Path) -> None: assert config.pat == "cli-token" assert config.project_id == "cli-project" assert config.uri == "https://api.eu.dremio.cloud" + + +# --------------------------------------------------------------------------- +# Profile-based config tests +# --------------------------------------------------------------------------- + + +def _write_profiles_config(tmp_path: Path, profiles: dict, default_profile: str | None = None) -> Path: + """Helper to write a profiles-format config file.""" + config_file = tmp_path / "config.yaml" + data: dict = {"profiles": profiles} + if default_profile: + data["default_profile"] = default_profile + config_file.write_text(yaml.dump(data)) + return config_file + + +def test_load_config_with_profiles(tmp_path: Path) -> None: + """Should load values from the default profile.""" + config_file = _write_profiles_config( + tmp_path, + { + "prod": {"pat": "prod-pat", "project_id": "prod-proj"}, + "dev": {"pat": "dev-pat", "project_id": "dev-proj", "uri": "https://api.eu.dremio.cloud"}, + }, + default_profile="prod", + ) + + with patch.dict(os.environ, {}, clear=False): + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI", "DREMIO_PROFILE"]: + os.environ.pop(k, None) + config = load_config(config_file) + + assert config.pat == "prod-pat" + assert config.project_id == "prod-proj" + assert config.uri == "https://api.dremio.cloud" + + +def test_load_config_profile_arg(tmp_path: Path) -> None: + """--profile arg should select a specific profile.""" + config_file = _write_profiles_config( + tmp_path, + { + "prod": {"pat": "prod-pat", "project_id": "prod-proj"}, + "dev": {"pat": "dev-pat", "project_id": "dev-proj", "uri": "https://api.eu.dremio.cloud"}, + }, + default_profile="prod", + ) + + with patch.dict(os.environ, {}, clear=False): + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI", "DREMIO_PROFILE"]: + os.environ.pop(k, None) + config = load_config(config_file, profile="dev") + + assert config.pat == "dev-pat" + assert config.project_id == "dev-proj" + assert config.uri == "https://api.eu.dremio.cloud" + + +def test_load_config_dremio_profile_env(tmp_path: Path) -> None: + """DREMIO_PROFILE env var should select a profile.""" + config_file = _write_profiles_config( + tmp_path, + { + "prod": {"pat": "prod-pat", "project_id": "prod-proj"}, + "dev": {"pat": "dev-pat", "project_id": "dev-proj"}, + }, + default_profile="prod", + ) + + with patch.dict(os.environ, {"DREMIO_PROFILE": "dev"}, clear=False): + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]: + os.environ.pop(k, None) + config = load_config(config_file) + + assert config.pat == "dev-pat" + assert config.project_id == "dev-proj" + + +def test_profile_arg_overrides_env(tmp_path: Path) -> None: + """CLI --profile should override DREMIO_PROFILE env var.""" + config_file = _write_profiles_config( + tmp_path, + { + "prod": {"pat": "prod-pat", "project_id": "prod-proj"}, + "dev": {"pat": "dev-pat", "project_id": "dev-proj"}, + }, + default_profile="prod", + ) + + with patch.dict(os.environ, {"DREMIO_PROFILE": "prod"}, clear=False): + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI"]: + os.environ.pop(k, None) + config = load_config(config_file, profile="dev") + + assert config.pat == "dev-pat" + + +def test_env_vars_override_profile_values(tmp_path: Path) -> None: + """Env vars (token, project_id) should override profile file values.""" + config_file = _write_profiles_config( + tmp_path, + {"prod": {"pat": "file-pat", "project_id": "file-proj"}}, + default_profile="prod", + ) + + with patch.dict(os.environ, {"DREMIO_TOKEN": "env-pat"}, clear=False): + for k in ["DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI", "DREMIO_PROFILE"]: + os.environ.pop(k, None) + config = load_config(config_file) + + assert config.pat == "env-pat" # env wins + assert config.project_id == "file-proj" # from profile + + +def test_profiles_fallback_to_first(tmp_path: Path) -> None: + """When no default_profile is set, fall back to the first profile.""" + config_file = _write_profiles_config( + tmp_path, + {"alpha": {"pat": "a-pat", "project_id": "a-proj"}}, + ) + + with patch.dict(os.environ, {}, clear=False): + for k in ["DREMIO_TOKEN", "DREMIO_PAT", "DREMIO_PROJECT_ID", "DREMIO_URI", "DREMIO_PROFILE"]: + os.environ.pop(k, None) + config = load_config(config_file) + + assert config.pat == "a-pat" + + +# --------------------------------------------------------------------------- +# list_profiles / get_default_profile_name / set_default_profile +# --------------------------------------------------------------------------- + + +def test_list_profiles_with_profiles_format(tmp_path: Path) -> None: + config_file = _write_profiles_config( + tmp_path, + { + "prod": {"pat": "p1", "project_id": "proj1"}, + "dev": {"pat": "p2", "project_id": "proj2"}, + }, + ) + profiles = list_profiles(config_file) + assert set(profiles.keys()) == {"prod", "dev"} + assert profiles["prod"]["pat"] == "p1" + + +def test_list_profiles_with_legacy_flat(tmp_path: Path) -> None: + config_file = tmp_path / "config.yaml" + config_file.write_text(yaml.dump({"pat": "tok", "project_id": "proj"})) + profiles = list_profiles(config_file) + assert set(profiles.keys()) == {"default"} + assert profiles["default"]["pat"] == "tok" + + +def test_list_profiles_empty(tmp_path: Path) -> None: + profiles = list_profiles(tmp_path / "nonexistent.yaml") + assert profiles == {} + + +def test_get_default_profile_name_profiles(tmp_path: Path) -> None: + config_file = _write_profiles_config( + tmp_path, + {"prod": {"pat": "p1", "project_id": "proj1"}}, + default_profile="prod", + ) + assert get_default_profile_name(config_file) == "prod" + + +def test_get_default_profile_name_legacy(tmp_path: Path) -> None: + config_file = tmp_path / "config.yaml" + config_file.write_text(yaml.dump({"pat": "tok", "project_id": "proj"})) + assert get_default_profile_name(config_file) == "default" + + +def test_set_default_profile(tmp_path: Path) -> None: + config_file = _write_profiles_config( + tmp_path, + { + "prod": {"pat": "p1", "project_id": "proj1"}, + "dev": {"pat": "p2", "project_id": "proj2"}, + }, + default_profile="prod", + ) + set_default_profile("dev", config_file) + assert get_default_profile_name(config_file) == "dev" + + +def test_set_default_profile_not_found(tmp_path: Path) -> None: + config_file = _write_profiles_config( + tmp_path, + {"prod": {"pat": "p1", "project_id": "proj1"}}, + default_profile="prod", + ) + with pytest.raises(ValueError, match="not found"): + set_default_profile("nonexistent", config_file) diff --git a/tests/test_commands/test_context.py b/tests/test_commands/test_context.py new file mode 100644 index 0000000..66733bb --- /dev/null +++ b/tests/test_commands/test_context.py @@ -0,0 +1,113 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for dremio context command.""" + +from __future__ import annotations + +from pathlib import Path + +import yaml +from typer.testing import CliRunner + +from drs.cli import app + +runner = CliRunner() + + +def _write_config(tmp_path: Path, data: dict) -> Path: + config_path = tmp_path / "config.yaml" + config_path.write_text(yaml.dump(data)) + return config_path + + +def test_context_list(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "default_profile": "prod", + "profiles": { + "prod": {"pat": "p1", "project_id": "proj-1"}, + "dev": {"pat": "p2", "project_id": "proj-2", "uri": "https://api.eu.dremio.cloud"}, + }, + }, + ) + result = runner.invoke(app, ["--config", str(config_path), "context", "list"]) + assert result.exit_code == 0 + assert "prod" in result.output + assert "dev" in result.output + assert "EU" in result.output + + +def test_context_list_empty(tmp_path: Path) -> None: + config_path = tmp_path / "nonexistent.yaml" + result = runner.invoke(app, ["--config", str(config_path), "context", "list"]) + assert result.exit_code == 1 + assert "No profiles" in result.output + + +def test_context_use(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "default_profile": "prod", + "profiles": { + "prod": {"pat": "p1", "project_id": "proj-1"}, + "dev": {"pat": "p2", "project_id": "proj-2"}, + }, + }, + ) + result = runner.invoke(app, ["--config", str(config_path), "context", "use", "dev"]) + assert result.exit_code == 0 + assert "dev" in result.output + + # Verify the file was updated + data = yaml.safe_load(config_path.read_text()) + assert data["default_profile"] == "dev" + + +def test_context_use_not_found(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "default_profile": "prod", + "profiles": {"prod": {"pat": "p1", "project_id": "proj-1"}}, + }, + ) + result = runner.invoke(app, ["--config", str(config_path), "context", "use", "nonexistent"]) + assert result.exit_code == 1 + assert "not found" in result.output + + +def test_context_current(tmp_path: Path) -> None: + config_path = _write_config( + tmp_path, + { + "default_profile": "staging", + "profiles": { + "staging": {"pat": "p1", "project_id": "proj-1"}, + }, + }, + ) + result = runner.invoke(app, ["--config", str(config_path), "context", "current"]) + assert result.exit_code == 0 + assert "staging" in result.output + + +def test_context_current_no_config(tmp_path: Path) -> None: + config_path = tmp_path / "nonexistent.yaml" + result = runner.invoke(app, ["--config", str(config_path), "context", "current"]) + assert result.exit_code == 1 + assert "No profiles" in result.output diff --git a/tests/test_commands/test_reflection.py b/tests/test_commands/test_reflection.py index a6d657a..b2ca0de 100644 --- a/tests/test_commands/test_reflection.py +++ b/tests/test_commands/test_reflection.py @@ -17,11 +17,90 @@ from __future__ import annotations -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch import pytest -from drs.commands.reflection import delete, get_reflection, refresh +from drs.commands.reflection import delete, get_reflection, list_reflections, refresh + +QUERY_RESULT = {"job_id": "j1", "state": "COMPLETED", "rowCount": 2, "rows": [{"id": "r1"}, {"id": "r2"}]} + + +@pytest.mark.asyncio +async def test_list_reflections_all(mock_client) -> None: + """Omitting path queries all reflections without a WHERE clause.""" + with patch("drs.commands.reflection.run_query", new_callable=AsyncMock, return_value=QUERY_RESULT) as mock_rq: + result = await list_reflections(mock_client) + mock_rq.assert_called_once_with(mock_client, "SELECT * FROM sys.project.reflections") + assert result["rowCount"] == 2 + + +@pytest.mark.asyncio +async def test_list_reflections_for_dataset(mock_client) -> None: + """Providing a path filters by dataset_id.""" + mock_client.get_catalog_by_path = AsyncMock(return_value={"id": "ds-123"}) + with patch("drs.commands.reflection.run_query", new_callable=AsyncMock, return_value=QUERY_RESULT) as mock_rq: + result = await list_reflections(mock_client, path="space.my_table") + mock_rq.assert_called_once_with(mock_client, "SELECT * FROM sys.project.reflections WHERE dataset_id = 'ds-123'") + assert result["rowCount"] == 2 + + +@pytest.mark.asyncio +async def test_list_reflections_with_limit(mock_client) -> None: + """--limit appends a SQL LIMIT clause.""" + with patch("drs.commands.reflection.run_query", new_callable=AsyncMock, return_value=QUERY_RESULT) as mock_rq: + await list_reflections(mock_client, limit=50) + mock_rq.assert_called_once_with(mock_client, "SELECT * FROM sys.project.reflections LIMIT 50") + + +@pytest.mark.asyncio +async def test_list_reflections_dataset_with_limit(mock_client) -> None: + """Both path and limit combine WHERE and LIMIT.""" + mock_client.get_catalog_by_path = AsyncMock(return_value={"id": "ds-456"}) + with patch("drs.commands.reflection.run_query", new_callable=AsyncMock, return_value=QUERY_RESULT) as mock_rq: + await list_reflections(mock_client, path="space.ds", limit=10) + mock_rq.assert_called_once_with( + mock_client, "SELECT * FROM sys.project.reflections WHERE dataset_id = 'ds-456' LIMIT 10" + ) + + +@pytest.mark.asyncio +async def test_list_reflections_filter_by_type(mock_client) -> None: + """--type filters by reflection type.""" + with patch("drs.commands.reflection.run_query", new_callable=AsyncMock, return_value=QUERY_RESULT) as mock_rq: + await list_reflections(mock_client, rtype="raw") + mock_rq.assert_called_once_with(mock_client, "SELECT * FROM sys.project.reflections WHERE type = 'RAW'") + + +@pytest.mark.asyncio +async def test_list_reflections_filter_by_status(mock_client) -> None: + """--status filters by reflection status.""" + with patch("drs.commands.reflection.run_query", new_callable=AsyncMock, return_value=QUERY_RESULT) as mock_rq: + await list_reflections(mock_client, status="failed") + mock_rq.assert_called_once_with(mock_client, "SELECT * FROM sys.project.reflections WHERE status = 'FAILED'") + + +@pytest.mark.asyncio +async def test_list_reflections_filter_by_dataset_name(mock_client) -> None: + """--dataset-name filters with ILIKE substring match.""" + with patch("drs.commands.reflection.run_query", new_callable=AsyncMock, return_value=QUERY_RESULT) as mock_rq: + await list_reflections(mock_client, dataset_name="orders") + mock_rq.assert_called_once_with( + mock_client, "SELECT * FROM sys.project.reflections WHERE dataset_name ILIKE '%orders%'" + ) + + +@pytest.mark.asyncio +async def test_list_reflections_combined_filters(mock_client) -> None: + """Multiple filters combine with AND.""" + mock_client.get_catalog_by_path = AsyncMock(return_value={"id": "ds-789"}) + with patch("drs.commands.reflection.run_query", new_callable=AsyncMock, return_value=QUERY_RESULT) as mock_rq: + await list_reflections(mock_client, path="space.ds", rtype="raw", status="can_accelerate", limit=5) + mock_rq.assert_called_once_with( + mock_client, + "SELECT * FROM sys.project.reflections" + " WHERE dataset_id = 'ds-789' AND type = 'RAW' AND status = 'CAN_ACCELERATE' LIMIT 5", + ) @pytest.mark.asyncio diff --git a/tests/test_commands/test_setup.py b/tests/test_commands/test_setup.py index 2bab050..7195f28 100644 --- a/tests/test_commands/test_setup.py +++ b/tests/test_commands/test_setup.py @@ -26,128 +26,188 @@ from drs.auth import DEFAULT_URI from drs.cli import app -from drs.commands.setup import validate_credentials, write_config +from drs.commands.setup import _slugify, validate_credentials, write_profile runner = CliRunner() +# -- Unit tests for helpers -- -def test_write_config(tmp_path) -> None: + +def test_write_profile_new_file(tmp_path) -> None: config_path = tmp_path / "config.yaml" - write_config("https://api.eu.dremio.cloud", "my-pat", "my-project", config_path) + write_profile("https://api.eu.dremio.cloud", "my-pat", "my-project", "eu-prod", True, config_path) data = yaml.safe_load(config_path.read_text()) - assert data["uri"] == "https://api.eu.dremio.cloud" - assert data["pat"] == "my-pat" - assert data["project_id"] == "my-project" - # File should be owner-only readable + assert data["default_profile"] == "eu-prod" + assert data["profiles"]["eu-prod"]["uri"] == "https://api.eu.dremio.cloud" + assert data["profiles"]["eu-prod"]["pat"] == "my-pat" + assert data["profiles"]["eu-prod"]["project_id"] == "my-project" assert oct(config_path.stat().st_mode & 0o777) == "0o600" -def test_write_config_omits_default_uri(tmp_path) -> None: +def test_write_profile_omits_default_uri(tmp_path) -> None: config_path = tmp_path / "config.yaml" - write_config(DEFAULT_URI, "my-pat", "my-project", config_path) + write_profile(DEFAULT_URI, "my-pat", "my-project", "us-prod", True, config_path) data = yaml.safe_load(config_path.read_text()) - assert "uri" not in data - assert data["pat"] == "my-pat" - assert data["project_id"] == "my-project" + assert "uri" not in data["profiles"]["us-prod"] + assert data["profiles"]["us-prod"]["pat"] == "my-pat" -def test_write_config_creates_dirs(tmp_path) -> None: +def test_write_profile_creates_dirs(tmp_path) -> None: config_path = tmp_path / "nested" / "deep" / "config.yaml" - write_config(DEFAULT_URI, "my-pat", "my-project", config_path) + write_profile(DEFAULT_URI, "my-pat", "my-project", "test", True, config_path) assert config_path.exists() data = yaml.safe_load(config_path.read_text()) - assert data["pat"] == "my-pat" + assert data["profiles"]["test"]["pat"] == "my-pat" + + +def test_write_profile_preserves_existing(tmp_path) -> None: + """Adding a second profile should not overwrite the first.""" + config_path = tmp_path / "config.yaml" + write_profile(DEFAULT_URI, "pat-1", "proj-1", "first", True, config_path) + write_profile("https://api.eu.dremio.cloud", "pat-2", "proj-2", "second", False, config_path) + + data = yaml.safe_load(config_path.read_text()) + assert "first" in data["profiles"] + assert "second" in data["profiles"] + assert data["profiles"]["first"]["pat"] == "pat-1" + assert data["profiles"]["second"]["pat"] == "pat-2" + assert data["default_profile"] == "first" # not changed + + +def test_write_profile_migrates_legacy(tmp_path) -> None: + """Legacy flat config should be migrated into a 'default' profile.""" + config_path = tmp_path / "config.yaml" + config_path.write_text(yaml.dump({"pat": "old-pat", "project_id": "old-proj"})) + + write_profile(DEFAULT_URI, "new-pat", "new-proj", "new-profile", False, config_path) + + data = yaml.safe_load(config_path.read_text()) + assert "default" in data["profiles"] + assert data["profiles"]["default"]["pat"] == "old-pat" + assert "new-profile" in data["profiles"] + assert data["default_profile"] == "default" # legacy becomes default + + +def test_write_profile_set_default(tmp_path) -> None: + """set_default=True should update default_profile.""" + config_path = tmp_path / "config.yaml" + write_profile(DEFAULT_URI, "pat-1", "proj-1", "first", True, config_path) + write_profile(DEFAULT_URI, "pat-2", "proj-2", "second", True, config_path) + + data = yaml.safe_load(config_path.read_text()) + assert data["default_profile"] == "second" + + +def test_slugify() -> None: + assert _slugify("My Project") == "my-project" + assert _slugify("Production Analytics") == "production-analytics" + assert _slugify("dev_sandbox-123") == "dev-sandbox-123" + assert _slugify(" ") == "default" + + +# -- Validation tests -- @pytest.mark.asyncio async def test_validate_credentials_success() -> None: mock_client = AsyncMock() - mock_client.get_project = AsyncMock(return_value={"id": "p1", "name": "My Project"}) + mock_client.list_projects = AsyncMock(return_value={"data": [{"id": "p1", "name": "My Project"}]}) mock_client.close = AsyncMock() with patch("drs.commands.setup.DremioClient", return_value=mock_client): - ok, msg, data = await validate_credentials(DEFAULT_URI, "good-pat", "p1") + ok, msg, projects = await validate_credentials(DEFAULT_URI, "good-pat") assert ok is True - assert "My Project" in msg - assert data["name"] == "My Project" + assert "1 project" in msg + assert projects[0]["name"] == "My Project" @pytest.mark.asyncio -async def test_validate_credentials_bad_pat() -> None: +async def test_validate_credentials_multiple_projects() -> None: mock_client = AsyncMock() - response = httpx.Response(401, request=httpx.Request("GET", "https://api.dremio.cloud")) - mock_client.get_project = AsyncMock( - side_effect=httpx.HTTPStatusError("Unauthorized", request=response.request, response=response) + mock_client.list_projects = AsyncMock( + return_value={ + "data": [ + {"id": "p1", "name": "Prod"}, + {"id": "p2", "name": "Dev"}, + ] + } ) mock_client.close = AsyncMock() with patch("drs.commands.setup.DremioClient", return_value=mock_client): - ok, msg, data = await validate_credentials(DEFAULT_URI, "bad-pat", "p1") + ok, msg, projects = await validate_credentials(DEFAULT_URI, "good-pat") - assert ok is False - assert "PAT" in msg or "Authentication" in msg - assert data is None + assert ok is True + assert "2 project" in msg + assert len(projects) == 2 @pytest.mark.asyncio -async def test_validate_credentials_bad_project() -> None: +async def test_validate_credentials_bad_pat() -> None: mock_client = AsyncMock() - response = httpx.Response(404, request=httpx.Request("GET", "https://api.dremio.cloud")) - mock_client.get_project = AsyncMock( - side_effect=httpx.HTTPStatusError("Not Found", request=response.request, response=response) + response = httpx.Response(401, request=httpx.Request("GET", "https://api.dremio.cloud")) + mock_client.list_projects = AsyncMock( + side_effect=httpx.HTTPStatusError("Unauthorized", request=response.request, response=response) ) mock_client.close = AsyncMock() with patch("drs.commands.setup.DremioClient", return_value=mock_client): - ok, msg, data = await validate_credentials(DEFAULT_URI, "good-pat", "bad-project") + ok, msg, projects = await validate_credentials(DEFAULT_URI, "bad-pat") assert ok is False - assert "Project" in msg - assert data is None + assert "PAT" in msg or "Authentication" in msg + assert projects is None @pytest.mark.asyncio async def test_validate_credentials_forbidden() -> None: mock_client = AsyncMock() response = httpx.Response(403, request=httpx.Request("GET", "https://api.dremio.cloud")) - mock_client.get_project = AsyncMock( + mock_client.list_projects = AsyncMock( side_effect=httpx.HTTPStatusError("Forbidden", request=response.request, response=response) ) mock_client.close = AsyncMock() with patch("drs.commands.setup.DremioClient", return_value=mock_client): - ok, msg, data = await validate_credentials(DEFAULT_URI, "limited-pat", "p1") + ok, msg, projects = await validate_credentials(DEFAULT_URI, "limited-pat") assert ok is False assert "Access denied" in msg - assert data is None + assert projects is None @pytest.mark.asyncio async def test_validate_credentials_connection_error() -> None: mock_client = AsyncMock() - mock_client.get_project = AsyncMock(side_effect=httpx.ConnectError("Connection refused")) + mock_client.list_projects = AsyncMock(side_effect=httpx.ConnectError("Connection refused")) mock_client.close = AsyncMock() with patch("drs.commands.setup.DremioClient", return_value=mock_client): - ok, msg, data = await validate_credentials("https://api.bad.dremio.cloud", "pat", "p1") + ok, msg, projects = await validate_credentials("https://api.bad.dremio.cloud", "pat") assert ok is False assert "Cannot reach" in msg - assert data is None + assert projects is None -def test_write_config_includes_header(tmp_path) -> None: - config_path = tmp_path / "config.yaml" - write_config(DEFAULT_URI, "my-pat", "my-project", config_path) +@pytest.mark.asyncio +async def test_validate_credentials_no_projects() -> None: + mock_client = AsyncMock() + mock_client.list_projects = AsyncMock(return_value={"data": []}) + mock_client.close = AsyncMock() + + with patch("drs.commands.setup.DremioClient", return_value=mock_client): + ok, msg, _projects = await validate_credentials(DEFAULT_URI, "good-pat") + + assert ok is False + assert "No projects" in msg - raw = config_path.read_text() - assert raw.startswith("# Dremio CLI config") - assert "plaintext" in raw + +# -- CLI integration tests -- def test_setup_non_interactive(tmp_path) -> None: @@ -161,11 +221,11 @@ def test_setup_non_interactive(tmp_path) -> None: def test_setup_happy_path(tmp_path) -> None: - """Full wizard flow: region, PAT, project ID, validation, config write.""" + """Full wizard flow: region, PAT, project discovery, profile naming.""" config_path = tmp_path / "config.yaml" mock_client = AsyncMock() - mock_client.get_project = AsyncMock(return_value={"id": "p1", "name": "Test Project"}) + mock_client.list_projects = AsyncMock(return_value={"data": [{"id": "p1", "name": "Test Project"}]}) mock_client.close = AsyncMock() with ( @@ -174,43 +234,30 @@ def test_setup_happy_path(tmp_path) -> None: patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path), ): mock_sys.stdin.isatty.return_value = True - # Input: region=1, PAT=test-pat, project_id=test-proj - result = runner.invoke(app, ["setup"], input="1\ntest-pat\ntest-proj\n") + # Input: region=1, PAT=test-pat, profile_name=test-project (accept default) + result = runner.invoke(app, ["setup"], input="1\ntest-pat\ntest-project\n") assert result.exit_code == 0 assert "Setup complete" in result.output assert config_path.exists() data = yaml.safe_load(config_path.read_text()) - assert data["pat"] == "test-pat" - assert data["project_id"] == "test-proj" + assert data["profiles"]["test-project"]["pat"] == "test-pat" + assert data["profiles"]["test-project"]["project_id"] == "p1" -def test_setup_existing_config_decline(tmp_path) -> None: - """Declining to overwrite existing config should abort.""" +def test_setup_multiple_projects(tmp_path) -> None: + """Should let user pick from multiple discovered projects.""" config_path = tmp_path / "config.yaml" - config_path.write_text("pat: old\n") - - with ( - patch("drs.commands.setup.sys") as mock_sys, - patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path), - ): - mock_sys.stdin.isatty.return_value = True - # Input: decline overwrite (n) - result = runner.invoke(app, ["setup"], input="n\n") - - assert result.exit_code == 0 - assert "cancelled" in result.output.lower() - # Config should be unchanged - assert config_path.read_text() == "pat: old\n" - - -def test_setup_existing_config_overwrite(tmp_path) -> None: - """Accepting overwrite should proceed with the wizard.""" - config_path = tmp_path / "config.yaml" - config_path.write_text("pat: old\n") mock_client = AsyncMock() - mock_client.get_project = AsyncMock(return_value={"id": "p1", "name": "New Project"}) + mock_client.list_projects = AsyncMock( + return_value={ + "data": [ + {"id": "p1", "name": "Production"}, + {"id": "p2", "name": "Dev Sandbox"}, + ] + } + ) mock_client.close = AsyncMock() with ( @@ -219,12 +266,12 @@ def test_setup_existing_config_overwrite(tmp_path) -> None: patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path), ): mock_sys.stdin.isatty.return_value = True - # Input: overwrite=y, region=1, PAT=new-pat, project_id=new-proj - result = runner.invoke(app, ["setup"], input="y\n1\nnew-pat\nnew-proj\n") + # Input: region=1, PAT=my-pat, pick project 2, profile_name=dev + result = runner.invoke(app, ["setup"], input="1\nmy-pat\n2\ndev\n") assert result.exit_code == 0 data = yaml.safe_load(config_path.read_text()) - assert data["pat"] == "new-pat" + assert data["profiles"]["dev"]["project_id"] == "p2" def test_setup_retry_then_abort(tmp_path) -> None: @@ -233,7 +280,7 @@ def test_setup_retry_then_abort(tmp_path) -> None: mock_client = AsyncMock() response = httpx.Response(401, request=httpx.Request("GET", "https://api.dremio.cloud")) - mock_client.get_project = AsyncMock( + mock_client.list_projects = AsyncMock( side_effect=httpx.HTTPStatusError("Unauthorized", request=response.request, response=response) ) mock_client.close = AsyncMock() @@ -244,8 +291,8 @@ def test_setup_retry_then_abort(tmp_path) -> None: patch("drs.commands.setup.DEFAULT_CONFIG_PATH", config_path), ): mock_sys.stdin.isatty.return_value = True - # Input: region=1, PAT=bad, project_id=p1, then decline retry - result = runner.invoke(app, ["setup"], input="1\nbad-pat\np1\nn\n") + # Input: region=1, PAT=bad, then decline retry + result = runner.invoke(app, ["setup"], input="1\nbad-pat\nn\n") assert result.exit_code == 1 assert "cancelled" in result.output.lower() @@ -257,7 +304,7 @@ def test_setup_global_config_passthrough(tmp_path) -> None: config_path = tmp_path / "custom.yaml" mock_client = AsyncMock() - mock_client.get_project = AsyncMock(return_value={"id": "p1", "name": "Test"}) + mock_client.list_projects = AsyncMock(return_value={"data": [{"id": "p1", "name": "Test"}]}) mock_client.close = AsyncMock() with ( @@ -265,9 +312,19 @@ def test_setup_global_config_passthrough(tmp_path) -> None: patch("drs.commands.setup.DremioClient", return_value=mock_client), ): mock_sys.stdin.isatty.return_value = True - result = runner.invoke(app, ["--config", str(config_path), "setup"], input="1\nmy-pat\nmy-proj\n") + # Input: region=1, PAT=my-pat, profile_name=test + result = runner.invoke(app, ["--config", str(config_path), "setup"], input="1\nmy-pat\ntest\n") assert result.exit_code == 0 assert config_path.exists() data = yaml.safe_load(config_path.read_text()) - assert data["pat"] == "my-pat" + assert data["profiles"]["test"]["pat"] == "my-pat" + + +def test_write_profile_includes_header(tmp_path) -> None: + config_path = tmp_path / "config.yaml" + write_profile(DEFAULT_URI, "my-pat", "my-project", "test", True, config_path) + + raw = config_path.read_text() + assert raw.startswith("# Dremio CLI config") + assert "plaintext" in raw