|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import copy |
| 8 | +import unittest |
| 9 | +from contextlib import nullcontext |
| 10 | +from typing import Tuple |
| 11 | + |
| 12 | +import torch |
| 13 | +from torch.testing._internal import common_utils |
| 14 | + |
| 15 | +from torchao.prototype.quantization.float8_static_quant.prototype_float8_tensor import ( |
| 16 | + PrototypeFloat8Tensor, |
| 17 | + _choose_quant_func_and_quantize_tensor, |
| 18 | +) |
| 19 | +from torchao.prototype.quantization.quant_api import ( |
| 20 | + Float8StaticActivationFloat8WeightConfig, |
| 21 | +) |
| 22 | +from torchao.quantization import ( |
| 23 | + Float8DynamicActivationFloat8WeightConfig, |
| 24 | + quantize_, |
| 25 | +) |
| 26 | +from torchao.quantization.granularity import PerRow, PerTensor |
| 27 | +from torchao.quantization.utils import compute_error |
| 28 | +from torchao.testing.utils import TorchAOIntegrationTestCase |
| 29 | +from torchao.utils import ( |
| 30 | + is_sm_at_least_90, |
| 31 | +) |
| 32 | + |
| 33 | + |
| 34 | +# copied from test/quantization/quantize_/workflows/float8/test_float8_tensor.py |
| 35 | +class ToyConvModel(torch.nn.Module): |
| 36 | + def __init__( |
| 37 | + self, dim, in_channels, out_channels, kernel_size, bias, padding, dtype, device |
| 38 | + ): |
| 39 | + super().__init__() |
| 40 | + convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} |
| 41 | + self.conv = convs[dim]( |
| 42 | + in_channels, |
| 43 | + out_channels, |
| 44 | + kernel_size, |
| 45 | + bias=bias, |
| 46 | + padding=padding, |
| 47 | + dtype=dtype, |
| 48 | + device=device, |
| 49 | + ) |
| 50 | + |
| 51 | + def forward(self, x): |
| 52 | + return self.conv(x) |
| 53 | + |
| 54 | + |
| 55 | +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 56 | +@unittest.skipIf(not is_sm_at_least_90(), "Need sm90+") |
| 57 | +@common_utils.instantiate_parametrized_tests |
| 58 | +class TestFloat8StaticActivation(TorchAOIntegrationTestCase): |
| 59 | + def setUp(self): |
| 60 | + super().setUp() |
| 61 | + self.dtype = torch.bfloat16 |
| 62 | + torch.manual_seed(42) |
| 63 | + |
| 64 | + @common_utils.parametrize("granularity", [PerRow(), PerTensor()]) |
| 65 | + def test_static_activation_float8_weight(self, granularity): |
| 66 | + """Test that static quantization matches dynamic quantization when using the same scale""" |
| 67 | + torch.compiler.reset() |
| 68 | + |
| 69 | + dtype = torch.bfloat16 |
| 70 | + |
| 71 | + M, N, K = 32, 32, 32 |
| 72 | + input_tensor = torch.randn(M, K, dtype=dtype, device="cuda") |
| 73 | + |
| 74 | + model = torch.nn.Linear(K, N, bias=False).eval().to(device="cuda", dtype=dtype) |
| 75 | + model_static_quant = copy.deepcopy(model) |
| 76 | + model_dynamic_quant = copy.deepcopy(model) |
| 77 | + |
| 78 | + # Apply dynamic quantization |
| 79 | + dynamic_config = Float8DynamicActivationFloat8WeightConfig( |
| 80 | + granularity=granularity, |
| 81 | + ) |
| 82 | + quantize_(model_dynamic_quant, dynamic_config) |
| 83 | + |
| 84 | + dynamic_out_eager = model_dynamic_quant(input_tensor) |
| 85 | + model_dynamic_quant = torch.compile(model_dynamic_quant, fullgraph=True) |
| 86 | + dynamic_out_compile = model_dynamic_quant(input_tensor) |
| 87 | + |
| 88 | + # Get activation scale from dynamic quantization |
| 89 | + float8_input = _choose_quant_func_and_quantize_tensor( |
| 90 | + input_tensor, model_dynamic_quant.weight.act_quant_kwargs |
| 91 | + ) |
| 92 | + # Apply static quantization with the same scale using version 2 |
| 93 | + static_config = Float8StaticActivationFloat8WeightConfig( |
| 94 | + act_quant_scale=float8_input.scale.detach().clone(), |
| 95 | + granularity=granularity, |
| 96 | + version=2, |
| 97 | + ) |
| 98 | + quantize_(model_static_quant, static_config) |
| 99 | + |
| 100 | + # Verify weight is PrototypeFloat8Tensor |
| 101 | + self.assertIsInstance(model_static_quant.weight, PrototypeFloat8Tensor) |
| 102 | + self.assertIsNotNone(model_static_quant.weight.act_quant_scale) |
| 103 | + self.assertIsNotNone(model_static_quant.weight.act_quant_kwargs) |
| 104 | + |
| 105 | + static_out_eager = model_static_quant(input_tensor) |
| 106 | + model_static_quant = torch.compile(model_static_quant, fullgraph=True) |
| 107 | + static_out_compile = model_static_quant(input_tensor) |
| 108 | + |
| 109 | + sqnr_static_vs_dynamic_eager = compute_error( |
| 110 | + dynamic_out_eager, static_out_eager |
| 111 | + ) |
| 112 | + sqnr_static_vs_dynamic_compile = compute_error( |
| 113 | + dynamic_out_compile, static_out_compile |
| 114 | + ) |
| 115 | + self.assertGreater( |
| 116 | + sqnr_static_vs_dynamic_eager, |
| 117 | + 40, |
| 118 | + "SQNR of static v.s. dynamic (eager) should be > 40 dB", |
| 119 | + ) |
| 120 | + self.assertGreater( |
| 121 | + sqnr_static_vs_dynamic_compile, |
| 122 | + 40, |
| 123 | + "SQNR of static v.s. dynamic (compile) should be > 40 dB", |
| 124 | + ) |
| 125 | + |
| 126 | + @common_utils.parametrize("granularity", [PerRow(), PerTensor()]) |
| 127 | + def test_creation_and_attributes(self, granularity): |
| 128 | + """Test tensor creation, dtypes, and attributes""" |
| 129 | + M, N, K = 32, 32, 32 |
| 130 | + dtype = torch.bfloat16 |
| 131 | + |
| 132 | + input_tensor = torch.randn(M, K, dtype=dtype, device="cuda") |
| 133 | + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") |
| 134 | + |
| 135 | + # First get a scale from dynamic quantization |
| 136 | + dynamic_config = Float8DynamicActivationFloat8WeightConfig( |
| 137 | + granularity=granularity, |
| 138 | + ) |
| 139 | + model_dynamic = copy.deepcopy(linear) |
| 140 | + quantize_(model_dynamic, dynamic_config) |
| 141 | + |
| 142 | + quantized_input = _choose_quant_func_and_quantize_tensor( |
| 143 | + input_tensor, model_dynamic.weight.act_quant_kwargs |
| 144 | + ) |
| 145 | + |
| 146 | + # Now apply static quantization using version 2 |
| 147 | + static_config = Float8StaticActivationFloat8WeightConfig( |
| 148 | + act_quant_scale=quantized_input.scale.detach().clone(), |
| 149 | + granularity=granularity, |
| 150 | + version=2, |
| 151 | + ) |
| 152 | + quantize_(linear, static_config) |
| 153 | + |
| 154 | + w = linear.weight |
| 155 | + |
| 156 | + # Verify attributes |
| 157 | + self.assertEqual(w.shape, (N, K)) |
| 158 | + self.assertEqual(w.qdata.dtype, torch.float8_e4m3fn) |
| 159 | + self.assertIsInstance(w, PrototypeFloat8Tensor) |
| 160 | + self.assertIsNotNone(w.act_quant_kwargs) |
| 161 | + self.assertIsNotNone(w.act_quant_scale) |
| 162 | + |
| 163 | + # Check scale shape based on granularity |
| 164 | + if isinstance(granularity, PerRow): |
| 165 | + self.assertEqual(w.scale.shape, (N, 1)) |
| 166 | + elif isinstance(granularity, PerTensor): |
| 167 | + self.assertEqual(w.scale.shape, (1, 1)) |
| 168 | + |
| 169 | + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) |
| 170 | + @common_utils.parametrize("compile", [True, False]) |
| 171 | + @common_utils.parametrize("inference_mode", [True, False]) |
| 172 | + # test for 2D/3D conv |
| 173 | + # Inputs are (N, C_in, C_out, (D, H, W), kernel_size or |
| 174 | + # (N, C_in, C_out, (H, W), kernel_size |
| 175 | + @common_utils.parametrize( |
| 176 | + "sizes", |
| 177 | + [ |
| 178 | + (1, 160, 320, (3, 194, 130), 3), |
| 179 | + # Note: kernel_size can't be 1, otherwise |
| 180 | + # the weight will be channels_last even though |
| 181 | + # it's contiguous because of the value of |
| 182 | + # stride |
| 183 | + (1, 320, 640, (96, 64), 3), |
| 184 | + ], |
| 185 | + ) |
| 186 | + def test_fp8_conv_variants( |
| 187 | + self, |
| 188 | + dtype: torch.dtype, |
| 189 | + compile: bool, |
| 190 | + inference_mode: bool, |
| 191 | + sizes: Tuple, |
| 192 | + ): |
| 193 | + torch.compiler.reset() |
| 194 | + granularity = PerTensor() |
| 195 | + N, C_in, C_out, spatial_dims, kernel_size = sizes |
| 196 | + dim = len(spatial_dims) |
| 197 | + convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} |
| 198 | + assert dim in convs, f"Unsupported dim: {dim}" |
| 199 | + conv_class = convs[dim] |
| 200 | + _is_conv = lambda m, fqn: isinstance(m, conv_class) |
| 201 | + |
| 202 | + input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda") |
| 203 | + |
| 204 | + model = ToyConvModel( |
| 205 | + dim, |
| 206 | + C_in, |
| 207 | + C_out, |
| 208 | + kernel_size, |
| 209 | + bias=False, |
| 210 | + padding=0, |
| 211 | + dtype=dtype, |
| 212 | + device="cuda", |
| 213 | + ).eval() |
| 214 | + |
| 215 | + channels_last_memory_format = ( |
| 216 | + torch.channels_last_3d if dim == 3 else torch.channels_last |
| 217 | + ) |
| 218 | + input_tensor = input_tensor.to(memory_format=channels_last_memory_format) |
| 219 | + model = model.to(memory_format=channels_last_memory_format) |
| 220 | + |
| 221 | + quantized_model = copy.deepcopy(model) |
| 222 | + |
| 223 | + dynamic_config = Float8DynamicActivationFloat8WeightConfig( |
| 224 | + granularity=granularity, |
| 225 | + ) |
| 226 | + model_dynamic_quant = copy.deepcopy(model) |
| 227 | + quantize_(model_dynamic_quant, dynamic_config, filter_fn=_is_conv) |
| 228 | + # Get activation scale from dynamic quantization |
| 229 | + tmp_input_tensor = _choose_quant_func_and_quantize_tensor( |
| 230 | + input_tensor.clone(), model_dynamic_quant.conv.weight.act_quant_kwargs |
| 231 | + ) |
| 232 | + config = Float8StaticActivationFloat8WeightConfig( |
| 233 | + act_quant_scale=tmp_input_tensor.scale.detach().clone(), |
| 234 | + granularity=granularity, |
| 235 | + version=2, |
| 236 | + ) |
| 237 | + quantize_(quantized_model, config, filter_fn=_is_conv) |
| 238 | + |
| 239 | + if compile: |
| 240 | + quantized_model = torch.compile(quantized_model, fullgraph=True) |
| 241 | + |
| 242 | + inference_mode_ctx = torch.inference_mode() if inference_mode else nullcontext() |
| 243 | + with inference_mode_ctx: |
| 244 | + output_original = model(input_tensor) |
| 245 | + output_quantized = quantized_model(input_tensor) |
| 246 | + |
| 247 | + error = compute_error(output_original, output_quantized) |
| 248 | + assert compute_error(output_original, output_quantized) > 20, ( |
| 249 | + f"Quantization error is too high got a SQNR of {error}" |
| 250 | + ) |
| 251 | + |
| 252 | + |
| 253 | +if __name__ == "__main__": |
| 254 | + common_utils.run_tests() |
0 commit comments