diff --git a/config/interpolators-ich1.yaml b/config/interpolators-ich1.yaml index db7c26bc..d016eea2 100644 --- a/config/interpolators-ich1.yaml +++ b/config/interpolators-ich1.yaml @@ -23,6 +23,7 @@ runs: checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/c30490b6ba064e4db03b430f3a2595ad config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml steps: 0/120/6 + label: stage_E extra_requirements: - git+https://github.com/ecmwf/anemoi-inference.git@e369b1a90313e9701db13f63364a467aa281cf36 extra_requirements: diff --git a/config/showcase-interpolators-ich1.yaml b/config/showcase-interpolators-ich1.yaml new file mode 100644 index 00000000..def7f896 --- /dev/null +++ b/config/showcase-interpolators-ich1.yaml @@ -0,0 +1,27 @@ +# Showcase config for interpolators-ich1. +# Includes the base config and overrides only the showcase section. +# +# .venv/bin/evalml showcase config/showcase-interpolators-ich1.yaml + +include: interpolators-ich1.yaml + +showcase: + params: + - T_2M + - SP_10M + - TOT_PREC + meteograms: + enabled: false + stations: [JUN] + animations: + enabled: true + domains: + - europe + - switzerland + - globe + speed: 10 # simulated hours per second + runs: + - Varda-Single + comparisons: + - left: Varda-Single + right: KENDA-CH1 diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 403683d9..fb6422ad 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -64,7 +64,10 @@ def load_analysis_data_from_zarr( "PMSL": "msl", "TOT_PREC": "tp", } - tot_prec_string = "TOT_PREC_6H" if min(np.diff(steps)) == 6 else "TOT_PREC_1H" + _diffs = np.diff(steps) + tot_prec_string = ( + "TOT_PREC_6H" if len(_diffs) > 0 and min(_diffs) == 6 else "TOT_PREC_1H" + ) PARAMS_MAP_COSMO1 = { v: v.replace("TOT_PREC", tot_prec_string) for v in PARAMS_MAP_COSMO2.keys() } diff --git a/src/evalml/cli.py b/src/evalml/cli.py index 51a9ed45..51f20096 100644 --- a/src/evalml/cli.py +++ b/src/evalml/cli.py @@ -69,9 +69,24 @@ def generate_graph( click.echo(f"Graph saved to {output_file}") +def _deep_merge(base: dict, override: dict) -> dict: + """Recursively merge override into base. Override wins on conflicts.""" + result = dict(base) + for key, value in override.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = _deep_merge(result[key], value) + else: + result[key] = value + return result + + def load_yaml(path: Path) -> dict[str, Any]: with path.open("r") as f: - return yaml.safe_load(f) + data = yaml.safe_load(f) + if include := data.pop("include", None): + base = load_yaml(path.parent / include) + data = _deep_merge(base, data) + return data def workflow_options(func): diff --git a/src/evalml/config.py b/src/evalml/config.py index 3e89b239..3f58d675 100644 --- a/src/evalml/config.py +++ b/src/evalml/config.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, List, Any, ClassVar, FrozenSet +from typing import Dict, List, Any, ClassVar, FrozenSet, Optional from pydantic import BaseModel, Field, RootModel, field_validator @@ -212,6 +212,12 @@ class BaselineItem(BaseModel): baseline: BaselineConfig +class RegionConfig(BaseModel): + """A custom map region defined by name, extent, and projection.""" + + name: str = Field(..., description="Name for the custom region (used as wildcard).") + + class DomainConfig(BaseModel): """A custom map domain defined by name, extent, and projection.""" @@ -228,6 +234,13 @@ class DomainConfig(BaseModel): model_config = {"extra": "forbid"} +class AnimationComparison(BaseModel): + """A side-by-side comparison animation between two runs.""" + + left: str = Field(..., description="Label of the run shown in the left panel.") + right: str = Field(..., description="Label of the run shown in the right panel.") + + class MeteogramConfig(BaseModel): """Configuration for meteogram generation.""" @@ -248,7 +261,7 @@ class AnimationsConfig(BaseModel): default=True, description="Whether to generate forecast animations (GIFs per param and region).", ) - domains: List[str | DomainConfig] = Field( + domains: List[str | RegionConfig] = Field( default=["globe", "europe", "switzerland"], description=( "Domains to generate animations for. Each entry is either a named domain " @@ -257,6 +270,25 @@ class AnimationsConfig(BaseModel): "[lon_min, lon_max, lat_min, lat_max], and optional 'projection'." ), ) + speed: float = Field( + default=10.0, + gt=0, + description="Animation playback speed in simulated hours per second.", + ) + runs: Optional[List[str]] = Field( + default=None, + description=( + "Labels of runs to generate individual animations for. " + "Defaults to all candidate runs when omitted." + ), + ) + comparisons: List[AnimationComparison] = Field( + default=[], + description=( + "Side-by-side two-panel comparison animations. Each entry specifies " + "the labels of the left and right panel runs." + ), + ) class ShowcaseConfig(BaseModel): diff --git a/src/plotting/compat.py b/src/plotting/compat.py index 7c4d14d3..bf2ad9e2 100644 --- a/src/plotting/compat.py +++ b/src/plotting/compat.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path import earthkit.data as ekd @@ -77,6 +77,60 @@ def load_state_from_grib( return state +def load_state_from_zarr( + zarr_root: Path, + reftime: datetime, + lead_time_hours: int, + params: list[str], + source_type: str = "analysis", +) -> dict: + """Load a single time step from a zarr source into the state dict used by StatePlotter. + + Parameters + ---------- + zarr_root: + Path to the zarr dataset. + reftime: + Forecast reference time (init time). + lead_time_hours: + Lead time in hours to load. + params: + List of parameter names (ICON convention, e.g. ``['U_10M', 'V_10M']``). + source_type: + ``'analysis'`` for truth zarrs (loads via ``load_analysis_data_from_zarr``), + ``'baseline'`` for baseline forecast zarrs. + """ + from data_input import load_analysis_data_from_zarr, load_baseline_from_zarr + + steps = [lead_time_hours] + + if source_type == "analysis": + ds = load_analysis_data_from_zarr(zarr_root, reftime, steps, params) + ds_t = ds.isel(time=0) if "time" in ds.dims else ds.squeeze() + else: + ds = load_baseline_from_zarr(zarr_root, reftime, steps, params) + ds_t = ds.isel(lead_time=0) if "lead_time" in ds.dims else ds.squeeze() + + lat = ds_t.lat.values.flatten() + lon = ds_t.lon.values.flatten() + + hull = MultiPoint(list(zip(lon.tolist(), lat.tolist()))).convex_hull + state = { + "forecast_reference_time": reftime, + "valid_time": reftime + timedelta(hours=lead_time_hours), + "longitudes": lon, + "latitudes": lat, + "lam_envelope": gpd.GeoSeries([hull], crs="EPSG:4326"), + "fields": {}, + } + for param in params: + if param in ds_t.data_vars: + state["fields"][param] = ds_t[param].values.flatten() + else: + state["fields"][param] = np.full(lat.size, np.nan, dtype=float) + return state + + def load_state_from_raw( file: Path, paramlist: list[str] | None = None ) -> dict[str, np.ndarray | dict[str, np.ndarray]]: diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 4c715931..50113691 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,8 +1,47 @@ import pytest +from os.path import Path +from evalml.cli import _deep_merge, load_yaml from evalml.config import ConfigModel +def test_deep_merge_override_wins(): + base = {"a": 1, "b": {"x": 1, "y": 2}} + override = {"b": {"y": 99}, "c": 3} + result = _deep_merge(base, override) + assert result == {"a": 1, "b": {"x": 1, "y": 99}, "c": 3} + + +def test_deep_merge_non_dict_override_replaces(): + base = {"a": {"x": 1}} + override = {"a": [1, 2, 3]} + result = _deep_merge(base, override) + assert result["a"] == [1, 2, 3] + + +def test_load_yaml_without_include(tmp_path): + f = tmp_path / "config.yaml" + f.write_text("a: 1\n") + assert load_yaml(f) == {"a": 1} + + +def test_load_yaml_include_merges_base(tmp_path): + base = tmp_path / "base.yaml" + base.write_text("a: 1\nb:\n x: 1\n y: 2\n") + + child = tmp_path / "child.yaml" + child.write_text("include: base.yaml\nb:\n y: 99\nc: 3\n") + + result = load_yaml(child) + assert result == {"a": 1, "b": {"x": 1, "y": 99}, "c": 3} + + +def test_load_yaml_include_validates_as_config_model(): + path = Path("config/showcase-interpolators-ich1.yaml") + data = load_yaml(path) + _ = ConfigModel.model_validate(data) + + def test_example_forecasters_config(example_forecasters_config): """Test that the example config loads correctly.""" diff --git a/workflow/Snakefile b/workflow/Snakefile index 2f9f3c36..4b0ee5c0 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -140,7 +140,7 @@ rule showcase_all: expand( rules.make_forecast_animation.output, init_time=[t.strftime("%Y%m%d%H%M") for t in REFTIMES], - run_id=CANDIDATES, + run_id=SHOWCASE_ANIMATION_RUN_IDS, param=SHOWCASE_PARAMS, region=list(SHOWCASE_REGIONS.keys()), showcase=EXPERIMENT_NAME, @@ -148,6 +148,18 @@ rule showcase_all: if config["showcase"]["animations"]["enabled"] else [] ), + ( + expand( + rules.make_comparison_animation.output, + init_time=[t.strftime("%Y%m%d%H%M") for t in REFTIMES], + comparison_id=[c["id"] for c in SHOWCASE_COMPARISONS], + param=SHOWCASE_PARAMS, + region=list(SHOWCASE_REGIONS.keys()), + showcase=EXPERIMENT_NAME, + ) + if config["showcase"]["animations"]["enabled"] and SHOWCASE_COMPARISONS + else [] + ), ( expand( rules.plot_meteogram.output, diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 599aa8d6..cbd3cbe9 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -332,3 +332,140 @@ RUN_CONFIGS = collect_all_runs() ENV_CONFIGS = collect_all_envs() BASELINE_CONFIGS = collect_all_baselines() EXPERIMENT_PARTICIPANTS = collect_experiment_participants() + + +# ============================================================================ +# Showcase animation helpers +# ============================================================================ + + +def sanitize_label(label: str) -> str: + """Sanitize a run label for use as a path component.""" + import re as _re + + return _re.sub(r"[^a-zA-Z0-9_-]", "_", label) + + +def collect_zarr_sources() -> dict: + """Collect zarr-based sources (truth + baselines) keyed by their label. + + Returns a dict mapping label -> {root, step, total_hours, source_type}. + """ + sources = {} + + # Truth (analysis) + truth_cfg = config.get("truth", {}) + if truth_cfg and "root" in truth_cfg: + label = truth_cfg.get("label", "truth") + sources[label] = { + "root": truth_cfg["root"], + "step": 1, + "total_hours": 120, + "source_type": "analysis", + } + + # Baselines + for baseline_id, cfg in BASELINE_CONFIGS.items(): + label = cfg.get("label", baseline_id) + _, total, step = map(int, cfg.get("steps", "0/120/1").split("/")) + sources[label] = { + "root": cfg["root"], + "step": step, + "total_hours": total, + "source_type": "baseline", + } + + return sources + + +def _resolve_label(label: str) -> dict: + """Resolve a label to a source descriptor used in comparison entries. + + Returns a dict with: + type — ``'run'`` or ``'zarr'`` + run_id — present when type == 'run' + label — present when type == 'zarr' + step — time step in hours + """ + # ML runs (candidates and non-candidates such as nested forecasters) + for run_id, cfg in RUN_CONFIGS.items(): + if cfg.get("label") == label: + return { + "type": "run", + "run_id": run_id, + "step": int(cfg["steps"].split("/")[2]), + } + # Zarr sources (truth / baselines) + if label in ZARR_SOURCES: + z = ZARR_SOURCES[label] + return {"type": "zarr", "label": label, "step": z["step"]} + + available_runs = sorted( + {cfg.get("label") for cfg in RUN_CONFIGS.values() if cfg.get("label")} + ) + available_zarr = sorted(ZARR_SOURCES.keys()) + raise ValueError( + f"No source found with label {label! r}. " + f"ML run labels: {available_runs}. " + f"Zarr source labels: {available_zarr}." + ) + + +def label_to_run_id(label: str) -> str: + """Return the run_id for the given label (ML runs only). + + Searches both candidate and non-candidate runs (e.g. nested forecasters). + Raises ValueError if not found. + """ + for run_id, cfg in RUN_CONFIGS.items(): + if cfg.get("label") == label: + return run_id + available = sorted( + {cfg.get("label") for cfg in RUN_CONFIGS.values() if cfg.get("label")} + ) + raise ValueError( + f"No run found with label {label! r}. Available ML run labels: {available}" + ) + + +def parse_showcase_animation_runs() -> list: + """Return the run_ids to animate individually. + + If ``animations.runs`` is set in the showcase config, filter by those labels + (ML runs only; zarr sources have their own animation pipeline). + Otherwise return all candidate run_ids. + """ + labels = config.get("showcase", {}).get("animations", {}).get("runs") + if labels is None: + return list(collect_all_candidates().keys()) + return [label_to_run_id(label) for label in labels] + + +def parse_showcase_comparisons() -> list: + """Parse ``animations.comparisons`` from the showcase config. + + Each returned entry has: + id — sanitised ``{left_label}_vs_{right_label}`` path component + left — source descriptor (type, run_id/label, step) + right — source descriptor (type, run_id/label, step) + """ + comparisons = ( + config.get("showcase", {}).get("animations", {}).get("comparisons", []) + ) + result = [] + for c in comparisons: + left_label = c["left"] + right_label = c["right"] + result.append( + { + "id": f"{sanitize_label(left_label)}_vs_{sanitize_label(right_label)}", + "left": _resolve_label(left_label), + "right": _resolve_label(right_label), + } + ) + return result + + +ZARR_SOURCES = collect_zarr_sources() +SHOWCASE_ANIMATION_RUN_IDS = parse_showcase_animation_runs() +SHOWCASE_COMPARISONS = parse_showcase_comparisons() diff --git a/workflow/rules/inference.smk b/workflow/rules/inference.smk index 87782c38..5a6f0a75 100644 --- a/workflow/rules/inference.smk +++ b/workflow/rules/inference.smk @@ -334,4 +334,6 @@ rule inference_execute: ' ) >{log} 2>&1 """ - # fmt: on + + +# fmt: on diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index 5b6a9fd8..c8c4d4cc 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -118,11 +118,6 @@ rule plot_forecast_frame: --param {wildcards.param} --leadtime {wildcards.leadtime} --region {wildcards.region} \ {params.region_extra} \ --accu {params.accu} - # interactive editing (needs to set localrule: True and use only one core) - # marimo edit {input.script} -- \ - # --input {params.grib_out_dir} --date {wildcards.init_time} --outfn {output[0]}\ - # --param {wildcards.param} --leadtime {wildcards.leadtime} --region {wildcards.region}\ - # --accu {params.accu}\ """ @@ -136,6 +131,11 @@ def get_leadtimes(wc): rule make_forecast_animation: + localrule: True + wildcard_constraints: + run_id="|".join(map(re.escape, RUN_CONFIGS.keys())), + param="|".join(map(re.escape, SHOWCASE_PARAMS)), + region="|".join(map(re.escape, SHOWCASE_REGIONS.keys())), input: lambda wc: expand( rules.plot_forecast_frame.output, @@ -153,8 +153,154 @@ rule make_forecast_animation: region="|".join(map(re.escape, SHOWCASE_REGIONS.keys())), localrule: True params: - delay=lambda wc: 10 * int(RUN_CONFIGS[wc.run_id]["steps"].split("/")[2]), + delay=lambda wc: round( + int(RUN_CONFIGS[wc.run_id]["steps"].split("/")[2]) + / config["showcase"]["animations"].get("speed", 10.0) + * 100 + ), + shell: + """ + convert -delay {params.delay} -loop 0 {input} {output} + """ + + +def _comparison_by_id(comparison_id: str) -> dict: + """Look up a SHOWCASE_COMPARISONS entry by its id wildcard.""" + for c in SHOWCASE_COMPARISONS: + if c["id"] == comparison_id: + return c + raise ValueError(f"No comparison with id {comparison_id! r}") + + +def _side_gif_path(side: dict, wc) -> list: + """Return the GIF path list for one side of a comparison (run or zarr).""" + if side["type"] == "run": + return expand( + rules.make_forecast_animation.output, + run_id=side["run_id"], + init_time=wc.init_time, + param=wc.param, + region=wc.region, + showcase=wc.showcase, + ) + else: + return expand( + rules.make_zarr_animation.output, + source_id=side["label"], + init_time=wc.init_time, + param=wc.param, + region=wc.region, + showcase=wc.showcase, + ) + + +def get_zarr_leadtimes(wc): + """Get lead times for a zarr source, skipping step 0 for TOT_PREC.""" + cfg = ZARR_SOURCES[wc.source_id] + step = cfg["step"] + total = cfg["total_hours"] + start = step # always skip lead time 0 (no meaningful accumulation at t=0) + return [f"{i:03}" for i in range(start, total + 1, step)] + + +rule plot_zarr_frame: + input: + script="workflow/scripts/plot_zarr_frame.py", + output: + OUT_ROOT + / "data/zarr/{source_id}/{init_time}/frames/frame_{leadtime}_{param}_{region}.png", + wildcard_constraints: + source_id="|".join(map(re.escape, ZARR_SOURCES.keys())) or "NEVER", + leadtime=r"\d+", + region="|".join(map(re.escape, SHOWCASE_REGIONS.keys())), + resources: + slurm_partition="postproc", + cpus_per_task=1, + runtime="10m", + params: + zarr_path=lambda wc: ZARR_SOURCES[wc.source_id]["root"], + source_type=lambda wc: ZARR_SOURCES[wc.source_id]["source_type"], + region_extra=lambda wc: ( + "--extent {} --projection {}".format( + " ".join(map(str, SHOWCASE_REGIONS[wc.region]["extent"])), + SHOWCASE_REGIONS[wc.region]["projection"], + ) + if SHOWCASE_REGIONS.get(wc.region, {}).get("extent") is not None + else "" + ), + accu=lambda wc: ZARR_SOURCES[wc.source_id]["step"], + shell: + """ + export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) + python {input.script} \ + --zarr {params.zarr_path} \ + --source_type {params.source_type} \ + --date {wildcards.init_time} \ + --outfn {output} \ + --param {wildcards.param} \ + --leadtime {wildcards.leadtime} \ + --region {wildcards.region} \ + {params.region_extra} \ + --accu {params.accu} + """ + + +rule make_zarr_animation: + localrule: True + wildcard_constraints: + source_id="|".join(map(re.escape, ZARR_SOURCES.keys())) or "NEVER", + param="|".join(map(re.escape, SHOWCASE_PARAMS)), + region="|".join(map(re.escape, SHOWCASE_REGIONS.keys())), + input: + lambda wc: expand( + rules.plot_zarr_frame.output, + source_id=wc.source_id, + init_time=wc.init_time, + param=wc.param, + region=wc.region, + leadtime=get_zarr_leadtimes(wc), + ), + output: + OUT_ROOT + / "results/{showcase}/zarr/{source_id}/{init_time}/{init_time}_{param}_{region}.gif", + params: + delay=lambda wc: round( + ZARR_SOURCES[wc.source_id]["step"] + / config["showcase"]["animations"].get("speed", 10.0) + * 100 + ), shell: """ convert -delay {params.delay} -loop 0 {input} {output} """ + + +rule make_comparison_animation: + """Side-by-side two-panel animation comparing two sources, synced in simulated time.""" + localrule: True + wildcard_constraints: + param="|".join(map(re.escape, SHOWCASE_PARAMS)), + region="|".join(map(re.escape, SHOWCASE_REGIONS.keys())), + comparison_id="|".join(map(re.escape, [c["id"] for c in SHOWCASE_COMPARISONS])) + or "NEVER", + input: + left=lambda wc: _side_gif_path(_comparison_by_id(wc.comparison_id)["left"], wc), + right=lambda wc: _side_gif_path( + _comparison_by_id(wc.comparison_id)["right"], wc + ), + script="workflow/scripts/plot_combine_animations.py", + output: + OUT_ROOT + / "results/{showcase}/comparisons/{comparison_id}/{init_time}/{init_time}_{param}_{region}.gif", + params: + left_step=lambda wc: _comparison_by_id(wc.comparison_id)["left"]["step"], + right_step=lambda wc: _comparison_by_id(wc.comparison_id)["right"]["step"], + speed=config["showcase"]["animations"].get("speed", 10.0), + shell: + """ + python {input.script} \ + --left {input.left} --left_step {params.left_step} \ + --right {input.right} --right_step {params.right_step} \ + --output {output} \ + --speed {params.speed} + """ diff --git a/workflow/scripts/plot_combine_animations.py b/workflow/scripts/plot_combine_animations.py new file mode 100644 index 00000000..1c5d95ab --- /dev/null +++ b/workflow/scripts/plot_combine_animations.py @@ -0,0 +1,137 @@ +"""Combine two GIF animations side by side, synced in simulated time. + +Each GIF may have a different time step (e.g. 6h vs 1h). The output plays at +the finest resolution, holding frames from the coarser GIF steady while the +finer one advances. Both panels stay in sync with respect to simulated time. + +Usage +----- + python plot_combine_animations.py \\ + --left left.gif --left_step 6 \\ + --right right.gif --right_step 1 \\ + --output comparison.gif \\ + [--speed 6] # simulated hours per second (default: 6) + [--total_hours 120] # total simulated hours covered (default: 120) + [--start_hour 1] # first simulated hour in the GIFs (default: min step) +""" + +import argparse +from pathlib import Path + +from PIL import Image, ImageSequence + + +def load_frames(path: str) -> list[Image.Image]: + im = Image.open(path) + frames = [] + for frame in ImageSequence.Iterator(im): + frames.append(frame.convert("RGBA")) + return frames + + +def frame_for_hour(frames: list[Image.Image], step: int, hour: int) -> Image.Image: + """Return the frame that is valid at the given simulated hour. + + Frames are assumed to start at ``step`` (i.e. frames[0] covers hour=step, + frames[1] covers hour=2*step, etc.). Hours before the first frame return + frames[0]; hours beyond the last frame return the last frame. + """ + idx = max(0, min(len(frames) - 1, (hour - 1) // step)) + return frames[idx] + + +def combine( + left_path: str, + right_path: str, + out_path: str, + left_step: int, + right_step: int, + speed: float = 6.0, + total_hours: int = 120, + start_hour: int | None = None, +) -> None: + left_frames = load_frames(left_path) + right_frames = load_frames(right_path) + + out_step = min(left_step, right_step) + if start_hour is None: + start_hour = out_step + + sim_hours = list(range(start_hour, total_hours + 1, out_step)) + delay_ms = round(out_step / speed * 1000) + + w = left_frames[0].width + right_frames[0].width + h = max(left_frames[0].height, right_frames[0].height) + + out_frames = [] + for sh in sim_hours: + canvas = Image.new("RGBA", (w, h), (255, 255, 255, 255)) + lf = frame_for_hour(left_frames, left_step, sh) + rf = frame_for_hour(right_frames, right_step, sh) + canvas.paste(lf, (0, 0)) + canvas.paste(rf, (left_frames[0].width, 0)) + out_frames.append(canvas.convert("P", palette=Image.ADAPTIVE, colors=256)) + + out_frames[0].save( + out_path, + save_all=True, + append_images=out_frames[1:], + loop=0, + duration=delay_ms, + optimize=False, + ) + print( + f"Written {len(out_frames)} frames " + f"({out_step}h steps, {delay_ms}ms/frame, " + f"{speed} sim-h/s) → {out_path}" + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Combine two GIF animations side by side, synced in simulated time." + ) + parser.add_argument("--left", required=True, help="Path to the left GIF") + parser.add_argument("--right", required=True, help="Path to the right GIF") + parser.add_argument("--output", required=True, help="Output GIF path") + parser.add_argument( + "--left_step", type=int, required=True, help="Time step of the left GIF (h)" + ) + parser.add_argument( + "--right_step", type=int, required=True, help="Time step of the right GIF (h)" + ) + parser.add_argument( + "--speed", + type=float, + default=6.0, + help="Animation speed in simulated hours per second (default: 6)", + ) + parser.add_argument( + "--total_hours", + type=int, + default=120, + help="Total simulated hours covered by the GIFs (default: 120)", + ) + parser.add_argument( + "--start_hour", + type=int, + default=None, + help="First simulated hour in the GIFs (default: min step size)", + ) + args = parser.parse_args() + + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + combine( + left_path=args.left, + right_path=args.right, + out_path=args.output, + left_step=args.left_step, + right_step=args.right_step, + speed=args.speed, + total_hours=args.total_hours, + start_hour=args.start_hour, + ) + + +if __name__ == "__main__": + main() diff --git a/workflow/scripts/plot_zarr_frame.py b/workflow/scripts/plot_zarr_frame.py new file mode 100644 index 00000000..ce0c5f75 --- /dev/null +++ b/workflow/scripts/plot_zarr_frame.py @@ -0,0 +1,200 @@ +"""Plot a single forecast frame from a zarr source (truth or baseline). + +Analogous to plot_forecast_frame.mo.py but reads zarr instead of GRIB. +TOT_PREC disaggregation is handled by the data_input loading functions, +so no accumulation arithmetic is needed here. + +Usage +----- + python plot_zarr_frame.py \\ + --zarr /path/to/data.zarr \\ + --source_type analysis \\ # or 'baseline' + --date 202503270600 \\ + --leadtime 006 \\ + --param T_2M \\ + --region switzerland \\ + --outfn /path/to/frame.png \\ + [--extent LON_MIN LON_MAX LAT_MIN LAT_MAX] \\ + [--projection orthographic] \\ + [--accu 1] +""" + +import logging +from argparse import ArgumentParser +from datetime import datetime +from pathlib import Path + +import cartopy.crs as ccrs +import numpy as np + +from plotting import DOMAINS, StatePlotter, get_projection +from plotting.colormap_defaults import CMAP_DEFAULTS +from plotting.compat import load_state_from_zarr + +LOG = logging.getLogger(__name__) +LOG_FMT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +logging.basicConfig(level=logging.INFO, format=LOG_FMT) + + +def get_style(param, units_override=None, accu=1): + lookup = f"{param}_{accu}H" if param == "TOT_PREC" else param + cfg = CMAP_DEFAULTS[lookup] + import earthkit.plots as ekp + + units = units_override if units_override is not None else cfg.get("units", "") + return { + "style": ekp.styles.Style( + levels=cfg.get("bounds", cfg.get("levels", None)), + extend="both", + units=units, + colors=cfg.get("colors", None), + ), + "norm": cfg.get("norm", None), + "cmap": cfg.get("cmap", None), + "levels": cfg.get("levels", None), + "vmin": cfg.get("vmin", None), + "vmax": cfg.get("vmax", None), + "colors": cfg.get("colors", None), + } + + +def preprocess_field(param, state): + try: + import pint + + _ureg = pint.UnitRegistry() + + def _k_to_c(arr): + try: + return (_ureg.Quantity(arr, _ureg.kelvin).to(_ureg.degC)).magnitude + except Exception: + return arr - 273.15 + + def _ms_to_knots(arr): + try: + return ( + _ureg.Quantity(arr, _ureg.meter / _ureg.second).to(_ureg.knot) + ).magnitude + except Exception: + return arr * 1.943844 + + def _m_to_mm(arr): + try: + return (_ureg.Quantity(arr, _ureg.meter).to(_ureg.millimeter)).magnitude + except Exception: + return arr * 1000 + + except Exception: + LOG.warning("pint not available; using hardcoded conversions") + + def _k_to_c(arr): + return arr - 273.15 + + def _ms_to_knots(arr): + return arr * 1.943844 + + def _m_to_mm(arr): + return arr * 1000 + + fields = state["fields"] + if param in ("T_2M", "TD_2M", "T", "TD"): + return _k_to_c(fields[param]), "°C" + if param == "SP_10M": + return np.sqrt(fields["U_10M"] ** 2 + fields["V_10M"] ** 2), "m/s" + if param == "SP": + return np.sqrt(fields["U"] ** 2 + fields["V"] ** 2), "m/s" + if param == "TOT_PREC": + return np.maximum(_m_to_mm(fields[param]), 0), "mm" + return fields[param], None + + +def main(): + parser = ArgumentParser() + parser.add_argument("--zarr", type=str, required=True, help="Path to zarr dataset") + parser.add_argument( + "--source_type", + type=str, + default="analysis", + choices=["analysis", "baseline"], + help="Zarr source type", + ) + parser.add_argument( + "--date", type=str, required=True, help="Reference datetime (YYYYmmddHHMM)" + ) + parser.add_argument("--outfn", type=str, required=True, help="Output filename") + parser.add_argument( + "--leadtime", type=str, required=True, help="Lead time (hours, zero-padded)" + ) + parser.add_argument("--param", type=str, required=True, help="Parameter name") + parser.add_argument("--region", type=str, required=True, help="Region name") + parser.add_argument( + "--extent", + type=float, + nargs=4, + default=None, + metavar=("LON_MIN", "LON_MAX", "LAT_MIN", "LAT_MAX"), + ) + parser.add_argument("--projection", type=str, default=None) + parser.add_argument( + "--accu", type=int, default=1, help="Accumulation period in hours" + ) + args = parser.parse_args() + + reftime = datetime.strptime(args.date, "%Y%m%d%H%M") + lead_time_hours = int(args.leadtime) + outfn = Path(args.outfn) + param = args.param + + if param == "SP_10M": + paramlist = ["U_10M", "V_10M"] + elif param == "SP": + paramlist = ["U", "V"] + else: + paramlist = [param] + + state = load_state_from_zarr( + zarr_root=Path(args.zarr), + reftime=reftime, + lead_time_hours=lead_time_hours, + params=paramlist, + source_type=args.source_type, + ) + + plotter = StatePlotter(state["longitudes"], state["latitudes"], outfn.parent) + + if args.extent is not None: + projection = get_projection(args.projection or "orthographic") + extent = args.extent + else: + projection = DOMAINS[args.region]["projection"] + extent = DOMAINS[args.region]["extent"] + + fig = plotter.init_geoaxes( + nrows=1, + ncols=1, + projection=projection, + bbox=extent, + name=args.region, + size=(6, 6), + ) + subplot = fig.add_map(row=0, column=0) + + field, units_override = preprocess_field(param, state) + plotter.plot_field( + subplot, field, **get_style(param, units_override, accu=args.accu) + ) + subplot.ax.add_geometries( + state["lam_envelope"], + edgecolor="black", + facecolor="none", + crs=ccrs.PlateCarree(), + ) + + validtime = state["valid_time"].strftime("%Y%m%d%H%M") + fig.title(f"{param}, time: {validtime}") + fig.save(outfn, bbox_inches="tight", dpi=200) + LOG.info(f"saved: {outfn}") + + +if __name__ == "__main__": + main() diff --git a/workflow/tools/config.schema.json b/workflow/tools/config.schema.json index 566b54bd..a685a942 100644 --- a/workflow/tools/config.schema.json +++ b/workflow/tools/config.schema.json @@ -1,5 +1,26 @@ { "$defs": { + "AnimationComparison": { + "description": "A side-by-side comparison animation between two runs.", + "properties": { + "left": { + "description": "Label of the run shown in the left panel.", + "title": "Left", + "type": "string" + }, + "right": { + "description": "Label of the run shown in the right panel.", + "title": "Right", + "type": "string" + } + }, + "required": [ + "left", + "right" + ], + "title": "AnimationComparison", + "type": "object" + }, "AnimationsConfig": { "description": "Configuration for animation generation.", "properties": { @@ -22,12 +43,44 @@ "type": "string" }, { - "$ref": "#/$defs/DomainConfig" + "$ref": "#/$defs/RegionConfig" } ] }, "title": "Domains", "type": "array" + }, + "speed": { + "default": 10.0, + "description": "Animation playback speed in simulated hours per second.", + "exclusiveMinimum": 0, + "title": "Speed", + "type": "number" + }, + "runs": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Labels of runs to generate individual animations for. Defaults to all candidate runs when omitted.", + "title": "Runs" + }, + "comparisons": { + "default": [], + "description": "Side-by-side two-panel comparison animations. Each entry specifies the labels of the left and right panel runs.", + "items": { + "$ref": "#/$defs/AnimationComparison" + }, + "title": "Comparisons", + "type": "array" } }, "title": "AnimationsConfig", @@ -178,44 +231,6 @@ "title": "DefaultResources", "type": "object" }, - "DomainConfig": { - "additionalProperties": false, - "description": "A custom map domain defined by name, extent, and projection.", - "properties": { - "name": { - "description": "Name for the custom domain (used as wildcard).", - "title": "Name", - "type": "string" - }, - "extent": { - "anyOf": [ - { - "items": { - "type": "number" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "default": null, - "description": "Geographic extent as [lon_min, lon_max, lat_min, lat_max] in PlateCarree coordinates. None means full globe.", - "title": "Extent" - }, - "projection": { - "default": "orthographic", - "description": "Projection name (must be a key in plotting._PROJECTIONS, e.g. 'orthographic').", - "title": "Projection", - "type": "string" - } - }, - "required": [ - "name" - ], - "title": "DomainConfig", - "type": "object" - }, "ExperimentConfig": { "description": "Configuration for the experiment workflow outputs.", "properties": { @@ -638,6 +653,21 @@ "title": "Profile", "type": "object" }, + "RegionConfig": { + "description": "A custom map region defined by name, extent, and projection.", + "properties": { + "name": { + "description": "Name for the custom region (used as wildcard).", + "title": "Name", + "type": "string" + } + }, + "required": [ + "name" + ], + "title": "RegionConfig", + "type": "object" + }, "ShowcaseConfig": { "description": "Configuration for the showcase workflow.", "properties": {