Skip to content

[BUG] ChainedOptimizer applies global grad-norm clipping to Muon (orthogonalizing) param groups, silently stalling training #5394

@yuchenwang3

Description

@yuchenwang3

Describe the bug

ChainedOptimizer.step() computes a single global grad_norm across all chained sub-optimizers and applies clip_grad_by_total_norm_fp32(..., total_norm=grad_norm) to every sub-optimizer's parameters, including a Muon / orthogonalizing group (dist_muon). For an orthogonalizing optimizer this is (a) meaningless — Muon discards gradient magnitude by design (Newton–Schulz orthogonalization is scale-invariant) — and (b) harmful: when the global grad_norm is large, the clip coefficient c = clip_grad/grad_norm is tiny and scales the per-matrix gradients fed into Newton–Schulz below its internal F.normalize(eps=1e-7) floor, so orthogonalization degenerates (near-zero, non-orthogonal update). Affected layers stop updating, loss stalls, grad_norm grows, c shrinks further, more layers collapse — a silent positive-feedback stall.

Steps/Code to reproduce bug

We hit this fine-tuning a GatedDeltaNet model whose gradients are inherently large and per-layer-imbalanced (grad_norm ≈ 5e7, growing to ~2e11). Minimal mechanism repro (torch-only) showing the clip coefficient pushing a per-matrix gradient below the Newton–Schulz floor:

import torch, torch.nn.functional as F
clip_grad, grad_norm = 1.0, 5e7
c = clip_grad / (grad_norm + 1e-6)                 # ≈ 2e-8

def newton_schulz(x, steps=6, eps=1e-7):
    a, b, cc = 3.4445, -4.7750, 2.0315
    X = F.normalize(x, p=2, dim=(-2, -1), eps=eps)
    for _ in range(steps):
        A = X @ X.transpose(-2, -1); X = a*X + (b*A + cc*(A@A)) @ X
    return X

G = torch.randn(2048, 2048); G = G / G.norm() * 0.1     # a late-layer matrix grad
print("no clip :", newton_schulz(G).norm().item())       # ~40  (orthogonal)
print("clipped :", newton_schulz(G * c).norm().item())   # collapses toward 0

Training-level, single-variable (fixed 32-example overfit batch, optimizer=dist_muon, only clip_grad changed):

clip_grad=1.0 : loss 0.596 0.671 ... 0.583   stalls ~0.5, never overfits; grad_norm 7.5e7 -> 2.2e11
clip_grad=0   : loss 0.596 0.530 ... 0.019   clean overfit

(Keeping clip_grad=1.0 and instead only lowering Newton–Schulz eps 1e-7→1e-30 also unblocks learning, confirming the clip → NS-eps interaction.)

Expected behavior

Orthogonalizing / Muon param groups should not be subject to magnitude-based global gradient clipping (it is a no-op-at-best for scale-invariant updates). Training should proceed (the fixed-batch overfit should drive loss → 0), as it does with clip_grad=0.

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions