@@ -636,3 +636,51 @@ def test_qat_produces_same_graph_as_ptq():
636636 qat_quantized_model .graph .nodes , ptq_quantized_model .graph .nodes
637637 )
638638 )
639+
640+
641+ # TODO: conv1d_t is currently unsupported, add when resolved
642+ @pytest .mark .parametrize ("conv_module" , ["conv1d" , "conv2d" , "conv2d_t" ])
643+ @pytest .mark .parametrize ("conv_bias" , [True , False ])
644+ @pytest .mark .parametrize ("bn_affine" , [True , False ])
645+ def test_torchao_native_conv_bn_qat_fusing (conv_module , conv_bias , bn_affine ):
646+ if not conv_bias :
647+ pytest .skip ("Conv without bias is not supported." )
648+
649+ if conv_module .startswith ("conv1d" ):
650+ input_shape = (1 , 3 , 32 )
651+ elif conv_module .startswith ("conv2d" ):
652+ input_shape = (1 , 3 , 32 , 32 )
653+
654+ model = models .ConvBNModule (
655+ conv_module = conv_module ,
656+ conv_bias = conv_bias ,
657+ bn_affine = bn_affine ,
658+ )
659+ model .eval ()
660+
661+ exported_model = export (model , (torch .randn (* input_shape ),), strict = True )
662+ prepared_model = _prepare_for_quantization (exported_model , is_qat = True )
663+ quantized_model = convert_pt2e (prepared_model )
664+
665+ def is_conv (node ):
666+ return node .op == "call_function" and node .target in [
667+ torch .ops .aten .conv1d .default ,
668+ torch .ops .aten .conv2d .default ,
669+ torch .ops .aten .conv_transpose2d .input ,
670+ ]
671+
672+ graph_nodes = list (quantized_model .graph .nodes )
673+ conv_node = next (n for n in graph_nodes if is_conv (n ))
674+ conv_node_args = conv_node .args
675+
676+ if len (conv_node_args ) > 3 :
677+ conv_node_args = conv_node_args [:3 ]
678+
679+ assert len ([n for n in graph_nodes if "batch_norm" in n .name ]) == 0
680+ assert (
681+ len (conv_node .users ) == 1
682+ and list (conv_node .users .keys ())[0 ].target
683+ == torch .ops .quantized_decomposed .quantize_per_tensor .default
684+ )
685+ assert all (arg .name .startswith ("dequantize" ) for arg in conv_node_args )
686+ assert len (graph_nodes ) == 15
0 commit comments