Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,24 @@
from executorch.extension.llm.export.config.llm_config import LlmConfig
from executorch.extension.llm.export.partitioner_lib import (
get_coreml_partitioner,
get_ethosu_partitioner,
get_mps_partitioner,
get_openvino_partitioner,
get_qnn_partitioner,
get_tosa_partitioner,
get_vgf_partitioner,
get_vulkan_partitioner,
get_xnnpack_partitioner,
)
from executorch.extension.llm.export.quantizer_lib import (
get_coreml_quantizer,
get_ethosu_quantizer,
get_ov_quantizer,
get_pt2e_quantization_params,
get_pt2e_quantizers,
get_qnn_quantizer,
get_tosa_quantizer,
get_vgf_quantizer,
get_vulkan_quantizer,
)
from executorch.util.activation_memory_profiler import generate_memory_trace
Expand Down Expand Up @@ -210,6 +216,8 @@ def build_args_parser() -> argparse.ArgumentParser:
"coreml_baseline_8a_c8w",
"coreml_baseline_8a_c4w",
"vulkan_8w",
"tosa_8a8w",
"ethosu_8a8w",
],
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
)
Expand Down Expand Up @@ -788,6 +796,26 @@ def get_quantizer_and_quant_params(llm_config):
llm_config.quantization.pt2e_quantize.value
)
quantizers.append(coreml_quantizer)
if llm_config.backend.tosa.enabled and llm_config.quantization.pt2e_quantize:
tosa_quantizer = get_tosa_quantizer(
llm_config.backend.tosa.version, llm_config.quantization.pt2e_quantize.value
)
quantizers.append(tosa_quantizer)
if llm_config.backend.ethosu.enabled and llm_config.quantization.pt2e_quantize:
ethosu_quantizer = get_ethosu_quantizer(
llm_config.backend.ethosu.target,
llm_config.backend.ethosu.system_config,
llm_config.backend.ethosu.memory_mode,
llm_config.quantization.pt2e_quantize.value,
)
quantizers.append(ethosu_quantizer)
if llm_config.backend.vgf.enabled and llm_config.quantization.pt2e_quantize:
vgf_quantizer = get_vgf_quantizer(
llm_config.backend.vgf.compile_spec,
llm_config.backend.vgf.compiler_flags,
llm_config.quantization.pt2e_quantize.value,
)
quantizers.append(vgf_quantizer)
if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize:
assert (
len(quantizers) == 0
Expand Down Expand Up @@ -930,6 +958,47 @@ def _to_edge_and_lower_llama_openvino(
return builder.to_executorch(passes=additional_passes)


def _to_edge_and_lower_llama_arm(
builder_exported,
modelname,
quantizers,
additional_passes,
llm_config: LlmConfig,
verbose: bool = False,
) -> LLMEdgeManager:

logging.info("Lowering model using TOSA partitioner")

partitioners = []
if llm_config.backend.ethosu.enabled:
partitioners.append(
get_ethosu_partitioner(
llm_config.backend.ethosu.target,
)
)
modelname = f"ethosu_{modelname}"
elif llm_config.backend.vgf.enabled:
partitioners.append(
get_vgf_partitioner(
llm_config.backend.vgf.compile_spec,
llm_config.backend.vgf.compiler_flags,
)
)
modelname = f"vgf_{modelname}"
elif llm_config.backend.tosa.enabled:
partitioners.append(get_tosa_partitioner(llm_config.backend.tosa.version))
modelname = f"tosa_{modelname}"

builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
partitioners
)

if verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)

return builder.to_executorch(passes=additional_passes)


def _to_edge_and_lower_llama( # noqa: C901
builder_exported,
modelname,
Expand Down Expand Up @@ -1119,7 +1188,14 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]

# export_to_edge
builder_exported = _prepare_for_llama_export(llm_config).export()
builder_manager = _prepare_for_llama_export(llm_config)
if (
llm_config.backend.tosa.enabled
or llm_config.backend.vgf.enabled
or llm_config.backend.ethosu.enabled
):
builder_manager.skip_dim_order = False
builder_exported = builder_manager.export()
builder_exported.run_canonical_optimizations()
modelname = builder_exported.modelname

Expand Down Expand Up @@ -1162,6 +1238,19 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
openvino_device=llm_config.backend.openvino.device,
verbose=llm_config.debug.verbose,
)
elif (
llm_config.backend.tosa.enabled
or llm_config.backend.ethosu.enabled
or llm_config.backend.vgf.enabled
):
builder = _to_edge_and_lower_llama_arm(
builder_exported,
modelname,
quantizers,
additional_passes,
llm_config,
verbose=llm_config.debug.verbose,
)
else:
builder = _to_edge_and_lower_llama(
builder_exported,
Expand Down
53 changes: 52 additions & 1 deletion examples/models/llama/tests/test_export_llama_lib.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from executorch.backends.arm.quantizer.arm_quantizer import (
EthosUQuantizer,
TOSAQuantizer,
VgfQuantizer,
)

from executorch.devtools.backend_debug import get_delegation_info
from executorch.examples.models.llama.export_llama_lib import (
_export_llama,
build_args_parser,
get_quantizer_and_quant_params,
)
from executorch.extension.llm.export.config.llm_config import LlmConfig
from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize

UNWANTED_OPS = [
"aten_permute_copy_default",
Expand Down Expand Up @@ -48,3 +56,46 @@ def test_has_expected_ops_and_op_counts(self):

for op, _op_info in delegation_info.delegation_by_operator.items():
self.assertTrue(op not in UNWANTED_OPS)

def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self):
llm_config = LlmConfig()
llm_config.backend.tosa.enabled = True
llm_config.quantization.pt2e_quantize = Pt2eQuantize.tosa_8a8w

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
llm_config
)

