Skip to content

Commit 868cd7e

Browse files
committed
deprecate v1 of IntxWeightOnlyConfig
Summary: deprecate v1 of `IntxWeightOnlyConfig` and delete all callsites Test Plan: CI Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0c5c9b6 ghstack-comment-id: 3670602310 Pull-Request: #3512
1 parent 3b29890 commit 868cd7e

File tree

1 file changed

+18
-57
lines changed

1 file changed

+18
-57
lines changed

torchao/quantization/quant_api.py

Lines changed: 18 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
Int4XPULayout,
4141
MarlinSparseLayout,
4242
PlainLayout,
43-
QDQLayout,
4443
SemiSparseLayout,
4544
TensorCoreTiledLayout,
4645
to_affine_quantized_floatx,
@@ -1704,26 +1703,15 @@ class IntxWeightOnlyConfig(AOBaseConfig):
17041703
`mapping_type`: The type of mapping to use for the weight quantization.
17051704
Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC.
17061705
`scale_dtype`: The dtype to use for the weight scale.
1707-
`layout`: The layout to use for the packed weight tensor:
1708-
- QDQLayout: this layout is designed for export to ExecuTorch.this layout represents the quantization with Q/DQ quant primitives,
1709-
and is intended for export applications like ExecuTorch.
17101706
`intx_packing_format`: The format to use for the packed weight tensor (version 2 only).
17111707
`intx_choose_qparams_algorithm`: The algorithm to use for choosing the quantization parameters.
17121708
`version`: version of the config to use, only subset of above args are valid based on version, see note for more details.
1713-
1714-
Note:
1715-
1716-
Current state for IntxWeightOnlyConfig is that it supports both v1 (legacy) and v2.
1717-
1718-
* `intx_packing_format` is used for version 2.
1719-
* `layout` is only used for version 1.
17201709
"""
17211710

17221711
weight_dtype: torch.dtype = torch.int8
17231712
granularity: Granularity = PerAxis(0)
17241713
mapping_type: MappingType = MappingType.SYMMETRIC
17251714
scale_dtype: Optional[torch.dtype] = None
1726-
layout: Layout = QDQLayout()
17271715
intx_packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8
17281716
intx_choose_qparams_algorithm: IntxChooseQParamsAlgorithm = (
17291717
IntxChooseQParamsAlgorithm.AFFINE
@@ -1762,7 +1750,6 @@ def _intx_weight_only_quantize_tensor(
17621750
granularity = config.granularity
17631751
mapping_type = config.mapping_type
17641752
scale_dtype = config.scale_dtype
1765-
layout = config.layout
17661753
intx_packing_format = config.intx_packing_format
17671754
intx_choose_qparams_algorithm = config.intx_choose_qparams_algorithm
17681755

@@ -1793,51 +1780,25 @@ def _intx_weight_only_quantize_tensor(
17931780
assert weight.dim() == 4
17941781
block_size = (1, group_size, 1, 1)
17951782

1796-
if config.version == 2:
1797-
if config.intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8:
1798-
if custom_zero_point is not None and custom_zero_point.dtype == torch.int32:
1799-
custom_zero_point = custom_zero_point.to(torch.int8)
1800-
new_weight = IntxUnpackedToInt8Tensor.from_hp(
1801-
weight,
1802-
block_size,
1803-
weight_dtype,
1804-
mapping_type=mapping_type,
1805-
custom_scale=custom_scale,
1806-
custom_zero_point=custom_zero_point,
1807-
intx_choose_qparams_algorithm=intx_choose_qparams_algorithm,
1808-
)
1809-
if scale_dtype is not None and scale_dtype != weight.dtype:
1810-
_adjust_scale_dtype_in_intx_unpacked_tensor(
1811-
new_weight, weight, scale_dtype
1812-
)
1813-
1814-
return new_weight
1815-
else:
1816-
raise ValueError(f"Unsupported packing format: {intx_packing_format}")
1783+
assert config.version == 2
1784+
if config.intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8:
1785+
if custom_zero_point is not None and custom_zero_point.dtype == torch.int32:
1786+
custom_zero_point = custom_zero_point.to(torch.int8)
1787+
new_weight = IntxUnpackedToInt8Tensor.from_hp(
1788+
weight,
1789+
block_size,
1790+
weight_dtype,
1791+
mapping_type=mapping_type,
1792+
custom_scale=custom_scale,
1793+
custom_zero_point=custom_zero_point,
1794+
intx_choose_qparams_algorithm=intx_choose_qparams_algorithm,
1795+
)
1796+
if scale_dtype is not None and scale_dtype != weight.dtype:
1797+
_adjust_scale_dtype_in_intx_unpacked_tensor(new_weight, weight, scale_dtype)
18171798

1818-
# Version 1
1819-
assert config.intx_choose_qparams_algorithm == IntxChooseQParamsAlgorithm.AFFINE, (
1820-
"version 1 only supports affine algorithm"
1821-
)
1822-
assert config.version == 1
1823-
warnings.warn(
1824-
"Config Deprecation: version 1 of IntxWeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2967 for more details"
1825-
)
1826-
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype]
1827-
weight = to_affine_quantized_intx(
1828-
input_float=weight,
1829-
mapping_type=mapping_type,
1830-
block_size=block_size,
1831-
target_dtype=torch.int8,
1832-
quant_min=quant_min,
1833-
quant_max=quant_max,
1834-
scale_dtype=scale_dtype,
1835-
zero_point_dtype=torch.int8,
1836-
preserve_zero=(mapping_type == MappingType.SYMMETRIC),
1837-
zero_point_domain=ZeroPointDomain.INT,
1838-
_layout=layout,
1839-
)
1840-
return weight
1799+
return new_weight
1800+
else:
1801+
raise ValueError(f"Unsupported packing format: {intx_packing_format}")
18411802

18421803

18431804
@register_quantize_module_handler(IntxWeightOnlyConfig)

0 commit comments

Comments
 (0)