You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
FT local checkpointing: corrupt replica-retrieved local ckpt (deterministic NaN at load+3 iters) has no integrity validation / global-ckpt fallback; rerun_state_machine loops forever under ft_launcher in-job restart #5281
After an in-job node-loss recovery that loads a replica-retrieved local (non-persistent) checkpoint (nvidia-resiliency-ext local checkpointing with --replication), the restored training state deterministically produces NaN in the forward loss 3 iterations after the load — reproduced in two independent runs, same iteration both times. Because nothing validates the local checkpoint's integrity/quality, and the rerun state machine's only remedy is "reschedule on a different GPU" (impossible under ft_launcher in-job restart, which returns the same GPUs by construction), the job enters an infinite restart loop (we observed 44 identical ~90 s cycles before cancelling).
Two concrete asks:
Validate local/replica-assembled checkpoints and fall back to the persistent global checkpoint when they fail (integrity check at load, and/or treat a reproducible NaN immediately after a local-tier load as "local checkpoint bad → reload global tier"). Today the corrupt local tier is trusted unconditionally and there is no fallback path.
rerun_state_machine needs an escape hatch when GPU diversity is impossible: under --enable-ft-package + ft_launcher worker-group restart, "Got rescheduled on the same GPU. Need to resume again from the same checkpoint" repeats forever. A bounded retry → abort (or → global-checkpoint fallback) would convert an infinite loop into a recoverable failure.
A related crash-safety hazard we hit on the way out: each loop cycle re-saves the diagnostic checkpoint in place at the same iteration; interrupting one of those saves (our scancel) left latest_checkpointed_iteration.txt pointing at a partial directory (5.7 GB of ~344 GB) — an atomic temp-dir + rename (tracker updated only after completion) would remove that class of failure.
Environment
Megatron-LM (mcore dev image 648b916, 2026-05), PyTorch 2.12.0a0+...nv26.04, Python 3.12
nvidia-resiliency-ext 0.6.0 (ft_launcher, c10d rendezvous; FT section timeouts configured via --ft-cfg-path)
5 nodes × 8 GPUs (B300, EFA), SLURM + pyxis
~38B MoE (150B-A12B family cut to 12 layers), --bf16 + MXFP8 (--fp8-format=e4m3 --fp8-recipe=mxfp8), EP=8, mock data
Sequence (run A = job 250184; run B = job 249886 — identical signature)
Training healthy; local checkpoint saved at iteration 120 (/dev/shm, replicated), global /fsx checkpoint at iteration 100.
Fault injection: one active node is killed entirely (agent + workers + its /dev/shm shards).
ft_launcher restarts the worker group; the spare node joins; the group loads the local iteration-120 checkpoint — the dead node's shards are retrieved from replicas. Load reports success:
successfully loaded checkpoint from <ckpt_dir> [ t 1/1, p 1/1 ] at iteration 120
Training resumes (iterations 121, 122 fine), then:
ERROR:megatron.core.rerun_state_machine:Unexpected result nan on rank 9 at iteration #123 invocation #1 (message='found NaN in local forward loss calculation')
WARNING:megatron.core.rerun_state_machine:Need to rerun step to check reproducibility of initial result
ERROR:megatron.core.rerun_state_machine:... First rerun: unexpected result is reproducible within the tolerance (nan = nan). Need to rerun on a different GPU to verify ...
WARNING:megatron.core.rerun_state_machine:Saving a checkpoint and exiting now. Please resume the job from the checkpoint to rerun the last iteration and establish a diagnostic
(Run B: NaN also at iteration fix typo in mappings.py #123, also reproducible — deterministic, not a transient.)
The workers exit per the rerun protocol; ft_launcher restarts the group on the same nodes/GPUs (that is what in-job restart does):
ERROR:megatron.core.rerun_state_machine:Got rescheduled on the same GPU. Need to resume again from the same checkpoint (node: ..., gpu: 0)
→ exit → restart → same GPU → exit → ... We cancelled after 44 identical cycles (~90 s each); --max-restarts 100 would have burned ~2.5 h of 40 idle GPUs.
Why we believe the local/replica checkpoint is the trigger
An identical drill without --replication (same overlay otherwise, same node-kill, spare swap, /fsx fallback load) recovered cleanly and trained to completion (no NaN, no rerun events).
Multiple earlier kill-tests loading non-replicated local checkpoints at 38B/75B/150B resumed cleanly past the load point.
Both replication runs NaN at the same iteration (fix typo in mappings.py #123) shortly after a replica-retrieved load, reproducibly under rerun.
A single node loss with replication enabled converts into either total job death (run B, via the rendezvous race in nvrx#348) or an unbounded restart loop (run A). The layered design (local tier → global tier → requeue) cannot help because nothing ever decides the local tier is bad: ask (1) above is the missing piece — check local checkpoint quality; if bad, use the global checkpoint.
Full (lightly redacted) worker/agent logs for both runs available on request.
Summary
After an in-job node-loss recovery that loads a replica-retrieved local (non-persistent) checkpoint (nvidia-resiliency-ext local checkpointing with
--replication), the restored training state deterministically produces NaN in the forward loss 3 iterations after the load — reproduced in two independent runs, same iteration both times. Because nothing validates the local checkpoint's integrity/quality, and the rerun state machine's only remedy is "reschedule on a different GPU" (impossible underft_launcherin-job restart, which returns the same GPUs by construction), the job enters an infinite restart loop (we observed 44 identical ~90 s cycles before cancelling).Two concrete asks:
rerun_state_machineneeds an escape hatch when GPU diversity is impossible: under--enable-ft-package+ ft_launcher worker-group restart, "Got rescheduled on the same GPU. Need to resume again from the same checkpoint" repeats forever. A bounded retry → abort (or → global-checkpoint fallback) would convert an infinite loop into a recoverable failure.A related crash-safety hazard we hit on the way out: each loop cycle re-saves the diagnostic checkpoint in place at the same iteration; interrupting one of those saves (our scancel) left
latest_checkpointed_iteration.txtpointing at a partial directory (5.7 GB of ~344 GB) — an atomic temp-dir + rename (tracker updated only after completion) would remove that class of failure.Environment
648b916, 2026-05), PyTorch2.12.0a0+...nv26.04, Python 3.12ft_launcher, c10d rendezvous; FT section timeouts configured via--ft-cfg-path)--bf16+ MXFP8 (--fp8-format=e4m3 --fp8-recipe=mxfp8), EP=8, mock data--enable-ft-package,--ckpt-format=torch_dist,--exit-signal-handler,--non-persistent-ckpt-type local --non-persistent-local-ckpt-dir /dev/shm/<run> --non-persistent-save-interval 20 --non-persistent-local-ckpt-algo fully_parallel --replication --replication-jump 8 --replication-factor 2--nnodes 4(1 parked spare)Sequence (run A = job 250184; run B = job 249886 — identical signature)
/dev/shm, replicated), global/fsxcheckpoint at iteration 100./dev/shmshards).--max-restarts 100would have burned ~2.5 h of 40 idle GPUs.Why we believe the local/replica checkpoint is the trigger
--replication(same overlay otherwise, same node-kill, spare swap, /fsx fallback load) recovered cleanly and trained to completion (no NaN, no rerun events).Impact
A single node loss with replication enabled converts into either total job death (run B, via the rendezvous race in nvrx#348) or an unbounded restart loop (run A). The layered design (local tier → global tier → requeue) cannot help because nothing ever decides the local tier is bad: ask (1) above is the missing piece — check local checkpoint quality; if bad, use the global checkpoint.
Full (lightly redacted) worker/agent logs for both runs available on request.