Skip to content

Commit 28e48b0

Browse files
committed
Fix NVFP4 QAT mixed precision
**Summary:** This commit adds support for bf16 activations + fp32 weights mixed precision for NVFP4 QAT, which previously threw a dtype assertion error: ``` File "ao/torchao/prototype/qat/nvfp4.py", line 159, in forward assert fq.dtype == x.dtype ``` **Test Plan:** ``` python test/quantization/test_qat.py -k test_nvfp4_fake_quantized_linear_mixed_precision ```
1 parent f3342a0 commit 28e48b0

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

test/quantization/test_qat.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2182,6 +2182,42 @@ def test_qat_nvfp4_training(self, use_per_tensor_scale: bool):
21822182
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
21832183
self.assertFalse(torch.equal(new_weight, prev_weight))
21842184

2185+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
2186+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
2187+
def test_nvfp4_fake_quanitzed_linear_mixed_precision(self):
2188+
"""
2189+
Test `NVFP4FakeQuantizedLinear` with bf16 input activations and fp32 weights.
2190+
"""
2191+
from torchao.prototype.qat.nvfp4 import (
2192+
NVFP4FakeQuantizeConfig,
2193+
NVFP4FakeQuantizedLinear,
2194+
)
2195+
2196+
activation_dtype = torch.bfloat16
2197+
weight_dtype = torch.float32
2198+
linear = torch.nn.Linear(128, 512, dtype=weight_dtype).cuda()
2199+
activation_config = NVFP4FakeQuantizeConfig(use_per_tensor_scale=True)
2200+
weight_config = NVFP4FakeQuantizeConfig(use_per_tensor_scale=True)
2201+
linear = NVFP4FakeQuantizedLinear.from_linear(
2202+
linear, activation_config, weight_config
2203+
)
2204+
2205+
# simulate 1 step of training
2206+
optimizer = torch.optim.SGD(linear.parameters())
2207+
loss_fn = torch.nn.CrossEntropyLoss()
2208+
target = torch.randn(1, 512).float().cuda()
2209+
x = torch.randn(1, 128, dtype=activation_dtype).cuda()
2210+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
2211+
out = linear(x)
2212+
self.assertEqual(linear.weight.dtype, weight_dtype)
2213+
self.assertEqual(x.dtype, activation_dtype)
2214+
self.assertEqual(out.dtype, activation_dtype)
2215+
loss = loss_fn(out, target)
2216+
loss.backward()
2217+
self.assertEqual(linear.weight.grad.dtype, weight_dtype)
2218+
optimizer.step()
2219+
optimizer.zero_grad()
2220+
21852221
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
21862222
@unittest.skipIf(
21872223
not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0"

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _addmm_nvfp4_dispatch(
492492

493493
# Add bias after scaling if needed
494494
if should_add_bias_separately:
495-
result = result + bias
495+
result = result + bias.to(a._orig_dtype)
496496

497497
return result
498498

torchao/prototype/qat/nvfp4.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class _NVFP4QuantizedForwardFakeQuantizedBackward(torch.autograd.Function):
3939
"""
4040

4141
@staticmethod
42+
@torch.amp.custom_fwd(device_type="cuda")
4243
def forward(
4344
ctx,
4445
_input: torch.Tensor,
@@ -87,6 +88,7 @@ def forward(
8788
)
8889

8990
@staticmethod
91+
@torch.amp.custom_bwd(device_type="cuda")
9092
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
9193
_input, weight = ctx.saved_tensors
9294
assert isinstance(_input, NVFP4Tensor)

0 commit comments

Comments
 (0)