diff --git a/.gitignore b/.gitignore index 05cc2d0..9147e40 100644 --- a/.gitignore +++ b/.gitignore @@ -156,6 +156,7 @@ activemq-data/ .envrc .venv .venv-rocm +.build/ env/ venv/ ENV/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..910ab56 --- /dev/null +++ b/README.md @@ -0,0 +1,27 @@ +# FusionAIHub (FAITH) + +## Frontier setup + +```bash +# 1. Clone to scratch +cd /lustre/orion/fus187/scratch/$USER +git clone git@github.com:PlasmaControl/FusionAIHub.git +cd FusionAIHub +git switch foundation_model + +# 2. Install pixi +curl -fsSL https://pixi.sh/install.sh | bash +source ~/.bashrc + +# 3. Install the Frontier env (~5 min) +pixi install -e frontier + +# 4. Build flash-attention 2 (~2-5 min) +pixi run -e frontier setup-flash-attn + + +## Other platforms + +- **NVIDIA/CUDA**: `pixi install` (default env), scripts in `scripts/slurm/` +- **della-milan (MI210)**: `bash scripts/slurm_della_milan/setup_rocm_env.sh`, + scripts in `scripts/slurm_della_milan/` diff --git a/docs/superpowers/specs/2026-05-11-e2e-stage1-file-open-profile-design.md b/docs/superpowers/specs/2026-05-11-e2e-stage1-file-open-profile-design.md new file mode 100644 index 0000000..b44a010 --- /dev/null +++ b/docs/superpowers/specs/2026-05-11-e2e-stage1-file-open-profile-design.md @@ -0,0 +1,150 @@ +# Profiling file-open cost for `train_e2e_stage1` on Frontier + +**Date:** 2026-05-11 +**Author:** nchen +**Status:** Design — approved, plan pending + +## Goal + +Measure the end-to-end file-open cost of an `e2e_stage1` training job on Frontier +(Lustre filesystem, ~8753 shot HDF5 files at `/lustre/orion/fus187/proj-shared/foundation_model`), +and decide whether it is a real problem that needs mitigation. + +## Background + +`scripts/training/train_e2e_stage1.py` uses +`tokamak_foundation_model.data.multi_file_dataset.TokamakMultiFileDataset` to read +single-shot HDF5 files. File opens happen in two distinct places: + +1. **Startup indexing pass.** `_load_or_compute_lengths()` opens every shot HDF5 + sequentially to read its duration and compute a chunk count. Results are + cached to a `.pt` sidecar; subsequent runs short-circuit this entirely. +2. **Steady-state, during training.** Each DataLoader worker has its own LRU + cache of `h5py.File` handles, bounded by `max_open_files=1024`. Cache hits + are free; cold misses re-open with `h5py.File(path, "r", rdcc_nbytes=0)`. + Per-worker counters (`_prof_opens`, `_prof_hits`, `_prof_open_s`, + `_prof_close_s`, `_prof_getitem_s`) are already in place. + +Existing infrastructure we'll reuse: +- `scripts/profile_indexing.py` — times Phase 1. +- `scripts/slurm_frontier/profile_indexing.sh` — Frontier launcher for the above. +- `scripts/training/profile_stage1.py` — `torch.profiler` on the full train step. +- `scripts/training/probe_stage1_loading.py` — single-process `__getitem__` timing. + +Prior measurements (`logs/4555562_idx_profile.out`): +- 100-file run: 6.00 files/s, predicted ~33 min on full 8753. +- Two full-dataset attempts (jobs 4555563, 4558113) did **not** finish: the first + timed out at 1 h walltime, the second failed at 7 s (exit 1). +- `runs/lengths_cache_e2e_stage1/` is currently empty. + +## Scope + +**In:** +- Single Frontier job, one node, production training config (8 DDP ranks × + 4 workers/rank × batch 16, pulled from `scripts/slurm_frontier/train_e2e_stage1.sh`). +- Both phases: full-dataset indexing + ~200 steady-state training steps. +- A written verdict on whether file-open cost is acceptable or needs work. + +**Out:** +- Multi-node coordination measurements. +- Multiple worker-count sweeps (4 vs 8 vs 16). One config only. +- Lustre stripe-config experimentation. +- Changes to the production training script. + +## Plan + +### Phase A — startup indexing (full dataset) + +Run `scripts/profile_indexing.py` with no file cap against the full data +directory, writing the lengths cache to `runs/lengths_cache_e2e_stage1/`. Walltime +budget **3 h** (the prior 1 h attempt timed out). + +Measurements: +- Total wall time, files/s, valid/skipped count, total chunks. + +Side benefit: populates the lengths cache so all future training jobs skip the +indexing wall entirely. + +### Phase B — steady-state opens during training + +Run a new thin script `scripts/training/profile_stage1_opens.py` that mirrors +the existing `scripts/training/profile_stage1.py` structure (imports +`build_configs`, `build_datasets`, `resolve_shot_files`, `compute_step_loss` +from `train_e2e_stage1.py` — no changes to the production script). + +Configuration to match production (`train_e2e_stage1.sh`): +- 8 DDP ranks per node, 1 GPU per rank, `--gpu-bind=closest`. +- 4 DataLoader workers per rank (32 workers total). +- `batch_size=16`, `chunk_duration_s=0.05`, `step_size_s=0.01`, `warmup_s=1.0`, + `prediction_horizon_s=0.05`, `d_model=256`, `n_layers=8`, `n_heads=8`. +- Reuse the lengths cache from Phase A. + +Run ~200 training steps. At the end, each worker dumps its profiling counters +(`_prof_opens`, `_prof_hits`, `_prof_open_s`, `_prof_close_s`, `_prof_getitem_s`, +`_prof_load_s`, `_prof_process_s`) to a per-worker JSON file in +`runs/profile_e2e_stage1_opens/per_worker/`. + +Rank 0 reads all per-worker JSONs after `dist.barrier()`, aggregates, and +writes `summary.json` plus a human-readable `report.md`. + +If the existing in-place stdout logging (every 50 calls) is sufficient +to extract these numbers from the SLURM log, the JSON dump can be skipped in +favor of a `parse_log.py` post-processor. We will pick whichever is simpler +during implementation; the spec does not lock in one approach. + +### Putting them together + +Single launcher `scripts/slurm_frontier/profile_e2e_stage1_opens.sh`: +- `#SBATCH -t 03:00:00`, 1 node, account `fus187`. +- Runs Phase A first (CPU-only mode by calling the python script directly, + not via `srun`), then Phase B (via `srun -n 8 --gpu-bind=closest …`). +- Each phase writes to its own subdirectory under `runs/profile_e2e_stage1_opens/`. + +## Outputs + +All in `runs/profile_e2e_stage1_opens/`: + +- `indexing.log` — Phase A stdout: wall time, files/s, valid/skipped, total chunks. +- `per_worker/rank{R}_worker{W}.json` — raw per-worker counters from Phase B. +- `summary.json` — aggregated open counts / hit rate / open-wall across the + 32 workers; `__getitem__` time breakdown. +- `report.md` — synthesis and verdicts (see below). + +Side effect: `runs/lengths_cache_e2e_stage1/lengths_e2e_stage1_{train,val}.pt` +populated for future runs. + +## Verdict criteria (to include in `report.md`) + +| Question | Threshold | Source | +|---|---|---| +| Is full-dataset indexing tolerable? | < 30 min OK; 30–60 min worth pre-caching; > 60 min should be a permanent cache or restripe | Phase A wall time | +| Is the training loop open-bound? | Open-wall fraction of `__getitem__` < 5 % = good, 5–20 % = OK, > 20 % = bad | Phase B `_prof_open_s / _prof_getitem_s` | +| Is `max_open_files=1024` right-sized? | Hit rate > 95 % in steps 100–200 = fine; less = LRU churn | Phase B `_prof_hits / (_prof_hits + _prof_opens)` | +| Cold-start to first useful step | Indexing + warm-up; report as a number | Phase A + Phase B step-1 timing | + +Each verdict comes with a one-line recommendation: leave alone / pre-cache / +resize LRU / restripe / something else. + +## Expected back-of-envelope (sanity check) + +- 32 workers, 8753 files → ~274 files/worker. LRU=1024 means every worker fits + its slice — cold opens should happen at most once per file per worker. +- A pure `h5py.File()` open on Lustre is plausibly 20–100 ms (no duration + scan). At ~50 ms × 274 files = ~14 s of cold-open wall per worker, amortized + across the entire epoch. +- If the actual hit rate is much below 95 %, that's a red flag worth digging + into (DistributedSampler shard, `TwoLevelSampler` interaction, or per-worker + shard size larger than expected). +- Indexing throughput on Lustre is the dominant unknown. The prior 100-file + warm-cache extrapolation predicted 33 min but the full run timed out at 1 h, + so the true rate may be 2–4× slower than the small-N extrapolation suggested. + +## Open questions / decisions deferred to plan + +- Whether to dump counters via per-worker JSON files or parse the existing + stdout log (pick simpler at implementation time). +- Whether Phase A and Phase B share one SLURM job or run as two + `--dependency`-linked jobs (one job is simpler, picked here unless Phase A + is unstable enough to need re-runs). +- Whether to add an MPI broadcast of `__getitem__` step-1 timing for end-to-end + cold-start, or just report indexing wall + a single rank's step-1 time. diff --git a/pixi.lock b/pixi.lock index 1b49816..0915ca8 100644 --- a/pixi.lock +++ b/pixi.lock @@ -666,13 +666,16 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/omegaconf-2.3.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.2-h35e630c_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/packaging-26.2-pyhc364b38_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pip-26.1.1-pyh8b19718_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.11.15-hd63d673_0_cpython.conda - conda: https://conda.anaconda.org/conda-forge/noarch/python_abi-3.11-8_cp311.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0.3-py311h3778330_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/setuptools-82.0.1-pyh332efcf_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h366c992_103.conda - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.15.0-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-hc9c84f9_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/wheel-0.47.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/yaml-0.2.5-h280c20c_3.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl @@ -754,7 +757,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/a0/60/429e9b1cb3fc651937727befe258ea24122d9663e4d5709a48c9cbfceecb/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/09/7d/af933f0f6e0767995b4e2d705a0665e454d1c19402aa7e895de3951ebb04/scipy-1.17.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/bf/00/b8cc413748fb6383d1582e7cda51314f99743351c462a92dc690d5b5853b/sentry_sdk-2.59.0-py2.py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/d4/59e74daffcb57a07668852eeeb6035af9f32cbfd7a1d2511f17d2fe6a738/smmap-5.0.3-py3-none-any.whl @@ -1854,7 +1856,7 @@ packages: - pypi: ./ name: faith version: 26.1.dev0 - sha256: a79a12427b966cbe89abbd4681f70365e3eb9940b4eb6d992b9980c7dc0667ca + sha256: aa80d437e54308cbff39c33a40977cc207bfe33afe044f10a0545121a7dad92b requires_dist: - einops>=0.8.2,<0.9 - h5py>=3.15.1,<4 @@ -5826,6 +5828,19 @@ packages: - trove-classifiers>=2024.10.12 ; extra == 'tests' - defusedxml ; extra == 'xmp' requires_python: '>=3.10' +- conda: https://conda.anaconda.org/conda-forge/noarch/pip-26.1.1-pyh8b19718_0.conda + sha256: 1bd94ef1ae08fd811ef3b26857e46ba460c7430bf1f3ccd94a4d6614fd619bd5 + md5: 35870d32aed92041d31cbb15e822dca3 + depends: + - python >=3.10,<3.13.0a0 + - setuptools + - wheel + license: MIT + license_family: MIT + purls: + - pkg:pypi/pip?source=hash-mapping + size: 1201616 + timestamp: 1777924080196 - pypi: https://files.pythonhosted.org/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl name: platformdirs version: 4.5.1 @@ -7212,62 +7227,6 @@ packages: - importlib-metadata>=7.0.2 ; python_full_version < '3.10' and extra == 'type' - jaraco-develop>=7.21 ; sys_platform != 'cygwin' and extra == 'type' requires_python: '>=3.9' -- pypi: https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl - name: setuptools - version: 82.0.1 - sha256: a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb - requires_dist: - - pytest>=6,!=8.1.* ; extra == 'test' - - virtualenv>=13.0.0 ; extra == 'test' - - wheel>=0.44.0 ; extra == 'test' - - pip>=19.1 ; extra == 'test' - - packaging>=24.2 ; extra == 'test' - - jaraco-envs>=2.2 ; extra == 'test' - - pytest-xdist>=3 ; extra == 'test' - - jaraco-path>=3.7.2 ; extra == 'test' - - build[virtualenv]>=1.0.3 ; extra == 'test' - - filelock>=3.4.0 ; extra == 'test' - - ini2toml[lite]>=0.14 ; extra == 'test' - - tomli-w>=1.0.0 ; extra == 'test' - - pytest-timeout ; extra == 'test' - - pytest-perf ; sys_platform != 'cygwin' and extra == 'test' - - jaraco-develop>=7.21 ; python_full_version >= '3.9' and sys_platform != 'cygwin' and extra == 'test' - - pytest-home>=0.5 ; extra == 'test' - - pytest-subprocess ; extra == 'test' - - pyproject-hooks!=1.1 ; extra == 'test' - - jaraco-test>=5.5 ; extra == 'test' - - sphinx>=3.5 ; extra == 'doc' - - jaraco-packaging>=9.3 ; extra == 'doc' - - rst-linker>=1.9 ; extra == 'doc' - - furo ; extra == 'doc' - - sphinx-lint ; extra == 'doc' - - jaraco-tidelift>=1.4 ; extra == 'doc' - - pygments-github-lexers==0.0.5 ; extra == 'doc' - - sphinx-favicon ; extra == 'doc' - - sphinx-inline-tabs ; extra == 'doc' - - sphinx-reredirects ; extra == 'doc' - - sphinxcontrib-towncrier ; extra == 'doc' - - sphinx-notfound-page>=1,<2 ; extra == 'doc' - - pyproject-hooks!=1.1 ; extra == 'doc' - - towncrier<24.7 ; extra == 'doc' - - packaging>=24.2 ; extra == 'core' - - more-itertools>=8.8 ; extra == 'core' - - jaraco-text>=3.7 ; extra == 'core' - - importlib-metadata>=6 ; python_full_version < '3.10' and extra == 'core' - - tomli>=2.0.1 ; python_full_version < '3.11' and extra == 'core' - - wheel>=0.43.0 ; extra == 'core' - - jaraco-functools>=4 ; extra == 'core' - - more-itertools ; extra == 'core' - - pytest-checkdocs>=2.4 ; extra == 'check' - - pytest-ruff>=0.2.1 ; sys_platform != 'cygwin' and extra == 'check' - - ruff>=0.13.0 ; sys_platform != 'cygwin' and extra == 'check' - - pytest-cov ; extra == 'cover' - - pytest-enabler>=2.2 ; extra == 'enabler' - - pytest-mypy ; extra == 'type' - - mypy==1.18.* ; extra == 'type' - - importlib-metadata>=7.0.2 ; python_full_version < '3.10' and extra == 'type' - - jaraco-develop>=7.21 ; sys_platform != 'cygwin' and extra == 'type' - requires_python: '>=3.9' - conda: https://conda.anaconda.org/conda-forge/noarch/setuptools-82.0.0-pyh332efcf_0.conda sha256: fd7201e38e38bf7f25818d624ca8da97b8998957ca9ae3fb7fdc9c17e6b25fcd md5: 1d00d46c634177fc8ede8b99d6089239 @@ -7279,6 +7238,17 @@ packages: - pkg:pypi/setuptools?source=compressed-mapping size: 637506 timestamp: 1770634745653 +- conda: https://conda.anaconda.org/conda-forge/noarch/setuptools-82.0.1-pyh332efcf_0.conda + sha256: 82088a6e4daa33329a30bc26dc19a98c7c1d3f05c0f73ce9845d4eab4924e9e1 + md5: 8e194e7b992f99a5015edbd4ebd38efd + depends: + - python >=3.10 + license: MIT + license_family: MIT + purls: + - pkg:pypi/setuptools?source=hash-mapping + size: 639697 + timestamp: 1773074868565 - conda: https://conda.anaconda.org/conda-forge/noarch/sh-2.2.2-pyh707e725_1.conda sha256: 0346e6d30f96ebd4a4dec849dcfd644e6e09ad798f9fac76d6720896b07526f0 md5: 49190c42cea9458405140171fc02e847 @@ -8889,6 +8859,18 @@ packages: - markupsafe>=2.1.1 - watchdog>=2.3 ; extra == 'watchdog' requires_python: '>=3.9' +- conda: https://conda.anaconda.org/conda-forge/noarch/wheel-0.47.0-pyhd8ed1ab_0.conda + sha256: 9e156ffaefb8463437144326ada4b85d1de17961b9997ac5f1cbbaf747bd8bed + md5: d0e3b2f0030cf4fca58bde71d246e94c + depends: + - packaging >=24.0 + - python >=3.10 + license: MIT + license_family: MIT + purls: + - pkg:pypi/wheel?source=hash-mapping + size: 33491 + timestamp: 1776878563806 - pypi: https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl name: widgetsnbextension version: 4.0.15 diff --git a/pyproject.toml b/pyproject.toml index 0a17573..4ded3e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,15 @@ toksearch_d3d = { channel = "ga-fdp" } [tool.pixi.feature.frontier] platforms = ["linux-64"] +[tool.pixi.feature.frontier.dependencies] +# pip is needed for the `setup-flash-attn` task below to install flash-attn +# from a git URL with --no-build-isolation. The PyTorch wheels we pull from +# the rocm7.1 index don't drag pip in transitively. +pip = "*" +# ninja: aiter (a transitive dep of flash_attn on ROCm) JIT-compiles a small +# C++ extension at first `import flash_attn`. It calls `ninja` from PATH. +ninja = "*" + [tool.pixi.feature.frontier.pypi-dependencies] # rocm7.1 index ships torch 2.10.0 + torchvision 0.25-0.26 only. torch = { version = ">=2.10,<2.11", index = "https://download.pytorch.org/whl/rocm7.1" } @@ -82,6 +91,16 @@ torchvision = { version = ">=0.25,<0.27", index = "https://download.pytorch.or # torch 2.10 declares triton-rocm as a dep; uv won't auto-discover it # through the per-package `index = ...` above, so list it explicitly. triton-rocm = { version = "*", index = "https://download.pytorch.org/whl/rocm7.1" } +# Flash-Attention 2 (gfx90a / MI250X) is NOT listed here intentionally: +# the build needs `module load rocm/7.1.1` + `FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE`, +# which pixi/uv can't set. Install via the `setup-flash-attn` task below; we use +# the AMD Triton backend (not Composable Kernel) per the AMD docs at +# rocm.docs.amd.com/.../model-acceleration-libraries.html — Triton skips the +# multi-hour CK template/hipcc compile and builds in ~10-15 min. + +[tool.pixi.feature.frontier.tasks] +setup-flash-attn = { cmd = "bash scripts/slurm_frontier/setup_frontier_env.sh", description = "Build & install flash-attn 2 into the frontier pixi env on a Frontier compute node (gfx90a). Auto-salloc's if run from a login node." } +verify-flash-attn = { cmd = "python scripts/slurm_frontier/verify_flash_attn.py", description = "Smoke-test flash_attn on the local MI250X." } [tool.pixi.environments] default = ["cuda"] diff --git a/scripts/build_dataset_cache.py b/scripts/build_dataset_cache.py new file mode 100755 index 0000000..4dfbdb6 --- /dev/null +++ b/scripts/build_dataset_cache.py @@ -0,0 +1,548 @@ +#!/usr/bin/env python3 +""" +CPU-only builder for the dataset indexing caches that ``train_e2e`` jobs +expect on disk. + +Runs the per-file HDF5 scans (video-presence + chunk-count) **in parallel** +via a process pool, then writes cache files in the exact format the +training runtime expects (``filter_video_present_files`` and +``_load_or_compute_lengths`` in ``multi_file_dataset.py``). Training itself +never spawns a process pool — the parallelism lives here on purpose, where +CUDA / NCCL are not initialised, so the ``fork`` foot-gun cannot bite. + +Usage: + # Quick smoke (10 files): + python scripts/build_dataset_cache.py --max_files 10 + + # Full pass, write cache to a known location: + python scripts/build_dataset_cache.py \ + --cache_dir /lustre/orion/fus187/proj-shared/foundation_model_meta + + # Don't write the cache (pure timing measurement): + python scripts/build_dataset_cache.py --no_cache + +CPU-only: imports torch only for cache I/O, never touches CUDA. Pure h5py + +numpy + multiprocessing for the scans. +""" +import argparse +import logging +import multiprocessing as mp +import os +import random +import sys +import tempfile +import time +from concurrent.futures import ProcessPoolExecutor +from pathlib import Path +from typing import List, Optional, Tuple + +import h5py +import numpy as np +import torch +from tqdm import tqdm + +# Make sure we can import the project package without installing. +PROJECT_ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(PROJECT_ROOT / "src")) + +# Pulled in for SIGNAL_CONFIGS / MOVIE_CONFIGS only (these are class-level +# @dataclass lists, picklable, replicated into each worker process via +# ProcessPoolExecutor's pickle bridge). +from tokamak_foundation_model.data.data_loader import ( # noqa: E402 + TokamakH5Dataset, +) + + +# ── Worker functions ──────────────────────────────────────────────────── +# Must be top-level (picklable) for ProcessPoolExecutor. They re-import +# h5py inside the function so each worker process owns its HDF5 library +# state, matching the runtime behaviour of one shot-file open per call. + + +def _video_present_worker(args: tuple) -> Optional[str]: + """Return ``str(path)`` if any requested camera has non-empty data.""" + path, camera_names = args + try: + with h5py.File(path, "r") as f: + for cam in camera_names: + if cam not in f or "ydata" not in f[cam]: + continue + yd = f[cam]["ydata"] + xd = f[cam].get("xdata") + if ( + yd.size > 0 + and yd.ndim == 4 + and xd is not None + and xd.size >= 2 + ): + return str(path) + except Exception: + return None + return None + + +def _compute_length_worker(args: tuple) -> int: + """Return per-file chunk count. + + Inlines the duration arithmetic from + ``TokamakH5Dataset._compute_duration`` so the worker is self-contained + and does not need a dataset instance. + """ + ( + path, + signal_configs, + movie_configs, + max_duration_s, + warmup_s, + chunk_duration_s, + prediction_horizon_s, + step_size_s, + prediction_mode, + ) = args + try: + with h5py.File(path, "r") as f: + duration = 0.0 + for cfg in signal_configs: + for key_path in cfg.hdf5_keys: + try: + curr = f + for part in key_path.split("/"): + curr = curr[part] + xdata_s = curr["xdata"][:] + if len(xdata_s) < 2: + continue + duration = max(duration, float(xdata_s[-1])) + break + except (KeyError, ValueError): + continue + for mcfg in movie_configs: + for key_path in mcfg.hdf5_keys: + try: + curr = f + for part in key_path.split("/"): + curr = curr[part] + xdata_ms = curr["xdata"][:] + if len(xdata_ms) < 2: + continue + duration = max(duration, float(xdata_ms[-1])) + break + except (KeyError, ValueError): + continue + duration = min(duration, max_duration_s) - warmup_s + if duration <= 0.0: + return 0 + if prediction_mode: + total_window = chunk_duration_s + prediction_horizon_s + return max( + 0, int(np.floor((duration - total_window) / step_size_s)) + 1 + ) + if duration < chunk_duration_s: + return 0 + return int(np.floor((duration - chunk_duration_s) / step_size_s)) + 1 + except OSError: + return 0 + + +# ── Parallel scan + cache-write helpers ───────────────────────────────── + + +def _atomic_torch_save(payload: dict, cache_path: Path) -> None: + """Write ``payload`` to ``cache_path`` via ``.tmp`` + ``replace`` so a + crashed write never leaves a half-written zip that the next + ``torch.load`` would barf on.""" + cache_path.parent.mkdir(parents=True, exist_ok=True) + tmp = Path(str(cache_path) + ".tmp") + torch.save(payload, tmp) + tmp.replace(cache_path) + + +def parallel_video_presence_scan( + paths: List[Path], + camera_names: List[str], + cache_path: Optional[Path], + num_workers: int, +) -> List[Path]: + """Return the subset of ``paths`` whose HDF5 has non-empty video data. + + Writes a cache file in the same format as + ``multi_file_dataset.filter_video_present_files`` so training jobs + hit it transparently. + """ + paths_key = tuple(str(p) for p in paths) + cameras_key = tuple(sorted(camera_names)) + ctx = mp.get_context("forkserver") + tasks = [(p, camera_names) for p in paths] + video_present: List[str] = [] + with ProcessPoolExecutor(max_workers=num_workers, mp_context=ctx) as exc: + for result in tqdm( + exc.map(_video_present_worker, tasks, chunksize=8), + total=len(tasks), + desc=f"Video presence ({num_workers} workers)", + ): + if result is not None: + video_present.append(result) + if cache_path is not None: + _atomic_torch_save( + { + "paths_key": paths_key, + "cameras_key": cameras_key, + "video_present": video_present, + }, + cache_path, + ) + present = set(video_present) + return [p for p in paths if str(p) in present] + + +def parallel_lengths_scan( + paths: List[Path], + signal_configs: list, + movie_configs: list, + max_duration_s: float, + warmup_s: float, + chunk_duration_s: float, + prediction_horizon_s: float, + step_size_s: float, + prediction_mode: bool, + cache_path: Optional[Path], + num_workers: int, +) -> List[int]: + """Return per-file chunk counts in input order. Writes cache in the + same format as ``multi_file_dataset._load_or_compute_lengths`` so + training jobs hit it transparently.""" + paths_as_str = [str(p) for p in paths] + ctx = mp.get_context("forkserver") + tasks = [ + ( + p, + signal_configs, + movie_configs, + max_duration_s, + warmup_s, + chunk_duration_s, + prediction_horizon_s, + step_size_s, + prediction_mode, + ) + for p in paths + ] + with ProcessPoolExecutor(max_workers=num_workers, mp_context=ctx) as exc: + lengths = list( + tqdm( + exc.map(_compute_length_worker, tasks, chunksize=8), + total=len(tasks), + desc=f"Computing lengths ({num_workers} workers)", + ) + ) + if cache_path is not None: + _atomic_torch_save( + {"paths": paths_as_str, "lengths": lengths}, cache_path, + ) + return lengths + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") +logger = logging.getLogger("build_dataset_cache") + + +# Defaults match train_e2e_stage1.py's build_configs() for stage1. +DEFAULT_DIAGNOSTICS = [ + "ts_core_density", "ts_core_temp", "ts_tangential_density", + "ts_tangential_temp", "cer_ti", "cer_rot", "mse", "filterscopes", +] +DEFAULT_ACTUATORS = [ + "pin", "beam_voltage", "ech_power", "ech_tor_angle", "ech_pol_angle", + "ech_polarization", "gas_flow", "gas_raw", "rmp", +] + + +def resolve_shot_files( + data_dir: Path, + max_files: Optional[int], + val_fraction: float, + seed: int, +) -> Tuple[List[Path], List[Path]]: + """Mirror train_e2e_stage1.resolve_shot_files for the no-YAML branch. + + Identical seeding and split logic so the returned file lists are byte-for- + byte the same as what training would index. + """ + rng = random.Random(seed) + all_files = sorted(data_dir.glob("*_processed.h5")) + rng.shuffle(all_files) + n = len(all_files) + if n == 0: + return [], [] + n_val = max(1, int(val_fraction * n)) + val_files = all_files[:n_val] + train_files = all_files[n_val:] + if max_files is not None: + train_files = train_files[:max_files] + val_files = val_files[: max(1, max_files // 4)] + return train_files, val_files + + +def time_indexing( + label: str, + files: List[Path], + cache_path: Optional[Path], + chunk_duration_s: float, + prediction_horizon_s: float, + step_size_s: float, + warmup_s: float, + max_duration_s: float, + num_workers: int, +) -> dict: + """Run the parallel lengths scan and time it. Writes the cache in the + on-disk format that the training-runtime dataset expects.""" + logger.info(f"[{label}] indexing {len(files)} files (workers={num_workers})…") + t0 = time.perf_counter() + lengths = parallel_lengths_scan( + paths=files, + signal_configs=TokamakH5Dataset.SIGNAL_CONFIGS, + movie_configs=TokamakH5Dataset.MOVIE_CONFIGS, + max_duration_s=max_duration_s, + warmup_s=warmup_s, + chunk_duration_s=chunk_duration_s, + prediction_horizon_s=prediction_horizon_s, + step_size_s=step_size_s, + prediction_mode=True, + cache_path=cache_path, + num_workers=num_workers, + ) + dt = time.perf_counter() - t0 + + n_total = len(files) + n_valid = sum(1 for n in lengths if n > 0) + n_skipped = n_total - n_valid + n_chunks = int(sum(lengths)) + rate = (n_total / dt) if dt > 0 else float("inf") + + logger.info( + f"[{label}] {n_total} files in {dt:.2f}s " + f"({rate:.2f} files/s) " + f"valid={n_valid} skipped={n_skipped} total_chunks={n_chunks}" + ) + if cache_path is not None: + logger.info(f"[{label}] cache written: {cache_path}") + return dict( + label=label, + n_total=n_total, + n_valid=n_valid, + n_skipped=n_skipped, + n_chunks=n_chunks, + wall_s=dt, + files_per_s=rate, + cache_path=str(cache_path) if cache_path else None, + ) + + +def main(): + ap = argparse.ArgumentParser( + description="Profile build_datasets indexing throughput (CPU-only)." + ) + ap.add_argument( + "--data_dir", type=Path, + default=Path("/lustre/orion/fus187/proj-shared/foundation_model"), + ) + ap.add_argument("--max_files", type=int, default=None, + help="Cap on training files (default: all). val_files is " + "max_files // 4 to mirror train_e2e_stage1.") + ap.add_argument("--val_fraction", type=float, default=0.1) + ap.add_argument("--seed", type=int, default=42) + ap.add_argument("--chunk_duration_s", type=float, default=0.05) + ap.add_argument("--prediction_horizon_s", type=float, default=0.05) + ap.add_argument("--step_size_s", type=float, default=0.01) + ap.add_argument("--warmup_s", type=float, default=1.0) + ap.add_argument("--cache_dir", type=Path, default=None, + help="Where to save the lengths cache. Default: a unique " + "tempdir, so every run is a cold cache miss (the point of " + "this profiler). Set to a stable path to persist the cache " + "for training jobs.") + ap.add_argument("--no_cache", action="store_true", + help="Skip writing the cache entirely.") + ap.add_argument("--diagnostic_names", type=str, default=None, + help="Comma-separated list. Default: stage1 diagnostics.") + ap.add_argument("--actuator_names", type=str, default=None, + help="Comma-separated list. Default: stage1 actuators.") + ap.add_argument("--skip_val", action="store_true", + help="Profile train indexing only.") + ap.add_argument( + "--use_video", nargs="*", default=[], + help="Camera names to require present (e.g. 'tangtv'). Must match the " + "training run's --use_video so the resulting lengths cache is keyed " + "on the same path list. Empty (default) skips the video filter.", + ) + ap.add_argument( + "--video_cache_dir", type=Path, default=None, + help="Where to write/read the video-presence cache. Defaults to " + "--cache_dir so the training run can reuse it.", + ) + ap.add_argument( + "--num_workers", type=int, + default=int(os.environ.get("INDEXING_WORKERS", "8")), + help="Process-pool size for the parallel HDF5 scans (default 8, " + "env override INDEXING_WORKERS). One worker per concurrent open; " + "bumping this raises Lustre MDS pressure linearly.", + ) + ap.add_argument( + "--max_duration_s", type=float, default=12.0, + help="Cap on shot duration used by the lengths arithmetic. Must " + "match TokamakMultiFileDataset's default for the cache to be a " + "drop-in for training.", + ) + ap.add_argument( + "--cache_name_prefix", type=str, default="lengths_e2e_stage1", + help="Filename prefix for the lengths cache. Defaults to " + "'lengths_e2e_stage1' (matches train_e2e_stage1.py's expected " + "cache name). Override for other stages, e.g. " + "'lengths_e2e_stage2_delta'. The lengths cache contents depend " + "on (paths, prediction_horizon_s, chunk_duration_s, step_size_s, " + "warmup_s) — stages with different windowing MUST use distinct " + "prefixes to avoid overwriting each other's cache.", + ) + args = ap.parse_args() + + if not args.data_dir.is_dir(): + raise SystemExit(f"data_dir not found: {args.data_dir}") + + diagnostic_names = ( + args.diagnostic_names.split(",") if args.diagnostic_names + else DEFAULT_DIAGNOSTICS + ) + actuator_names = ( + args.actuator_names.split(",") if args.actuator_names + else DEFAULT_ACTUATORS + ) + + logger.info(f"data_dir = {args.data_dir}") + logger.info(f"diagnostics = {diagnostic_names}") + logger.info(f"actuators = {actuator_names}") + logger.info( + f"chunk_duration_s={args.chunk_duration_s} " + f"prediction_horizon_s={args.prediction_horizon_s} " + f"step_size_s={args.step_size_s} warmup_s={args.warmup_s}" + ) + + train_files, val_files = resolve_shot_files( + args.data_dir, args.max_files, args.val_fraction, args.seed, + ) + logger.info(f"Resolved files — train: {len(train_files)} val: {len(val_files)}") + if not train_files: + raise SystemExit(f"No *_processed.h5 files matched {args.data_dir}") + + # Cache directory selection. + if args.no_cache: + cache_dir = None + logger.info("Cache: disabled (--no_cache)") + elif args.cache_dir is not None: + cache_dir = args.cache_dir + cache_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Cache dir: {cache_dir}") + else: + cache_dir = Path(tempfile.mkdtemp(prefix="build_dataset_cache_")) + logger.info(f"Cache dir (tempdir, cold-miss every run): {cache_dir}") + + # Apply video-presence filter BEFORE building the lengths cache so the + # stored `paths` key matches what training will see at run time. Without + # this, training (with --use_video) builds a smaller filtered list, the + # cache's `paths` check fails, and the pre-warm is wasted. + if args.use_video: + video_cache_dir = args.video_cache_dir or cache_dir + n_train_before = len(train_files) + n_val_before = len(val_files) + train_files = parallel_video_presence_scan( + paths=train_files, + camera_names=args.use_video, + cache_path=( + video_cache_dir / "video_present_train.pt" + if video_cache_dir else None + ), + num_workers=args.num_workers, + ) + val_files = parallel_video_presence_scan( + paths=val_files, + camera_names=args.use_video, + cache_path=( + video_cache_dir / "video_present_val.pt" + if video_cache_dir else None + ), + num_workers=args.num_workers, + ) + logger.info( + f"Video-presence filter ({args.use_video}): " + f"train {n_train_before} -> {len(train_files)}; " + f"val {n_val_before} -> {len(val_files)}" + ) + + train_cache = (cache_dir / f"{args.cache_name_prefix}_train.pt") if cache_dir else None + val_cache = (cache_dir / f"{args.cache_name_prefix}_val.pt") if cache_dir else None + + results = [] + results.append(time_indexing( + label="train", + files=train_files, + cache_path=train_cache, + chunk_duration_s=args.chunk_duration_s, + prediction_horizon_s=args.prediction_horizon_s, + step_size_s=args.step_size_s, + warmup_s=args.warmup_s, + max_duration_s=args.max_duration_s, + num_workers=args.num_workers, + )) + + if val_files and not args.skip_val: + results.append(time_indexing( + label="val", + files=val_files, + cache_path=val_cache, + chunk_duration_s=args.chunk_duration_s, + prediction_horizon_s=args.prediction_horizon_s, + step_size_s=args.step_size_s, + warmup_s=args.warmup_s, + max_duration_s=args.max_duration_s, + num_workers=args.num_workers, + )) + + # ─── Aggregate summary ─────────────────────────────────────────────── + total_files = sum(r["n_total"] for r in results) + total_skipped = sum(r["n_skipped"] for r in results) + total_chunks = sum(r["n_chunks"] for r in results) + total_wall = sum(r["wall_s"] for r in results) + overall_rate = (total_files / total_wall) if total_wall > 0 else float("inf") + + print() + print("=" * 68) + print(" INDEXING PROFILE SUMMARY") + print("=" * 68) + for r in results: + print( + f" {r['label']:<6} files={r['n_total']:<6} " + f"valid={r['n_valid']:<6} skipped={r['n_skipped']:<4} " + f"chunks={r['n_chunks']:<8} " + f"time={r['wall_s']:>7.2f}s rate={r['files_per_s']:>6.2f} files/s" + ) + print("-" * 68) + print( + f" {'TOTAL':<6} files={total_files:<6} " + f"valid={total_files - total_skipped:<6} " + f"skipped={total_skipped:<4} " + f"chunks={total_chunks:<8} " + f"time={total_wall:>7.2f}s rate={overall_rate:>6.2f} files/s" + ) + print("=" * 68) + + # Predicted full-dataset cost. + if args.max_files is not None: + # Estimate total dataset size by re-globbing without the cap. + full_count = len(sorted(args.data_dir.glob("*_processed.h5"))) + if full_count > total_files and overall_rate > 0: + predicted = full_count / overall_rate + print( + f" Predicted full-dataset indexing ({full_count} files): " + f"{predicted:.0f}s = {predicted / 60:.1f} min" + ) + print() + + +if __name__ == "__main__": + main() diff --git a/scripts/data_preparation/make_processing_stats.py b/scripts/data_preparation/make_processing_stats.py index ef80aad..257735f 100644 --- a/scripts/data_preparation/make_processing_stats.py +++ b/scripts/data_preparation/make_processing_stats.py @@ -4,7 +4,7 @@ def main(): hdf5_files = sorted( - Path("/scratch/gpfs/EKOLEMEN/foundation_model/").glob("*_processed.h5") + Path("/lustre/orion/fus187/proj-shared/foundation_model").glob("*_processed.h5") ) all_signals = [ @@ -45,7 +45,7 @@ def main(): compute_preprocessing_stats( hdf5_paths=hdf5_files, signal_names=all_signals, - output_path="preprocessing_stats.pt", + output_path="/lustre/orion/fus187/proj-shared/foundation_model_meta/preprocessing_stats.pt", stft_signals=stft_signals, hdf5_key_map=hdf5_key_map, zero_is_missing_signals=zero_is_missing_signals, diff --git a/scripts/profile_indexing.py b/scripts/profile_indexing.py deleted file mode 100755 index e2af387..0000000 --- a/scripts/profile_indexing.py +++ /dev/null @@ -1,280 +0,0 @@ -#!/usr/bin/env python3 -""" -CPU-only profiler for the file-length indexing pass that train_e2e jobs do -in build_datasets(). - -Replicates train_e2e_stage1.py's resolve_shot_files() and dataset construction, -times only the indexing step, and reports total wall time and files/sec -throughput. Use this to: - - - Predict how long indexing will take on N files before launching training. - - Pre-populate the lengths cache so subsequent training jobs skip the wall. - -Usage: - # Quick smoke (10 files): - python scripts/profile_indexing.py --max_files 10 - - # Full pass, write cache to a known location: - python scripts/profile_indexing.py \ - --cache_dir runs/lengths_cache_e2e_stage1 - - # Don't write the cache (pure measurement): - python scripts/profile_indexing.py --no_cache - -CPU-only: imports torch but never touches CUDA. Pure h5py + numpy I/O on Lustre. -""" -import argparse -import logging -import random -import sys -import tempfile -import time -from pathlib import Path -from typing import List, Optional, Tuple - -# Make sure we can import the project package without installing. -PROJECT_ROOT = Path(__file__).resolve().parents[1] -sys.path.insert(0, str(PROJECT_ROOT / "src")) - -# These imports must come after the path tweak. Note: TokamakMultiFileDataset -# pulls in torch but only uses CPU paths during indexing. -from tokamak_foundation_model.data.multi_file_dataset import TokamakMultiFileDataset # noqa: E402 - -logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") -logger = logging.getLogger("profile_indexing") - - -# Defaults match train_e2e_stage1.py's build_configs() for stage1. -DEFAULT_DIAGNOSTICS = [ - "ts_core_density", "ts_core_temp", "ts_tangential_density", - "ts_tangential_temp", "cer_ti", "cer_rot", "mse", "filterscopes", -] -DEFAULT_ACTUATORS = [ - "pin", "beam_voltage", "ech_power", "ech_tor_angle", "ech_pol_angle", - "ech_polarization", "gas_flow", "gas_raw", "rmp", -] - - -def resolve_shot_files( - data_dir: Path, - max_files: Optional[int], - val_fraction: float, - seed: int, -) -> Tuple[List[Path], List[Path]]: - """Mirror train_e2e_stage1.resolve_shot_files for the no-YAML branch. - - Identical seeding and split logic so the returned file lists are byte-for- - byte the same as what training would index. - """ - rng = random.Random(seed) - all_files = sorted(data_dir.glob("*_processed.h5")) - rng.shuffle(all_files) - n = len(all_files) - if n == 0: - return [], [] - n_val = max(1, int(val_fraction * n)) - val_files = all_files[:n_val] - train_files = all_files[n_val:] - if max_files is not None: - train_files = train_files[:max_files] - val_files = val_files[: max(1, max_files // 4)] - return train_files, val_files - - -def time_indexing( - label: str, - files: List[Path], - cache_path: Optional[Path], - chunk_duration_s: float, - prediction_horizon_s: float, - step_size_s: float, - warmup_s: float, - diagnostic_names: List[str], - actuator_names: List[str], -) -> dict: - """Build a TokamakMultiFileDataset and time only the indexing pass.""" - logger.info(f"[{label}] indexing {len(files)} files…") - t0 = time.perf_counter() - ds = TokamakMultiFileDataset( - files, - chunk_duration_s=chunk_duration_s, - prediction_mode=True, - prediction_horizon_s=prediction_horizon_s, - step_size_s=step_size_s, - warmup_s=warmup_s, - preprocessing_stats={}, - input_signals=diagnostic_names, - target_signals=diagnostic_names + actuator_names, - lengths_cache_path=cache_path, - ) - dt = time.perf_counter() - t0 - - n_total = len(files) - n_valid = len(ds._valid_indices) - n_skipped = n_total - n_valid - n_chunks = int(ds._cumulative_lengths[-1]) if n_valid > 0 else 0 - rate = (n_total / dt) if dt > 0 else float("inf") - - logger.info( - f"[{label}] {n_total} files in {dt:.2f}s " - f"({rate:.2f} files/s) " - f"valid={n_valid} skipped={n_skipped} total_chunks={n_chunks}" - ) - if cache_path is not None: - logger.info(f"[{label}] cache written: {cache_path}") - return dict( - label=label, - n_total=n_total, - n_valid=n_valid, - n_skipped=n_skipped, - n_chunks=n_chunks, - wall_s=dt, - files_per_s=rate, - cache_path=str(cache_path) if cache_path else None, - ) - - -def main(): - ap = argparse.ArgumentParser( - description="Profile build_datasets indexing throughput (CPU-only)." - ) - ap.add_argument( - "--data_dir", type=Path, - default=Path("/lustre/orion/fus187/proj-shared/foundation_model"), - ) - ap.add_argument("--max_files", type=int, default=None, - help="Cap on training files (default: all). val_files is " - "max_files // 4 to mirror train_e2e_stage1.") - ap.add_argument("--val_fraction", type=float, default=0.1) - ap.add_argument("--seed", type=int, default=42) - ap.add_argument("--chunk_duration_s", type=float, default=0.05) - ap.add_argument("--prediction_horizon_s", type=float, default=0.05) - ap.add_argument("--step_size_s", type=float, default=0.01) - ap.add_argument("--warmup_s", type=float, default=1.0) - ap.add_argument("--cache_dir", type=Path, default=None, - help="Where to save the lengths cache. Default: a unique " - "tempdir, so every run is a cold cache miss (the point of " - "this profiler). Set to a stable path to persist the cache " - "for training jobs.") - ap.add_argument("--no_cache", action="store_true", - help="Skip writing the cache entirely.") - ap.add_argument("--diagnostic_names", type=str, default=None, - help="Comma-separated list. Default: stage1 diagnostics.") - ap.add_argument("--actuator_names", type=str, default=None, - help="Comma-separated list. Default: stage1 actuators.") - ap.add_argument("--skip_val", action="store_true", - help="Profile train indexing only.") - args = ap.parse_args() - - if not args.data_dir.is_dir(): - raise SystemExit(f"data_dir not found: {args.data_dir}") - - diagnostic_names = ( - args.diagnostic_names.split(",") if args.diagnostic_names - else DEFAULT_DIAGNOSTICS - ) - actuator_names = ( - args.actuator_names.split(",") if args.actuator_names - else DEFAULT_ACTUATORS - ) - - logger.info(f"data_dir = {args.data_dir}") - logger.info(f"diagnostics = {diagnostic_names}") - logger.info(f"actuators = {actuator_names}") - logger.info( - f"chunk_duration_s={args.chunk_duration_s} " - f"prediction_horizon_s={args.prediction_horizon_s} " - f"step_size_s={args.step_size_s} warmup_s={args.warmup_s}" - ) - - train_files, val_files = resolve_shot_files( - args.data_dir, args.max_files, args.val_fraction, args.seed, - ) - logger.info(f"Resolved files — train: {len(train_files)} val: {len(val_files)}") - if not train_files: - raise SystemExit(f"No *_processed.h5 files matched {args.data_dir}") - - # Cache directory selection. - if args.no_cache: - cache_dir = None - logger.info("Cache: disabled (--no_cache)") - elif args.cache_dir is not None: - cache_dir = args.cache_dir - cache_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Cache dir: {cache_dir}") - else: - cache_dir = Path(tempfile.mkdtemp(prefix="profile_indexing_")) - logger.info(f"Cache dir (tempdir, cold-miss every run): {cache_dir}") - - train_cache = (cache_dir / "lengths_e2e_stage1_train.pt") if cache_dir else None - val_cache = (cache_dir / "lengths_e2e_stage1_val.pt") if cache_dir else None - - results = [] - results.append(time_indexing( - label="train", - files=train_files, - cache_path=train_cache, - chunk_duration_s=args.chunk_duration_s, - prediction_horizon_s=args.prediction_horizon_s, - step_size_s=args.step_size_s, - warmup_s=args.warmup_s, - diagnostic_names=diagnostic_names, - actuator_names=actuator_names, - )) - - if val_files and not args.skip_val: - results.append(time_indexing( - label="val", - files=val_files, - cache_path=val_cache, - chunk_duration_s=args.chunk_duration_s, - prediction_horizon_s=args.prediction_horizon_s, - step_size_s=args.step_size_s, - warmup_s=args.warmup_s, - diagnostic_names=diagnostic_names, - actuator_names=actuator_names, - )) - - # ─── Aggregate summary ─────────────────────────────────────────────── - total_files = sum(r["n_total"] for r in results) - total_skipped = sum(r["n_skipped"] for r in results) - total_chunks = sum(r["n_chunks"] for r in results) - total_wall = sum(r["wall_s"] for r in results) - overall_rate = (total_files / total_wall) if total_wall > 0 else float("inf") - - print() - print("=" * 68) - print(" INDEXING PROFILE SUMMARY") - print("=" * 68) - for r in results: - print( - f" {r['label']:<6} files={r['n_total']:<6} " - f"valid={r['n_valid']:<6} skipped={r['n_skipped']:<4} " - f"chunks={r['n_chunks']:<8} " - f"time={r['wall_s']:>7.2f}s rate={r['files_per_s']:>6.2f} files/s" - ) - print("-" * 68) - print( - f" {'TOTAL':<6} files={total_files:<6} " - f"valid={total_files - total_skipped:<6} " - f"skipped={total_skipped:<4} " - f"chunks={total_chunks:<8} " - f"time={total_wall:>7.2f}s rate={overall_rate:>6.2f} files/s" - ) - print("=" * 68) - - # Predicted full-dataset cost. - if args.max_files is not None: - # Estimate total dataset size by re-globbing without the cap. - full_count = len(sorted(args.data_dir.glob("*_processed.h5"))) - if full_count > total_files and overall_rate > 0: - predicted = full_count / overall_rate - print( - f" Predicted full-dataset indexing ({full_count} files): " - f"{predicted:.0f}s = {predicted / 60:.1f} min" - ) - print() - - -if __name__ == "__main__": - main() diff --git a/scripts/slurm_rocm/setup_rocm_env.sh b/scripts/slurm_della_milan/setup_rocm_env.sh old mode 100755 new mode 100644 similarity index 93% rename from scripts/slurm_rocm/setup_rocm_env.sh rename to scripts/slurm_della_milan/setup_rocm_env.sh index 5f267f4..f99ed57 --- a/scripts/slurm_rocm/setup_rocm_env.sh +++ b/scripts/slurm_della_milan/setup_rocm_env.sh @@ -1,6 +1,7 @@ #!/bin/bash # Run this once on della-milan to create a ROCm venv for MI210 (gfx90a). -# Usage: bash scripts/slurm_rocm/setup_rocm_env.sh +# For OLCF Frontier (MI250X), use scripts/slurm_frontier/setup_frontier_env.sh instead. +# Usage: bash scripts/slurm_della_milan/setup_rocm_env.sh set -euo pipefail PROJECT_DIR=/scratch/gpfs/EKOLEMEN/nc1514/FusionAIHub diff --git a/scripts/slurm_rocm/submit_all.sh b/scripts/slurm_della_milan/submit_all.sh similarity index 100% rename from scripts/slurm_rocm/submit_all.sh rename to scripts/slurm_della_milan/submit_all.sh diff --git a/scripts/slurm_rocm/train_bes.sh b/scripts/slurm_della_milan/train_bes.sh similarity index 100% rename from scripts/slurm_rocm/train_bes.sh rename to scripts/slurm_della_milan/train_bes.sh diff --git a/scripts/slurm_rocm/train_bolo_raw.sh b/scripts/slurm_della_milan/train_bolo_raw.sh similarity index 100% rename from scripts/slurm_rocm/train_bolo_raw.sh rename to scripts/slurm_della_milan/train_bolo_raw.sh diff --git a/scripts/slurm_rocm/train_cer_rot.sh b/scripts/slurm_della_milan/train_cer_rot.sh similarity index 100% rename from scripts/slurm_rocm/train_cer_rot.sh rename to scripts/slurm_della_milan/train_cer_rot.sh diff --git a/scripts/slurm_rocm/train_cer_ti.sh b/scripts/slurm_della_milan/train_cer_ti.sh similarity index 100% rename from scripts/slurm_rocm/train_cer_ti.sh rename to scripts/slurm_della_milan/train_cer_ti.sh diff --git a/scripts/slurm_rocm/train_co2.sh b/scripts/slurm_della_milan/train_co2.sh similarity index 100% rename from scripts/slurm_rocm/train_co2.sh rename to scripts/slurm_della_milan/train_co2.sh diff --git a/scripts/slurm_rocm/train_ddp.sh b/scripts/slurm_della_milan/train_ddp.sh old mode 100755 new mode 100644 similarity index 97% rename from scripts/slurm_rocm/train_ddp.sh rename to scripts/slurm_della_milan/train_ddp.sh index 3e0fc83..2e099e6 --- a/scripts/slurm_rocm/train_ddp.sh +++ b/scripts/slurm_della_milan/train_ddp.sh @@ -1,7 +1,7 @@ #!/bin/bash # 2-GPU DDP launcher for ROCm on della-milan. # Usage: -# SIGNAL=ece bash scripts/slurm_rocm/train_ddp.sh +# SIGNAL=ece bash scripts/slurm_della_milan/train_ddp.sh # Env: # SIGNAL required signal name (matches MODEL_REGISTRY entry) # BATCH_SIZE per-GPU batch size (default: 4) diff --git a/scripts/slurm_rocm/train_e2e_stage1_ddp.sh b/scripts/slurm_della_milan/train_e2e_stage1_ddp.sh old mode 100755 new mode 100644 similarity index 98% rename from scripts/slurm_rocm/train_e2e_stage1_ddp.sh rename to scripts/slurm_della_milan/train_e2e_stage1_ddp.sh index c16ef94..4843c4f --- a/scripts/slurm_rocm/train_e2e_stage1_ddp.sh +++ b/scripts/slurm_della_milan/train_e2e_stage1_ddp.sh @@ -1,7 +1,7 @@ #!/bin/bash # 2-GPU DDP launcher for E2E Stage 1 on AMD MI210 (della-milan). # Usage: -# bash scripts/slurm_rocm/train_e2e_stage1_ddp.sh +# bash scripts/slurm_della_milan/train_e2e_stage1_ddp.sh # Env overrides: # GPUS (default: "0,1") # BATCH_SIZE (per-rank, default: 16) diff --git a/scripts/slurm_rocm/train_e2e_stage2_ddp.sh b/scripts/slurm_della_milan/train_e2e_stage2_ddp.sh old mode 100755 new mode 100644 similarity index 98% rename from scripts/slurm_rocm/train_e2e_stage2_ddp.sh rename to scripts/slurm_della_milan/train_e2e_stage2_ddp.sh index 2a23fa1..640011e --- a/scripts/slurm_rocm/train_e2e_stage2_ddp.sh +++ b/scripts/slurm_della_milan/train_e2e_stage2_ddp.sh @@ -1,7 +1,7 @@ #!/bin/bash # 2-GPU DDP launcher for E2E Stage 2 on AMD MI210 (della-milan). # Usage: -# bash scripts/slurm_rocm/train_e2e_stage2_ddp.sh +# bash scripts/slurm_della_milan/train_e2e_stage2_ddp.sh # Env overrides: # GPUS (default: "0,1") # BATCH_SIZE per-rank, (default: 8 — bf16 rollouts are heavier than stage1) diff --git a/scripts/slurm_rocm/train_e2e_stage2_delta_ddp.sh b/scripts/slurm_della_milan/train_e2e_stage2_delta_ddp.sh old mode 100755 new mode 100644 similarity index 97% rename from scripts/slurm_rocm/train_e2e_stage2_delta_ddp.sh rename to scripts/slurm_della_milan/train_e2e_stage2_delta_ddp.sh index cdc9983..bdeba56 --- a/scripts/slurm_rocm/train_e2e_stage2_delta_ddp.sh +++ b/scripts/slurm_della_milan/train_e2e_stage2_delta_ddp.sh @@ -1,6 +1,6 @@ #!/bin/bash # 2-GPU DDP launcher for E2E Stage 2_delta on AMD MI210. -# Usage: bash scripts/slurm_rocm/train_e2e_stage2_delta_ddp.sh +# Usage: bash scripts/slurm_della_milan/train_e2e_stage2_delta_ddp.sh # #SBATCH --job-name=e2e_stage2_delta_ddp_rocm #SBATCH --output=logs/%j_e2e_stage2_delta_ddp.out diff --git a/scripts/slurm_rocm/train_e2e_stage2_extended_ddp.sh b/scripts/slurm_della_milan/train_e2e_stage2_extended_ddp.sh similarity index 100% rename from scripts/slurm_rocm/train_e2e_stage2_extended_ddp.sh rename to scripts/slurm_della_milan/train_e2e_stage2_extended_ddp.sh diff --git a/scripts/slurm_rocm/train_e2e_stage3_ddp.sh b/scripts/slurm_della_milan/train_e2e_stage3_ddp.sh similarity index 100% rename from scripts/slurm_rocm/train_e2e_stage3_ddp.sh rename to scripts/slurm_della_milan/train_e2e_stage3_ddp.sh diff --git a/scripts/slurm_rocm/train_ece.sh b/scripts/slurm_della_milan/train_ece.sh similarity index 100% rename from scripts/slurm_rocm/train_ece.sh rename to scripts/slurm_della_milan/train_ece.sh diff --git a/scripts/slurm_rocm/train_filterscopes.sh b/scripts/slurm_della_milan/train_filterscopes.sh similarity index 100% rename from scripts/slurm_rocm/train_filterscopes.sh rename to scripts/slurm_della_milan/train_filterscopes.sh diff --git a/scripts/slurm_rocm/train_i_coil.sh b/scripts/slurm_della_milan/train_i_coil.sh similarity index 100% rename from scripts/slurm_rocm/train_i_coil.sh rename to scripts/slurm_della_milan/train_i_coil.sh diff --git a/scripts/slurm_rocm/train_ich.sh b/scripts/slurm_della_milan/train_ich.sh similarity index 100% rename from scripts/slurm_rocm/train_ich.sh rename to scripts/slurm_della_milan/train_ich.sh diff --git a/scripts/slurm_rocm/train_langmuir.sh b/scripts/slurm_della_milan/train_langmuir.sh similarity index 100% rename from scripts/slurm_rocm/train_langmuir.sh rename to scripts/slurm_della_milan/train_langmuir.sh diff --git a/scripts/slurm_rocm/train_mhr.sh b/scripts/slurm_della_milan/train_mhr.sh similarity index 100% rename from scripts/slurm_rocm/train_mhr.sh rename to scripts/slurm_della_milan/train_mhr.sh diff --git a/scripts/slurm_rocm/train_mirnov.sh b/scripts/slurm_della_milan/train_mirnov.sh similarity index 100% rename from scripts/slurm_rocm/train_mirnov.sh rename to scripts/slurm_della_milan/train_mirnov.sh diff --git a/scripts/slurm_rocm/train_mse.sh b/scripts/slurm_della_milan/train_mse.sh similarity index 100% rename from scripts/slurm_rocm/train_mse.sh rename to scripts/slurm_della_milan/train_mse.sh diff --git a/scripts/slurm_rocm/train_neutron_rate.sh b/scripts/slurm_della_milan/train_neutron_rate.sh similarity index 100% rename from scripts/slurm_rocm/train_neutron_rate.sh rename to scripts/slurm_della_milan/train_neutron_rate.sh diff --git a/scripts/slurm_rocm/train_sxr.sh b/scripts/slurm_della_milan/train_sxr.sh similarity index 100% rename from scripts/slurm_rocm/train_sxr.sh rename to scripts/slurm_della_milan/train_sxr.sh diff --git a/scripts/slurm_rocm/train_ts_core_density.sh b/scripts/slurm_della_milan/train_ts_core_density.sh similarity index 100% rename from scripts/slurm_rocm/train_ts_core_density.sh rename to scripts/slurm_della_milan/train_ts_core_density.sh diff --git a/scripts/slurm_rocm/train_ts_core_temp.sh b/scripts/slurm_della_milan/train_ts_core_temp.sh similarity index 100% rename from scripts/slurm_rocm/train_ts_core_temp.sh rename to scripts/slurm_della_milan/train_ts_core_temp.sh diff --git a/scripts/slurm_rocm/train_ts_tangential_density.sh b/scripts/slurm_della_milan/train_ts_tangential_density.sh similarity index 100% rename from scripts/slurm_rocm/train_ts_tangential_density.sh rename to scripts/slurm_della_milan/train_ts_tangential_density.sh diff --git a/scripts/slurm_rocm/train_ts_tangential_temp.sh b/scripts/slurm_della_milan/train_ts_tangential_temp.sh similarity index 100% rename from scripts/slurm_rocm/train_ts_tangential_temp.sh rename to scripts/slurm_della_milan/train_ts_tangential_temp.sh diff --git a/scripts/slurm_rocm/train_vib.sh b/scripts/slurm_della_milan/train_vib.sh similarity index 100% rename from scripts/slurm_rocm/train_vib.sh rename to scripts/slurm_della_milan/train_vib.sh diff --git a/scripts/slurm_frontier/_compare_profiles.py b/scripts/slurm_frontier/_compare_profiles.py new file mode 100755 index 0000000..67ac2f4 --- /dev/null +++ b/scripts/slurm_frontier/_compare_profiles.py @@ -0,0 +1,74 @@ +"""Diff two memory.json outputs from profile_stage1.py and print a table. + +Usage: + python _compare_profiles.py + +Prints rows: step_time_s, throughput_steps_per_s, peak_alloc_GB, +peak_reserved_GB. Each row has baseline value, treatment value, delta +(treatment - baseline), and ratio (treatment / baseline). Pure stdlib. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + + +def fmt(x: float | None) -> str: + if x is None: + return " n/a" + return f"{x:>7.3f}" + + +def main() -> int: + p = argparse.ArgumentParser() + p.add_argument("baseline", type=Path) + p.add_argument("treatment", type=Path) + args = p.parse_args() + + with args.baseline.open() as f: + base = json.load(f) + with args.treatment.open() as f: + treat = json.load(f) + + rows = [ + ("step_time_s", "active_mean_step_s", True), + ("throughput_steps_per_s", "throughput_steps_per_s", False), + ("peak_alloc_GB", "peak_alloc_GB", True), + ("peak_reserved_GB", "peak_reserved_GB", True), + ] + + print(f"baseline ({base.get('attn_impl')}): {args.baseline}") + print(f"treatment ({treat.get('attn_impl')}): {args.treatment}") + print() + print(f"{'metric':<24} {'baseline':>9} {'treatment':>10} {'delta':>9} {'ratio':>8}") + print("-" * 64) + for label, key, lower_is_better in rows: + b = base.get(key) + t = treat.get(key) + delta = (t - b) if (b is not None and t is not None) else None + ratio = (t / b) if (b not in (None, 0) and t is not None) else None + arrow = "" + if delta is not None: + if lower_is_better: + arrow = "↓" if delta < 0 else "↑" + else: + arrow = "↑" if delta > 0 else "↓" + print( + f"{label:<24} {fmt(b):>9} {fmt(t):>10} " + f"{fmt(delta):>9} {fmt(ratio):>8} {arrow}" + ) + print() + # Headline line for grep-friendly summary. + b_step = base.get("active_mean_step_s") + t_step = treat.get("active_mean_step_s") + if b_step and t_step: + speedup = b_step / t_step + print(f"SUMMARY: {speedup:.2f}x speedup with {treat.get('attn_impl')}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/slurm_frontier/_frontier_common.sh b/scripts/slurm_frontier/_frontier_common.sh deleted file mode 100755 index 554a4c5..0000000 --- a/scripts/slurm_frontier/_frontier_common.sh +++ /dev/null @@ -1,49 +0,0 @@ -# Frontier-common environment for ROCm DDP jobs. -# Source from every Frontier SLURM script BEFORE activating the venv. -# Sets modules, RCCL/NCCL knobs, MIOpen cache, and MASTER_ADDR/PORT. -# -# Frontier hardware reminders (see docs.olcf.ornl.gov): -# - 4x MI250X = 8 GCDs per node, each appears as a separate GPU. -# - HSN is Slingshot via libfabric/cxi; RCCL needs hsn0 + kdreg2. -# - MIOpen cache in $HOME is slow & contended; redirect to /tmp. - -# shellcheck shell=bash - -module load PrgEnv-gnu/8.7.0 -module load cpe/26.03 -module load rocm/7.1.1 -module load craype-accel-amd-gfx90a -export LD_LIBRARY_PATH="${CRAY_LD_LIBRARY_PATH}:${LD_LIBRARY_PATH:-}" - -# Pixi env activation (replaces the old conda env). One-time setup: -# pixi install -e frontier -# Each SLURM script then sources this file to get the env on PATH. -export PATH="$HOME/.pixi/bin:$PATH" -# shellcheck disable=SC1091,SC2046 -eval "$(pixi shell-hook -e frontier --manifest-path /lustre/orion/fus187/scratch/nchen/FusionAIHub/pyproject.toml)" - -# Performance / correctness knobs -export PYTORCH_ROCM_ARCH=gfx90a -export OMP_NUM_THREADS=1 -export PYTHONUNBUFFERED=1 -export HSA_FORCE_FINE_GRAIN_PCIE=1 - -# RCCL over Slingshot HSN -export NCCL_SOCKET_IFNAME=hsn0 -export NCCL_NET_GDR_LEVEL=3 -export FI_MR_CACHE_MONITOR=kdreg2 -export FI_CXI_DEFAULT_CQ_SIZE=131072 - -# MIOpen kernel cache: per-job, node-local -export MIOPEN_USER_DB_PATH="/tmp/${USER}-miopen-${SLURM_JOB_ID:-local}" -export MIOPEN_CUSTOM_CACHE_DIR="$MIOPEN_USER_DB_PATH" -mkdir -p "$MIOPEN_USER_DB_PATH" - -# Distributed master endpoint derived from SLURM allocation -if [ -n "${SLURM_NODELIST:-}" ]; then - MASTER_ADDR="$(scontrol show hostnames "$SLURM_NODELIST" | head -n1)" -else - MASTER_ADDR="127.0.0.1" -fi -export MASTER_ADDR -export MASTER_PORT="${MASTER_PORT:-29500}" diff --git a/scripts/slurm_frontier/_frontier_settings.sh b/scripts/slurm_frontier/_frontier_settings.sh new file mode 100755 index 0000000..4f3e5dd --- /dev/null +++ b/scripts/slurm_frontier/_frontier_settings.sh @@ -0,0 +1,39 @@ +# shellcheck shell=bash +# Sourced by every Frontier SLURM wrapper. Wrappers cd to the FusionAIHub +# repo root before sourcing, so $PWD = repo root here. + +module load PrgEnv-gnu/8.7.0 +module load cpe/26.03 +module load rocm/7.1.1 +module load craype-accel-amd-gfx90a +export LD_LIBRARY_PATH="${CRAY_LD_LIBRARY_PATH}:${LD_LIBRARY_PATH}" + +PIXI_ENV="$PWD/.pixi/envs/frontier" +export PATH="${PIXI_ENV}/bin:${PATH}" +export LD_LIBRARY_PATH="${PIXI_ENV}/lib:${LD_LIBRARY_PATH}" +export CONDA_PREFIX="${PIXI_ENV}" + +# Performance / correctness knobs +export PYTORCH_ROCM_ARCH=gfx90a +export OMP_NUM_THREADS=1 +export PYTHONUNBUFFERED=1 +export HSA_FORCE_FINE_GRAIN_PCIE=1 + +# flash-attn 2 on ROCm: main_perf branch requires this at IMPORT time to +# take the Triton-AMD (aiter) path; otherwise it tries `flash_attn_2_cuda`. +export FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE + +# RCCL over Slingshot HSN +export NCCL_SOCKET_IFNAME=hsn0 +export NCCL_NET_GDR_LEVEL=3 +export FI_MR_CACHE_MONITOR=kdreg2 +export FI_CXI_DEFAULT_CQ_SIZE=131072 + +# MIOpen kernel cache: per-job, node-local +export MIOPEN_USER_DB_PATH="/tmp/${USER}-miopen-${SLURM_JOB_ID}" +export MIOPEN_CUSTOM_CACHE_DIR="$MIOPEN_USER_DB_PATH" +mkdir -p "$MIOPEN_USER_DB_PATH" + +# Distributed master endpoint +export MASTER_ADDR="$(scontrol show hostnames "$SLURM_NODELIST" | head -n1)" +export MASTER_PORT=29500 diff --git a/scripts/slurm_frontier/benchmark_attn_kernels.sh b/scripts/slurm_frontier/benchmark_attn_kernels.sh new file mode 100644 index 0000000..85cf63f --- /dev/null +++ b/scripts/slurm_frontier/benchmark_attn_kernels.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# Kernel-level benchmark of attention implementations on MI250X. +# Sweeps head_dim x seq_len for 4 impls (flash_ext, sdpa_math, sdpa_flash, +# sdpa_auto). Sanity-checks whether flash-attn wins anywhere on Frontier +# before we commit to it for any production stage. +# +# Usage: +# sbatch scripts/slurm_frontier/benchmark_attn_kernels.sh +# +#SBATCH -A fus187 +#SBATCH -J attn_bench +#SBATCH -o logs/%j_attn_bench.out +#SBATCH -e logs/%j_attn_bench.err +#SBATCH -t 00:30:00 +#SBATCH -p batch +#SBATCH -q debug +#SBATCH -N 1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-task=1 +#SBATCH --gpu-bind=closest +#SBATCH --cpus-per-task=7 +set -uo pipefail + +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$PWD}" +if [ ! -f "${PROJECT_DIR}/scripts/slurm_frontier/_frontier_settings.sh" ]; then + echo "ERROR: SLURM_SUBMIT_DIR (${PROJECT_DIR}) is not the repo root." >&2 + echo " cd into the FusionAIHub repo before sbatch." >&2 + exit 1 +fi +cd "${PROJECT_DIR}" +mkdir -p logs + +# shellcheck disable=SC1091 +source scripts/slurm_frontier/_frontier_settings.sh + +OUT_DIR="profile/${SLURM_JOB_ID}_attn_bench" +mkdir -p "$OUT_DIR" +echo "[bench] outputs -> $OUT_DIR" +echo "[bench] FLASH_ATTENTION_TRITON_AMD_ENABLE=${FLASH_ATTENTION_TRITON_AMD_ENABLE}" + +srun -N 1 -n 1 -c "$SLURM_CPUS_PER_TASK" \ + --gpus-per-task=1 --gpu-bind=closest \ + scripts/slurm_frontier/_srun_rank_wrapper.sh \ + scripts/training/benchmark_attn_kernels.py \ + --out_dir "$OUT_DIR" \ + --batch 4 \ + --n_heads 16 \ + --head_dims 32 64 128 \ + --seq_lens 32 128 512 2048 4096 \ + --dtype bf16 + +echo "" +echo "=== Done. Summary: $OUT_DIR/summary.md ===" diff --git a/scripts/slurm_frontier/build_dataset_cache.sh b/scripts/slurm_frontier/build_dataset_cache.sh new file mode 100644 index 0000000..c6b310b --- /dev/null +++ b/scripts/slurm_frontier/build_dataset_cache.sh @@ -0,0 +1,88 @@ +#!/bin/bash +# Frontier CPU-only launcher for scripts/build_dataset_cache.py. +# Builds the dataset indexing caches (video-presence + per-file chunk counts) +# in parallel so subsequent train_e2e jobs hit them at __init__ time and skip +# the indexing wall entirely. +# +# Usage: +# # Smoke (100 files): +# MAX_FILES=100 sbatch scripts/slurm_frontier/build_dataset_cache.sh +# +# # Full pass, persist cache for training jobs to reuse: +# sbatch scripts/slurm_frontier/build_dataset_cache.sh +# +# # Don't allocate a GPU node at all — source _frontier_settings.sh (which +# # activates the pixi `frontier` env) on a login or compute node and call +# # python directly: +# python scripts/build_dataset_cache.py --max_files 100 +# +# Common env overrides: +# MAX_FILES= # cap on training files (default: unset = all) +# DATA_DIR= # override data root +# CACHE_DIR= # where to write the indexing caches (default: +# # /lustre/orion/fus187/proj-shared/foundation_model_meta, +# # matches the train_e2e_stage1.py default so +# # subsequent training jobs reuse the cache) +# NO_CACHE=1 # skip cache write (pure timing measurement) +# +#SBATCH -A fus187 +#SBATCH -J build_dataset_cache +#SBATCH -o logs/%j_build_dataset_cache.out +#SBATCH -e logs/%j_build_dataset_cache.err +#SBATCH -t 0:30:00 +#SBATCH -p batch +#SBATCH -N 1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-task=0 +#SBATCH --cpus-per-task=16 +set -uo pipefail + +# SLURM stages the submit script under /var/spool/slurmd/... so BASH_SOURCE +# is useless for locating the repo. Use SLURM_SUBMIT_DIR — submit from the +# repo root: `cd && sbatch scripts/slurm_frontier/build_dataset_cache.sh`. +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$PWD}" +if [ ! -f "${PROJECT_DIR}/scripts/slurm_frontier/_frontier_settings.sh" ]; then + echo "ERROR: SLURM_SUBMIT_DIR (${PROJECT_DIR}) is not the repo root." >&2 + echo " cd into the FusionAIHub repo before sbatch." >&2 + exit 1 +fi +cd "${PROJECT_DIR}" +mkdir -p logs + +# shellcheck disable=SC1091 +source scripts/slurm_frontier/_frontier_settings.sh + +DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" +CACHE_DIR="${CACHE_DIR:-/lustre/orion/fus187/proj-shared/foundation_model_meta}" +# Must mirror train_e2e_stage1.sh's --use_video so the produced lengths cache +# is keyed on the same (post-filter) path list training will see. Set empty +# to skip the filter — but then the cache won't be reusable by --use_video +# training runs. +USE_VIDEO="${USE_VIDEO:-tangtv}" + +MAX_FILES_FLAG="" +[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" + +CACHE_FLAG="--cache_dir $CACHE_DIR" +[ "${NO_CACHE:-0}" = "1" ] && CACHE_FLAG="--no_cache" + +VIDEO_FLAG="" +[ -n "${USE_VIDEO}" ] && VIDEO_FLAG="--use_video $USE_VIDEO" + +# Stage selector. PREDICTION_HORIZON_S and CACHE_NAME_PREFIX must agree: +# the lengths cache contents depend on prediction_horizon_s, so we name +# the cache file per stage to avoid one stage overwriting another. +PREDICTION_HORIZON_S="${PREDICTION_HORIZON_S:-0.05}" +CACHE_NAME_PREFIX="${CACHE_NAME_PREFIX:-lengths_e2e_stage1}" + +echo "[build_dataset_cache] data_dir=$DATA_DIR cache=$CACHE_DIR \ +use_video=${USE_VIDEO:-none} max_files=${MAX_FILES:-all} \ +prediction_horizon_s=${PREDICTION_HORIZON_S} prefix=${CACHE_NAME_PREFIX}" + +python -u scripts/build_dataset_cache.py \ + --data_dir "$DATA_DIR" \ + --prediction_horizon_s "$PREDICTION_HORIZON_S" \ + --cache_name_prefix "$CACHE_NAME_PREFIX" \ + $CACHE_FLAG \ + $VIDEO_FLAG \ + $MAX_FILES_FLAG diff --git a/scripts/slurm_frontier/make_processing_stats.sh b/scripts/slurm_frontier/make_processing_stats.sh new file mode 100755 index 0000000..198440d --- /dev/null +++ b/scripts/slurm_frontier/make_processing_stats.sh @@ -0,0 +1,28 @@ +#!/bin/bash +#SBATCH -A fus187 +#SBATCH -J make_processing_stats +#SBATCH -o logs/%j_make_processing_stats.out +#SBATCH -e logs/%j_make_processing_stats.err +#SBATCH -p extended +#SBATCH -N 1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH -t 24:00:00 +set -uo pipefail + +# SLURM stages the submit script under /var/spool/slurmd/... so BASH_SOURCE +# is useless for locating the repo. Use SLURM_SUBMIT_DIR — submit from the +# repo root: `cd && sbatch scripts/slurm_frontier/make_processing_stats.sh`. +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$PWD}" +if [ ! -f "${PROJECT_DIR}/scripts/slurm_frontier/_frontier_settings.sh" ]; then + echo "ERROR: SLURM_SUBMIT_DIR (${PROJECT_DIR}) is not the repo root." >&2 + echo " cd into the FusionAIHub repo before sbatch." >&2 + exit 1 +fi +cd "${PROJECT_DIR}" +mkdir -p logs + +# shellcheck disable=SC1091 +source scripts/slurm_frontier/_frontier_settings.sh + +srun python -u scripts/data_preparation/make_processing_stats.py diff --git a/scripts/slurm_frontier/memory_probe_e2e.sh b/scripts/slurm_frontier/memory_probe_e2e.sh new file mode 100644 index 0000000..8d47a11 --- /dev/null +++ b/scripts/slurm_frontier/memory_probe_e2e.sh @@ -0,0 +1,54 @@ +#!/bin/bash +#SBATCH -A fus187 +#SBATCH -J mem_probe +#SBATCH -o logs/%j_mem_probe.out +#SBATCH -e logs/%j_mem_probe.err +#SBATCH -t 01:30:00 +#SBATCH -p batch +#SBATCH -q debug +#SBATCH -N 1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-task=1 +#SBATCH --gpu-bind=closest +#SBATCH --cpus-per-task=7 +set -uo pipefail + +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$PWD}" +if [ ! -f "${PROJECT_DIR}/scripts/slurm_frontier/_frontier_settings.sh" ]; then + echo "ERROR: SLURM_SUBMIT_DIR (${PROJECT_DIR}) is not the repo root." >&2 + echo " cd into the FusionAIHub repo before sbatch." >&2 + exit 1 +fi +cd "${PROJECT_DIR}" +mkdir -p logs + +# shellcheck disable=SC1091 +source scripts/slurm_frontier/_frontier_settings.sh + +BATCH="${BATCH:-1}" + +run_probe() { + local label="$1"; local d_model="$2"; local n_layers="$3" + local n_heads="$4"; local k="$5"; shift 5 + echo "" + echo "================================================================" + echo "=== $label (d_model=$d_model n_layers=$n_layers n_heads=$n_heads K=$k batch=$BATCH) ===" + echo "================================================================" + srun -N 1 -n 1 -c "$SLURM_CPUS_PER_TASK" \ + --gpus-per-task=1 --gpu-bind=closest \ + scripts/slurm_frontier/_srun_rank_wrapper.sh \ + scripts/training/memory_probe_e2e.py \ + --d_model "$d_model" --n_layers "$n_layers" --n_heads "$n_heads" \ + --batch_size "$BATCH" --K_rollout "$k" \ + "$@" || echo "[$label] non-zero exit (likely OOM — see above)" +} + +COMMON_FLAGS=(--attn_impl sdpa --gradient_checkpoint) + +# Single-shot probe: does 2.68B fit at K=50? +# Prior at this exact shape: K=25 → 53.73 GB peak (optim.step-bound). +# K=50 doubles rollout activations; predicted borderline (60-65 GB peak). +run_probe "2.68B @ K=50 (d=2048 L=32)" 2048 32 32 50 "${COMMON_FLAGS[@]}" + +echo "" +echo "=== Done. ===" diff --git a/scripts/slurm_frontier/profile_indexing.sh b/scripts/slurm_frontier/profile_indexing.sh deleted file mode 100644 index 0622871..0000000 --- a/scripts/slurm_frontier/profile_indexing.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash -# Frontier CPU-only launcher for scripts/profile_indexing.py. -# Times the file-length indexing pass that train_e2e jobs do in build_datasets, -# and reports files/sec throughput. Optionally pre-populates a lengths cache -# so future training jobs skip the indexing wall entirely. -# -# Usage: -# # Smoke (100 files, ~1 min): -# MAX_FILES=100 sbatch scripts/slurm_frontier/profile_indexing.sh -# -# # Full pass, persist cache for training jobs to reuse: -# sbatch scripts/slurm_frontier/profile_indexing.sh -# -# # Don't allocate a GPU node at all by calling python directly after `conda -# # activate $CONDA_ENV_PATH` from a login or compute node: -# python scripts/profile_indexing.py --max_files 100 -# -# Common env overrides: -# MAX_FILES= # cap on training files (default: unset = all) -# DATA_DIR= # override data root -# CACHE_DIR= # where to write the lengths cache (default: -# # runs/lengths_cache_e2e_stage1/, persists for -# # subsequent training jobs) -# NO_CACHE=1 # skip cache write (pure profile) -# -#SBATCH -A fus187 -#SBATCH -J e2e_idx_profile -#SBATCH -o logs/%j_idx_profile.out -#SBATCH -e logs/%j_idx_profile.err -#SBATCH -t 01:00:00 -#SBATCH -p batch -#SBATCH -N 1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-task=0 -#SBATCH --cpus-per-task=8 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -CACHE_DIR="${CACHE_DIR:-runs/lengths_cache_e2e_stage1}" - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -CACHE_FLAG="--cache_dir $CACHE_DIR" -[ "${NO_CACHE:-0}" = "1" ] && CACHE_FLAG="--no_cache" - -echo "[idx_profile] data_dir=$DATA_DIR cache=$CACHE_DIR max_files=${MAX_FILES:-all}" - -python -u scripts/profile_indexing.py \ - --data_dir "$DATA_DIR" \ - $CACHE_FLAG \ - $MAX_FILES_FLAG diff --git a/scripts/slurm_frontier/profile_stage1_1x1.sh b/scripts/slurm_frontier/profile_stage1_1x1.sh new file mode 100644 index 0000000..b47d729 --- /dev/null +++ b/scripts/slurm_frontier/profile_stage1_1x1.sh @@ -0,0 +1,92 @@ +#!/bin/bash +#SBATCH -A fus187 +#SBATCH -J e2e_s1_prof +#SBATCH -o logs/%j_e2e_s1_prof.out +#SBATCH -e logs/%j_e2e_s1_prof.err +#SBATCH -t 00:30:00 +#SBATCH -p batch +#SBATCH -q debug +#SBATCH -N 1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-task=1 +#SBATCH --gpu-bind=closest +#SBATCH --cpus-per-task=7 +set -uo pipefail + +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$PWD}" +if [ ! -f "${PROJECT_DIR}/scripts/slurm_frontier/_frontier_settings.sh" ]; then + echo "ERROR: SLURM_SUBMIT_DIR (${PROJECT_DIR}) is not the repo root." >&2 + echo " cd into the FusionAIHub repo before sbatch." >&2 + exit 1 +fi +cd "${PROJECT_DIR}" +mkdir -p logs + +# shellcheck disable=SC1091 +source scripts/slurm_frontier/_frontier_settings.sh + +DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" +STATS_PATH="${STATS_PATH:-/lustre/orion/fus187/proj-shared/foundation_model_meta/preprocessing_stats.pt}" +LENGTHS_CACHE_DIR="${LENGTHS_CACHE_DIR:-runs/profile_stage1_lengths_cache}" +mkdir -p "$LENGTHS_CACHE_DIR" +BATCH_SIZE="${BATCH_SIZE:-4}" +NUM_WORKERS="${NUM_WORKERS:-4}" +MAX_FILES="${MAX_FILES:-15}" +N_LAYERS="${N_LAYERS:-26}" +D_MODEL="${D_MODEL:-256}" +N_HEADS="${N_HEADS:-8}" +PROFILE_WAIT="${PROFILE_WAIT:-3}" +PROFILE_WARMUP="${PROFILE_WARMUP:-3}" +PROFILE_ACTIVE="${PROFILE_ACTIVE:-15}" + +PROF_ROOT="profile/${SLURM_JOB_ID}_stage1_1x1" +mkdir -p "$PROF_ROOT/without_flash" "$PROF_ROOT/with_flash" +echo "[profile/1x1] outputs -> $PROF_ROOT" +echo "[profile/1x1] n_layers=$N_LAYERS d_model=$D_MODEL n_heads=$N_HEADS \ +batch=$BATCH_SIZE active_steps=$PROFILE_ACTIVE max_files=$MAX_FILES" + +run_profile() { + local out_dir="$1" + local extra_flag="$2" + local label="$3" + echo "" + echo "=== [$label] starting profile run ===" + srun -N 1 -n 1 -c "$SLURM_CPUS_PER_TASK" \ + --gpus-per-task=1 --gpu-bind=closest \ + scripts/slurm_frontier/_srun_rank_wrapper.sh \ + scripts/training/profile_stage1.py \ + --data_dir "$DATA_DIR" \ + --stats_path "$STATS_PATH" \ + --lengths_cache_dir "$LENGTHS_CACHE_DIR" \ + --output_dir "$out_dir" \ + --batch_size "$BATCH_SIZE" \ + --num_workers "$NUM_WORKERS" \ + --max_files "$MAX_FILES" \ + --d_model "$D_MODEL" \ + --n_layers "$N_LAYERS" \ + --n_heads "$N_HEADS" \ + --profile_wait "$PROFILE_WAIT" \ + --profile_warmup "$PROFILE_WARMUP" \ + --profile_active "$PROFILE_ACTIVE" \ + --use_video tangtv \ + --use_spectro ece co2 bes \ + $extra_flag +} + +# Order matters: run WITHOUT first so MIOpen kernel cache is identical for +# both runs (flash-attn doesn't touch MIOpen, but other ops do). +run_profile "$PROF_ROOT/without_flash" "" "no-flash" +run_profile "$PROF_ROOT/with_flash" "--use_flash_attn" "flash" + +echo "" +echo "=== Comparison ===" +python scripts/slurm_frontier/_compare_profiles.py \ + "$PROF_ROOT/without_flash/memory.json" \ + "$PROF_ROOT/with_flash/memory.json" \ + | tee "$PROF_ROOT/comparison.txt" + +echo "" +echo "=== Done ===" +echo "Open traces in chrome://tracing or Perfetto:" +echo " $PROF_ROOT/without_flash/trace.json" +echo " $PROF_ROOT/with_flash/trace.json" diff --git a/scripts/slurm_frontier/setup_frontier_env.sh b/scripts/slurm_frontier/setup_frontier_env.sh new file mode 100755 index 0000000..14cc928 --- /dev/null +++ b/scripts/slurm_frontier/setup_frontier_env.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Build & install flash-attention 2 (Triton backend) for OLCF Frontier (MI250X / gfx90a). +# +# Run from the repo root on a Frontier LOGIN node: +# pixi run -e frontier setup-flash-attn +# +# Builds entirely on the login node — no SLURM allocation, no GPU. The Triton +# backend (FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE) replaces the multi-hour +# Composable Kernel template/hipcc compile with a quick pure-Python install +# (~2-5 min). Triton kernels are JIT-compiled at first use, so no GPU is +# needed at build time. +# +# A separate `verify-flash-attn` pixi task tests the install on a GPU; run it +# from inside any SLURM allocation that has --gpus. +# +# Prerequisite: `pixi install -e frontier` has been run once. +set -euo pipefail + +PROJECT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +FLASH_ATTN_SHA=5301a359f59ef8fa10f211618d9f7a69716a8898 +FLASH_ATTN_URL="https://github.com/ROCm/flash-attention.git" +FLASH_ATTN_LOCAL="${PROJECT_DIR}/.build/flash-attention" +ROCM_MODULE=rocm/7.1.1 + +cd "$PROJECT_DIR" + +echo "=== Ensuring local flash-attention checkout ===" +mkdir -p "$(dirname "${FLASH_ATTN_LOCAL}")" +if [ ! -d "${FLASH_ATTN_LOCAL}/.git" ]; then + echo " cloning ${FLASH_ATTN_URL} -> ${FLASH_ATTN_LOCAL}" + git clone --filter=blob:none "${FLASH_ATTN_URL}" "${FLASH_ATTN_LOCAL}" +fi +pushd "${FLASH_ATTN_LOCAL}" >/dev/null +HAVE_SHA="$(git rev-parse HEAD 2>/dev/null || echo none)" +if [ "${HAVE_SHA}" != "${FLASH_ATTN_SHA}" ]; then + echo " fetching + checking out ${FLASH_ATTN_SHA}" + git fetch origin "${FLASH_ATTN_SHA}" + git checkout -q "${FLASH_ATTN_SHA}" +fi +echo " initializing submodules" +git submodule update --init --recursive +popd >/dev/null + +# Locate the pixi env's python. We bypass `pixi run` / `pixi install` because +# both re-resolve the lock file on every invocation (slow on PyPI sockets, +# and pixi/uv hangs on autofs locks under contention). +PIXI_PY="${PROJECT_DIR}/.pixi/envs/frontier/bin/python" +if [ ! -x "$PIXI_PY" ]; then + echo "ERROR: frontier pixi env not provisioned at $PIXI_PY." >&2 + echo " Run \`pixi install -e frontier\` first." >&2 + exit 1 +fi + +# Module load on the login node. The Triton backend doesn't strictly require +# the ROCm module at build time (Triton compiles kernels JIT at first call, +# inside whatever ROCm environment the runtime uses), but we load it for +# consistency with the runtime environment. +# shellcheck disable=SC1091 +source /etc/profile.d/lmod.sh 2>/dev/null || true +module load PrgEnv-gnu "${ROCM_MODULE}" craype-accel-amd-gfx90a + +# Triton backend — no Composable Kernel, no hipcc template explosion. +export FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE +export PYTORCH_ROCM_ARCH=gfx90a + +echo "" +echo "=== Installing flash-attn 2 (Triton backend) on login node ===" +echo " source = ${FLASH_ATTN_LOCAL}" +echo " pinned SHA = ${FLASH_ATTN_SHA}" +echo " python = ${PIXI_PY}" +echo " FLASH_ATTENTION_TRITON_AMD_ENABLE=${FLASH_ATTENTION_TRITON_AMD_ENABLE}" +"$PIXI_PY" -m pip install --no-build-isolation -v "${FLASH_ATTN_LOCAL}" + +echo "" +echo "=== Login-node install complete ===" +echo "Test the install on a GPU from inside a SLURM allocation:" +echo " salloc -A fus187 -t 00:10:00 -N 1 --gpus=1" +echo " pixi run -e frontier verify-flash-attn" diff --git a/scripts/slurm_frontier/train_e2e_stage1.sh b/scripts/slurm_frontier/train_e2e_stage1.sh index 894fd31..3a448ea 100644 --- a/scripts/slurm_frontier/train_e2e_stage1.sh +++ b/scripts/slurm_frontier/train_e2e_stage1.sh @@ -3,28 +3,63 @@ #SBATCH -J e2e_stage1 #SBATCH -o logs/%j_e2e_stage1.out #SBATCH -e logs/%j_e2e_stage1.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 1 +#SBATCH -t 24:00:00 +#SBATCH -p extended +#SBATCH -N 8 #SBATCH --ntasks-per-node=8 +#SBATCH --gres=gpu:8 #SBATCH --gpus-per-task=1 #SBATCH --gpu-bind=closest #SBATCH --cpus-per-task=7 +#SBATCH --mem=0 set -e -cd /lustre/orion/fus187/scratch/nchen/FusionAIHub -mkdir -p logs runs/e2e_stage1 +# SLURM stages the submit script under /var/spool/slurmd/... so BASH_SOURCE +# is useless for locating the repo. Use SLURM_SUBMIT_DIR — submit from the +# repo root: `cd && sbatch scripts/slurm_frontier/train_e2e_stage1.sh`. +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$PWD}" +if [ ! -f "${PROJECT_DIR}/scripts/slurm_frontier/_frontier_settings.sh" ]; then + echo "ERROR: SLURM_SUBMIT_DIR (${PROJECT_DIR}) is not the repo root." >&2 + echo " cd into the FusionAIHub repo before sbatch." >&2 + exit 1 +fi +cd "${PROJECT_DIR}" +CHECKPOINT_DIR="/lustre/orion/fus187/proj-shared/models/e2e_stage1" +mkdir -p logs "${CHECKPOINT_DIR}" export MASTER_PORT=29500 -source scripts/slurm_frontier/_frontier_common.sh +source scripts/slurm_frontier/_frontier_settings.sh + +# Auto-resume from previous chained submission. Pass --resume_checkpoint +# only when a `_latest.pt` is on disk; the Python script's flag guard +# would otherwise fall through to fresh init anyway, but being explicit +# makes the log line show whether we resumed or started cold. +RESUME_FLAG="" +LATEST_CKPT="${CHECKPOINT_DIR}/e2e_stage1_latest.pt" +if [ -f "${LATEST_CKPT}" ]; then + echo "[train_e2e_stage1] resuming from ${LATEST_CKPT}" + RESUME_FLAG="--resume_checkpoint ${LATEST_CKPT}" +else + echo "[train_e2e_stage1] no latest checkpoint at ${LATEST_CKPT}; starting fresh" +fi + +# Per-node sampler: one line per node per minute with mean GPU busy%, +# host RAM, and mean VRAM%. Launched as a side srun step with --overlap +# so it shares the allocation without stealing GPUs. Cost ~0.1% of one +# CPU/node. Killed when this script exits (walltime or normal end). +SAMPLER_LOG="logs/${SLURM_JOB_ID}_sampler.log" +srun --overlap -N "$SLURM_JOB_NUM_NODES" --ntasks-per-node=1 -c 1 \ + scripts/slurm_frontier/_node_sampler.sh > "$SAMPLER_LOG" 2>&1 & +SAMPLER_PID=$! +trap 'kill "$SAMPLER_PID" 2>/dev/null || true' EXIT srun -N $SLURM_JOB_NUM_NODES -n $SLURM_NTASKS -c $SLURM_CPUS_PER_TASK \ --gpus-per-task=1 --gpu-bind=closest \ scripts/slurm_frontier/_srun_rank_wrapper.sh \ scripts/training/train_e2e_stage1.py \ --data_dir /lustre/orion/fus187/proj-shared/foundation_model \ - --stats_path data/preprocessing_stats.pt \ - --checkpoint_dir runs/e2e_stage1 \ + --stats_path /lustre/orion/fus187/proj-shared/foundation_model_meta/preprocessing_stats.pt \ + --checkpoint_dir "${CHECKPOINT_DIR}" \ --val_fraction 0.1 \ --seed 42 \ --chunk_duration_s 0.05 \ @@ -32,17 +67,21 @@ srun -N $SLURM_JOB_NUM_NODES -n $SLURM_NTASKS -c $SLURM_CPUS_PER_TASK \ --step_size_s 0.01 \ --warmup_s 1.0 \ --d_model 256 \ - --n_layers 8 \ + --n_layers 26 \ --n_heads 8 \ --dropout 0.1 \ - --lr 1e-4 \ + --lr 5e-4 \ --min_lr 1e-6 \ - --warmup_steps 2000 \ + --warmup_steps 4000 \ --weight_decay 0.1 \ --grad_clip 5.0 \ - --batch_size 16 \ - --num_workers 4 \ - --max_steps 50000 \ + --batch_size 64 \ + --num_workers 6 \ + --max_steps 672000 \ --log_every 50 \ - --val_every 500 \ - --val_max_batches 20 + --val_every 1180 \ + --val_max_batches 100 \ + --use_video tangtv \ + --use_spectro ece co2 bes \ + --no_amp_val \ + ${RESUME_FLAG} diff --git a/scripts/slurm_frontier/train_e2e_stage1_1x1.sh b/scripts/slurm_frontier/train_e2e_stage1_1x1.sh deleted file mode 100644 index aa19f31..0000000 --- a/scripts/slurm_frontier/train_e2e_stage1_1x1.sh +++ /dev/null @@ -1,136 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage1 — 1 node × 1 GCD (single-GPU smoke / dev) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage1_1x1.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 16) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29500) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage1_1x1.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s1_1x1 -#SBATCH -o logs/%j_e2e_s1_1x1.out -#SBATCH -e logs/%j_e2e_s1_1x1.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29500}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-1}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 1))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-16}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage1_frontier}" -mkdir -p "$CHECKPOINT_DIR" - -# Auto-resume from latest checkpoint if it exists. -LATEST="$CHECKPOINT_DIR/e2e_stage1_latest.pt" -RESUME_FLAG="" -if [ -f "$LATEST" ]; then - RESUME_FLAG="--resume_checkpoint $LATEST" - echo "[stage1] auto-resume from $LATEST" -fi - -TRAIN_SHOTS_FLAG="" -[ -n "${TRAIN_SHOTS_YAML:-}" ] && TRAIN_SHOTS_FLAG="--train_shots_yaml $TRAIN_SHOTS_YAML" -echo "${SMOKE_BANNER}[stage1/1x1] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS" -echo "${SMOKE_BANNER}[stage1/1x1] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -# ─── Optional GPU+CPU profiling sidecar (PROFILE=1) ────────────────────── -PROF_PID="" -if [ "${PROFILE:-0}" = "1" ]; then - PROF_DIR="${PROF_DIR:-profile/${SLURM_JOB_ID}_$(basename "$0" .sh)}" - mkdir -p "$PROF_DIR" - echo "[profile] sampling rocm-smi + mpstat (1 Hz) -> $PROF_DIR" - srun --overlap --jobid="$SLURM_JOB_ID" \ - -N "$NODES" -n "$NODES" --ntasks-per-node=1 \ - --gpus-per-task=0 --cpus-per-task=2 \ - scripts/slurm_frontier/_profile_node.sh "$PROF_DIR" & - PROF_PID=$! -fi -trap '[ -n "${PROF_PID:-}" ] && kill "$PROF_PID" 2>/dev/null; true' EXIT - -srun --overlap -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage1.py \ - $RESUME_FLAG $MAX_FILES_FLAG $TRAIN_SHOTS_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---prediction_horizon_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---lr 1e-4 \ ---min_lr 1e-6 \ ---warmup_steps 2000 \ ---weight_decay 0.1 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage1_1x8.sh b/scripts/slurm_frontier/train_e2e_stage1_1x8.sh deleted file mode 100644 index a958e1b..0000000 --- a/scripts/slurm_frontier/train_e2e_stage1_1x8.sh +++ /dev/null @@ -1,122 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage1 — 1 node × 8 GCDs (production single-node DDP) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage1_1x8.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 16) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29500) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage1_1x8.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s1_1x8 -#SBATCH -o logs/%j_e2e_s1_1x8.out -#SBATCH -e logs/%j_e2e_s1_1x8.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 1 -#SBATCH --ntasks-per-node=8 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29500}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-1}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 8))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-16}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage1_frontier}" -mkdir -p "$CHECKPOINT_DIR" - -# Auto-resume from latest checkpoint if it exists. -LATEST="$CHECKPOINT_DIR/e2e_stage1_latest.pt" -RESUME_FLAG="" -if [ -f "$LATEST" ]; then - RESUME_FLAG="--resume_checkpoint $LATEST" - echo "[stage1] auto-resume from $LATEST" -fi - -TRAIN_SHOTS_FLAG="" -[ -n "${TRAIN_SHOTS_YAML:-}" ] && TRAIN_SHOTS_FLAG="--train_shots_yaml $TRAIN_SHOTS_YAML" -echo "${SMOKE_BANNER}[stage1/1x8] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS" -echo "${SMOKE_BANNER}[stage1/1x8] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage1.py \ - $RESUME_FLAG $MAX_FILES_FLAG $TRAIN_SHOTS_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---prediction_horizon_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---lr 1e-4 \ ---min_lr 1e-6 \ ---warmup_steps 2000 \ ---weight_decay 0.1 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage1_Nx1.sh b/scripts/slurm_frontier/train_e2e_stage1_Nx1.sh deleted file mode 100644 index c47dc61..0000000 --- a/scripts/slurm_frontier/train_e2e_stage1_Nx1.sh +++ /dev/null @@ -1,122 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage1 — N nodes × 1 GCD (cross-node networking smoke; default N=2) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage1_Nx1.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 16) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29500) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage1_Nx1.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s1_Nx1 -#SBATCH -o logs/%j_e2e_s1_Nx1.out -#SBATCH -e logs/%j_e2e_s1_Nx1.err -#SBATCH -t 01:00:00 -#SBATCH -p batch -#SBATCH -N 2 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29500}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-2}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 1))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-16}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage1_frontier}" -mkdir -p "$CHECKPOINT_DIR" - -# Auto-resume from latest checkpoint if it exists. -LATEST="$CHECKPOINT_DIR/e2e_stage1_latest.pt" -RESUME_FLAG="" -if [ -f "$LATEST" ]; then - RESUME_FLAG="--resume_checkpoint $LATEST" - echo "[stage1] auto-resume from $LATEST" -fi - -TRAIN_SHOTS_FLAG="" -[ -n "${TRAIN_SHOTS_YAML:-}" ] && TRAIN_SHOTS_FLAG="--train_shots_yaml $TRAIN_SHOTS_YAML" -echo "${SMOKE_BANNER}[stage1/Nx1] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS" -echo "${SMOKE_BANNER}[stage1/Nx1] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage1.py \ - $RESUME_FLAG $MAX_FILES_FLAG $TRAIN_SHOTS_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---prediction_horizon_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---lr 1e-4 \ ---min_lr 1e-6 \ ---warmup_steps 2000 \ ---weight_decay 0.1 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage1_NxN.sh b/scripts/slurm_frontier/train_e2e_stage1_NxN.sh deleted file mode 100644 index b47aa94..0000000 --- a/scripts/slurm_frontier/train_e2e_stage1_NxN.sh +++ /dev/null @@ -1,122 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage1 — N nodes × 8 GCDs (production multi-node; default N=4, override with `sbatch -N `) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage1_NxN.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 16) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29500) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage1_NxN.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s1_NxN -#SBATCH -o logs/%j_e2e_s1_NxN.out -#SBATCH -e logs/%j_e2e_s1_NxN.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 4 -#SBATCH --ntasks-per-node=8 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29500}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-4}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 8))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-16}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage1_frontier}" -mkdir -p "$CHECKPOINT_DIR" - -# Auto-resume from latest checkpoint if it exists. -LATEST="$CHECKPOINT_DIR/e2e_stage1_latest.pt" -RESUME_FLAG="" -if [ -f "$LATEST" ]; then - RESUME_FLAG="--resume_checkpoint $LATEST" - echo "[stage1] auto-resume from $LATEST" -fi - -TRAIN_SHOTS_FLAG="" -[ -n "${TRAIN_SHOTS_YAML:-}" ] && TRAIN_SHOTS_FLAG="--train_shots_yaml $TRAIN_SHOTS_YAML" -echo "${SMOKE_BANNER}[stage1/NxN] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS" -echo "${SMOKE_BANNER}[stage1/NxN] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage1.py \ - $RESUME_FLAG $MAX_FILES_FLAG $TRAIN_SHOTS_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---prediction_horizon_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---lr 1e-4 \ ---min_lr 1e-6 \ ---warmup_steps 2000 \ ---weight_decay 0.1 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage1_flashattn.sh b/scripts/slurm_frontier/train_e2e_stage1_flashattn.sh new file mode 100755 index 0000000..6e76d47 --- /dev/null +++ b/scripts/slurm_frontier/train_e2e_stage1_flashattn.sh @@ -0,0 +1,90 @@ +#!/bin/bash +# Production stage-1 run with flash-attention 2 enabled. +# Mirrors scripts/slurm_frontier/train_e2e_stage1.sh; adds --use_flash_attn +# and uses a distinct CHECKPOINT_DIR so the flash and non-flash runs don't +# clobber each other. +# +# Usage: +# cd +# sbatch scripts/slurm_frontier/train_e2e_stage1_flashattn.sh +# +# Prerequisite: flash_attn package must be built (one-time): +# pixi run -e frontier setup-flash-attn +# +#SBATCH -A fus187 +#SBATCH -J e2e_stage1_flashattn +#SBATCH -o logs/%j_e2e_stage1_flashattn.out +#SBATCH -e logs/%j_e2e_stage1_flashattn.err +#SBATCH -t 24:00:00 +#SBATCH -p extended +#SBATCH -N 8 +#SBATCH --ntasks-per-node=8 +#SBATCH --gres=gpu:8 +#SBATCH --gpus-per-task=1 +#SBATCH --gpu-bind=closest +#SBATCH --cpus-per-task=7 +#SBATCH --mem=0 +set -e + +# SLURM stages the submit script under /var/spool/slurmd/... so BASH_SOURCE +# is useless for locating the repo. Use SLURM_SUBMIT_DIR — submit from the +# repo root: `cd && sbatch scripts/slurm_frontier/train_e2e_stage1_flashattn.sh`. +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$PWD}" +if [ ! -f "${PROJECT_DIR}/scripts/slurm_frontier/_frontier_settings.sh" ]; then + echo "ERROR: SLURM_SUBMIT_DIR (${PROJECT_DIR}) is not the repo root." >&2 + echo " cd into the FusionAIHub repo before sbatch." >&2 + exit 1 +fi +cd "${PROJECT_DIR}" +CHECKPOINT_DIR="/lustre/orion/fus187/proj-shared/models/e2e_stage1_flashattn" +mkdir -p logs "${CHECKPOINT_DIR}" + +export MASTER_PORT=29500 +source scripts/slurm_frontier/_frontier_settings.sh + +# Auto-resume from previous chained submission. Pass --resume_checkpoint +# only when a `_latest.pt` is on disk; the Python script's flag guard +# would otherwise fall through to fresh init anyway, but being explicit +# makes the log line show whether we resumed or started cold. +RESUME_FLAG="" +LATEST_CKPT="${CHECKPOINT_DIR}/e2e_stage1_latest.pt" +if [ -f "${LATEST_CKPT}" ]; then + echo "[train_e2e_stage1_flashattn] resuming from ${LATEST_CKPT}" + RESUME_FLAG="--resume_checkpoint ${LATEST_CKPT}" +else + echo "[train_e2e_stage1_flashattn] no latest checkpoint at ${LATEST_CKPT}; starting fresh" +fi + +srun -N $SLURM_JOB_NUM_NODES -n $SLURM_NTASKS -c $SLURM_CPUS_PER_TASK \ + --gpus-per-task=1 --gpu-bind=closest \ + scripts/slurm_frontier/_srun_rank_wrapper.sh \ + scripts/training/train_e2e_stage1.py \ + --data_dir /lustre/orion/fus187/proj-shared/foundation_model \ + --stats_path /lustre/orion/fus187/proj-shared/foundation_model_meta/preprocessing_stats.pt \ + --checkpoint_dir "${CHECKPOINT_DIR}" \ + --val_fraction 0.1 \ + --seed 42 \ + --chunk_duration_s 0.05 \ + --prediction_horizon_s 0.05 \ + --step_size_s 0.01 \ + --warmup_s 1.0 \ + --d_model 256 \ + --n_layers 26 \ + --n_heads 8 \ + --dropout 0.1 \ + --lr 5e-4 \ + --min_lr 1e-6 \ + --warmup_steps 4000 \ + --weight_decay 0.1 \ + --grad_clip 5.0 \ + --batch_size 64 \ + --num_workers 6 \ + --max_steps 672000 \ + --log_every 50 \ + --val_every 1180 \ + --val_max_batches 100 \ + --use_video tangtv \ + --use_spectro ece co2 bes \ + --no_amp_val \ + --use_flash_attn \ + ${RESUME_FLAG} diff --git a/scripts/slurm_frontier/train_e2e_stage2.sh b/scripts/slurm_frontier/train_e2e_stage2.sh index 228f6fc..d3bb7d1 100644 --- a/scripts/slurm_frontier/train_e2e_stage2.sh +++ b/scripts/slurm_frontier/train_e2e_stage2.sh @@ -12,11 +12,17 @@ #SBATCH --cpus-per-task=7 set -e -cd /lustre/orion/fus187/scratch/nchen/FusionAIHub +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$PWD}" +if [ ! -f "${PROJECT_DIR}/scripts/slurm_frontier/_frontier_settings.sh" ]; then + echo "ERROR: SLURM_SUBMIT_DIR (${PROJECT_DIR}) is not the repo root." >&2 + echo " cd into the FusionAIHub repo before sbatch." >&2 + exit 1 +fi +cd "${PROJECT_DIR}" mkdir -p logs runs/e2e_stage2 export MASTER_PORT=29501 -source scripts/slurm_frontier/_frontier_common.sh +source scripts/slurm_frontier/_frontier_settings.sh srun -N $SLURM_JOB_NUM_NODES -n $SLURM_NTASKS -c $SLURM_CPUS_PER_TASK \ --gpus-per-task=1 --gpu-bind=closest \ diff --git a/scripts/slurm_frontier/train_e2e_stage2_1x1.sh b/scripts/slurm_frontier/train_e2e_stage2_1x1.sh deleted file mode 100644 index 9e18f6c..0000000 --- a/scripts/slurm_frontier/train_e2e_stage2_1x1.sh +++ /dev/null @@ -1,126 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage2 — 1 node × 1 GCD (single-GPU smoke / dev) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage2_1x1.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 8) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29501) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage2_1x1.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s2_1x1 -#SBATCH -o logs/%j_e2e_s2_1x1.out -#SBATCH -e logs/%j_e2e_s2_1x1.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29501}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-1}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 1))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-8}" -K_MAX="${K_MAX:-10}" -CURRICULUM_STEPS="${CURRICULUM_STEPS:-$((MAX_STEPS / 2))}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage2_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage1_frontier/e2e_stage1_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -if [ -f "$INIT_CHECKPOINT" ]; then - INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - echo "[stage2] init from $INIT_CHECKPOINT" -else - echo "[stage2] WARNING: $INIT_CHECKPOINT not found — random init" -fi - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" -echo "${SMOKE_BANNER}[stage2/1x1] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS K_max=$K_MAX" -echo "${SMOKE_BANNER}[stage2/1x1] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage2.py \ - $INIT_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---K_max "$K_MAX" \ ---curriculum_steps "$CURRICULUM_STEPS" \ ---lr 3e-5 \ ---min_lr 1e-6 \ ---warmup_steps 200 \ ---weight_decay 0.1 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage2_1x8.sh b/scripts/slurm_frontier/train_e2e_stage2_1x8.sh deleted file mode 100644 index 1fead01..0000000 --- a/scripts/slurm_frontier/train_e2e_stage2_1x8.sh +++ /dev/null @@ -1,126 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage2 — 1 node × 8 GCDs (production single-node DDP) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage2_1x8.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 8) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29501) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage2_1x8.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s2_1x8 -#SBATCH -o logs/%j_e2e_s2_1x8.out -#SBATCH -e logs/%j_e2e_s2_1x8.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 1 -#SBATCH --ntasks-per-node=8 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29501}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-1}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 8))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-8}" -K_MAX="${K_MAX:-10}" -CURRICULUM_STEPS="${CURRICULUM_STEPS:-$((MAX_STEPS / 2))}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage2_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage1_frontier/e2e_stage1_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -if [ -f "$INIT_CHECKPOINT" ]; then - INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - echo "[stage2] init from $INIT_CHECKPOINT" -else - echo "[stage2] WARNING: $INIT_CHECKPOINT not found — random init" -fi - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" -echo "${SMOKE_BANNER}[stage2/1x8] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS K_max=$K_MAX" -echo "${SMOKE_BANNER}[stage2/1x8] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage2.py \ - $INIT_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---K_max "$K_MAX" \ ---curriculum_steps "$CURRICULUM_STEPS" \ ---lr 3e-5 \ ---min_lr 1e-6 \ ---warmup_steps 200 \ ---weight_decay 0.1 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage2_Nx1.sh b/scripts/slurm_frontier/train_e2e_stage2_Nx1.sh deleted file mode 100644 index 3d668b8..0000000 --- a/scripts/slurm_frontier/train_e2e_stage2_Nx1.sh +++ /dev/null @@ -1,126 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage2 — N nodes × 1 GCD (cross-node networking smoke; default N=2) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage2_Nx1.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 8) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29501) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage2_Nx1.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s2_Nx1 -#SBATCH -o logs/%j_e2e_s2_Nx1.out -#SBATCH -e logs/%j_e2e_s2_Nx1.err -#SBATCH -t 01:00:00 -#SBATCH -p batch -#SBATCH -N 2 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29501}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-2}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 1))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-8}" -K_MAX="${K_MAX:-10}" -CURRICULUM_STEPS="${CURRICULUM_STEPS:-$((MAX_STEPS / 2))}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage2_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage1_frontier/e2e_stage1_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -if [ -f "$INIT_CHECKPOINT" ]; then - INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - echo "[stage2] init from $INIT_CHECKPOINT" -else - echo "[stage2] WARNING: $INIT_CHECKPOINT not found — random init" -fi - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" -echo "${SMOKE_BANNER}[stage2/Nx1] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS K_max=$K_MAX" -echo "${SMOKE_BANNER}[stage2/Nx1] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage2.py \ - $INIT_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---K_max "$K_MAX" \ ---curriculum_steps "$CURRICULUM_STEPS" \ ---lr 3e-5 \ ---min_lr 1e-6 \ ---warmup_steps 200 \ ---weight_decay 0.1 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage2_NxN.sh b/scripts/slurm_frontier/train_e2e_stage2_NxN.sh deleted file mode 100644 index 265418e..0000000 --- a/scripts/slurm_frontier/train_e2e_stage2_NxN.sh +++ /dev/null @@ -1,126 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage2 — N nodes × 8 GCDs (production multi-node; default N=4, override with `sbatch -N `) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage2_NxN.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 8) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29501) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage2_NxN.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s2_NxN -#SBATCH -o logs/%j_e2e_s2_NxN.out -#SBATCH -e logs/%j_e2e_s2_NxN.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 4 -#SBATCH --ntasks-per-node=8 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29501}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-4}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 8))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-8}" -K_MAX="${K_MAX:-10}" -CURRICULUM_STEPS="${CURRICULUM_STEPS:-$((MAX_STEPS / 2))}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage2_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage1_frontier/e2e_stage1_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -if [ -f "$INIT_CHECKPOINT" ]; then - INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - echo "[stage2] init from $INIT_CHECKPOINT" -else - echo "[stage2] WARNING: $INIT_CHECKPOINT not found — random init" -fi - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" -echo "${SMOKE_BANNER}[stage2/NxN] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS K_max=$K_MAX" -echo "${SMOKE_BANNER}[stage2/NxN] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage2.py \ - $INIT_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---K_max "$K_MAX" \ ---curriculum_steps "$CURRICULUM_STEPS" \ ---lr 3e-5 \ ---min_lr 1e-6 \ ---warmup_steps 200 \ ---weight_decay 0.1 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage2_delta.sh b/scripts/slurm_frontier/train_e2e_stage2_delta.sh index 608ea13..b18265e 100644 --- a/scripts/slurm_frontier/train_e2e_stage2_delta.sh +++ b/scripts/slurm_frontier/train_e2e_stage2_delta.sh @@ -3,39 +3,99 @@ #SBATCH -J e2e_stage2_delta #SBATCH -o logs/%j_e2e_stage2_delta.out #SBATCH -e logs/%j_e2e_stage2_delta.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 1 +#SBATCH -t 24:00:00 +#SBATCH -p extended +#SBATCH -N 8 #SBATCH --ntasks-per-node=8 +#SBATCH --gres=gpu:8 #SBATCH --gpus-per-task=1 #SBATCH --gpu-bind=closest #SBATCH --cpus-per-task=7 +#SBATCH --mem=0 set -e -cd /lustre/orion/fus187/scratch/nchen/FusionAIHub -mkdir -p logs runs/e2e_stage2_delta +# Submission pattern (matches Stage 1 chained-job recipe): +# +# # First job — short to land in `batch` partition (2h cap): +# sbatch -p batch -t 2:00:00 -N 8 scripts/slurm_frontier/train_e2e_stage2_delta.sh +# +# # Followup 24h jobs on `extended`, chained via afterany so each +# # resubmit picks up the previous job's _latest.pt automatically: +# sbatch -p extended -t 24:00:00 -N 8 --dependency=afterany: \ +# scripts/slurm_frontier/train_e2e_stage2_delta.sh +# Resolve repo from SLURM_SUBMIT_DIR. SLURM stages the script under +# /var/spool/slurmd/... so BASH_SOURCE is useless. Submit from repo root. +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$PWD}" +if [ ! -f "${PROJECT_DIR}/scripts/slurm_frontier/_frontier_settings.sh" ]; then + echo "ERROR: SLURM_SUBMIT_DIR (${PROJECT_DIR}) is not the repo root." >&2 + echo " cd into the FusionAIHub repo before sbatch." >&2 + exit 1 +fi +cd "${PROJECT_DIR}" + +CHECKPOINT_DIR="/lustre/orion/fus187/proj-shared/models/e2e_stage2_delta" +STAGE1_CKPT_DIR="/lustre/orion/fus187/proj-shared/models/e2e_stage1" +STAGE1_BEST="${STAGE1_CKPT_DIR}/e2e_stage1_best.pt" +mkdir -p logs "${CHECKPOINT_DIR}" + +# Per-stage MASTER_PORT (different from Stage 1's 29500 so concurrent +# jobs don't collide on the rendezvous port). export MASTER_PORT=29502 -source scripts/slurm_frontier/_frontier_common.sh +source scripts/slurm_frontier/_frontier_settings.sh + +# Auto-resume from previous chained submission. If a `_latest.pt` exists +# we resume (chained-job continuation). Otherwise initialise from +# Stage 1's `e2e_stage1_best.pt` via --init_checkpoint (cold start). +RESUME_FLAG="" +INIT_FLAG="" +LATEST_CKPT="${CHECKPOINT_DIR}/e2e_stage2_delta_latest.pt" +if [ -f "${LATEST_CKPT}" ]; then + echo "[train_e2e_stage2_delta] resuming from ${LATEST_CKPT}" + RESUME_FLAG="--resume_checkpoint ${LATEST_CKPT}" +elif [ -f "${STAGE1_BEST}" ]; then + echo "[train_e2e_stage2_delta] cold start — initialising from ${STAGE1_BEST}" + INIT_FLAG="--init_checkpoint ${STAGE1_BEST}" +else + echo "ERROR: neither ${LATEST_CKPT} nor ${STAGE1_BEST} found." >&2 + echo " Stage 2 delta needs Stage 1's best.pt to bootstrap." >&2 + exit 1 +fi + +# Per-node sampler: one line per node per minute with mean GPU busy%, +# host RAM, and mean VRAM%. Launched as a side srun step with --overlap +# so it shares the allocation without stealing GPUs. Cost ~0.1% of one +# CPU/node. Killed when this script exits (walltime or normal end). +SAMPLER_LOG="logs/${SLURM_JOB_ID}_sampler.log" +srun --overlap -N "$SLURM_JOB_NUM_NODES" --ntasks-per-node=1 -c 1 \ + scripts/slurm_frontier/_node_sampler.sh > "$SAMPLER_LOG" 2>&1 & +SAMPLER_PID=$! +trap 'kill "$SAMPLER_PID" 2>/dev/null || true' EXIT +# Validation cadence: at 8 nodes × batch_size=8 (global batch 512), +# 4,632,251 stage-2 train chunks → 9047 steps/epoch. val_every=9047 ≈ 1 +# val per epoch — same "1 val per epoch" pattern Stage 1 settled on. +# val_max_batches=30 because Stage 2 val is K_max=10× more expensive +# per batch than Stage 1's single-step val. srun -N $SLURM_JOB_NUM_NODES -n $SLURM_NTASKS -c $SLURM_CPUS_PER_TASK \ --gpus-per-task=1 --gpu-bind=closest \ scripts/slurm_frontier/_srun_rank_wrapper.sh \ scripts/training/train_e2e_stage2_delta.py \ --data_dir /lustre/orion/fus187/proj-shared/foundation_model \ - --stats_path data/preprocessing_stats.pt \ - --checkpoint_dir runs/e2e_stage2_delta \ + --stats_path /lustre/orion/fus187/proj-shared/foundation_model_meta/preprocessing_stats.pt \ + --checkpoint_dir "${CHECKPOINT_DIR}" \ --val_fraction 0.1 \ --seed 42 \ --chunk_duration_s 0.05 \ --step_size_s 0.01 \ --warmup_s 1.0 \ --d_model 256 \ - --n_layers 8 \ + --n_layers 26 \ --n_heads 8 \ --dropout 0.1 \ --K_max 10 \ - --curriculum_steps 25000 \ + --curriculum_steps 180940 \ + --grad_checkpoint_every 0 \ --mae_weight 1.0 \ --cos_weight 0.3 \ --mag_weight 0.1 \ @@ -46,8 +106,12 @@ srun -N $SLURM_JOB_NUM_NODES -n $SLURM_NTASKS -c $SLURM_CPUS_PER_TASK \ --weight_decay 0.1 \ --grad_clip 5.0 \ --batch_size 8 \ - --num_workers 4 \ - --max_steps 50000 \ + --num_workers 6 \ + --max_steps 180940 \ --log_every 50 \ - --val_every 500 \ - --val_max_batches 20 + --val_every 9047 \ + --val_max_batches 30 \ + --use_video tangtv \ + --use_spectro ece co2 bes \ + ${INIT_FLAG} \ + ${RESUME_FLAG} diff --git a/scripts/slurm_frontier/train_e2e_stage2_delta_1x1.sh b/scripts/slurm_frontier/train_e2e_stage2_delta_1x1.sh deleted file mode 100644 index 7bbfa5b..0000000 --- a/scripts/slurm_frontier/train_e2e_stage2_delta_1x1.sh +++ /dev/null @@ -1,133 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage2 Delta — 1 node × 1 GCD (single-GPU smoke / dev) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage2_delta_1x1.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 8) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29502) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage2_delta_1x1.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s2d_1x1 -#SBATCH -o logs/%j_e2e_s2d_1x1.out -#SBATCH -e logs/%j_e2e_s2d_1x1.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29502}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-1}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 1))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-8}" -K_MAX="${K_MAX:-10}" -CURRICULUM_STEPS="${CURRICULUM_STEPS:-$((MAX_STEPS / 2))}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -MAE_WEIGHT="${MAE_WEIGHT:-1.0}" -COS_WEIGHT="${COS_WEIGHT:-0.3}" -MAG_WEIGHT="${MAG_WEIGHT:-0.1}" -MIN_DISP_NORM="${MIN_DISP_NORM:-0.01}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage2_delta_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage1_frontier/e2e_stage1_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -[ -f "$INIT_CHECKPOINT" ] && INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - -LATEST="$CHECKPOINT_DIR/e2e_stage2_delta_latest.pt" -RESUME_FLAG="" -[ -f "$LATEST" ] && RESUME_FLAG="--resume_checkpoint $LATEST" - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" -echo "${SMOKE_BANNER}[stage2_delta/1x1] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS K_max=$K_MAX" -echo "${SMOKE_BANNER}[stage2_delta/1x1] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage2_delta.py \ - $INIT_FLAG $RESUME_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---K_max "$K_MAX" \ ---curriculum_steps "$CURRICULUM_STEPS" \ ---mae_weight "$MAE_WEIGHT" \ ---cos_weight "$COS_WEIGHT" \ ---mag_weight "$MAG_WEIGHT" \ ---min_disp_norm "$MIN_DISP_NORM" \ ---lr 5e-4 \ ---min_lr 1e-6 \ ---warmup_steps 500 \ ---weight_decay 0.1 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage2_delta_1x8.sh b/scripts/slurm_frontier/train_e2e_stage2_delta_1x8.sh deleted file mode 100644 index 9f2f035..0000000 --- a/scripts/slurm_frontier/train_e2e_stage2_delta_1x8.sh +++ /dev/null @@ -1,133 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage2 Delta — 1 node × 8 GCDs (production single-node DDP) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage2_delta_1x8.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 8) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29502) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage2_delta_1x8.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s2d_1x8 -#SBATCH -o logs/%j_e2e_s2d_1x8.out -#SBATCH -e logs/%j_e2e_s2d_1x8.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 1 -#SBATCH --ntasks-per-node=8 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29502}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-1}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 8))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-8}" -K_MAX="${K_MAX:-10}" -CURRICULUM_STEPS="${CURRICULUM_STEPS:-$((MAX_STEPS / 2))}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -MAE_WEIGHT="${MAE_WEIGHT:-1.0}" -COS_WEIGHT="${COS_WEIGHT:-0.3}" -MAG_WEIGHT="${MAG_WEIGHT:-0.1}" -MIN_DISP_NORM="${MIN_DISP_NORM:-0.01}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage2_delta_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage1_frontier/e2e_stage1_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -[ -f "$INIT_CHECKPOINT" ] && INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - -LATEST="$CHECKPOINT_DIR/e2e_stage2_delta_latest.pt" -RESUME_FLAG="" -[ -f "$LATEST" ] && RESUME_FLAG="--resume_checkpoint $LATEST" - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" -echo "${SMOKE_BANNER}[stage2_delta/1x8] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS K_max=$K_MAX" -echo "${SMOKE_BANNER}[stage2_delta/1x8] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage2_delta.py \ - $INIT_FLAG $RESUME_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---K_max "$K_MAX" \ ---curriculum_steps "$CURRICULUM_STEPS" \ ---mae_weight "$MAE_WEIGHT" \ ---cos_weight "$COS_WEIGHT" \ ---mag_weight "$MAG_WEIGHT" \ ---min_disp_norm "$MIN_DISP_NORM" \ ---lr 5e-4 \ ---min_lr 1e-6 \ ---warmup_steps 500 \ ---weight_decay 0.1 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage2_delta_Nx1.sh b/scripts/slurm_frontier/train_e2e_stage2_delta_Nx1.sh deleted file mode 100644 index 2204717..0000000 --- a/scripts/slurm_frontier/train_e2e_stage2_delta_Nx1.sh +++ /dev/null @@ -1,133 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage2 Delta — N nodes × 1 GCD (cross-node networking smoke; default N=2) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage2_delta_Nx1.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 8) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29502) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage2_delta_Nx1.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s2d_Nx1 -#SBATCH -o logs/%j_e2e_s2d_Nx1.out -#SBATCH -e logs/%j_e2e_s2d_Nx1.err -#SBATCH -t 01:00:00 -#SBATCH -p batch -#SBATCH -N 2 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29502}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-2}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 1))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-8}" -K_MAX="${K_MAX:-10}" -CURRICULUM_STEPS="${CURRICULUM_STEPS:-$((MAX_STEPS / 2))}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -MAE_WEIGHT="${MAE_WEIGHT:-1.0}" -COS_WEIGHT="${COS_WEIGHT:-0.3}" -MAG_WEIGHT="${MAG_WEIGHT:-0.1}" -MIN_DISP_NORM="${MIN_DISP_NORM:-0.01}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage2_delta_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage1_frontier/e2e_stage1_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -[ -f "$INIT_CHECKPOINT" ] && INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - -LATEST="$CHECKPOINT_DIR/e2e_stage2_delta_latest.pt" -RESUME_FLAG="" -[ -f "$LATEST" ] && RESUME_FLAG="--resume_checkpoint $LATEST" - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" -echo "${SMOKE_BANNER}[stage2_delta/Nx1] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS K_max=$K_MAX" -echo "${SMOKE_BANNER}[stage2_delta/Nx1] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage2_delta.py \ - $INIT_FLAG $RESUME_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---K_max "$K_MAX" \ ---curriculum_steps "$CURRICULUM_STEPS" \ ---mae_weight "$MAE_WEIGHT" \ ---cos_weight "$COS_WEIGHT" \ ---mag_weight "$MAG_WEIGHT" \ ---min_disp_norm "$MIN_DISP_NORM" \ ---lr 5e-4 \ ---min_lr 1e-6 \ ---warmup_steps 500 \ ---weight_decay 0.1 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage2_delta_NxN.sh b/scripts/slurm_frontier/train_e2e_stage2_delta_NxN.sh deleted file mode 100644 index d54a5fe..0000000 --- a/scripts/slurm_frontier/train_e2e_stage2_delta_NxN.sh +++ /dev/null @@ -1,133 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage2 Delta — N nodes × 8 GCDs (production multi-node; default N=4, override with `sbatch -N `) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage2_delta_NxN.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 8) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29502) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage2_delta_NxN.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s2d_NxN -#SBATCH -o logs/%j_e2e_s2d_NxN.out -#SBATCH -e logs/%j_e2e_s2d_NxN.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 4 -#SBATCH --ntasks-per-node=8 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29502}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-4}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 8))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-8}" -K_MAX="${K_MAX:-10}" -CURRICULUM_STEPS="${CURRICULUM_STEPS:-$((MAX_STEPS / 2))}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -MAE_WEIGHT="${MAE_WEIGHT:-1.0}" -COS_WEIGHT="${COS_WEIGHT:-0.3}" -MAG_WEIGHT="${MAG_WEIGHT:-0.1}" -MIN_DISP_NORM="${MIN_DISP_NORM:-0.01}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage2_delta_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage1_frontier/e2e_stage1_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -[ -f "$INIT_CHECKPOINT" ] && INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - -LATEST="$CHECKPOINT_DIR/e2e_stage2_delta_latest.pt" -RESUME_FLAG="" -[ -f "$LATEST" ] && RESUME_FLAG="--resume_checkpoint $LATEST" - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" -echo "${SMOKE_BANNER}[stage2_delta/NxN] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS K_max=$K_MAX" -echo "${SMOKE_BANNER}[stage2_delta/NxN] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage2_delta.py \ - $INIT_FLAG $RESUME_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---K_max "$K_MAX" \ ---curriculum_steps "$CURRICULUM_STEPS" \ ---mae_weight "$MAE_WEIGHT" \ ---cos_weight "$COS_WEIGHT" \ ---mag_weight "$MAG_WEIGHT" \ ---min_disp_norm "$MIN_DISP_NORM" \ ---lr 5e-4 \ ---min_lr 1e-6 \ ---warmup_steps 500 \ ---weight_decay 0.1 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage2_extended.sh b/scripts/slurm_frontier/train_e2e_stage2_extended.sh index 2138b6e..9397677 100644 --- a/scripts/slurm_frontier/train_e2e_stage2_extended.sh +++ b/scripts/slurm_frontier/train_e2e_stage2_extended.sh @@ -12,11 +12,17 @@ #SBATCH --cpus-per-task=7 set -e -cd /lustre/orion/fus187/scratch/nchen/FusionAIHub +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$PWD}" +if [ ! -f "${PROJECT_DIR}/scripts/slurm_frontier/_frontier_settings.sh" ]; then + echo "ERROR: SLURM_SUBMIT_DIR (${PROJECT_DIR}) is not the repo root." >&2 + echo " cd into the FusionAIHub repo before sbatch." >&2 + exit 1 +fi +cd "${PROJECT_DIR}" mkdir -p logs runs/e2e_stage2_extended export MASTER_PORT=29503 -source scripts/slurm_frontier/_frontier_common.sh +source scripts/slurm_frontier/_frontier_settings.sh srun -N $SLURM_JOB_NUM_NODES -n $SLURM_NTASKS -c $SLURM_CPUS_PER_TASK \ --gpus-per-task=1 --gpu-bind=closest \ diff --git a/scripts/slurm_frontier/train_e2e_stage2_extended_1x1.sh b/scripts/slurm_frontier/train_e2e_stage2_extended_1x1.sh deleted file mode 100644 index 5538695..0000000 --- a/scripts/slurm_frontier/train_e2e_stage2_extended_1x1.sh +++ /dev/null @@ -1,138 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage2 Extended — 1 node × 1 GCD (single-GPU smoke / dev) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage2_extended_1x1.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 4) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29503) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage2_extended_1x1.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s2e_1x1 -#SBATCH -o logs/%j_e2e_s2e_1x1.out -#SBATCH -e logs/%j_e2e_s2e_1x1.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29503}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-1}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 1))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-4}" -CURRICULUM_KS="${CURRICULUM_KS:-2,3,4}" -BLOCK_STEPS="${BLOCK_STEPS:-$((MAX_STEPS / 3))}" -GRAD_CHECKPOINT_EVERY="${GRAD_CHECKPOINT_EVERY:-2}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -MAE_WEIGHT="${MAE_WEIGHT:-1.0}" -COS_WEIGHT="${COS_WEIGHT:-0.3}" -MAG_WEIGHT="${MAG_WEIGHT:-0.1}" -MIN_DISP_NORM="${MIN_DISP_NORM:-0.01}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage2_ext_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage2_delta_frontier/e2e_stage2_delta_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -[ -f "$INIT_CHECKPOINT" ] && INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - -LATEST="$CHECKPOINT_DIR/e2e_stage2_ext_latest.pt" -RESUME_FLAG="" -[ -f "$LATEST" ] && RESUME_FLAG="--resume_checkpoint $LATEST" - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" - -NO_DISP_FLAG="" -[ "${NO_DISPLACEMENT_LOSS:-0}" = "1" ] && NO_DISP_FLAG="--no_displacement_loss" -echo "${SMOKE_BANNER}[stage2_extended/1x1] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS Ks=$CURRICULUM_KS" -echo "${SMOKE_BANNER}[stage2_extended/1x1] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage2_extended.py \ - $INIT_FLAG $RESUME_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG $NO_DISP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---curriculum_Ks "$CURRICULUM_KS" \ ---block_steps "$BLOCK_STEPS" \ ---mae_weight "$MAE_WEIGHT" \ ---cos_weight "$COS_WEIGHT" \ ---mag_weight "$MAG_WEIGHT" \ ---min_disp_norm "$MIN_DISP_NORM" \ ---grad_checkpoint_every "$GRAD_CHECKPOINT_EVERY" \ ---lr 1e-5 \ ---min_lr 1e-7 \ ---warmup_steps 500 \ ---weight_decay 0.01 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage2_extended_1x8.sh b/scripts/slurm_frontier/train_e2e_stage2_extended_1x8.sh deleted file mode 100644 index c4035b3..0000000 --- a/scripts/slurm_frontier/train_e2e_stage2_extended_1x8.sh +++ /dev/null @@ -1,138 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage2 Extended — 1 node × 8 GCDs (production single-node DDP) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage2_extended_1x8.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 4) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29503) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage2_extended_1x8.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s2e_1x8 -#SBATCH -o logs/%j_e2e_s2e_1x8.out -#SBATCH -e logs/%j_e2e_s2e_1x8.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 1 -#SBATCH --ntasks-per-node=8 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29503}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-1}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 8))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-4}" -CURRICULUM_KS="${CURRICULUM_KS:-2,3,4}" -BLOCK_STEPS="${BLOCK_STEPS:-$((MAX_STEPS / 3))}" -GRAD_CHECKPOINT_EVERY="${GRAD_CHECKPOINT_EVERY:-2}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -MAE_WEIGHT="${MAE_WEIGHT:-1.0}" -COS_WEIGHT="${COS_WEIGHT:-0.3}" -MAG_WEIGHT="${MAG_WEIGHT:-0.1}" -MIN_DISP_NORM="${MIN_DISP_NORM:-0.01}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage2_ext_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage2_delta_frontier/e2e_stage2_delta_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -[ -f "$INIT_CHECKPOINT" ] && INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - -LATEST="$CHECKPOINT_DIR/e2e_stage2_ext_latest.pt" -RESUME_FLAG="" -[ -f "$LATEST" ] && RESUME_FLAG="--resume_checkpoint $LATEST" - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" - -NO_DISP_FLAG="" -[ "${NO_DISPLACEMENT_LOSS:-0}" = "1" ] && NO_DISP_FLAG="--no_displacement_loss" -echo "${SMOKE_BANNER}[stage2_extended/1x8] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS Ks=$CURRICULUM_KS" -echo "${SMOKE_BANNER}[stage2_extended/1x8] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage2_extended.py \ - $INIT_FLAG $RESUME_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG $NO_DISP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---curriculum_Ks "$CURRICULUM_KS" \ ---block_steps "$BLOCK_STEPS" \ ---mae_weight "$MAE_WEIGHT" \ ---cos_weight "$COS_WEIGHT" \ ---mag_weight "$MAG_WEIGHT" \ ---min_disp_norm "$MIN_DISP_NORM" \ ---grad_checkpoint_every "$GRAD_CHECKPOINT_EVERY" \ ---lr 1e-5 \ ---min_lr 1e-7 \ ---warmup_steps 500 \ ---weight_decay 0.01 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage2_extended_Nx1.sh b/scripts/slurm_frontier/train_e2e_stage2_extended_Nx1.sh deleted file mode 100644 index b0beee1..0000000 --- a/scripts/slurm_frontier/train_e2e_stage2_extended_Nx1.sh +++ /dev/null @@ -1,138 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage2 Extended — N nodes × 1 GCD (cross-node networking smoke; default N=2) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage2_extended_Nx1.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 4) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29503) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage2_extended_Nx1.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s2e_Nx1 -#SBATCH -o logs/%j_e2e_s2e_Nx1.out -#SBATCH -e logs/%j_e2e_s2e_Nx1.err -#SBATCH -t 01:00:00 -#SBATCH -p batch -#SBATCH -N 2 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29503}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-2}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 1))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-4}" -CURRICULUM_KS="${CURRICULUM_KS:-2,3,4}" -BLOCK_STEPS="${BLOCK_STEPS:-$((MAX_STEPS / 3))}" -GRAD_CHECKPOINT_EVERY="${GRAD_CHECKPOINT_EVERY:-2}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -MAE_WEIGHT="${MAE_WEIGHT:-1.0}" -COS_WEIGHT="${COS_WEIGHT:-0.3}" -MAG_WEIGHT="${MAG_WEIGHT:-0.1}" -MIN_DISP_NORM="${MIN_DISP_NORM:-0.01}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage2_ext_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage2_delta_frontier/e2e_stage2_delta_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -[ -f "$INIT_CHECKPOINT" ] && INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - -LATEST="$CHECKPOINT_DIR/e2e_stage2_ext_latest.pt" -RESUME_FLAG="" -[ -f "$LATEST" ] && RESUME_FLAG="--resume_checkpoint $LATEST" - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" - -NO_DISP_FLAG="" -[ "${NO_DISPLACEMENT_LOSS:-0}" = "1" ] && NO_DISP_FLAG="--no_displacement_loss" -echo "${SMOKE_BANNER}[stage2_extended/Nx1] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS Ks=$CURRICULUM_KS" -echo "${SMOKE_BANNER}[stage2_extended/Nx1] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage2_extended.py \ - $INIT_FLAG $RESUME_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG $NO_DISP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---curriculum_Ks "$CURRICULUM_KS" \ ---block_steps "$BLOCK_STEPS" \ ---mae_weight "$MAE_WEIGHT" \ ---cos_weight "$COS_WEIGHT" \ ---mag_weight "$MAG_WEIGHT" \ ---min_disp_norm "$MIN_DISP_NORM" \ ---grad_checkpoint_every "$GRAD_CHECKPOINT_EVERY" \ ---lr 1e-5 \ ---min_lr 1e-7 \ ---warmup_steps 500 \ ---weight_decay 0.01 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage2_extended_NxN.sh b/scripts/slurm_frontier/train_e2e_stage2_extended_NxN.sh deleted file mode 100644 index c124a0e..0000000 --- a/scripts/slurm_frontier/train_e2e_stage2_extended_NxN.sh +++ /dev/null @@ -1,138 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage2 Extended — N nodes × 8 GCDs (production multi-node; default N=4, override with `sbatch -N `) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage2_extended_NxN.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 4) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29503) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage2_extended_NxN.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s2e_NxN -#SBATCH -o logs/%j_e2e_s2e_NxN.out -#SBATCH -e logs/%j_e2e_s2e_NxN.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 4 -#SBATCH --ntasks-per-node=8 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29503}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-4}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 8))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-4}" -CURRICULUM_KS="${CURRICULUM_KS:-2,3,4}" -BLOCK_STEPS="${BLOCK_STEPS:-$((MAX_STEPS / 3))}" -GRAD_CHECKPOINT_EVERY="${GRAD_CHECKPOINT_EVERY:-2}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -MAE_WEIGHT="${MAE_WEIGHT:-1.0}" -COS_WEIGHT="${COS_WEIGHT:-0.3}" -MAG_WEIGHT="${MAG_WEIGHT:-0.1}" -MIN_DISP_NORM="${MIN_DISP_NORM:-0.01}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage2_ext_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage2_delta_frontier/e2e_stage2_delta_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -[ -f "$INIT_CHECKPOINT" ] && INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - -LATEST="$CHECKPOINT_DIR/e2e_stage2_ext_latest.pt" -RESUME_FLAG="" -[ -f "$LATEST" ] && RESUME_FLAG="--resume_checkpoint $LATEST" - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" - -NO_DISP_FLAG="" -[ "${NO_DISPLACEMENT_LOSS:-0}" = "1" ] && NO_DISP_FLAG="--no_displacement_loss" -echo "${SMOKE_BANNER}[stage2_extended/NxN] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS Ks=$CURRICULUM_KS" -echo "${SMOKE_BANNER}[stage2_extended/NxN] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage2_extended.py \ - $INIT_FLAG $RESUME_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG $NO_DISP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---curriculum_Ks "$CURRICULUM_KS" \ ---block_steps "$BLOCK_STEPS" \ ---mae_weight "$MAE_WEIGHT" \ ---cos_weight "$COS_WEIGHT" \ ---mag_weight "$MAG_WEIGHT" \ ---min_disp_norm "$MIN_DISP_NORM" \ ---grad_checkpoint_every "$GRAD_CHECKPOINT_EVERY" \ ---lr 1e-5 \ ---min_lr 1e-7 \ ---warmup_steps 500 \ ---weight_decay 0.01 \ ---grad_clip 5.0 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_max_batches "$VAL_MAX_BATCHES" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage3.sh b/scripts/slurm_frontier/train_e2e_stage3.sh index a503125..ac5249a 100644 --- a/scripts/slurm_frontier/train_e2e_stage3.sh +++ b/scripts/slurm_frontier/train_e2e_stage3.sh @@ -12,11 +12,17 @@ #SBATCH --cpus-per-task=7 set -e -cd /lustre/orion/fus187/scratch/nchen/FusionAIHub +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$PWD}" +if [ ! -f "${PROJECT_DIR}/scripts/slurm_frontier/_frontier_settings.sh" ]; then + echo "ERROR: SLURM_SUBMIT_DIR (${PROJECT_DIR}) is not the repo root." >&2 + echo " cd into the FusionAIHub repo before sbatch." >&2 + exit 1 +fi +cd "${PROJECT_DIR}" mkdir -p logs runs/e2e_stage3 export MASTER_PORT=29504 -source scripts/slurm_frontier/_frontier_common.sh +source scripts/slurm_frontier/_frontier_settings.sh srun -N $SLURM_JOB_NUM_NODES -n $SLURM_NTASKS -c $SLURM_CPUS_PER_TASK \ --gpus-per-task=1 --gpu-bind=closest \ diff --git a/scripts/slurm_frontier/train_e2e_stage3_1x1.sh b/scripts/slurm_frontier/train_e2e_stage3_1x1.sh deleted file mode 100644 index 325cf8c..0000000 --- a/scripts/slurm_frontier/train_e2e_stage3_1x1.sh +++ /dev/null @@ -1,148 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage3 — 1 node × 1 GCD (single-GPU smoke / dev) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage3_1x1.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 16) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29504) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage3_1x1.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s3_1x1 -#SBATCH -o logs/%j_e2e_s3_1x1.out -#SBATCH -e logs/%j_e2e_s3_1x1.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29504}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-1}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 1))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-16}" -VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-8}" -K_MIN="${K_MIN:-2}" -K_MAX="${K_MAX:-4}" -N_CURRICULUM_BLOCKS="${N_CURRICULUM_BLOCKS:-2}" -CURRICULUM_STEPS="${CURRICULUM_STEPS:-$((MAX_STEPS / 2))}" -LORA_RANK="${LORA_RANK:-16}" -LORA_ALPHA="${LORA_ALPHA:-16.0}" -POOL_SIZE="${POOL_SIZE:-50}" -BUFFER_SIZE="${BUFFER_SIZE:-500}" -BUFFER_REFRESH_PERIOD="${BUFFER_REFRESH_PERIOD:-50}" -BUFFER_REFRESH_FRACTION="${BUFFER_REFRESH_FRACTION:-0.1}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage3_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage2_delta_frontier/e2e_stage2_delta_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -[ -f "$INIT_CHECKPOINT" ] && INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - -LATEST="$CHECKPOINT_DIR/e2e_stage3_latest.pt" -RESUME_FLAG="" -[ -f "$LATEST" ] && RESUME_FLAG="--resume_checkpoint $LATEST" - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" - -USE_DISP_FLAG="--use_displacement_loss" -[ "${NO_DISPLACEMENT_LOSS:-0}" = "1" ] && USE_DISP_FLAG="" -echo "${SMOKE_BANNER}[stage3/1x1] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS K=[$K_MIN,$K_MAX]" -echo "${SMOKE_BANNER}[stage3/1x1] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage3.py \ - $INIT_FLAG $RESUME_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG $USE_DISP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---lora_rank "$LORA_RANK" \ ---lora_alpha "$LORA_ALPHA" \ ---K_min "$K_MIN" \ ---K_max "$K_MAX" \ ---n_curriculum_blocks "$N_CURRICULUM_BLOCKS" \ ---curriculum_steps "$CURRICULUM_STEPS" \ ---pool_size "$POOL_SIZE" \ ---buffer_size "$BUFFER_SIZE" \ ---buffer_refresh_period "$BUFFER_REFRESH_PERIOD" \ ---buffer_refresh_fraction "$BUFFER_REFRESH_FRACTION" \ ---lr 3e-5 \ ---min_lr 1e-7 \ ---warmup_steps 200 \ ---weight_decay 0.01 \ ---grad_clip 5.0 \ ---cos_weight 0.3 \ ---mag_weight 0.1 \ ---min_disp_norm 0.01 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_batch_size "$VAL_BATCH_SIZE" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage3_1x8.sh b/scripts/slurm_frontier/train_e2e_stage3_1x8.sh deleted file mode 100644 index ee344bf..0000000 --- a/scripts/slurm_frontier/train_e2e_stage3_1x8.sh +++ /dev/null @@ -1,148 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage3 — 1 node × 8 GCDs (production single-node DDP) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage3_1x8.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 16) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29504) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage3_1x8.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s3_1x8 -#SBATCH -o logs/%j_e2e_s3_1x8.out -#SBATCH -e logs/%j_e2e_s3_1x8.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 1 -#SBATCH --ntasks-per-node=8 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29504}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-1}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 8))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-16}" -VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-8}" -K_MIN="${K_MIN:-2}" -K_MAX="${K_MAX:-4}" -N_CURRICULUM_BLOCKS="${N_CURRICULUM_BLOCKS:-2}" -CURRICULUM_STEPS="${CURRICULUM_STEPS:-$((MAX_STEPS / 2))}" -LORA_RANK="${LORA_RANK:-16}" -LORA_ALPHA="${LORA_ALPHA:-16.0}" -POOL_SIZE="${POOL_SIZE:-50}" -BUFFER_SIZE="${BUFFER_SIZE:-500}" -BUFFER_REFRESH_PERIOD="${BUFFER_REFRESH_PERIOD:-50}" -BUFFER_REFRESH_FRACTION="${BUFFER_REFRESH_FRACTION:-0.1}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage3_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage2_delta_frontier/e2e_stage2_delta_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -[ -f "$INIT_CHECKPOINT" ] && INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - -LATEST="$CHECKPOINT_DIR/e2e_stage3_latest.pt" -RESUME_FLAG="" -[ -f "$LATEST" ] && RESUME_FLAG="--resume_checkpoint $LATEST" - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" - -USE_DISP_FLAG="--use_displacement_loss" -[ "${NO_DISPLACEMENT_LOSS:-0}" = "1" ] && USE_DISP_FLAG="" -echo "${SMOKE_BANNER}[stage3/1x8] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS K=[$K_MIN,$K_MAX]" -echo "${SMOKE_BANNER}[stage3/1x8] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage3.py \ - $INIT_FLAG $RESUME_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG $USE_DISP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---lora_rank "$LORA_RANK" \ ---lora_alpha "$LORA_ALPHA" \ ---K_min "$K_MIN" \ ---K_max "$K_MAX" \ ---n_curriculum_blocks "$N_CURRICULUM_BLOCKS" \ ---curriculum_steps "$CURRICULUM_STEPS" \ ---pool_size "$POOL_SIZE" \ ---buffer_size "$BUFFER_SIZE" \ ---buffer_refresh_period "$BUFFER_REFRESH_PERIOD" \ ---buffer_refresh_fraction "$BUFFER_REFRESH_FRACTION" \ ---lr 3e-5 \ ---min_lr 1e-7 \ ---warmup_steps 200 \ ---weight_decay 0.01 \ ---grad_clip 5.0 \ ---cos_weight 0.3 \ ---mag_weight 0.1 \ ---min_disp_norm 0.01 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_batch_size "$VAL_BATCH_SIZE" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage3_Nx1.sh b/scripts/slurm_frontier/train_e2e_stage3_Nx1.sh deleted file mode 100644 index a6717cd..0000000 --- a/scripts/slurm_frontier/train_e2e_stage3_Nx1.sh +++ /dev/null @@ -1,148 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage3 — N nodes × 1 GCD (cross-node networking smoke; default N=2) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage3_Nx1.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 16) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29504) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage3_Nx1.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s3_Nx1 -#SBATCH -o logs/%j_e2e_s3_Nx1.out -#SBATCH -e logs/%j_e2e_s3_Nx1.err -#SBATCH -t 01:00:00 -#SBATCH -p batch -#SBATCH -N 2 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29504}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-2}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 1))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-16}" -VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-8}" -K_MIN="${K_MIN:-2}" -K_MAX="${K_MAX:-4}" -N_CURRICULUM_BLOCKS="${N_CURRICULUM_BLOCKS:-2}" -CURRICULUM_STEPS="${CURRICULUM_STEPS:-$((MAX_STEPS / 2))}" -LORA_RANK="${LORA_RANK:-16}" -LORA_ALPHA="${LORA_ALPHA:-16.0}" -POOL_SIZE="${POOL_SIZE:-50}" -BUFFER_SIZE="${BUFFER_SIZE:-500}" -BUFFER_REFRESH_PERIOD="${BUFFER_REFRESH_PERIOD:-50}" -BUFFER_REFRESH_FRACTION="${BUFFER_REFRESH_FRACTION:-0.1}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage3_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage2_delta_frontier/e2e_stage2_delta_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -[ -f "$INIT_CHECKPOINT" ] && INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - -LATEST="$CHECKPOINT_DIR/e2e_stage3_latest.pt" -RESUME_FLAG="" -[ -f "$LATEST" ] && RESUME_FLAG="--resume_checkpoint $LATEST" - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" - -USE_DISP_FLAG="--use_displacement_loss" -[ "${NO_DISPLACEMENT_LOSS:-0}" = "1" ] && USE_DISP_FLAG="" -echo "${SMOKE_BANNER}[stage3/Nx1] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS K=[$K_MIN,$K_MAX]" -echo "${SMOKE_BANNER}[stage3/Nx1] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage3.py \ - $INIT_FLAG $RESUME_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG $USE_DISP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---lora_rank "$LORA_RANK" \ ---lora_alpha "$LORA_ALPHA" \ ---K_min "$K_MIN" \ ---K_max "$K_MAX" \ ---n_curriculum_blocks "$N_CURRICULUM_BLOCKS" \ ---curriculum_steps "$CURRICULUM_STEPS" \ ---pool_size "$POOL_SIZE" \ ---buffer_size "$BUFFER_SIZE" \ ---buffer_refresh_period "$BUFFER_REFRESH_PERIOD" \ ---buffer_refresh_fraction "$BUFFER_REFRESH_FRACTION" \ ---lr 3e-5 \ ---min_lr 1e-7 \ ---warmup_steps 200 \ ---weight_decay 0.01 \ ---grad_clip 5.0 \ ---cos_weight 0.3 \ ---mag_weight 0.1 \ ---min_disp_norm 0.01 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_batch_size "$VAL_BATCH_SIZE" \ No newline at end of file diff --git a/scripts/slurm_frontier/train_e2e_stage3_NxN.sh b/scripts/slurm_frontier/train_e2e_stage3_NxN.sh deleted file mode 100644 index fa79119..0000000 --- a/scripts/slurm_frontier/train_e2e_stage3_NxN.sh +++ /dev/null @@ -1,148 +0,0 @@ -#!/bin/bash -# Frontier DDP launcher: train_e2e Stage3 — N nodes × 8 GCDs (production multi-node; default N=4, override with `sbatch -N `) -# -# Usage: -# sbatch scripts/slurm_frontier/train_e2e_stage3_NxN.sh -# -# Common env overrides: -# SMOKE=1 # short test: MAX_STEPS=20, MAX_FILES=4, freq logs -# MAX_STEPS= # total optimizer steps -# MAX_FILES= # cap on training shots (debug) -# BATCH_SIZE= # per-rank batch size (default 16) -# NUM_WORKERS= # DataLoader workers per rank (default 4) -# DATA_DIR= # override data root -# CHECKPOINT_DIR= # override checkpoint dir -# MASTER_PORT= # override port (default 29504) -# -# Override resource shape on the CLI (sbatch flags beat #SBATCH directives): -# sbatch -N 8 -t 12:00:00 scripts/slurm_frontier/train_e2e_stage3_NxN.sh -# -#SBATCH -A fus187 -#SBATCH -J e2e_s3_NxN -#SBATCH -o logs/%j_e2e_s3_NxN.out -#SBATCH -e logs/%j_e2e_s3_NxN.err -#SBATCH -t 02:00:00 -#SBATCH -p batch -#SBATCH -N 4 -#SBATCH --ntasks-per-node=8 -#SBATCH --gpus-per-task=1 -#SBATCH --gpu-bind=closest -#SBATCH --cpus-per-task=7 -set -uo pipefail - -PROJECT_DIR=/lustre/orion/fus187/scratch/nchen/FusionAIHub -cd "$PROJECT_DIR" -mkdir -p logs - -# Per-stage MASTER_PORT default (overridable). Must be set BEFORE sourcing -# _frontier_common.sh, since that script only fills in if unset. -export MASTER_PORT="${MASTER_PORT:-29504}" - -# shellcheck disable=SC1091 -source scripts/slurm_frontier/_frontier_common.sh - -# ─── Resource shape (taken from SLURM allocation, never hard-coded) ────── -NODES="${SLURM_JOB_NUM_NODES:-4}" -TOTAL_RANKS="${SLURM_NTASKS:-$((NODES * 8))}" -CPUS_PER_TASK="${SLURM_CPUS_PER_TASK:-7}" - -# ─── SMOKE=1 overrides for end-to-end smoke testing ────────────────────── -if [ "${SMOKE:-0}" = "1" ]; then - MAX_STEPS="${MAX_STEPS:-20}" - MAX_FILES="${MAX_FILES:-4}" - NUM_WORKERS="${NUM_WORKERS:-2}" - LOG_EVERY="${LOG_EVERY:-2}" - VAL_EVERY="${VAL_EVERY:-10}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-2}" - SMOKE_BANNER="[SMOKE] " -else - MAX_STEPS="${MAX_STEPS:-1000}" - NUM_WORKERS="${NUM_WORKERS:-4}" - LOG_EVERY="${LOG_EVERY:-50}" - VAL_EVERY="${VAL_EVERY:-200}" - VAL_MAX_BATCHES="${VAL_MAX_BATCHES:-20}" - SMOKE_BANNER="" -fi - -MAX_FILES_FLAG="" -[ -n "${MAX_FILES:-}" ] && MAX_FILES_FLAG="--max_files $MAX_FILES" - -# ─── Stage-specific defaults & init/resume flags ───────────────────────── -BATCH_SIZE="${BATCH_SIZE:-16}" -VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-8}" -K_MIN="${K_MIN:-2}" -K_MAX="${K_MAX:-4}" -N_CURRICULUM_BLOCKS="${N_CURRICULUM_BLOCKS:-2}" -CURRICULUM_STEPS="${CURRICULUM_STEPS:-$((MAX_STEPS / 2))}" -LORA_RANK="${LORA_RANK:-16}" -LORA_ALPHA="${LORA_ALPHA:-16.0}" -POOL_SIZE="${POOL_SIZE:-50}" -BUFFER_SIZE="${BUFFER_SIZE:-500}" -BUFFER_REFRESH_PERIOD="${BUFFER_REFRESH_PERIOD:-50}" -BUFFER_REFRESH_FRACTION="${BUFFER_REFRESH_FRACTION:-0.1}" -D_MODEL="${D_MODEL:-256}" -N_LAYERS="${N_LAYERS:-8}" -N_HEADS="${N_HEADS:-8}" -DATA_DIR="${DATA_DIR:-/lustre/orion/fus187/proj-shared/foundation_model}" -STATS_PATH="${STATS_PATH:-data/preprocessing_stats.pt}" -CHECKPOINT_DIR="${CHECKPOINT_DIR:-runs/e2e_stage3_frontier}" -INIT_CHECKPOINT="${INIT_CHECKPOINT:-runs/e2e_stage2_delta_frontier/e2e_stage2_delta_best.pt}" -mkdir -p "$CHECKPOINT_DIR" - -INIT_FLAG="" -[ -f "$INIT_CHECKPOINT" ] && INIT_FLAG="--init_checkpoint $INIT_CHECKPOINT" - -LATEST="$CHECKPOINT_DIR/e2e_stage3_latest.pt" -RESUME_FLAG="" -[ -f "$LATEST" ] && RESUME_FLAG="--resume_checkpoint $LATEST" - -NO_AMP_FLAG="" -[ "${NO_AMP:-0}" = "1" ] && NO_AMP_FLAG="--no_amp" - -USE_DISP_FLAG="--use_displacement_loss" -[ "${NO_DISPLACEMENT_LOSS:-0}" = "1" ] && USE_DISP_FLAG="" -echo "${SMOKE_BANNER}[stage3/NxN] nodes=$NODES total_ranks=$TOTAL_RANKS \ -batch=$BATCH_SIZE steps=$MAX_STEPS K=[$K_MIN,$K_MAX]" -echo "${SMOKE_BANNER}[stage3/NxN] master=$MASTER_ADDR:$MASTER_PORT data=$DATA_DIR" - -srun -N "$NODES" -n "$TOTAL_RANKS" -c "$CPUS_PER_TASK" \ - --gpus-per-task=1 --gpu-bind=closest \ - scripts/slurm_frontier/_srun_rank_wrapper.sh \ - scripts/training/train_e2e_stage3.py \ - $INIT_FLAG $RESUME_FLAG $MAX_FILES_FLAG $NO_AMP_FLAG $USE_DISP_FLAG \ ---data_dir "$DATA_DIR" \ ---stats_path "$STATS_PATH" \ ---checkpoint_dir "$CHECKPOINT_DIR" \ ---val_fraction 0.1 \ ---seed 42 \ ---chunk_duration_s 0.05 \ ---step_size_s 0.01 \ ---warmup_s 1.0 \ ---d_model "$D_MODEL" \ ---n_layers "$N_LAYERS" \ ---n_heads "$N_HEADS" \ ---dropout 0.1 \ ---lora_rank "$LORA_RANK" \ ---lora_alpha "$LORA_ALPHA" \ ---K_min "$K_MIN" \ ---K_max "$K_MAX" \ ---n_curriculum_blocks "$N_CURRICULUM_BLOCKS" \ ---curriculum_steps "$CURRICULUM_STEPS" \ ---pool_size "$POOL_SIZE" \ ---buffer_size "$BUFFER_SIZE" \ ---buffer_refresh_period "$BUFFER_REFRESH_PERIOD" \ ---buffer_refresh_fraction "$BUFFER_REFRESH_FRACTION" \ ---lr 3e-5 \ ---min_lr 1e-7 \ ---warmup_steps 200 \ ---weight_decay 0.01 \ ---grad_clip 5.0 \ ---cos_weight 0.3 \ ---mag_weight 0.1 \ ---min_disp_norm 0.01 \ ---batch_size "$BATCH_SIZE" \ ---num_workers "$NUM_WORKERS" \ ---max_steps "$MAX_STEPS" \ ---log_every "$LOG_EVERY" \ ---val_every "$VAL_EVERY" \ ---val_batch_size "$VAL_BATCH_SIZE" \ No newline at end of file diff --git a/scripts/slurm_frontier/verify_flash_attn.py b/scripts/slurm_frontier/verify_flash_attn.py new file mode 100644 index 0000000..c441114 --- /dev/null +++ b/scripts/slurm_frontier/verify_flash_attn.py @@ -0,0 +1,25 @@ +"""Smoke test for flash-attention 2 on Frontier (MI250X / gfx90a).""" +import sys + +import torch + +try: + import flash_attn + from flash_attn import flash_attn_func +except ImportError as e: + sys.exit(f"flash_attn not importable: {e}") + +assert torch.cuda.is_available(), "no GPU visible to torch" +assert torch.version.hip is not None, "torch is not a ROCm build" + +arch = torch.cuda.get_device_properties(0).gcnArchName +assert "gfx90a" in arch, f"unexpected gcn arch: {arch}" + +q = k = v = torch.randn(2, 8, 16, 64, device="cuda", dtype=torch.float16) +out = flash_attn_func(q, k, v, causal=True) +assert out.shape == q.shape + +print( + f"flash_attn {flash_attn.__version__} OK on " + f"{torch.cuda.get_device_name(0)} ({arch})" +) diff --git a/scripts/slurm/benchmark_data_loader.sh b/scripts/slurm_stellar/benchmark_data_loader.sh similarity index 100% rename from scripts/slurm/benchmark_data_loader.sh rename to scripts/slurm_stellar/benchmark_data_loader.sh diff --git a/scripts/slurm/benchmark_e2e_memory.sh b/scripts/slurm_stellar/benchmark_e2e_memory.sh similarity index 100% rename from scripts/slurm/benchmark_e2e_memory.sh rename to scripts/slurm_stellar/benchmark_e2e_memory.sh diff --git a/scripts/slurm/benchmark_stage2_ext.sh b/scripts/slurm_stellar/benchmark_stage2_ext.sh similarity index 100% rename from scripts/slurm/benchmark_stage2_ext.sh rename to scripts/slurm_stellar/benchmark_stage2_ext.sh diff --git a/scripts/slurm/compute_ae_token_stats.sh b/scripts/slurm_stellar/compute_ae_token_stats.sh similarity index 100% rename from scripts/slurm/compute_ae_token_stats.sh rename to scripts/slurm_stellar/compute_ae_token_stats.sh diff --git a/scripts/slurm/eval_e2e_stage1.sh b/scripts/slurm_stellar/eval_e2e_stage1.sh similarity index 100% rename from scripts/slurm/eval_e2e_stage1.sh rename to scripts/slurm_stellar/eval_e2e_stage1.sh diff --git a/scripts/slurm/eval_e2e_stage2.sh b/scripts/slurm_stellar/eval_e2e_stage2.sh similarity index 100% rename from scripts/slurm/eval_e2e_stage2.sh rename to scripts/slurm_stellar/eval_e2e_stage2.sh diff --git a/scripts/slurm/generate_tokens.sh b/scripts/slurm_stellar/generate_tokens.sh similarity index 100% rename from scripts/slurm/generate_tokens.sh rename to scripts/slurm_stellar/generate_tokens.sh diff --git a/scripts/slurm/make_processing_stats.sh b/scripts/slurm_stellar/make_processing_stats.sh similarity index 100% rename from scripts/slurm/make_processing_stats.sh rename to scripts/slurm_stellar/make_processing_stats.sh diff --git a/scripts/slurm/prepare_data.sh b/scripts/slurm_stellar/prepare_data.sh similarity index 100% rename from scripts/slurm/prepare_data.sh rename to scripts/slurm_stellar/prepare_data.sh diff --git a/scripts/slurm/profile_stage1.sh b/scripts/slurm_stellar/profile_stage1.sh similarity index 100% rename from scripts/slurm/profile_stage1.sh rename to scripts/slurm_stellar/profile_stage1.sh diff --git a/scripts/slurm/sample_ddp.sh b/scripts/slurm_stellar/sample_ddp.sh similarity index 100% rename from scripts/slurm/sample_ddp.sh rename to scripts/slurm_stellar/sample_ddp.sh diff --git a/scripts/slurm/test_dynamics_overfit.sh b/scripts/slurm_stellar/test_dynamics_overfit.sh similarity index 100% rename from scripts/slurm/test_dynamics_overfit.sh rename to scripts/slurm_stellar/test_dynamics_overfit.sh diff --git a/scripts/slurm/train_aurora_debug.sh b/scripts/slurm_stellar/train_aurora_debug.sh similarity index 100% rename from scripts/slurm/train_aurora_debug.sh rename to scripts/slurm_stellar/train_aurora_debug.sh diff --git a/scripts/slurm/train_bc_stage1.sh b/scripts/slurm_stellar/train_bc_stage1.sh similarity index 100% rename from scripts/slurm/train_bc_stage1.sh rename to scripts/slurm_stellar/train_bc_stage1.sh diff --git a/scripts/slurm/train_bc_stage2.sh b/scripts/slurm_stellar/train_bc_stage2.sh similarity index 100% rename from scripts/slurm/train_bc_stage2.sh rename to scripts/slurm_stellar/train_bc_stage2.sh diff --git a/scripts/slurm/train_bc_stage2_extended.sh b/scripts/slurm_stellar/train_bc_stage2_extended.sh similarity index 100% rename from scripts/slurm/train_bc_stage2_extended.sh rename to scripts/slurm_stellar/train_bc_stage2_extended.sh diff --git a/scripts/slurm/train_bes.sh b/scripts/slurm_stellar/train_bes.sh similarity index 100% rename from scripts/slurm/train_bes.sh rename to scripts/slurm_stellar/train_bes.sh diff --git a/scripts/slurm/train_bolo_raw.sh b/scripts/slurm_stellar/train_bolo_raw.sh similarity index 100% rename from scripts/slurm/train_bolo_raw.sh rename to scripts/slurm_stellar/train_bolo_raw.sh diff --git a/scripts/slurm/train_cer_rot.sh b/scripts/slurm_stellar/train_cer_rot.sh similarity index 100% rename from scripts/slurm/train_cer_rot.sh rename to scripts/slurm_stellar/train_cer_rot.sh diff --git a/scripts/slurm/train_cer_ti.sh b/scripts/slurm_stellar/train_cer_ti.sh similarity index 100% rename from scripts/slurm/train_cer_ti.sh rename to scripts/slurm_stellar/train_cer_ti.sh diff --git a/scripts/slurm/train_co2.sh b/scripts/slurm_stellar/train_co2.sh similarity index 100% rename from scripts/slurm/train_co2.sh rename to scripts/slurm_stellar/train_co2.sh diff --git a/scripts/slurm/train_co2_tf_only.sh b/scripts/slurm_stellar/train_co2_tf_only.sh similarity index 100% rename from scripts/slurm/train_co2_tf_only.sh rename to scripts/slurm_stellar/train_co2_tf_only.sh diff --git a/scripts/slurm/train_e2e_stage1.sh b/scripts/slurm_stellar/train_e2e_stage1.sh similarity index 100% rename from scripts/slurm/train_e2e_stage1.sh rename to scripts/slurm_stellar/train_e2e_stage1.sh diff --git a/scripts/slurm/train_e2e_stage2.sh b/scripts/slurm_stellar/train_e2e_stage2.sh similarity index 100% rename from scripts/slurm/train_e2e_stage2.sh rename to scripts/slurm_stellar/train_e2e_stage2.sh diff --git a/scripts/slurm/train_e2e_stage2_delta.sh b/scripts/slurm_stellar/train_e2e_stage2_delta.sh similarity index 100% rename from scripts/slurm/train_e2e_stage2_delta.sh rename to scripts/slurm_stellar/train_e2e_stage2_delta.sh diff --git a/scripts/slurm/train_e2e_stage2_extended.sh b/scripts/slurm_stellar/train_e2e_stage2_extended.sh similarity index 100% rename from scripts/slurm/train_e2e_stage2_extended.sh rename to scripts/slurm_stellar/train_e2e_stage2_extended.sh diff --git a/scripts/slurm/train_e2e_stage3.sh b/scripts/slurm_stellar/train_e2e_stage3.sh similarity index 100% rename from scripts/slurm/train_e2e_stage3.sh rename to scripts/slurm_stellar/train_e2e_stage3.sh diff --git a/scripts/slurm/train_ece.sh b/scripts/slurm_stellar/train_ece.sh similarity index 100% rename from scripts/slurm/train_ece.sh rename to scripts/slurm_stellar/train_ece.sh diff --git a/scripts/slurm/train_ece_conv_fct.sh b/scripts/slurm_stellar/train_ece_conv_fct.sh similarity index 100% rename from scripts/slurm/train_ece_conv_fct.sh rename to scripts/slurm_stellar/train_ece_conv_fct.sh diff --git a/scripts/slurm/train_ece_conv_nc.sh b/scripts/slurm_stellar/train_ece_conv_nc.sh similarity index 100% rename from scripts/slurm/train_ece_conv_nc.sh rename to scripts/slurm_stellar/train_ece_conv_nc.sh diff --git a/scripts/slurm/train_ece_conv_tfc.sh b/scripts/slurm_stellar/train_ece_conv_tfc.sh similarity index 100% rename from scripts/slurm/train_ece_conv_tfc.sh rename to scripts/slurm_stellar/train_ece_conv_tfc.sh diff --git a/scripts/slurm/train_ece_tf_only.sh b/scripts/slurm_stellar/train_ece_tf_only.sh similarity index 100% rename from scripts/slurm/train_ece_tf_only.sh rename to scripts/slurm_stellar/train_ece_tf_only.sh diff --git a/scripts/slurm/train_filterscopes.sh b/scripts/slurm_stellar/train_filterscopes.sh similarity index 100% rename from scripts/slurm/train_filterscopes.sh rename to scripts/slurm_stellar/train_filterscopes.sh diff --git a/scripts/slurm/train_foundation_model.sh b/scripts/slurm_stellar/train_foundation_model.sh similarity index 100% rename from scripts/slurm/train_foundation_model.sh rename to scripts/slurm_stellar/train_foundation_model.sh diff --git a/scripts/slurm/train_foundation_model_debug.sh b/scripts/slurm_stellar/train_foundation_model_debug.sh similarity index 100% rename from scripts/slurm/train_foundation_model_debug.sh rename to scripts/slurm_stellar/train_foundation_model_debug.sh diff --git a/scripts/slurm/train_i_coil.sh b/scripts/slurm_stellar/train_i_coil.sh similarity index 100% rename from scripts/slurm/train_i_coil.sh rename to scripts/slurm_stellar/train_i_coil.sh diff --git a/scripts/slurm/train_ich.sh b/scripts/slurm_stellar/train_ich.sh similarity index 100% rename from scripts/slurm/train_ich.sh rename to scripts/slurm_stellar/train_ich.sh diff --git a/scripts/slurm/train_langmuir.sh b/scripts/slurm_stellar/train_langmuir.sh similarity index 100% rename from scripts/slurm/train_langmuir.sh rename to scripts/slurm_stellar/train_langmuir.sh diff --git a/scripts/slurm/train_mhr.sh b/scripts/slurm_stellar/train_mhr.sh similarity index 100% rename from scripts/slurm/train_mhr.sh rename to scripts/slurm_stellar/train_mhr.sh diff --git a/scripts/slurm/train_mhr_conv_dw_ft.sh b/scripts/slurm_stellar/train_mhr_conv_dw_ft.sh similarity index 100% rename from scripts/slurm/train_mhr_conv_dw_ft.sh rename to scripts/slurm_stellar/train_mhr_conv_dw_ft.sh diff --git a/scripts/slurm/train_mhr_tf_only.sh b/scripts/slurm_stellar/train_mhr_tf_only.sh similarity index 100% rename from scripts/slurm/train_mhr_tf_only.sh rename to scripts/slurm_stellar/train_mhr_tf_only.sh diff --git a/scripts/slurm/train_mhr_tf_only_multinode.sh b/scripts/slurm_stellar/train_mhr_tf_only_multinode.sh similarity index 100% rename from scripts/slurm/train_mhr_tf_only_multinode.sh rename to scripts/slurm_stellar/train_mhr_tf_only_multinode.sh diff --git a/scripts/slurm/train_mhr_weighted_mse.sh b/scripts/slurm_stellar/train_mhr_weighted_mse.sh similarity index 100% rename from scripts/slurm/train_mhr_weighted_mse.sh rename to scripts/slurm_stellar/train_mhr_weighted_mse.sh diff --git a/scripts/slurm/train_mirnov.sh b/scripts/slurm_stellar/train_mirnov.sh similarity index 100% rename from scripts/slurm/train_mirnov.sh rename to scripts/slurm_stellar/train_mirnov.sh diff --git a/scripts/slurm/train_mse.sh b/scripts/slurm_stellar/train_mse.sh similarity index 100% rename from scripts/slurm/train_mse.sh rename to scripts/slurm_stellar/train_mse.sh diff --git a/scripts/slurm/train_multimodal.sh b/scripts/slurm_stellar/train_multimodal.sh similarity index 100% rename from scripts/slurm/train_multimodal.sh rename to scripts/slurm_stellar/train_multimodal.sh diff --git a/scripts/slurm/train_neutron_rate.sh b/scripts/slurm_stellar/train_neutron_rate.sh similarity index 100% rename from scripts/slurm/train_neutron_rate.sh rename to scripts/slurm_stellar/train_neutron_rate.sh diff --git a/scripts/slurm/train_spectrogram_ae.sh b/scripts/slurm_stellar/train_spectrogram_ae.sh similarity index 100% rename from scripts/slurm/train_spectrogram_ae.sh rename to scripts/slurm_stellar/train_spectrogram_ae.sh diff --git a/scripts/slurm/train_sxr.sh b/scripts/slurm_stellar/train_sxr.sh similarity index 100% rename from scripts/slurm/train_sxr.sh rename to scripts/slurm_stellar/train_sxr.sh diff --git a/scripts/slurm/train_ts_core_density.sh b/scripts/slurm_stellar/train_ts_core_density.sh similarity index 100% rename from scripts/slurm/train_ts_core_density.sh rename to scripts/slurm_stellar/train_ts_core_density.sh diff --git a/scripts/slurm/train_ts_core_temp.sh b/scripts/slurm_stellar/train_ts_core_temp.sh similarity index 100% rename from scripts/slurm/train_ts_core_temp.sh rename to scripts/slurm_stellar/train_ts_core_temp.sh diff --git a/scripts/slurm/train_ts_tangential_density.sh b/scripts/slurm_stellar/train_ts_tangential_density.sh similarity index 100% rename from scripts/slurm/train_ts_tangential_density.sh rename to scripts/slurm_stellar/train_ts_tangential_density.sh diff --git a/scripts/slurm/train_ts_tangential_temp.sh b/scripts/slurm_stellar/train_ts_tangential_temp.sh similarity index 100% rename from scripts/slurm/train_ts_tangential_temp.sh rename to scripts/slurm_stellar/train_ts_tangential_temp.sh diff --git a/scripts/slurm/train_unimodal.sh b/scripts/slurm_stellar/train_unimodal.sh similarity index 100% rename from scripts/slurm/train_unimodal.sh rename to scripts/slurm_stellar/train_unimodal.sh diff --git a/scripts/slurm/train_vib.sh b/scripts/slurm_stellar/train_vib.sh similarity index 100% rename from scripts/slurm/train_vib.sh rename to scripts/slurm_stellar/train_vib.sh diff --git a/scripts/slurm/train_video_ae.sh b/scripts/slurm_stellar/train_video_ae.sh similarity index 100% rename from scripts/slurm/train_video_ae.sh rename to scripts/slurm_stellar/train_video_ae.sh diff --git a/scripts/training/benchmark_attn_kernels.py b/scripts/training/benchmark_attn_kernels.py new file mode 100644 index 0000000..4a2f3b2 --- /dev/null +++ b/scripts/training/benchmark_attn_kernels.py @@ -0,0 +1,299 @@ +"""Kernel-level benchmark: flash-attn vs standard attention on MI250X. + +Compares four self-attention implementations on synthetic (q, k, v) of +realistic transformer shapes, on one MI250X GCD: + + flash_ext : flash_attn.flash_attn_func (external pkg, Triton-AMD/aiter) + sdpa_math : torch.nn.functional.scaled_dot_product_attention, math + backend forced (the "standard" path — what we use today) + sdpa_flash : F.scaled_dot_product_attention, flash backend forced + (PyTorch native, uses AOTriton on ROCm 7.x — completely + different code path from flash_ext) + sdpa_auto : F.scaled_dot_product_attention with defaults (PyTorch + picks; useful as a "what does torch want" reference) + +Measures forward time, backward time, peak alloc. Reports a markdown +table to stdout and a JSON dump. + +Why: the e2e profile measured flash_ext as 19% slower / 3.78× memory +than nn.MultiheadAttention at the e2e Stage 1 shape (head_dim=32, +seq_len≈26). Before concluding flash-attn is bad on Frontier, we need +a sanity check at shapes where flash should obviously win. +""" + +from __future__ import annotations + +import argparse +import json +import time +from contextlib import nullcontext +from pathlib import Path +from typing import Callable + +import torch +import torch.nn.functional as F + +try: + from torch.nn.attention import SDPBackend, sdpa_kernel +except ImportError: + SDPBackend = None + sdpa_kernel = None + +try: + from flash_attn import flash_attn_func as _flash_attn_func +except ImportError: + _flash_attn_func = None + + +def make_qkv( + batch: int, seq_len: int, n_heads: int, head_dim: int, + layout: str, dtype: torch.dtype, device: torch.device, + requires_grad: bool, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Allocate (q, k, v) in the layout the impl expects. + + layout='bhsd' for SDPA (batch, heads, seq, dim); + layout='bshd' for flash_attn_func (batch, seq, heads, dim). + """ + if layout == "bhsd": + shape = (batch, n_heads, seq_len, head_dim) + elif layout == "bshd": + shape = (batch, seq_len, n_heads, head_dim) + else: + raise ValueError(layout) + q = torch.randn(shape, dtype=dtype, device=device, requires_grad=requires_grad) + k = torch.randn(shape, dtype=dtype, device=device, requires_grad=requires_grad) + v = torch.randn(shape, dtype=dtype, device=device, requires_grad=requires_grad) + return q, k, v + + +def run_flash_ext(q, k, v): + # flash_attn_func expects (B, S, H, D) + return _flash_attn_func(q, k, v, causal=False) + + +def _sdpa_with_backend(backend): + def _call(q, k, v): + # SDPA expects (B, H, S, D) + ctx = sdpa_kernel(backend) if (sdpa_kernel and backend is not None) else nullcontext() + with ctx: + return F.scaled_dot_product_attention(q, k, v, is_causal=False) + return _call + + +_MHA_CACHE: dict = {} + + +def _get_nn_mha(d_model: int, n_heads: int, dtype, device) -> torch.nn.MultiheadAttention: + """Cache an nn.MultiheadAttention so we don't re-init every call. + + Constructed in fp32 then cast — matches typical autocast-style usage. + """ + key = (d_model, n_heads, dtype) + mha = _MHA_CACHE.get(key) + if mha is None: + mha = torch.nn.MultiheadAttention( + d_model, n_heads, dropout=0.0, batch_first=True, bias=True, + ).to(device=device, dtype=dtype) + _MHA_CACHE[key] = mha + return mha + + +def run_nn_mha(q, k, v): + """Match stage1/2's current backbone: nn.MultiheadAttention(h, h, h). + + Input layout is (B, S, H, D); we collapse heads*dim → embed for MHA, then + re-split on output. need_weights=False is the path that *could* dispatch + to SDPA internally — this measurement tells us whether it actually does. + """ + B, S, H, D = q.shape + embed = H * D + qh = q.reshape(B, S, embed) + # MHA does its own Q/K/V projection; matching the pattern in the backbone + # which calls self.attn(h, h, h, need_weights=False). + mha = _get_nn_mha(embed, H, q.dtype, q.device) + out, _ = mha(qh, qh, qh, need_weights=False) + return out.reshape(B, S, H, D) + + +def time_fn_fwd_bwd( + fn: Callable, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + n_warmup: int, n_iters: int, do_bwd: bool, +) -> dict: + """Time fn(q, k, v) forward (and optionally backward). + + Returns dict with fwd_ms, bwd_ms (or None), peak_alloc_GB. + """ + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + # Warmup + for _ in range(n_warmup): + out = fn(q, k, v) + if do_bwd: + out.sum().backward() + q.grad = k.grad = v.grad = None + torch.cuda.synchronize() + + # Forward timing + fwd_start = torch.cuda.Event(enable_timing=True) + fwd_end = torch.cuda.Event(enable_timing=True) + fwd_start.record() + outs = [] + for _ in range(n_iters): + out = fn(q, k, v) + outs.append(out) + fwd_end.record() + torch.cuda.synchronize() + fwd_ms = fwd_start.elapsed_time(fwd_end) / n_iters + + bwd_ms = None + if do_bwd: + bwd_start = torch.cuda.Event(enable_timing=True) + bwd_end = torch.cuda.Event(enable_timing=True) + bwd_start.record() + for out in outs: + out.sum().backward(retain_graph=False) + q.grad = k.grad = v.grad = None + bwd_end.record() + torch.cuda.synchronize() + bwd_ms = bwd_start.elapsed_time(bwd_end) / n_iters + + peak_alloc_gb = torch.cuda.max_memory_allocated() / 1e9 + return {"fwd_ms": fwd_ms, "bwd_ms": bwd_ms, "peak_alloc_GB": peak_alloc_gb} + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--out_dir", type=Path, required=True) + p.add_argument("--batch", type=int, default=4) + p.add_argument("--n_heads", type=int, default=16) + p.add_argument("--head_dims", type=int, nargs="+", default=[32, 64, 128]) + p.add_argument("--seq_lens", type=int, nargs="+", + default=[32, 128, 512, 2048, 4096]) + p.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16") + p.add_argument("--n_warmup", type=int, default=3) + p.add_argument("--n_iters", type=int, default=10) + p.add_argument("--no_bwd", action="store_true") + args = p.parse_args() + + assert torch.cuda.is_available(), "no CUDA/HIP device visible" + device = torch.device("cuda") + dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 + args.out_dir.mkdir(parents=True, exist_ok=True) + + print(f"device: {torch.cuda.get_device_name(0)}") + print(f"dtype : {dtype}") + print(f"shapes: batch={args.batch} n_heads={args.n_heads} " + f"head_dims={args.head_dims} seq_lens={args.seq_lens}") + print(f"flash_attn package: {'installed' if _flash_attn_func else 'MISSING'}") + print(f"sdpa_kernel ctx : {'available' if sdpa_kernel else 'MISSING (old torch)'}") + print() + + # Compose impl list. Skip flash_ext if package missing; skip sdpa_flash if + # the ctx manager is missing (very old torch). + impls: list[tuple[str, str, Callable]] = [] # (name, layout, fn) + if _flash_attn_func is not None: + impls.append(("flash_ext", "bshd", run_flash_ext)) + if sdpa_kernel is not None: + impls.append(("sdpa_math", "bhsd", _sdpa_with_backend(SDPBackend.MATH))) + impls.append(("sdpa_flash", "bhsd", _sdpa_with_backend(SDPBackend.FLASH_ATTENTION))) + impls.append(("sdpa_auto", "bhsd", _sdpa_with_backend(None))) + # The one we actually use in production today: nn.MultiheadAttention via + # backbone.py. Tells us whether it dispatches to SDPA internally on this + # PyTorch+ROCm build. + impls.append(("nn_mha", "bshd", run_nn_mha)) + + rows: list[dict] = [] + for head_dim in args.head_dims: + for seq_len in args.seq_lens: + print(f"-- head_dim={head_dim} seq_len={seq_len} --") + for name, layout, fn in impls: + try: + q, k, v = make_qkv( + args.batch, seq_len, args.n_heads, head_dim, + layout, dtype, device, + requires_grad=not args.no_bwd, + ) + res = time_fn_fwd_bwd( + fn, q, k, v, + n_warmup=args.n_warmup, n_iters=args.n_iters, + do_bwd=not args.no_bwd, + ) + rows.append({ + "impl": name, "head_dim": head_dim, "seq_len": seq_len, + "batch": args.batch, "n_heads": args.n_heads, + "dtype": args.dtype, **res, + }) + bwd_str = f" bwd={res['bwd_ms']:7.2f}ms" if res["bwd_ms"] else "" + print( + f" {name:<10} fwd={res['fwd_ms']:7.2f}ms" + f"{bwd_str} peak={res['peak_alloc_GB']:5.2f}GB" + ) + except Exception as e: + print(f" {name:<10} FAILED: {type(e).__name__}: {e}") + rows.append({ + "impl": name, "head_dim": head_dim, "seq_len": seq_len, + "batch": args.batch, "n_heads": args.n_heads, + "dtype": args.dtype, "error": f"{type(e).__name__}: {e}", + }) + finally: + del q, k, v + torch.cuda.empty_cache() + print() + + # Markdown summary + md_path = args.out_dir / "summary.md" + json_path = args.out_dir / "results.json" + with json_path.open("w") as f: + json.dump({"args": vars(args) | {"out_dir": str(args.out_dir)}, "rows": rows}, f, + indent=2, default=str) + + # Table: for each (head_dim, seq_len), show ratio of each impl vs sdpa_math + lines: list[str] = [] + lines.append( + f"# Attention kernel benchmark ({torch.cuda.get_device_name(0)}, " + f"{args.dtype}, batch={args.batch}, n_heads={args.n_heads})" + ) + lines.append("") + lines.append("Forward + backward time in ms (lower is better). " + "Peak alloc in GB. `× math` = ratio of total time to sdpa_math.") + lines.append("") + grouped: dict[tuple[int, int], dict[str, dict]] = {} + for r in rows: + if "error" in r: + continue + key = (r["head_dim"], r["seq_len"]) + grouped.setdefault(key, {})[r["impl"]] = r + for (head_dim, seq_len), impl_map in sorted(grouped.items()): + lines.append(f"## head_dim={head_dim}, seq_len={seq_len}") + lines.append("") + lines.append("| impl | fwd (ms) | bwd (ms) | total (ms) | × math | peak (GB) |") + lines.append("|---|---:|---:|---:|---:|---:|") + base = impl_map.get("sdpa_math") + base_total = (base["fwd_ms"] + (base["bwd_ms"] or 0)) if base else None + for impl_name in ("sdpa_math", "sdpa_flash", "sdpa_auto", "flash_ext", "nn_mha"): + if impl_name not in impl_map: + continue + r = impl_map[impl_name] + total = r["fwd_ms"] + (r["bwd_ms"] or 0) + ratio = f"{total / base_total:5.2f}" if base_total else " n/a" + bwd_str = f"{r['bwd_ms']:.2f}" if r["bwd_ms"] else "—" + lines.append( + f"| {impl_name} | {r['fwd_ms']:.2f} | {bwd_str} | " + f"{total:.2f} | {ratio} | {r['peak_alloc_GB']:.2f} |" + ) + lines.append("") + md = "\n".join(lines) + with md_path.open("w") as f: + f.write(md) + print() + print("=" * 60) + print(md) + print("=" * 60) + print(f"\nJSON: {json_path}") + print(f"MD : {md_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/training/memory_probe_e2e.py b/scripts/training/memory_probe_e2e.py new file mode 100644 index 0000000..175f812 --- /dev/null +++ b/scripts/training/memory_probe_e2e.py @@ -0,0 +1,282 @@ +"""Memory-ceiling probe for the e2e model at scaled-up sizes. + +Constructs ``E2EFoundationModel`` at a configurable size, generates synthetic +inputs matching each modality's expected shape, and runs one forward + +backward under bf16 autocast. Prints peak memory and param count. + +Use to find the largest model that fits on one MI250X GCD under various +combinations of `attn_impl` and `gradient_checkpoint`. Reports both the +single-step ("stage 1") and K-step rollout ("stage 2") cases. + +Typical usage (inside a 1-GCD SLURM allocation): + + python scripts/training/memory_probe_e2e.py \\ + --d_model 1024 --n_layers 24 --n_heads 16 \\ + --batch_size 4 --K_rollout 1 \\ + --attn_impl sdpa --gradient_checkpoint +""" + +from __future__ import annotations + +import argparse +import gc +import sys +import time +from pathlib import Path + +import torch + +# Resolve train_e2e_stage1 without installing as a package. +sys.path.insert(0, str(Path(__file__).parent)) + +from tokamak_foundation_model.e2e.model import E2EFoundationModel # noqa: E402 +from train_e2e_stage1 import ( # type: ignore # noqa: E402 + SPECTROGRAM_MODALITIES, + VIDEO_MODALITIES, + build_configs, +) + + +def make_synthetic_inputs( + diagnostics, actuators, batch: int, device: torch.device, dtype: torch.dtype, +): + """Random tensors matching each modality's expected (channels, *spatial, samples). + + Mirrors the layout the real tokenizers expect: see the SlowTimeSeriesTokenizer, + FastTimeSeriesTokenizer, VideoTokenizer, SpectrogramTokenizer ctors and the + forward signatures in tokenizers.py. + """ + diag_in: dict[str, torch.Tensor] = {} + for d in diagnostics: + if d.kind in ("slow_ts", "fast_ts"): + diag_in[d.name] = torch.randn( + batch, d.n_channels, d.window_samples, device=device, dtype=dtype + ) + elif d.kind == "video": + assert d.height is not None and d.width is not None + # VideoTokenizer's patch_embed is a Conv3d expecting + # (B, n_channels, T, H, W). For tangtv n_channels=2. + diag_in[d.name] = torch.randn( + batch, d.n_channels, d.window_samples, d.height, d.width, + device=device, dtype=dtype, + ) + elif d.kind == "spectrogram": + assert d.freq_bins is not None + diag_in[d.name] = torch.randn( + batch, d.n_channels, d.freq_bins, d.window_samples, + device=device, dtype=dtype, + ) + else: + raise ValueError(d.kind) + act_in = { + a.name: torch.randn( + batch, a.n_channels, a.window_samples, device=device, dtype=dtype + ) + for a in actuators + } + return diag_in, act_in + + +class BF16AdamW(torch.optim.AdamW): + """AdamW that allocates ``exp_avg`` / ``exp_avg_sq`` state in bf16. + + Default AdamW allocates state with ``torch.zeros_like(p)`` which inherits + the param's dtype (fp32 under our bf16-autocast setup). That doubles the + optimizer-state footprint relative to bf16. This subclass intercepts state + init and forces bf16, halving Adam's m+v from ~16 to ~8 bytes/param. + + Note: this is a memory-probe approximation. Real bf16 Adam needs + stochastic rounding on the m, v updates to avoid quantization bias — + libraries like bitsandbytes (AdamW8bit) and DeepSpeed (bf16 optimizer) + handle that. We don't, because we only care about memory here, not the + optimizer's numerical behavior. + + CURRENTLY BROKEN. The naive approach (allocate state in bf16, let the + parent step() handle the rest) hits dtype mismatches in both paths: + - foreach=True (default): "Tensors of the same index must be on the + same device and the same dtype..." + - foreach=False: `exp_avg.lerp_(grad, ...)` strictly requires matching + dtypes — bf16 state + fp32 grad fails. + A correct implementation would either (a) cast grads to bf16 just before + step, (b) upcast m,v to fp32 transiently inside a custom step, or + (c) bring in bitsandbytes / DeepSpeed. None of those is worth the + iteration cost right now — use fp32 AdamW and account for bf16 savings + analytically (saves ~8 bytes/param). + """ + + def __init__(self, params, *args, **kwargs) -> None: + kwargs.setdefault("foreach", False) + kwargs.setdefault("fused", False) + super().__init__(params, *args, **kwargs) + + @torch.no_grad() + def step(self, closure=None): # type: ignore[override] + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + state = self.state[p] + if len(state) == 0: + state["step"] = torch.tensor(0.0) + state["exp_avg"] = torch.zeros_like( + p, dtype=torch.bfloat16, memory_format=torch.preserve_format, + ) + state["exp_avg_sq"] = torch.zeros_like( + p, dtype=torch.bfloat16, memory_format=torch.preserve_format, + ) + if group.get("amsgrad", False): + state["max_exp_avg_sq"] = torch.zeros_like( + p, dtype=torch.bfloat16, + memory_format=torch.preserve_format, + ) + return super().step(closure) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--d_model", type=int, default=1024) + p.add_argument("--n_layers", type=int, default=24) + p.add_argument("--n_heads", type=int, default=16) + p.add_argument("--mlp_ratio", type=float, default=4.0) + p.add_argument("--dropout", type=float, default=0.0) + p.add_argument("--batch_size", type=int, default=4) + p.add_argument("--chunk_duration_s", type=float, default=0.05) + p.add_argument( + "--use_video", nargs="*", + default=["tangtv"], + choices=[e[0] for e in VIDEO_MODALITIES], + ) + p.add_argument( + "--use_spectro", nargs="*", + default=["ece", "co2", "bes"], + choices=[e[0] for e in SPECTROGRAM_MODALITIES], + ) + p.add_argument( + "--attn_impl", choices=["standard", "sdpa", "flash"], default="standard", + ) + p.add_argument("--gradient_checkpoint", action="store_true") + p.add_argument( + "--K_rollout", type=int, default=1, + help="Simulate K-step rollout: repeat forward K times, backprop " + "through the chain (matches stage-2 memory pattern).", + ) + p.add_argument("--no_amp", action="store_true", + help="Disable bf16 autocast (debug only).") + p.add_argument( + "--bf16_optim_state", action="store_true", + help="Store Adam's m, v moments in bf16 instead of fp32. Halves the " + "optimizer-state memory (saves ~8 bytes/param). Memory-probe " + "approximation: real training would want stochastic rounding to " + "avoid divergence — see bitsandbytes/AdamW8bit or DeepSpeed bf16.", + ) + args = p.parse_args() + + assert torch.cuda.is_available(), "No CUDA/HIP device visible" + device = torch.device("cuda") + dtype = torch.float32 # inputs in fp32; autocast handles bf16 internally + print(f"device: {torch.cuda.get_device_name(0)}") + print(f"config: d_model={args.d_model} n_layers={args.n_layers} " + f"n_heads={args.n_heads} attn_impl={args.attn_impl} " + f"grad_ckpt={args.gradient_checkpoint} K_rollout={args.K_rollout}") + + diagnostics, actuators = build_configs( + args.chunk_duration_s, + use_video=args.use_video, + use_spectro=args.use_spectro, + ) + print(f"diagnostics: {[d.name for d in diagnostics]}") + print(f"actuators : {[a.name for a in actuators]}") + + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + mem_pre_model = torch.cuda.memory_allocated() / 1e9 + + model = E2EFoundationModel( + diagnostics=diagnostics, actuators=actuators, + d_model=args.d_model, n_heads=args.n_heads, n_layers=args.n_layers, + mlp_ratio=args.mlp_ratio, dropout=args.dropout, + attn_impl=args.attn_impl, + gradient_checkpoint=args.gradient_checkpoint, + ).to(device) + model.train() + n_params = sum(p.numel() for p in model.parameters()) + n_total_tokens = model.n_total_tokens + + mem_after_model = torch.cuda.memory_allocated() / 1e9 + print() + print(f"params : {n_params/1e6:.1f}M") + print(f"n_total_tokens: {n_total_tokens}") + print(f"weight mem : {mem_after_model - mem_pre_model:.2f} GB " + f"(should be ~{n_params * 4 / 1e9:.2f} GB at fp32)") + + if args.bf16_optim_state: + # WARNING: this path is currently broken — see BF16AdamW docstring. + # Use bitsandbytes / DeepSpeed in real training for bf16 Adam state. + optim = BF16AdamW(model.parameters(), lr=1e-4, weight_decay=0.1) + else: + optim = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.1) + + diag_in, act_in = make_synthetic_inputs( + diagnostics, actuators, args.batch_size, device, dtype, + ) + step_index = torch.zeros(args.batch_size, dtype=torch.long, device=device) + time_offset_s = torch.zeros(args.batch_size, dtype=dtype, device=device) + + # Reset peak so we measure only the forward+backward window + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + mem_at_start = torch.cuda.memory_allocated() / 1e9 + t0 = time.perf_counter() + + ctx = (torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + if not args.no_amp else + torch.amp.autocast(device_type="cuda", enabled=False)) + + try: + optim.zero_grad(set_to_none=True) + loss = torch.zeros((), device=device) + with ctx: + # K-step rollout: forward K times, accumulating loss. Each forward + # holds activations needed for backward, matching stage 2's pattern. + for k in range(args.K_rollout): + outputs = model(diag_in, act_in, step_index + k, time_offset_s) + # model returns Dict[str, Tensor] (per-modality reconstructions). + # Cheap proxy loss — sum of squared outputs across all + # modalities. We only care about making backprop happen, not + # the loss value. + for v in outputs.values(): + loss = loss + (v.float() ** 2).mean() + loss.backward() + # optim.step() materializes Adam's m, v state tensors (~8 bytes/param + # in fp32) on first call. Including it gives a realistic training-step + # memory peak — otherwise we under-count by ~8 GB at the 1B scale. + optim.step() + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + peak = torch.cuda.max_memory_allocated() / 1e9 + reserved = torch.cuda.max_memory_reserved() / 1e9 + print() + print(f"forward+backward+step time: {elapsed:.2f} s") + print(f"peak alloc : {peak:.2f} GB") + print(f"peak reserved : {reserved:.2f} GB") + print(f"loss : {loss.item():.4f} (sanity)") + print() + print("SUCCESS — model + step fit on this GCD.") + except torch.cuda.OutOfMemoryError as e: + peak = torch.cuda.max_memory_allocated() / 1e9 + reserved = torch.cuda.max_memory_reserved() / 1e9 + print() + print(f"OOM during forward+backward.") + print(f"peak alloc at OOM : {peak:.2f} GB") + print(f"peak reserved at OOM : {reserved:.2f} GB") + print(f"error: {e}") + sys.exit(1) + finally: + # Clean up before exit so SLURM reports a sensible final state. + del diag_in, act_in, optim, model + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/scripts/training/profile_stage1.py b/scripts/training/profile_stage1.py index 8b371b5..ea6c863 100644 --- a/scripts/training/profile_stage1.py +++ b/scripts/training/profile_stage1.py @@ -27,6 +27,7 @@ from __future__ import annotations import argparse +import json import sys import time from pathlib import Path @@ -41,6 +42,8 @@ from tokamak_foundation_model.data.data_loader import collate_fn from tokamak_foundation_model.e2e.model import E2EFoundationModel from train_e2e_stage1 import ( # type: ignore + SPECTROGRAM_MODALITIES, + VIDEO_MODALITIES, build_configs, build_datasets, compute_step_loss, @@ -62,16 +65,40 @@ def main() -> None: ) p.add_argument("--batch_size", type=int, default=256) p.add_argument("--num_workers", type=int, default=8) + p.add_argument( + "--max_files", type=int, default=15, + help="Cap on shot files used for profiling. Default 15 — profiling " + "only needs enough chunks to fill the active window, and " + "scanning the full ~7878-file train set blows the wallclock.", + ) p.add_argument("--chunk_duration_s", type=float, default=0.05) p.add_argument("--prediction_horizon_s", type=float, default=0.05) p.add_argument("--step_size_s", type=float, default=0.01) p.add_argument("--warmup_s", type=float, default=1.0) p.add_argument("--d_model", type=int, default=256) - p.add_argument("--n_layers", type=int, default=8) + p.add_argument("--n_layers", type=int, default=26) p.add_argument("--n_heads", type=int, default=8) p.add_argument("--dropout", type=float, default=0.1) p.add_argument("--val_fraction", type=float, default=0.1) p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--use_video", nargs="*", default=[], + choices=[entry[0] for entry in VIDEO_MODALITIES], + help="Camera names to include as video modalities (match canonical run).", + ) + p.add_argument( + "--use_spectro", nargs="*", default=[], + choices=[entry[0] for entry in SPECTROGRAM_MODALITIES], + help="Spectrogram modality names to include (match canonical run).", + ) + p.add_argument( + "--no_amp_val", action="store_true", + help="Accepted for parity with train_e2e_stage1; unused here (no validation).", + ) + p.add_argument( + "--use_flash_attn", action="store_true", + help="Use flash-attention 2 in the backbone (requires flash_attn package).", + ) # Profiler schedule: (wait, warmup, active). ``wait`` skips the dataloader # spin-up transient; ``warmup`` primes caches so the active window is # steady-state; ``active`` is what gets recorded. @@ -85,7 +112,11 @@ def main() -> None: print(f"Device: {device}") print(f"num_workers={args.num_workers} batch_size={args.batch_size}") - diagnostics, actuators = build_configs(args.chunk_duration_s) + diagnostics, actuators = build_configs( + args.chunk_duration_s, + use_video=args.use_video, + use_spectro=args.use_spectro, + ) diag_names = [c.name for c in diagnostics] act_names = [c.name for c in actuators] print(f"Diagnostics ({len(diag_names)}): {diag_names}") @@ -94,7 +125,7 @@ def main() -> None: train_files, val_files = resolve_shot_files( data_dir=args.data_dir, train_shots_yaml=None, val_shots_yaml=None, - max_files=None, val_fraction=args.val_fraction, seed=args.seed, + max_files=args.max_files, val_fraction=args.val_fraction, seed=args.seed, ) print(f"Train files: {len(train_files)} val: {len(val_files)}") @@ -126,6 +157,7 @@ def main() -> None: persistent_workers=args.num_workers > 0, ) + attn_impl = "flash" if args.use_flash_attn else "standard" model = E2EFoundationModel( diagnostics=diagnostics, actuators=actuators, @@ -133,10 +165,11 @@ def main() -> None: n_layers=args.n_layers, n_heads=args.n_heads, dropout=args.dropout, + attn_impl=attn_impl, ).to(device) opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.1) n_params = sum(p.numel() for p in model.parameters()) / 1e6 - print(f"Model params: {n_params:.2f}M") + print(f"Model params: {n_params:.2f}M attn_impl={attn_impl}") total_steps = args.profile_wait + args.profile_warmup + args.profile_active print( @@ -174,12 +207,15 @@ def on_ready(prof_obj: profile) -> None: model.train() step_times: list[float] = [] + active_start = args.profile_wait + args.profile_warmup t_start = time.time() prof.start() for step, batch in enumerate(loader): if step >= total_steps: break + if step == active_start and device.type == "cuda": + torch.cuda.reset_peak_memory_stats() s = time.perf_counter() opt.zero_grad(set_to_none=True) loss, _ = compute_step_loss(model, batch, device) @@ -196,15 +232,46 @@ def on_ready(prof_obj: profile) -> None: print(f"Total wall time: {time.time() - t_start:.1f} s") print(f"Per-step wall times (s): " + " ".join(f"{t:.2f}" for t in step_times)) - active_slice = step_times[args.profile_wait + args.profile_warmup:] + active_slice = step_times[active_start:] + active_mean = (sum(active_slice) / len(active_slice)) if active_slice else float("nan") if active_slice: print( f"Active-window mean: " - f"{sum(active_slice) / len(active_slice):.2f} s/step " + f"{active_mean:.3f} s/step " f"(over {len(active_slice)} steps)" ) + + peak_alloc_gb = 0.0 + peak_reserved_gb = 0.0 + if device.type == "cuda": + peak_alloc_gb = torch.cuda.max_memory_allocated() / 1e9 + peak_reserved_gb = torch.cuda.max_memory_reserved() / 1e9 + print( + f"Active-window peak memory: " + f"alloc={peak_alloc_gb:.2f} GB reserved={peak_reserved_gb:.2f} GB" + ) + + memory_json = { + "attn_impl": attn_impl, + "n_layers": args.n_layers, + "d_model": args.d_model, + "n_heads": args.n_heads, + "batch_size": args.batch_size, + "use_video": list(args.use_video), + "use_spectro": list(args.use_spectro), + "active_steps": len(active_slice), + "active_mean_step_s": active_mean, + "throughput_steps_per_s": (1.0 / active_mean) if active_slice and active_mean > 0 else None, + "peak_alloc_GB": peak_alloc_gb, + "peak_reserved_GB": peak_reserved_gb, + } + mem_path = args.output_dir / "memory.json" + with mem_path.open("w") as f: + json.dump(memory_json, f, indent=2) + print(f"Trace : {trace_path}") print(f"Summary: {summary_path}") + print(f"Memory: {mem_path}") print("Open the trace in chrome://tracing or Perfetto.") diff --git a/scripts/training/train_e2e_stage1.py b/scripts/training/train_e2e_stage1.py index d753d38..07fbb8f 100644 --- a/scripts/training/train_e2e_stage1.py +++ b/scripts/training/train_e2e_stage1.py @@ -39,10 +39,10 @@ import torch.nn.functional as F import yaml from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from tokamak_foundation_model.data.data_loader import collate_fn from tokamak_foundation_model.data.multi_file_dataset import ( + DistributedTwoLevelSampler, TokamakMultiFileDataset, TwoLevelSampler, filter_video_present_files, @@ -567,7 +567,16 @@ def validate( max_batches: Optional[int] = None, use_amp: bool = False, ) -> Dict[str, Dict[str, float]]: - """Return per-modality validation metrics. + """Return per-modality validation metrics, computed in a + distribution-aware way. + + The val_loader is assumed to be sharded across ranks (via a + ``DistributedTwoLevelSampler`` with ``shuffle=False``). Each rank + accumulates partial sums on its shard; the totals are all-reduced + once at the end so every rank ends up with the same global metric + values. This replaces the previous "every rank validates everything" + behaviour, which caused host-memory OOMs at 64+ ranks because each + rank held the full val workload in flight independently. ``out[name]`` has keys ``model_mae``, ``copy_mae``, ``pred_delta``, ``tgt_delta``, ``delta_ratio``. @@ -578,10 +587,25 @@ def validate( ``pred_delta ≈ 0``; a model predicting the true dynamics has ``delta_ratio = pred_delta / tgt_delta ∈ [0.8, 1.2]``. """ + import torch.distributed as dist + model.eval() + # Bypass the DDP wrapper for the val forward pass. DDP's pre-forward + # hook (rebuild_buckets logic) was observed to trigger GPU memory + # access faults during validation even under no_grad. The inner + # module's weights are identical across ranks (DDP keeps them in + # sync), so forwarding through it directly produces the same result. + inner = _core(model) + keys = ("model_mae", "copy_mae", "pred_delta", "tgt_delta") - sums = {k: {n: 0.0 for n in diagnostic_names} for k in keys} - n_batches = 0 + M = len(diagnostic_names) + K = len(keys) + # fp32 accumulators regardless of autocast — keeps cross-rank + # all_reduce in fp32 (bf16 all_reduce on RCCL has stability issues) + # and avoids precision loss across many batches. + sums_t = torch.zeros(K, M, device=device, dtype=torch.float32) + n_batches_t = torch.zeros((), device=device, dtype=torch.float32) + name_to_col = {n: j for j, n in enumerate(diagnostic_names)} amp_ctx = ( torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) @@ -590,16 +614,19 @@ def validate( for i, batch in enumerate(loader): if max_batches is not None and i >= max_batches: break + # Only the forward pass runs inside autocast; metric math + # explicitly upcasts to fp32 below. with amp_ctx: predictions, diag_inputs, targets, masks = forward_batch( - model, batch, device + inner, batch, device ) - copy_mod = copy_baseline_mae(batch, _core(model).diagnostics, device) + copy_mod = copy_baseline_mae(batch, inner.diagnostics, device) for name in diagnostic_names: - pred = predictions[name] - inp = diag_inputs[name] - tgt = targets[name] - existing = masks[name] + j = name_to_col[name] + pred = predictions[name].float() + inp = diag_inputs[name].float() + tgt = targets[name].float() + existing = masks[name].float() if masks[name] is not None else None cleaned_pred, mask_p = _clean_and_mask(pred, None) cleaned_tgt, mask_t = _clean_and_mask(tgt, existing) @@ -616,20 +643,31 @@ def validate( (cleaned_tgt - inp).abs() * combined ).sum() / denom - sums["model_mae"][name] += model_mae_v.item() - sums["copy_mae"][name] += copy_mod[name] - sums["pred_delta"][name] += pred_delta.item() - sums["tgt_delta"][name] += tgt_delta.item() - n_batches += 1 - - denom = max(n_batches, 1) + sums_t[0, j] += model_mae_v + sums_t[1, j] += float(copy_mod[name]) + sums_t[2, j] += pred_delta + sums_t[3, j] += tgt_delta + n_batches_t += 1.0 + + # Single all-reduce across ranks (sums + batch count combined into + # contiguous fp32 tensors above). Empty-shard ranks contribute + # zeros and a count of 0, which is the correct behaviour. + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(sums_t, op=dist.ReduceOp.SUM) + dist.all_reduce(n_batches_t, op=dist.ReduceOp.SUM) + + denom = float(n_batches_t.item()) + if denom <= 0.0: + denom = 1.0 + sums = sums_t.detach().cpu().numpy() model.train() out: Dict[str, Dict[str, float]] = {} for name in diagnostic_names: - model_mae = sums["model_mae"][name] / denom - copy_mae = sums["copy_mae"][name] / denom - pred_d = sums["pred_delta"][name] / denom - tgt_d = sums["tgt_delta"][name] / denom + j = name_to_col[name] + model_mae = float(sums[0, j]) / denom + copy_mae = float(sums[1, j]) / denom + pred_d = float(sums[2, j]) / denom + tgt_d = float(sums[3, j]) / denom ratio = pred_d / tgt_d if tgt_d > 1e-8 else float("nan") out[name] = { "model_mae": model_mae, @@ -777,6 +815,15 @@ def main() -> None: parser.add_argument("--data_dir", type=Path, required=True) parser.add_argument("--stats_path", type=Path, required=True) parser.add_argument("--checkpoint_dir", type=Path, required=True) + parser.add_argument( + "--lengths_cache_dir", + type=Path, + default=Path("/lustre/orion/fus187/proj-shared/foundation_model_meta"), + help="Directory for TokamakMultiFileDataset length-cache sidecar " + "files (lengths_e2e_stage1_{train,val}.pt). Defaults to the " + "shared foundation_model_meta dir so all ranks/jobs reuse the " + "same cache.", + ) parser.add_argument("--train_shots_yaml", type=Path, default=None) parser.add_argument("--val_shots_yaml", type=Path, default=None) parser.add_argument("--max_files", type=int, default=None) @@ -793,6 +840,13 @@ def main() -> None: parser.add_argument("--d_model", type=int, default=64) parser.add_argument("--n_layers", type=int, default=4) parser.add_argument("--n_heads", type=int, default=4) + parser.add_argument( + "--gradient_checkpoint", action="store_true", + help="Recompute backbone-block activations during backward instead " + "of storing them. Trades ~30%% extra compute for typically " + "5-10x less activation memory; required to scale n_layers / " + "d_model on a single GCD.", + ) parser.add_argument("--dropout", type=float, default=0.0) # Optim @@ -854,7 +908,29 @@ def main() -> None: "--no_amp", action="store_true", help="Disable bf16 mixed precision (default: AMP on when CUDA).", ) + parser.add_argument( + "--no_amp_val", action="store_true", + help="Disable bf16 autocast during validation only (training still " + "uses AMP if --no_amp not set). Workaround for the GPU memory-" + "access faults seen during distributed validation at n_layers=26 " + "on Frontier ROCm 7.1.1.", + ) + parser.add_argument( + "--use_flash_attn", action="store_true", + help="Use flash-attention 2 (external pkg) in the backbone. Requires " + "the flash_attn package (install via `pixi run -e frontier " + "setup-flash-attn`). On MI250X this is slower than --use_sdpa_attn; " + "prefer that flag instead.", + ) + parser.add_argument( + "--use_sdpa_attn", action="store_true", + help="Use F.scaled_dot_product_attention in the backbone. On ROCm 7.x " + "this dispatches to AOTriton flash-attn and is 1.4-5x faster than the " + "default nn.MultiheadAttention path with substantially less memory.", + ) args = parser.parse_args() + if args.use_flash_attn and args.use_sdpa_attn: + parser.error("--use_flash_attn and --use_sdpa_attn are mutually exclusive") dm = DistributedManager() @@ -906,15 +982,16 @@ def main() -> None: if args.use_video: n_train_before = len(train_files) n_val_before = len(val_files) + args.lengths_cache_dir.mkdir(parents=True, exist_ok=True) train_files = filter_video_present_files( train_files, args.use_video, - cache_path=args.checkpoint_dir / "video_present_train.pt", + cache_path=args.lengths_cache_dir / "video_present_train.pt", ) val_files = filter_video_present_files( val_files, args.use_video, - cache_path=args.checkpoint_dir / "video_present_val.pt", + cache_path=args.lengths_cache_dir / "video_present_val.pt", ) logger.info( f"Video-presence filter ({args.use_video}): " @@ -947,6 +1024,12 @@ def main() -> None: f"Actuators ({len(actuators)}): " + ", ".join(actuator_names) ) + if args.use_flash_attn: + attn_impl = "flash" + elif args.use_sdpa_attn: + attn_impl = "sdpa" + else: + attn_impl = "standard" model = E2EFoundationModel( diagnostics=diagnostics, actuators=actuators, @@ -954,6 +1037,8 @@ def main() -> None: n_heads=args.n_heads, n_layers=args.n_layers, dropout=args.dropout, + attn_impl=attn_impl, + gradient_checkpoint=args.gradient_checkpoint, ).to(device) n_params = sum(p.numel() for p in model.parameters()) n_total_tokens = model.n_total_tokens @@ -961,7 +1046,9 @@ def main() -> None: logger.info( f"Model — d_model={args.d_model} n_layers={args.n_layers} " f"n_heads={args.n_heads} tokens={n_total_tokens} " - f"params={n_params / 1e6:.2f}M ddp={dm.distributed}" + f"params={n_params / 1e6:.2f}M ddp={dm.distributed} " + f"attn_impl={attn_impl} " + f"gradient_checkpoint={args.gradient_checkpoint}" ) # ── Datasets ──────────────────────────────────────────────────────── @@ -976,7 +1063,7 @@ def main() -> None: warmup_s=args.warmup_s, diagnostic_names=diagnostic_names, actuator_names=actuator_names, - lengths_cache_dir=args.checkpoint_dir, + lengths_cache_dir=args.lengths_cache_dir, ) logger.info(f"Chunks — train: {len(train_ds)} val: {len(val_ds)}") @@ -989,10 +1076,14 @@ def _worker_init(_worker_id: int) -> None: torch.set_num_threads(n) if dm.distributed: - # DistributedSampler shards chunk indices across ranks. Loses the - # file-sequential cache locality of TwoLevelSampler — revisit if - # HDF5 open() time becomes a bottleneck under DDP. - train_sampler = DistributedSampler( + # DDP-aware file-level sharding. Preserves TwoLevelSampler's + # per-worker LRU file-handle cache locality (each rank owns a + # fixed slice of the file list, iterates its own files + # sequentially). PyTorch's DistributedSampler, which shards + # chunk indices instead, was observed to make HDF5 open() the + # dominant cost (~12 s/step at 2-GPU DDP vs. ~1 s/step + # single-GPU at the same batch). + train_sampler = DistributedTwoLevelSampler( train_ds, num_replicas=dm.world_size, rank=dm.rank, @@ -1015,16 +1106,41 @@ def _worker_init(_worker_id: int) -> None: persistent_workers=args.num_workers > 0, worker_init_fn=_worker_init, ) + # Distributed validation: shard the val set across ranks so each + # rank validates ~1/world_size of it. Matching the train sampler's + # file-level sharding (preserves LRU file-handle locality and avoids + # the host-OOM that hit at 64 ranks when every rank held the full + # val workload independently). Metrics are all-reduced inside + # validate() so all ranks end up with identical global numbers. + if dm.distributed: + val_sampler = DistributedTwoLevelSampler( + val_ds, + num_replicas=dm.world_size, + rank=dm.rank, + shuffle=False, + seed=args.seed, + drop_last=True, + ) + else: + val_sampler = TwoLevelSampler(val_ds, shuffle=False) + + # Val loader memory budget. Train workers stay alive during val and + # hold their prefetched batches (6 workers x 2 prefetch = 12 in flight + # per rank). With num_workers=6 prefetch=1 the combined peak (18) hits + # ~97% host RAM on 2-node smokes -> OOM territory. Capping val to + # 4 workers x 1 prefetch keeps the combined in-flight at 16 batches, + # within the 502 GB node budget. Workers are torn down at end-of-val. + val_num_workers = min(4, args.num_workers) val_loader = DataLoader( val_ds, batch_size=args.batch_size, - shuffle=False, - num_workers=args.num_workers, + sampler=val_sampler, + num_workers=val_num_workers, collate_fn=collate_fn, drop_last=True, - prefetch_factor=2, + prefetch_factor=1, pin_memory=False, - persistent_workers=args.num_workers > 0, + persistent_workers=False, worker_init_fn=_worker_init, ) @@ -1041,6 +1157,10 @@ def _worker_init(_worker_id: int) -> None: # bf16 mixed precision. bf16 has the same dynamic range as fp32 so # no GradScaler is required; matches train_e2e_stage2_delta.py. use_amp = (not args.no_amp) and device.type == "cuda" + # Separate flag for validation AMP. Defaults to the training value, + # but --no_amp_val turns it off independently as a workaround for + # ROCm-side GPU memory-access faults observed during distributed val. + use_amp_val = use_amp and not args.no_amp_val def amp_ctx_factory(): if use_amp: @@ -1205,7 +1325,7 @@ def amp_ctx_factory(): device, diagnostic_names, max_batches=args.val_max_batches, - use_amp=use_amp, + use_amp=use_amp_val, ) logger.info( "Validation (MAE model vs copy; delta-ratio pred/tgt):" diff --git a/scripts/training/train_e2e_stage2.py b/scripts/training/train_e2e_stage2.py index 6b5430f..c597bca 100644 --- a/scripts/training/train_e2e_stage2.py +++ b/scripts/training/train_e2e_stage2.py @@ -502,6 +502,19 @@ def main() -> None: parser.add_argument("--d_model", type=int, default=256) parser.add_argument("--n_layers", type=int, default=8) parser.add_argument("--n_heads", type=int, default=8) + parser.add_argument( + "--gradient_checkpoint", action="store_true", + help="Recompute backbone-block activations during backward instead " + "of storing them. Especially helpful for K-step rollouts since " + "activation memory otherwise scales as K x layers. Costs ~30%% " + "extra compute.", + ) + parser.add_argument( + "--use_sdpa_attn", action="store_true", + help="Use F.scaled_dot_product_attention in the backbone. On ROCm 7.x " + "this dispatches to AOTriton flash-attn and is 1.4-5x faster with " + "substantially less memory than the default nn.MultiheadAttention path.", + ) parser.add_argument("--dropout", type=float, default=0.1) # Curriculum @@ -585,6 +598,7 @@ def main() -> None: f"Actuators ({len(actuators)}): " + ", ".join(actuator_names) ) + attn_impl = "sdpa" if args.use_sdpa_attn else "standard" model = E2EFoundationModel( diagnostics=diagnostics, actuators=actuators, @@ -592,6 +606,8 @@ def main() -> None: n_heads=args.n_heads, n_layers=args.n_layers, dropout=args.dropout, + attn_impl=attn_impl, + gradient_checkpoint=args.gradient_checkpoint, ).to(device) if args.init_checkpoint is not None: @@ -618,7 +634,9 @@ def main() -> None: logger.info( f"Model — d_model={args.d_model} n_layers={args.n_layers} " f"n_heads={args.n_heads} tokens={n_total_tokens} " - f"params={n_params / 1e6:.2f}M ddp={dm.distributed}" + f"params={n_params / 1e6:.2f}M ddp={dm.distributed} " + f"attn_impl={attn_impl} " + f"gradient_checkpoint={args.gradient_checkpoint}" ) # ── Datasets ──────────────────────────────────────────────────────── diff --git a/scripts/training/train_e2e_stage2_delta.py b/scripts/training/train_e2e_stage2_delta.py index fcc3381..20794dc 100644 --- a/scripts/training/train_e2e_stage2_delta.py +++ b/scripts/training/train_e2e_stage2_delta.py @@ -42,12 +42,15 @@ from typing import Dict, List, Optional, Tuple import torch +import torch.distributed as dist import torch.nn.functional as F +import torch.utils.checkpoint as torch_ckpt import yaml from torch.utils.data import DataLoader from tokamak_foundation_model.data.data_loader import collate_fn from tokamak_foundation_model.data.multi_file_dataset import ( + DistributedTwoLevelSampler, TokamakMultiFileDataset, TwoLevelSampler, filter_video_present_files, @@ -60,7 +63,18 @@ ) from tokamak_foundation_model.e2e.rollout import TokenSpaceRollout from tokamak_foundation_model.utils.distributed import DistributedManager -from torch.utils.data.distributed import DistributedSampler + +from tokamak_foundation_model.e2e.multimodal import ( + SPECTROGRAM_MODALITIES, + VIDEO_MODALITIES, + append_multimodal_diagnostics, + spectro_loss_gate as _spectro_loss_gate, + spectro_trunc_t as _spectro_trunc_t, + split_spectro_target_by_step, + split_video_target_by_step, + video_loss_gate as _video_loss_gate, + video_standardize_per_bc as _video_standardize_per_bc, +) from tokamak_foundation_model.e2e.multimodal import ( SPECTROGRAM_MODALITIES, @@ -407,6 +421,7 @@ def rollout_forward_loss_delta( video_diag_names: Optional[List[str]] = None, video_n_frames: Optional[Dict[str, int]] = None, spectro_diag_names: Optional[List[str]] = None, + grad_checkpoint_every: int = 0, ) -> Tuple[torch.Tensor, List[Dict[str, Dict[str, float]]]]: """Tokenise step-0, split targets/actuators, run K-step rollout with full backprop, and return (summed loss, per-step per-modality metrics). @@ -456,7 +471,11 @@ def rollout_forward_loss_delta( spectro_target_full: Dict[str, torch.Tensor] = {} spectro_gate: Dict[str, torch.Tensor] = {} spectro_trunc_t: Dict[str, int] = {} - cfg_by_name = {c.name: c for c in rollout.model.diagnostics} + # Use _core(rollout) for the metadata read so this works whether the + # rollout is DDP-wrapped (training) or already unwrapped (validate()). + # DDP only proxies forward(); arbitrary attribute access like .model + # raises AttributeError on the DDP wrapper. + cfg_by_name = {c.name: c for c in _core(rollout).model.diagnostics} for name in spectro_diag_names: raw = batch["targets"][name].to(device).float() cleaned, _ = _clean_and_mask(raw, None) @@ -503,14 +522,43 @@ def rollout_forward_loss_delta( target_per_step.append(tgt_k) mask_per_step.append(mk_k) - result = rollout(diag_initial, act_per_step) + # Gradient checkpointing on the rollout (ported from stage 2 extended). + # When grad_checkpoint_every >= k_steps the entire K-step rollout is one + # checkpoint group: forward activations are discarded; recomputed during + # backward → ~K-fold less activation memory at ~33% step-time penalty. + # Per-group chunking (0 < g < k_steps) needs the chunk_fn pattern from + # stage 2 extended — not ported here. + # + # Bypass DDP inside the checkpointed function (use _core(rollout)) + # to avoid DDP forward hooks firing twice (first forward + recompute + # backward), which on MI250X produces "Memory access fault by GPU". + # DDP's gradient all_reduce still works correctly because the hooks + # are registered on parameters and fire when grads are populated, + # independent of which forward path produced the gradient. + inner_rollout = _core(rollout) + + def _checkpointed_rollout(diag_init, act): + return inner_rollout(diag_init, act).predictions + + if grad_checkpoint_every <= 0: + predictions = rollout(diag_initial, act_per_step).predictions + elif grad_checkpoint_every >= k_steps: + predictions = torch_ckpt.checkpoint( + _checkpointed_rollout, diag_initial, act_per_step, + use_reentrant=False, + ) + else: + raise NotImplementedError( + f"grad_checkpoint_every={grad_checkpoint_every} < " + f"k_steps={k_steps}: per-group chunking is not ported to " + "stage 2 delta. Pass 0 (off) or a value >= k_steps " + f"(single group). Current k_steps={k_steps}." + ) # Video heads emit (B, T, C, H, W); permute per step to (B, C, T, H, W) # so loss / metric paths see a single shape contract. for k in range(k_steps): for name in video_diag_names: - result.predictions[k][name] = ( - result.predictions[k][name].permute(0, 2, 1, 3, 4) - ) + predictions[k][name] = predictions[k][name].permute(0, 2, 1, 3, 4) # Accumulate per-(step, modality) metrics as on-device scalar tensors; # transfer them to CPU once at the end of the forward pass instead of @@ -527,7 +575,7 @@ def rollout_forward_loss_delta( mr_row: List[torch.Tensor] = [] nv_row: List[torch.Tensor] = [] for name in diagnostic_names: - pred = result.predictions[k][name] + pred = predictions[k][name] target = target_per_step[k][name] mask = mask_per_step[k][name] if name in video_diag_names or name in spectro_diag_names: @@ -727,8 +775,22 @@ def validate( mask = mask_per_step[k][name] if name in video_diag_names or name in spectro_diag_names: mae = masked_mae(pred, target, mask).item() + # Spectrogram diag_initial holds the full STFT output + # (e.g. 98 frames at the canonical config) while target + # is sliced to trunc_t (e.g. 96) by + # split_spectro_target_by_step. Truncate the copy + # baseline input to the same time-axis length so + # masked_mae's broadcast doesn't blow up. Video + # diag_initial and per-step target share the same T, + # so no truncation needed there. + if name in spectro_diag_names: + baseline_input = diag_initial[name][ + ..., : spectro_trunc_t[name] + ] + else: + baseline_input = diag_initial[name] copy_mae = masked_mae( - diag_initial[name], target, mask + baseline_input, target, mask ).item() sums[k][name]["model_mae"] += mae sums[k][name]["copy_mae"] += copy_mae @@ -752,6 +814,40 @@ def validate( counts[k][name]["disp"] += 1 rollout.model.train() + + # Aggregate metrics across DDP ranks. With the val loader sharded by + # DistributedTwoLevelSampler each rank holds sums/counts for its own + # ~1/world_size slice; without all_reduce the rank-0 logger would + # print only its slice. Flatten the nested dicts to two fp32 tensors, + # all_reduce(SUM), then unflatten. + if dist.is_available() and dist.is_initialized(): + sum_keys = [ + (k, n, m) + for k in range(K_max) + for n in diagnostic_names + for m in keys + ] + cnt_keys = [ + (k, n, m) + for k in range(K_max) + for n in diagnostic_names + for m in ("mae", "disp") + ] + sum_t = torch.tensor( + [sums[k][n][m] for (k, n, m) in sum_keys], + device=device, dtype=torch.float32, + ) + cnt_t = torch.tensor( + [counts[k][n][m] for (k, n, m) in cnt_keys], + device=device, dtype=torch.float32, + ) + dist.all_reduce(sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(cnt_t, op=dist.ReduceOp.SUM) + for i, (k, n, m) in enumerate(sum_keys): + sums[k][n][m] = float(sum_t[i].item()) + for i, (k, n, m) in enumerate(cnt_keys): + counts[k][n][m] = int(cnt_t[i].item()) + out: Dict[int, Dict[str, Dict[str, float]]] = {} for k in range(K_max): out[k] = {} @@ -824,6 +920,18 @@ def main() -> None: parser.add_argument("--data_dir", type=Path, required=True) parser.add_argument("--stats_path", type=Path, required=True) parser.add_argument("--checkpoint_dir", type=Path, required=True) + parser.add_argument( + "--lengths_cache_dir", + type=Path, + default=Path("/lustre/orion/fus187/proj-shared/foundation_model_meta"), + help="Directory for TokamakMultiFileDataset length-cache sidecar " + "files (lengths_e2e_stage2_delta_{train,val}.pt) and the " + "video-presence cache (video_present_{train,val}.pt). Defaults " + "to the same shared dir Stage 1 uses so the video-presence " + "cache is reused — it only depends on (paths, camera_names), " + "not the stage. Kept separate from --checkpoint_dir so cache " + "files survive checkpoint-dir cleanups.", + ) parser.add_argument( "--init_checkpoint", type=Path, @@ -862,6 +970,17 @@ def main() -> None: ) parser.add_argument("--K_max", type=int, default=10) parser.add_argument("--curriculum_steps", type=int, default=25_000) + parser.add_argument( + "--grad_checkpoint_every", type=int, default=10, + help="Gradient checkpointing group size for the K-step rollout. " + "0 = disabled (full activation memory). >= k_steps = single " + "checkpoint group covering the entire rollout (recommended for " + "K_max=10: pass 10). Activations within the group are discarded " + "after forward and recomputed during backward (~33%% step-time " + "penalty in exchange for ~K-fold less activation memory). " + "Values 0 < g < k_steps would need per-group chunking (matching " + "stage 2 extended); not yet supported here.", + ) # Loss weights — Stage 2b specific. parser.add_argument("--mae_weight", type=float, default=1.0) @@ -920,6 +1039,7 @@ def main() -> None: ) if dm.is_main: args.checkpoint_dir.mkdir(parents=True, exist_ok=True) + args.lengths_cache_dir.mkdir(parents=True, exist_ok=True) dm.barrier() train_files, val_files = resolve_shot_files( @@ -933,11 +1053,11 @@ def main() -> None: n_train_pre, n_val_pre = len(train_files), len(val_files) train_files = filter_video_present_files( train_files, args.use_video, - cache_path=args.checkpoint_dir / "video_present_train.pt", + cache_path=args.lengths_cache_dir / "video_present_train.pt", ) val_files = filter_video_present_files( val_files, args.use_video, - cache_path=args.checkpoint_dir / "video_present_val.pt", + cache_path=args.lengths_cache_dir / "video_present_val.pt", ) logger.info( f"Video-presence filter ({args.use_video}): " @@ -1030,18 +1150,30 @@ def main() -> None: ) train_ds = TokamakMultiFileDataset( train_files, - lengths_cache_path=args.checkpoint_dir / "lengths_e2e_stage2_delta_train.pt", + lengths_cache_path=args.lengths_cache_dir / "lengths_e2e_stage2_delta_train.pt", **shared, ) val_ds = TokamakMultiFileDataset( val_files, - lengths_cache_path=args.checkpoint_dir / "lengths_e2e_stage2_delta_val.pt", + lengths_cache_path=args.lengths_cache_dir / "lengths_e2e_stage2_delta_val.pt", **shared, ) logger.info( f"Chunks — train: {len(train_ds)} val: {len(val_ds)} " f"prediction_horizon_s={prediction_horizon_s:.3f} (K_max={args.K_max})" ) + + # Per-worker OMP_NUM_THREADS enforcement: with --cpus-per-task=7 in + # the SLURM script and 6 DataLoader workers per rank, default torch + # thread heuristics can oversubscribe (each worker spawning 7 OMP + # threads → 42 threads competing for 7 cores). Match the value the + # parent process saw via OMP_NUM_THREADS (set to 1 in + # _frontier_settings.sh). + def _worker_init(_worker_id: int) -> None: + import os as _os + n = int(_os.environ.get("OMP_NUM_THREADS", "1")) + torch.set_num_threads(n) + train_loader = DataLoader( train_ds, batch_size=args.batch_size, # TwoLevelSampler: shuffle file order per epoch, sequential @@ -1050,8 +1182,14 @@ def main() -> None: # RandomSampler across 7878 files gave ~1% hit rate and # spent ~10% of worker time on HDF5 file opens (observed # via py-spy on Stage 1 job 2719669). + # DistributedTwoLevelSampler is the DDP-aware sibling: each + # rank owns a fixed slice of the file list and iterates its + # own files front-to-back, so the per-worker LRU stays warm + # across epochs. PyTorch's DistributedSampler shards chunk + # indices instead and was observed to push step time from + # ~1 s to ~12 s under 2-GPU DDP on Stage 1. sampler=( - DistributedSampler( + DistributedTwoLevelSampler( train_ds, num_replicas=dm.world_size, rank=dm.rank, @@ -1063,20 +1201,43 @@ def main() -> None: else TwoLevelSampler(train_ds, shuffle=True) ), num_workers=args.num_workers, collate_fn=collate_fn, drop_last=True, + # prefetch_factor=3 + val_num_workers=4 is the v9-validated config + # at batch=8 (RAM ~68% steady, ~75% val-overlap peak — comfortable + # under the 502 GB cap). Larger batch needs revisiting via the + # empirical model: variable cost ≈ num_workers × prefetch × + # batch × ~1.3 GB. + prefetch_factor=3, pin_memory=device.type == "cuda", persistent_workers=args.num_workers > 0, + worker_init_fn=_worker_init, ) + # Val sampler mirrors the train sampler's DDP pattern: shard files + # across ranks so each rank evaluates ~1/world_size of the val set, + # then sums + counts are all_reduce'd inside validate() (see below). + if dm.distributed: + val_sampler = DistributedTwoLevelSampler( + val_ds, num_replicas=dm.world_size, rank=dm.rank, + shuffle=False, seed=args.seed, drop_last=True, + ) + else: + val_sampler = TwoLevelSampler(val_ds, shuffle=False) + # Val loader memory budget (ported from Stage 1 OOM testing): + # train workers stay alive during val (persistent=True on train) and + # hold their prefetched batches. Capping val to + # num_workers=min(4, args.num_workers), prefetch_factor=1, and + # persistent_workers=False keeps the combined in-flight footprint + # under the 502 GB node budget. Without this we OOM'd at 97% host + # RAM on 2-node smokes when val workers spun up alongside the train + # 6×2 prefetch pool. + val_num_workers = min(4, args.num_workers) val_loader = DataLoader( - val_ds, batch_size=args.batch_size, shuffle=False, - num_workers=args.num_workers, collate_fn=collate_fn, drop_last=True, - # pin_memory=False for val: each iter() call re-creates the main - # process's pin_memory thread + internal queues, and those pinned - # allocations ratchet host RSS upward across validations (observed - # +127 GB on val 1, +27 GB on val 2 with persistent_workers=True, - # OOM on val 2 at batch=256). Val is 1–20 batches per call so the - # synchronous H2D cost is negligible. + val_ds, batch_size=args.batch_size, + sampler=val_sampler, + num_workers=val_num_workers, collate_fn=collate_fn, drop_last=True, + prefetch_factor=1, pin_memory=False, - persistent_workers=args.num_workers > 0, + persistent_workers=False, + worker_init_fn=_worker_init, ) opt = torch.optim.AdamW( @@ -1165,6 +1326,7 @@ def amp_ctx_factory(): video_diag_names=video_diag_names, video_n_frames=video_n_frames, spectro_diag_names=spectro_diag_names, + grad_checkpoint_every=args.grad_checkpoint_every, ) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.grad_clip) diff --git a/scripts/training/train_e2e_stage2_extended.py b/scripts/training/train_e2e_stage2_extended.py index ed01b51..3946ac4 100644 --- a/scripts/training/train_e2e_stage2_extended.py +++ b/scripts/training/train_e2e_stage2_extended.py @@ -59,6 +59,7 @@ from tokamak_foundation_model.data.data_loader import collate_fn from tokamak_foundation_model.data.multi_file_dataset import ( + DistributedTwoLevelSampler, TokamakMultiFileDataset, TwoLevelSampler, filter_video_present_files, @@ -71,7 +72,6 @@ ) from tokamak_foundation_model.e2e.rollout import TokenSpaceRollout from tokamak_foundation_model.utils.distributed import DistributedManager -from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as _DDP from tokamak_foundation_model.e2e.multimodal import ( @@ -895,7 +895,21 @@ def validate( # output reports them as NaN (counts[k][name]["disp"] # never advances). mae = masked_mae(pred, target, mask).item() - copy_mae = masked_mae(diag_initial[name], target, mask).item() + # Spectrogram diag_initial holds the full STFT output + # (e.g. 98 frames) while target is sliced to trunc_t + # (e.g. 96) by split_spectro_target_by_step. Truncate + # the copy baseline to match so masked_mae's + # broadcast doesn't blow up. Video shapes already + # agree. + if name in spectro_set: + baseline_input = diag_initial[name][ + ..., : spectro_trunc_t_map[name] + ] + else: + baseline_input = diag_initial[name] + copy_mae = masked_mae( + baseline_input, target, mask + ).item() sums[k][name]["model_mae"] += mae sums[k][name]["copy_mae"] += copy_mae counts[k][name]["mae"] += 1 @@ -1334,8 +1348,14 @@ def forward( # RandomSampler across 7878 files gave ~1% hit rate and # spent ~10% of worker time on HDF5 file opens (observed # via py-spy on Stage 1 job 2719669). + # DistributedTwoLevelSampler is the DDP-aware sibling: each + # rank owns a fixed slice of the file list and iterates its + # own files front-to-back, so the per-worker LRU stays warm + # across epochs. PyTorch's DistributedSampler shards chunk + # indices instead and was observed to push step time from + # ~1 s to ~12 s under 2-GPU DDP on Stage 1. sampler=( - DistributedSampler( + DistributedTwoLevelSampler( train_ds, num_replicas=dm.world_size, rank=dm.rank, diff --git a/scripts/training/train_e2e_stage3.py b/scripts/training/train_e2e_stage3.py index 92fd4a9..f93bb14 100644 --- a/scripts/training/train_e2e_stage3.py +++ b/scripts/training/train_e2e_stage3.py @@ -582,6 +582,11 @@ def main() -> None: parser.add_argument("--d_model", type=int, default=256) parser.add_argument("--n_layers", type=int, default=8) parser.add_argument("--n_heads", type=int, default=8) + parser.add_argument( + "--gradient_checkpoint", action="store_true", + help="Recompute backbone-block activations during backward. Costs " + "~30%% extra compute; needed for deeper / wider rollouts.", + ) parser.add_argument("--dropout", type=float, default=0.1) # LoRA @@ -717,6 +722,7 @@ def main() -> None: diagnostics=diagnostics, actuators=actuators, d_model=args.d_model, n_heads=args.n_heads, n_layers=args.n_layers, dropout=args.dropout, + gradient_checkpoint=args.gradient_checkpoint, ).to(device) if args.init_checkpoint is not None: diff --git a/src/tokamak_foundation_model/data/multi_file_dataset.py b/src/tokamak_foundation_model/data/multi_file_dataset.py index 81a83fc..7785f35 100644 --- a/src/tokamak_foundation_model/data/multi_file_dataset.py +++ b/src/tokamak_foundation_model/data/multi_file_dataset.py @@ -207,6 +207,12 @@ def _load_or_compute_lengths( """ Return per-file chunk counts, loading from cache when available. + Under DDP only rank 0 reads/computes/writes the cache; all other + ranks receive the result via ``dist.broadcast_object_list``. This + avoids 8 ranks hammering the Lustre MDS with redundant scans and + prevents concurrent ``torch.save`` calls from corrupting the + sidecar zip file. + Parameters ---------- max_duration_s : float @@ -215,7 +221,8 @@ def _load_or_compute_lengths( Path to the sidecar cache file. If the file exists *and* its stored path list matches the current ``hdf5_paths``, the cached lengths are returned directly without opening any HDF5 file. - Otherwise lengths are computed and written to this path. + Otherwise lengths are computed and written to this path + atomically (``.tmp`` + ``replace``). Returns ------- @@ -223,50 +230,72 @@ def _load_or_compute_lengths( Number of chunks for each path in ``self.hdf5_paths``. Files that could not be opened have length ``0``. """ - paths_as_str = [str(p) for p in self.hdf5_paths] - - if lengths_cache_path is not None: - cache_path = Path(lengths_cache_path) - if cache_path.exists(): - cache = torch.load(cache_path, weights_only=False) - if cache.get("paths") == paths_as_str: - print(f"Loaded file lengths from cache: {cache_path}") - return cache["lengths"] + import torch.distributed as dist + distributed = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if distributed else 0 - lengths = [] - for path in tqdm(self.hdf5_paths, desc="Computing file lengths"): - try: - with h5py.File(path, "r") as f: - duration = min(self._compute_duration(f), max_duration_s) - # Subtract warmup: usable duration starts after warmup_s - duration = duration - self.warmup_s - if duration <= 0.0: - length = 0 - elif self.prediction_mode: - total_window = ( - self.chunk_duration_s + self.prediction_horizon_s - ) - length = max(0, int(np.floor( - (duration - total_window) / self.step_size_s - )) + 1) - else: - if duration < self.chunk_duration_s: + paths_as_str = [str(p) for p in self.hdf5_paths] + lengths: Optional[list[int]] = None + + if rank == 0: + if lengths_cache_path is not None: + cache_path = Path(lengths_cache_path) + if cache_path.exists(): + try: + cache = torch.load(cache_path, weights_only=False) + if cache.get("paths") == paths_as_str: + print(f"Loaded file lengths from cache: {cache_path}") + lengths = cache["lengths"] + except Exception as e: + print( + f"Warning: lengths cache at {cache_path} is " + f"unreadable ({e}); recomputing." + ) + + if lengths is None: + lengths = [] + for path in tqdm(self.hdf5_paths, desc="Computing file lengths"): + try: + with h5py.File(path, "r") as f: + duration = min(self._compute_duration(f), max_duration_s) + # Subtract warmup: usable duration starts after warmup_s + duration = duration - self.warmup_s + if duration <= 0.0: + length = 0 + elif self.prediction_mode: + total_window = ( + self.chunk_duration_s + self.prediction_horizon_s + ) + length = max(0, int(np.floor( + (duration - total_window) / self.step_size_s + )) + 1) + else: + if duration < self.chunk_duration_s: + length = 0 + else: + length = int(np.floor( + (duration - self.chunk_duration_s) / self.step_size_s + )) + 1 + except OSError as e: + print(f"Warning: could not open {path}: {e}") length = 0 - else: - length = int(np.floor( - (duration - self.chunk_duration_s) / self.step_size_s - )) + 1 - except OSError as e: - print(f"Warning: could not open {path}: {e}") - length = 0 - lengths.append(length) - - if lengths_cache_path is not None: - torch.save( - {"paths": paths_as_str, "lengths": lengths}, - lengths_cache_path - ) - print(f"Saved file lengths to cache: {lengths_cache_path}") + lengths.append(length) + + if lengths_cache_path is not None: + # Atomic write: write to .tmp then rename, so a crashed + # write never leaves a half-written zip that the next + # torch.load would barf on. + tmp_path = Path(str(lengths_cache_path) + ".tmp") + torch.save( + {"paths": paths_as_str, "lengths": lengths}, tmp_path, + ) + tmp_path.replace(Path(lengths_cache_path)) + print(f"Saved file lengths to cache: {lengths_cache_path}") + + if distributed: + payload = [lengths] if rank == 0 else [None] + dist.broadcast_object_list(payload, src=0) + lengths = payload[0] return lengths @@ -444,6 +473,148 @@ def __iter__(self): yield from range(start, end) +# ============================================================================= +# DDP-aware two-level sampler (file-level sharding) +# ============================================================================= + + +class DistributedTwoLevelSampler(Sampler): + """ + DDP-aware file-level sharding with sequential intra-file iteration. + + Combines :class:`TwoLevelSampler`'s file-sequential locality with + DDP-aware sharding. The file list is partitioned across ranks **once** + at construction (round-robin: rank ``r`` owns positions + ``r, r + N, r + 2N, …``). Each rank then iterates **its own** files, + front-to-back within each file, with per-epoch shuffling of the + rank's own file order via :meth:`set_epoch`. + + Why this matters + ---------------- + PyTorch's :class:`~torch.utils.data.distributed.DistributedSampler` + shards *chunk indices* across ranks, which scatters each rank's + accesses across the entire file pool and defeats the per-worker LRU + file-handle cache in :class:`TokamakMultiFileDataset`. On the live + DIII-D dataset (~7900 shots, LRU=100) this collapses cache hit rate + to ~1 % and makes HDF5 ``open()`` the dominant per-step cost under + DDP (observed ~12 s/step on a 2-GPU DDP run vs. ~1 s/step single-GPU + at the same batch size). + + Static (vs. rotated) sharding + ----------------------------- + The file-to-rank assignment is fixed for the lifetime of the + sampler. Each rank only ever sees its own subset of files. This + keeps the LRU file-handle cache warm across epochs (especially with + ``persistent_workers=True``). For many-epoch training the cross-rank + data diversity that rotated sharding would buy is dominated by + within-rank re-exposure; use PyTorch's ``DistributedSampler`` if + you'd rather have every rank eventually see every file at the cost + of cache locality. + + Length parity across ranks + -------------------------- + File sizes vary; per-rank totals may differ. Every rank truncates to + the minimum per-rank chunk count so DDP all-reduce stays in + lockstep. Padding (``drop_last=False``) is not supported. + + Parameters + ---------- + dataset : TokamakMultiFileDataset + Dataset with ``_valid_lengths`` and ``_cumulative_lengths``. + num_replicas : int + World size. + rank : int + This rank's index in ``[0, num_replicas)``. + shuffle : bool, optional + Per-epoch shuffle of the rank's own file order. Default + ``True``. + seed : int, optional + RNG seed. The per-epoch RNG uses ``seed + epoch``. Default ``0``. + drop_last : bool, optional + Must be ``True``. Present for API compatibility with + ``DistributedSampler``. Default ``True``. + """ + + def __init__( + self, + dataset: "TokamakMultiFileDataset", + num_replicas: int, + rank: int, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = True, + ) -> None: + if num_replicas < 1: + raise ValueError(f"num_replicas must be >= 1, got {num_replicas}") + if not (0 <= rank < num_replicas): + raise ValueError( + f"rank {rank} not in [0, num_replicas={num_replicas})" + ) + n_files = len(dataset._valid_lengths) + if num_replicas > n_files: + raise ValueError( + f"num_replicas={num_replicas} exceeds n_files={n_files}; " + f"cannot shard." + ) + if not drop_last: + raise NotImplementedError( + "drop_last=False (padded sampling) is not supported. " + "Pass drop_last=True so every rank sees the same number " + "of samples per epoch." + ) + + self.dataset = dataset + self.num_replicas = int(num_replicas) + self.rank = int(rank) + self.shuffle = bool(shuffle) + self.seed = int(seed) + self.drop_last = True + self.epoch = 0 + + # Static round-robin partition of the *valid* file list. + self._rank_file_positions: list[int] = list( + range(self.rank, n_files, self.num_replicas) + ) + + # Pre-compute equal per-rank chunk count = min over ranks. + per_rank_totals = [ + sum(int(dataset._valid_lengths[p]) + for p in range(r, n_files, self.num_replicas)) + for r in range(self.num_replicas) + ] + self._num_samples = min(per_rank_totals) + + def set_epoch(self, epoch: int) -> None: + """Set the epoch used to seed per-epoch shuffles. Mirrors + :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`. + Call once per training epoch before iterating.""" + self.epoch = int(epoch) + + def __len__(self) -> int: + return self._num_samples + + def __iter__(self): + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + perm = torch.randperm( + len(self._rank_file_positions), generator=g, + ).tolist() + file_order = [self._rank_file_positions[i] for i in perm] + else: + file_order = list(self._rank_file_positions) + + yielded = 0 + for pos in file_order: + start = int(self.dataset._cumulative_lengths[pos]) + end = int(self.dataset._cumulative_lengths[pos + 1]) + for chunk_idx in range(start, end): + if yielded >= self._num_samples: + return + yield chunk_idx + yielded += 1 + + # ============================================================================= # Convenience factory # ============================================================================= @@ -527,57 +698,70 @@ def filter_video_present_files( The subset of ``paths`` with at least one camera present. Order is preserved. """ + import torch.distributed as dist + distributed = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if distributed else 0 + paths_key = tuple(str(p) for p in paths) cameras_key = tuple(sorted(camera_names)) + video_present: Optional[list[str]] = None - if cache_path is not None and cache_path.exists(): - try: - cache = torch.load(cache_path, weights_only=False) - if ( - cache.get("paths_key") == paths_key - and cache.get("cameras_key") == cameras_key - ): - present = set(cache["video_present"]) - return [p for p in paths if str(p) in present] - except Exception: - # Corrupt or unreadable cache — fall through to rescan. - pass - - print( - f"Scanning {len(paths)} files for {cameras_key} video presence " - "(cache miss)..." - ) - video_present: list[str] = [] - for p in tqdm(paths, desc="Video presence scan"): - try: - with h5py.File(p, "r") as f: - for cam in camera_names: - if cam not in f or "ydata" not in f[cam]: - continue - yd = f[cam]["ydata"] - xd = f[cam].get("xdata") - if ( - yd.size > 0 - and yd.ndim == 4 - and xd is not None - and xd.size >= 2 - ): - video_present.append(str(p)) - break - except Exception as e: - print(f" skipping {p.name}: {e}") - - if cache_path is not None: - cache_path.parent.mkdir(parents=True, exist_ok=True) - torch.save( - { - "paths_key": paths_key, - "cameras_key": cameras_key, - "video_present": video_present, - }, - cache_path, - ) - print(f"Saved video-presence cache to {cache_path}") + if rank == 0: + if cache_path is not None and cache_path.exists(): + try: + cache = torch.load(cache_path, weights_only=False) + if ( + cache.get("paths_key") == paths_key + and cache.get("cameras_key") == cameras_key + ): + video_present = list(cache["video_present"]) + except Exception: + # Corrupt or unreadable cache — fall through to rescan. + video_present = None + + if video_present is None: + print( + f"Scanning {len(paths)} files for {cameras_key} video presence " + "(cache miss)..." + ) + video_present = [] + for p in tqdm(paths, desc="Video presence scan"): + try: + with h5py.File(p, "r") as f: + for cam in camera_names: + if cam not in f or "ydata" not in f[cam]: + continue + yd = f[cam]["ydata"] + xd = f[cam].get("xdata") + if ( + yd.size > 0 + and yd.ndim == 4 + and xd is not None + and xd.size >= 2 + ): + video_present.append(str(p)) + break + except Exception as e: + print(f" skipping {p.name}: {e}") + + if cache_path is not None: + cache_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = Path(str(cache_path) + ".tmp") + torch.save( + { + "paths_key": paths_key, + "cameras_key": cameras_key, + "video_present": video_present, + }, + tmp_path, + ) + tmp_path.replace(Path(cache_path)) + print(f"Saved video-presence cache to {cache_path}") + + if distributed: + payload = [video_present] if rank == 0 else [None] + dist.broadcast_object_list(payload, src=0) + video_present = payload[0] present = set(video_present) return [p for p in paths if str(p) in present] diff --git a/src/tokamak_foundation_model/e2e/backbone.py b/src/tokamak_foundation_model/e2e/backbone.py index c113590..3cdba1c 100644 --- a/src/tokamak_foundation_model/e2e/backbone.py +++ b/src/tokamak_foundation_model/e2e/backbone.py @@ -7,10 +7,17 @@ """ import math -from typing import List, Optional, Union, cast +from typing import List, Optional, Tuple, Union, cast import torch import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +try: + from flash_attn.modules.mha import MHA as _FlashMHA +except ImportError: + _FlashMHA = None def _fourier_features(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: @@ -63,6 +70,89 @@ def forward( return self.mlp(torch.cat([step_feats, time_feats], dim=-1)) +class FlashSelfAttention(nn.Module): + """flash_attn MHA wrapped to match nn.MultiheadAttention's self-attn call. + + BackboneBlock calls ``self.attn(h, h, h, need_weights=False)`` and + unpacks ``attn_out, _``. We mimic that signature; only self-attention + (q is k is v) is supported. Requires fp16/bf16 inputs at runtime — + the training script's bf16 autocast satisfies this. + """ + + def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0) -> None: + super().__init__() + if _FlashMHA is None: + raise ImportError( + "flash_attn not installed; build it via " + "`pixi run -e frontier setup-flash-attn`" + ) + self.mha = _FlashMHA( + embed_dim=d_model, + num_heads=n_heads, + dropout=dropout, + causal=False, + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + need_weights: bool = False, + ) -> Tuple[torch.Tensor, None]: + del k, v, need_weights + return self.mha(q), None + + +class SDPASelfAttention(nn.Module): + """Self-attention via ``F.scaled_dot_product_attention``. + + Drop-in for ``nn.MultiheadAttention(h, h, h, need_weights=False)`` but + routes through PyTorch's SDPA, which on ROCm 7.x dispatches to AOTriton + flash-attention. Empirical wins over ``nn.MultiheadAttention`` on MI250X: + 1.4-5× attention speedup, 2-3× lower attention memory. + """ + + def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0) -> None: + super().__init__() + assert d_model % n_heads == 0, ( + f"d_model={d_model} must be divisible by n_heads={n_heads}" + ) + self.n_heads = n_heads + self.head_dim = d_model // n_heads + # Fused QKV projection — single matmul, matches what nn.MultiheadAttention + # does internally but keeps the weight name distinct so a switch + # between attn_impls never silently loads a wrong-shaped checkpoint. + self.qkv = nn.Linear(d_model, 3 * d_model, bias=True) + self.out_proj = nn.Linear(d_model, d_model, bias=True) + self.dropout_p = dropout + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + need_weights: bool = False, + ) -> Tuple[torch.Tensor, None]: + # Self-attention path: BackboneBlock calls self.attn(h, h, h, ...) + del k, v, need_weights + B, S, D = q.shape + # (B, S, 3*D) -> (B, S, 3, H, D_head) -> (3, B, H, S, D_head) + qkv = self.qkv(q).reshape(B, S, 3, self.n_heads, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) + q_, k_, v_ = qkv[0], qkv[1], qkv[2] + out = F.scaled_dot_product_attention( + q_, k_, v_, + dropout_p=self.dropout_p if self.training else 0.0, + is_causal=False, + ) + # (B, H, S, D_head) -> (B, S, D) + out = out.transpose(1, 2).reshape(B, S, D) + return self.out_proj(out), None + + class BackboneBlock(nn.Module): """Pre-norm Transformer encoder block: norm→attn→residual, norm→MLP→residual.""" @@ -72,12 +162,23 @@ def __init__( n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0, + attn_impl: str = "standard", ) -> None: super().__init__() self.norm1 = nn.LayerNorm(d_model) - self.attn = nn.MultiheadAttention( - d_model, n_heads, dropout=dropout, batch_first=True - ) + if attn_impl == "flash": + self.attn = FlashSelfAttention(d_model, n_heads, dropout=dropout) + elif attn_impl == "sdpa": + self.attn = SDPASelfAttention(d_model, n_heads, dropout=dropout) + elif attn_impl == "standard": + self.attn = nn.MultiheadAttention( + d_model, n_heads, dropout=dropout, batch_first=True + ) + else: + raise ValueError( + f"attn_impl must be 'standard', 'sdpa', or 'flash', got " + f"{attn_impl!r}" + ) self.norm2 = nn.LayerNorm(d_model) hidden = int(d_model * mlp_ratio) self.mlp = nn.Sequential( @@ -121,14 +222,17 @@ def __init__( n_layers: int = 8, mlp_ratio: float = 4.0, dropout: float = 0.0, + attn_impl: str = "standard", + gradient_checkpoint: bool = False, ) -> None: super().__init__() self.d_model = d_model self.n_layers = n_layers + self.gradient_checkpoint = gradient_checkpoint self.step_cond = StepConditioning(d_model) self.blocks = nn.ModuleList( [ - BackboneBlock(d_model, n_heads, mlp_ratio, dropout) + BackboneBlock(d_model, n_heads, mlp_ratio, dropout, attn_impl=attn_impl) for _ in range(n_layers) ] ) @@ -160,12 +264,21 @@ def forward( step_embed = self.step_cond(step_index, time_offset_s).unsqueeze(1) x = tokens + step_embed if return_intermediates: + # Intermediates path keeps every block's output anyway, so + # checkpointing would defeat its purpose — disable here. intermediates: List[torch.Tensor] = [x] for block in self.blocks: x = block(x) intermediates.append(x) intermediates.append(self.final_norm(x)) return intermediates + # Gradient checkpointing recomputes each block's activations during + # backward instead of storing them. Only active during training + # (no-op under inference / no_grad) so eval cost is unchanged. + use_ckpt = self.gradient_checkpoint and self.training and torch.is_grad_enabled() for block in self.blocks: - x = block(x) + if use_ckpt: + x = checkpoint(block, x, use_reentrant=False) + else: + x = block(x) return self.final_norm(x) \ No newline at end of file diff --git a/src/tokamak_foundation_model/e2e/model.py b/src/tokamak_foundation_model/e2e/model.py index 41d6456..f492e1c 100644 --- a/src/tokamak_foundation_model/e2e/model.py +++ b/src/tokamak_foundation_model/e2e/model.py @@ -172,6 +172,8 @@ def __init__( n_layers: int = 8, mlp_ratio: float = 4.0, dropout: float = 0.0, + attn_impl: str = "standard", + gradient_checkpoint: bool = False, ) -> None: super().__init__() self.diagnostics = list(diagnostics) @@ -271,6 +273,8 @@ def __init__( n_layers=n_layers, mlp_ratio=mlp_ratio, dropout=dropout, + attn_impl=attn_impl, + gradient_checkpoint=gradient_checkpoint, ) def tokenize( diff --git a/src/tokamak_foundation_model/e2e/output_heads.py b/src/tokamak_foundation_model/e2e/output_heads.py index d519adc..84ba42e 100644 --- a/src/tokamak_foundation_model/e2e/output_heads.py +++ b/src/tokamak_foundation_model/e2e/output_heads.py @@ -98,12 +98,36 @@ def __init__( self.patch_size = patch_size self.n_patches = window_samples // patch_size + # Post-deconv inverse-stem at sample resolution, mirroring the + # tokenizer's pre-patch stem. The deconv first lifts each token back + # to ``stem_channels × patch_size`` samples; the inverse stem then + # refines the per-sample reconstruction with two small-kernel convs, + # giving the head the capacity to recover sharp features (spikes, + # bursts) the linear deconv alone smooths over. + stem_channels = 64 self.deconv = nn.ConvTranspose1d( in_channels=d_model, - out_channels=1, + out_channels=stem_channels, kernel_size=patch_size, stride=patch_size, ) + self.inv_stem = nn.Sequential( + nn.Conv1d(stem_channels, stem_channels, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv1d(stem_channels, 1, kernel_size=3, padding=1), + ) + + # Pre-unembed per-token MLP refiners (mirror of the tokenizer's). + n_refine_blocks = 2 + self.refine = nn.ModuleList([ + nn.Sequential( + nn.LayerNorm(d_model), + nn.Linear(d_model, d_model * 4), + nn.GELU(), + nn.Linear(d_model * 4, d_model), + ) + for _ in range(n_refine_blocks) + ]) def forward(self, tokens: torch.Tensor) -> torch.Tensor: """Reconstruct raw signal. @@ -120,10 +144,13 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor: ``(batch, n_channels, window_samples)`` raw-signal reconstruction. """ batch = tokens.shape[0] + for block in self.refine: + tokens = tokens + block(tokens) t = tokens.reshape(batch, self.n_channels, self.n_patches, self.d_model) t = t.reshape(batch * self.n_channels, self.n_patches, self.d_model) t = t.transpose(1, 2) # (B*C, d_model, n_patches) - out = self.deconv(t) # (B*C, 1, window_samples) + out = self.deconv(t) # (B*C, stem_channels, window_samples) + out = self.inv_stem(out) # (B*C, 1, window_samples) return out.reshape(batch, self.n_channels, self.window_samples) @@ -266,6 +293,18 @@ def __init__( self.n_patches_f = n_patches_f self.n_patches_t = n_patches_t + # Pre-unembed per-token MLP refiners (mirror of the tokenizer's). + n_refine_blocks = 4 + self.refine = nn.ModuleList([ + nn.Sequential( + nn.LayerNorm(d_model), + nn.Linear(d_model, d_model * 4), + nn.GELU(), + nn.Linear(d_model * 4, d_model), + ) + for _ in range(n_refine_blocks) + ]) + # Inverse of the tokenizer's patch Conv2d. self.patch_unembed = nn.ConvTranspose2d( d_model, @@ -278,6 +317,8 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor: """``(B, n_tokens, d_model) -> (B, n_channels, freq_bins, n_patches_t * patch_t)``.""" B = tokens.shape[0] + for block in self.refine: + tokens = tokens + block(tokens) # (B, n_tokens, d_model) -> (B, d_model, n_patches_f, n_patches_t). # The flatten order in the tokenizer is (n_patches_f, n_patches_t) # row-major (n_patches_f slow, n_patches_t fast), so we reshape diff --git a/src/tokamak_foundation_model/e2e/tokenizers/fast_time_series.py b/src/tokamak_foundation_model/e2e/tokenizers/fast_time_series.py index bcb3355..b602414 100644 --- a/src/tokamak_foundation_model/e2e/tokenizers/fast_time_series.py +++ b/src/tokamak_foundation_model/e2e/tokenizers/fast_time_series.py @@ -57,8 +57,20 @@ def __init__( self.patch_size = patch_size self.n_patches = window_samples // patch_size + # Pre-patch convolutional stem at sample resolution. Two small-kernel + # convs lift the per-sample representation to ``stem_channels`` before + # the patch-stride embedding, so sharp local features (spikes, bursts) + # are captured before the lossy 50-sample downsample. + stem_channels = 64 + self.stem = nn.Sequential( + nn.Conv1d(1, stem_channels, kernel_size=3, padding=1), + nn.GELU(), + nn.Conv1d(stem_channels, stem_channels, kernel_size=3, padding=1), + nn.GELU(), + ) + self.conv = nn.Conv1d( - in_channels=1, + in_channels=stem_channels, out_channels=d_model, kernel_size=patch_size, stride=patch_size, @@ -66,6 +78,20 @@ def __init__( self.channel_pos = nn.Parameter(torch.empty(n_channels, d_model)) self.patch_pos = nn.Parameter(torch.empty(self.n_patches, d_model)) self.modality_embed = nn.Parameter(torch.empty(d_model)) + + # Pre-backbone per-token MLP refiners (stacked ViT-style residual + # MLP blocks). Two blocks, matching the spectrogram pathway. + n_refine_blocks = 2 + self.refine = nn.ModuleList([ + nn.Sequential( + nn.LayerNorm(d_model), + nn.Linear(d_model, d_model * 4), + nn.GELU(), + nn.Linear(d_model * 4, d_model), + ) + for _ in range(n_refine_blocks) + ]) + nn.init.normal_(self.channel_pos, std=0.02) nn.init.normal_(self.patch_pos, std=0.02) nn.init.normal_(self.modality_embed, std=0.02) @@ -86,6 +112,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ batch = x.shape[0] x_flat = x.reshape(batch * self.n_channels, 1, self.window_samples) + x_flat = self.stem(x_flat) # (B*C, stem_channels, window_samples) patches = self.conv(x_flat) # (B*C, d_model, n_patches) patches = patches.transpose(1, 2) # (B*C, n_patches, d_model) patches = patches.reshape( @@ -94,6 +121,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: patches = patches + self.patch_pos patches = patches + self.channel_pos.unsqueeze(1) patches = patches + self.modality_embed - return patches.reshape( + tokens = patches.reshape( batch, self.n_channels * self.n_patches, self.d_model ) + for block in self.refine: + tokens = tokens + block(tokens) + return tokens diff --git a/src/tokamak_foundation_model/e2e/tokenizers/spectrogram.py b/src/tokamak_foundation_model/e2e/tokenizers/spectrogram.py index 3e368e0..ccb1225 100644 --- a/src/tokamak_foundation_model/e2e/tokenizers/spectrogram.py +++ b/src/tokamak_foundation_model/e2e/tokenizers/spectrogram.py @@ -99,6 +99,20 @@ def __init__( # (per-batch ``mask=False``). Same pattern as VideoTokenizer. self.missing_token = nn.Parameter(torch.empty(self.n_tokens, d_model)) + # Pre-backbone per-token MLP refiners (stacked ViT-style residual MLP + # blocks). Each block is independently applied with a residual at the + # call site so adding/removing blocks is a single-line change. + n_refine_blocks = 4 + self.refine = nn.ModuleList([ + nn.Sequential( + nn.LayerNorm(d_model), + nn.Linear(d_model, d_model * 4), + nn.GELU(), + nn.Linear(d_model * 4, d_model), + ) + for _ in range(n_refine_blocks) + ]) + nn.init.normal_(self.spatial_pe, std=0.02) nn.init.normal_(self.modality_embed, std=0.02) nn.init.normal_(self.missing_token, std=0.02) @@ -109,7 +123,10 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: x = x[..., : self.trunc_t] # (B, C, F, T_trunc) tokens = self.proj(x) # (B, d_model, n_f, n_t) tokens = tokens.flatten(2).transpose(1, 2) # (B, n_tokens, d_model) - return tokens + self.spatial_pe + self.modality_embed + tokens = tokens + self.spatial_pe + self.modality_embed + for block in self.refine: + tokens = tokens + block(tokens) + return tokens def forward( self, x: torch.Tensor, mask: torch.Tensor | None = None @@ -130,10 +147,15 @@ def forward( torch.Tensor Tokens of shape ``(B, n_tokens, d_model)``. """ + # Always invoke _encode and reference missing_token so the autograd + # graph for proj / spatial_pe / modality_embed / missing_token is + # data-independent. Lets us run DDP without `find_unused_parameters` + # (RCCL bucket rebuilds on a per-batch-changing unused-set were + # causing GPU memory faults on Frontier). Extra cost: a Conv2d on + # the masked-out rows; small relative to the backbone transformer. B = x.shape[0] - if mask is None or mask.all(): - return self._encode(x) - out = self.missing_token.expand(B, -1, -1).clone() - if mask.any(): - out[mask] = self._encode(x[mask]) - return out + encoded = self._encode(x) + missing = self.missing_token.expand(B, -1, -1) + if mask is None: + return encoded + 0.0 * missing.sum() + return torch.where(mask.view(B, 1, 1), encoded, missing) diff --git a/src/tokamak_foundation_model/e2e/tokenizers/video.py b/src/tokamak_foundation_model/e2e/tokenizers/video.py index 0a44064..3ae5143 100644 --- a/src/tokamak_foundation_model/e2e/tokenizers/video.py +++ b/src/tokamak_foundation_model/e2e/tokenizers/video.py @@ -134,10 +134,16 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: def forward( self, x: torch.Tensor, mask: torch.Tensor | None = None ) -> torch.Tensor: + # Always invoke _encode and reference missing_token so the autograd + # graph for patch_embed / spatial_pe / modality_emb / missing_token + # is data-independent. Lets us run DDP without + # `find_unused_parameters` (RCCL bucket rebuilds on a per-batch- + # changing unused-set were causing GPU memory faults on Frontier). + # Extra cost: a Conv3d on the masked-out rows; minor relative to + # the backbone transformer. B = x.shape[0] - if mask is None or mask.all(): - return self._encode(x) - out = self.missing_token.expand(B, -1, -1).clone() - if mask.any(): - out[mask] = self._encode(x[mask]) - return out + encoded = self._encode(x) + missing = self.missing_token.expand(B, -1, -1) + if mask is None: + return encoded + 0.0 * missing.sum() + return torch.where(mask.view(B, 1, 1), encoded, missing) diff --git a/src/tokamak_foundation_model/utils/distributed.py b/src/tokamak_foundation_model/utils/distributed.py index 903bfac..a6db966 100644 --- a/src/tokamak_foundation_model/utils/distributed.py +++ b/src/tokamak_foundation_model/utils/distributed.py @@ -46,10 +46,28 @@ def device(self) -> torch.device: return torch.device("cuda", self.device_index) return torch.device("cpu") - def wrap(self, model: torch.nn.Module) -> torch.nn.Module: - """Wrap model with DDP if distributed, otherwise return as-is.""" + def wrap( + self, model: torch.nn.Module, find_unused_parameters: bool = False, + ) -> torch.nn.Module: + """Wrap model with DDP if distributed, otherwise return as-is. + + Default ``find_unused_parameters=False`` relies on every parameter + being touched in every step. The video / spectrogram tokenizers + always run ``_encode`` and reference ``missing_token`` regardless + of the per-batch validity mask, so the autograd graph is + data-independent and DDP's reducer can use static buckets. This + avoids RCCL bucket-rebuild faults observed on Frontier. + + Override to ``True`` only as a debugging escape hatch — it incurs + a per-step unused-param scan and was previously observed to + trigger GPU memory faults via RCCL on this stack. + """ if self.distributed: - return DistributedDataParallel(model, device_ids=[self.device_index]) + return DistributedDataParallel( + model, + device_ids=[self.device_index], + find_unused_parameters=find_unused_parameters, + ) return model def unwrap(self, model: torch.nn.Module):