Skip to content

Commit 09b18ee

Browse files
authored
Update fp8 conv3d to use mslk (#3530)
Update fp8 conv3d to use mlsk Summary: fbgemm_gpu_genai is renamed to https://github.com/meta-pytorch/MSLK/tree/main, so updating the dependency to mslk for fp8 conv for now (we can migrate others in the future) Next: we'll remove the permute and the test the new functionality added in fp8 conv op Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants Reviewers: Subscribers: Tasks: Tags:
1 parent d6bbb67 commit 09b18ee

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torchao.testing.utils import TorchAOIntegrationTestCase
3232
from torchao.utils import (
3333
_is_fbgemm_gpu_genai_available,
34+
_is_mslk_available,
3435
is_sm_at_least_89,
3536
is_sm_at_least_90,
3637
is_sm_at_least_100,
@@ -329,8 +330,8 @@ def _test_fp8_matmul_model(
329330
not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0"
330331
)
331332
@unittest.skipIf(
332-
not _is_fbgemm_gpu_genai_available(),
333-
"Requires fbgemm_gpu_genai to be installed",
333+
not _is_mslk_available(),
334+
"Requires mslk to be installed",
334335
)
335336
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
336337
@common_utils.parametrize("compile", [True, False])

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from torchao.utils import (
4242
TorchAOBaseTensor,
4343
_is_fbgemm_gpu_genai_available,
44+
_is_mslk_available,
4445
fill_defaults,
4546
is_sm_at_least_90,
4647
is_sm_at_least_100,
@@ -506,9 +507,7 @@ def _quantize_and_scaled_conv3d(
506507
assert input_tensor.dim() == 5 and weight_tensor.dim() == 5, (
507508
"Only support 3D conv currently"
508509
)
509-
assert _is_fbgemm_gpu_genai_available(), (
510-
"quantized fp8 conv3d requires fbgemm_gpu_genai to be available"
511-
)
510+
assert _is_mslk_available(), "quantized fp8 conv3d requires mslk to be available"
512511
act_quant_kwargs = weight_tensor.act_quant_kwargs
513512
# quantize activation, if `act_quant_kwargs` is specified
514513
if act_quant_kwargs is not None:
@@ -519,8 +518,8 @@ def _quantize_and_scaled_conv3d(
519518
if isinstance(input_tensor, Float8Tensor):
520519
kernel_choice = None
521520
if weight_tensor.kernel_preference == KernelPreference.AUTO:
522-
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_100():
523-
kernel_choice = "fbgemm"
521+
if _is_mslk_available() and is_sm_at_least_100():
522+
kernel_choice = "mslk"
524523
else:
525524
raise NotImplementedError(
526525
f"No available kernel choice for {weight_tensor.kernel_preference}"
@@ -532,7 +531,7 @@ def _quantize_and_scaled_conv3d(
532531
f"No available kernel choice for {weight_tensor.kernel_preference}"
533532
)
534533

535-
assert kernel_choice == "fbgemm", "Only fbgemm kernel choice is supported currently"
534+
assert kernel_choice == "mslk", "Only mslk kernel choice is supported currently"
536535
input_qdata = input_tensor.qdata
537536
weight_qdata = weight_tensor.qdata
538537

@@ -560,7 +559,10 @@ def _quantize_and_scaled_conv3d(
560559

561560
input_scale = input_tensor.scale
562561
weight_scale = weight_tensor.scale
563-
output = torch.ops.fbgemm.f8f8bf16_conv(
562+
563+
import mslk.conv # noqa: F401
564+
565+
output = torch.ops.mslk.f8f8bf16_conv(
564566
input_qdata,
565567
weight_qdata,
566568
input_scale * weight_scale,

torchao/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,13 @@ def _is_fbgemm_gpu_genai_available():
11461146
return True
11471147

11481148

1149+
def _is_mslk_available():
1150+
if is_fbcode():
1151+
return True
1152+
1153+
return importlib.util.find_spec("mslk") is not None
1154+
1155+
11491156
class DummyModule(torch.nn.Module):
11501157
"""This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a
11511158
DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor.

0 commit comments

Comments
 (0)