Commit e4bfca1
authored
Fix NVFP4 QAT mixed precision (#3501)
**Summary:** This commit adds support for bf16 activations +
fp32 weights mixed precision for NVFP4 QAT, which previously
threw a dtype assertion error:
```
File "ao/torchao/prototype/qat/nvfp4.py", line 159, in forward
assert fq.dtype == x.dtype
```
**Test Plan:**
```
python test/quantization/test_qat.py -k test_nvfp4_fake_quantized_linear_mixed_precision
```1 parent 09b18ee commit e4bfca1
File tree
3 files changed
+39
-1
lines changed- test/quantization
- torchao/prototype
- mx_formats
- qat
3 files changed
+39
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2182 | 2182 | | |
2183 | 2183 | | |
2184 | 2184 | | |
| 2185 | + | |
| 2186 | + | |
| 2187 | + | |
| 2188 | + | |
| 2189 | + | |
| 2190 | + | |
| 2191 | + | |
| 2192 | + | |
| 2193 | + | |
| 2194 | + | |
| 2195 | + | |
| 2196 | + | |
| 2197 | + | |
| 2198 | + | |
| 2199 | + | |
| 2200 | + | |
| 2201 | + | |
| 2202 | + | |
| 2203 | + | |
| 2204 | + | |
| 2205 | + | |
| 2206 | + | |
| 2207 | + | |
| 2208 | + | |
| 2209 | + | |
| 2210 | + | |
| 2211 | + | |
| 2212 | + | |
| 2213 | + | |
| 2214 | + | |
| 2215 | + | |
| 2216 | + | |
| 2217 | + | |
| 2218 | + | |
| 2219 | + | |
| 2220 | + | |
2185 | 2221 | | |
2186 | 2222 | | |
2187 | 2223 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
492 | 492 | | |
493 | 493 | | |
494 | 494 | | |
495 | | - | |
| 495 | + | |
496 | 496 | | |
497 | 497 | | |
498 | 498 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
| 42 | + | |
42 | 43 | | |
43 | 44 | | |
44 | 45 | | |
| |||
87 | 88 | | |
88 | 89 | | |
89 | 90 | | |
| 91 | + | |
90 | 92 | | |
91 | 93 | | |
92 | 94 | | |
| |||
0 commit comments