Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 91 additions & 13 deletions src/drs/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,110 @@ class DrsConfig(BaseModel):
project_id: str


def _extract_flat_values(raw: dict[str, Any]) -> dict[str, Any]:
"""Extract config values from a flat (legacy) config dict."""
values = {
"uri": raw.get("uri", raw.get("endpoint")),
"pat": raw.get("pat", raw.get("token")),
"project_id": raw.get("project_id", raw.get("projectId")),
}
return {k: v for k, v in values.items() if v is not None}


def _resolve_profile(raw: dict[str, Any], profile: str | None) -> dict[str, Any]:
"""Resolve a profile from the config file.

If the file uses the new ``profiles`` format, look up the requested
profile (falling back to ``default_profile``). If the file is the
legacy flat format (no ``profiles`` key), treat it as a single
implicit profile.
"""
profiles = raw.get("profiles")
if profiles is None:
# Legacy flat format — treat as single implicit profile
return _extract_flat_values(raw)

# New profiles format
name = profile or raw.get("default_profile") or next(iter(profiles), None)
if not name or name not in profiles:
return {}
return _extract_flat_values(profiles[name])


def read_config_file(config_path: Path | None = None) -> dict[str, Any]:
"""Read and return the raw YAML config dict. Returns {} if the file doesn't exist."""
path = config_path or DEFAULT_CONFIG_PATH
if not path.exists():
return {}
with path.open() as f:
return yaml.safe_load(f) or {}


def list_profiles(config_path: Path | None = None) -> dict[str, dict[str, Any]]:
"""Return all profiles from the config file.

For legacy flat configs, returns a single ``default`` profile.
"""
raw = read_config_file(config_path)
profiles = raw.get("profiles")
if profiles is None:
flat = _extract_flat_values(raw)
if flat:
return {"default": flat}
return {}
return {name: _extract_flat_values(vals) for name, vals in profiles.items()}


def get_default_profile_name(config_path: Path | None = None) -> str | None:
"""Return the name of the default profile, or None."""
raw = read_config_file(config_path)
profiles = raw.get("profiles")
if profiles is None:
# Legacy flat format — the implicit profile name is "default"
flat = _extract_flat_values(raw)
return "default" if flat else None
return raw.get("default_profile") or next(iter(profiles), None)


def set_default_profile(profile_name: str, config_path: Path | None = None) -> None:
"""Set the default profile in the config file."""
path = config_path or DEFAULT_CONFIG_PATH
raw = read_config_file(path)
profiles = raw.get("profiles", {})
if profile_name not in profiles:
raise ValueError(f"Profile '{profile_name}' not found in config")
Comment on lines +106 to +108
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Handle implicit default profile in set_default_profile

set_default_profile only checks the explicit profiles map, so legacy flat configs (which the same module treats as an implicit default profile) always raise “Profile not found”. This makes dremio context use default fail for valid legacy configs even though context list/context current expose default, which breaks the advertised backward-compatibility path for profile management.

Useful? React with 👍 / 👎.

raw["default_profile"] = profile_name
_write_raw_config(raw, path)


def _write_raw_config(raw: dict[str, Any], config_path: Path) -> None:
"""Write a raw config dict to YAML."""
config_path.parent.mkdir(parents=True, exist_ok=True)
header = "# Dremio CLI config — generated by 'dremio setup'\n# PAT is stored in plaintext. Keep this file private (mode 600).\n"
config_path.write_text(header + yaml.dump(raw, default_flow_style=False, sort_keys=False))
config_path.chmod(0o600)


def load_config(
config_path: Path | None = None,
*,
profile: str | None = None,
cli_token: str | None = None,
cli_project_id: str | None = None,
cli_uri: str | None = None,
) -> DrsConfig:
"""Load config with resolution order: CLI args > env vars > config file > defaults.
"""Load config with resolution order: CLI args > env vars > config file profile > defaults.

