From 0630fb4e8843a7a54c4340e1f3658a012b648e02 Mon Sep 17 00:00:00 2001 From: "David W. Romero" Date: Sun, 17 May 2026 21:06:56 -0700 Subject: [PATCH] fix(callbacks): wrap visualisation forwards in the precision plugin's autocast context MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Lightning's precision plugin only enters its autocast context around the training/validation/test step hooks — it does NOT wrap callbacks. For ``bf16-mixed`` / ``fp16-mixed`` runs the visualisation callbacks therefore called the model in fp32 even though training and ``val/loss`` ran in bf16/fp16. For SIREN-Hyena kernels at moderate-to-large ``ω₀`` the autocast-OFF forward produces output that is several orders of magnitude different from the autocast-ON forward (Δ ≈ 3.7 in prediction magnitude even at init, growing with training). The visible symptom was perfect ``train/loss`` and ``val/loss`` paired with structureless ``val/sequence_1d_grid`` images: the model genuinely solved the task under autocast, but the viz callback was running an unrelated fp32 path that diverged badly. Three call sites patched (``Sequence1DVisualizationCallback``, ``ValidationImageGridCallback``, ``ValidationVolumeGridCallback``): * Enter ``trainer.precision_plugin.forward_context()`` around the ``pl_module(...)`` call so the visualisation matches the model's actual training/inference output. For 32-true precision this resolves to ``nullcontext()`` (no behaviour change). * Cast ``preds.float()`` afterwards because ``torch.Tensor.numpy()`` does not support bf16. Verified on saved diag-baseline checkpoint (``81ni1fah``, ``simple_copy_1d/hyena_blockdiag_film``, val/loss=0.00065): the same checkpoint and batch produce MSE 0.0008 with autocast ON (matches val/loss) and MSE 1.37 with autocast OFF. The patched callback now renders the autocast-ON predictions and the resulting ``val/sequence_1d_grid`` shows clean digit/letter reproductions matching the labels. Co-authored-by: Cursor --- .../callbacks/image_grid_val_visualization.py | 27 ++++++++++++++++--- .../callbacks/sequence_visualization_1d.py | 22 +++++++++++++-- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/experiments/callbacks/image_grid_val_visualization.py b/experiments/callbacks/image_grid_val_visualization.py index 41e1e530..6493c960 100644 --- a/experiments/callbacks/image_grid_val_visualization.py +++ b/experiments/callbacks/image_grid_val_visualization.py @@ -138,9 +138,22 @@ def _log_image_grid(self, trainer: pl.Trainer, pl_module: pl.LightningModule, ev if condition is not None: condition = condition.to(device) - # Forward pass + # Forward pass. Lightning's precision plugin only wraps + # training_step/validation_step/test_step in autocast — it does NOT + # wrap callbacks. For bf16/fp16-mixed runs the model is trained under + # autocast and the forward output can differ catastrophically between + # autocast-ON and autocast-OFF (see Hyena+SIREN at large ω₀). Enter + # the precision plugin's forward_context() so visualisations reflect + # the actual model output. ``.float()`` afterwards converts any bf16 + # tensors back to fp32 (numpy/matplotlib can't handle bf16). pl_module.eval() - preds = pl_module({"input": x, "condition": condition})["logits"] + forward_context = getattr(trainer.precision_plugin, "forward_context", None) + ctx = forward_context() if callable(forward_context) else torch.amp.autocast( + device_type="cuda" if x.is_cuda else "cpu", enabled=False + ) + with ctx: + preds = pl_module({"input": x, "condition": condition})["logits"] + preds = preds.float() # Convert to NCHW images, supporting flattened inputs. x_nchw = self._as_nchw_images(x) @@ -514,9 +527,15 @@ def _log_volume_grid(self, trainer: pl.Trainer, pl_module: pl.LightningModule, e if condition is not None: condition = condition.to(device) - # Forward pass + # Forward pass — see ValidationImageGridCallback above for rationale. pl_module.eval() - preds = pl_module({"input": x, "condition": condition})["logits"] + forward_context = getattr(trainer.precision_plugin, "forward_context", None) + ctx = forward_context() if callable(forward_context) else torch.amp.autocast( + device_type="cuda" if x.is_cuda else "cpu", enabled=False + ) + with ctx: + preds = pl_module({"input": x, "condition": condition})["logits"] + preds = preds.float() # Limit samples n = min(self.num_samples, x.shape[0]) diff --git a/experiments/callbacks/sequence_visualization_1d.py b/experiments/callbacks/sequence_visualization_1d.py index 51a0c8fe..118fd72b 100644 --- a/experiments/callbacks/sequence_visualization_1d.py +++ b/experiments/callbacks/sequence_visualization_1d.py @@ -223,9 +223,27 @@ def _log_visualization(self, trainer: pl.Trainer, pl_module: pl.LightningModule, if condition is not None: condition = condition.to(device) - # Forward pass + # Forward pass. Lightning's precision plugin only wraps + # ``training_step``/``validation_step``/``test_step`` in autocast — it + # does NOT wrap callbacks like this one. For bf16/fp16-mixed runs the + # model is trained under autocast, and the forward output can differ + # *significantly* (catastrophically, for SIRENs with large ω₀) between + # autocast-ON and autocast-OFF. We therefore explicitly enter the + # precision plugin's forward context so the visualisation reflects what + # the model actually produces during validation/inference. pl_module.eval() - preds = pl_module({"input": x, "condition": condition})["logits"] + forward_context = getattr(trainer.precision_plugin, "forward_context", None) + if callable(forward_context): + ctx = forward_context() + else: + ctx = torch.amp.autocast( + device_type="cuda" if x.is_cuda else "cpu", enabled=False + ) + with ctx: + preds = pl_module({"input": x, "condition": condition})["logits"] + # bf16/fp16 outputs are not supported by ``torch.Tensor.numpy()`` — + # cast back to fp32 before any downstream matplotlib/numpy use. + preds = preds.float() # x: [B, L, C_in], y: [B, segment_length, C_out], preds: [B, segment_length, C_out] n = min(self.num_samples, x.shape[0])