Skip to content

Commit 33cb6e4

Browse files
committed
NXP backend: Add tests for conv+bn fusing in QAT
1 parent db15dba commit 33cb6e4

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

backends/nxp/tests/models.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,28 @@ def forward(self, x):
457457
return self.pool(x)
458458

459459

460+
class ConvBNModule(torch.nn.Module):
461+
def __init__(self, conv_module, conv_bias, bn_affine):
462+
super().__init__()
463+
464+
if conv_module == "conv1d":
465+
self.conv = torch.nn.Conv1d(3, 64, 3, padding=1, bias=conv_bias)
466+
self.bn = torch.nn.BatchNorm1d(64, affine=bn_affine)
467+
elif conv_module == "conv2d":
468+
self.conv = torch.nn.Conv2d(3, 64, 3, padding=1, bias=conv_bias)
469+
self.bn = torch.nn.BatchNorm2d(64, affine=bn_affine)
470+
elif conv_module == "conv1d_t":
471+
self.conv = torch.nn.ConvTranspose1d(3, 64, 3, padding=1, bias=conv_bias)
472+
self.bn = torch.nn.BatchNorm1d(64, affine=bn_affine)
473+
elif conv_module == "conv2d_t":
474+
self.conv = torch.nn.ConvTranspose2d(3, 64, 3, padding=1, bias=conv_bias)
475+
self.bn = torch.nn.BatchNorm2d(64, affine=bn_affine)
476+
477+
def forward(self, x):
478+
x = self.conv(x)
479+
return self.bn(x)
480+
481+
460482
class MulTensorModule(torch.nn.Module):
461483
def __init__(self):
462484
super().__init__()

backends/nxp/tests/test_quantizer.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,3 +636,51 @@ def test_qat_produces_same_graph_as_ptq():
636636
qat_quantized_model.graph.nodes, ptq_quantized_model.graph.nodes
637637
)
638638
)
639+
640+
641+
# TODO: conv1d_t is currently unsupported, add when resolved
642+
@pytest.mark.parametrize("conv_module", ["conv1d", "conv2d", "conv2d_t"])
643+
@pytest.mark.parametrize("conv_bias", [True, False])
644+
@pytest.mark.parametrize("bn_affine", [True, False])
645+
def test_torchao_native_conv_bn_qat_fusing(conv_module, conv_bias, bn_affine):
646+
if not conv_bias:
647+
pytest.skip("Conv without bias is not supported.")
648+
649+
if conv_module.startswith("conv1d"):
650+
input_shape = (1, 3, 32)
651+
elif conv_module.startswith("conv2d"):
652+
input_shape = (1, 3, 32, 32)
653+
654+
model = models.ConvBNModule(
655+
conv_module=conv_module,
656+
conv_bias=conv_bias,
657+
bn_affine=bn_affine,
658+
)
659+
model.eval()
660+
661+
exported_model = export(model, (torch.randn(*input_shape),), strict=True)
662+
prepared_model = _prepare_for_quantization(exported_model, is_qat=True)
663+
quantized_model = convert_pt2e(prepared_model)
664+
665+
def is_conv(node):
666+
return node.op == "call_function" and node.target in [
667+
torch.ops.aten.conv1d.default,
668+
torch.ops.aten.conv2d.default,
669+
torch.ops.aten.conv_transpose2d.input,
670+
]
671+
672+
graph_nodes = list(quantized_model.graph.nodes)
673+
conv_node = next(n for n in graph_nodes if is_conv(n))
674+
conv_node_args = conv_node.args
675+
676+
if len(conv_node_args) > 3:
677+
conv_node_args = conv_node_args[:3]
678+
679+
assert len([n for n in graph_nodes if "batch_norm" in n.name]) == 0
680+
assert (
681+
len(conv_node.users) == 1
682+
and list(conv_node.users.keys())[0].target
683+
== torch.ops.quantized_decomposed.quantize_per_tensor.default
684+
)
685+
assert all(arg.name.startswith("dequantize") for arg in conv_node_args)
686+
assert len(graph_nodes) == 15

0 commit comments

Comments
 (0)