diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 219cc71ded1..cdf843a5e15 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -34,18 +34,24 @@ from executorch.extension.llm.export.config.llm_config import LlmConfig from executorch.extension.llm.export.partitioner_lib import ( get_coreml_partitioner, + get_ethosu_partitioner, get_mps_partitioner, get_openvino_partitioner, get_qnn_partitioner, + get_tosa_partitioner, + get_vgf_partitioner, get_vulkan_partitioner, get_xnnpack_partitioner, ) from executorch.extension.llm.export.quantizer_lib import ( get_coreml_quantizer, + get_ethosu_quantizer, get_ov_quantizer, get_pt2e_quantization_params, get_pt2e_quantizers, get_qnn_quantizer, + get_tosa_quantizer, + get_vgf_quantizer, get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace @@ -210,6 +216,8 @@ def build_args_parser() -> argparse.ArgumentParser: "coreml_baseline_8a_c8w", "coreml_baseline_8a_c4w", "vulkan_8w", + "tosa_8a8w", + "ethosu_8a8w", ], help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.", ) @@ -788,6 +796,26 @@ def get_quantizer_and_quant_params(llm_config): llm_config.quantization.pt2e_quantize.value ) quantizers.append(coreml_quantizer) + if llm_config.backend.tosa.enabled and llm_config.quantization.pt2e_quantize: + tosa_quantizer = get_tosa_quantizer( + llm_config.backend.tosa.version, llm_config.quantization.pt2e_quantize.value + ) + quantizers.append(tosa_quantizer) + if llm_config.backend.ethosu.enabled and llm_config.quantization.pt2e_quantize: + ethosu_quantizer = get_ethosu_quantizer( + llm_config.backend.ethosu.target, + llm_config.backend.ethosu.system_config, + llm_config.backend.ethosu.memory_mode, + llm_config.quantization.pt2e_quantize.value, + ) + quantizers.append(ethosu_quantizer) + if llm_config.backend.vgf.enabled and llm_config.quantization.pt2e_quantize: + vgf_quantizer = get_vgf_quantizer( + llm_config.backend.vgf.compile_spec, + llm_config.backend.vgf.compiler_flags, + llm_config.quantization.pt2e_quantize.value, + ) + quantizers.append(vgf_quantizer) if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize: assert ( len(quantizers) == 0 @@ -930,6 +958,47 @@ def _to_edge_and_lower_llama_openvino( return builder.to_executorch(passes=additional_passes) +def _to_edge_and_lower_llama_arm( + builder_exported, + modelname, + quantizers, + additional_passes, + llm_config: LlmConfig, + verbose: bool = False, +) -> LLMEdgeManager: + + logging.info("Lowering model using TOSA partitioner") + + partitioners = [] + if llm_config.backend.ethosu.enabled: + partitioners.append( + get_ethosu_partitioner( + llm_config.backend.ethosu.target, + ) + ) + modelname = f"ethosu_{modelname}" + elif llm_config.backend.vgf.enabled: + partitioners.append( + get_vgf_partitioner( + llm_config.backend.vgf.compile_spec, + llm_config.backend.vgf.compiler_flags, + ) + ) + modelname = f"vgf_{modelname}" + elif llm_config.backend.tosa.enabled: + partitioners.append(get_tosa_partitioner(llm_config.backend.tosa.version)) + modelname = f"tosa_{modelname}" + + builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( + partitioners + ) + + if verbose: + print_delegation_info(builder.edge_manager.exported_program().graph_module) + + return builder.to_executorch(passes=additional_passes) + + def _to_edge_and_lower_llama( # noqa: C901 builder_exported, modelname, @@ -1119,7 +1188,14 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] # export_to_edge - builder_exported = _prepare_for_llama_export(llm_config).export() + builder_manager = _prepare_for_llama_export(llm_config) + if ( + llm_config.backend.tosa.enabled + or llm_config.backend.vgf.enabled + or llm_config.backend.ethosu.enabled + ): + builder_manager.skip_dim_order = False + builder_exported = builder_manager.export() builder_exported.run_canonical_optimizations() modelname = builder_exported.modelname @@ -1162,6 +1238,19 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 openvino_device=llm_config.backend.openvino.device, verbose=llm_config.debug.verbose, ) + elif ( + llm_config.backend.tosa.enabled + or llm_config.backend.ethosu.enabled + or llm_config.backend.vgf.enabled + ): + builder = _to_edge_and_lower_llama_arm( + builder_exported, + modelname, + quantizers, + additional_passes, + llm_config, + verbose=llm_config.debug.verbose, + ) else: builder = _to_edge_and_lower_llama( builder_exported, diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index 172517207de..a0c814a8b7f 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -1,17 +1,25 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest +from executorch.backends.arm.quantizer.arm_quantizer import ( + EthosUQuantizer, + TOSAQuantizer, + VgfQuantizer, +) + from executorch.devtools.backend_debug import get_delegation_info from executorch.examples.models.llama.export_llama_lib import ( _export_llama, build_args_parser, + get_quantizer_and_quant_params, ) -from executorch.extension.llm.export.config.llm_config import LlmConfig +from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize UNWANTED_OPS = [ "aten_permute_copy_default", @@ -48,3 +56,46 @@ def test_has_expected_ops_and_op_counts(self): for op, _op_info in delegation_info.delegation_by_operator.items(): self.assertTrue(op not in UNWANTED_OPS) + + def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self): + llm_config = LlmConfig() + llm_config.backend.tosa.enabled = True + llm_config.quantization.pt2e_quantize = Pt2eQuantize.tosa_8a8w + + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) + + self.assertIsNone(pt2e_quant_params) + self.assertIsNone(quant_dtype) + self.assertEqual(len(quantizers), 1) + self.assertIsInstance(quantizers[0], TOSAQuantizer) + + def test_get_quantizer_and_quant_params_returns_ethosu_quantizer(self): + llm_config = LlmConfig() + llm_config.backend.ethosu.enabled = True + llm_config.quantization.pt2e_quantize = Pt2eQuantize.ethosu_8a8w + + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) + + self.assertIsNone(pt2e_quant_params) + self.assertIsNone(quant_dtype) + self.assertEqual(len(quantizers), 1) + self.assertIsInstance(quantizers[0], EthosUQuantizer) + + def test_get_quantizer_and_quant_params_returns_vgf_quantizer(self): + llm_config = LlmConfig() + llm_config.backend.vgf.enabled = True + llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT" + llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_8a8w + + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) + + self.assertIsNone(pt2e_quant_params) + self.assertIsNone(quant_dtype) + self.assertEqual(len(quantizers), 1) + self.assertIsInstance(quantizers[0], VgfQuantizer) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ae15dded91d..5b7b9bc4c74 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -96,6 +97,7 @@ def __init__( dynamic_shapes: Optional[Any] = None, save_exported_program: bool = False, generate_etrecord: bool = False, + skip_dim_order: bool = True, ): # Store necessary constructor arguments. self.model = model @@ -118,6 +120,7 @@ def __init__( self.dynamic_shapes = dynamic_shapes self.save_exported_program = save_exported_program self.generate_etrecord = generate_etrecord + self.skip_dim_order = skip_dim_order # Note: treat this as the source of truth for the result of # torch.export'ing a model. If the overall ExportedProgram is needed, @@ -197,7 +200,7 @@ def _get_dynamic_shape(self) -> Any: def _get_edge_config(self) -> EdgeCompileConfig: edge_config = EdgeCompileConfig( _check_ir_validity=False, - _skip_dim_order=True, + _skip_dim_order=self.skip_dim_order, ) return edge_config diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index b40fad88a9c..4ed00391c03 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -288,6 +289,9 @@ class Pt2eQuantize(str, Enum): coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w" coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w" vulkan_8w = "vulkan_8w" + tosa_8a8w = "tosa_8a8w" + ethosu_8a8w = "ethosu_8a8w" + vgf_8a8w = "vgf_8a8w" class SpinQuant(str, Enum): @@ -474,6 +478,39 @@ class TorchAOKernelsConfig: use_torchao_kernels_tied_embedding: bool = False +@dataclass +class TosaConfig: + """ + Configures the TOSA backend. + """ + + enabled: bool = False + version: str = "TOSA-1.0+INT" + + +@dataclass +class EthosUConfig: + """ + Configures the Ethos-U backend. + """ + + enabled: bool = False + target: str = "ethos-u85-128" # Default target, can be overridden. + memory_mode: str = "default" + system_config: str = "default" + + +@dataclass +class VgfConfig: + """ + Configures the VGF backend. + """ + + enabled: bool = False + compile_spec: Optional[str] = "TOSA-1.0+INT" + compiler_flags: List[str] = field(default_factory=list) + + @dataclass class BackendConfig: """ @@ -488,6 +525,9 @@ class BackendConfig: mps: MPSConfig = field(default_factory=MPSConfig) openvino: OpenvinoConfig = field(default_factory=OpenvinoConfig) torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig) + tosa: TosaConfig = field(default_factory=TosaConfig) + ethosu: EthosUConfig = field(default_factory=EthosUConfig) + vgf: VgfConfig = field(default_factory=VgfConfig) ################################################################################ diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 03ac2bd91e4..e28557112f8 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -1,10 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import List, Optional def get_xnnpack_partitioner(dynamic_quant_only_partitioner: bool = True): @@ -236,3 +237,32 @@ def get_qnn_partitioner( # TODO: if deprecated legacy export, skip_mutable_buffer can be set False skip_mutable_buffer=True, ) + + +def get_tosa_partitioner(version: str): + from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec + from executorch.backends.arm.tosa.partitioner import TOSAPartitioner + + compile_spec = TosaCompileSpec(version) + + return TOSAPartitioner(compile_spec) + + +def get_ethosu_partitioner(target: str): + from executorch.backends.arm.ethosu.compile_spec import EthosUCompileSpec + from executorch.backends.arm.ethosu.partitioner import EthosUPartitioner + + compile_spec = EthosUCompileSpec(target) + + return EthosUPartitioner(compile_spec) + + +def get_vgf_partitioner( + compile_spec: Optional[str], compiler_flags: Optional[List[str]] +): + from executorch.backends.arm.vgf.compile_spec import VgfCompileSpec + from executorch.backends.arm.vgf.partitioner import VgfPartitioner + + compile_spec_obj = VgfCompileSpec(compile_spec, compiler_flags) + + return VgfPartitioner(compile_spec_obj) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 592a6666dfa..80020d62a66 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -320,3 +321,66 @@ def get_vulkan_quantizer(pt2e_quantize: str): quantizer = VulkanQuantizer().set_global(config) return quantizer + + +def get_tosa_quantizer(version: str, pt2e_quantize: str): + from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, + ) + from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec + + compile_spec = TosaCompileSpec(version) + + quantizer = TOSAQuantizer(compile_spec) + + if pt2e_quantize == "tosa_8a8w": + quantizer.set_global(get_symmetric_quantization_config()) + else: + raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}") + + return quantizer + + +def get_ethosu_quantizer( + target: str, system_config: str, memory_mode: str, pt2e_quantize: str +): + from executorch.backends.arm.ethosu.compile_spec import EthosUCompileSpec + from executorch.backends.arm.quantizer.arm_quantizer import ( + EthosUQuantizer, + get_symmetric_quantization_config, + ) + + compile_spec = EthosUCompileSpec(target, system_config, memory_mode) + + quantizer = EthosUQuantizer(compile_spec) + + if pt2e_quantize == "ethosu_8a8w": + quantizer.set_global(get_symmetric_quantization_config()) + else: + raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}") + + return quantizer + + +def get_vgf_quantizer( + compile_spec: Optional[str], + compiler_flags: Optional[List[str]], + pt2e_quantize: str, +): + from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + VgfQuantizer, + ) + from executorch.backends.arm.vgf.compile_spec import VgfCompileSpec + + compile_spec_obj = VgfCompileSpec(compile_spec, compiler_flags) + + quantizer = VgfQuantizer(compile_spec_obj) + + if pt2e_quantize == "vgf_8a8w": + quantizer.set_global(get_symmetric_quantization_config()) + else: + raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}") + + return quantizer