Authentication priority:
1. --token CLI arg
2. DREMIO_TOKEN / DREMIO_PAT env var
3. Config file pat/token field
3. Config file profile (selected via --profile / DREMIO_PROFILE / default_profile)
"""
# -- Config file (lowest priority) --
file_values: dict[str, Any] = {}
path = config_path or DEFAULT_CONFIG_PATH
if path.exists():
with path.open() as f:
raw = yaml.safe_load(f) or {}
file_values = {
"uri": raw.get("uri", raw.get("endpoint")),
"pat": raw.get("pat", raw.get("token")),
"project_id": raw.get("project_id", raw.get("projectId")),
}
file_values = {k: v for k, v in file_values.items() if v is not None}
raw = read_config_file(config_path)
# Profile selection: CLI --profile > DREMIO_PROFILE env var > default_profile in file
effective_profile = profile or os.environ.get("DREMIO_PROFILE")
file_values = _resolve_profile(raw, effective_profile)

# -- Env vars (override file) --
env_values: dict[str, Any] = {}
Expand Down
5 changes: 5 additions & 0 deletions src/drs/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from drs.client import DremioClient
from drs.commands import (
chat,
context,
engine,
folder,
grant,
Expand Down Expand Up @@ -68,6 +69,7 @@
app.add_typer(grant.app, name="grant")
app.add_typer(project.app, name="project")
app.add_typer(chat.app, name="chat")
app.add_typer(context.app, name="context")
app.command("setup")(setup.setup_command)

# Global state for config
Expand All @@ -79,6 +81,7 @@
def main(
ctx: typer.Context,
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
profile: str | None = typer.Option(None, "--profile", "-p", help="Named profile to use from config file"),
token: str | None = typer.Option(None, "--token", help="Dremio personal access token (PAT)"),
project_id: str | None = typer.Option(None, "--project-id", help="Dremio Cloud project ID"),
uri: str | None = typer.Option(
Expand Down Expand Up @@ -110,6 +113,7 @@ def main(
global _cli_opts
_cli_opts = {
"config_path": Path(config) if config else None,
"profile": profile,
"cli_token": token,
"cli_project_id": project_id,
"cli_uri": uri,
Expand All @@ -126,6 +130,7 @@ def get_config() -> DrsConfig:
try:
_config = load_config(
_cli_opts.get("config_path"),
profile=_cli_opts.get("profile"),
cli_token=_cli_opts.get("cli_token"),
cli_project_id=_cli_opts.get("cli_project_id"),
cli_uri=_cli_opts.get("cli_uri"),
Expand Down
113 changes: 113 additions & 0 deletions src/drs/commands/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#
# Copyright (C) 2017-2026 Dremio Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""dremio context — switch between named profiles."""

from __future__ import annotations

from pathlib import Path

import typer
from rich.console import Console
from rich.table import Table

from drs.auth import (
DEFAULT_CONFIG_PATH,
DEFAULT_URI,
get_default_profile_name,
list_profiles,
set_default_profile,
)

app = typer.Typer(
help="Manage named configuration profiles.",
context_settings={"help_option_names": ["-h", "--help"]},
)

console = Console()
err_console = Console(stderr=True)


def _get_config_path(ctx: typer.Context) -> Path:
"""Resolve config path from the global --config flag."""
if ctx.obj and ctx.obj.get("config_path"):
return ctx.obj["config_path"]
return DEFAULT_CONFIG_PATH


@app.command("list")
def cli_list(ctx: typer.Context) -> None:
"""List all configured profiles."""
config_path = _get_config_path(ctx)
profiles = list_profiles(config_path)

if not profiles:
err_console.print(
"[yellow]No profiles configured.[/yellow]\nRun [bold cyan]dremio setup[/bold cyan] to create one."
)
raise typer.Exit(1)

default_name = get_default_profile_name(config_path)

table = Table(show_header=True, header_style="bold")
table.add_column("")
table.add_column("Profile")
table.add_column("Region")
table.add_column("Project ID")

for name, values in profiles.items():
is_active = name == default_name
marker = "*" if is_active else " "
uri = values.get("uri", DEFAULT_URI)
region = "EU" if "eu.dremio" in uri else "US"
project_id = values.get("project_id", "—")
style = "bold" if is_active else ""
table.add_row(marker, name, region, project_id, style=style)

console.print(table)


