Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions experiments/callbacks/image_grid_val_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,22 @@ def _log_image_grid(self, trainer: pl.Trainer, pl_module: pl.LightningModule, ev
if condition is not None:
condition = condition.to(device)

# Forward pass
# Forward pass. Lightning's precision plugin only wraps
# training_step/validation_step/test_step in autocast — it does NOT
# wrap callbacks. For bf16/fp16-mixed runs the model is trained under
# autocast and the forward output can differ catastrophically between
# autocast-ON and autocast-OFF (see Hyena+SIREN at large ω₀). Enter
# the precision plugin's forward_context() so visualisations reflect
# the actual model output. ``.float()`` afterwards converts any bf16
# tensors back to fp32 (numpy/matplotlib can't handle bf16).
pl_module.eval()
preds = pl_module({"input": x, "condition": condition})["logits"]
forward_context = getattr(trainer.precision_plugin, "forward_context", None)
ctx = forward_context() if callable(forward_context) else torch.amp.autocast(
device_type="cuda" if x.is_cuda else "cpu", enabled=False
)
with ctx:
preds = pl_module({"input": x, "condition": condition})["logits"]
preds = preds.float()

# Convert to NCHW images, supporting flattened inputs.
x_nchw = self._as_nchw_images(x)
Expand Down Expand Up @@ -514,9 +527,15 @@ def _log_volume_grid(self, trainer: pl.Trainer, pl_module: pl.LightningModule, e
if condition is not None:
condition = condition.to(device)

# Forward pass
# Forward pass — see ValidationImageGridCallback above for rationale.
pl_module.eval()
preds = pl_module({"input": x, "condition": condition})["logits"]
forward_context = getattr(trainer.precision_plugin, "forward_context", None)
ctx = forward_context() if callable(forward_context) else torch.amp.autocast(
device_type="cuda" if x.is_cuda else "cpu", enabled=False
)
with ctx:
preds = pl_module({"input": x, "condition": condition})["logits"]
preds = preds.float()

# Limit samples
n = min(self.num_samples, x.shape[0])
Expand Down
22 changes: 20 additions & 2 deletions experiments/callbacks/sequence_visualization_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,27 @@ def _log_visualization(self, trainer: pl.Trainer, pl_module: pl.LightningModule,
if condition is not None:
condition = condition.to(device)

# Forward pass
# Forward pass. Lightning's precision plugin only wraps
# ``training_step``/``validation_step``/``test_step`` in autocast — it
# does NOT wrap callbacks like this one. For bf16/fp16-mixed runs the
# model is trained under autocast, and the forward output can differ
# *significantly* (catastrophically, for SIRENs with large ω₀) between
# autocast-ON and autocast-OFF. We therefore explicitly enter the
# precision plugin's forward context so the visualisation reflects what
# the model actually produces during validation/inference.
pl_module.eval()
preds = pl_module({"input": x, "condition": condition})["logits"]
forward_context = getattr(trainer.precision_plugin, "forward_context", None)
if callable(forward_context):
ctx = forward_context()
else:
ctx = torch.amp.autocast(
device_type="cuda" if x.is_cuda else "cpu", enabled=False
)
with ctx:
preds = pl_module({"input": x, "condition": condition})["logits"]
# bf16/fp16 outputs are not supported by ``torch.Tensor.numpy()`` —
# cast back to fp32 before any downstream matplotlib/numpy use.
preds = preds.float()

# x: [B, L, C_in], y: [B, segment_length, C_out], preds: [B, segment_length, C_out]
n = min(self.num_samples, x.shape[0])
Expand Down
Loading