diff --git a/backends/nxp/aten_passes/neutron_aten_pass_manager.py b/backends/nxp/aten_passes/neutron_aten_pass_manager.py index 35205c76c68..f365761fac2 100644 --- a/backends/nxp/aten_passes/neutron_aten_pass_manager.py +++ b/backends/nxp/aten_passes/neutron_aten_pass_manager.py @@ -36,21 +36,30 @@ PassType = type[Callable[[torch.fx.GraphModule], PassResult]] +def _get_default_passes(neutron_target_spec, qat_mode: bool = False) -> list[PassType]: + passes = [ + SplitGroupConvolution(), + SplitGRUBasedOnNumLayers(), + RemoveNodesWithKnownOutputs(), + FuseLinearAndAddPass(), + MoveActivationBeforeConcat(neutron_target_spec), + ] + + if not qat_mode: + # In QAT mode, the fusing should happen after the training + # to preserve batch norm stats updating mechanism. + passes.append(FuseBatchNormWithConvPass()) + passes.append(FuseBatchNormWithLinearPass()) + + return passes + + class NeutronAtenPassManager(PassManager): def __init__( self, neutron_target_spec: NeutronTargetSpec, passes: list[PassType] = None ): - passes: list[PassType] = passes or [ - FuseBatchNormWithConvPass(), - FuseBatchNormWithLinearPass(), - SplitGroupConvolution(), - SplitGRUBasedOnNumLayers(), - RemoveNodesWithKnownOutputs(), - FuseLinearAndAddPass(), - MoveActivationBeforeConcat(neutron_target_spec), - ] - + passes: list[PassType] = passes or _get_default_passes(neutron_target_spec) super().__init__(passes) def __call__(self, module: nn.Module) -> PassResult: diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index b9186884d5e..104a521247a 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -6,6 +6,7 @@ import torch from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( + _get_default_passes, NeutronAtenPassManager, ) @@ -17,6 +18,7 @@ AddmmPattern, AddTensorPattern, AvgPoolPattern, + BatchNormPattern, CatPattern, Conv1dPattern, Conv2dPattern, @@ -245,6 +247,7 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False) OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig), OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig), OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(BatchNormPattern(is_qat=is_qat), static_qconfig), OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig), OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig), OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig), @@ -293,7 +296,12 @@ def transform_for_annotation( ) -> torch.fx.GraphModule: model.graph.eliminate_dead_code() # Remove dead code to simplify the graph for the passes. - model = NeutronAtenPassManager(self.neutron_target_spec)(model).graph_module + pass_manager = NeutronAtenPassManager( + self.neutron_target_spec, + _get_default_passes(self.neutron_target_spec, self.is_qat), + ) + + model = pass_manager(model).graph_module model.graph.eliminate_dead_code() # Remove dead code again, in case it was created by the passes. diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index e8f247d4bbc..07cb5274eab 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -153,6 +153,27 @@ def get_anchors( ) +class BatchNormPattern(QuantizationPattern): + def __init__(self, is_qat: bool): + super().__init__(is_qat=is_qat) + + def partition_types(self) -> list[OpOverload]: + # BatchNorm quantization is needed only when in QAT mode + return [torch.ops.aten.batch_norm.default] if self.is_qat else [] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors | None: + node = fused_partition[0].nodes[-1] + + return PartitionAnchors( + inputs=[], + weights=[], + biases=[], + output=[(node,)], + ) + + def get_anchors_for_fixed_quant_specs( fused_partition: list[fx.GraphModule], scale: float, @@ -356,6 +377,14 @@ def get_anchors( ) +def _is_batch_norm(node_: Node) -> bool: + return node_.op == "call_function" and node_.target in [ + torch.ops.aten.batch_norm.default, + torch.ops.aten.native_batch_norm.default, + torch.ops.aten._native_batch_norm_legit_no_training.default, + ] + + class ConvPattern(QuantizationPattern): @abstractmethod def partition_types(self) -> list[OpOverload]: @@ -398,11 +427,20 @@ def get_anchors( if len(conv_node.args) > 2 and conv_node.args[2] is not None: bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)] + output_specs = [(conv_node,)] + # In order for QAT to be numerically correct, there should be no quantization between + # convolution node and batch norm node. + if self.is_qat: + conv_users = conv_node.users + possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None + if possibly_bn and _is_batch_norm(possibly_bn): + output_specs = [] + return PartitionAnchors( inputs=[(conv_node, NodeArgsIdx(0))], weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)], biases=bias, - output=[(conv_node,)], + output=output_specs, ) @@ -479,6 +517,14 @@ def get_anchors( output = [] activation.meta["quantization_annotation"].input_qspec_map = {} + # In order for QAT to be numerically correct, there should be no quantization between + # convolution node and batch norm node. + if self.is_qat: + conv_users = conv_node.users + possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None + if possibly_bn and _is_batch_norm(possibly_bn): + output = [] + return PartitionAnchors( inputs=[(conv_node, NodeArgsIdx(0))], weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)], @@ -524,11 +570,20 @@ def get_anchors( if len(conv_node.args) > 2 and conv_node.args[2] is not None: bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)] + output_specs = [(conv_node,)] + # In order for QAT to be numerically correct, there should be no quantization between + # convolution node and batch norm node. + if self.is_qat: + conv_users = conv_node.users + possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None + if possibly_bn and _is_batch_norm(possibly_bn): + output_specs = [] + return PartitionAnchors( inputs=[(conv_node, NodeArgsIdx(0))], weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)], biases=bias, - output=[(conv_node,)], + output=output_specs, ) diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index e2b41aab8de..3d1992a6c40 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -457,6 +457,30 @@ def forward(self, x): return self.pool(x) +class ConvBNModule(torch.nn.Module): + def __init__(self, conv_module, conv_bias, bn_affine): + super().__init__() + + if conv_module == "conv1d": + self.conv = torch.nn.Conv1d(3, 64, 3, padding=1, bias=conv_bias) + self.bn = torch.nn.BatchNorm1d(64, affine=bn_affine) + elif conv_module == "conv2d": + self.conv = torch.nn.Conv2d(3, 64, 3, padding=1, bias=conv_bias) + self.bn = torch.nn.BatchNorm2d(64, affine=bn_affine) + elif conv_module == "conv1d_t": + self.conv = torch.nn.ConvTranspose1d(3, 64, 3, padding=1, bias=conv_bias) + self.bn = torch.nn.BatchNorm1d(64, affine=bn_affine) + elif conv_module == "conv2d_t": + self.conv = torch.nn.ConvTranspose2d(3, 64, 3, padding=1, bias=conv_bias) + self.bn = torch.nn.BatchNorm2d(64, affine=bn_affine) + else: + raise ValueError(f"Unknown conv_module: {conv_module}") + + def forward(self, x): + x = self.conv(x) + return self.bn(x) + + class MulTensorModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/nxp/tests/test_quantizer.py b/backends/nxp/tests/test_quantizer.py index 27422f9ce1e..5d2e7ce033a 100644 --- a/backends/nxp/tests/test_quantizer.py +++ b/backends/nxp/tests/test_quantizer.py @@ -636,3 +636,51 @@ def test_qat_produces_same_graph_as_ptq(): qat_quantized_model.graph.nodes, ptq_quantized_model.graph.nodes ) ) + + +# TODO: conv1d_t is currently unsupported, add when resolved +@pytest.mark.parametrize("conv_module", ["conv1d", "conv2d", "conv2d_t"]) +@pytest.mark.parametrize("conv_bias", [True, False]) +@pytest.mark.parametrize("bn_affine", [True, False]) +def test_torchao_native_conv_bn_qat_fusing(conv_module, conv_bias, bn_affine): + if not conv_bias: + pytest.skip("Conv without bias is not supported.") + + if conv_module.startswith("conv1d"): + input_shape = (1, 3, 32) + elif conv_module.startswith("conv2d"): + input_shape = (1, 3, 32, 32) + + model = models.ConvBNModule( + conv_module=conv_module, + conv_bias=conv_bias, + bn_affine=bn_affine, + ) + model.eval() + + exported_model = export(model, (torch.randn(*input_shape),), strict=True) + prepared_model = _prepare_for_quantization(exported_model, is_qat=True) + quantized_model = convert_pt2e(prepared_model) + + def is_conv(node): + return node.op == "call_function" and node.target in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv_transpose2d.input, + ] + + graph_nodes = list(quantized_model.graph.nodes) + conv_node = next(n for n in graph_nodes if is_conv(n)) + conv_node_args = conv_node.args + + if len(conv_node_args) > 3: + conv_node_args = conv_node_args[:3] + + assert len([n for n in graph_nodes if "batch_norm" in n.name]) == 0 + assert ( + len(conv_node.users) == 1 + and list(conv_node.users.keys())[0].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert all(arg.name.startswith("dequantize") for arg in conv_node_args) + assert len(graph_nodes) == 15 diff --git a/examples/nxp/aot_neutron_compile.py b/examples/nxp/aot_neutron_compile.py index 175dc9d8d70..66e213e7044 100644 --- a/examples/nxp/aot_neutron_compile.py +++ b/examples/nxp/aot_neutron_compile.py @@ -149,6 +149,13 @@ def get_model_and_inputs_from_name(model_name: str): default=False, help="Produce a quantized model", ) + parser.add_argument( + "--use_qat", + action="store_true", + required=False, + default=False, + help="Use QAT mode for quantization (does not include QAT training)", + ) parser.add_argument( "-s", "--so_library", @@ -218,8 +225,10 @@ def get_model_and_inputs_from_name(model_name: str): "No calibration inputs available, using the example inputs instead" ) calibration_inputs = example_inputs - quantizer = NeutronQuantizer(neutron_target_spec) - module = calibrate_and_quantize(module, calibration_inputs, quantizer) + quantizer = NeutronQuantizer(neutron_target_spec, args.use_qat) + module = calibrate_and_quantize( + module, calibration_inputs, quantizer, is_qat=args.use_qat + ) if args.so_library is not None: logging.debug(f"Loading libraries: {args.so_library}")