@app.command("use")
def cli_use(
ctx: typer.Context,
name: str = typer.Argument(help="Profile name to set as default"),
) -> None:
"""Switch the default profile."""
config_path = _get_config_path(ctx)
try:
set_default_profile(name, config_path)
except ValueError as exc:
err_console.print(f"[red]{exc}[/red]")
profiles = list_profiles(config_path)
if profiles:
err_console.print(f"Available profiles: {', '.join(profiles)}")
raise typer.Exit(1)

console.print(f"Switched to profile [bold]{name}[/bold].")


@app.command("current")
def cli_current(ctx: typer.Context) -> None:
"""Show the active (default) profile name."""
config_path = _get_config_path(ctx)
name = get_default_profile_name(config_path)

if not name:
err_console.print(
"[yellow]No profiles configured.[/yellow]\nRun [bold cyan]dremio setup[/bold cyan] to create one."
)
raise typer.Exit(1)

console.print(name)
59 changes: 47 additions & 12 deletions src/drs/commands/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,42 @@ async def create(client: DremioClient, path: str, rtype: str, display_fields: li
raise handle_api_error(exc) from exc


async def list_reflections(client: DremioClient, path: str) -> dict:
"""List reflections on a dataset via sys.project.reflections."""
parts = parse_path(path)
try:
entity = await client.get_catalog_by_path(parts)
except httpx.HTTPStatusError as exc:
raise handle_api_error(exc) from exc
dataset_id = entity["id"]
sql = f"SELECT * FROM sys.project.reflections WHERE dataset_id = '{dataset_id}'"
async def list_reflections(
client: DremioClient,
path: str | None = None,
*,
rtype: str | None = None,
status: str | None = None,
dataset_name: str | None = None,
limit: int | None = None,
) -> dict:
"""List reflections via sys.project.reflections.

When *path* is given, only reflections for that dataset are returned.
When omitted, all reflections in the project are returned.
Optional filters narrow results by *rtype*, *status*, or *dataset_name*.
An optional *limit* caps the number of rows returned.
"""
sql = "SELECT * FROM sys.project.reflections"
conditions: list[str] = []
if path is not None:
parts = parse_path(path)
try:
entity = await client.get_catalog_by_path(parts)
except httpx.HTTPStatusError as exc:
raise handle_api_error(exc) from exc
dataset_id = entity["id"]
conditions.append(f"dataset_id = '{dataset_id}'")
if rtype is not None:
conditions.append(f"type = '{rtype.upper()}'")
if status is not None:
conditions.append(f"status = '{status.upper()}'")
if dataset_name is not None:
conditions.append(f"dataset_name ILIKE '%{dataset_name}%'")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Escape dataset-name filter before SQL interpolation

The --dataset-name value is interpolated directly into a quoted SQL literal without escaping. Inputs containing a single quote (for example O'Reilly) generate invalid SQL and cause reflection list to fail, even though this flag is intended for free-form substring matching.

Useful? React with 👍 / 👎.

if conditions:
sql += " WHERE " + " AND ".join(conditions)
if limit is not None:
sql += f" LIMIT {limit}"
return await run_query(client, sql)


Expand Down Expand Up @@ -142,12 +169,20 @@ def cli_create(

@app.command("list")
def cli_list(
path: str = typer.Argument(help="Dot-separated dataset path"),
path: str = typer.Argument(None, help="Dot-separated dataset path (omit to list all reflections)"),
rtype: str = typer.Option(None, "--type", "-t", help="Filter by reflection type: raw or aggregation"),
status: str = typer.Option(None, "--status", "-s", help="Filter by status (e.g. CAN_ACCELERATE, FAILED, EXPIRED)"),
dataset_name: str = typer.Option(None, "--dataset-name", "-d", help="Filter by dataset name (substring match)"),
limit: int = typer.Option(None, "--limit", "-l", help="Maximum number of reflections to return"),
fmt: OutputFormat = typer.Option(OutputFormat.json, "--output", "-o", help="Output format"),
) -> None:
"""List all reflections defined on a dataset."""
"""List reflections. Shows all project reflections, or those for a specific dataset."""
client = _get_client()
_run_command(list_reflections(client, path), client, fmt)
_run_command(
list_reflections(client, path, rtype=rtype, status=status, dataset_name=dataset_name, limit=limit),
client,
fmt,
)


@app.command("get")
Expand Down
Loading
Loading