self.assertIsNone(pt2e_quant_params)
self.assertIsNone(quant_dtype)
self.assertEqual(len(quantizers), 1)
self.assertIsInstance(quantizers[0], TOSAQuantizer)

def test_get_quantizer_and_quant_params_returns_ethosu_quantizer(self):
llm_config = LlmConfig()
llm_config.backend.ethosu.enabled = True
llm_config.quantization.pt2e_quantize = Pt2eQuantize.ethosu_8a8w

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
llm_config
)

self.assertIsNone(pt2e_quant_params)
self.assertIsNone(quant_dtype)
self.assertEqual(len(quantizers), 1)
self.assertIsInstance(quantizers[0], EthosUQuantizer)

def test_get_quantizer_and_quant_params_returns_vgf_quantizer(self):
llm_config = LlmConfig()
llm_config.backend.vgf.enabled = True
llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT"
llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_8a8w

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
llm_config
)

self.assertIsNone(pt2e_quant_params)
self.assertIsNone(quant_dtype)
self.assertEqual(len(quantizers), 1)
self.assertIsInstance(quantizers[0], VgfQuantizer)
5 changes: 4 additions & 1 deletion extension/llm/export/builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(
dynamic_shapes: Optional[Any] = None,
save_exported_program: bool = False,
generate_etrecord: bool = False,
skip_dim_order: bool = True,
):
# Store necessary constructor arguments.
self.model = model
Expand All @@ -118,6 +120,7 @@ def __init__(
self.dynamic_shapes = dynamic_shapes
self.save_exported_program = save_exported_program
self.generate_etrecord = generate_etrecord
self.skip_dim_order = skip_dim_order

# Note: treat this as the source of truth for the result of
# torch.export'ing a model. If the overall ExportedProgram is needed,
Expand Down Expand Up @@ -197,7 +200,7 @@ def _get_dynamic_shape(self) -> Any:
def _get_edge_config(self) -> EdgeCompileConfig:
edge_config = EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
_skip_dim_order=self.skip_dim_order,
)
return edge_config

Expand Down
40 changes: 40 additions & 0 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -288,6 +289,9 @@ class Pt2eQuantize(str, Enum):
coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w"
coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w"
vulkan_8w = "vulkan_8w"
tosa_8a8w = "tosa_8a8w"
ethosu_8a8w = "ethosu_8a8w"
vgf_8a8w = "vgf_8a8w"


class SpinQuant(str, Enum):
Expand Down Expand Up @@ -474,6 +478,39 @@ class TorchAOKernelsConfig:
use_torchao_kernels_tied_embedding: bool = False


@dataclass
class TosaConfig:
"""
Configures the TOSA backend.
"""

enabled: bool = False
version: str = "TOSA-1.0+INT"


@dataclass
class EthosUConfig:
"""
Configures the Ethos-U backend.
"""

enabled: bool = False
target: str = "ethos-u85-128" # Default target, can be overridden.
memory_mode: str = "default"
system_config: str = "default"


@dataclass
class VgfConfig:
"""
Configures the VGF backend.
"""

enabled: bool = False
compile_spec: Optional[str] = "TOSA-1.0+INT"
compiler_flags: List[str] = field(default_factory=list)


@dataclass
class BackendConfig:
"""
Expand All @@ -488,6 +525,9 @@ class BackendConfig:
mps: MPSConfig = field(default_factory=MPSConfig)
openvino: OpenvinoConfig = field(default_factory=OpenvinoConfig)
torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig)
tosa: TosaConfig = field(default_factory=TosaConfig)
ethosu: EthosUConfig = field(default_factory=EthosUConfig)
vgf: VgfConfig = field(default_factory=VgfConfig)


################################################################################
Expand Down
32 changes: 31 additions & 1 deletion extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
from typing import List, Optional


def get_xnnpack_partitioner(dynamic_quant_only_partitioner: bool = True):
Expand Down Expand Up @@ -236,3 +237,32 @@ def get_qnn_partitioner(
# TODO: if deprecated legacy export, skip_mutable_buffer can be set False
skip_mutable_buffer=True,
)


def get_tosa_partitioner(version: str):
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner

compile_spec = TosaCompileSpec(version)

return TOSAPartitioner(compile_spec)


def get_ethosu_partitioner(target: str):
from executorch.backends.arm.ethosu.compile_spec import EthosUCompileSpec
from executorch.backends.arm.ethosu.partitioner import EthosUPartitioner

compile_spec = EthosUCompileSpec(target)

return EthosUPartitioner(compile_spec)


def get_vgf_partitioner(
compile_spec: Optional[str], compiler_flags: Optional[List[str]]
):
from executorch.backends.arm.vgf.compile_spec import VgfCompileSpec
from executorch.backends.arm.vgf.partitioner import VgfPartitioner

compile_spec_obj = VgfCompileSpec(compile_spec, compiler_flags)

return VgfPartitioner(compile_spec_obj)
Loading
Loading