Describe the bug
ChainedOptimizer.step() computes a single global grad_norm across all chained sub-optimizers and applies clip_grad_by_total_norm_fp32(..., total_norm=grad_norm) to every sub-optimizer's parameters, including a Muon / orthogonalizing group (dist_muon). For an orthogonalizing optimizer this is (a) meaningless — Muon discards gradient magnitude by design (Newton–Schulz orthogonalization is scale-invariant) — and (b) harmful: when the global grad_norm is large, the clip coefficient c = clip_grad/grad_norm is tiny and scales the per-matrix gradients fed into Newton–Schulz below its internal F.normalize(eps=1e-7) floor, so orthogonalization degenerates (near-zero, non-orthogonal update). Affected layers stop updating, loss stalls, grad_norm grows, c shrinks further, more layers collapse — a silent positive-feedback stall.
Steps/Code to reproduce bug
We hit this fine-tuning a GatedDeltaNet model whose gradients are inherently large and per-layer-imbalanced (grad_norm ≈ 5e7, growing to ~2e11). Minimal mechanism repro (torch-only) showing the clip coefficient pushing a per-matrix gradient below the Newton–Schulz floor:
import torch, torch.nn.functional as F
clip_grad, grad_norm = 1.0, 5e7
c = clip_grad / (grad_norm + 1e-6) # ≈ 2e-8
def newton_schulz(x, steps=6, eps=1e-7):
a, b, cc = 3.4445, -4.7750, 2.0315
X = F.normalize(x, p=2, dim=(-2, -1), eps=eps)
for _ in range(steps):
A = X @ X.transpose(-2, -1); X = a*X + (b*A + cc*(A@A)) @ X
return X
G = torch.randn(2048, 2048); G = G / G.norm() * 0.1 # a late-layer matrix grad
print("no clip :", newton_schulz(G).norm().item()) # ~40 (orthogonal)
print("clipped :", newton_schulz(G * c).norm().item()) # collapses toward 0
Training-level, single-variable (fixed 32-example overfit batch, optimizer=dist_muon, only clip_grad changed):
clip_grad=1.0 : loss 0.596 0.671 ... 0.583 stalls ~0.5, never overfits; grad_norm 7.5e7 -> 2.2e11
clip_grad=0 : loss 0.596 0.530 ... 0.019 clean overfit
(Keeping clip_grad=1.0 and instead only lowering Newton–Schulz eps 1e-7→1e-30 also unblocks learning, confirming the clip → NS-eps interaction.)
Expected behavior
Orthogonalizing / Muon param groups should not be subject to magnitude-based global gradient clipping (it is a no-op-at-best for scale-invariant updates). Training should proceed (the fixed-batch overfit should drive loss → 0), as it does with clip_grad=0.
Additional context
Describe the bug
ChainedOptimizer.step()computes a single globalgrad_normacross all chained sub-optimizers and appliesclip_grad_by_total_norm_fp32(..., total_norm=grad_norm)to every sub-optimizer's parameters, including a Muon / orthogonalizing group (dist_muon). For an orthogonalizing optimizer this is (a) meaningless — Muon discards gradient magnitude by design (Newton–Schulz orthogonalization is scale-invariant) — and (b) harmful: when the globalgrad_normis large, the clip coefficientc = clip_grad/grad_normis tiny and scales the per-matrix gradients fed into Newton–Schulz below its internalF.normalize(eps=1e-7)floor, so orthogonalization degenerates (near-zero, non-orthogonal update). Affected layers stop updating, loss stalls,grad_normgrows,cshrinks further, more layers collapse — a silent positive-feedback stall.Steps/Code to reproduce bug
We hit this fine-tuning a GatedDeltaNet model whose gradients are inherently large and per-layer-imbalanced (
grad_norm ≈ 5e7, growing to ~2e11). Minimal mechanism repro (torch-only) showing the clip coefficient pushing a per-matrix gradient below the Newton–Schulz floor:Training-level, single-variable (fixed 32-example overfit batch,
optimizer=dist_muon, onlyclip_gradchanged):(Keeping
clip_grad=1.0and instead only lowering Newton–Schulzeps1e-7→1e-30 also unblocks learning, confirming theclip → NS-epsinteraction.)Expected behavior
Orthogonalizing / Muon param groups should not be subject to magnitude-based global gradient clipping (it is a no-op-at-best for scale-invariant updates). Training should proceed (the fixed-batch overfit should drive loss → 0), as it does with
clip_grad=0.Additional context
megatron/core/optimizer/optimizer.py,ChainedOptimizer.step()—grad_norm = self.get_grad_norm()then per sub-optimizerclip_grad_by_total_norm_fp32(..., total_norm=grad_norm).clip_grad=0fordist_muon; or clamp the clip coefficient with a floor so it cannot push grads under the optimizer's numerical eps.