Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
28b15d3
Removed the argument "batch_size" from the trainers.
renierts Feb 13, 2026
305c7e2
Bugfix in the dataset class. When iterating over movie configurations…
renierts Feb 13, 2026
3243412
Added base script for video reconstruction. Copied from Aza's branch …
renierts Feb 13, 2026
dfc63ee
Added base script for video reconstruction. Copied from Aza's branch …
renierts Feb 13, 2026
65f48fc
Minor changes in the example scripts. More preprocessing options for …
renierts Feb 14, 2026
746f7ba
Fixed a bug where the dataset class failed when using multiple worker…
renierts Feb 14, 2026
f053586
Lots of bugfixes in the dataset, trainer, and models.
renierts Feb 16, 2026
300a4b3
Extended checkpointing - the trainer stores now:
renierts Feb 16, 2026
939360c
Extended checkpointing - the trainer stores now:
renierts Feb 16, 2026
d359e07
Adapted the other reconstruction scripts to match the new API.
renierts Feb 16, 2026
9d5bee1
Bugfix in the dataset class. When splitting inputs and targets, I for…
renierts Feb 16, 2026
9e79a91
Prepared an option to preprocess movies. This has to be fully integra…
renierts Feb 16, 2026
029b685
Added a baseline fusion transformer for latent space prediction.
renierts Feb 17, 2026
1298f37
Foundation model (#56)
renierts Feb 17, 2026
7f20db2
Moved some remaining scripts to the correct subdirectories.
renierts Feb 17, 2026
fc95315
Still working on preparing the dataset. This is not ready to push. Pr…
renierts Feb 17, 2026
5437224
Updated the data loader. Bugfix for loading the correct slices from H…
Feb 19, 2026
354e643
Added scripts for data fetching in Omega.
Feb 24, 2026
f4ff282
Added a documentation for setting up Globus CLI on Omega and start a …
Feb 24, 2026
39cfaea
Updated README.md:
Feb 24, 2026
605fc68
More PTData to fetch.
Feb 24, 2026
9f436ec
PEP-8 compatible code.
Feb 25, 2026
80ba381
Generalized make_preprocessing_stats.py and made the function compute…
Feb 25, 2026
5d2c032
A lot of bugfixes in the dataloader and prepare_data.py
Mar 2, 2026
ffa2c29
Many bugfixees in the dataset class and for computing preprocessing s…
Mar 4, 2026
33db368
Speed-ups in data_loader.py.
Mar 5, 2026
345a3d5
Speed-ups in the dataloader.
Mar 9, 2026
06a9065
drawing.py:
Mar 10, 2026
857f75a
Bugfix in processing methods of the dataloader:
Mar 11, 2026
1630475
Added a separate baseline encoder for filterscopes (renamed fast_time…
Mar 12, 2026
9924b6d
Added a weighted loss to penalize target distributions.
Mar 13, 2026
b67168b
Modified the default parameters of some profile and time-series signa…
Mar 17, 2026
850d621
Added CER related info to the dataset class and to the model factory.
Mar 17, 2026
af6e0e1
Merge remote-tracking branch 'origin/foundation_model' into dev-peter
Mar 17, 2026
4808eaf
Added dummy perceiver stuff. Be careful - this is not structured nice…
Mar 17, 2026
fcd7906
Added more RMP point names to the data fetching script.
Mar 31, 2026
62ae163
Updated all scripts according to the increased set of diagnostics and…
Apr 1, 2026
8c81907
Updated preprocessing_stats. Here, the statistics are now pre-calcula…
Apr 2, 2026
1634e70
Merge branch 'foundation_model' into dev-peter
renierts Apr 2, 2026
166a065
Dev peter (#68) (#69)
renierts Apr 2, 2026
8feb60a
TS profiles are now slow time series instead of profiles.
Apr 7, 2026
a9f83b5
Had to update all the profiles and slow time-series. The latent featu…
Apr 13, 2026
db551ac
Removed the argument "batch_size" from the trainers.
renierts Feb 13, 2026
d1109bb
Bugfix in the dataset class. When iterating over movie configurations…
renierts Feb 13, 2026
5dc6c7c
Added base script for video reconstruction. Copied from Aza's branch …
renierts Feb 13, 2026
b0c1ce7
Minor changes in the example scripts. More preprocessing options for …
renierts Feb 14, 2026
36fd17f
Fixed a bug where the dataset class failed when using multiple worker…
renierts Feb 14, 2026
e84fae4
Lots of bugfixes in the dataset, trainer, and models.
renierts Feb 16, 2026
897697c
Adapted the other reconstruction scripts to match the new API.
renierts Feb 16, 2026
39225f1
Foundation model (#56)
renierts Feb 17, 2026
7e0c537
Moved some remaining scripts to the correct subdirectories.
renierts Feb 17, 2026
d18375a
Updated the data loader. Bugfix for loading the correct slices from H…
Feb 19, 2026
1fb3a69
Added scripts for data fetching in Omega.
Feb 24, 2026
fe43bb2
Added a documentation for setting up Globus CLI on Omega and start a …
Feb 24, 2026
09691fc
Updated README.md:
Feb 24, 2026
a46d97b
More PTData to fetch.
Feb 24, 2026
bb50ad2
PEP-8 compatible code.
Feb 25, 2026
9cdca1a
A lot of bugfixes in the dataloader and prepare_data.py
Mar 2, 2026
7a1a9a4
Many bugfixees in the dataset class and for computing preprocessing s…
Mar 4, 2026
0ef276d
Speed-ups in data_loader.py.
Mar 5, 2026
946b5f7
Speed-ups in the dataloader.
Mar 9, 2026
be36ebc
Added a separate baseline encoder for filterscopes (renamed fast_time…
Mar 12, 2026
cc77bec
Updated preprocessing_stats. Here, the statistics are now pre-calcula…
Apr 2, 2026
77e72f2
Dev peter (#68) (#69)
renierts Apr 2, 2026
cf4b51e
TS profiles are now slow time series instead of profiles.
Apr 7, 2026
6cf8981
Had to update all the profiles and slow time-series. The latent featu…
Apr 13, 2026
4f68b7c
Merge branch 'dev-peter' of https://github.com/PlasmaControl/FusionAI…
Apr 13, 2026
ebc74e1
Big changes. Now, the entire foundation model is trained jointly.
Apr 23, 2026
739084a
Much better GPU utilization of the e2d pipeline now (98% on a single …
Apr 24, 2026
4ec7075
Prepared for video data. 100fps works better with the 50ms chunks tha…
Apr 24, 2026
da616d5
Stage 2 is ready for video support.
May 4, 2026
f9d6fcc
Prepared for real multi-model foundation model. TS+Video+Spectrograms.
May 7, 2026
5d41b67
Merge origin/foundation_model into dev-peter (preserve DDP + dev-pete…
May 7, 2026
4b32cd5
Prepared for real multi-model foundation model. TS+Video+Spectrograms.
May 7, 2026
5f43f64
Code changes in the e2e training pipeline.
May 11, 2026
90ae51d
Forgot to add multimodal.py that offers a better structure for multim…
May 11, 2026
576c1e5
Merge branch 'foundation_model' into dev-peter
renierts May 11, 2026
c60e3c9
Dev peter (#77) (#78)
renierts May 11, 2026
90e0798
Updated the data sampler. MultiFile for DDP is supported now. Signifi…
May 11, 2026
210bfb0
Updated the SLURM scripts to be more generalizable to different user …
May 11, 2026
bf777ce
Made the dataset more efficient when it comes to DDP.
May 12, 2026
402b37a
Bugfixes for the multimodal foundation model. Had to account for miss…
May 13, 2026
9fc118a
Stage 1 is ready for DDP and scheduled. Now bugfixing stage 2. One bu…
May 14, 2026
c8bf315
Bugfix in the validation part of spectrograms for stage 2. It was nec…
May 14, 2026
56c2b98
Stage 2 can be used now.
May 15, 2026
d6207c4
Increased model size to 50M parameters.
May 15, 2026
a4daf8a
Merge remote-tracking branch 'origin/dev-peter' into nathan_fm
May 20, 2026
ecd385d
Add scaling levers (SDPA attn, gradient checkpoint) + memory probe.
May 23, 2026
2f152c2
memory improvemtns
May 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ activemq-data/
.envrc
.venv
.venv-rocm
.build/
env/
venv/
ENV/
Expand Down
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# FusionAIHub (FAITH)

## Frontier setup

```bash
# 1. Clone to scratch
cd /lustre/orion/fus187/scratch/$USER
git clone git@github.com:PlasmaControl/FusionAIHub.git
cd FusionAIHub
git switch foundation_model

# 2. Install pixi
curl -fsSL https://pixi.sh/install.sh | bash
source ~/.bashrc

# 3. Install the Frontier env (~5 min)
pixi install -e frontier

# 4. Build flash-attention 2 (~2-5 min)
pixi run -e frontier setup-flash-attn


## Other platforms

- **NVIDIA/CUDA**: `pixi install` (default env), scripts in `scripts/slurm/`
- **della-milan (MI210)**: `bash scripts/slurm_della_milan/setup_rocm_env.sh`,
scripts in `scripts/slurm_della_milan/`
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Profiling file-open cost for `train_e2e_stage1` on Frontier

**Date:** 2026-05-11
**Author:** nchen
**Status:** Design — approved, plan pending

## Goal

Measure the end-to-end file-open cost of an `e2e_stage1` training job on Frontier
(Lustre filesystem, ~8753 shot HDF5 files at `/lustre/orion/fus187/proj-shared/foundation_model`),
and decide whether it is a real problem that needs mitigation.

## Background

`scripts/training/train_e2e_stage1.py` uses
`tokamak_foundation_model.data.multi_file_dataset.TokamakMultiFileDataset` to read
single-shot HDF5 files. File opens happen in two distinct places:

1. **Startup indexing pass.** `_load_or_compute_lengths()` opens every shot HDF5
sequentially to read its duration and compute a chunk count. Results are
cached to a `.pt` sidecar; subsequent runs short-circuit this entirely.
2. **Steady-state, during training.** Each DataLoader worker has its own LRU
cache of `h5py.File` handles, bounded by `max_open_files=1024`. Cache hits
are free; cold misses re-open with `h5py.File(path, "r", rdcc_nbytes=0)`.
Per-worker counters (`_prof_opens`, `_prof_hits`, `_prof_open_s`,
`_prof_close_s`, `_prof_getitem_s`) are already in place.

Existing infrastructure we'll reuse:
- `scripts/profile_indexing.py` — times Phase 1.
- `scripts/slurm_frontier/profile_indexing.sh` — Frontier launcher for the above.
- `scripts/training/profile_stage1.py` — `torch.profiler` on the full train step.
- `scripts/training/probe_stage1_loading.py` — single-process `__getitem__` timing.

Prior measurements (`logs/4555562_idx_profile.out`):
- 100-file run: 6.00 files/s, predicted ~33 min on full 8753.
- Two full-dataset attempts (jobs 4555563, 4558113) did **not** finish: the first
timed out at 1 h walltime, the second failed at 7 s (exit 1).
- `runs/lengths_cache_e2e_stage1/` is currently empty.

## Scope

**In:**
- Single Frontier job, one node, production training config (8 DDP ranks ×
4 workers/rank × batch 16, pulled from `scripts/slurm_frontier/train_e2e_stage1.sh`).
- Both phases: full-dataset indexing + ~200 steady-state training steps.
- A written verdict on whether file-open cost is acceptable or needs work.

**Out:**
- Multi-node coordination measurements.
- Multiple worker-count sweeps (4 vs 8 vs 16). One config only.
- Lustre stripe-config experimentation.
- Changes to the production training script.

## Plan

### Phase A — startup indexing (full dataset)

Run `scripts/profile_indexing.py` with no file cap against the full data
directory, writing the lengths cache to `runs/lengths_cache_e2e_stage1/`. Walltime
budget **3 h** (the prior 1 h attempt timed out).

Measurements:
- Total wall time, files/s, valid/skipped count, total chunks.

Side benefit: populates the lengths cache so all future training jobs skip the
indexing wall entirely.

### Phase B — steady-state opens during training

Run a new thin script `scripts/training/profile_stage1_opens.py` that mirrors
the existing `scripts/training/profile_stage1.py` structure (imports
`build_configs`, `build_datasets`, `resolve_shot_files`, `compute_step_loss`
from `train_e2e_stage1.py` — no changes to the production script).

Configuration to match production (`train_e2e_stage1.sh`):
- 8 DDP ranks per node, 1 GPU per rank, `--gpu-bind=closest`.
- 4 DataLoader workers per rank (32 workers total).
- `batch_size=16`, `chunk_duration_s=0.05`, `step_size_s=0.01`, `warmup_s=1.0`,
`prediction_horizon_s=0.05`, `d_model=256`, `n_layers=8`, `n_heads=8`.
- Reuse the lengths cache from Phase A.

Run ~200 training steps. At the end, each worker dumps its profiling counters
(`_prof_opens`, `_prof_hits`, `_prof_open_s`, `_prof_close_s`, `_prof_getitem_s`,
`_prof_load_s`, `_prof_process_s`) to a per-worker JSON file in
`runs/profile_e2e_stage1_opens/per_worker/`.

Rank 0 reads all per-worker JSONs after `dist.barrier()`, aggregates, and
writes `summary.json` plus a human-readable `report.md`.

If the existing in-place stdout logging (every 50 calls) is sufficient
to extract these numbers from the SLURM log, the JSON dump can be skipped in
favor of a `parse_log.py` post-processor. We will pick whichever is simpler
during implementation; the spec does not lock in one approach.

### Putting them together

Single launcher `scripts/slurm_frontier/profile_e2e_stage1_opens.sh`:
- `#SBATCH -t 03:00:00`, 1 node, account `fus187`.
- Runs Phase A first (CPU-only mode by calling the python script directly,
not via `srun`), then Phase B (via `srun -n 8 --gpu-bind=closest …`).
- Each phase writes to its own subdirectory under `runs/profile_e2e_stage1_opens/`.

## Outputs

All in `runs/profile_e2e_stage1_opens/`:

- `indexing.log` — Phase A stdout: wall time, files/s, valid/skipped, total chunks.
- `per_worker/rank{R}_worker{W}.json` — raw per-worker counters from Phase B.
- `summary.json` — aggregated open counts / hit rate / open-wall across the
32 workers; `__getitem__` time breakdown.
- `report.md` — synthesis and verdicts (see below).

Side effect: `runs/lengths_cache_e2e_stage1/lengths_e2e_stage1_{train,val}.pt`
populated for future runs.

## Verdict criteria (to include in `report.md`)

| Question | Threshold | Source |
|---|---|---|
| Is full-dataset indexing tolerable? | < 30 min OK; 30–60 min worth pre-caching; > 60 min should be a permanent cache or restripe | Phase A wall time |
| Is the training loop open-bound? | Open-wall fraction of `__getitem__` < 5 % = good, 5–20 % = OK, > 20 % = bad | Phase B `_prof_open_s / _prof_getitem_s` |
| Is `max_open_files=1024` right-sized? | Hit rate > 95 % in steps 100–200 = fine; less = LRU churn | Phase B `_prof_hits / (_prof_hits + _prof_opens)` |
| Cold-start to first useful step | Indexing + warm-up; report as a number | Phase A + Phase B step-1 timing |

Each verdict comes with a one-line recommendation: leave alone / pre-cache /
resize LRU / restripe / something else.

## Expected back-of-envelope (sanity check)

- 32 workers, 8753 files → ~274 files/worker. LRU=1024 means every worker fits
its slice — cold opens should happen at most once per file per worker.
- A pure `h5py.File()` open on Lustre is plausibly 20–100 ms (no duration
scan). At ~50 ms × 274 files = ~14 s of cold-open wall per worker, amortized
across the entire epoch.
- If the actual hit rate is much below 95 %, that's a red flag worth digging
into (DistributedSampler shard, `TwoLevelSampler` interaction, or per-worker
shard size larger than expected).
- Indexing throughput on Lustre is the dominant unknown. The prior 100-file
warm-cache extrapolation predicted 33 min but the full run timed out at 1 h,
so the true rate may be 2–4× slower than the small-N extrapolation suggested.

## Open questions / decisions deferred to plan

- Whether to dump counters via per-worker JSON files or parse the existing
stdout log (pick simpler at implementation time).
- Whether Phase A and Phase B share one SLURM job or run as two
`--dependency`-linked jobs (one job is simpler, picked here unless Phase A
is unstable enough to need re-runs).
- Whether to add an MPI broadcast of `__getitem__` step-1 timing for end-to-end
cold-start, or just report indexing wall + a single rank's step-1 time.
98 changes: 40 additions & 58 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,32 @@ toksearch_d3d = { channel = "ga-fdp" }
[tool.pixi.feature.frontier]
platforms = ["linux-64"]

[tool.pixi.feature.frontier.dependencies]
# pip is needed for the `setup-flash-attn` task below to install flash-attn
# from a git URL with --no-build-isolation. The PyTorch wheels we pull from
# the rocm7.1 index don't drag pip in transitively.
pip = "*"
# ninja: aiter (a transitive dep of flash_attn on ROCm) JIT-compiles a small
# C++ extension at first `import flash_attn`. It calls `ninja` from PATH.
ninja = "*"

[tool.pixi.feature.frontier.pypi-dependencies]
# rocm7.1 index ships torch 2.10.0 + torchvision 0.25-0.26 only.
torch = { version = ">=2.10,<2.11", index = "https://download.pytorch.org/whl/rocm7.1" }
torchvision = { version = ">=0.25,<0.27", index = "https://download.pytorch.org/whl/rocm7.1" }
# torch 2.10 declares triton-rocm as a dep; uv won't auto-discover it
# through the per-package `index = ...` above, so list it explicitly.
triton-rocm = { version = "*", index = "https://download.pytorch.org/whl/rocm7.1" }
# Flash-Attention 2 (gfx90a / MI250X) is NOT listed here intentionally:
# the build needs `module load rocm/7.1.1` + `FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE`,
# which pixi/uv can't set. Install via the `setup-flash-attn` task below; we use
# the AMD Triton backend (not Composable Kernel) per the AMD docs at
# rocm.docs.amd.com/.../model-acceleration-libraries.html — Triton skips the
# multi-hour CK template/hipcc compile and builds in ~10-15 min.

[tool.pixi.feature.frontier.tasks]
setup-flash-attn = { cmd = "bash scripts/slurm_frontier/setup_frontier_env.sh", description = "Build & install flash-attn 2 into the frontier pixi env on a Frontier compute node (gfx90a). Auto-salloc's if run from a login node." }
verify-flash-attn = { cmd = "python scripts/slurm_frontier/verify_flash_attn.py", description = "Smoke-test flash_attn on the local MI250X." }

[tool.pixi.environments]
default = ["cuda"]
Expand Down
Loading