@@ -377,6 +377,14 @@ def get_anchors(
377377 )
378378
379379
380+ def _is_batch_norm (node_ : Node ) -> bool :
381+ return node_ .op == "call_function" and node_ .target in [
382+ torch .ops .aten .batch_norm .default ,
383+ torch .ops .aten .native_batch_norm .default ,
384+ torch .ops .aten ._native_batch_norm_legit_no_training .default ,
385+ ]
386+
387+
380388class ConvPattern (QuantizationPattern ):
381389 @abstractmethod
382390 def partition_types (self ) -> list [OpOverload ]:
@@ -419,11 +427,20 @@ def get_anchors(
419427 if len (conv_node .args ) > 2 and conv_node .args [2 ] is not None :
420428 bias = [(conv_node , NodeArgsIdx (2 ), bias_quantization_qspec )]
421429
430+ output_specs = [(conv_node ,)]
431+ # In order for QAT to be numerically correct, there should be no quantization between
432+ # convolution node and batch norm node.
433+ if self .is_qat :
434+ conv_users = conv_node .users
435+ possibly_bn = list (conv_users .keys ())[0 ] if len (conv_users ) == 1 else None
436+ if possibly_bn and _is_batch_norm (possibly_bn ):
437+ output_specs = []
438+
422439 return PartitionAnchors (
423440 inputs = [(conv_node , NodeArgsIdx (0 ))],
424441 weights = [(conv_node , NodeArgsIdx (1 ), weight_quantization_spec )],
425442 biases = bias ,
426- output = [( conv_node ,)] ,
443+ output = output_specs ,
427444 )
428445
429446
@@ -500,6 +517,14 @@ def get_anchors(
500517 output = []
501518 activation .meta ["quantization_annotation" ].input_qspec_map = {}
502519
520+ # In order for QAT to be numerically correct, there should be no quantization between
521+ # convolution node and batch norm node.
522+ if self .is_qat :
523+ conv_users = conv_node .users
524+ possibly_bn = list (conv_users .keys ())[0 ] if len (conv_users ) == 1 else None
525+ if possibly_bn and _is_batch_norm (possibly_bn ):
526+ output = []
527+
503528 return PartitionAnchors (
504529 inputs = [(conv_node , NodeArgsIdx (0 ))],
505530 weights = [(conv_node , NodeArgsIdx (1 ), weight_quantization_spec )],
@@ -545,11 +570,20 @@ def get_anchors(
545570 if len (conv_node .args ) > 2 and conv_node .args [2 ] is not None :
546571 bias = [(conv_node , NodeArgsIdx (2 ), bias_quantization_qspec )]
547572
573+ output_specs = [(conv_node ,)]
574+ # In order for QAT to be numerically correct, there should be no quantization between
575+ # convolution node and batch norm node.
576+ if self .is_qat :
577+ conv_users = conv_node .users
578+ possibly_bn = list (conv_users .keys ())[0 ] if len (conv_users ) == 1 else None
579+ if possibly_bn and _is_batch_norm (possibly_bn ):
580+ output_specs = []
581+
548582 return PartitionAnchors (
549583 inputs = [(conv_node , NodeArgsIdx (0 ))],
550584 weights = [(conv_node , NodeArgsIdx (1 ), weight_quantization_spec )],
551585 biases = bias ,
552- output = [( conv_node ,)] ,
586+ output = output_specs ,
553587 )
554588
555589
0 commit comments