Skip to content

Commit 4e839f2

Browse files
committed
NXP backend: Remove conv output quantization annotation if followed by BN
1 parent 78ef065 commit 4e839f2

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

backends/nxp/quantizer/patterns.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,14 @@ def get_anchors(
377377
)
378378

379379

380+
def _is_batch_norm(node_: Node) -> bool:
381+
return node_.op == "call_function" and node_.target in [
382+
torch.ops.aten.batch_norm.default,
383+
torch.ops.aten.native_batch_norm.default,
384+
torch.ops.aten._native_batch_norm_legit_no_training.default,
385+
]
386+
387+
380388
class ConvPattern(QuantizationPattern):
381389
@abstractmethod
382390
def partition_types(self) -> list[OpOverload]:
@@ -419,11 +427,20 @@ def get_anchors(
419427
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
420428
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]
421429

430+
output_specs = [(conv_node,)]
431+
# In order for QAT to be numerically correct, there should be no quantization between
432+
# convolution node and batch norm node.
433+
if self.is_qat:
434+
conv_users = conv_node.users
435+
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
436+
if possibly_bn and _is_batch_norm(possibly_bn):
437+
output_specs = []
438+
422439
return PartitionAnchors(
423440
inputs=[(conv_node, NodeArgsIdx(0))],
424441
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
425442
biases=bias,
426-
output=[(conv_node,)],
443+
output=output_specs,
427444
)
428445

429446

@@ -500,6 +517,14 @@ def get_anchors(
500517
output = []
501518
activation.meta["quantization_annotation"].input_qspec_map = {}
502519

520+
# In order for QAT to be numerically correct, there should be no quantization between
521+
# convolution node and batch norm node.
522+
if self.is_qat:
523+
conv_users = conv_node.users
524+
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
525+
if possibly_bn and _is_batch_norm(possibly_bn):
526+
output = []
527+
503528
return PartitionAnchors(
504529
inputs=[(conv_node, NodeArgsIdx(0))],
505530
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
@@ -545,11 +570,20 @@ def get_anchors(
545570
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
546571
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]
547572

573+
output_specs = [(conv_node,)]
574+
# In order for QAT to be numerically correct, there should be no quantization between
575+
# convolution node and batch norm node.
576+
if self.is_qat:
577+
conv_users = conv_node.users
578+
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
579+
if possibly_bn and _is_batch_norm(possibly_bn):
580+
output_specs = []
581+
548582
return PartitionAnchors(
549583
inputs=[(conv_node, NodeArgsIdx(0))],
550584
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
551585
biases=bias,
552-
output=[(conv_node,)],
586+
output=output_specs,
553587
)
554588

555589

0 commit comments

Comments
 (0)