diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 7dd16f856cd..5b46e62067f 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -23,9 +23,14 @@ MINICONDA_VERSION=23.10.0-1 BUCK2_VERSION=$(cat ci_commit_pins/buck2.txt) case "${IMAGE_NAME}" in - executorch-ubuntu-22.04-gcc9) + executorch-ubuntu-22.04-gcc11) + LINTRUNNER="" + GCC_VERSION=11 + ;; + executorch-ubuntu-22.04-gcc9-nopytorch) LINTRUNNER="" GCC_VERSION=9 + SKIP_PYTORCH=yes ;; executorch-ubuntu-22.04-clang12) LINTRUNNER="" @@ -54,13 +59,13 @@ case "${IMAGE_NAME}" in executorch-ubuntu-22.04-mediatek-sdk) MEDIATEK_SDK=yes CLANG_VERSION=12 - ANDROID_NDK_VERSION=r27b + ANDROID_NDK_VERSION=r28c ;; executorch-ubuntu-22.04-clang12-android) LINTRUNNER="" CLANG_VERSION=12 # From https://developer.android.com/ndk/downloads - ANDROID_NDK_VERSION=r27b + ANDROID_NDK_VERSION=r28c ;; *) echo "Invalid image name ${IMAGE_NAME}" @@ -95,6 +100,7 @@ docker build \ --build-arg "QNN_SDK=${QNN_SDK:-}" \ --build-arg "MEDIATEK_SDK=${MEDIATEK_SDK:-}" \ --build-arg "ANDROID_NDK_VERSION=${ANDROID_NDK_VERSION:-}" \ + --build-arg "SKIP_PYTORCH=${SKIP_PYTORCH:-}" \ -f "${OS}"/Dockerfile \ "$@" \ . diff --git a/.ci/docker/ci_commit_pins/optimum-executorch.txt b/.ci/docker/ci_commit_pins/optimum-executorch.txt index ef3282ba6cc..156ff2f3c82 100644 --- a/.ci/docker/ci_commit_pins/optimum-executorch.txt +++ b/.ci/docker/ci_commit_pins/optimum-executorch.txt @@ -1 +1 @@ -40b02a2dc61bbf901a2df91719f47c98d65368ec +0123293118efb08ac4ffc4fefe9d330201465c93 diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index 8c9330d6f2c..f4ec226f512 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -4d4abec80f03cd8fdefe1d9cb3a60d3690cd777e +7a064ed3eafa43f17412d434b395240c727b3000 diff --git a/.ci/docker/common/install_arm.sh b/.ci/docker/common/install_arm.sh new file mode 100644 index 00000000000..dec8a1693ee --- /dev/null +++ b/.ci/docker/common/install_arm.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +install_arm_prerequiresites() { + apt-get update -y + apt-get install -y --no-install-recommends \ + mesa-vulkan-drivers libvulkan1 + rm -rf /var/lib/apt/lists/* +} + +install_arm_prerequiresites diff --git a/.ci/docker/common/install_pytorch.sh b/.ci/docker/common/install_pytorch.sh index 4bf33348681..9809b6a8e3c 100755 --- a/.ci/docker/common/install_pytorch.sh +++ b/.ci/docker/common/install_pytorch.sh @@ -12,8 +12,8 @@ source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" install_domains() { echo "Install torchvision and torchaudio" - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${TORCHAUDIO_VERSION}" - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${TORCHVISION_VERSION}" + pip_install --no-build-isolation --user "git+https://github.com/pytorch/audio.git@${TORCHAUDIO_VERSION}" + pip_install --no-build-isolation --user "git+https://github.com/pytorch/vision.git@${TORCHVISION_VERSION}" } install_pytorch_and_domains() { diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index dcd2afa7a13..f25d340fdb3 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -1,12 +1,12 @@ mpmath==1.3.0 numpy>=2.0.0; python_version >= '3.10' PyYAML==6.0.1 -ruamel.yaml==0.17.32 -sympy==1.12 +ruamel.yaml==0.18.15 +sympy>=1.13.3 timm==0.6.13 tomli==2.0.1 torchsr==1.0.4 -transformers==4.47.1 +transformers==4.56.1 zstd==1.5.5.1 pandas>=2.2.2; python_version >= '3.10' pytest==7.2.0 @@ -16,18 +16,20 @@ hypothesis==6.84.2 parameterized==0.9.0 # Doc build requirements, same as https://github.com/pytorch/pytorch/blob/main/.ci/docker/requirements-docs.txt -sphinx==5.3.0 +sphinx==7.2.6 +sphinxcontrib.katex==0.9.10 +breathe==4.36.0 # only if generating C++ +exhale==0.3.7 # only if generating C++ docs +docutils==0.18.1,<0.21 +sphinx-design==0.6.1 +sphinxcontrib-mermaid==1.0.0 +myst-parser==3.0.1 # if want to contribute in markdown +sphinx-gallery==0.14.0 # only if hosting interactive tutorials +sphinx-sitemap==2.7.1 sphinx-reredirects==0.1.4 -sphinx-gallery==0.14.0 -breathe==4.34.0 -exhale==0.2.3 -docutils==0.16 matplotlib>=3.9.4 +sphinx-copybutton==0.5.2 # PyTorch Theme --e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme -myst-parser==0.18.1 -sphinx_design==0.4.1 -sphinx-copybutton==0.5.0 - +pytorch_sphinx_theme2==0.2.0 # script unit test requirements yaspin==3.1.0 diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index fddd7e6df36..b7478df5489 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -64,9 +64,10 @@ ENV SCCACHE_S3_KEY_PREFIX executorch ENV SCCACHE_REGION us-east-1 ARG TORCH_VERSION +ARG SKIP_PYTORCH COPY ./common/install_pytorch.sh install_pytorch.sh COPY ./common/utils.sh utils.sh -RUN bash ./install_pytorch.sh && rm install_pytorch.sh utils.sh +RUN if [ -z "${SKIP_PYTORCH}" ]; then bash ./install_pytorch.sh; fi && rm install_pytorch.sh utils.sh ARG LINTRUNNER # Install lintrunner if needed @@ -83,6 +84,9 @@ RUN if [ -n "${ANDROID_NDK_VERSION}" ]; then bash ./install_android.sh; fi RUN rm install_android.sh ARG ARM_SDK +COPY ./common/install_arm.sh install_arm.sh +RUN if [ -n "${ARM_SDK}" ]; then bash ./install_arm.sh; fi +RUN rm install_arm.sh ARG ZEPHYR_SDK COPY ./common/install_zephyr.sh install_zephyr.sh diff --git a/.ci/scripts/build-qnn-sdk.sh b/.ci/scripts/build-qnn-sdk.sh index 7f34e8afb63..0968fc2a096 100755 --- a/.ci/scripts/build-qnn-sdk.sh +++ b/.ci/scripts/build-qnn-sdk.sh @@ -18,7 +18,7 @@ build_qnn_backend() { export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." && pwd)" parallelism=$(( $(nproc) - 1 )) - bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number ${parallelism} --release + bash backends/qualcomm/scripts/build.sh --skip_linux_android --skip_linux_embedded --job_number ${parallelism} --release } set_up_aot() { @@ -38,14 +38,14 @@ set_up_aot() { -DEXECUTORCH_BUILD_EXTENSION_EXTENSION_LLM=ON \ -DEXECUTORCH_BUILD_EXTENSION_EXTENSION_LLM_RUNNER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ -DPYTHON_EXECUTABLE=python3 - cmake --build $PWD --target "PyQnnManagerAdaptor" "PyQnnWrapperAdaptor" -j$(nproc) + cmake --build $PWD --target "PyQnnManagerAdaptor" -j$(nproc) # install Python APIs to correct import path # The filename might vary depending on your Python and host version. cp -f backends/qualcomm/PyQnnManagerAdaptor.cpython-310-x86_64-linux-gnu.so $EXECUTORCH_ROOT/backends/qualcomm/python - cp -f backends/qualcomm/PyQnnWrapperAdaptor.cpython-310-x86_64-linux-gnu.so $EXECUTORCH_ROOT/backends/qualcomm/python popd # Workaround for fbs files in exir/_serialize diff --git a/.ci/scripts/cuda_benchmark.py b/.ci/scripts/cuda_benchmark.py new file mode 100644 index 00000000000..b135925d4b4 --- /dev/null +++ b/.ci/scripts/cuda_benchmark.py @@ -0,0 +1,939 @@ +""" +Benchmark script for CUDA model runners. +Runs model runner commands multiple times and collects performance metrics. +Supports whisper, voxtral, gemma3, and other CUDA models. +""" + +import argparse +import json +import statistics +import subprocess +import sys +from dataclasses import dataclass +from typing import List, Optional, Tuple + + +@dataclass +class RunMetrics: + """Metrics from a single run.""" + + generated_tokens: int + tokens_per_sec: float + model_load_time_ms: float + total_inference_time_ms: float + encoder_time_ms: float + generation_time_ms: float + first_token_latency_ms: float + + def __repr__(self): + return ( + f"Tokens: {self.generated_tokens}, " + f"Throughput: {self.tokens_per_sec:.2f} t/s, " + f"Model load: {self.model_load_time_ms:.0f}ms, " + f"Total inference: {self.total_inference_time_ms:.0f}ms, " + f"Encoder: {self.encoder_time_ms:.0f}ms, " + f"Generation: {self.generation_time_ms:.0f}ms, " + f"First token: {self.first_token_latency_ms:.0f}ms" + ) + + +def parse_pytorch_observer_log(log_line: str) -> Optional[RunMetrics]: + """Parse PyTorchObserver JSON output and compute metrics.""" + try: + # Find the JSON part in the log line + if "PyTorchObserver" not in log_line: + return None + + json_str = log_line.split("PyTorchObserver")[1].strip() + data = json.loads(json_str) + + # Extract values + generated_tokens = data.get("generated_tokens", 0) + inference_start_ms = data.get("inference_start_ms", 0) + inference_end_ms = data.get("inference_end_ms", 0) + prompt_eval_end_ms = data.get("prompt_eval_end_ms", 0) + first_token_ms = data.get("first_token_ms", 0) + model_load_start_ms = data.get("model_load_start_ms", 0) + model_load_end_ms = data.get("model_load_end_ms", 0) + + # Compute metrics + # Total inference time: from inference start to inference end + total_inference_time_ms = inference_end_ms - inference_start_ms + + # Encoder time: from inference start to prompt evaluation end + encoder_time_ms = prompt_eval_end_ms - inference_start_ms + + # Generation time: from prompt evaluation end to inference end + generation_time_ms = inference_end_ms - prompt_eval_end_ms + + # Calculate throughput based on generation time + tokens_per_sec = ( + (generated_tokens / generation_time_ms * 1000) + if generation_time_ms > 0 + else 0 + ) + model_load_time_ms = model_load_end_ms - model_load_start_ms + first_token_latency_ms = first_token_ms - prompt_eval_end_ms + + return RunMetrics( + generated_tokens=generated_tokens, + tokens_per_sec=tokens_per_sec, + model_load_time_ms=model_load_time_ms, + total_inference_time_ms=total_inference_time_ms, + encoder_time_ms=encoder_time_ms, + generation_time_ms=generation_time_ms, + first_token_latency_ms=first_token_latency_ms, + ) + except (json.JSONDecodeError, KeyError, ValueError) as e: + print(f"Error parsing PyTorchObserver log: {e}", file=sys.stderr) + return None + + +def get_gpu_clocks() -> Optional[Tuple[str, str]]: + """Get current GPU and memory clock frequencies.""" + try: + # Get GPU clock + result_gpu = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=clocks.gr", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=10, + ) + # Get memory clock + result_mem = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=clocks.mem", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=10, + ) + + if result_gpu.returncode == 0 and result_mem.returncode == 0: + gpu_clock = result_gpu.stdout.strip().split("\n")[0] + mem_clock = result_mem.stdout.strip().split("\n")[0] + return gpu_clock, mem_clock + except Exception as e: + print(f"Warning: Failed to get GPU clocks: {e}", file=sys.stderr) + return None + + +def set_gpu_clocks(gpu_clock: Optional[int] = None) -> bool: + """ + Set GPU clock frequency to a fixed value. + + Args: + gpu_clock: Target GPU clock frequency in MHz. + If None, will use max available. + + Returns: + True if successful, False otherwise + """ + try: + print("\n[GPU Clock Setup] Fixing GPU clock frequency...") + + # Enable persistence mode + result = subprocess.run( + ["sudo", "nvidia-smi", "-pm", "1"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode != 0: + print( + f"Warning: Failed to enable persistence mode: {result.stderr}", + file=sys.stderr, + ) + return False + print("✓ Enabled persistence mode") + + # Lock GPU clocks + if gpu_clock is None: + # Get max GPU clock + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=clocks.max.gr", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + gpu_clock = int(result.stdout.strip().split("\n")[0]) + print(f"✓ Detected max GPU clock: {gpu_clock} MHz") + + # Lock GPU clock to the target frequency + result = subprocess.run( + ["sudo", "nvidia-smi", "-lgc", f"{gpu_clock},{gpu_clock}"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode != 0: + print( + f"Warning: Failed to lock GPU clock: {result.stderr}", + file=sys.stderr, + ) + return False + + print(f"✓ Locked GPU clock to {gpu_clock} MHz") + return True + + except Exception as e: + print(f"Error: Failed to set GPU clocks: {e}", file=sys.stderr) + return False + + +def reset_gpu_clocks() -> bool: + """Reset GPU clock frequencies to default.""" + try: + print("\n[GPU Clock Cleanup] Resetting GPU clock frequency...") + + # Reset GPU clocks + result = subprocess.run( + ["sudo", "nvidia-smi", "-rgc"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode != 0: + print( + f"Warning: Failed to reset GPU clock: {result.stderr}", + file=sys.stderr, + ) + return False + print("✓ Reset GPU clock to default") + + # Disable persistence mode + result = subprocess.run( + ["sudo", "nvidia-smi", "-pm", "0"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode != 0: + print( + "Warning: Failed to disable persistence mode: " f"{result.stderr}", + file=sys.stderr, + ) + return False + print("✓ Disabled persistence mode") + + return True + + except Exception as e: + print(f"Error: Failed to reset GPU clocks: {e}", file=sys.stderr) + return False + + +def _print_warmup_info(warmup_runs: int) -> None: + """Print warmup phase information.""" + if warmup_runs > 0: + print(f"\n{'='*70}") + print(f"WARMUP PHASE: Running {warmup_runs} warmup iterations...") + print(f"{'='*70}") + + +def _print_benchmark_info( + actual_benchmark_runs: int, trim_count: int, num_runs: int +) -> None: + """Print benchmark phase information.""" + print(f"\n{'='*70}") + print(f"BENCHMARK PHASE: Running {actual_benchmark_runs} iterations") + print(f"Will trim top and bottom {trim_count} results (10% of {num_runs})") + print(f"Final statistics will be based on middle {num_runs} results") + print(f"{'='*70}") + + +def _run_single_iteration( + command: str, run_num: int, verbose: bool +) -> Optional[RunMetrics]: + """ + Run a single benchmark iteration and return metrics. + + Args: + command: Command to execute + run_num: Current run number + verbose: Print verbose output + + Returns: + RunMetrics if successful, None otherwise + """ + try: + # Run command and capture output + result = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + timeout=300, # 5 minute timeout + ) + + if result.returncode != 0: + print( + f"Error: Command failed with return code {result.returncode}", + file=sys.stderr, + ) + if result.stderr: + print(f"stderr: {result.stderr}", file=sys.stderr) + return None + + # Search for PyTorchObserver line in output + observer_line = None + for line in result.stdout.split("\n"): + if "PyTorchObserver" in line: + observer_line = line + break + + if observer_line is None: + print( + f"Warning: No PyTorchObserver output found in run {run_num}", + file=sys.stderr, + ) + if verbose: + print(f"stdout:\n{result.stdout}", file=sys.stderr) + return None + + # Parse and return metrics + metrics = parse_pytorch_observer_log(observer_line) + if metrics is None: + print( + f"Warning: Failed to parse metrics from run {run_num}", + file=sys.stderr, + ) + return None + + print(f"✓ {metrics}") + return metrics + + except subprocess.TimeoutExpired: + print(f"Error: Command timed out on run {run_num}", file=sys.stderr) + return None + except Exception as e: + print(f"Error on run {run_num}: {e}", file=sys.stderr) + return None + + +def run_model_benchmark( + command: str, + num_runs: int = 5, + warmup_runs: int = 0, + verbose: bool = False, +) -> List[RunMetrics]: + """ + Run the model runner command multiple times and collect metrics. + + For trimmed mean calculation, this function runs extra iterations + to ensure we can trim outliers. Based on num_runs, we calculate + trim_count = num_runs * 0.1, then run num_runs + 2*trim_count total + iterations. The top and bottom trim_count results will be discarded. + + Args: + command: Full command to run + num_runs: Number of benchmark runs requested by user (after trim) + warmup_runs: Number of warmup runs (results will be discarded) + verbose: Print detailed output + + Returns: + List of RunMetrics from benchmark runs (excluding warmup). + """ + # Calculate trim count and total runs + trim_count = int(num_runs * 0.1) + actual_benchmark_runs = num_runs + 2 * trim_count + total_runs = warmup_runs + actual_benchmark_runs + + # Print phase information + _print_warmup_info(warmup_runs) + _print_benchmark_info(actual_benchmark_runs, trim_count, num_runs) + + # Execute all runs + results = [] + for run_num in range(1, total_runs + 1): + is_warmup = run_num <= warmup_runs + phase = "Warmup" if is_warmup else "Benchmark" + benchmark_run_num = run_num - warmup_runs if not is_warmup else run_num + + # Print run header + if is_warmup: + print(f"\n[{phase} {run_num}/{warmup_runs}] Executing: {command}") + else: + print( + f"\n[{phase} {benchmark_run_num}/{actual_benchmark_runs}] " + f"Executing: {command}" + ) + + # Run iteration and collect metrics + metrics = _run_single_iteration(command, run_num, verbose) + if metrics is not None and not is_warmup: + results.append(metrics) + + return results + + +def calculate_trimmed_stats( + values: List[float], trim_count: int +) -> Tuple[List[float], float, float, float, float]: + """ + Calculate statistics on trimmed data. + + Args: + values: List of numeric values + trim_count: Number of values to trim from each end + + Returns: + Tuple of (trimmed_values, min, max, mean, stdev) + """ + if not values: + return [], 0.0, 0.0, 0.0, 0.0 + + # Sort values + sorted_values = sorted(values) + n = len(sorted_values) + + # Trim if we have enough data and trim_count > 0 + if trim_count > 0 and n > 2 * trim_count: + trimmed_values = sorted_values[trim_count : n - trim_count] + else: + trimmed_values = sorted_values + + # Calculate stats on trimmed data + min_val = min(trimmed_values) + max_val = max(trimmed_values) + mean_val = statistics.mean(trimmed_values) + stdev_val = statistics.stdev(trimmed_values) if len(trimmed_values) > 1 else 0.0 + + return trimmed_values, min_val, max_val, mean_val, stdev_val + + +@dataclass +class MetricStats: + """Statistics for a single metric with operations.""" + + name: str + mean: float + min_val: float + max_val: float + stdev: float + unit: str = "" + extra_info: dict | None = None + + def create_v3_record( + self, + model_name: str, + backend: str, + runner_name: str, + runner_type: str, + base_extra_info: dict, + ) -> dict: + """ + Create a v3 format record for this metric. + + Args: + model_name: Model name with quantization + backend: Backend name (e.g., "cuda-aoti") + runner_name: GPU device name + runner_type: CUDA driver version + base_extra_info: Base extra_info dict to copy + + Returns: + Complete v3 format metric record + """ + extra_stats = { + "min": self.min_val, + "max": self.max_val, + "stdev": self.stdev, + } + if self.extra_info: + extra_stats.update(self.extra_info) + + return { + "benchmark": { + "name": "ExecuTorch", + "mode": "inference", + "extra_info": base_extra_info.copy(), + }, + "model": { + "name": model_name, + "type": "OSS model", + "backend": backend, + }, + "metric": { + "name": self.name, + "benchmark_values": [self.mean], + "target_value": 0, + "extra_info": extra_stats, + }, + "runners": [{"name": runner_name, "type": runner_type}], + } + + def print_stats(self) -> None: + """Print formatted statistics for this metric.""" + # Determine precision based on metric type + is_throughput = "tokens" in self.name.lower() + precision = 2 if is_throughput else 0 + + # Format metric name for display + display_name = self.name.replace("_", " ").upper() + if self.unit: + display_name = f"{display_name} ({self.unit})" + + print(f"{display_name}:") + print(f" Min: {self.min_val:.{precision}f} {self.unit}") + print(f" Max: {self.max_val:.{precision}f} {self.unit}") + print(f" Mean: {self.mean:.{precision}f} {self.unit}") + print(f" Stdev: {self.stdev:.{precision}f} {self.unit}") + print() + + +@dataclass +class BenchmarkResults: + """Summary of benchmark results.""" + + model_name: str + total_runs: int + trimmed_runs: int + discarded_runs: int + generated_tokens: int + + # Metrics + throughput: MetricStats + model_load_time: MetricStats + total_inference_time: MetricStats + encoder_time: MetricStats + generation_time: MetricStats + first_token_latency: MetricStats + + def save_json(self, output_path: str) -> None: + """Save results to JSON file.""" + with open(output_path, "w") as f: + json.dump(self.to_dict(), f, indent=2) + print(f"\n✓ Results saved to: {output_path}") + + def to_dict(self) -> dict: + """Convert results to dictionary for JSON serialization.""" + return { + "model_name": self.model_name, + "total_runs": self.total_runs, + "trimmed_runs": self.trimmed_runs, + "discarded_runs": self.discarded_runs, + "generated_tokens": self.generated_tokens, + "throughput_mean": self.throughput.mean, + "throughput_min": self.throughput.min_val, + "throughput_max": self.throughput.max_val, + "throughput_stdev": self.throughput.stdev, + "model_load_time_mean": self.model_load_time.mean, + "model_load_time_min": self.model_load_time.min_val, + "model_load_time_max": self.model_load_time.max_val, + "model_load_time_stdev": self.model_load_time.stdev, + "total_inference_time_mean": self.total_inference_time.mean, + "total_inference_time_min": self.total_inference_time.min_val, + "total_inference_time_max": self.total_inference_time.max_val, + "total_inference_time_stdev": self.total_inference_time.stdev, + "encoder_time_mean": self.encoder_time.mean, + "encoder_time_min": self.encoder_time.min_val, + "encoder_time_max": self.encoder_time.max_val, + "encoder_time_stdev": self.encoder_time.stdev, + "generation_time_mean": self.generation_time.mean, + "generation_time_min": self.generation_time.min_val, + "generation_time_max": self.generation_time.max_val, + "generation_time_stdev": self.generation_time.stdev, + "first_token_latency_mean": self.first_token_latency.mean, + "first_token_latency_min": self.first_token_latency.min_val, + "first_token_latency_max": self.first_token_latency.max_val, + "first_token_latency_stdev": self.first_token_latency.stdev, + } + + def to_v3_format( + self, + model: str, + quantization: str, + git_sha: str, + workflow_run_id: str, + workflow_run_url: str = "", + gpu_name: str = "CUDA", + cuda_driver_version: str = "cuda", + ) -> List[dict]: + """ + Transform benchmark results to PyTorch benchmark database v3 format. + + Args: + model: Model name (e.g., "openai/whisper-small") + quantization: Quantization type (e.g., "non-quantized") + git_sha: Git commit SHA + workflow_run_id: GitHub workflow run ID + workflow_run_url: GitHub workflow run URL + gpu_name: GPU device name (e.g., "Tesla V100", "A100") + cuda_driver_version: CUDA driver version (e.g., "12.6", "535.104.05") + + Returns: + List of benchmark records in v3 format + """ + # Shared configuration + model_name_with_quant = f"{model}_{quantization}" + backend = "cuda-aoti" + runner_name = gpu_name + runner_type = cuda_driver_version + + # Create base extra_info + base_extra_info = { + "backend": "cuda", + "quantization": quantization, + "git_sha": git_sha, + "workflow_run_id": workflow_run_id, + } + if workflow_run_url: + base_extra_info["workflow_run_url"] = workflow_run_url + + # Create v3 records for all metrics + return [ + self.throughput.create_v3_record( + model_name_with_quant, + backend, + runner_name, + runner_type, + base_extra_info, + ), + self.model_load_time.create_v3_record( + model_name_with_quant, + backend, + runner_name, + runner_type, + base_extra_info, + ), + self.total_inference_time.create_v3_record( + model_name_with_quant, + backend, + runner_name, + runner_type, + base_extra_info, + ), + self.encoder_time.create_v3_record( + model_name_with_quant, + backend, + runner_name, + runner_type, + base_extra_info, + ), + self.generation_time.create_v3_record( + model_name_with_quant, + backend, + runner_name, + runner_type, + base_extra_info, + ), + self.first_token_latency.create_v3_record( + model_name_with_quant, + backend, + runner_name, + runner_type, + base_extra_info, + ), + ] + + +def compute_summary( + model_name: str, results: List[RunMetrics], requested_runs: int +) -> BenchmarkResults: + """ + Compute summary statistics using trimmed data. + + All statistics (min, max, mean, stdev) are calculated based on + the trimmed dataset after removing outliers. + + Args: + model_name: Name of the model being benchmarked + results: List of all collected run metrics + requested_runs: Number of runs originally requested by user + + Returns: + BenchmarkResults object with all computed statistics + """ + if not results: + raise ValueError("No valid results to summarize.") + + # Calculate trim count based on requested runs (not actual runs) + trim_count = int(requested_runs * 0.1) + + # Helper to create MetricStats from values + def create_metric_stats( + name: str, values: List[float], unit: str = "", extra_info: dict | None = None + ) -> MetricStats: + _, min_val, max_val, mean_val, stdev_val = calculate_trimmed_stats( + values, trim_count + ) + return MetricStats( + name=name, + mean=mean_val, + min_val=min_val, + max_val=max_val, + stdev=stdev_val, + unit=unit, + extra_info=extra_info, + ) + + # Get the first trimmed result to get trimmed_runs count + trimmed_throughput, _, _, _, _ = calculate_trimmed_stats( + [r.tokens_per_sec for r in results], trim_count + ) + + return BenchmarkResults( + model_name=model_name, + total_runs=len(results), + trimmed_runs=len(trimmed_throughput), + discarded_runs=trim_count * 2, + generated_tokens=results[0].generated_tokens, + throughput=create_metric_stats( + "throughput(tokens/sec)", + [r.tokens_per_sec for r in results], + "t/s", + {"trimmed_runs": len(trimmed_throughput)}, + ), + model_load_time=create_metric_stats( + "model_load_time(ms)", + [r.model_load_time_ms for r in results], + "ms", + ), + total_inference_time=create_metric_stats( + "total_inference_time(ms)", + [r.total_inference_time_ms for r in results], + "ms", + ), + encoder_time=create_metric_stats( + "encoder_time(ms)", + [r.encoder_time_ms for r in results], + "ms", + ), + generation_time=create_metric_stats( + "generation_time(ms)", + [r.generation_time_ms for r in results], + "ms", + ), + first_token_latency=create_metric_stats( + "first_token_latency(ms)", + [r.first_token_latency_ms for r in results], + "ms", + ), + ) + + +def print_summary(summary: BenchmarkResults) -> None: + """Print formatted summary of benchmark results.""" + print("\n" + "=" * 70) + print(f"BENCHMARK SUMMARY for model: {summary.model_name}") + print("=" * 70) + print(f"Total runs collected: {summary.total_runs}") + print(f"Trimmed to: {summary.trimmed_runs} runs") + print( + f"(Discarded {summary.discarded_runs // 2} highest and " + f"{summary.discarded_runs // 2} lowest results)" + ) + print(f"Generated tokens per run: {summary.generated_tokens}") + print() + + # Print all metrics using their print_stats method + summary.throughput.print_stats() + summary.model_load_time.print_stats() + summary.total_inference_time.print_stats() + summary.encoder_time.print_stats() + summary.generation_time.print_stats() + summary.first_token_latency.print_stats() + + print("=" * 70) + + +def main(): + # Parse command-line arguments + parser = argparse.ArgumentParser( + description="Benchmark CUDA model runners and collect performance metrics" + ) + parser.add_argument( + "--runner_command", + type=str, + required=True, + help="Full command to run the model runner", + ) + parser.add_argument( + "--model_name", + type=str, + required=True, + help="Name of the model being benchmarked", + ) + parser.add_argument( + "--num_runs", + type=int, + default=50, + help="Number of benchmark runs (default: 50)", + ) + parser.add_argument( + "--warmup_runs", + type=int, + default=0, + help="Number of warmup runs before benchmark (default: 0.1 * num_runs)", + ) + parser.add_argument( + "--fix_gpu_clock", + type=bool, + default=True, + help="Fix GPU clock frequency to maximum before benchmarking", + ) + parser.add_argument( + "--gpu_clock", + type=int, + default=None, + help="Target GPU clock frequency in MHz (requires " + "--fix_gpu_clock). If not specified, uses max available.", + ) + parser.add_argument( + "--output_json", + type=str, + default=None, + help="Path to save JSON results", + ) + parser.add_argument( + "--output_v3", + type=str, + default=None, + help="Path to save v3 format JSON results for dashboard", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="Model ID (e.g., 'openai/whisper-small') - required for v3 format", + ) + parser.add_argument( + "--quantization", + type=str, + default=None, + help="Quantization type (e.g., 'non-quantized') - required for v3 format", + ) + parser.add_argument( + "--git_sha", + type=str, + default=None, + help="Git commit SHA - required for v3 format", + ) + parser.add_argument( + "--workflow_run_id", + type=str, + default=None, + help="GitHub workflow run ID - required for v3 format", + ) + parser.add_argument( + "--workflow_run_url", + type=str, + default="", + help="GitHub workflow run URL - optional for v3 format", + ) + parser.add_argument( + "--gpu_name", + type=str, + default=None, + help="GPU device name (e.g., 'Tesla V100', 'A100') - optional for v3 format", + ) + parser.add_argument( + "--cuda_driver_version", + type=str, + default=None, + help="CUDA driver version (e.g., '12.6', '535.104.05') - optional for v3 format", + ) + parser.add_argument("--verbose", action="store_true", help="Print verbose output") + + args = parser.parse_args() + + warmup_runs = ( + int(0.1 * args.num_runs) if args.warmup_runs == 0 else args.warmup_runs + ) + + print(f"Running benchmark for model: {args.model_name}") + print(f"Number of runs: {args.num_runs}") + if warmup_runs > 0: + print(f"Warmup runs: {warmup_runs}") + if args.fix_gpu_clock: + clock_str = f"{args.gpu_clock}" if args.gpu_clock else "max available" + print(f"GPU clock will be fixed to: {clock_str} MHz") + print(f"Command: {args.runner_command}\n") + + # Fix GPU clocks if requested + gpu_clock_fixed = False + if args.fix_gpu_clock: + # Get current clocks before fixing + initial_clocks = get_gpu_clocks() + if initial_clocks: + print( + f"Current GPU clocks - GPU: {initial_clocks[0]} MHz, " + f"Memory: {initial_clocks[1]} MHz" + ) + + gpu_clock_fixed = set_gpu_clocks(args.gpu_clock) + if not gpu_clock_fixed: + print( + "Warning: Failed to fix GPU clocks. " + "Continuing without fixed clocks...", + file=sys.stderr, + ) + + try: + # Run benchmark + results = run_model_benchmark( + command=args.runner_command, + num_runs=args.num_runs, + warmup_runs=warmup_runs, + verbose=args.verbose, + ) + + # Compute and print summary + summary = compute_summary(args.model_name, results, args.num_runs) + print_summary(summary) + + # Save JSON results if requested + if args.output_json: + summary.save_json(args.output_json) + + # Save v3 format if requested + if args.output_v3: + # Validate required parameters for v3 format + if not all( + [args.model, args.quantization, args.git_sha, args.workflow_run_id] + ): + print( + "Error: --output_v3 requires --model, --quantization, " + "--git_sha, and --workflow_run_id", + file=sys.stderr, + ) + sys.exit(1) + + v3_records = summary.to_v3_format( + model=args.model, + quantization=args.quantization, + git_sha=args.git_sha, + workflow_run_id=args.workflow_run_id, + workflow_run_url=args.workflow_run_url, + gpu_name=args.gpu_name if args.gpu_name else "UNKNOWN GPU", + cuda_driver_version=( + args.cuda_driver_version if args.cuda_driver_version else "cuda" + ), + ) + + with open(args.output_v3, "w") as f: + json.dump(v3_records, f, indent=2) + + print(f"✓ v3 format results saved to: {args.output_v3}") + print(f"✓ Generated {len(v3_records)} v3 records for dashboard upload") + + finally: + # Reset GPU clocks if they were fixed + if gpu_clock_fixed: + reset_gpu_clocks() + + +if __name__ == "__main__": + main() diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh new file mode 100755 index 00000000000..3c173b0ea2a --- /dev/null +++ b/.ci/scripts/export_model_artifact.sh @@ -0,0 +1,187 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Export model to CUDA/Metal format with optional quantization + +show_help() { + cat << EOF +Usage: export_model_artifact.sh [quant_name] [output_dir] + +Export a HuggingFace model to CUDA/Metal format with optional quantization. + +Arguments: + device cuda or metal (required) + + hf_model HuggingFace model ID (required) + Supported models: + - mistralai/Voxtral-Mini-3B-2507 + - openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}) + - google/gemma-3-4b-it + + quant_name Quantization type (optional, default: non-quantized) + Options: + - non-quantized + - quantized-int4-tile-packed + - quantized-int4-weight-only + + output_dir Output directory for artifacts (optional, default: current directory) + +Examples: + export_model_artifact.sh metal "openai/whisper-small" + export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" + export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output" +EOF +} + +if [ "${1:-}" = "-h" ] || [ "${1:-}" = "--help" ]; then + show_help + exit 0 +fi + +if [ -z "${1:-}" ]; then + echo "Error: hf_model argument is required" + echo "Run with -h or --help for usage information" + exit 1 +fi + +set -eux + +DEVICE="$1" +HF_MODEL="$2" +QUANT_NAME="${3:-non-quantized}" +OUTPUT_DIR="${4:-.}" + +case "$DEVICE" in + cuda) + ;; + metal) + ;; + *) + echo "Error: Unsupported device '$DEVICE'" + echo "Supported devices: cuda, metal" + exit 1 + ;; +esac + +# Determine model configuration based on HF model ID +case "$HF_MODEL" in + mistralai/Voxtral-Mini-3B-2507) + MODEL_NAME="voxtral" + TASK="multimodal-text-to-text" + MAX_SEQ_LEN="1024" + EXTRA_PIP="mistral-common librosa" + PREPROCESSOR_FEATURE_SIZE="128" + PREPROCESSOR_OUTPUT="voxtral_preprocessor.pte" + ;; + openai/whisper-*) + MODEL_NAME="whisper" + TASK="automatic-speech-recognition" + MAX_SEQ_LEN="" + EXTRA_PIP="librosa" + PREPROCESSOR_OUTPUT="whisper_preprocessor.pte" + if [[ "$HF_MODEL" == *"large-v3"* ]]; then + PREPROCESSOR_FEATURE_SIZE="128" + else + PREPROCESSOR_FEATURE_SIZE="80" + fi + ;; + google/gemma-3-4b-it) + if [ "$DEVICE" = "metal" ]; then + echo "Error: Export for device 'metal' is not yet tested for model '$HF_MODEL'" + exit 1 + fi + MODEL_NAME="gemma3" + TASK="multimodal-text-to-text" + MAX_SEQ_LEN="64" + EXTRA_PIP="" + PREPROCESSOR_FEATURE_SIZE="" + PREPROCESSOR_OUTPUT="" + ;; + *) + echo "Error: Unsupported model '$HF_MODEL'" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it" + exit 1 + ;; +esac + +# Determine quantization args based on quant name +case "$QUANT_NAME" in + non-quantized) + EXTRA_ARGS="" + ;; + quantized-int4-tile-packed) + if [ "$DEVICE" = "metal" ]; then + echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'" + exit 1 + fi + EXTRA_ARGS="--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d" + ;; + quantized-int4-weight-only) + if [ "$DEVICE" = "metal" ]; then + echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'" + exit 1 + fi + EXTRA_ARGS="--qlinear_encoder 4w" + ;; + *) + echo "Error: Unsupported quantization '$QUANT_NAME'" + echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only" + exit 1 + ;; +esac + +echo "::group::Export $MODEL_NAME" + +if [ -n "$EXTRA_PIP" ]; then + pip install $EXTRA_PIP +fi +pip list + +MAX_SEQ_LEN_ARG="" +if [ -n "$MAX_SEQ_LEN" ]; then + MAX_SEQ_LEN_ARG="--max_seq_len $MAX_SEQ_LEN" +fi + +DEVICE_ARG="" +if [ "$DEVICE" = "cuda" ]; then + DEVICE_ARG="--device cuda" +fi + +optimum-cli export executorch \ + --model "$HF_MODEL" \ + --task "$TASK" \ + --recipe "$DEVICE" \ + --dtype bfloat16 \ + ${DEVICE_ARG} \ + ${MAX_SEQ_LEN_ARG} \ + ${EXTRA_ARGS} \ + --output_dir ./ + +if [ -n "$PREPROCESSOR_OUTPUT" ]; then + python -m executorch.extension.audio.mel_spectrogram \ + --feature_size $PREPROCESSOR_FEATURE_SIZE \ + --stack_output \ + --max_audio_len 300 \ + --output_file $PREPROCESSOR_OUTPUT +fi + +test -f model.pte +test -f aoti_${DEVICE}_blob.ptd +if [ -n "$PREPROCESSOR_OUTPUT" ]; then + test -f $PREPROCESSOR_OUTPUT +fi +echo "::endgroup::" + +echo "::group::Store $MODEL_NAME Artifacts" +mkdir -p "${OUTPUT_DIR}" +mv model.pte "${OUTPUT_DIR}/" +mv aoti_${DEVICE}_blob.ptd "${OUTPUT_DIR}/" +if [ -n "$PREPROCESSOR_OUTPUT" ]; then + mv $PREPROCESSOR_OUTPUT "${OUTPUT_DIR}/" +fi +ls -al "${OUTPUT_DIR}" +echo "::endgroup::" diff --git a/.ci/scripts/setup-openvino.sh b/.ci/scripts/setup-openvino.sh index ff667619125..587494f46ac 100755 --- a/.ci/scripts/setup-openvino.sh +++ b/.ci/scripts/setup-openvino.sh @@ -10,19 +10,17 @@ set -ex # shellcheck source=/dev/null source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" -git clone https://github.com/openvinotoolkit/openvino.git -cd openvino && git checkout releases/2025/1 -git submodule update --init --recursive -sudo ./install_build_dependencies.sh -mkdir build && cd build -cmake .. -DCMAKE_BUILD_TYPE=Release -DENABLE_PYTHON=ON -make -j$(nproc) +# Download and install OpenVINO from release packages +OPENVINO_VERSION="2025.3" +OPENVINO_BUILD="2025.3.0.19807.44526285f24" +OPENVINO_URL="https://storage.openvinotoolkit.org/repositories/openvino/packages/${OPENVINO_VERSION}/linux/openvino_toolkit_ubuntu22_${OPENVINO_BUILD}_x86_64.tgz" -cd .. -cmake --install build --prefix dist +curl -Lo /tmp/openvino_toolkit.tgz --retry 3 --fail ${OPENVINO_URL} +tar -xzf /tmp/openvino_toolkit.tgz +mv openvino_toolkit_ubuntu22_${OPENVINO_BUILD}_x86_64 openvino -source dist/setupvars.sh -cd ../backends/openvino +source openvino/setupvars.sh +cd backends/openvino pip install -r requirements.txt cd scripts ./openvino_build.sh --enable_python diff --git a/.ci/scripts/setup-samsung-linux-deps.sh b/.ci/scripts/setup-samsung-linux-deps.sh index ed704b2bfbd..c1f2912713b 100644 --- a/.ci/scripts/setup-samsung-linux-deps.sh +++ b/.ci/scripts/setup-samsung-linux-deps.sh @@ -11,9 +11,9 @@ set -ex download_ai_lite_core() { API_BASE="https://soc-developer.semiconductor.samsung.com/api/v1/resource/ai-litecore/download" - API_KEY="kn10SoSY3hkC-9Qny5TqD2mnqVrlupv3krnjLeBt5cY" + API_KEY=$SAMSUNG_AI_LITECORE_KEY - VERSION="0.5" + VERSION="0.7" OS_NAME="Ubuntu 22.04" OUT_FILE="/tmp/exynos-ai-litecore-v${VERSION}.tar.gz" TARGET_PATH="/tmp/exynos_ai_lite_core" @@ -52,7 +52,7 @@ download_ai_lite_core() { install_enn_backend() { NDK_INSTALLATION_DIR=/opt/ndk rm -rf "${NDK_INSTALLATION_DIR}" && sudo mkdir -p "${NDK_INSTALLATION_DIR}" - ANDROID_NDK_VERSION=r27b + ANDROID_NDK_VERSION=r28c # build Exynos backend export ANDROID_NDK_ROOT=${ANDROID_NDK_ROOT:-/opt/ndk} @@ -62,7 +62,7 @@ install_enn_backend() { export PYTHONPATH=${PYTHONPATH:-}:${EXECUTORCH_ROOT}/.. } -AI_LITE_CORE_VERSION=0.5.0 +AI_LITE_CORE_VERSION=0.7.0 download_ai_lite_core ${AI_LITE_CORE_VERSION} install_enn_backend diff --git a/.ci/scripts/setup-windows-msvc.ps1 b/.ci/scripts/setup-windows-msvc.ps1 new file mode 100644 index 00000000000..e15a003d803 --- /dev/null +++ b/.ci/scripts/setup-windows-msvc.ps1 @@ -0,0 +1,52 @@ +conda create --yes --quiet -n et python=3.12 +conda activate et + +# Install cmake +conda install -y cmake + +# Activate the VS environment - this is required for MSVC to work +# There are a bunch of environment variables that it requires. +# See https://learn.microsoft.com/en-us/cpp/build/building-on-the-command-line. +& "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\Launch-VsDevShell.ps1" -Arch amd64 + +# Install CI requirements +pip install -r .ci/docker/requirements-ci.txt + +# Create build directory +$buildDir = "cmake-out-msvc" +if (Test-Path -Path $buildDir) { + Remove-Item -Path $buildDir -Recurse -Force +} +New-Item -Path $buildDir -ItemType Directory + +# Configure CMake with MSVC (not ClangCL) and disable custom/quantized ops +cmake -S . -B $buildDir ` + -DCMAKE_BUILD_TYPE=Release ` + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON ` + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON ` + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON ` + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON ` + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON ` + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON ` + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON ` + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=OFF ` + -DEXECUTORCH_BUILD_KERNELS_CUSTOM_AOT=OFF ` + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=OFF ` + -DEXECUTORCH_BUILD_XNNPACK=ON ` + -DEXECUTORCH_BUILD_EXTENSION_LLM=ON ` + -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON + +if ($LASTEXITCODE -ne 0) { + Write-Host "CMake configuration failed. Exit code: $LASTEXITCODE." + exit $LASTEXITCODE +} + +# Build with MSVC +cmake --build $buildDir --config Release -j16 + +if ($LASTEXITCODE -ne 0) { + Write-Host "Build failed. Exit code: $LASTEXITCODE." + exit $LASTEXITCODE +} + +Write-Host "MSVC build completed successfully!" diff --git a/.ci/scripts/test-cuda-build.sh b/.ci/scripts/test-cuda-build.sh new file mode 100755 index 00000000000..08673533927 --- /dev/null +++ b/.ci/scripts/test-cuda-build.sh @@ -0,0 +1,92 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -exu + +CUDA_VERSION=${1:-"12.6"} + +echo "=== Testing ExecuTorch CUDA ${CUDA_VERSION} Build ===" + +# Function to build and test ExecuTorch with CUDA support +test_executorch_cuda_build() { + local cuda_version=$1 + + echo "Building ExecuTorch with CUDA ${cuda_version} support..." + echo "ExecuTorch will automatically detect CUDA and install appropriate PyTorch wheel" + + # Check available resources before starting + echo "=== System Information ===" + echo "Available memory: $(free -h | grep Mem | awk '{print $2}')" + echo "Available disk space: $(df -h . | tail -1 | awk '{print $4}')" + echo "CPU cores: $(nproc)" + echo "CUDA version check:" + nvcc --version || echo "nvcc not found" + nvidia-smi || echo "nvidia-smi not found" + + echo "=== Starting ExecuTorch Installation ===" + # Install ExecuTorch with CUDA support with timeout and error handling + timeout 5400 ./install_executorch.sh || { + local exit_code=$? + echo "ERROR: install_executorch.sh failed with exit code: $exit_code" + if [ $exit_code -eq 124 ]; then + echo "ERROR: Installation timed out after 90 minutes" + fi + exit $exit_code + } + + echo "SUCCESS: ExecuTorch CUDA build completed" + + # Verify the installation + echo "=== Verifying ExecuTorch CUDA Installation ===" + + # Test that ExecuTorch was built successfully + python -c " +import executorch +print('SUCCESS: ExecuTorch imported successfully') +" + + # Test CUDA availability and show details + python -c " +try: + import torch + print('INFO: PyTorch version:', torch.__version__) + print('INFO: CUDA available:', torch.cuda.is_available()) + + if torch.cuda.is_available(): + print('SUCCESS: CUDA is available for ExecuTorch') + print('INFO: CUDA version:', torch.version.cuda) + print('INFO: GPU device count:', torch.cuda.device_count()) + print('INFO: Current GPU device:', torch.cuda.current_device()) + print('INFO: GPU device name:', torch.cuda.get_device_name()) + + # Test basic CUDA tensor operation + device = torch.device('cuda') + x = torch.randn(10, 10).to(device) + y = torch.randn(10, 10).to(device) + z = torch.mm(x, y) + print('SUCCESS: CUDA tensor operation completed on device:', z.device) + print('INFO: Result tensor shape:', z.shape) + + print('SUCCESS: ExecuTorch CUDA integration verified') + else: + print('WARNING: CUDA not detected, but ExecuTorch built successfully') + exit(1) +except Exception as e: + print('ERROR: ExecuTorch CUDA test failed:', e) + exit(1) +" + + echo "SUCCESS: ExecuTorch CUDA ${cuda_version} build and verification completed successfully" +} + +# Main execution +echo "Current working directory: $(pwd)" +echo "Directory contents:" +ls -la + +# Run the CUDA build test +test_executorch_cuda_build "${CUDA_VERSION}" diff --git a/.ci/scripts/test_ane_static_llama.sh b/.ci/scripts/test_ane_static_llama.sh index 3081c7ffe52..73a9c4ca54b 100644 --- a/.ci/scripts/test_ane_static_llama.sh +++ b/.ci/scripts/test_ane_static_llama.sh @@ -28,6 +28,13 @@ pushd $EXECUTORCH_ROOT/examples/apple/coreml/llama # Download stories llama110m artifacts download_stories_model_artifacts +# Test static ANE llama model +python export_static_llm_coreml.py --checkpoint stories110M.pt --params params.json --output model.pte + +# The ANE cannot run in github CI +# python run_static_llm.py --model model.pte --params params.json --tokenizer tokenizer.model --prompt "Once upon a time," --lookahead + +# Test export of deprecated model python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w --embedding-quantize 4,32 popd diff --git a/.ci/scripts/test_backend.sh b/.ci/scripts/test_backend.sh new file mode 100755 index 00000000000..1a8e3219be0 --- /dev/null +++ b/.ci/scripts/test_backend.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +set -eux + +SUITE=$1 +FLOW=$2 +ARTIFACT_DIR=$3 + +REPORT_FILE="$ARTIFACT_DIR/test-report-$FLOW-$SUITE.json" + +echo "Running backend test job for suite $SUITE, flow $FLOW." +echo "Saving job artifacts to $ARTIFACT_DIR." + +eval "$(conda shell.bash hook)" +CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") +conda activate "${CONDA_ENV}" + +if [[ "$(uname)" == "Darwin" ]]; then + bash .ci/scripts/setup-conda.sh + eval "$(conda shell.bash hook)" + CONDA_RUN_CMD="${CONDA_RUN} --no-capture-output" + ${CONDA_RUN_CMD} pip install awscli==1.37.21 + IS_MACOS=1 +else + CONDA_RUN_CMD="" + IS_MACOS=0 +fi + +export PYTHON_EXECUTABLE=python + +# CMake options to use, in addition to the defaults. +EXTRA_BUILD_ARGS="" + +if [[ "$FLOW" == *qnn* ]]; then + # Setup QNN sdk and deps - note that this is a bit hacky due to the nature of the + # Qualcomm build. TODO (gjcomer) Clean this up once the QNN pybinding integration is + # cleaned up. + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake + PYTHON_EXECUTABLE=python source .ci/scripts/build-qnn-sdk.sh + QNN_X86_LIB_DIR=`realpath build-x86/lib/` + export LD_LIBRARY_PATH"=$QNN_X86_LIB_DIR:$QNN_SDK_ROOT/lib/x86_64-linux-clang/:${LD_LIBRARY_PATH:-}" + + # TODO Get SDK root from install scripts + EXTRA_BUILD_ARGS+=" -DEXECUTORCH_BUILD_QNN=ON -DQNN_SDK_ROOT=$QNN_SDK_ROOT" +fi + +if [[ "$FLOW" == *vulkan* ]]; then + # Setup swiftshader and Vulkan SDK which are required to build the Vulkan delegate. + source .ci/scripts/setup-vulkan-linux-deps.sh + + EXTRA_BUILD_ARGS+=" -DEXECUTORCH_BUILD_VULKAN=ON" +fi + +if [[ "$FLOW" == *arm* ]]; then + + # Setup ARM deps. + if [[ "$FLOW" == *vgf* ]]; then + .ci/scripts/setup-arm-baremetal-tools.sh --enable-mlsdk-deps --install-mlsdk-deps-with-pip + else + .ci/scripts/setup-arm-baremetal-tools.sh + fi + source examples/arm/arm-scratch/setup_path.sh + + if [[ "$FLOW" == *ethos_u* ]]; then + # Prepare a test runner binary that can run on the Corstone-3x0 FVPs + backends/arm/scripts/build_executorch.sh + backends/arm/test/setup_testing.sh + fi + + if [[ "$FLOW" == *vgf* ]]; then + # Prepare a test runner binary for VKML runtime + backends/arm/test/setup_testing_vkml.sh + fi +fi + +if [[ $IS_MACOS -eq 1 ]]; then + SETUP_SCRIPT=.ci/scripts/setup-macos.sh +else + SETUP_SCRIPT=.ci/scripts/setup-linux.sh +fi +CMAKE_ARGS="$EXTRA_BUILD_ARGS" ${CONDA_RUN_CMD} $SETUP_SCRIPT --build-tool cmake --build-mode Release --editable true + +EXIT_CODE=0 +${CONDA_RUN_CMD} pytest -c /dev/nul -n auto backends/test/suite/$SUITE/ -m flow_$FLOW --json-report --json-report-file="$REPORT_FILE" || EXIT_CODE=$? +# Generate markdown summary. +${CONDA_RUN_CMD} python -m executorch.backends.test.suite.generate_markdown_summary_json "$REPORT_FILE" > ${GITHUB_STEP_SUMMARY:-"step_summary.md"} --exit-code $EXIT_CODE diff --git a/.ci/scripts/test_backend_linux.sh b/.ci/scripts/test_backend_linux.sh deleted file mode 100755 index 243602fea21..00000000000 --- a/.ci/scripts/test_backend_linux.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -set -eux - -SUITE=$1 -FLOW=$2 -ARTIFACT_DIR=$3 - -REPORT_FILE="$ARTIFACT_DIR/test-report-$FLOW-$SUITE.csv" - -echo "Running backend test job for suite $SUITE, flow $FLOW." -echo "Saving job artifacts to $ARTIFACT_DIR." - -# The generic Linux job chooses to use base env, not the one setup by the image -eval "$(conda shell.bash hook)" -CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") -conda activate "${CONDA_ENV}" - -export PYTHON_EXECUTABLE=python - -# CMake options to use, in addition to the defaults. -EXTRA_BUILD_ARGS="" - -if [[ "$FLOW" == *qnn* ]]; then - # Setup QNN sdk and deps - note that this is a bit hacky due to the nature of the - # Qualcomm build. TODO (gjcomer) Clean this up once the QNN pybinding integration is - # cleaned up. - PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake - PYTHON_EXECUTABLE=python source .ci/scripts/build-qnn-sdk.sh - QNN_X86_LIB_DIR=`realpath build-x86/lib/` - export LD_LIBRARY_PATH"=$QNN_X86_LIB_DIR:$QNN_SDK_ROOT/lib/x86_64-linux-clang/:${LD_LIBRARY_PATH:-}" - - # TODO Get SDK root from install scripts - EXTRA_BUILD_ARGS+=" -DEXECUTORCH_BUILD_QNN=ON -DQNN_SDK_ROOT=$QNN_SDK_ROOT" -fi - -if [[ "$FLOW" == *vulkan* ]]; then - # Setup swiftshader and Vulkan SDK which are required to build the Vulkan delegate - source .ci/scripts/setup-vulkan-linux-deps.sh - - EXTRA_BUILD_ARGS+=" -DEXECUTORCH_BUILD_VULKAN=ON" -fi - -# We need the runner to test the built library. -PYTHON_EXECUTABLE=python CMAKE_ARGS="$EXTRA_BUILD_ARGS" .ci/scripts/setup-linux.sh --build-tool cmake --build-mode Release --editable true - -EXIT_CODE=0 -python -m executorch.backends.test.suite.runner $SUITE --flow $FLOW --report "$REPORT_FILE" || EXIT_CODE=$? - -# Generate markdown summary. -python -m executorch.backends.test.suite.generate_markdown_summary "$REPORT_FILE" > ${GITHUB_STEP_SUMMARY:-"step_summary.md"} --exit-code $EXIT_CODE diff --git a/.ci/scripts/test_backend_macos.sh b/.ci/scripts/test_backend_macos.sh deleted file mode 100755 index c31fd504b03..00000000000 --- a/.ci/scripts/test_backend_macos.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -set -eux - -SUITE=$1 -FLOW=$2 -ARTIFACT_DIR=$3 - -REPORT_FILE="$ARTIFACT_DIR/test-report-$FLOW-$SUITE.csv" - -echo "Running backend test job for suite $SUITE, flow $FLOW." -echo "Saving job artifacts to $ARTIFACT_DIR." - -${CONDA_RUN} --no-capture-output pip install awscli==1.37.21 - -bash .ci/scripts/setup-conda.sh -eval "$(conda shell.bash hook)" - -PYTHON_EXECUTABLE=python -${CONDA_RUN} --no-capture-output .ci/scripts/setup-macos.sh --build-tool cmake --build-mode Release - -EXIT_CODE=0 -${CONDA_RUN} --no-capture-output python -m executorch.backends.test.suite.runner $SUITE --flow $FLOW --report "$REPORT_FILE" || EXIT_CODE=$? - -# Generate markdown summary. -${CONDA_RUN} --no-capture-output python -m executorch.backends.test.suite.generate_markdown_summary "$REPORT_FILE" > ${GITHUB_STEP_SUMMARY:-"step_summary.md"} --exit-code $EXIT_CODE diff --git a/.ci/scripts/test_huggingface_optimum_model.py b/.ci/scripts/test_huggingface_optimum_model.py index 05b25299522..e5d815cfc00 100644 --- a/.ci/scripts/test_huggingface_optimum_model.py +++ b/.ci/scripts/test_huggingface_optimum_model.py @@ -43,7 +43,9 @@ def cli_export(command, model_dir): def check_causal_lm_output_quality( - model_id: str, generated_tokens: List[int], max_perplexity_threshold: float = 100.0 + model_id: str, + generated_tokens: List[int], + max_perplexity_threshold: float = 100.0, ): """ Evaluates the quality of text generated by a causal language model by calculating its perplexity. @@ -58,12 +60,24 @@ def check_causal_lm_output_quality( """ logging.info(f"Starting perplexity check with model '{model_id}' ...") # Load model - model = AutoModelForCausalLM.from_pretrained( - model_id, - low_cpu_mem_usage=True, - use_cache=False, - torch_dtype=torch.bfloat16, - ) + cls_name = AutoModelForCausalLM + if "llava" in model_id: + from transformers import LlavaForConditionalGeneration + + cls_name = LlavaForConditionalGeneration + try: + model = cls_name.from_pretrained( + model_id, + low_cpu_mem_usage=True, + use_cache=False, + torch_dtype=torch.bfloat16, + ) + except TypeError: + model = cls_name.from_pretrained( + model_id, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + ) with torch.no_grad(): outputs = model(input_ids=generated_tokens, labels=generated_tokens) @@ -156,6 +170,86 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only assert check_causal_lm_output_quality(model_id, generated_tokens) is True +def test_llm_with_image_modality( + model_id, model_dir, recipe, *, quantize=True, run_only=False +): + command = [ + "optimum-cli", + "export", + "executorch", + "--model", + model_id, + "--task", + "multimodal-text-to-text", + "--recipe", + recipe, + "--output_dir", + model_dir, + "--use_custom_sdpa", + "--use_custom_kv_cache", + "--qlinear", + "8da4w", + "--qembedding", + "8w", + ] + if not run_only: + cli_export(command, model_dir) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(model_dir) + + # input + processor = AutoProcessor.from_pretrained(model_id) + image_url = "https://llava-vl.github.io/static/images/view.jpg" + conversation = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", + } + ], + }, + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + { + "type": "text", + "text": "What are the things I should be cautious about when I visit here?", + }, + ], + }, + ] + inputs = processor.apply_chat_template( + conversation, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + from executorch.extension.llm.runner import GenerationConfig, MultimodalRunner + + runner = MultimodalRunner(f"{model_dir}/model.pte", f"{model_dir}/tokenizer.model") + generated_text = runner.generate_text_hf( + inputs, + GenerationConfig(max_new_tokens=128, temperature=0, echo=False), + processor.image_token_id, + ) + print(f"\nGenerated text:\n\t{generated_text}") + # Free memory before loading eager for quality check + del runner + gc.collect() + assert ( + check_causal_lm_output_quality( + model_id, tokenizer.encode(generated_text, return_tensors="pt") + ) + is True + ) + + def test_fill_mask(model_id, model_dir, recipe, *, quantize=True, run_only=False): command = [ "optimum-cli", @@ -353,6 +447,9 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False): required=False, help="When provided, write the pte file to this directory. Otherwise, a temporary directory is created for the test.", ) + parser.add_argument( + "--run_only", action="store_true", help="Skip export and only run the test" + ) args = parser.parse_args() _text_generation_mapping = { @@ -384,8 +481,16 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False): "vit": ("google/vit-base-patch16-224", test_vit), } + _multimodal_model_mapping = { + "gemma3-4b": ("google/gemma-3-4b-it", test_llm_with_image_modality), + "llava": ("llava-hf/llava-1.5-7b-hf", test_llm_with_image_modality), + } + model_to_model_id_and_test_function = ( - _text_generation_mapping | _mask_fill_mapping | _misc_model_mapping + _text_generation_mapping + | _mask_fill_mapping + | _misc_model_mapping + | _multimodal_model_mapping ) if args.model not in model_to_model_id_and_test_function: @@ -400,4 +505,5 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False): model_dir=tmp_dir if args.model_dir is None else args.model_dir, recipe=args.recipe, quantize=args.quantize, + run_only=args.run_only, ) diff --git a/.ci/scripts/test_ios_ci.sh b/.ci/scripts/test_ios_ci.sh index a89c2cc5809..46c3f71f021 100755 --- a/.ci/scripts/test_ios_ci.sh +++ b/.ci/scripts/test_ios_ci.sh @@ -36,6 +36,7 @@ say() { say "Cloning the Demo App" +git config --global http.postBuffer 524288000 git clone --depth 1 https://github.com/meta-pytorch/executorch-examples.git say "Installing CoreML Backend Requirements" diff --git a/.ci/scripts/test_llama.sh b/.ci/scripts/test_llama.sh index 84278e290f6..414ab85be58 100644 --- a/.ci/scripts/test_llama.sh +++ b/.ci/scripts/test_llama.sh @@ -130,7 +130,6 @@ if [[ "${MODE}" =~ .*qnn.* ]]; then cp schema/program.fbs exir/_serialize/program.fbs cp schema/scalar_type.fbs exir/_serialize/scalar_type.fbs cp -f build-x86/backends/qualcomm/PyQnnManagerAdaptor.cpython-310-x86_64-linux-gnu.so backends/qualcomm/python - cp -f build-x86/backends/qualcomm/PyQnnWrapperAdaptor.cpython-310-x86_64-linux-gnu.so backends/qualcomm/python else QNN=OFF @@ -159,6 +158,7 @@ cmake_install_executorch_libraries() { -DCMAKE_INSTALL_PREFIX=cmake-out \ -DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \ -DEXECUTORCH_BUILD_QNN="$QNN" \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ -DQNN_SDK_ROOT="$QNN_SDK_ROOT" cmake --build cmake-out -j9 --target install --config "$CMAKE_BUILD_TYPE" } @@ -170,15 +170,14 @@ cmake_build_llama_runner() { git submodule update --init popd dir="examples/models/llama" - retry cmake \ - -DEXECUTORCH_BUILD_TESTS=ON \ - -DBUILD_TESTING=OFF \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \ - -Bcmake-out/${dir} \ - ${dir} - cmake --build cmake-out/${dir} -j9 --config "$CMAKE_BUILD_TYPE" - + if [[ "$CMAKE_BUILD_TYPE" == "Debug" ]]; then + PRESET="llama-debug" + else + PRESET="llama-release" + fi + pushd "${dir}" + cmake --workflow --preset "${PRESET}" + popd } cleanup_files() { @@ -236,7 +235,7 @@ if [[ "${CUSTOM}" == "ON" ]]; then EXPORT_ARGS="${EXPORT_ARGS} model.use_sdpa_with_kv_cache=true" fi if [[ "${QE}" == "ON" ]]; then - EXPORT_ARGS="${EXPORT_ARGS} quantization.embedding_quantize=\"8,1024\"" + EXPORT_ARGS="${EXPORT_ARGS} quantization.embedding_quantize=\"8,768\"" fi if [[ "${MPS}" == "ON" ]]; then EXPORT_ARGS="${EXPORT_ARGS} backend.mps.enabled=true model.enable_dynamic_shape=false debug.verbose=true" diff --git a/.ci/scripts/test_llama_lora.sh b/.ci/scripts/test_llama_lora.sh deleted file mode 100644 index 6337bbf76a2..00000000000 --- a/.ci/scripts/test_llama_lora.sh +++ /dev/null @@ -1,133 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -set -exu -# shellcheck source=/dev/null -source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" - -cmake_install_executorch_libraries() { - echo "Installing libexecutorch.a, libextension_module.so, libportable_ops_lib.a" - rm -rf cmake-out - retry cmake --preset llm \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE=Release - cmake --build cmake-out -j9 --target install --config Release -} - -cmake_build_llama_runner() { - echo "Building llama runner" - pushd extension/llm/tokenizers - echo "Updating tokenizers submodule" - git submodule update --init - popd - dir="examples/models/llama" - retry cmake \ - -DBUILD_TESTING=OFF \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE=Release \ - -Bcmake-out/${dir} \ - ${dir} - cmake --build cmake-out/${dir} -j9 --config Release -} - -cleanup_files() { - echo "Deleting downloaded and generated files" - rm -rf "${DOWNLOADED_PATH}/" - rm result.txt -} - -# Download model artifacts from HF Hub. -# Hosting in personal repo for now. -HF_MODEL_REPO="lucylq/llama3_1B_lora" -DOWNLOADED_PATH=$( - bash "$(dirname "${BASH_SOURCE[0]}")/download_hf_hub.sh" \ - --model_id "${HF_MODEL_REPO}" \ - --files "adapter_config.json" "adapter_model.pt" "consolidated.00.pth" "params.json" "tokenizer.model" -) -# Build llama runner. -cmake_install_executorch_libraries -cmake_build_llama_runner - -# Constants. -RUNTIME_ARGS="--tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1" -PROMPT="What happens if you eat watermelon seeds?" -EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C," - -# Export LoRA PTE file. -MODEL_NAME="llama_3_2_1B_lora" -$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - base.adapter_checkpoint="${DOWNLOADED_PATH}/adapter_model.pt" \ - base.adapter_config="${DOWNLOADED_PATH}/adapter_config.json" \ - base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override="fp32" \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - export.output_name="${MODEL_NAME}.pte" - -# Run llama runner -NOW=$(date +"%H:%M:%S") -echo "Starting to run llama runner at ${NOW}" -# shellcheck source=/dev/null -cmake-out/examples/models/llama/llama_main --model_path=${MODEL_NAME}.pte --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt -NOW=$(date +"%H:%M:%S") -echo "Finished at ${NOW}" - -RESULT=$(cat result.txt) -if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then - echo "Expected result prefix: ${EXPECTED_PREFIX}" - echo "Actual result: ${RESULT}" - # Do not clean up files if test passes, as they're re-used in the next test. - echo "Success" -else - echo "Expected result prefix: ${EXPECTED_PREFIX}" - echo "Actual result: ${RESULT}" - echo "Failure; results not the same" - cleanup_files - exit 1 -fi - -# Export LoRA PTE, PTD file. -MODEL_SEPARATE="${MODEL_NAME}_separate" -$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - base.adapter_checkpoint="${DOWNLOADED_PATH}/adapter_model.pt" \ - base.adapter_config="${DOWNLOADED_PATH}/adapter_config.json" \ - base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override="fp32" \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - export.output_name="${MODEL_SEPARATE}.pte" \ - export.foundation_weights_file="${MODEL_SEPARATE}.ptd" - -# Run llama runner. -NOW=$(date +"%H:%M:%S") -echo "Starting to run llama runner at ${NOW}" -# shellcheck source=/dev/null -cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_path=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt -NOW=$(date +"%H:%M:%S") -echo "Finished at ${NOW}" - -RESULT2=$(cat result2.txt) -if [[ "${RESULT2}" == "${EXPECTED_PREFIX}"* ]]; then - echo "Expected result prefix: ${EXPECTED_PREFIX}" - echo "Actual result: ${RESULT2}" - echo "Success" - cleanup_files -else - echo "Expected result prefix: ${EXPECTED_PREFIX}" - echo "Actual result: ${RESULT2}" - echo "Failure; results not the same" - cleanup_files - exit 1 -fi diff --git a/.ci/scripts/test_llama_torchao_lowbit.sh b/.ci/scripts/test_llama_torchao_lowbit.sh index 5f472fad63b..a7ded52ccc6 100644 --- a/.ci/scripts/test_llama_torchao_lowbit.sh +++ b/.ci/scripts/test_llama_torchao_lowbit.sh @@ -31,6 +31,7 @@ cmake -DPYTHON_EXECUTABLE=python \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_XNNPACK=OFF \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ diff --git a/.ci/scripts/test_llava.sh b/.ci/scripts/test_llava.sh index afed3c54123..d8cb9596ffc 100644 --- a/.ci/scripts/test_llava.sh +++ b/.ci/scripts/test_llava.sh @@ -38,6 +38,7 @@ EXECUTORCH_COMMON_CMAKE_ARGS=" \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_LLM=ON \ -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ @@ -107,7 +108,7 @@ cmake_build_llava_runner_for_android() { # only export the one without custom op for now since it's export_llava() { echo "Starting to export Llava. This will take about 6 mins" - $PYTHON_EXECUTABLE -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts + $PYTHON_EXECUTABLE -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts --max-context-len 768 } # Download a new image @@ -149,7 +150,7 @@ run_and_verify() { # verify result.txt RESULT=$(cat result.txt) - EXPECTED_PREFIX="ASSISTANT: image captures a basketball game in progress, with" + EXPECTED_PREFIX="ASSISTANT: The image captures a basketball game in progress, with" if [[ "${RESULT}" == *"${EXPECTED_PREFIX}"* ]]; then echo "Expected result prefix: ${EXPECTED_PREFIX}" diff --git a/.ci/scripts/test_lora.sh b/.ci/scripts/test_lora.sh new file mode 100644 index 00000000000..08210bf85cb --- /dev/null +++ b/.ci/scripts/test_lora.sh @@ -0,0 +1,224 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -exu +# shellcheck source=/dev/null +source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" + +cmake_install_executorch_libraries() { + echo "Installing libexecutorch.a, libextension_module.so, libportable_ops_lib.a" + rm -rf cmake-out + cmake --workflow llm-release +} + +cmake_build_llama_runner() { + echo "Building llama runner" + pushd extension/llm/tokenizers + echo "Updating tokenizers submodule" + git submodule update --init + popd + make llama-cpu +} + +cleanup_files() { + echo "Deleting downloaded and generated files" + rm -rf "${HF_QWEN_PATH}/" + rm -rf "${HF_ADAPTER_PATH}/" + rm -rf *.pte *.ptd + rm result*.txt +} + +# Hosting lora adapter in personal repo for now. +python -m pip install -q huggingface_hub +HF_ADAPTER_REPO="lucylq/qwen3_06B_lora_math" +HF_ADAPTER_PATH=$( + bash "$(dirname "${BASH_SOURCE[0]}")/download_hf_hub.sh" \ + --model_id "${HF_ADAPTER_REPO}" \ + --files "adapter_config.json" "adapter_model.safetensors" +) + +### SINGLE LORA PTE ### +# Export LoRA PTE file. +$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ + --config examples/models/qwen3/config/qwen3_xnnpack.yaml \ + +base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \ + +base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \ + +export.output_name="qwen_lora_math_full.pte" + +# Capture the path of the downloaded qwen artifacts +HF_QWEN_PATH=$(python -c "from huggingface_hub import snapshot_download; print(snapshot_download('unsloth/Qwen3-0.6B'))") +echo "Model downloaded to: $HF_QWEN_PATH" + +### BUILD LLAMA RUNNER. +cmake_install_executorch_libraries +cmake_build_llama_runner + +# Runner constants. +RUNTIME_ARGS="--tokenizer_path=${HF_QWEN_PATH}/ --temperature=0 --seq_len=100 --warmup=1" +PROMPT="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant" +EXPECTED_PREFIX=" +<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant +To calculate 15% of 80, we can multiply 80 by 0.15. +80 * 0.15 = 12 +So, 15% of 80 is 12. +#### 12 +The answer is: 12<|im_end|>" + +# Run llama runner on single lora PTE file. +NOW=$(date +"%H:%M:%S") +echo "Test 1: Single lora file. Starting to run llama runner at ${NOW}" +# shellcheck source=/dev/null +cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math_full.pte --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt +NOW=$(date +"%H:%M:%S") +echo "Finished at ${NOW}" + +RESULT=$(cat result.txt) +if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then + echo "Expected result prefix: ${EXPECTED_PREFIX}" + echo "Actual result: ${RESULT}" + # Do not clean up files if test passes, as they're re-used in the next test. + echo "Test 1: Success" +else + echo "Expected result prefix: ${EXPECTED_PREFIX}" + echo "Actual result: ${RESULT}" + echo "Test 1: Failure; results not the same" + cleanup_files + exit 1 +fi + +### PROGRAM DATA SEPARATION ### +# Export LoRA PTE, LoRA PTD, foundation PTD file. +$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ + --config examples/models/qwen3/config/qwen3_xnnpack.yaml \ + +base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \ + +base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \ + +export.output_name="qwen_lora_math.pte" \ + +export.foundation_weights_file="qwen_foundation.ptd" \ + +export.lora_weights_file="qwen_lora_math.ptd" + +# Run llama runner on PTE, PTD files. +NOW=$(date +"%H:%M:%S") +echo "Test 2: Program data separation lora. Starting to run llama runner at ${NOW}" +# shellcheck source=/dev/null +cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math.pte --data_paths="qwen_foundation.ptd,qwen_lora_math.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt +NOW=$(date +"%H:%M:%S") +echo "Finished at ${NOW}" + +RESULT=$(cat result.txt) +if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then + echo "Expected result prefix: ${EXPECTED_PREFIX}" + echo "Actual result: ${RESULT}" + echo "Test 2: Success" +else + echo "Expected result prefix: ${EXPECTED_PREFIX}" + echo "Actual result: ${RESULT}" + echo "Test 2: Failure; results not the same" +# cleanup_files + exit 1 +fi + +# Confirm file sizes. +FOUNDATION_SIZE=$(stat -c%s qwen_foundation.ptd) +if [[ $FOUNDATION_SIZE -le "2400000000" ]]; then + echo "qwen_foundation_q.ptd size is: $FOUNDATION_SIZE" +else + echo "qwen_foundation_q.ptd size: $FOUNDATION_SIZE is greater than threshold 2.4GB" + cleanup_files + exit 1 +fi + +### QUANTIZATION & PROGRAM DATA SEPARATION ### +EXPECTED_QUANT_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant: + +Okay, so I need to calculate 15% of 80." +EXPECTED_QUANT_LORA_PREFIX=" +<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant +To calculate 15% of 80, we can multiply 80 by 15/100. +So, 15% of 80 is equal to (80 * 15) / 100 = 1200 / 100 = 12. +#### 12 +The answer is: 12<|im_end|>" + +# Export Quantized PTE, PTD file, no LoRA. +$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ + --config examples/models/qwen3/config/qwen3_xnnpack.yaml \ + +export.output_name="qwen_q.pte" \ + +export.foundation_weights_file="qwen_foundation_q.ptd" \ + +quantization.qmode="8da4w" \ + +quantization.group_size=32 + +# Export Quantized LoRA PTE, LoRA PTD, foundation PTD file. +$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ + --config examples/models/qwen3/config/qwen3_xnnpack.yaml \ + +base.adapter_checkpoint="${HF_ADAPTER_PATH}/adapter_model.safetensors" \ + +base.adapter_config="${HF_ADAPTER_PATH}/adapter_config.json" \ + +export.output_name="qwen_lora_math_q.pte" \ + +export.foundation_weights_file="qwen_foundation_lora_q.ptd" \ + +export.lora_weights_file="qwen_lora_math_q.ptd" \ + +quantization.qmode="8da4w" \ + +quantization.group_size=32 + +# Confirm that qwen_foundation_lora_q.ptd and qwen_foundation_q.ptd are the same. +if diff -q qwen_foundation_lora_q.ptd qwen_foundation_q.ptd > /dev/null; then + echo "qwen_foundation_lora_q.ptd and qwen_foundation_q.ptd are identical." +else + echo "qwen_foundation_lora_q.ptd and qwen_foundation_q.ptd are not identical." + cleanup_files + exit 1 +fi + +# Run quantized qwen model (no adapter). +NOW=$(date +"%H:%M:%S") +echo "Test 3: Quantized qwen model (no lora). Starting to run llama runner at ${NOW}" +# shellcheck source=/dev/null +cmake-out/examples/models/llama/llama_main --model_path=qwen_q.pte --data_paths="qwen_foundation_q.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt +NOW=$(date +"%H:%M:%S") +echo "Finished at ${NOW}" +RESULT=$(cat result.txt) +if [[ "${RESULT}" == "${EXPECTED_QUANT_PREFIX}"* ]]; then + echo "Expected result prefix: ${EXPECTED_QUANT_PREFIX}" + echo "Actual result: ${RESULT}" + echo "Test 3: Success" +else + echo "Expected result prefix: ${EXPECTED_QUANT_PREFIX}" + echo "Actual result: ${RESULT}" + echo "Test 3: Failure; results not the same" + cleanup_files + exit 1 +fi + +# Run quantized lora adapter. +NOW=$(date +"%H:%M:%S") +echo "Test 4: Quantized, program-data separation lora. Starting to run llama runner at ${NOW}" +# shellcheck source=/dev/null +cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math_q.pte --data_paths="qwen_foundation_q.ptd,qwen_lora_math_q.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt +NOW=$(date +"%H:%M:%S") +echo "Finished at ${NOW}" + +RESULT=$(cat result.txt) +if [[ "${RESULT}" == "${EXPECTED_QUANT_LORA_PREFIX}"* ]]; then + echo "Expected result prefix: ${EXPECTED_QUANT_LORA_PREFIX}" + echo "Actual result: ${RESULT}" + echo "Test 4: Success" +else + echo "Expected result prefix: ${EXPECTED_QUANT_LORA_PREFIX}" + echo "Actual result: ${RESULT}" + echo "Test 4: Failure; results not the same" + cleanup_files + exit 1 +fi + +# Confirm qwen_foundation_q.ptd file size. +FOUNDATION_Q_SIZE=$(stat -c%s qwen_foundation_q.ptd) +if [[ $FOUNDATION_Q_SIZE -le "1000000000" ]]; then + echo "qwen_foundation_q.ptd size is: $FOUNDATION_Q_SIZE" +else + echo "qwen_foundation_q.ptd size: $FOUNDATION_Q_SIZE is greater than threshold 1GB" + cleanup_files + exit 1 +fi + +cleanup_files diff --git a/.ci/scripts/test_model.sh b/.ci/scripts/test_model.sh index 74eb75c6ddd..34063a23374 100755 --- a/.ci/scripts/test_model.sh +++ b/.ci/scripts/test_model.sh @@ -48,22 +48,33 @@ prepare_artifacts_upload() { fi } + build_cmake_executor_runner() { local backend_string_select="${1:-}" echo "Building executor_runner" rm -rf ${CMAKE_OUTPUT_DIR} mkdir ${CMAKE_OUTPUT_DIR} + # Common options: + COMMON="-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE" if [[ "$backend_string_select" == "XNNPACK" ]]; then echo "Backend $backend_string_select selected" - (cd ${CMAKE_OUTPUT_DIR} \ - && cmake -DCMAKE_BUILD_TYPE=Release \ + cmake -DCMAKE_BUILD_TYPE=Release \ -DEXECUTORCH_BUILD_XNNPACK=ON \ - -DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" ..) + ${COMMON} \ + -B${CMAKE_OUTPUT_DIR} . + cmake --build ${CMAKE_OUTPUT_DIR} -j4 + elif [[ "$backend_string_select" == "CUDA" ]]; then + echo "Backend $backend_string_select selected" + cmake -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_CUDA=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + ${COMMON} \ + -B${CMAKE_OUTPUT_DIR} . cmake --build ${CMAKE_OUTPUT_DIR} -j4 else cmake -DCMAKE_BUILD_TYPE=Debug \ -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ - -DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \ + ${COMMON} \ -B${CMAKE_OUTPUT_DIR} . cmake --build ${CMAKE_OUTPUT_DIR} -j4 --config Debug fi @@ -131,13 +142,13 @@ test_model_with_xnnpack() { return 0 fi - # Delegation + # Delegation and test with pybindings if [[ ${WITH_QUANTIZATION} == true ]]; then SUFFIX="q8" - "${PYTHON_EXECUTABLE}" -m examples.xnnpack.aot_compiler --model_name="${MODEL_NAME}" --delegate --quantize + "${PYTHON_EXECUTABLE}" -m examples.xnnpack.aot_compiler --model_name="${MODEL_NAME}" --delegate --quantize --test_after_export else SUFFIX="fp32" - "${PYTHON_EXECUTABLE}" -m examples.xnnpack.aot_compiler --model_name="${MODEL_NAME}" --delegate + "${PYTHON_EXECUTABLE}" -m examples.xnnpack.aot_compiler --model_name="${MODEL_NAME}" --delegate --test_after_export fi OUTPUT_MODEL_PATH="${MODEL_NAME}_xnnpack_${SUFFIX}.pte" @@ -320,6 +331,13 @@ test_model_with_mediatek() { EXPORTED_MODEL=$(find "./${EXPORT_SCRIPT}" -type f -name "*.pte" -print -quit) } +test_model_with_cuda() { + # Export a basic .pte and .ptd, then run the model. + "${PYTHON_EXECUTABLE}" -m examples.cuda.scripts.export --model_name="${MODEL_NAME}" --output_dir "./" + build_cmake_executor_runner "CUDA" + ./${CMAKE_OUTPUT_DIR}/executor_runner --model_path "./${MODEL_NAME}.pte" --data_path "./aoti_cuda_blob.ptd" +} + if [[ "${BACKEND}" == "portable" ]]; then echo "Testing ${MODEL_NAME} with portable kernels..." @@ -372,6 +390,12 @@ elif [[ "${BACKEND}" == "mediatek" ]]; then if [[ $? -eq 0 ]]; then prepare_artifacts_upload fi +elif [[ "${BACKEND}" == "cuda" ]]; then + echo "Testing ${MODEL_NAME} with cuda..." + test_model_with_cuda + if [[ $? -eq 0 ]]; then + prepare_artifacts_upload + fi else set +e if [[ "${BACKEND}" == *"quantization"* ]]; then diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh new file mode 100755 index 00000000000..715c8b497cd --- /dev/null +++ b/.ci/scripts/test_model_e2e.sh @@ -0,0 +1,209 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Test CUDA/Metal model end-to-end, need to run .ci/scripts/export_model_artifact.sh first + +show_help() { + cat << EOF +Usage: test_model_e2e.sh [model_dir] + +Build and run end-to-end tests for CUDA/Metal models. + +Arguments: + device cuda or metal (required) + + hf_model HuggingFace model ID (required) + Supported models: + - mistralai/Voxtral-Mini-3B-2507 + - openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}) + - google/gemma-3-4b-it + + quant_name Quantization type (required) + Options: + - non-quantized + - quantized-int4-tile-packed + - quantized-int4-weight-only + + model_dir Directory containing model artifacts (optional, default: current directory) + Expected files: model.pte, aoti_cuda_blob.ptd/aoti_metal_blob.ptd + Tokenizers and test files will be downloaded to this directory + +Examples: + test_model_e2e.sh metal "openai/whisper-small" "non-quantized" + test_model_e2e.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" "./model_output" +EOF +} + +if [ "${1:-}" = "-h" ] || [ "${1:-}" = "--help" ]; then + show_help + exit 0 +fi + +if [ -z "${1:-}" ]; then + echo "Error: hf_model argument is required" + echo "Run with -h or --help for usage information" + exit 1 +fi + +if [ -z "${2:-}" ]; then + echo "Error: quant_name argument is required" + echo "Run with -h or --help for usage information" + exit 1 +fi + +set -eux + +DEVICE="$1" +HF_MODEL="$2" +QUANT_NAME="$3" +# Download tokenizers, audio, and image files to this directory +MODEL_DIR="${4:-.}" + +echo "Testing model: $HF_MODEL (quantization: $QUANT_NAME)" + +# Make sure model.pte and aoti_${DEVICE}_blob.ptd exist +if [ ! -f "$MODEL_DIR/model.pte" ]; then + echo "Error: model.pte not found in $MODEL_DIR" + exit 1 +fi +if [ ! -f "$MODEL_DIR/aoti_${DEVICE}_blob.ptd" ]; then + echo "Error: aoti_${DEVICE}_blob.ptd not found in $MODEL_DIR" + exit 1 +fi +# Locate EXECUTORCH_ROOT from the directory of this script +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXECUTORCH_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" + +pushd "$EXECUTORCH_ROOT" + +# Determine model configuration based on HF model ID +case "$HF_MODEL" in + mistralai/Voxtral-Mini-3B-2507) + MODEL_NAME="voxtral" + RUNNER_TARGET="voxtral_runner" + RUNNER_PATH="voxtral" + EXPECTED_OUTPUT="poem" + PREPROCESSOR="voxtral_preprocessor.pte" + TOKENIZER_URL="https://huggingface.co/mistralai/Voxtral-Mini-3B-2507/resolve/main" # @lint-ignore + TOKENIZER_FILE="tekken.json" + AUDIO_URL="https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav" + AUDIO_FILE="poem.wav" + IMAGE_PATH="" + ;; + openai/whisper-*) + MODEL_NAME="${HF_MODEL#openai/}" + RUNNER_TARGET="whisper_runner" + RUNNER_PATH="whisper" + EXPECTED_OUTPUT="Mr. Quilter is the apostle of the middle classes" + PREPROCESSOR="whisper_preprocessor.pte" + TOKENIZER_URL="https://huggingface.co/${HF_MODEL}/resolve/main" # @lint-ignore + TOKENIZER_FILE="" + AUDIO_URL="" + AUDIO_FILE="output.wav" + IMAGE_PATH="" + ;; + google/gemma-3-4b-it) + MODEL_NAME="gemma3" + RUNNER_TARGET="gemma3_e2e_runner" + RUNNER_PATH="gemma3" + EXPECTED_OUTPUT="chip" + PREPROCESSOR="" + TOKENIZER_URL="https://huggingface.co/unsloth/gemma-3-4b-it/resolve/main" # @lint-ignore + TOKENIZER_FILE="" + AUDIO_URL="" + AUDIO_FILE="" + IMAGE_PATH="docs/source/_static/img/et-logo.png" + ;; + *) + echo "Error: Unsupported model '$HF_MODEL'" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it" + exit 1 + ;; +esac + +echo "::group::Setup ExecuTorch Requirements" +./install_requirements.sh +pip list +echo "::endgroup::" + +echo "::group::Prepare $MODEL_NAME Artifacts" + + +# Download tokenizer files +if [ "$TOKENIZER_FILE" != "" ]; then + curl -L $TOKENIZER_URL/$TOKENIZER_FILE -o $MODEL_DIR/$TOKENIZER_FILE +else + curl -L $TOKENIZER_URL/tokenizer.json -o $MODEL_DIR/tokenizer.json + curl -L $TOKENIZER_URL/tokenizer_config.json -o $MODEL_DIR/tokenizer_config.json + curl -L $TOKENIZER_URL/special_tokens_map.json -o $MODEL_DIR/special_tokens_map.json +fi + +# Download test files +if [ "$AUDIO_URL" != "" ]; then + curl -L $AUDIO_URL -o ${MODEL_DIR}/$AUDIO_FILE +elif [[ "$MODEL_NAME" == *whisper* ]]; then + conda install -y -c conda-forge "ffmpeg<8" + pip install datasets soundfile + pip install torchcodec==0.10.0.dev20251211 --extra-index-url https://download.pytorch.org/whl/nightly/cpu + python -c "from datasets import load_dataset;import soundfile as sf;sample = load_dataset('distil-whisper/librispeech_long', 'clean', split='validation')[0]['audio'];sf.write('${MODEL_DIR}/$AUDIO_FILE', sample['array'][:sample['sampling_rate']*30], sample['sampling_rate'])" +fi + +ls -al +echo "::endgroup::" + +echo "::group::Build $MODEL_NAME Runner" + +if [ "$DEVICE" != "cuda" ] && [ "$DEVICE" != "metal" ]; then + echo "Error: Unsupported device '$DEVICE'. Must be 'cuda' or 'metal'." + exit 1 +fi + +MAKE_TARGET="${RUNNER_PATH}-${DEVICE}" +make "${MAKE_TARGET}" +echo "::endgroup::" + +echo "::group::Run $MODEL_NAME Runner" +set +e +if [ "$DEVICE" = "cuda" ]; then + export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH +fi + +# Build runner command with common arguments +RUNNER_BIN="cmake-out/examples/models/$RUNNER_PATH/$RUNNER_TARGET" +RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --data_path ${MODEL_DIR}/aoti_${DEVICE}_blob.ptd --temperature 0" + +# Add model-specific arguments +case "$MODEL_NAME" in + voxtral) + RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --audio_path ${MODEL_DIR}/$AUDIO_FILE --processor_path ${MODEL_DIR}/$PREPROCESSOR" + ;; + whisper-*) + RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/ --audio_path ${MODEL_DIR}/$AUDIO_FILE --processor_path ${MODEL_DIR}/$PREPROCESSOR" + ;; + gemma3) + RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/ --image_path $IMAGE_PATH" + ;; +esac + +OUTPUT=$($RUNNER_BIN $RUNNER_ARGS 2>&1) +EXIT_CODE=$? +set -e + +if ! echo "$OUTPUT" | grep -iq "$EXPECTED_OUTPUT"; then + echo "Expected output '$EXPECTED_OUTPUT' not found in output" + exit 1 +else + echo "Success: '$EXPECTED_OUTPUT' found in output" +fi + +if [ $EXIT_CODE -ne 0 ]; then + echo "Unexpected exit code: $EXIT_CODE" + exit $EXIT_CODE +fi +echo "::endgroup::" + +popd diff --git a/.ci/scripts/test_openvino.sh b/.ci/scripts/test_openvino.sh index 85884a6475b..2bb2115b1ec 100755 --- a/.ci/scripts/test_openvino.sh +++ b/.ci/scripts/test_openvino.sh @@ -10,7 +10,7 @@ set -ex # shellcheck source=/dev/null source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" -source openvino/dist/setupvars.sh +source openvino/setupvars.sh cd backends/openvino/tests python test_runner.py --test_type ops python test_runner.py --test_type models diff --git a/.ci/scripts/test_phi_3_mini.sh b/.ci/scripts/test_phi_3_mini.sh index 289263ace37..086822bbad4 100644 --- a/.ci/scripts/test_phi_3_mini.sh +++ b/.ci/scripts/test_phi_3_mini.sh @@ -23,8 +23,16 @@ if hash nproc &> /dev/null; then NPROC=$(nproc); fi cmake_install_executorch_libraries() { rm -rf cmake-out - cmake --preset llm -DCMAKE_INSTALL_PREFIX=cmake-out -DCMAKE_BUILD_TYPE=${BUILD_TYPE} - cmake --build cmake-out -j16 --target install --config ${BUILD_TYPE} + + # Select workflow preset based on BUILD_TYPE + if [[ "${BUILD_TYPE}" == "Debug" ]]; then + WORKFLOW_PRESET="llm-debug" + else + WORKFLOW_PRESET="llm-release" + fi + + echo "Using workflow preset: ${WORKFLOW_PRESET}" + cmake --workflow --preset ${WORKFLOW_PRESET} } cmake_build_phi_3_mini() { @@ -36,34 +44,33 @@ cmake_build_phi_3_mini() { cmake --build ${BUILD_DIR}/${MODEL_DIR} -j${NPROC} --config ${BUILD_TYPE} } -# Download and convert tokenizer.model +# Download tokenizer.model prepare_tokenizer() { - echo "Downloading and converting tokenizer.model" - wget -O tokenizer.model "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/tokenizer.model?download=true" - $PYTHON_EXECUTABLE -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin + echo "Downloading tokenizer.model" + wget -O tokenizer.model "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/tokenizer.model?download=true" } # Export phi-3-mini model to pte export_phi_3_mini () { echo "Exporting phi-3-mini. This will take a few minutes" - $PYTHON_EXECUTABLE -m executorch.examples.models.phi-3-mini.export_phi-3-mini -c "4k" -s 128 -o phi-3-mini.pte + optimum-cli export executorch --model microsoft/Phi-3-mini-4k-instruct --task text-generation --recipe xnnpack --output_dir ./ } run_and_verify() { NOW=$(date +"%H:%M:%S") echo "Starting to run phi-3-mini runner at ${NOW}" - if [[ ! -f "phi-3-mini.pte" ]]; then - echo "Export failed. Abort" + if [[ ! -f "model.pte" ]]; then + echo "Missing model artifact. Abort" exit 1 fi - if [[ ! -f "tokenizer.bin" ]]; then - echo "tokenizer.bin is missing." + if [[ ! -f "tokenizer.model" ]]; then + echo "tokenizer.model is missing." exit 1 fi ${BUILD_DIR}/${MODEL_DIR}/phi_3_mini_runner \ - --model_path=phi-3-mini.pte \ - --tokenizer_path=tokenizer.bin \ + --model_path=model.pte \ + --tokenizer_path=tokenizer.model \ --seq_len=60 \ --temperature=0 \ --prompt="<|system|> @@ -92,7 +99,7 @@ What is the capital of France?<|end|> cmake_install_executorch_libraries cmake_build_phi_3_mini -# Step 2. Export the tokenizer and model +# Step 2. Export the model prepare_tokenizer export_phi_3_mini diff --git a/.ci/scripts/test_qnn_static_llama.sh b/.ci/scripts/test_qnn_static_llama.sh deleted file mode 100644 index 7898d03b3b9..00000000000 --- a/.ci/scripts/test_qnn_static_llama.sh +++ /dev/null @@ -1,69 +0,0 @@ -#!/bin/bash -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -set -euxo pipefail - -source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" - -# Download QNN_SDK. If already downloaded, export environment path -source "$(dirname "${BASH_SOURCE[0]}")/../../backends/qualcomm/scripts/install_qnn_sdk.sh" -install_qnn - -export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." && pwd)" -export LD_LIBRARY_PATH="${QNN_SDK_ROOT}/lib/x86_64-linux-clang" -export PYTHONPATH=".." -cp schema/program.fbs exir/_serialize/program.fbs -cp schema/scalar_type.fbs exir/_serialize/scalar_type.fbs -cp -f build-x86/backends/qualcomm/PyQnnManagerAdaptor.cpython-310-x86_64-linux-gnu.so backends/qualcomm/python -cp -f build-x86/backends/qualcomm/PyQnnWrapperAdaptor.cpython-310-x86_64-linux-gnu.so backends/qualcomm/python - -if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then - PYTHON_EXECUTABLE=python3 -fi - -which "${PYTHON_EXECUTABLE}" - -# Although static llama CI does not require graphviz, it is required by test_qnn_delegate.py -pip install graphviz - -# Download stories llama110m artifacts -download_stories_model_artifacts -echo "Creating tokenizer.bin" -$PYTHON_EXECUTABLE -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin - -set +e -# Compile only as weight sharing is not applicable on x86. -$PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_110m --model SM8650 --build_folder build-android/ --executorch_root . --artifact_dir ./stories_110m_pte_size --llama_artifacts . --compile_only -exit_code1=$? - -# Checks accuracy with weight sharing disabled since x86 does not support weight sharing. -$PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_110m --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./stories_110m_accuracy --llama_artifacts . --enable_x86_64 -exit_code2=$? - -# Check BC -bash backends/qualcomm/bc/test_qnn_static_llama_bc.sh -exit_code3=$? - -# Check the exit codes and print messages -if [ $exit_code1 -ne 0 ]; then - echo "Static Llama compile only with weight sharing test failed. $exit_code1." -fi - -if [ $exit_code2 -ne 0 ]; then - echo "Static Llama accuracy test failed. $exit_code2." -fi - -if [ $exit_code3 -ne 0 ]; then - echo "Static Llama BACKWARD COMPATIBILITY test failed. $exit_code3." -fi - -# Return failure if either program failed -if [ $exit_code1 -ne 0 ] || [ $exit_code2 -ne 0 ] || [ $exit_code3 -ne 0 ]; then - exit 1 -else - exit 0 -fi diff --git a/.ci/scripts/test_qnn_static_llama_eval.sh b/.ci/scripts/test_qnn_static_llama_eval.sh new file mode 100644 index 00000000000..5faa0b854e8 --- /dev/null +++ b/.ci/scripts/test_qnn_static_llama_eval.sh @@ -0,0 +1,90 @@ +#!/bin/bash +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -euo pipefail + +echo ">>> Script invoked with arguments: $@" + +source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" + +# Download QNN_SDK. If already downloaded, export environment path +source "$(dirname "${BASH_SOURCE[0]}")/../../backends/qualcomm/scripts/install_qnn_sdk.sh" +install_qnn + +export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." && pwd)" +export LD_LIBRARY_PATH="${QNN_SDK_ROOT}/lib/x86_64-linux-clang" +export PYTHONPATH=".." +cp schema/program.fbs exir/_serialize/program.fbs +cp schema/scalar_type.fbs exir/_serialize/scalar_type.fbs +cp -f build-x86/backends/qualcomm/PyQnnManagerAdaptor.cpython-310-x86_64-linux-gnu.so backends/qualcomm/python + +if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then + PYTHON_EXECUTABLE=python3 +fi + +which "${PYTHON_EXECUTABLE}" + +# ------------------------------- +# Parse args +# ------------------------------- +EXTRA_FLAGS="" +THRESHOLD=62.0 # default fallback + +while [[ $# -gt 0 ]]; do + case "$1" in + --flags) + EXTRA_FLAGS="$2" + shift 2 + ;; + --threshold) + THRESHOLD="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Config +PYTHON_EXECUTABLE="${PYTHON_EXECUTABLE:-python3}" +MODEL="qwen2_5-0_5b" +MAX_SEQ=1024 +PTQ="16a4w" + +EXTRA_FLAGS="$@" + +# Run command and capture *both stdout and stderr* +LOG_FILE="eval_${MODEL}_$(date +%Y%m%d_%H%M%S).log" + +echo ">>> Running evaluation with flags: $EXTRA_FLAGS | threshold: $THRESHOLD" +$PYTHON_EXECUTABLE -m executorch.examples.qualcomm.oss_scripts.llama.eval_llama_qnn \ + --decoder_model "$MODEL" \ + --quant_linear_only \ + --max_seq_length "$MAX_SEQ" \ + --ptq "$PTQ" \ + $EXTRA_FLAGS 2>&1 | tee "$LOG_FILE" + +# Extract last word_perplexity +LAST_PERP=$(grep "INFO:root:wikitext:" "$LOG_FILE" | tail -n 1 | sed -E "s/.*'word_perplexity,none': ([0-9.]+).*/\1/") + +if [[ -z "$LAST_PERP" ]]; then + echo "❌ Could not find word_perplexity in logs!" + exit 1 +fi + +echo ">>> Last word_perplexity = $LAST_PERP" + +# Compare against threshold +awk -v val="$LAST_PERP" -v thr="$THRESHOLD" 'BEGIN {exit (val > thr)}' +if [[ $? -ne 0 ]]; then + echo "❌ Regression detected: word_perplexity ($LAST_PERP) > threshold ($THRESHOLD)" + exit 1 +fi + +echo "✅ Check passed: word_perplexity ($LAST_PERP) <= $THRESHOLD" diff --git a/.ci/scripts/test_qnn_static_llm.sh b/.ci/scripts/test_qnn_static_llm.sh new file mode 100644 index 00000000000..46923f52127 --- /dev/null +++ b/.ci/scripts/test_qnn_static_llm.sh @@ -0,0 +1,93 @@ +#!/bin/bash +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -euxo pipefail + +source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" + +TASK_NAME=$1 +if [[ -z "${TASK_NAME:-}" ]]; then + echo "Missing task name, exiting..." + exit 1 +fi + + +# Download QNN_SDK. If already downloaded, export environment path +source "$(dirname "${BASH_SOURCE[0]}")/../../backends/qualcomm/scripts/install_qnn_sdk.sh" +install_qnn + +export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." && pwd)" +export LD_LIBRARY_PATH="${QNN_SDK_ROOT}/lib/x86_64-linux-clang" +export PYTHONPATH=".." +cp schema/program.fbs exir/_serialize/program.fbs +cp schema/scalar_type.fbs exir/_serialize/scalar_type.fbs +cp -f build-x86/backends/qualcomm/PyQnnManagerAdaptor.cpython-310-x86_64-linux-gnu.so backends/qualcomm/python + +if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then + PYTHON_EXECUTABLE=python3 +fi + +which "${PYTHON_EXECUTABLE}" + +# Although static llama CI does not require graphviz, it is required by test_qnn_delegate.py +pip install graphviz + +set +e + +echo "Executing task: $TASK_NAME" +if [[ "${TASK_NAME}" == "stories_110m" ]]; then + # Download stories llama110m artifacts + download_stories_model_artifacts + echo "Creating tokenizer.bin" + $PYTHON_EXECUTABLE -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin + + # Compile only as weight sharing is not applicable on x86. + $PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_110m --model SM8650 --build_folder build-android/ --executorch_root . --artifact_dir ./stories_110m_pte_size --llama_artifacts . --compile_only + exit_code1=$? + + # Checks accuracy with weight sharing disabled since x86 does not support weight sharing. + $PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_110m --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./stories_110m_accuracy --llama_artifacts . --enable_x86_64 + exit_code2=$? + + # Check the exit codes and print messages + if [ $exit_code1 -ne 0 ]; then + echo "Static Llama compile only with weight sharing test failed. $exit_code1." + fi + + if [ $exit_code2 -ne 0 ]; then + echo "Static Llama accuracy test failed. $exit_code2." + fi + + if [ $exit_code1 -ne 0 ] || [ $exit_code2 -ne 0 ]; then + exit 1 + else + exit 0 + fi + +elif [[ "${TASK_NAME}" == "stories_260k_bc" ]]; then + + # Check BC + bash backends/qualcomm/bc/test_qnn_static_llama_bc.sh + exit_code1=$? + if [ $exit_code1 -ne 0 ]; then + exit 1 + else + exit 0 + fi + +elif [[ "${TASK_NAME}" == "smollm2_135m" ]]; then + $PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_static_llm_model --model_name smollm2_135m --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./static_smollm2 --enable_x86_64 + exit_code1=$? + if [ $exit_code1 -ne 0 ]; then + exit 1 + else + exit 0 + fi +else + echo "Unsupported task: $TASK_NAME" + exit 1 +fi diff --git a/.ci/scripts/test_torchao_huggingface_checkpoints.sh b/.ci/scripts/test_torchao_huggingface_checkpoints.sh index c0910b47826..da50d28800a 100644 --- a/.ci/scripts/test_torchao_huggingface_checkpoints.sh +++ b/.ci/scripts/test_torchao_huggingface_checkpoints.sh @@ -1,10 +1,11 @@ #!/usr/bin/env bash -set -euo pipefail +set -euxo pipefail # ------------------------- # Args / flags # ------------------------- TEST_WITH_RUNNER=0 +USE_TORCHAO_KERNELS=0 MODEL_NAME="" # Parse args @@ -22,10 +23,14 @@ while [[ $# -gt 0 ]]; do --test_with_runner) TEST_WITH_RUNNER=1 ;; + --use_torchao_kernels) + USE_TORCHAO_KERNELS=1 + ;; -h|--help) - echo "Usage: $0 [--test_with_runner]" + echo "Usage: $0 [--test_with_runner] [--use_torchao_kernels]" echo " model_name: qwen3_4b | phi_4_mini" echo " --test_with_runner: build ET + run llama_main to sanity-check the export" + echo " --use_torchao_kernels: use torchao kernels for linear and tied embedding" exit 0 ;; *) @@ -42,6 +47,13 @@ fi MODEL_OUT=model.pte + +# Default to XNNPACK +BACKEND_ARGS="-X --xnnpack-extended-ops" +if [[ "$USE_TORCHAO_KERNELS" -eq 1 ]]; then + BACKEND_ARGS="--use-torchao-kernels" +fi + case "$MODEL_NAME" in qwen3_4b) echo "Running Qwen3-4B export..." @@ -58,12 +70,12 @@ case "$MODEL_NAME" in --output_name $MODEL_OUT \ -kv \ --use_sdpa_with_kv_cache \ - -X \ - --xnnpack-extended-ops \ --max_context_length 1024 \ --max_seq_length 1024 \ + --metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \ + --verbose \ --dtype fp32 \ - --metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' + ${BACKEND_ARGS} ;; phi_4_mini) @@ -81,12 +93,12 @@ case "$MODEL_NAME" in --output_name $MODEL_OUT \ -kv \ --use_sdpa_with_kv_cache \ - -X \ - --xnnpack-extended-ops \ --max_context_length 1024 \ --max_seq_length 1024 \ + --metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \ + --verbose \ --dtype fp32 \ - --metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' + ${BACKEND_ARGS} ;; *) @@ -104,6 +116,10 @@ if [[ $MODEL_SIZE -gt $EXPECTED_MODEL_SIZE_UPPER_BOUND ]]; then fi # Install ET with CMake +EXECUTORCH_BUILD_KERNELS_TORCHAO="OFF" +if [[ "$USE_TORCHAO_KERNELS" -eq 1 ]]; then + EXECUTORCH_BUILD_KERNELS_TORCHAO="ON" +fi if [[ "$TEST_WITH_RUNNER" -eq 1 ]]; then echo "[runner] Building and testing llama_main ..." cmake -DPYTHON_EXECUTABLE=python \ @@ -113,6 +129,7 @@ if [[ "$TEST_WITH_RUNNER" -eq 1 ]]; then -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_XNNPACK=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ @@ -120,6 +137,7 @@ if [[ "$TEST_WITH_RUNNER" -eq 1 ]]; then -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ -DEXECUTORCH_BUILD_EXTENSION_LLM=ON \ -DEXECUTORCH_BUILD_KERNELS_LLM=ON \ + -DEXECUTORCH_BUILD_KERNELS_TORCHAO=${EXECUTORCH_BUILD_KERNELS_TORCHAO} \ -Bcmake-out . cmake --build cmake-out -j16 --config Release --target install diff --git a/.ci/scripts/test_wheel_package_qnn.sh b/.ci/scripts/test_wheel_package_qnn.sh new file mode 100644 index 00000000000..f245554ddbe --- /dev/null +++ b/.ci/scripts/test_wheel_package_qnn.sh @@ -0,0 +1,226 @@ +#!/bin/bash +# === CI Wheel Build & Test Script === + +# Exit immediately on error, print each command, and capture all output to build.log +set -e +set -x +exec > >(tee -i build.log) 2>&1 + +# Save repo root +REPO_ROOT=$(pwd) + +# ---------------------------- +# Dynamically create script_qnn_wheel_test.py +# ---------------------------- +cat > "/tmp/script_qnn_wheel_test.py" << 'EOF' +# pyre-ignore-all-errors +import argparse + +import torch +from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer +from executorch.backends.qualcomm.utils.utils import ( + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + get_soc_to_chipset_map, + to_edge_transform_and_lower_to_qnn, +) +from executorch.exir.backend.utils import format_delegated_graph +from executorch.examples.models.model_factory import EagerModelFactory +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.extension.export_util.utils import save_pte_program +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e, prepare_qat_pt2e + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("-f", "--output_folder", type=str, default="", help="The folder to store the exported program") + parser.add_argument("--soc", type=str, default="SM8650", help="Specify the SoC model.") + parser.add_argument("-q", "--quantization", choices=["ptq", "qat"], help="Run post-traininig quantization.") + args = parser.parse_args() + + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + def forward(self, arg): + return self.linear(arg) + def get_example_inputs(self): + return (torch.randn(3, 3),) + + model = LinearModule() + example_inputs = model.get_example_inputs() + + if args.quantization: + quantizer = QnnQuantizer() + m = torch.export.export(model.eval(), example_inputs, strict=True).module() + if args.quantization == "qat": + m = prepare_qat_pt2e(m, quantizer) + m(*example_inputs) + elif args.quantization == "ptq": + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + else: + m = model + + use_fp16 = True if args.quantization is None else False + backend_options = generate_htp_compiler_spec(use_fp16=use_fp16) + compile_spec = generate_qnn_executorch_compiler_spec( + soc_model=get_soc_to_chipset_map()[args.soc], + backend_options=backend_options, + ) + delegated_program = to_edge_transform_and_lower_to_qnn(m, example_inputs, compile_spec) + output_graph = format_delegated_graph(delegated_program.exported_program().graph_module) + # Ensure QnnBackend is in the output graph + assert "QnnBackend" in output_graph + executorch_program = delegated_program.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=False) + ) + save_pte_program(executorch_program, "linear", args.output_folder) + +if __name__ == "__main__": + main() +EOF + +# ---------------------------- +# Wheel build and .so checks +# ---------------------------- +echo "=== Building Wheel Package ===" +source .ci/scripts/utils.sh +install_executorch +EXECUTORCH_BUILDING_WHEEL=1 python setup.py bdist_wheel +unset EXECUTORCH_BUILDING_WHEEL + +WHEEL_FILE=$(ls dist/*.whl | head -n 1) +echo "Found wheel: $WHEEL_FILE" + +PYTHON_VERSION=$1 +# ---------------------------- +# Check wheel does NOT contain qualcomm/sdk +# ---------------------------- +echo "Checking wheel does not contain qualcomm/sdk..." +SDK_FILES=$(unzip -l "$WHEEL_FILE" | awk '{print $4}' | grep -E "executorch/backends/qualcomm/sdk" || true) +if [ -n "$SDK_FILES" ]; then + echo "ERROR: Wheel package contains unexpected qualcomm/sdk files:" + echo "$SDK_FILES" + exit 1 +else + echo "OK: No qualcomm/sdk files found in wheel" +fi + +# ---------------------------- +# Check .so files in the wheel +# ---------------------------- +echo "Checking for .so files inside the wheel..." +WHEEL_SO_FILES=$(unzip -l "$WHEEL_FILE" | awk '{print $4}' | grep -E "executorch/backends/qualcomm/python" || true) +if [ -z "$WHEEL_SO_FILES" ]; then + echo "ERROR: No .so files found in wheel under executorch/backends/qualcomm/python" + exit 1 +else + echo "Wheel contains the following .so files:" + echo "$WHEEL_SO_FILES" +fi + +# ---------------------------- +# Helpers +# ---------------------------- +get_site_packages_dir () { + local PYBIN="$1" + "$PYBIN" - <<'PY' +import sysconfig, sys +print(sysconfig.get_paths().get("purelib") or sysconfig.get_paths().get("platlib")) +PY +} + +run_core_tests () { + local PYBIN="$1" # path to python + local PIPBIN="$2" # path to pip + local LABEL="$3" # label to print (conda/venv) + + echo "=== [$LABEL] Installing wheel & deps ===" + "$PIPBIN" install --upgrade pip + "$PIPBIN" install "$WHEEL_FILE" + TORCH_VERSION=$( + "$PYBIN" - <<'PY' +import runpy +module_vars = runpy.run_path("torch_pin.py") +print(module_vars["TORCH_VERSION"]) +PY +) + + NIGHTLY_VERSION=$( + "$PYBIN" - <<'PY' +import runpy +module_vars = runpy.run_path("torch_pin.py") +print(module_vars["NIGHTLY_VERSION"]) +PY +) + echo "=== [$LABEL] Install torch==${TORCH_VERSION}.${NIGHTLY_VERSION} ===" + + # Install torchao based on the pinned PyTorch version + "$PIPBIN" install torch=="${TORCH_VERSION}.${NIGHTLY_VERSION}" --index-url "https://download.pytorch.org/whl/nightly/cpu" + "$PIPBIN" install wheel + + # Install torchao based on the pinned commit from third-party/ao submodule + pushd "$REPO_ROOT/third-party/ao" > /dev/null + export USE_CPP=0 + "$PIPBIN" install . --no-build-isolation + popd > /dev/null + + echo "=== [$LABEL] Import smoke tests ===" + "$PYBIN" -c "import executorch; print('executorch imported successfully')" + "$PYBIN" -c "import executorch.backends.qualcomm; print('executorch.backends.qualcomm imported successfully')" + "$PYBIN" -c "from executorch.export.target_recipes import get_android_recipe; recipe = get_android_recipe('android-arm64-snapdragon-fp16'); print(f'executorch.export.target_recipes imported successfully: {recipe}')" + + echo "=== [$LABEL] List installed executorch/backends/qualcomm/python ===" + local SITE_DIR + SITE_DIR="$(get_site_packages_dir "$PYBIN")" + local SO_DIR="$SITE_DIR/executorch/backends/qualcomm/python" + ls -l "$SO_DIR" || echo "Folder does not exist!" + + echo "=== [$LABEL] Run export script to generate linear.pte ===" + (cd "$REPO_ROOT" && "$PYBIN" "/tmp/script_qnn_wheel_test.py") + + if [ -f "$REPO_ROOT/linear.pte" ]; then + echo "[$LABEL] Model file linear.pte successfully created" + else + echo "ERROR: [$LABEL] Model file linear.pte was not created" + exit 1 + fi +} + +# ---------------------------- +# Conda environment setup & tests +# ---------------------------- +echo "=== Testing in Conda env ===" +TEMP_ENV_DIR=$(mktemp -d) +echo "Using temporary directory for conda: $TEMP_ENV_DIR" +conda create -y -p "$TEMP_ENV_DIR/env" python=$PYTHON_VERSION +# derive python/pip paths inside the conda env +CONDA_PY="$TEMP_ENV_DIR/env/bin/python" +CONDA_PIP="$TEMP_ENV_DIR/env/bin/pip" +# Some images require conda run; keep pip/python direct to simplify path math +run_core_tests "$CONDA_PY" "$CONDA_PIP" "conda" + +# Cleanup conda env +conda env remove -p "$TEMP_ENV_DIR/env" -y || true +rm -rf "$TEMP_ENV_DIR" + +# ---------------------------- +# Python venv setup & tests +# ---------------------------- +echo "=== Testing in Python venv ===" +TEMP_VENV_DIR=$(mktemp -d) +echo "Using temporary directory for venv: $TEMP_VENV_DIR" +python3 -m venv "$TEMP_VENV_DIR/venv" +VENV_PY="$TEMP_VENV_DIR/venv/bin/python" +VENV_PIP="$TEMP_VENV_DIR/venv/bin/pip" + +# Ensure venv has wheel/build basics if needed +"$VENV_PIP" install --upgrade pip + +run_core_tests "$VENV_PY" "$VENV_PIP" "venv" + +# Cleanup venv +rm -rf "$TEMP_VENV_DIR" + +echo "=== All tests completed! ===" diff --git a/.ci/scripts/test_yolo12.sh b/.ci/scripts/test_yolo12.sh index e3f20d5f970..594ddbf86ed 100755 --- a/.ci/scripts/test_yolo12.sh +++ b/.ci/scripts/test_yolo12.sh @@ -119,6 +119,8 @@ cmake_install_executorch_libraries() { -DEXECUTORCH_BUILD_XNNPACK="$XNNPACK" \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -B"${build_dir}" @@ -131,6 +133,8 @@ cmake_install_executorch_libraries() { -DEXECUTORCH_BUILD_XNNPACK="$XNNPACK" \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_ENABLE_LOGGING=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ diff --git a/.ci/scripts/unittest-buck2.sh b/.ci/scripts/unittest-buck2.sh index f748be62ac1..e78e682faac 100755 --- a/.ci/scripts/unittest-buck2.sh +++ b/.ci/scripts/unittest-buck2.sh @@ -9,9 +9,14 @@ set -eux # TODO: expand this to //... # TODO: can't query cadence & vulkan backends # TODO: can't query //kernels/prim_ops because of non-buckified stuff in OSS. -buck2 query "//backends/apple/... + //backends/example/... + \ +# TODO: Make //backends/arm tests use runtime wrapper so we can just query //backends/arm/... +buck2 query "//backends/apple/... + //backends/arm: + //backends/arm/debug/... + \ +//backends/arm/operator_support/... + //backends/arm/operators/... + \ +//backends/arm/_passes/... + //backends/arm/runtime/... + //backends/arm/tosa/... \ ++ //backends/example/... + \ //backends/mediatek/... + //backends/transforms/... + \ -//backends/xnnpack/... + //configurations/... + //extension/flat_tensor: + \ +//backends/xnnpack/... + //codegen/tools/... + \ +//configurations/... + //extension/flat_tensor: + \ //extension/llm/runner: + //kernels/aten/... + //kernels/optimized/... + \ //kernels/portable/... + //kernels/quantized/... + //kernels/test/... + \ //runtime/... + //schema/... + //test/... + //util/..." @@ -30,7 +35,17 @@ BUILDABLE_KERNELS_PRIM_OPS_TARGETS=$(buck2 query //kernels/prim_ops/... | grep - for op in "build" "test"; do buck2 $op $BUILDABLE_OPTIMIZED_OPS \ //examples/selective_build:select_all_dtype_selective_lib_portable_lib \ + //extension/llm/custom_ops/spinquant/test:fast_hadamard_transform_test \ + //extension/llm/runner/test:test_multimodal_input \ + //extension/llm/runner/test:test_generation_config \ //kernels/portable/... \ $BUILDABLE_KERNELS_PRIM_OPS_TARGETS //runtime/backend/... //runtime/core/... \ //runtime/executor: //runtime/kernel/... //runtime/platform/... done + +# Build only without testing +buck2 build //codegen/tools/... \ + //extension/llm/runner/io_manager:io_manager \ + //extension/llm/modules/... \ + //extension/llm/runner:multimodal_runner_lib \ + //extension/llm/runner:text_decoder_runner diff --git a/.ci/scripts/utils.sh b/.ci/scripts/utils.sh index f6f6ece786b..7fb7517e771 100644 --- a/.ci/scripts/utils.sh +++ b/.ci/scripts/utils.sh @@ -44,10 +44,48 @@ install_pip_dependencies() { popd || return } +dedupe_macos_loader_path_rpaths() { + if [[ "$(uname)" != "Darwin" ]]; then + return + fi + + local torch_lib_dir + pushd .. + torch_lib_dir=$(python -c "import importlib.util; print(importlib.util.find_spec('torch').submodule_search_locations[0])")/lib + popd + + if [[ -z "${torch_lib_dir}" || ! -d "${torch_lib_dir}" ]]; then + return + fi + + local torch_libs=( + "libtorch_cpu.dylib" + "libtorch.dylib" + "libc10.dylib" + ) + + for lib_name in "${torch_libs[@]}"; do + local lib_path="${torch_lib_dir}/${lib_name}" + if [[ ! -f "${lib_path}" ]]; then + continue + fi + + local removed=0 + # Repeatedly remove the @loader_path rpath entries until none remain. + while install_name_tool -delete_rpath @loader_path "${lib_path}" 2>/dev/null; do + removed=1 + done + + if [[ "${removed}" == "1" ]]; then + install_name_tool -add_rpath @loader_path "${lib_path}" || true + fi + done +} + install_domains() { echo "Install torchvision and torchaudio" - pip install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${TORCHAUDIO_VERSION}" - pip install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${TORCHVISION_VERSION}" + pip install --no-build-isolation --user "git+https://github.com/pytorch/audio.git@${TORCHAUDIO_VERSION}" + pip install --no-build-isolation --user "git+https://github.com/pytorch/vision.git@${TORCHVISION_VERSION}" } install_pytorch_and_domains() { @@ -101,6 +139,7 @@ install_pytorch_and_domains() { echo "Use cached wheel at ${cached_torch_wheel}" fi + dedupe_macos_loader_path_rpaths # Grab the pinned audio and vision commits from PyTorch TORCHAUDIO_VERSION=$(cat .github/ci_commit_pins/audio.txt) export TORCHAUDIO_VERSION @@ -125,14 +164,15 @@ build_executorch_runner_cmake() { clean_executorch_install_folders mkdir "${CMAKE_OUTPUT_DIR}" - pushd "${CMAKE_OUTPUT_DIR}" || return if [[ $1 == "Debug" ]]; then CXXFLAGS="-fsanitize=address,undefined" else CXXFLAGS="" fi - CXXFLAGS="$CXXFLAGS" retry cmake -DPYTHON_EXECUTABLE="${PYTHON_EXECUTABLE}" -DCMAKE_BUILD_TYPE="${1:-Release}" .. - popd || return + CXXFLAGS="$CXXFLAGS" retry cmake \ + -DPYTHON_EXECUTABLE="${PYTHON_EXECUTABLE}" \ + -DCMAKE_BUILD_TYPE="${1:-Release}" \ + -B${CMAKE_OUTPUT_DIR} . if [ "$(uname)" == "Darwin" ]; then CMAKE_JOBS=$(( $(sysctl -n hw.ncpu) - 1 )) diff --git a/.ci/scripts/wheel/test_base.py b/.ci/scripts/wheel/test_base.py index f8a7309a6c2..278e46fe75a 100644 --- a/.ci/scripts/wheel/test_base.py +++ b/.ci/scripts/wheel/test_base.py @@ -41,6 +41,18 @@ class ModelTest: def run_tests(model_tests: List[ModelTest]) -> None: + # Test that we can import the portable_lib module - verifies RPATH is correct + print("Testing portable_lib import...") + try: + from executorch.extension.pybindings._portable_lib import ( # noqa: F401 + _load_for_executorch, + ) + + print("✓ Successfully imported _load_for_executorch from portable_lib") + except ImportError as e: + print(f"✗ Failed to import portable_lib: {e}") + raise + # Why are we doing this envvar shenanigans? Since we build the testers, which # uses buck, we cannot run as root. This is a sneaky of getting around that # test. diff --git a/.githooks/README.md b/.githooks/README.md new file mode 100644 index 00000000000..cf79397337c --- /dev/null +++ b/.githooks/README.md @@ -0,0 +1,57 @@ +# Git Hooks + +This directory contains Git hooks for the ExecuTorch repository. + +## Pre-commit Hook + +The pre-commit hook automatically updates the PyTorch commit pin in `.ci/docker/ci_commit_pins/pytorch.txt` whenever `torch_pin.py` is modified. + +### How It Works + +1. When you commit changes to `torch_pin.py`, the hook detects the change +2. It parses the `NIGHTLY_VERSION` field (e.g., `dev20251004`) +3. Converts it to a date string (e.g., `2025-10-04`) +4. Fetches the corresponding commit hash from the PyTorch nightly branch at https://github.com/pytorch/pytorch/tree/nightly +5. Updates `.ci/docker/ci_commit_pins/pytorch.txt` with the new commit hash +6. Automatically stages the updated file for commit + +### Installation + +To install the Git hooks, run: + +```bash +.githooks/install.sh +``` + +This will copy the pre-commit hook to `.git/hooks/` and make it executable. + +### Manual Usage + +You can also run the update script manually at any time: + +```bash +python .github/scripts/update_pytorch_pin.py +``` + +### Uninstalling + +To remove the pre-commit hook: + +```bash +rm .git/hooks/pre-commit +``` + +## Troubleshooting + +If the hook fails during a commit: + +1. Check that Python 3 is available in your PATH +2. Ensure you have internet connectivity to fetch commits from GitHub +3. Verify that the `NIGHTLY_VERSION` in `torch_pin.py` is in the correct format (`devYYYYMMDD`) +4. Make sure the corresponding nightly release exists in the PyTorch nightly branch + +You can run the script manually to see detailed error messages: + +```bash +python .github/scripts/update_pytorch_pin.py +``` diff --git a/.githooks/install.sh b/.githooks/install.sh new file mode 100755 index 00000000000..b79f750177b --- /dev/null +++ b/.githooks/install.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# Script to install Git hooks from .githooks directory + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +GIT_DIR="$(git rev-parse --git-dir)" +HOOKS_DIR="${GIT_DIR}/hooks" + +echo "Installing Git hooks..." + +# Install pre-commit hook +echo "📦 Installing pre-commit hook..." +cp "${SCRIPT_DIR}/pre-commit" "${HOOKS_DIR}/pre-commit" +chmod +x "${HOOKS_DIR}/pre-commit" +echo "✅ pre-commit hook installed" + +echo "" +echo "🎉 Git hooks installed successfully!" +echo "" +echo "The pre-commit hook will automatically update .ci/docker/ci_commit_pins/pytorch.txt" +echo "whenever you commit changes to torch_pin.py" diff --git a/.githooks/pre-commit b/.githooks/pre-commit new file mode 100755 index 00000000000..f29342da67e --- /dev/null +++ b/.githooks/pre-commit @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# Pre-commit hook to automatically update PyTorch commit pin and sync c10 directories when torch_pin.py changes + +# Check if torch_pin.py is being committed +if git diff --cached --name-only | grep -q "^torch_pin.py$"; then + echo "🔍 Detected changes to torch_pin.py" + echo "📝 Updating PyTorch commit pin and syncing c10 directories..." + + # Run the update script (which now also syncs c10 directories) + if python .github/scripts/update_pytorch_pin.py; then + # Stage any modified files (pytorch.txt and grafted c10 files) + if ! git diff --quiet .ci/docker/ci_commit_pins/pytorch.txt; then + git add .ci/docker/ci_commit_pins/pytorch.txt + echo "📌 Staged .ci/docker/ci_commit_pins/pytorch.txt" + fi + + # Stage any grafted c10 files + if ! git diff --quiet runtime/core/portable_type/c10/; then + git add runtime/core/portable_type/c10/ + echo "📌 Staged grafted c10 files" + fi + else + echo "❌ Failed to update PyTorch commit pin" + echo "Please run: python .github/scripts/update_pytorch_pin.py" + exit 1 + fi +fi + +exit 0 diff --git a/.github/scripts/cherry_pick.py b/.github/scripts/cherry_pick.py index 1239ee030dd..8de5279f51b 100755 --- a/.github/scripts/cherry_pick.py +++ b/.github/scripts/cherry_pick.py @@ -39,7 +39,15 @@ def parse_args() -> Any: ) parser.add_argument( "--classification", - choices=["regression", "critical", "fixnewfeature", "docs", "release"], + choices=[ + "regression", + "critical", + "fixnewfeature", + "docs", + "release", + "examples", + "testci", + ], required=True, help="the cherry pick category", ) diff --git a/.github/scripts/propose_ghstack_orig_pr.py b/.github/scripts/propose_ghstack_orig_pr.py index 53b796adaa3..3abcc6cdcf9 100644 --- a/.github/scripts/propose_ghstack_orig_pr.py +++ b/.github/scripts/propose_ghstack_orig_pr.py @@ -86,6 +86,17 @@ def get_pr_stack_from_number(ref: str, repo: Repository) -> List[int]: return pr_stack +def get_differential_revision(pr, repo: Repository) -> str: + body = repo.get_pull(pr.number).body + matches = re.findall(r"Differential Revision: .*", body) + count = len(matches) + if count == 1: + # If there's more than one Differential Revision, let's just return empty + # so that we can disambiguate manually. + return matches[0] + return "" + + def create_prs_for_orig_branch(pr_stack: List[int], repo: Repository): # For the first PR, we want to merge to `main` branch, and we will update # as we go through the stack @@ -100,6 +111,7 @@ def create_prs_for_orig_branch(pr_stack: List[int], repo: Repository): # The PR we want to create is then "branch_to_merge" <- gh/user/x/orig # gh/user/x/orig is the clean diff between gh/user/x/base <- gh/user/x/head orig_branch_merge_head = pr.base.ref.replace("base", "orig") + differential_revision_text = get_differential_revision(pr, repo) bot_metadata = f"""This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: https://github.com/pytorch/executorch/pull/{pr.number} by @{pr.user.login} ^ Please use this as the source of truth for the PR details, comments, and reviews @@ -107,6 +119,7 @@ def create_prs_for_orig_branch(pr_stack: List[int], repo: Repository): ghstack PR head: https://github.com/pytorch/executorch/tree/{pr.head.ref} Merge bot PR base: https://github.com/pytorch/executorch/tree/{orig_branch_merge_base} Merge bot PR head: https://github.com/pytorch/executorch/tree/{orig_branch_merge_head} +{differential_revision_text} @diff-train-skip-merge""" existing_orig_pr = repo.get_pulls( diff --git a/.github/scripts/trigger_cuda_perf.sh b/.github/scripts/trigger_cuda_perf.sh new file mode 100755 index 00000000000..402dd009673 --- /dev/null +++ b/.github/scripts/trigger_cuda_perf.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Quick script to trigger cuda-perf workflow via GitHub CLI +# Usage: +# ./trigger_cuda_perf.sh # Use defaults (random model + quant) +# ./trigger_cuda_perf.sh --all # Run ALL models with ALL quantizations +# ./trigger_cuda_perf.sh "openai/whisper-medium" # Single model +# ./trigger_cuda_perf.sh "openai/whisper-small,google/gemma-3-4b-it" "non-quantized,quantized-int4-tile-packed" "100" + +set -e + +# All available models and quantizations +ALL_MODELS="mistralai/Voxtral-Mini-3B-2507,openai/whisper-small,openai/whisper-medium,openai/whisper-large-v3-turbo,google/gemma-3-4b-it" +ALL_QUANTIZATIONS="non-quantized,quantized-int4-tile-packed,quantized-int4-weight-only" + +# Check if gh CLI is installed +if ! command -v gh &> /dev/null; then + echo "Error: GitHub CLI (gh) is not installed." + echo "Install it from: https://cli.github.com/" + echo "" + echo "Quick install:" + echo " macOS: brew install gh" + echo " Linux: See https://github.com/cli/cli/blob/trunk/docs/install_linux.md" + exit 1 +fi + +# Check for --all flag +RUN_ALL=false +if [ "${1:-}" = "--all" ] || [ "${1:-}" = "-a" ]; then + RUN_ALL=true + shift # Remove the flag from arguments +fi + +# Default parameters +if [ "$RUN_ALL" = true ]; then + MODELS="$ALL_MODELS" + QUANT="$ALL_QUANTIZATIONS" + NUM_RUNS="${1:-50}" + RANDOM_MODEL="false" + echo "=========================================" + echo "Triggering cuda-perf workflow" + echo "Mode: RUN ALL MODELS AND QUANTIZATIONS" + echo "=========================================" + echo "Models: ALL (5 models)" + echo "Quantizations: ALL (3 quantizations)" + echo "Total configs: 15 combinations" + echo "Num runs: $NUM_RUNS" + echo "=========================================" +else + MODELS="${1:-}" + QUANT="${2:-}" + NUM_RUNS="${3:-50}" + RANDOM_MODEL="${4:-false}" + + # Display configuration + echo "=========================================" + echo "Triggering cuda-perf workflow" + echo "=========================================" + if [ -z "$MODELS" ]; then + echo "Models: (random selection)" + else + echo "Models: $MODELS" + fi + if [ -z "$QUANT" ]; then + echo "Quantizations: (random selection)" + else + echo "Quantizations: $QUANT" + fi + echo "Num runs: $NUM_RUNS" + echo "Random model: $RANDOM_MODEL" + echo "=========================================" +fi + +echo "" + +# Trigger workflow +gh workflow run cuda-perf.yml \ + -R pytorch/executorch \ + -f models="$MODELS" \ + -f quantizations="$QUANT" \ + -f num_runs="$NUM_RUNS" \ + -f random_model="$RANDOM_MODEL" + +if [ $? -eq 0 ]; then + echo "✓ Workflow triggered successfully!" + echo "" + echo "View status:" + echo " gh run list --workflow=cuda-perf.yml" + echo "" + echo "Watch the latest run:" + echo " gh run watch \$(gh run list --workflow=cuda-perf.yml --limit 1 --json databaseId --jq '.[0].databaseId')" +else + echo "✗ Failed to trigger workflow" + exit 1 +fi diff --git a/.github/scripts/update_pytorch_pin.py b/.github/scripts/update_pytorch_pin.py new file mode 100644 index 00000000000..dbc48552d9b --- /dev/null +++ b/.github/scripts/update_pytorch_pin.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 + +import base64 +import hashlib +import json +import re +import sys +import urllib.request +from pathlib import Path + + +def parse_nightly_version(nightly_version): + """ + Parse NIGHTLY_VERSION (e.g., 'dev20251004') to date string (e.g., '2025-10-04'). + + Args: + nightly_version: String in format 'devYYYYMMDD' + + Returns: + Date string in format 'YYYY-MM-DD' + """ + match = re.match(r"dev(\d{4})(\d{2})(\d{2})", nightly_version) + if not match: + raise ValueError(f"Invalid NIGHTLY_VERSION format: {nightly_version}") + + year, month, day = match.groups() + return f"{year}-{month}-{day}" + + +def get_torch_nightly_version(): + """ + Read NIGHTLY_VERSION from torch_pin.py. + + Returns: + NIGHTLY_VERSION string + """ + with open("torch_pin.py", "r") as f: + content = f.read() + + match = re.search(r'NIGHTLY_VERSION\s*=\s*["\']([^"\']+)["\']', content) + if not match: + raise ValueError("Could not find NIGHTLY_VERSION in torch_pin.py") + + return match.group(1) + + +def get_commit_hash_for_nightly(date_str): + """ + Fetch commit hash from PyTorch nightly branch for a given date. + + Args: + date_str: Date string in format 'YYYY-MM-DD' + + Returns: + Commit hash string + """ + api_url = "https://api.github.com/repos/pytorch/pytorch/commits" + params = f"?sha=nightly&per_page=50" + url = api_url + params + + req = urllib.request.Request(url) + req.add_header("Accept", "application/vnd.github.v3+json") + req.add_header("User-Agent", "ExecuTorch-Bot") + + try: + with urllib.request.urlopen(req) as response: + commits = json.loads(response.read().decode()) + except Exception as e: + print(f"Error fetching commits: {e}", file=sys.stderr) + sys.exit(1) + + # Look for commit with title matching "{date_str} nightly release" + target_title = f"{date_str} nightly release" + + for commit in commits: + commit_msg = commit.get("commit", {}).get("message", "") + # Check if the first line of commit message matches + first_line = commit_msg.split("\n")[0].strip() + if first_line.startswith(f"{date_str} nightly"): + return extract_hash_from_title(first_line) + + raise ValueError( + f"Could not find commit with title matching '{target_title}' in nightly branch" + ) + + +def extract_hash_from_title(title): + match = re.search(r"\(([0-9a-fA-F]{7,40})\)", title) + if not match: + raise ValueError(f"Could not extract commit hash from title '{title}'") + return match.group(1) + + +def update_pytorch_pin(commit_hash): + """ + Update .ci/docker/ci_commit_pins/pytorch.txt with the new commit hash. + + Args: + commit_hash: Commit hash to write + """ + pin_file = ".ci/docker/ci_commit_pins/pytorch.txt" + with open(pin_file, "w") as f: + f.write(f"{commit_hash}\n") + print(f"Updated {pin_file} with commit hash: {commit_hash}") + + +def should_skip_file(filename): + """ + Check if a file should be skipped during sync (build files). + + Args: + filename: Base filename to check + + Returns: + True if file should be skipped + """ + skip_files = {"BUCK", "CMakeLists.txt", "TARGETS", "targets.bzl"} + return filename in skip_files + + +def fetch_file_content(commit_hash, file_path): + """ + Fetch file content from GitHub API. + + Args: + commit_hash: Commit hash to fetch from + file_path: File path in the repository + + Returns: + File content as bytes + """ + api_url = f"https://api.github.com/repos/pytorch/pytorch/contents/{file_path}?ref={commit_hash}" + + req = urllib.request.Request(api_url) + req.add_header("Accept", "application/vnd.github.v3+json") + req.add_header("User-Agent", "ExecuTorch-Bot") + + try: + with urllib.request.urlopen(req) as response: + data = json.loads(response.read().decode()) + # Content is base64 encoded + content = base64.b64decode(data["content"]) + return content + except urllib.request.HTTPError as e: + print(f"Error fetching file {file_path}: {e}", file=sys.stderr) + raise + + +def sync_directory(et_dir, pt_path, commit_hash): + """ + Sync files from PyTorch to ExecuTorch using GitHub API. + Only syncs files that already exist in ExecuTorch - does not add new files. + + Args: + et_dir: ExecuTorch directory path + pt_path: PyTorch directory path in the repository (e.g., "c10") + commit_hash: Commit hash to fetch from + + Returns: + Number of files grafted + """ + files_grafted = 0 + print(f"Checking {et_dir} vs pytorch/{pt_path}...") + + if not et_dir.exists(): + print(f"Warning: ExecuTorch directory {et_dir} does not exist, skipping") + return 0 + + # Loop through files in ExecuTorch directory + for et_file in et_dir.rglob("*"): + if not et_file.is_file(): + continue + + # Skip build files + if should_skip_file(et_file.name): + continue + + # Construct corresponding path in PyTorch + rel_path = et_file.relative_to(et_dir) + pt_file_path = f"{pt_path}/{rel_path}".replace("\\", "/") + + # Fetch content from PyTorch and compare + try: + pt_content = fetch_file_content(commit_hash, pt_file_path) + et_content = et_file.read_bytes() + + if pt_content != et_content: + print(f"⚠️ Difference detected in {rel_path}") + print(f"📋 Grafting from PyTorch commit {commit_hash}...") + + et_file.write_bytes(pt_content) + print(f"✅ Grafted {et_file}") + files_grafted += 1 + except urllib.request.HTTPError as e: + if e.code != 404: # It's ok to have more files in ET than pytorch/pytorch. + print(f"Error fetching {rel_path} from PyTorch: {e}") + except Exception as e: + print(f"Error syncing {rel_path}: {e}") + continue + + return files_grafted + + +def sync_c10_directories(commit_hash): + """ + Sync c10 and torch/headeronly directories from PyTorch to ExecuTorch using GitHub API. + + Args: + commit_hash: PyTorch commit hash to sync from + + Returns: + Total number of files grafted + """ + print("\n🔄 Syncing c10 directories from PyTorch via GitHub API...") + + # Get repository root + repo_root = Path.cwd() + + # Define directory pairs to sync (from check_c10_sync.sh) + # Format: (executorch_dir, pytorch_path_in_repo) + dir_pairs = [ + ( + repo_root / "runtime/core/portable_type/c10/c10", + "c10", + ), + ( + repo_root / "runtime/core/portable_type/c10/torch/headeronly", + "torch/headeronly", + ), + ] + + total_grafted = 0 + for et_dir, pt_path in dir_pairs: + files_grafted = sync_directory(et_dir, pt_path, commit_hash) + total_grafted += files_grafted + + if total_grafted > 0: + print(f"\n✅ Successfully grafted {total_grafted} file(s) from PyTorch") + else: + print("\n✅ No differences found - c10 is in sync") + + return total_grafted + + +def main(): + try: + # Read NIGHTLY_VERSION from torch_pin.py + nightly_version = get_torch_nightly_version() + print(f"Found NIGHTLY_VERSION: {nightly_version}") + + # Parse to date string + date_str = parse_nightly_version(nightly_version) + print(f"Parsed date: {date_str}") + + # Fetch commit hash from PyTorch nightly branch + commit_hash = get_commit_hash_for_nightly(date_str) + print(f"Found commit hash: {commit_hash}") + + # Update the pin file + update_pytorch_pin(commit_hash) + + # Sync c10 directories from PyTorch + sync_c10_directories(commit_hash) + + print( + "\n✅ Successfully updated PyTorch commit pin and synced c10 directories!" + ) + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/_android.yml b/.github/workflows/_android.yml index 2449e94b2af..7b67c340350 100644 --- a/.github/workflows/_android.yml +++ b/.github/workflows/_android.yml @@ -48,26 +48,13 @@ jobs: bash examples/models/llama/install_requirements.sh bash ".ci/scripts/test_llama.sh" -model stories110M -build_tool cmake -dtype fp16 -mode portable -upload ${ARTIFACTS_DIR_NAME}/fp32-xnnpack-custom - mkdir -p examples/demo-apps/android/LlamaDemo/app/libs - cp aar-out/executorch.aar examples/demo-apps/android/LlamaDemo/app/libs - pushd examples/demo-apps/android/LlamaDemo - ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew build assembleAndroidTest - popd - - DEMO_APP_DIR="${ARTIFACTS_DIR_NAME}/llm_demo" - # The app directory is named using its build flavor as a suffix. - mkdir -p "${DEMO_APP_DIR}" - # Collect the app and its test suite - cp examples/demo-apps/android/LlamaDemo/app/build/outputs/apk/debug/*.apk "${DEMO_APP_DIR}" - cp examples/demo-apps/android/LlamaDemo/app/build/outputs/apk/androidTest/debug/*.apk "${DEMO_APP_DIR}" - # Running Android emulator directly on the runner and not using Docker run-emulator: needs: build-llm-demo # NB: Use metal install for KVM support to run the emulator faster runs-on: linux.24xl.spr-metal env: - ANDROID_NDK_VERSION: r27b + ANDROID_NDK_VERSION: r28c API_LEVEL: 34 steps: - name: Setup SSH (Click me for login details) @@ -103,8 +90,6 @@ jobs: shell: bash run: | set -eux - curl -O https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/llm_demo/app-debug.apk - curl -O https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/llm_demo/app-debug-androidTest.apk curl -O https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/fp32-xnnpack-custom/model.zip curl -o android-test-debug-androidTest.apk https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/library_test_dir/executorch_android-debug-androidTest.apk unzip model.zip diff --git a/.github/workflows/_get-changed-files.yml b/.github/workflows/_get-changed-files.yml new file mode 100644 index 00000000000..55712b06527 --- /dev/null +++ b/.github/workflows/_get-changed-files.yml @@ -0,0 +1,43 @@ +name: Get Changed Files + +on: + workflow_call: + outputs: + changed-files: + description: "List of changed files (space-separated) or '*' if not in a PR" + value: ${{ jobs.get-changed-files.outputs.changed-files }} + +jobs: + get-changed-files: + runs-on: ubuntu-latest + outputs: + changed-files: ${{ steps.get-files.outputs.changed-files }} + + steps: + - name: Get changed files + id: get-files + env: + GH_TOKEN: ${{ github.token }} + run: | + # Check if we're in a pull request context + if [ "${{ github.event_name }}" = "pull_request" ] || [ "${{ github.event_name }}" = "pull_request_target" ]; then + echo "Running in PR context" + + # Get the PR number from the github context + PR_NUMBER="${{ github.event.number }}" + + # Use gh CLI to get changed files in the PR with explicit repo + CHANGED_FILES=$(gh api repos/${{ github.repository }}/pulls/$PR_NUMBER/files --paginate --jq '.[] | select(.status != "removed") | .filename' | tr '\n' ' ' | sed 's/ $//') + + if [ -z "$CHANGED_FILES" ]; then + echo "No changed files found, setting to '*'" + CHANGED_FILES="*" + fi + + echo "Changed files: $CHANGED_FILES" + echo "changed-files=$CHANGED_FILES" >> "$GITHUB_OUTPUT" + + else + echo "Not in PR context, setting changed files to '*'" + echo "changed-files=*" >> "$GITHUB_OUTPUT" + fi diff --git a/.github/workflows/_link_check.yml b/.github/workflows/_link_check.yml index aadd6c07420..89b3655986c 100644 --- a/.github/workflows/_link_check.yml +++ b/.github/workflows/_link_check.yml @@ -55,3 +55,29 @@ jobs: echo "Or add \`@lint-ignore\` somewhere on the same line as the reference you want to skip checking." exit 1 } + + lint-file-size: + if: ${{ github.event_name == 'pull_request' }} + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + runner: linux.2xlarge + docker-image: ci-image:executorch-ubuntu-22.04-linter + submodules: false + fetch-depth: 0 + ref: ${{ inputs.ref }} + timeout: 30 + script: | + chmod +x ./scripts/lint_file_size.sh + ./scripts/lint_file_size.sh $( + if [ "${{ github.event_name }}" = "pull_request" ]; then + echo "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" + else + echo "${{ github.event.before }}" "${{ github.sha }}" + fi + ) || { + echo + echo "File size lint failed: some files exceed the 1 MB limit." + echo "If you really need large files, consider using Git LFS or storing them elsewhere." + echo "If you really need to get unblocked and check in the file, can add it to the EXCEPTIONS list in scripts/lint_file_size.sh." + exit 1 + } diff --git a/.github/workflows/_test_backend.yml b/.github/workflows/_test_backend.yml new file mode 100644 index 00000000000..ec426af8892 --- /dev/null +++ b/.github/workflows/_test_backend.yml @@ -0,0 +1,84 @@ +name: Test Backend + +on: + workflow_call: + inputs: + backend: + description: 'Backend to test (xnnpack, coreml, vulkan, qnn)' + required: true + type: string + flows: + description: 'JSON array of flows to test' + required: true + type: string + ref: + description: 'Git ref to checkout' + required: false + type: string + default: ${{ github.sha }} + timeout: + description: 'Job timeout in minutes' + required: false + type: number + default: 120 + run-linux: + description: 'Whether to run Linux tests' + required: false + type: boolean + default: false + run-macos: + description: 'Whether to run macOS tests' + required: false + type: boolean + default: false + runner-linux: + description: 'Runner type for Linux jobs' + required: false + type: string + default: linux.4xlarge.memory + +jobs: + test-backend-linux: + if: ${{ inputs.run-linux }} + strategy: + fail-fast: false + matrix: + flow: ${{ fromJSON(inputs.flows) }} + suite: [models, operators] + + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + ref: ${{ inputs.ref }} + runner: ${{ inputs.runner-linux }} + docker-image: ci-image:executorch-ubuntu-22.04-clang12 + submodules: recursive + timeout: ${{ inputs.timeout }} + upload-artifact: test-report-${{ matrix.flow }}-${{ matrix.suite }} + script: | + set -eux + + source .ci/scripts/test_backend.sh "${{ matrix.suite }}" "${{ matrix.flow }}" "${RUNNER_ARTIFACT_DIR}" + + test-backend-macos: + if: ${{ inputs.run-macos }} + strategy: + fail-fast: false + matrix: + flow: ${{ fromJSON(inputs.flows) }} + suite: [models, operators] + + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + ref: ${{ inputs.ref }} + runner: macos-m1-stable + python-version: "3.12" + submodules: recursive + timeout: ${{ inputs.timeout }} + upload-artifact: test-report-${{ matrix.flow }}-${{ matrix.suite }} + script: | + set -eux + + # This is needed to get the prebuilt PyTorch wheel from S3 + ${CONDA_RUN} --no-capture-output pip install awscli==1.37.21 + + source .ci/scripts/test_backend.sh "${{ matrix.suite }}" "${{ matrix.flow }}" "${RUNNER_ARTIFACT_DIR}" diff --git a/.github/workflows/_unittest.yml b/.github/workflows/_unittest.yml index 587f2cf5e5a..e26e7146f2a 100644 --- a/.github/workflows/_unittest.yml +++ b/.github/workflows/_unittest.yml @@ -32,7 +32,7 @@ jobs: id-token: write contents: read with: - runner: linux.2xlarge + runner: linux.2xlarge.memory docker-image: ${{ inputs.docker-image }} submodules: 'recursive' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} diff --git a/.github/workflows/add-unanswered-to-project.yml b/.github/workflows/add-unanswered-to-project.yml index ba2bc6c8436..5321d0f75e2 100644 --- a/.github/workflows/add-unanswered-to-project.yml +++ b/.github/workflows/add-unanswered-to-project.yml @@ -12,7 +12,7 @@ jobs: - name: Add open issues and open, non-draft PRs to org project (excluding certain authors) uses: actions/github-script@v7 with: - github-token: ${{ secrets.GITHUB_TOKEN }} + github-token: ${{ secrets.ET_EXT_CONTRIB }} script: | const projectId = "PVT_kwDOAUB9vs4A_PUL"; // PyTorch org project 136 const owner = 'pytorch'; @@ -20,20 +20,32 @@ jobs: // List of authors to exclude const excludedAuthors = new Set([ - "nil-is-all", "cbilgin", "KimishPatel", "psiddh", "digantdesai", "SS-JIA", "ahmtox", "mcr229", "shoumikhin", - "manuelcandales", "metascroy", "cccclai", "rohansjoshi", "kirklandsign", "abhinaykukkadapu", "JacobSzwejbka", - "Conarnar", "lucylq", "larryliu0820", "BujSet", "Gasoonjia", "Juntian777", "guangy10", "jackzhxng", - "GregoryComer", "leafs1", "swolchok", "mergennachin", "tarun292", "byjlw", "jathu", "Jack-Khuu", "georgehong", - "zhenyan-zhang-meta", "silverguo", "dbort", "jorgep31415", "huydhn", "mcremon-meta", "trivedivivek", "angelayi", - "helunwencser", "hsharma35", "zhxchen17", "iseeyuan", "svekars", "nathanaelsee", "dulinriley", "jerryzh168", - "cmodi-meta", "bigfootjon", "sxu", "ydwu4", "Riandy", "tugsbayasgalan", "bsoyluoglu", "yangw-dev", "YIWENX14", - "namanahuja", "yushangdi", "limintang", "pianpwk", "viveknayakatmeta", "andreanicastro", "JakeStevens", - "gmagogsfm", "zonglinpeng", "eigen-k", "derekxu", "salilsdesai", "skrtskrtfb", "pssrawat", "r-barnes", "pytorchbot", - "pytorchmergebot", "pytorchupdatebot", "facebook-github-bot", "Erik-Lundell", "zingo", "AdrianLundell", - "oscarandersson8218", "per", "Sebastian-Larsson", "SaoirseARM", "robell", "mansnils", "martinlsm", "freddan80", - "YufengShi-dudu", "tom-arm", "perheld", "Jerry-Ge", "gggekov", "fumchin", "wwwind", "haowhsu-quic", "shewu-quic", - "winskuo-quic", "chunit-quic", "DannyYuyang-quic", "chuntl", "cymbalrush", "DenisVieriu97", "billmguo", - "StrycekSimon", "jirioc", "robert-kalmar", "skywall", "neuropilot-captain" + "nil-is-all", "tanvirislam-meta", "cbilgin", "kimishpatel", "psiddh", "digantdesai", "SS-JIA", "ahmtox", "mcr229", + "shoumikhin", "manuelcandales", "metascroy", "cccclai", "rohansjoshi", "kirklandsign", "abhinaykukkadapu", + "JacobSzwejbka", "Conarnar", "lucylq", "larryliu0820", "BujSet", "Gasoonjia", "Juntian777", "guangy10", "jackzhxng", + "GregoryComer", "leafs1", "swolchok", "mergennachin", "tarun292", "byjlw", "jathu", "Jack-Khuu", "georgehong", + "zhenyan-zhang-meta", "silverguo", "harishs88ss", "AlannaBurke", "dbort", "huydhn", "mcremon-meta", "trivedivivek", + "angelayi", "helunwencser", "hsharma35", "zhxchen17", "iseeyuan", "svekars", "nathanaelsee", "dulinriley", + "jerryzh168", "cmodi-meta", "bigfootjon", "sxu", "ydwu4", "Riandy", "tugsbayasgalan", "bsoyluoglu", "yangw-dev", + "YIWENX14", "namanahuja", "yushangdi", "limintang", "pianpwk", "viveknayakatmeta", "andreanicastro", "JakeStevens", + "gmagogsfm", "zonglinpeng", "eigen-k", "derekxu", "salilsdesai", "skrtskrtfb", "pssrawat", "r-barnes", + "kalpit-meta-1", "Will-MingLun-Li", "KapJI", "piyengar", "j-bahr", "BoyuanFeng", "fgasperij", "DariusHolmgren", + "sammarden-meta", "kushrast", "meta-emilian", "Rittzz", "jeanschmidt", "copyrightly", "mikekgfb", "vmpuri", + "zonglinpengmeta", "maggiemoss", "aorenste", "hoangminhle98", "Solumin", "meyering", "rchen152", "AishwaryaSivaraman", + "migeed-z", "ebgraham", "Esteb37", "nausicaasnow", "Camyll", "ezyang", "huiyujie", "dltn", "cjhopman", "blackm00n", + "agunapal", "SamGondelman", "Ninja91", "ivayloen", "DrJessop", "rodrigos01meta", "akrieger", "cmt0", "yiming0416", + "ethansfng", "ThomasJannaud", "nirvanagth", "marcinkwiatkowski", "3l1", "omerjerk", "nitish2112", "yipjustin", + "ejnguyen", "andrewor14", "phaiting", "mgiordy", "LeeOHzzZ", "adicatana", "Polyomino", "ezrilow", "navsud", + "michaelmaitland", "RahulC7", "seyeong-han", "YifanShenSZ", "RdoubleA", "Olivia-liu", "Abhi-hpp", "Vysarat", + "azad-meta", "junpi", "pytorchbot", "pytorchmergebot", "pytorchupdatebot", "facebook-github-bot", "app/dependabot", + "Erik-Lundell", "zingo", "AdrianLundell", "oscarandersson8218", "per", "Sebastian-Larsson", "SaoirseARM", "robell", + "mansnils", "martinlsm", "freddan80", "YufengShi-dudu", "tom-arm", "perheld", "Jerry-Ge", "gggekov", "fumchin", "wwwind", + "benkli01", "Tessil", "maddun01", "Michiel-Olieslagers", "armwaheed", "agrima1304", "emmakujala", "annietllnd", + "MatthiasHertel80", "AlexTawseArm", "jmahbs", "haowhsu-quic", "shewu-quic", "winskuo-quic", "chunit-quic", + "DannyYuyang-quic", "chuntl", "thchenqti", "jethroqti", "chenweng-quic", "cymbalrush", "DenisVieriu97", "billmguo", + "StrycekSimon", "jirioc", "robert-kalmar", "skywall", "MartinPavella", "roman-janik-nxp", "novak-vaclav ", + "neuropilot-captain", "dijopaul", "cad-rlc", "cad-audio", "ynimmaga", "daniil-lyakhov", "emmanuel-ferdman", + "cavusmustafa", "Jiseong-oh", "alexdean08" ]); async function addItem(contentId, type, number) { @@ -80,11 +92,10 @@ jobs: owner, repo, state: 'open', - draft: false, } ); for (const pr of prs) { - if (!excludedAuthors.has(pr.user.login)) { + if (!pr.draft && !excludedAuthors.has(pr.user.login)) { await addItem(pr.node_id, 'pr', pr.number); } } diff --git a/.github/workflows/android-perf-private-device-experiment.yml b/.github/workflows/android-perf-private-device-experiment.yml deleted file mode 100644 index cf37538f620..00000000000 --- a/.github/workflows/android-perf-private-device-experiment.yml +++ /dev/null @@ -1,62 +0,0 @@ -name: android-perf (private devices) - -on: - schedule: - - cron: 0 0,4,8,12,16,20 * * * - pull_request: - paths: - - .github/workflows/android-perf-private-device-experiment.yml - push: - branches: - - main - paths: - - .github/workflows/android-perf-private-device-experiment.yml - # Note: GitHub has an upper limit of 10 inputs - workflow_dispatch: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: samsung_galaxy_s22+private - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - workflow_call: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: samsung_galaxy_s22+private - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - -concurrency: - group: android-perf-private-devices-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -jobs: - android: - uses: ./.github/workflows/android-perf.yml - secrets: inherit - permissions: - id-token: write - contents: read - with: - models: ${{ inputs.models || github.event_name == 'schedule' && 'Qwen/Qwen3-0.6B,HuggingFaceTB/SmolLM2-135M,meta-llama/Llama-3.2-1B,allenai/OLMo-1B-hf,google/gemma-3-1b-it' || 'google/gemma-3-1b-it' }} - devices: samsung_galaxy_s22+private - benchmark_configs: ${{ inputs.benchmark_configs }} diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml deleted file mode 100644 index 33937531a01..00000000000 --- a/.github/workflows/android-perf.yml +++ /dev/null @@ -1,562 +0,0 @@ -name: android-perf - -on: - schedule: - - cron: 0 0,8,16 * * * - pull_request: - paths: - - .github/workflows/android-perf.yml - - .ci/scripts/gather_benchmark_configs.py - - extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 - push: - branches: - - main - paths: - - .github/workflows/android-perf.yml - - .ci/scripts/gather_benchmark_configs.py - - extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 - # Note: GitHub has an upper limit of 10 inputs - workflow_dispatch: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: samsung_galaxy_s22+public - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - workflow_call: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: samsung_galaxy_s22+public - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -jobs: - set-parameters: - runs-on: ubuntu-22.04 - outputs: - benchmark_configs: ${{ steps.set-parameters.outputs.benchmark_configs }} - steps: - - uses: actions/checkout@v3 - with: - submodules: 'false' - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Set parameters - id: set-parameters - shell: bash - env: - # Separate default values from the workflow dispatch. To ensure defaults are accessible - # during scheduled runs and to provide flexibility for different defaults between - # on-demand and periodic benchmarking. - CRON_DEFAULT_MODELS: ${{ github.event_name == 'schedule' && 'mv3,mv2,ic4,ic3,resnet50,mobilebert,w2l,meta-llama/Llama-3.2-1B,meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8,meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8,Qwen/Qwen3-0.6B,HuggingFaceTB/SmolLM2-135M,allenai/OLMo-1B-hf,google/gemma-3-1b-it' || 'Qwen/Qwen3-0.6B' }} - CRON_DEFAULT_DEVICES: samsung_galaxy_s22+public - run: | - set -eux - - ARGS="--os android" - - MODELS="${{ inputs.models }}" - if [ -z "$MODELS" ]; then - MODELS="$CRON_DEFAULT_MODELS" - fi - ARGS="$ARGS --models $MODELS" - - DEVICES="${{ inputs.devices }}" - if [ -z "$DEVICES" ]; then - DEVICES="$CRON_DEFAULT_DEVICES" - fi - ARGS="$ARGS --devices $DEVICES" - - BENCHMARK_CONFIGS="${{ inputs.benchmark_configs }}" - if [ -n "$BENCHMARK_CONFIGS" ]; then - ARGS="$ARGS --configs $BENCHMARK_CONFIGS" - fi - - PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py $ARGS - - prepare-test-specs: - runs-on: linux.2xlarge - needs: set-parameters - strategy: - matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} - fail-fast: false - steps: - - uses: actions/checkout@v3 - - - name: Prepare the spec - id: prepare - shell: bash - env: - BENCHMARK_CONFIG: ${{ toJSON(matrix) }} - working-directory: extension/benchmark/android/benchmark - run: | - set -eux - - # The model will be exported in the next step to this S3 path - MODEL_PATH="https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/${{ matrix.model }}_${{ matrix.config }}/model.zip" - # We could write a script to properly use jinja here, but there is only one variable, - # so let's just sed it - sed -i -e 's,{{ model_path }},'"${MODEL_PATH}"',g' android-llm-device-farm-test-spec.yml.j2 - - BENCHMARK_CONFIG_ID=$(echo "${{ matrix.model }}_${{ matrix.config }}" | sed -e 's/[^A-Za-z0-9._-]/_/g') - # The config for this benchmark runs, we save it in the test spec so that it can be fetched - # later by the upload script - sed -i -e 's,{{ benchmark_config_id }},'"${BENCHMARK_CONFIG_ID}"',g' android-llm-device-farm-test-spec.yml.j2 - - cp android-llm-device-farm-test-spec.yml.j2 android-llm-device-farm-test-spec.yml - # Just print the test spec for debugging - cat android-llm-device-farm-test-spec.yml - - # Save the benchmark configs so that we can use it later in the dashboard - echo "${BENCHMARK_CONFIG}" > "${BENCHMARK_CONFIG_ID}.json" - echo "benchmark-config-id=${BENCHMARK_CONFIG_ID}" >> $GITHUB_OUTPUT - - - name: Upload the spec - uses: seemethere/upload-artifact-s3@v5 - with: - s3-bucket: gha-artifacts - s3-prefix: | - ${{ github.repository }}/${{ github.run_id }}/artifacts/${{ matrix.model }}_${{ matrix.config }} - retention-days: 1 - if-no-files-found: error - path: extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml - - - name: Update the benchmark configs - uses: seemethere/upload-artifact-s3@v5 - with: - s3-bucket: gha-artifacts - s3-prefix: | - ${{ github.repository }}/${{ github.run_id }}/artifacts/benchmark-configs/ - retention-days: 1 - if-no-files-found: error - path: extension/benchmark/android/benchmark/${{ steps.prepare.outputs.benchmark-config-id }}.json - - export-models: - name: export-models - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - permissions: - id-token: write - contents: read - needs: set-parameters - secrets: inherit - strategy: - matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} - fail-fast: false - with: - runner: linux.2xlarge.memory - docker-image: ci-image:executorch-ubuntu-22.04-qnn-sdk - submodules: 'recursive' - timeout: 60 - upload-artifact: android-models - upload-artifact-to-s3: true - secrets-env: EXECUTORCH_HF_TOKEN - script: | - # The generic Linux job chooses to use base env, not the one setup by the image - echo "::group::Setting up dev environment" - CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") - conda activate "${CONDA_ENV}" - if [[ ${{ matrix.config }} == *"qnn"* ]]; then - PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh - PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh - fi - PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool "cmake" - # Install requirements for export_llama - PYTHON_EXECUTABLE=python bash examples/models/llama/install_requirements.sh - - pip install -U "huggingface_hub[cli]" - huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN - pip install accelerate sentencepiece - pip list - - ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded/${{ matrix.model }}_${{ matrix.config }} - echo "::endgroup::" - - echo "::group::Exporting ${{ matrix.config }} model: ${{ matrix.model }}" - BUILD_MODE="cmake" - - if [[ ${{ matrix.model }} =~ ^[^/]+/[^/]+$ ]]; then - # HuggingFace model. Assume the pattern is always like "/" - HF_MODEL_REPO=${{ matrix.model }} - OUT_ET_MODEL_NAME="$(echo "$HF_MODEL_REPO" | awk -F'/' '{print $2}' | sed 's/_/-/g' | tr '[:upper:]' '[:lower:]')_${{ matrix.config }}" - - # Convert HF checkpoint to ET via etLLM path - if [[ "$HF_MODEL_REPO" == meta-llama/* ]]; then - if [[ ${{ matrix.config }} == "llama3_spinquant" ]]; then - # SpinQuant - # Download prequantized chceckpoint from Hugging Face - DOWNLOADED_PATH=$( - bash .ci/scripts/download_hf_hub.sh \ - --model_id "${HF_MODEL_REPO}" \ - --files "tokenizer.model" "params.json" "consolidated.00.pth" - ) - # Export using ExecuTorch's model definition - python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - model.use_sdpa_with_kv_cache=true \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - base.preq_mode="preq_8da4w_out_8da8w" \ - base.preq_group_size=32 \ - export.max_seq_length=2048 \ - export.max_context_length=2048 \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" \ - model.use_kv_cache=true \ - model.dtype_override=fp32 \ - base.preq_embedding_quantize=\'8,0\' \ - quantization.use_spin_quant=native \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "llama3_qlora" ]]; then - # QAT + LoRA - # Download prequantized chceckpoint from Hugging Face - DOWNLOADED_PATH=$( - bash .ci/scripts/download_hf_hub.sh \ - --model_id "${HF_MODEL_REPO}" \ - --files "tokenizer.model" "params.json" "consolidated.00.pth" - ) - # Export using ExecuTorch's model definition - python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - quantization.use_qat=true \ - base.use_lora=16 \ - base.preq_mode="preq_8da4w_out_8da8w" \ - base.preq_group_size=32 \ - base.preq_embedding_quantize=\'8,0\' \ - model.use_sdpa_with_kv_cache=true \ - model.use_kv_cache=true \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - model.dtype_override=fp32 \ - export.max_seq_length=2048 \ - export.max_context_length=2048 \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "llama3_fb16" ]]; then - # Original BF16 version, without any quantization - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "original" --files "tokenizer.model" "params.json" "consolidated.00.pth") - python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - backend.xnnpack.enabled=true \ - model.dtype_override=bf16 \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "et_xnnpack_custom_spda_kv_cache_8da4w" ]]; then - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "original" --files "tokenizer.model" "params.json" "consolidated.00.pth") - python -m extension.llm.export.export_llm \ - base.model_class=llama3_2 \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override=fp32 \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - quantization.qmode=8da4w \ - quantization.group_size=32 \ - quantization.embedding_quantize=\'8,0\' \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "llama3_qnn_htp" ]]; then - export QNN_SDK_ROOT=/tmp/qnn/2.37.0.250724 - export LD_LIBRARY_PATH=$QNN_SDK_ROOT/lib/x86_64-linux-clang/ - export PYTHONPATH=$(pwd)/.. - - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "original" --files "tokenizer.model" "params.json" "consolidated.00.pth") - python -m examples.qualcomm.oss_scripts.llama3_2.llama -- \ - --checkpoint "${DOWNLOADED_PATH}/consolidated.00.pth" \ - --params "${DOWNLOADED_PATH}/params.json" \ - --tokenizer_model "${DOWNLOADED_PATH}/tokenizer.model" \ - --compile_only \ - --ptq 16a4w \ - -m SM8650 \ - --model_size 1B \ - --model_mode kv \ - --prompt "Once" - - OUT_ET_MODEL_NAME="llama3_2_qnn" # Qualcomm hard-coded it in their script - find . -name "${OUT_ET_MODEL_NAME}.pte" -not -path "./${OUT_ET_MODEL_NAME}.pte" -exec mv {} ./ \; - ls -lh "${OUT_ET_MODEL_NAME}.pte" - fi - elif [[ "$HF_MODEL_REPO" == "Qwen/Qwen3-0.6B" ]]; then - if [[ ${{ matrix.config }} == "et_xnnpack_custom_spda_kv_cache_8da4w" ]]; then - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "." --files "tokenizer.json") - python -m extension.llm.export.export_llm \ - base.model_class=qwen3_0_6b \ - base.params=examples/models/qwen3/config/0_6b_config.json \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override=fp32 \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - quantization.qmode=8da4w \ - quantization.group_size=32 \ - quantization.embedding_quantize=\'8,0\' \ - base.metadata='"{\"get_bos_id\":151644,\"get_eos_ids\":[151645]}"' \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - fi - fi - - if [[ ${{ matrix.config }} == "hf_xnnpack_custom_spda_kv_cache_8da4w" ]]; then - DOWNLOADED_PATH=$( - bash .ci/scripts/download_hf_hub.sh \ - --model_id "${HF_MODEL_REPO}" \ - --files "tokenizer.json" - ) - echo "tokenizer.json is downloaded to $DOWNLOADED_PATH" - - # Install optimum-executorch - OPTIMUM_ET_COMMIT=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) - git clone https://github.com/huggingface/optimum-executorch - pushd optimum-executorch - # There is no release yet, for CI stability, always test from the same commit on main - git checkout $OPTIMUM_ET_COMMIT - python install_dev.py --skip_override_torch - pip list - - ARGS=( - "--model" "${HF_MODEL_REPO}" - "--task" "text-generation" - "--recipe" "xnnpack" - "--use_custom_sdpa" - "--use_custom_kv_cache" - "--qlinear" "8da4w" - "--qembedding" "8w" - "--output_dir" ".." - ) - - optimum-cli export executorch "${ARGS[@]}" - popd - - mv model.pte ${OUT_ET_MODEL_NAME}.pte - ls -lh "${OUT_ET_MODEL_NAME}.pte" - fi - - zip -j model.zip ${OUT_ET_MODEL_NAME}.pte ${DOWNLOADED_PATH}/tokenizer.* - ls -lh model.zip - mkdir -p ${ARTIFACTS_DIR_NAME} - mv model.zip ${ARTIFACTS_DIR_NAME} - ls -lh ${ARTIFACTS_DIR_NAME} - elif [[ ${{ matrix.model }} == "llama" ]]; then - # Install requirements for export_llama - PYTHON_EXECUTABLE=python bash examples/models/llama/install_requirements.sh - # Test llama2 - if [[ ${{ matrix.config }} == *"xnnpack"* ]]; then - DELEGATE_CONFIG="xnnpack+custom+qe" - elif [[ ${{ matrix.config }} == *"qnn"* ]]; then - DELEGATE_CONFIG="qnn" - else - echo "Unsupported delegate ${{ matrix.config }}" - exit 1 - fi - DTYPE="fp32" - PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh \ - -model "${{ matrix.model }}" \ - -build_tool "${BUILD_MODE}" \ - -dtype "${DTYPE}" \ - -mode "${DELEGATE_CONFIG}" \ - -upload "${ARTIFACTS_DIR_NAME}" - else - PYTHON_EXECUTABLE=python bash .ci/scripts/test_model.sh \ - "${{ matrix.model }}" \ - "${BUILD_MODE}" \ - "${{ matrix.config }}" \ - "${ARTIFACTS_DIR_NAME}" - fi - echo "::endgroup::" - - build-benchmark-app: - name: build-benchmark-app - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - permissions: - id-token: write - contents: read - needs: set-parameters - with: - runner: linux.2xlarge - docker-image: ci-image:executorch-ubuntu-22.04-clang12-android - submodules: 'recursive' - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - timeout: 90 - upload-artifact: android-apps - upload-artifact-to-s3: true - script: | - set -eux - - # Use sccache for NDK compiler as well - export CMAKE_CXX_COMPILER_LAUNCHER=sccache - export CMAKE_C_COMPILER_LAUNCHER=sccache - - # The generic Linux job chooses to use base env, not the one setup by the image - CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") - conda activate "${CONDA_ENV}" - PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake - export ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded - - PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh - PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh - - mkdir -p aar-out - PYTHON_EXECUTABLE=python ANDROID_ABIS="arm64-v8a" BUILD_AAR_DIR=aar-out EXECUTORCH_BUILD_QNN=ON QNN_SDK_ROOT=/tmp/qnn/2.37.0.250724 EXECUTORCH_ANDROID_PROFILING=ON bash scripts/build_android_library.sh - mkdir -p extension/benchmark/android/benchmark/app/libs - cp aar-out/executorch.aar extension/benchmark/android/benchmark/app/libs - pushd extension/benchmark/android/benchmark - ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew build assembleAndroidTest - popd - MINIBENCH_APP_DIR="${ARTIFACTS_DIR_NAME}/minibench" - mkdir -p "${MINIBENCH_APP_DIR}" - cp extension/benchmark/android/benchmark/app/build/outputs/apk/debug/*.apk "${MINIBENCH_APP_DIR}" - cp extension/benchmark/android/benchmark/app/build/outputs/apk/androidTest/debug/*.apk "${MINIBENCH_APP_DIR}" - - # Let's see how expensive this job is, we might want to tone it down by running it periodically - # CHANGE IF this job name 'benchmark-on-device' changed: extract_model_info() in executorch/.github/scripts/extract_benchmark_results.py - benchmark-on-device: - if: always() - permissions: - id-token: write - contents: read - uses: pytorch/test-infra/.github/workflows/mobile_job.yml@main - needs: - - set-parameters - - prepare-test-specs - - build-benchmark-app - - export-models - strategy: - matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} - fail-fast: false - with: - # Due to scheduling a job may be pushed beyond the default 60m threshold - timeout: 240 - device-type: android - runner: linux.2xlarge - test-infra-ref: '' - # This is the ARN of ExecuTorch project on AWS - project-arn: arn:aws:devicefarm:us-west-2:308535385114:project:02a2cf0f-6d9b-45ee-ba1a-a086587469e6 - device-pool-arn: ${{ matrix.device_arn }} - android-app-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/minibench/app-debug.apk - android-test-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/minibench/app-debug-androidTest.apk - test-spec: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/${{ matrix.model }}_${{ matrix.config }}/android-llm-device-farm-test-spec.yml - new-output-format-flag: true - - upload-benchmark-results: - needs: - - benchmark-on-device - if: always() - runs-on: linux.2xlarge - environment: upload-benchmark-results - permissions: - id-token: write - contents: read - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - - name: Authenticate with AWS - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results - # The max duration enforced by the server side - role-duration-seconds: 18000 - aws-region: us-east-1 - - - name: Setup conda - uses: pytorch/test-infra/.github/actions/setup-miniconda@main - with: - python-version: '3.10' - - - name: Download the list of artifacts from S3 - env: - ARTIFACTS_S3_DIR: s3://gha-artifacts/device_farm/${{ github.run_id }}/${{ github.run_attempt }}/artifacts/ - shell: bash - run: | - set -eux - ${CONDA_RUN} python -mpip install awscli==1.32.18 - - mkdir -p artifacts - pushd artifacts - ${CONDA_RUN} aws s3 sync "${ARTIFACTS_S3_DIR}" . - popd - - ls -lah artifacts - - - name: Download the list of benchmark configs from S3 - env: - BENCHMARK_CONFIGS_DIR: s3://gha-artifacts/${{ github.repository }}/${{ github.run_id }}/artifacts/benchmark-configs/ - shell: bash - run: | - set -eux - - mkdir -p benchmark-configs - pushd benchmark-configs - ${CONDA_RUN} aws s3 sync "${BENCHMARK_CONFIGS_DIR}" . - popd - - ls -lah benchmark-configs - - - name: Extract the benchmark results JSON - shell: bash - env: - DEVICE_TYPE: android - run: | - set -eux - - mkdir -p benchmark-results - - for ARTIFACTS_BY_JOB in artifacts/*.json; do - [ -f "${ARTIFACTS_BY_JOB}" ] || break - echo "${ARTIFACTS_BY_JOB}" - ${CONDA_RUN} python .github/scripts/extract_benchmark_results.py \ - --artifacts "${ARTIFACTS_BY_JOB}" \ - --output-dir benchmark-results \ - --app "${DEVICE_TYPE}" \ - --benchmark-configs benchmark-configs - done - - for BENCHMARK_RESULTS in benchmark-results/v3/*.json; do - cat "${BENCHMARK_RESULTS}" - echo - done - - - name: Upload the benchmark results (v3) - uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main - with: - benchmark-results-dir: benchmark-results/v3 - dry-run: false - schema-version: v3 - github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/android-release-artifacts.yml b/.github/workflows/android-release-artifacts.yml index f0b74342eb8..beda0f77c83 100644 --- a/.github/workflows/android-release-artifacts.yml +++ b/.github/workflows/android-release-artifacts.yml @@ -15,15 +15,11 @@ on: type: choice options: - "xnnpack" - - "vulkan+xnnpack" + - "vulkan" - "qnn" schedule: - cron: 0 10 * * * -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - jobs: check-if-aar-exists: name: check-if-aar-exists @@ -34,12 +30,13 @@ jobs: shell: bash run: | VERSION="${{ inputs.version }}" + FLAVOR="${{ inputs.flavor }}" if [ -z "$VERSION" ]; then echo "No version name specified. Will create a snapshot AAR" exit 0 fi - if curl -I "https://ossci-android.s3.amazonaws.com/executorch/release/${VERSION}/executorch.aar" | grep "200 OK"; then - echo "AAR already exists at https://ossci-android.s3.amazonaws.com/executorch/release/${VERSION}/executorch.aar" + if curl -I "https://ossci-android.s3.amazonaws.com/executorch/release/${VERSION}-${FLAVOR}/executorch.aar" | grep "200 OK"; then + echo "AAR already exists at https://ossci-android.s3.amazonaws.com/executorch/release/${VERSION}-${FLAVOR}/executorch.aar" echo "Will skip build/upload" exit 1 fi @@ -93,7 +90,14 @@ jobs: fi FLAVOR="${{ inputs.flavor }}" - if [[ "$FLAVOR" == "vulkan+xnnpack" || -z "$FLAVOR" ]]; then + if [ ! -z "$FLAVOR" ]; then + GRADLE_ARGS+=" -Dflavor=${FLAVOR}" + fi + + if [[ "$FLAVOR" == "vulkan" || -z "$FLAVOR" ]]; then + curl -O https://sdk.lunarg.com/sdk/download/1.4.321.1/linux/vulkansdk-linux-x86_64-1.4.321.1.tar.xz + tar xf vulkansdk-linux-x86_64-1.4.321.1.tar.xz -C /tmp + export PATH="/tmp/1.4.321.1/x86_64/bin:$PATH" export EXECUTORCH_BUILD_VULKAN=ON fi @@ -145,8 +149,12 @@ jobs: pip install awscli==1.32.18 AWS_CMD="aws s3 cp" VERSION="${{ inputs.version }}" + FLAVOR="${{ inputs.flavor }}" if [ -z "$VERSION" ]; then VERSION="snapshot-$(date +"%Y%m%d")" fi - ${AWS_CMD} executorch.aar s3://ossci-android/executorch/release/${VERSION}/executorch.aar --acl public-read - ${AWS_CMD} executorch.aar.sha256sums s3://ossci-android/executorch/release/${VERSION}/executorch.aar.sha256sums --acl public-read + if [ -z "$FLAVOR" ]; then + FLAVOR="xnnpack" + fi + ${AWS_CMD} executorch.aar s3://ossci-android/executorch/release/${VERSION}-${FLAVOR}/executorch.aar --acl public-read + ${AWS_CMD} executorch.aar.sha256sums s3://ossci-android/executorch/release/${VERSION}-${FLAVOR}/executorch.aar.sha256sums --acl public-read diff --git a/.github/workflows/apple-perf-private-device-experiment.yml b/.github/workflows/apple-perf-private-device-experiment.yml deleted file mode 100644 index 47e2c6c9340..00000000000 --- a/.github/workflows/apple-perf-private-device-experiment.yml +++ /dev/null @@ -1,62 +0,0 @@ -name: apple-perf (private devices) - -on: - schedule: - - cron: 0 0,4,8,12,16,20 * * * - pull_request: - paths: - - .github/workflows/apple-perf-private-device-experiment.yml - push: - branches: - - main - paths: - - .github/workflows/apple-perf-private-device-experiment.yml - # Note: GitHub has an upper limit of 10 inputs - workflow_dispatch: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: apple_iphone_15+pro_private - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - workflow_call: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: apple_iphone_15+pro_private - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - -concurrency: - group: apple-perf-private-devices-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -jobs: - apple: - uses: ./.github/workflows/apple-perf.yml - secrets: inherit - permissions: - id-token: write - contents: read - with: - models: ${{ inputs.models || github.event_name == 'schedule' && 'Qwen/Qwen3-0.6B,HuggingFaceTB/SmolLM2-135M,meta-llama/Llama-3.2-1B,allenai/OLMo-1B-hf,google/gemma-3-1b-it' || 'google/gemma-3-1b-it' }} - devices: apple_iphone_15+pro_private - benchmark_configs: ${{ inputs.benchmark_configs }} diff --git a/.github/workflows/apple-perf.yml b/.github/workflows/apple-perf.yml deleted file mode 100644 index 56fc67d1617..00000000000 --- a/.github/workflows/apple-perf.yml +++ /dev/null @@ -1,603 +0,0 @@ -name: apple-perf - -on: - schedule: - - cron: 0 1 * * * - pull_request: - paths: - - .github/workflows/apple-perf.yml - - .ci/scripts/gather_benchmark_configs.py - - extension/benchmark/apple/Benchmark/default-ios-device-farm-appium-test-spec.yml.j2 - push: - branches: - - main - paths: - - .github/workflows/apple-perf.yml - - .ci/scripts/gather_benchmark_configs.py - - extension/benchmark/apple/Benchmark/default-ios-device-farm-appium-test-spec.yml.j2 - # Note: GitHub has an upper limit of 10 inputs - workflow_dispatch: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: apple_iphone_15+public - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - workflow_call: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: apple_iphone_15+public - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -jobs: - set-parameters: - runs-on: ubuntu-22.04 - outputs: - benchmark_configs: ${{ steps.set-parameters.outputs.benchmark_configs }} - steps: - - uses: actions/checkout@v3 - with: - submodules: 'false' - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Set parameters - id: set-parameters - shell: bash - env: - # Separate default values from the workflow dispatch. To ensure defaults are accessible - # during scheduled runs and to provide flexibility for different defaults between - # on-demand and periodic benchmarking. - CRON_DEFAULT_MODELS: ${{ github.event_name == 'schedule' && 'mv3,mv2,ic4,ic3,resnet50,edsr,mobilebert,w2l,meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8,meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8,Qwen/Qwen3-0.6B,HuggingFaceTB/SmolLM2-135M,meta-llama/Llama-3.2-1B,allenai/OLMo-1B-hf,google/gemma-3-1b-it' || 'Qwen/Qwen3-0.6B' }} - CRON_DEFAULT_DEVICES: apple_iphone_15+public - run: | - set -eux - - ARGS="--os ios" - - MODELS="${{ inputs.models }}" - if [ -z "$MODELS" ]; then - MODELS="$CRON_DEFAULT_MODELS" - fi - ARGS="$ARGS --models $MODELS" - - DEVICES="${{ inputs.devices }}" - if [ -z "$DEVICES" ]; then - DEVICES="$CRON_DEFAULT_DEVICES" - fi - ARGS="$ARGS --devices $DEVICES" - - BENCHMARK_CONFIGS="${{ inputs.benchmark_configs }}" - if [ -n "$BENCHMARK_CONFIGS" ]; then - ARGS="$ARGS --configs $BENCHMARK_CONFIGS" - fi - - PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py $ARGS - - echo "benchmark_configs is: ${{ steps.set-parameters.outputs.benchmark_configs }}" - - prepare-test-specs: - runs-on: linux.2xlarge - needs: set-parameters - strategy: - matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} - fail-fast: false - steps: - - uses: actions/checkout@v3 - - - name: Prepare the spec - id: prepare - shell: bash - env: - BENCHMARK_CONFIG: ${{ toJSON(matrix) }} - working-directory: extension/benchmark/apple/Benchmark - run: | - set -eux - - # The model will be exported in the next step to this S3 path - MODEL_PATH="https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/${{ matrix.model }}_${{ matrix.config }}/model.zip" - # We could write a script to properly use jinja here, but there is only one variable, - # so let's just sed it - sed -i -e 's,{{ model_path }},'"${MODEL_PATH}"',g' default-ios-device-farm-appium-test-spec.yml.j2 - - BENCHMARK_CONFIG_ID=$(echo "${{ matrix.model }}_${{ matrix.config }}" | sed -e 's/[^A-Za-z0-9._-]/_/g') - # The config for this benchmark runs, we save it in the test spec so that it can be fetched - # later by the upload script - sed -i -e 's,{{ benchmark_config_id }},'"${BENCHMARK_CONFIG_ID}"',g' default-ios-device-farm-appium-test-spec.yml.j2 - - cp default-ios-device-farm-appium-test-spec.yml.j2 default-ios-device-farm-appium-test-spec.yml - # Just print the test spec for debugging - cat default-ios-device-farm-appium-test-spec.yml - - # Save the benchmark configs so that we can use it later in the dashboard - echo "${BENCHMARK_CONFIG}" > "${BENCHMARK_CONFIG_ID}.json" - echo "benchmark-config-id=${BENCHMARK_CONFIG_ID}" >> $GITHUB_OUTPUT - - - name: Upload the spec - uses: seemethere/upload-artifact-s3@v5 - with: - s3-bucket: gha-artifacts - s3-prefix: | - ${{ github.repository }}/${{ github.run_id }}/artifacts/${{ matrix.model }}_${{ matrix.config }} - retention-days: 1 - if-no-files-found: error - path: extension/benchmark/apple/Benchmark/default-ios-device-farm-appium-test-spec.yml - - - name: Update the benchmark configs - uses: seemethere/upload-artifact-s3@v5 - with: - s3-bucket: gha-artifacts - s3-prefix: | - ${{ github.repository }}/${{ github.run_id }}/artifacts/benchmark-configs/ - retention-days: 1 - if-no-files-found: error - path: extension/benchmark/apple/Benchmark/${{ steps.prepare.outputs.benchmark-config-id }}.json - - export-models: - name: export-models - uses: pytorch/test-infra/.github/workflows/macos_job.yml@main - needs: set-parameters - secrets: inherit - strategy: - matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} - fail-fast: false - with: - # NB: Need to use our AWS MacOS runner to upload large models to S3 - runner: macos-m1-stable - python-version: '3.11' - submodules: 'recursive' - timeout: 60 - upload-artifact: ios-models - upload-artifact-to-s3: true - secrets-env: EXECUTORCH_HF_TOKEN - script: | - set -eux - - echo "::group::Setting up CI environment" - .ci/scripts/setup-conda.sh - - BUILD_TOOL=cmake - # Setup MacOS dependencies as there is no Docker support on MacOS atm - GITHUB_RUNNER=1 PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - .ci/scripts/setup-macos.sh --build-tool "${BUILD_TOOL}" - - if [[ ${{ matrix.config }} == *"coreml"* ]]; then - PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - backends/apple/coreml/scripts/install_requirements.sh - fi - - # Install requirements for export_llama - PYTHON_EXECUTABLE=python ${CONDA_RUN} bash examples/models/llama/install_requirements.sh - - pip install -U "huggingface_hub[cli]" - huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN - ${CONDA_RUN} pip install accelerate sentencepiece - pip list - - ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded/${{ matrix.model }}_${{ matrix.config }} - echo "::endgroup::" - - echo "::group::Exporting ${{ matrix.config }} model: ${{ matrix.model }}" - BUILD_MODE="cmake" - - if [[ ${{ matrix.model }} =~ ^[^/]+/[^/]+$ ]]; then - # HuggingFace model. Assume the pattern is always like "/" - HF_MODEL_REPO=${{ matrix.model }} - OUT_ET_MODEL_NAME="$(echo "$HF_MODEL_REPO" | awk -F'/' '{print $2}' | sed 's/_/-/g' | tr '[:upper:]' '[:lower:]')_${{ matrix.config }}" - - # Convert HF checkpoint to ET via etLLM path - if [[ "$HF_MODEL_REPO" == meta-llama/* ]]; then - # The benchmark app replies on the _llm suffix to determine whether the model is a LLM or not - OUT_ET_MODEL_NAME=${OUT_ET_MODEL_NAME}_llm - # Llama models on Hugging Face - if [[ ${{ matrix.config }} == "llama3_spinquant" ]]; then - # SpinQuant - # Download prequantized chceckpoint from Hugging Face - DOWNLOADED_PATH=$( - bash .ci/scripts/download_hf_hub.sh \ - --model_id "${HF_MODEL_REPO}" \ - --files "tokenizer.model" "params.json" "consolidated.00.pth" - ) - # Export using ExecuTorch's model definition - ${CONDA_RUN} python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - model.use_sdpa_with_kv_cache=true \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - base.preq_mode="preq_8da4w_out_8da8w" \ - base.preq_group_size=32 \ - export.max_seq_length=2048 \ - export.max_context_length=2048 \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" \ - model.use_kv_cache=true \ - model.dtype_override=fp32 \ - base.preq_embedding_quantize=\'8,0\' \ - quantization.use_spin_quant=native \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "llama3_qlora" ]]; then - # QAT + LoRA - # Download prequantized chceckpoint from Hugging Face - DOWNLOADED_PATH=$( - bash .ci/scripts/download_hf_hub.sh \ - --model_id "${HF_MODEL_REPO}" \ - --files "tokenizer.model" "params.json" "consolidated.00.pth" - ) - # Export using ExecuTorch's model definition - ${CONDA_RUN} python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - quantization.use_qat=true \ - base.use_lora=16 \ - base.preq_mode="preq_8da4w_out_8da8w" \ - base.preq_group_size=32 \ - base.preq_embedding_quantize=\'8,0\' \ - model.use_sdpa_with_kv_cache=true \ - model.use_kv_cache=true \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - model.dtype_override=fp32 \ - export.max_seq_length=2048 \ - export.max_context_length=2048 \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "llama3_fb16" ]]; then - # Original BF16 version, without any quantization - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "original" --files "tokenizer.model" "params.json" "consolidated.00.pth") - ${CONDA_RUN} python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - backend.xnnpack.enabled=true \ - model.dtype_override=bf16 \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "et_xnnpack_custom_spda_kv_cache_8da4w" ]]; then - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "original" --files "tokenizer.model" "params.json" "consolidated.00.pth") - ${CONDA_RUN} python -m extension.llm.export.export_llm \ - base.model_class=llama3_2 \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override=fp32 \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - quantization.qmode=8da4w \ - quantization.group_size=32 \ - quantization.embedding_quantize=\'8,0\' \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "llama3_coreml_ane" ]]; then - # ANE - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "original" --files "tokenizer.model" "params.json" "consolidated.00.pth") - ${CONDA_RUN} python -m extension.llm.export.export_llm \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - quantization.embedding_quantize=\'4,32\' \ - model.use_kv_cache=true \ - model.enable_dynamic_shape=false \ - backend.coreml.enabled=true \ - backend.coreml.ios=18 \ - backend.coreml.quantize=c4w \ - backend.coreml.compute_units=cpu_and_ne \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - fi - elif [[ "$HF_MODEL_REPO" == "Qwen/Qwen3-0.6B" ]]; then - OUT_ET_MODEL_NAME=${OUT_ET_MODEL_NAME}_llm - if [[ ${{ matrix.config }} == "et_xnnpack_custom_spda_kv_cache_8da4w" ]]; then - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "." --files "tokenizer.json") - ${CONDA_RUN} python -m extension.llm.export.export_llm \ - base.model_class=qwen3_0_6b \ - base.params=examples/models/qwen3/config/0_6b_config.json \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override=fp32 \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - quantization.qmode=8da4w \ - quantization.group_size=32 \ - quantization.embedding_quantize=\'8,0\' \ - base.metadata='"{\"get_bos_id\":151644,\"get_eos_ids\":[151645]}"' \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - fi - fi - - if [[ ${{ matrix.config }} == "hf_xnnpack_custom_spda_kv_cache_8da4w" ]]; then - DOWNLOADED_PATH=$( - bash .ci/scripts/download_hf_hub.sh \ - --model_id "${HF_MODEL_REPO}" \ - --files "tokenizer.json" - ) - echo "tokenizer.json is downloaded to $DOWNLOADED_PATH" - - # Install optimum-executorch - OPTIMUM_ET_COMMIT=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) - git clone https://github.com/huggingface/optimum-executorch - pushd optimum-executorch - # There is no release yet, for CI stability, always test from the same commit on main - git checkout $OPTIMUM_ET_COMMIT - ${CONDA_RUN} python install_dev.py --skip_override_torch - pip list - - ARGS=( - "--model" "${HF_MODEL_REPO}" - "--task" "text-generation" - "--recipe" "xnnpack" - "--use_custom_sdpa" - "--use_custom_kv_cache" - "--qlinear" "8da4w" - "--qembedding" "8w" - "--output_dir" ".." - ) - - ${CONDA_RUN} optimum-cli export executorch "${ARGS[@]}" - popd - - # The benchmark app replies on the _llm suffix to determine whether the model is a LLM or not - OUT_ET_MODEL_NAME=${OUT_ET_MODEL_NAME}_llm - mv model.pte ${OUT_ET_MODEL_NAME}.pte - ls -lh "${OUT_ET_MODEL_NAME}.pte" - fi - - zip -j model.zip ${OUT_ET_MODEL_NAME}.pte ${DOWNLOADED_PATH}/tokenizer.* - ls -lh model.zip - mkdir -p "${ARTIFACTS_DIR_NAME}" - mv model.zip "${ARTIFACTS_DIR_NAME}" - elif [[ ${{ matrix.model }} == "llama" ]]; then - # Install requirements for export_llama - PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - bash examples/models/llama/install_requirements.sh - - # Test llama2 - if [[ ${{ matrix.config }} == *"xnnpack"* ]]; then - DELEGATE_CONFIG="xnnpack+custom+qe" - elif [[ ${{ matrix.config }} == *"coreml"* ]]; then - DELEGATE_CONFIG="coreml" - elif [[ ${{ matrix.config }} == *"mps"* ]]; then - DELEGATE_CONFIG="mps" - fi - DTYPE="fp32" - PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - bash .ci/scripts/test_llama.sh \ - -model "stories110M" \ - -build_tool "${BUILD_MODE}" \ - -dtype "${DTYPE}" \ - -mode "${DELEGATE_CONFIG}" \ - -upload "${ARTIFACTS_DIR_NAME}" - else - PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - bash .ci/scripts/test_model.sh \ - "${{ matrix.model }}" \ - "${BUILD_MODE}" \ - "${{ matrix.config }}" \ - "${ARTIFACTS_DIR_NAME}" - fi - echo "::endgroup::" - - build-benchmark-app: - name: build-benchmark-app - uses: pytorch/test-infra/.github/workflows/macos_job.yml@main - needs: - - set-parameters - secrets: inherit - with: - runner: macos-14-xlarge - python-version: '3.11' - submodules: 'recursive' - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - upload-artifact: ios-apps - secrets-env: BUILD_CERTIFICATE_BASE64 EXECUTORCH_BENCHMARK_BUILD_PROVISION_PROFILE_BASE64 KEYCHAIN_PASSWORD - timeout: 90 - script: | - set -eux - - echo "::group::Setting up CI environment" - .ci/scripts/setup-conda.sh - - BUILD_TOOL=cmake - # Setup MacOS dependencies as there is no Docker support on MacOS atm - GITHUB_RUNNER=1 PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - .ci/scripts/setup-macos.sh --build-tool "${BUILD_TOOL}" - export ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded - - # Setup Apple certificate for iOS development - BUILD_PROVISION_PROFILE_BASE64="${SECRET_EXECUTORCH_BENCHMARK_BUILD_PROVISION_PROFILE_BASE64}" \ - BUILD_CERTIFICATE_BASE64="${SECRET_BUILD_CERTIFICATE_BASE64}" \ - KEYCHAIN_PASSWORD="${SECRET_KEYCHAIN_PASSWORD}" \ - .ci/scripts/setup-ios.sh - - # Install CoreML Backend Requirements - PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - backends/apple/coreml/scripts/install_requirements.sh - echo "::endgroup::" - - echo "::group::Build ExecuTorch iOS frameworks" - PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output scripts/build_apple_frameworks.sh - echo "::endgroup::" - - # NB: Although exported models can be copied to this directory and bundled together with the - # app, we don't use this in CI and rely on AWS extra data parameter to make the model and the - # tokenizer available to the benchmark. This decouples the app and the model. We just need to - # create the directory here to pass the build - mkdir -p extension/benchmark/apple/Benchmark/Models - ${CONDA_RUN} --no-capture-output \ - scripts/build_apple_llm_demo.sh ${ARTIFACTS_DIR_NAME} - - upload-benchmark-app: - needs: build-benchmark-app - runs-on: linux.2xlarge - steps: - - name: Download the apps from GitHub - uses: actions/download-artifact@v4 - with: - # The name here needs to match the name of the upload-artifact parameter - name: ios-apps - path: ${{ runner.temp }}/artifacts/ - - - name: Verify the apps - shell: bash - working-directory: ${{ runner.temp }}/artifacts/ - run: | - ls -lah ./ - - - name: Upload the apps to S3 - uses: seemethere/upload-artifact-s3@v5 - with: - s3-bucket: gha-artifacts - s3-prefix: | - ${{ github.repository }}/${{ github.run_id }}/artifacts - retention-days: 14 - if-no-files-found: ignore - path: ${{ runner.temp }}/artifacts/ - - # CHANGE IF this job name 'benchmark-on-device' changed: extract_model_info() in executorch/.github/scripts/extract_benchmark_results.py - benchmark-on-device: - if: always() - needs: - - set-parameters - - prepare-test-specs - - upload-benchmark-app - - export-models - permissions: - id-token: write - contents: read - uses: pytorch/test-infra/.github/workflows/mobile_job.yml@main - strategy: - matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} - fail-fast: false - with: - # Due to scheduling a job may be pushed beyond the default 60m threshold - timeout: 120 - device-type: ios - # For iOS testing, the runner just needs to call AWS Device Farm, so there is no need to run this on macOS - runner: linux.2xlarge - test-infra-ref: '' - # This is the ARN of ExecuTorch project on AWS - project-arn: arn:aws:devicefarm:us-west-2:308535385114:project:02a2cf0f-6d9b-45ee-ba1a-a086587469e6 - device-pool-arn: ${{ matrix.device_arn }} - # Uploaded to S3 from the previous job - ios-ipa-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/Benchmark.ipa - ios-xctestrun-zip: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/Benchmark.xctestrun.zip - test-spec: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/${{ matrix.model }}_${{ matrix.config }}/default-ios-device-farm-appium-test-spec.yml - new-output-format-flag: true - - upload-benchmark-results: - needs: - - benchmark-on-device - if: always() - runs-on: linux.2xlarge - environment: upload-benchmark-results - permissions: - id-token: write - contents: read - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - - name: Authenticate with AWS - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results - # The max duration enforced by the server side - role-duration-seconds: 18000 - aws-region: us-east-1 - - - name: Setup conda - uses: pytorch/test-infra/.github/actions/setup-miniconda@main - with: - python-version: '3.10' - - - name: Download the list of artifacts from S3 - env: - ARTIFACTS_S3_DIR: s3://gha-artifacts/device_farm/${{ github.run_id }}/${{ github.run_attempt }}/artifacts/ - shell: bash - run: | - set -eux - ${CONDA_RUN} python -mpip install awscli==1.32.18 - - mkdir -p artifacts - pushd artifacts - ${CONDA_RUN} aws s3 sync "${ARTIFACTS_S3_DIR}" . - popd - - ls -lah artifacts - - - name: Download the list of benchmark configs from S3 - env: - BENCHMARK_CONFIGS_DIR: s3://gha-artifacts/${{ github.repository }}/${{ github.run_id }}/artifacts/benchmark-configs/ - shell: bash - run: | - set -eux - mkdir -p benchmark-configs - pushd benchmark-configs - ${CONDA_RUN} aws s3 sync "${BENCHMARK_CONFIGS_DIR}" . - popd - ls -lah benchmark-configs - - - name: Extract the benchmark results JSON - shell: bash - env: - DEVICE_TYPE: ios - run: | - set -eux - - mkdir -p benchmark-results - - for ARTIFACTS_BY_JOB in artifacts/*.json; do - [ -f "${ARTIFACTS_BY_JOB}" ] || break - echo "${ARTIFACTS_BY_JOB}" - ${CONDA_RUN} python .github/scripts/extract_benchmark_results.py \ - --artifacts "${ARTIFACTS_BY_JOB}" \ - --output-dir benchmark-results \ - --app "${DEVICE_TYPE}" \ - --benchmark-configs benchmark-configs - done - - for BENCHMARK_RESULTS in benchmark-results/v3/*.json; do - cat "${BENCHMARK_RESULTS}" - echo - done - - - name: Upload the benchmark results (v3) - uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main - with: - benchmark-results-dir: benchmark-results/v3 - dry-run: false - schema-version: v3 - github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/apple.yml b/.github/workflows/apple.yml index fb3c04d07fb..789af84c1d1 100644 --- a/.github/workflows/apple.yml +++ b/.github/workflows/apple.yml @@ -38,8 +38,21 @@ jobs: id: set_version shell: bash run: | - VERSION="0.8.0.$(TZ='PST8PDT' date +%Y%m%d)" + VERSION="1.1.0.$(TZ='PST8PDT' date +%Y%m%d)" echo "version=$VERSION" >> "$GITHUB_OUTPUT" + - name: Guardrail + if: ${{ (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && github.ref == 'refs/heads/main' }} + shell: bash + run: | + VERSION="${{ steps.set_version.outputs.version }}" + BRANCH="swiftpm-${VERSION}" + + if git ls-remote --exit-code "https://github.com/${{ github.repository }}" "refs/heads/${BRANCH}" > /dev/null 2>&1; then + echo "Branch '${BRANCH}' already exists!" + echo "Aborting workflow to prevent overwriting S3 binaries. The branch existence indicates this version was already published with specific checksums." + echo "Please delete the remote branch '${BRANCH}' and re-run this workflow." + exit 1 + fi build-demo-ios: name: build-demo-ios diff --git a/.github/workflows/build-wheels-aarch64-linux.yml b/.github/workflows/build-wheels-aarch64-linux.yml index abc378f9061..b8729058ec8 100644 --- a/.github/workflows/build-wheels-aarch64-linux.yml +++ b/.github/workflows/build-wheels-aarch64-linux.yml @@ -32,7 +32,7 @@ jobs: test-infra-ref: main with-cuda: disabled with-rocm: disabled - python-versions: '["3.10", "3.11", "3.12"]' + python-versions: '["3.10", "3.11", "3.12", "3.13"]' build: needs: generate-matrix diff --git a/.github/workflows/build-wheels-linux.yml b/.github/workflows/build-wheels-linux.yml index 8509ba52cb9..a149c4f5df0 100644 --- a/.github/workflows/build-wheels-linux.yml +++ b/.github/workflows/build-wheels-linux.yml @@ -32,7 +32,7 @@ jobs: test-infra-ref: main with-cuda: disabled with-rocm: disabled - python-versions: '["3.10", "3.11", "3.12"]' + python-versions: '["3.10", "3.11", "3.12", "3.13"]' build: needs: generate-matrix diff --git a/.github/workflows/build-wheels-macos.yml b/.github/workflows/build-wheels-macos.yml index 8db10c0335b..16da31ddd6d 100644 --- a/.github/workflows/build-wheels-macos.yml +++ b/.github/workflows/build-wheels-macos.yml @@ -32,7 +32,7 @@ jobs: test-infra-ref: main with-cuda: disabled with-rocm: disabled - python-versions: '["3.10", "3.11", "3.12"]' + python-versions: '["3.10", "3.11", "3.12", "3.13"]' build: needs: generate-matrix diff --git a/.github/workflows/build-wheels-windows.yml b/.github/workflows/build-wheels-windows.yml index 276edfb08d1..7fe6f880878 100644 --- a/.github/workflows/build-wheels-windows.yml +++ b/.github/workflows/build-wheels-windows.yml @@ -2,15 +2,23 @@ name: Build Windows Wheels on: pull_request: + paths: + - .ci/**/* + - .github/workflows/build-wheels-windows.yml + - examples/**/* + - pyproject.toml + - setup.py + tags: + - ciflow/binaries/* push: branches: - nightly - - main - release/* tags: - # NOTE: Binary build pipelines should only get triggered on release candidate builds - # Release candidate tags look like: v1.11.0-rc1 - - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + # NOTE: Binary build pipelines should only get triggered on release candidate builds + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + - ciflow/binaries/* workflow_dispatch: permissions: @@ -27,7 +35,7 @@ jobs: test-infra-ref: main with-cuda: disabled with-rocm: disabled - python-versions: '["3.10", "3.11", "3.12"]' + python-versions: '["3.10", "3.11", "3.12", "3.13"]' build: needs: generate-matrix diff --git a/.github/workflows/cuda-perf.yml b/.github/workflows/cuda-perf.yml new file mode 100644 index 00000000000..71e3adf5abc --- /dev/null +++ b/.github/workflows/cuda-perf.yml @@ -0,0 +1,439 @@ +name: cuda-perf + +on: + schedule: + - cron: 0 8 * * * # 12am / 1am PST (8am UTC) + pull_request: + paths: + - .github/workflows/cuda-perf.yml + - .ci/scripts/cuda_benchmark.py + - .ci/scripts/export_model_artifact.sh + - .ci/scripts/test_model_e2e.sh + push: + branches: + - main + paths: + - .github/workflows/cuda-perf.yml + - .ci/scripts/cuda_benchmark.py + - .ci/scripts/export_model_artifact.sh + - .ci/scripts/test_model_e2e.sh + workflow_dispatch: + inputs: + models: + description: Models to be benchmarked (comma-separated HuggingFace model IDs) + required: false + type: string + default: openai/whisper-small + quantizations: + description: Quantization types (comma-separated) + required: false + type: string + default: non-quantized + num_runs: + description: Number of benchmark runs per model + required: false + type: string + default: "50" + run_all_models: + description: Run all available models (overrides models input) + required: false + type: boolean + default: false + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +jobs: + set-parameters: + runs-on: ubuntu-22.04 + outputs: + benchmark_configs: ${{ steps.set-parameters.outputs.benchmark_configs }} + steps: + - uses: actions/checkout@v3 + with: + submodules: 'false' + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Set parameters + id: set-parameters + shell: bash + env: + # All available models and quantizations + ALL_MODELS: 'mistralai/Voxtral-Mini-3B-2507,openai/whisper-small,openai/whisper-medium,openai/whisper-large-v3-turbo,google/gemma-3-4b-it' + ALL_QUANTIZATIONS: 'non-quantized,quantized-int4-tile-packed,quantized-int4-weight-only' + NUM_RUNS: ${{ inputs.num_runs || '50' }} + RUN_ALL_MODELS: ${{ inputs.run_all_models || 'false' }} + RANDOM_MODEL: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' && 'true' || 'false' }} + run: | + set -eux + + MODELS="${{ inputs.models }}" + QUANTIZATIONS="${{ inputs.quantizations }}" + + # If run_all_models is true, use all models + if [ "$RUN_ALL_MODELS" = "true" ]; then + MODELS="$ALL_MODELS" + echo "Running all available models: $MODELS" + # For non-schedule events (PR, manual trigger without inputs), randomly select one model and one quantization + elif [ -z "$MODELS" ] && [ "${{ github.event_name }}" != "schedule" ]; then + # Split all models into array + IFS=',' read -ra ALL_MODEL_ARRAY <<< "$ALL_MODELS" + # Randomly select one model + RANDOM_MODEL_INDEX=$((RANDOM % ${#ALL_MODEL_ARRAY[@]})) + MODELS="${ALL_MODEL_ARRAY[$RANDOM_MODEL_INDEX]}" + echo "Randomly selected model for PR/push: $MODELS" + elif [ -z "$MODELS" ]; then + # Schedule event: use all models + MODELS="$ALL_MODELS" + fi + + # If run_all_models is true, use all quantizations + if [ "$RUN_ALL_MODELS" = "true" ]; then + QUANTIZATIONS="$ALL_QUANTIZATIONS" + echo "Running all available quantizations: $QUANTIZATIONS" + elif [ -z "$QUANTIZATIONS" ] && [ "${{ github.event_name }}" != "schedule" ]; then + # Split all quantizations into array + IFS=',' read -ra ALL_QUANT_ARRAY <<< "$ALL_QUANTIZATIONS" + # Randomly select one quantization + RANDOM_QUANT_INDEX=$((RANDOM % ${#ALL_QUANT_ARRAY[@]})) + QUANTIZATIONS="${ALL_QUANT_ARRAY[$RANDOM_QUANT_INDEX]}" + echo "Randomly selected quantization for PR/push: $QUANTIZATIONS" + elif [ -z "$QUANTIZATIONS" ]; then + # Schedule event: use all quantizations + QUANTIZATIONS="$ALL_QUANTIZATIONS" + fi + + # Split models and quantizations into arrays + IFS=',' read -ra MODEL_ARRAY <<< "$MODELS" + IFS=',' read -ra QUANT_ARRAY <<< "$QUANTIZATIONS" + + # If random model is requested (for main branch push), select one random model from the already selected models + if [ "$RANDOM_MODEL" = "true" ] && [ ${#MODEL_ARRAY[@]} -gt 1 ]; then + RANDOM_INDEX=$((RANDOM % ${#MODEL_ARRAY[@]})) + MODELS="${MODEL_ARRAY[$RANDOM_INDEX]}" + MODEL_ARRAY=("$MODELS") + echo "Random model selected for main branch push: $MODELS" + fi + + # Generate benchmark configs + CONFIGS='{"include":[' + FIRST=true + for MODEL in "${MODEL_ARRAY[@]}"; do + for QUANT in "${QUANT_ARRAY[@]}"; do + if [ "$FIRST" = true ]; then + FIRST=false + else + CONFIGS+=',' + fi + # Sanitize model name for use in artifact paths + MODEL_SAFE=$(echo "$MODEL" | sed 's/\//_/g') + CONFIGS+="{\"model\":\"$MODEL\",\"quant\":\"$QUANT\",\"model_safe\":\"$MODEL_SAFE\",\"num_runs\":\"$NUM_RUNS\"}" + done + done + CONFIGS+=']}' + + echo "benchmark_configs=$CONFIGS" >> $GITHUB_OUTPUT + echo "Generated benchmark configs:" + echo "$CONFIGS" | python -m json.tool + + export-models: + name: export-models + needs: set-parameters + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + secrets: inherit + strategy: + matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} + fail-fast: false + with: + timeout: 90 + secrets-env: EXECUTORCH_HF_TOKEN + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.6" + use-custom-docker-registry: false + submodules: recursive + upload-artifact: model-${{ matrix.model_safe }}-${{ matrix.quant }} + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + set -eux + echo "::group::Setup ExecuTorch" + ./install_executorch.sh + echo "::endgroup::" + + echo "::group::Setup Huggingface" + pip install -U "huggingface_hub[cli]<1.0" accelerate + huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) + pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} + echo "::endgroup::" + + echo "::group::Exporting model ${{ matrix.model }} with quantization ${{ matrix.quant }}" + OUTPUT_DIR="model_artifacts" + mkdir -p "$OUTPUT_DIR" + + bash .ci/scripts/export_model_artifact.sh cuda "${{ matrix.model }}" "${{ matrix.quant }}" "$OUTPUT_DIR" + + # Move artifacts to RUNNER_ARTIFACT_DIR for upload + mv "$OUTPUT_DIR"/* "${RUNNER_ARTIFACT_DIR}/" + ls -lah "${RUNNER_ARTIFACT_DIR}" + echo "::endgroup::" + + benchmark-cuda: + name: benchmark-cuda + needs: + - set-parameters + - export-models + if: always() + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + strategy: + matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} + fail-fast: false + with: + timeout: 90 + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.6" + use-custom-docker-registry: false + submodules: recursive + download-artifact: model-${{ matrix.model_safe }}-${{ matrix.quant }} + upload-artifact: results-${{ matrix.model_safe }}-${{ matrix.quant }} + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + set -eux + echo "::group::Setup environment" + ./install_requirements.sh + pip list + echo "::endgroup::" + + echo "::group::Prepare model artifacts" + mkdir -p model_artifacts + cp "${RUNNER_ARTIFACT_DIR}/model.pte" model_artifacts/model.pte + cp "${RUNNER_ARTIFACT_DIR}/aoti_cuda_blob.ptd" model_artifacts/aoti_cuda_blob.ptd + + # Copy additional files if they exist + if [ -f "${RUNNER_ARTIFACT_DIR}/voxtral_preprocessor.pte" ]; then + cp "${RUNNER_ARTIFACT_DIR}/voxtral_preprocessor.pte" model_artifacts/ + fi + if [ -f "${RUNNER_ARTIFACT_DIR}/whisper_preprocessor.pte" ]; then + cp "${RUNNER_ARTIFACT_DIR}/whisper_preprocessor.pte" model_artifacts/ + fi + if [ -f "${RUNNER_ARTIFACT_DIR}/tekken.json" ]; then + cp "${RUNNER_ARTIFACT_DIR}/tekken.json" model_artifacts/ + fi + if [ -f "${RUNNER_ARTIFACT_DIR}/poem.wav" ]; then + cp "${RUNNER_ARTIFACT_DIR}/poem.wav" model_artifacts/ + fi + if [ -f "${RUNNER_ARTIFACT_DIR}/output.wav" ]; then + cp "${RUNNER_ARTIFACT_DIR}/output.wav" model_artifacts/ + fi + # Copy tokenizer files + for file in tokenizer.json tokenizer_config.json special_tokens_map.json; do + if [ -f "${RUNNER_ARTIFACT_DIR}/$file" ]; then + cp "${RUNNER_ARTIFACT_DIR}/$file" model_artifacts/ + fi + done + + ls -lah model_artifacts/ + echo "::endgroup::" + + echo "::group::Build runner" + bash .ci/scripts/test_model_e2e.sh cuda "${{ matrix.model }}" "${{ matrix.quant }}" model_artifacts + echo "::endgroup::" + + echo "::group::Running benchmark for ${{ matrix.model }} (${{ matrix.quant }}) with ${{ matrix.num_runs }} runs" + export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH + + # Get GPU name using nvidia-smi + GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1) + echo "Detected GPU: $GPU_NAME" + + # Get CUDA driver version + CUDA_DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -1) + echo "CUDA Driver Version: $CUDA_DRIVER_VERSION" + + # Create results directory (separate from model artifacts) + RESULTS_DIR="benchmark_results" + mkdir -p "$RESULTS_DIR" + + # Determine model name and runner command based on model + case "${{ matrix.model }}" in + mistralai/Voxtral-Mini-3B-2507) + RUNNER="cmake-out/examples/models/voxtral/voxtral_runner" + PREPROCESSOR="model_artifacts/voxtral_preprocessor.pte" + TOKENIZER="model_artifacts/tekken.json" + AUDIO="model_artifacts/poem.wav" + RUNNER_CMD="$RUNNER --model_path model_artifacts/model.pte --data_path model_artifacts/aoti_cuda_blob.ptd --tokenizer_path $TOKENIZER --audio_path $AUDIO --processor_path $PREPROCESSOR --temperature 0" + MODEL_NAME="voxtral_${{ matrix.quant }}" + ;; + openai/whisper-*) + RUNNER="cmake-out/examples/models/whisper/whisper_runner" + PREPROCESSOR="model_artifacts/whisper_preprocessor.pte" + AUDIO="model_artifacts/output.wav" + RUNNER_CMD="$RUNNER --model_path model_artifacts/model.pte --data_path model_artifacts/aoti_cuda_blob.ptd --tokenizer_path model_artifacts/ --audio_path $AUDIO --processor_path $PREPROCESSOR --temperature 0" + MODEL_NAME=$(echo "${{ matrix.model }}" | sed 's/openai\///')_${{ matrix.quant }} + ;; + google/gemma-3-4b-it) + RUNNER="cmake-out/examples/models/gemma3/gemma3_e2e_runner" + IMAGE="docs/source/_static/img/et-logo.png" + RUNNER_CMD="$RUNNER --model_path model_artifacts/model.pte --data_path model_artifacts/aoti_cuda_blob.ptd --tokenizer_path model_artifacts/ --image_path $IMAGE --temperature 0" + MODEL_NAME="gemma3_${{ matrix.quant }}" + ;; + *) + echo "Error: Unsupported model '${{ matrix.model }}'" + exit 1 + ;; + esac + + # Run benchmark using cuda_benchmark.py + python .ci/scripts/cuda_benchmark.py \ + --runner_command "$RUNNER_CMD" \ + --model_name "$MODEL_NAME" \ + --num_runs "${{ matrix.num_runs }}" \ + --output_json "$RESULTS_DIR/benchmark_results.json" \ + --output_v3 "$RESULTS_DIR/benchmark_results_v3.json" \ + --model "${{ matrix.model }}" \ + --quantization "${{ matrix.quant }}" \ + --git_sha "${{ github.sha }}" \ + --workflow_run_id "${{ github.run_id }}" \ + --workflow_run_url "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" \ + --gpu_name "$GPU_NAME" \ + --cuda_driver_version "$CUDA_DRIVER_VERSION" + + # Save additional metadata + cat > "$RESULTS_DIR/metadata.json" < /dev/null; then + echo "" + echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner --take MYPY\`. (If you don't get the same results, run \'lintrunner init\' to update your local linter)\e[0m" + echo -e "\e[1m\e[36mSee https://github.com/pytorch/pytorch/wiki/lintrunner for setup instructions.\e[0m" + RC=1 + fi + + # Use jq to massage the JSON lint output into GitHub Actions workflow commands. + jq --raw-output \ + '"::\(if .severity == "advice" or .severity == "disabled" then "warning" else .severity end) file=\(.path),line=\(.line),col=\(.char),title=\(.code) \(.name)::" + (.description | gsub("\\n"; "%0A"))' \ + lint.json || true + + exit $RC + + lintrunner: + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + needs: [get-changed-files] + permissions: + id-token: write + contents: read + with: + runner: linux.2xlarge + docker-image: ci-image:executorch-ubuntu-22.04-linter + submodules: false + fetch-depth: 0 + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + # Not sure why this isn't set up in the docker + # image. lintrunner-mypy seems to work becaus setup-linux.sh + # does this as part of install_executorch. + pip install -r requirements-dev.txt + + CACHE_DIRECTORY="/tmp/.lintbin" + # Try to recover the cached binaries + if [[ -d "${CACHE_DIRECTORY}" ]]; then + # It's ok to fail this as lintrunner init would download these binaries + # again if they do not exist + cp -r "${CACHE_DIRECTORY}" . || true + fi RC=0 - # Run lintrunner on all files - if ! lintrunner --force-color --all-files --tee-json=lint.json 2> /dev/null; then + CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" + if [ "$CHANGED_FILES" = '*' ]; then + LINTRUNNER_FILES="--all-files" + else + LINTRUNNER_FILES="${CHANGED_FILES}" + fi + if ! lintrunner --force-color ${LINTRUNNER_FILES} --skip MYPY --tee-json=lint.json 2> /dev/null; then echo "" - echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner\`. (If you don't get the same results, run \'lintrunner init\' to update your local linter)\e[0m" + echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner --skip MYPY\`. (If you don't get the same results, run \'lintrunner init\' to update your local linter)\e[0m" echo -e "\e[1m\e[36mSee https://github.com/pytorch/pytorch/wiki/lintrunner for setup instructions.\e[0m" RC=1 fi @@ -81,21 +143,28 @@ jobs: ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} timeout: 90 script: | - FILES_NEEDS_FORMAT=$(/opt/google-java-format -n \ - extension/android/executorch_android/src/main/java/org/pytorch/executorch/*.java \ - extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/*.java \ - extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/*.java \ - extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/*.java \ - examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/*.java \ - examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/*.java \ - extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/*.java \ - extension/benchmark/android/benchmark/app/src/androidTest/java/org/pytorch/minibench/*.java) + FILES_NEEDS_FORMAT=$(find extension/android/executorch_android/src/main/java/org/pytorch/executorch \ + extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm \ + extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations \ + extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch \ + extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench \ + extension/benchmark/android/benchmark/app/src/androidTest/java/org/pytorch/minibench \ + -type f -name "*.java" 2>/dev/null | \ + xargs -r /opt/google-java-format -n) + if [ -n "$FILES_NEEDS_FORMAT" ]; then - echo "Warning: The following files need formatting. Please use google-java-format." - echo "Use a binary from https://github.com/google/google-java-format/releases/" - echo "For example:" - echo "wget https://github.com/google/google-java-format/releases/download/v1.23.0/google-java-format_linux-x86-64" - echo "chmod +x google-java-format_linux-x86-64" - echo "./google-java-format_linux-x86-64 -i $FILES_NEEDS_FORMAT" + echo "Warning: The following files need formatting:" + echo "$FILES_NEEDS_FORMAT" + echo "" + echo "Please use google-java-format from https://github.com/google/google-java-format/releases/" + echo "" + echo "To fix, run one of these commands:" + echo " # Using xargs (recommended):" + echo " find -type f -name '*.java' | xargs google-java-format -i" + echo "" + echo " # Or format specific files:" + echo "$FILES_NEEDS_FORMAT" | while IFS= read -r file; do + echo " google-java-format -i \"$file\"" + done exit 1 fi diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml new file mode 100644 index 00000000000..92351883e8f --- /dev/null +++ b/.github/workflows/metal.yml @@ -0,0 +1,131 @@ +name: Test Metal Backend + +on: + pull_request: + push: + branches: + - main + - release/* + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: false + +jobs: + test-metal-builds: + name: test-executorch-metal-build + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + runner: macos-m2-stable + python-version: '3.11' + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + set -eux + + echo "::group::Test ExecuTorch Metal build" + PYTHON_EXECUTABLE=python CMAKE_ARGS="-DEXECUTORCH_BUILD_METAL=ON" ${CONDA_RUN} --no-capture-output ./install_executorch.sh + echo "::endgroup::" + + export-model-metal-artifact: + name: export-model-metal-artifact + # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) + if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request' + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + secrets: inherit + strategy: + fail-fast: false + matrix: + model: + - repo: "mistralai" + name: "Voxtral-Mini-3B-2507" + - repo: "openai" + name: "whisper-small" + - repo: "openai" + name: "whisper-large-v3-turbo" + quant: + - "non-quantized" + with: + runner: macos-m2-stable + python-version: '3.11' + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + secrets-env: EXECUTORCH_HF_TOKEN + upload-artifact: ${{ matrix.model.repo }}-${{ matrix.model.name }}-metal-${{ matrix.quant }} + script: | + set -eux + + echo "::group::Setup Huggingface" + ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" accelerate + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + echo "::endgroup::" + + echo "::group::Setup Optimum-ExecuTorch" + OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) + echo "Using optimum-executorch version: ${OPTIMUM_ET_VERSION}" + ${CONDA_RUN} pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} + echo "::endgroup::" + + echo "::group::Setup ExecuTorch" + PYTHON_EXECUTABLE=python ${CONDA_RUN} ./install_executorch.sh + echo "::endgroup::" + + echo "::group::Pip List" + ${CONDA_RUN} pip list + echo "::endgroup::" + + ${CONDA_RUN} bash .ci/scripts/export_model_artifact.sh metal "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" + + test-model-metal-e2e: + name: test-model-metal-e2e + needs: export-model-metal-artifact + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + strategy: + fail-fast: false + matrix: + model: + - repo: "mistralai" + name: "Voxtral-Mini-3B-2507" + - repo: "openai" + name: "whisper-small" + - repo: "openai" + name: "whisper-large-v3-turbo" + quant: + - "non-quantized" + with: + runner: macos-m2-stable + python-version: '3.11' + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + download-artifact: ${{ matrix.model.repo }}-${{ matrix.model.name }}-metal-${{ matrix.quant }} + script: | + set -eux + + echo "::group::Print machine info" + uname -a + if [ $(uname -s) == Darwin ]; then + sw_vers + # Print RAM in GB + RAM_BYTES=$(sysctl -n hw.memsize) + RAM_GB=$(echo "scale=2; $RAM_BYTES/1024/1024/1024" | bc) + echo "Available RAM (GB): $RAM_GB" + sysctl machdep.cpu.brand_string + sysctl machdep.cpu.core_count + # Print number of GPU cores (Apple Silicon) + if command -v system_profiler &> /dev/null; then + GPU_CORES=$(system_profiler SPDisplaysDataType | awk '/Total Number of Cores/ {print $5; exit}') + if [ -z "$GPU_CORES" ]; then + # Fallback: try to parse "Core Count" from Apple GPU section + GPU_CORES=$(system_profiler SPDisplaysDataType | awk '/Core Count/ {print $3; exit}') + fi + echo "GPU Cores: ${GPU_CORES:-Unknown}" + else + echo "system_profiler not available, cannot determine GPU cores." + fi + fi + echo "::endgroup::" + + ${CONDA_RUN} bash .ci/scripts/test_model_e2e.sh metal "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index c220b371c0a..f2aa4a3511e 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -36,51 +36,37 @@ jobs: uses: ./.github/workflows/_link_check.yml with: ref: ${{ github.sha }} - - backend-test-linux: + + test-static-hf-llm-qnn-linux: + name: test-static-hf-llm-qnn-linux uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - strategy: - fail-fast: false - matrix: - flow: [ - qnn, qnn_16a16w, qnn_16a8w, qnn_16a4w, qnn_16a4w_block, qnn_8a8w, - vulkan, vulkan_static_int8_per_channel, - xnnpack, xnnpack_dynamic_int8_per_channel, xnnpack_static_int8_per_channel, xnnpack_static_int8_per_tensor - ] - suite: [models, operators] - with: - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - runner: linux.4xlarge.memory - docker-image: ci-image:executorch-ubuntu-22.04-clang12 - submodules: recursive - timeout: 120 - upload-artifact: test-report-${{ matrix.flow }}-${{ matrix.suite }} - script: | - set -eux - - source .ci/scripts/test_backend_linux.sh "${{ matrix.suite }}" "${{ matrix.flow }}" "${RUNNER_ARTIFACT_DIR}" - - backend-test-macos: - uses: pytorch/test-infra/.github/workflows/macos_job.yml@main permissions: id-token: write contents: read strategy: - fail-fast: false matrix: - flow: [coreml, coreml_static_int8] - suite: [models, operators] + task: [smollm2_135m] + fail-fast: false with: + runner: linux.24xlarge + docker-image: ci-image:executorch-ubuntu-22.04-qnn-sdk + submodules: 'recursive' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - runner: macos-m1-stable - python-version: 3.12 - submodules: recursive - timeout: 120 - upload-artifact: test-report-${{ matrix.flow }}-${{ matrix.suite }} + timeout: 900 script: | - set -eux + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + BUILD_TOOL="cmake" + + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh + PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh + + # Setup executorch + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool "${BUILD_TOOL}" - # This is needed to get the prebuilt PyTorch wheel from S3 - ${CONDA_RUN} --no-capture-output pip install awscli==1.37.21 + # Setup install_requirements for llama + PYTHON_EXECUTABLE=python bash examples/models/llama/install_requirements.sh - source .ci/scripts/test_backend_macos.sh "${{ matrix.suite }}" "${{ matrix.flow }}" "${RUNNER_ARTIFACT_DIR}" + PYTHON_EXECUTABLE=python bash .ci/scripts/test_qnn_static_llm.sh ${{ matrix.task }} diff --git a/.github/workflows/pending_user_response.py b/.github/workflows/pending_user_response.py new file mode 100644 index 00000000000..846e488e015 --- /dev/null +++ b/.github/workflows/pending_user_response.py @@ -0,0 +1,82 @@ +import datetime +import os + +from github import Github + +REPO_NAME = "pytorch/executorch" +LABEL = "need-user-input" +REMINDER_MARKER = "" +REMINDER_COMMENT = ( + f"{REMINDER_MARKER}\nHi @{0}, this issue/PR has been marked as 'need-user-input'. " + "Please respond or provide input. If we don't hear back in 30 days, this will be closed." +) +CLOSE_COMMENT = ( + f"{REMINDER_MARKER}\nClosing due to no response after 30 days. " + "If you still need help, feel free to re-open or comment again!" +) +DAYS_BEFORE_REMINDER = 30 +DAYS_BEFORE_CLOSE = 30 +REMINDER_COOLDOWN_DAYS = 7 # Don't post another reminder within 7 days + + +def main(): + g = Github(os.environ["GH_TOKEN"]) + repo = g.get_repo(REPO_NAME) + + print("[VALIDATION] Would connect to Github and fetch repo:", REPO_NAME) + issues = repo.get_issues(state="open", labels=[LABEL]) + print(f"[VALIDATION] Would fetch open issues with label '{LABEL}'.") + + now = datetime.datetime.utcnow() + + for issue in issues: + print(f"[VALIDATION] Would fetch comments for issue/PR #{issue.number}.") + comments = [] # Replace with mock comments if needed + last_comment = comments[-1] if comments else None + + # Find automation comments + auto_comments = [c for c in comments if REMINDER_MARKER in c.body] + user_comments = [c for c in comments if REMINDER_MARKER not in c.body] + + # ---- REMINDER LOGIC ---- + # Only remind if NO reminder in last 7 days + recent_auto_reminder = any( + (now - c.created_at).days < REMINDER_COOLDOWN_DAYS for c in auto_comments + ) + + if not auto_comments: + if ( + last_comment + and (now - last_comment.created_at).days >= DAYS_BEFORE_REMINDER + ): + user = issue.user.login + print(f"[VALIDATION] Would remind {user} on issue/PR #{issue.number}") + elif auto_comments and not recent_auto_reminder: + # Only post new reminder if last was > REMINDER_COOLDOWN_DAYS ago + last_auto = auto_comments[-1] + user = issue.user.login + if (now - last_auto.created_at).days >= REMINDER_COOLDOWN_DAYS: + print( + f"[VALIDATION] Would remind {user} again on issue/PR #{issue.number}" + ) + + # ---- EXISTING CLOSE/REMOVE LABEL LOGIC ---- + if auto_comments: + last_auto = auto_comments[-1] + user_responded = any( + c.created_at > last_auto.created_at and c.user.login == issue.user.login + for c in user_comments + ) + if not user_responded: + if (now - last_auto.created_at).days >= DAYS_BEFORE_CLOSE: + print( + f"[VALIDATION] Would close issue/PR #{issue.number} due to inactivity." + ) + else: + print( + f"[VALIDATION] Would remove label from issue/PR #{issue.number} after user response." + ) + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/pending_user_response.yml b/.github/workflows/pending_user_response.yml new file mode 100644 index 00000000000..4c431c7d5cb --- /dev/null +++ b/.github/workflows/pending_user_response.yml @@ -0,0 +1,26 @@ +name: Needs User Input Automation + +on: + schedule: + - cron: '0 8 * * 1' # runs every Monday at 8:00 UTC + workflow_dispatch: + +jobs: + needs-user-input: + runs-on: ubuntu-latest + steps: + - name: Checkout repo + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: pip install PyGithub + + - name: Run needs-user-input script + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: python .github/scripts/pending_user_response.py diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 6e9169132e5..6f83f7b45e6 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -13,6 +13,34 @@ concurrency: cancel-in-progress: true jobs: + test-qnn-wheel-packages-linux: + name: test-qnn-wheel-packages-linux + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + if: false + permissions: + id-token: write + contents: read + strategy: + fail-fast: false + matrix: + python-version: [ "3.10", "3.11", "3.12", "3.13" ] + with: + runner: linux.2xlarge + docker-image: ci-image:executorch-ubuntu-22.04-qnn-sdk + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 180 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + # Create a clean env for each python version + conda create -y -n test_env_${{ matrix.python-version }} python=${{ matrix.python-version }} + conda activate test_env_${{ matrix.python-version }} + + PYTHON_EXECUTABLE=python bash .ci/scripts/test_wheel_package_qnn.sh "${{ matrix.python-version }}" + test-setup-linux-gcc: name: test-setup-linux-gcc uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main @@ -23,7 +51,7 @@ jobs: fail-fast: false with: runner: linux.2xlarge - docker-image: ci-image:executorch-ubuntu-22.04-gcc9 + docker-image: ci-image:executorch-ubuntu-22.04-gcc11 submodules: 'recursive' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} timeout: 90 @@ -259,15 +287,20 @@ jobs: # Test selective build PYTHON_EXECUTABLE=python bash examples/selective_build/test_selective_build.sh "${BUILD_TOOL}" - test-llava-runner-linux: - name: test-llava-runner-linux + test-multimodal-linux: + if: ${{ !github.event.pull_request.head.repo.fork }} + name: test-multimodal-linux uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: id-token: write contents: read + secrets: inherit strategy: fail-fast: false + matrix: + model: ["gemma3-4b"] # llava gives segfault so not covering. with: + secrets-env: EXECUTORCH_HF_TOKEN runner: linux.24xlarge docker-image: ci-image:executorch-ubuntu-22.04-clang12 submodules: 'recursive' @@ -278,17 +311,20 @@ jobs: CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" + echo "::group::Setup ExecuTorch" PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool "cmake" + echo "::endgroup::" - # install Llava requirements - bash examples/models/llama/install_requirements.sh - bash examples/models/llava/install_requirements.sh - - # run python unittest - python -m unittest examples.models.llava.test.test_llava + echo "::group::Setup Huggingface" + pip install -U "huggingface_hub[cli]<1.0" accelerate + huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) + pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} + echo "::endgroup::" - # run e2e (export, tokenizer and runner) - PYTHON_EXECUTABLE=python bash .ci/scripts/test_llava.sh + echo "::group::Test ${{ matrix.model }}" + python .ci/scripts/test_huggingface_optimum_model.py --model ${{ matrix.model }} --quantize --recipe xnnpack + echo "::endgroup::" test-moshi-linux: name: test-moshi-linux @@ -305,6 +341,7 @@ jobs: ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} timeout: 90 script: | + set -eux # The generic Linux job chooses to use base env, not the one setup by the image CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" @@ -316,6 +353,7 @@ jobs: # reinstall executorch bash ./install_executorch.sh --minimal + pip list # run python unittest python -m unittest examples.models.moshi.mimi.test_mimi @@ -353,7 +391,7 @@ jobs: fail-fast: false with: runner: linux.2xlarge - docker-image: ci-image:executorch-ubuntu-22.04-gcc9 + docker-image: ci-image:executorch-ubuntu-22.04-gcc9-nopytorch submodules: 'recursive' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} timeout: 90 @@ -362,15 +400,16 @@ jobs: CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" - ./install_requirements.sh --use-pt-pinned-commit + ./install_requirements.sh # build module for executorch.extension.pybindings.portable_lib bash test/build_size_test.sh strip cmake-out/test/size_test output=$(ls -la cmake-out/test/size_test) arr=($output) size=${arr[4]} - # threshold=48120 on devserver with gcc11.4 + # threshold=48120 on devserver with gcc9 # todo(lfq): update once binary size is below 50kb. + # Note: using gcc9-nopytorch container with pinned nightly PyTorch threshold="63776" if [[ "$size" -le "$threshold" ]]; then echo "Success $size <= $threshold" @@ -530,6 +569,8 @@ jobs: id-token: write contents: read strategy: + matrix: + task: [stories_110m, stories_260k_bc] fail-fast: false with: runner: linux.2xlarge @@ -553,8 +594,7 @@ jobs: # Setup install_requirements for llama PYTHON_EXECUTABLE=python bash examples/models/llama/install_requirements.sh - # Test static llama weight sharing and accuracy - PYTHON_EXECUTABLE=python bash .ci/scripts/test_qnn_static_llama.sh + PYTHON_EXECUTABLE=python bash .ci/scripts/test_qnn_static_llm.sh ${{ matrix.task }} test-qnn-models-linux: name: test-qnn-models-linux @@ -596,11 +636,14 @@ jobs: # The generic Linux job chooses to use base env, not the one setup by the image CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" - + echo "::group::Setup ExecuTorch" PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool "cmake" + echo "::endgroup::" + echo "::group::Setup requirements" # install phi-3-mini requirements bash examples/models/phi-3-mini/install_requirements.sh + echo "::endgroup::" # run e2e (export, tokenizer and runner) PYTHON_EXECUTABLE=python bash .ci/scripts/test_phi_3_mini.sh Release @@ -687,8 +730,8 @@ jobs: # run llama runner in eager mode PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama_runner_eager.sh - test-llama-lora-linux: - name: test-llama-lora-linux + test-lora-linux: + name: test-lora-linux uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: id-token: write @@ -711,11 +754,8 @@ jobs: # Install llama requirements bash examples/models/llama/install_requirements.sh - # install a recent version of torchtune. - PYTHON_EXECUTABLE=python python -m pip install torchtune==0.7.0.dev20250730 --extra-index-url https://download.pytorch.org/whl/nightly/cpu - # run llama runner in eager mode - PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama_lora.sh + PYTHON_EXECUTABLE=python bash .ci/scripts/test_lora.sh test-mediatek-models-linux: name: test-mediatek-models-linux @@ -754,7 +794,7 @@ jobs: fail-fast: false with: runner: linux.2xlarge - docker-image: ci-image:executorch-ubuntu-22.04-gcc9 + docker-image: ci-image:executorch-ubuntu-22.04-gcc11 submodules: 'recursive' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} timeout: 90 @@ -821,15 +861,24 @@ jobs: # Install Node.js and Emscripten source .ci/scripts/setup-emscripten.sh + export PNPM_VERSION=10.24.0 + + curl -fsSL https://get.pnpm.io/install.sh | env PNPM_VERSION=$PNPM_VERSION SHELL="$(which bash)" sh - + + export PNPM_HOME="$HOME/.local/share/pnpm" + export PATH="$PNPM_HOME:$PATH" + + pnpm --version + # Test selective build bash scripts/build_wasm_tests.sh ${{ matrix.enable-etdump }} # Install Jest cd cmake-out-wasm/extension/wasm/test - npm install --save-dev jest + pnpm add -D jest@30.2.0 --ignore-scripts # Run unit test - npm test + pnpm test unittest-nxp-neutron: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main @@ -857,6 +906,7 @@ jobs: # Install test requirements pip install -r backends/nxp/requirements-tests-pypi.txt pip install -r backends/nxp/requirements-tests-eiq.txt + PYTHON_EXECUTABLE=python bash examples/nxp/setup.sh # Run pytest PYTHON_EXECUTABLE=python bash backends/nxp/run_unittests.sh @@ -868,11 +918,15 @@ jobs: test-samsung-models-linux: name: test-samsung-models-linux + # Skip this job if the pull request is from a fork (secrets are not available) + if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request' uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: id-token: write contents: read + secrets: inherit with: + secrets-env: SAMSUNG_AI_LITECORE_KEY runner: linux.2xlarge docker-image: ci-image:executorch-ubuntu-22.04-clang12-android submodules: 'recursive' @@ -889,6 +943,7 @@ jobs: PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool "cmake" # Setup Samsung SDK (AI Lite Core) and install enn backend + export SAMSUNG_AI_LITECORE_KEY=$SECRET_SAMSUNG_AI_LITECORE_KEY source .ci/scripts/setup-samsung-linux-deps.sh # Test models serially @@ -897,6 +952,12 @@ jobs: python -m executorch.examples.samsung.aot_compiler --model_name=$model -c E9955 done + # Test quant models + model_scripts="deeplab_v3 edsr inception_v3 inception_v4 mobilenet_v2 mobilenet_v3 resnet18 resnet50 vit wav2letter" + for m_script in $model_scripts; do + python -m executorch.examples.samsung.scripts.${m_script} -c e9955 -p A8W8 + done + # Test ops python -m unittest discover -s backends/samsung/test/ops -p "test_*.py" @@ -931,11 +992,16 @@ jobs: PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build # Test models serially - models="mv2 mv3 edsr resnet18 resnet50 dl3" + models="mv2 mv3 edsr resnet18 resnet50 dl3 w2l ic3 ic4" for model in $models; do python -m examples.vulkan.export --model_name=$model --test done + # For selected vision models, test with dynamic shapes + models="mv2 resnet18 resnet50 ic3 densenet161" + for model in $models; do + python -m examples.vulkan.export --model_name=$model --test -d + done test-vulkan-operators-linux: name: test-vulkan-operators-linux @@ -970,6 +1036,8 @@ jobs: ./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d ./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear ./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row + ./cmake-out/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations + ./cmake-out/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add # "Classic" Operator tests PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build diff --git a/.github/workflows/test-backend-arm.yml b/.github/workflows/test-backend-arm.yml new file mode 100644 index 00000000000..638d5a2079f --- /dev/null +++ b/.github/workflows/test-backend-arm.yml @@ -0,0 +1,32 @@ +name: Test ARM Backend + +on: + schedule: + - cron: 0 2 * * * + push: + branches: + - release/* + tags: + - ciflow/nightly/* + pull_request: + paths: + - .github/workflows/test-backend-arm.yml + - .github/workflows/_test_backend.yml + - .ci/scripts/test_backend.sh + - backends/test/suite/flow.py + - backends/test/suite/flows/arm.py + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}--${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + test-arm: + uses: ./.github/workflows/_test_backend.yml + with: + backend: arm + flows: '["arm_tosa_fp", "arm_tosa_int", "arm_ethos_u55", "arm_ethos_u85", "arm_vgf_fp", "arm_vgf_int"]' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + run-linux: true diff --git a/.github/workflows/test-backend-coreml.yml b/.github/workflows/test-backend-coreml.yml new file mode 100644 index 00000000000..247f9576595 --- /dev/null +++ b/.github/workflows/test-backend-coreml.yml @@ -0,0 +1,29 @@ +name: Test CoreML Backend + +on: + schedule: + - cron: 0 2 * * * + push: + branches: + - release/* + tags: + - ciflow/nightly/* + pull_request: + paths: + - .github/workflows/test-backend-coreml.yml + - .github/workflows/_test_backend.yml + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}--${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + test-coreml: + uses: ./.github/workflows/_test_backend.yml + with: + backend: coreml + flows: '["coreml", "coreml_static_int8"]' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + run-macos: true diff --git a/.github/workflows/test-backend-qnn.yml b/.github/workflows/test-backend-qnn.yml new file mode 100644 index 00000000000..907c4d2dac0 --- /dev/null +++ b/.github/workflows/test-backend-qnn.yml @@ -0,0 +1,30 @@ +name: Test QNN Backend + +on: + schedule: + - cron: 0 2 * * * + push: + branches: + - release/* + tags: + - ciflow/nightly/* + pull_request: + paths: + - .github/workflows/test-backend-qnn.yml + - .github/workflows/_test_backend.yml + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}--${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + test-qnn: + uses: ./.github/workflows/_test_backend.yml + with: + backend: qnn + flows: '["qnn", "qnn_16a16w", "qnn_16a8w", "qnn_16a4w", "qnn_16a4w_block", "qnn_8a8w"]' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + run-linux: true + runner-linux: linux.8xlarge.memory diff --git a/.github/workflows/test-backend-vulkan.yml b/.github/workflows/test-backend-vulkan.yml new file mode 100644 index 00000000000..cb2478fc825 --- /dev/null +++ b/.github/workflows/test-backend-vulkan.yml @@ -0,0 +1,29 @@ +name: Test Vulkan Backend + +on: + schedule: + - cron: 0 2 * * * + push: + branches: + - release/* + tags: + - ciflow/nightly/* + pull_request: + paths: + - .github/workflows/test-backend-vulkan.yml + - .github/workflows/_test_backend.yml + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}--${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + test-vulkan: + uses: ./.github/workflows/_test_backend.yml + with: + backend: vulkan + flows: '["vulkan", "vulkan_static_int8_per_channel"]' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + run-linux: true diff --git a/.github/workflows/test-backend-xnnpack.yml b/.github/workflows/test-backend-xnnpack.yml new file mode 100644 index 00000000000..086c9625a38 --- /dev/null +++ b/.github/workflows/test-backend-xnnpack.yml @@ -0,0 +1,29 @@ +name: Test XNNPACK Backend + +on: + schedule: + - cron: 0 2 * * * + push: + branches: + - release/* + tags: + - ciflow/nightly/* + pull_request: + paths: + - .github/workflows/test-backend-xnnpack.yml + - .github/workflows/_test_backend.yml + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}--${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + test-xnnpack: + uses: ./.github/workflows/_test_backend.yml + with: + backend: xnnpack + flows: '["xnnpack", "xnnpack_dynamic_int8_per_channel", "xnnpack_static_int8_per_channel", "xnnpack_static_int8_per_tensor"]' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + run-linux: true diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 975a8ebbb30..eb907db3d73 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -100,7 +100,7 @@ jobs: # cd $ZEPHYR_PROJ_ROOT/modules/lib/executorch # install_executorch # .ci/scripts/setup-arm-baremetal-tools.sh --target-toolchain zephyr -# source examples/arm/ethos-u-scratch/setup_path.sh +# source examples/arm/arm-scratch/setup_path.sh # source $ZEPHYR_PROJ_ROOT/zephyr/zephyr-env.sh # # # Get the model as PTE @@ -289,6 +289,8 @@ jobs: - test_arm_baremetal: test_models_ethos-u55 - test_arm_baremetal: test_models_ethos-u85 - test_arm_baremetal: test_smaller_stories_llama + - test_arm_baremetal: test_memory_allocation + - test_arm_baremetal: test_model_smollm2-135M fail-fast: false with: runner: linux.2xlarge.memory @@ -315,6 +317,40 @@ jobs: # Test test_arm_baremetal.sh with test backends/arm/test/test_arm_baremetal.sh "${ARM_TEST}" + test-arm-backend-vkml: + name: test-arm-backend-vkml + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + strategy: + matrix: + include: + - test_arm_baremetal: test_pytest_ops_vkml + fail-fast: false + with: + runner: linux.2xlarge.memory + docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + source .ci/scripts/utils.sh + install_executorch "--use-pt-pinned-commit" + + .ci/scripts/setup-arm-baremetal-tools.sh --disable-ethos-u-deps --enable-mlsdk-deps --install-mlsdk-deps-with-pip + + # Increase number of files user can monitor to bypass buck failures. + # Hopefully this is high enough for this setup. + sudo sysctl fs.inotify.max_user_watches=1048576 # 1024 * 1024 + + ARM_TEST=${{ matrix.test_arm_baremetal }} + + backends/arm/test/test_arm_baremetal.sh "${ARM_TEST}" + test-arm-cortex-m-size-test: name: test-arm-cortex-m-size-test uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main @@ -345,7 +381,7 @@ jobs: elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then setup_script_args="--target-toolchain zephyr" toolchain_prefix=arm-zephyr-eabi- - threshold="135168" # 132 KiB + threshold="136000" # 136 KiB toolchain_cmake=examples/zephyr/x86_64-linux-arm-zephyr-eabi-gcc.cmake else echo "Fail unsupport OS selection ${{ matrix.os }}" @@ -355,7 +391,7 @@ jobs: source .ci/scripts/utils.sh install_executorch "--use-pt-pinned-commit" .ci/scripts/setup-arm-baremetal-tools.sh ${setup_script_args} - source examples/arm/ethos-u-scratch/setup_path.sh + source examples/arm/arm-scratch/setup_path.sh # User toolchain ${toolchain_prefix}c++ --version @@ -423,7 +459,7 @@ jobs: install_executorch "--use-pt-pinned-commit" .ci/scripts/setup-arm-baremetal-tools.sh - source examples/arm/ethos-u-scratch/setup_path.sh + source examples/arm/arm-scratch/setup_path.sh # Install requirements for converting notebooks pip install notebook @@ -594,15 +630,22 @@ jobs: strategy: matrix: model: [qwen3_4b, phi_4_mini] + runner: [linux.2xlarge] + docker-image: [executorch-ubuntu-22.04-clang12] + backend: [xnnpack] include: - model: qwen3_4b - test_with_runner: true + runner: linux.arm64.2xlarge + docker-image: executorch-ubuntu-22.04-gcc11-aarch64 + backend: torchao - model: phi_4_mini - test_with_runner: false + runner: linux.arm64.2xlarge + docker-image: executorch-ubuntu-22.04-gcc11-aarch64 + backend: torchao fail-fast: false with: - runner: linux.2xlarge - docker-image: ci-image:executorch-ubuntu-22.04-clang12 + runner: ${{ matrix.runner }} + docker-image: ci-image:${{ matrix.docker-image }} submodules: 'recursive' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} timeout: 900 @@ -612,38 +655,54 @@ jobs: conda activate "${CONDA_ENV}" PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake - pip install -U "huggingface_hub[cli]" - - bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} ${{ matrix.test_with_runner && '--test_with_runner' || '' }} - - # # TODO(jackzhxng): Runner consistently runs out of memory before test finishes. Try to find a more powerful runner. - # test-llava-runner-macos: - # name: test-llava-runner-macos - # uses: pytorch/test-infra/.github/workflows/macos_job.yml@main - # strategy: - # fail-fast: false - # with: - # runner: macos-14-xlarge - # python-version: '3.11' - # submodules: 'recursive' - # ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - # timeout: 900 - # script: | - # BUILD_TOOL=cmake - - # bash .ci/scripts/setup-conda.sh - # # Setup MacOS dependencies as there is no Docker support on MacOS atm - # GITHUB_RUNNER=1 PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/setup-macos.sh --build-tool "${BUILD_TOOL}" - - # # install Llava requirements - # ${CONDA_RUN} bash examples/models/llama/install_requirements.sh - # ${CONDA_RUN} bash examples/models/llava/install_requirements.sh - - # # run python unittest - # ${CONDA_RUN} python -m unittest examples.models.llava.test.test_llava - - # # run e2e (export, tokenizer and runner) - # PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_llava.sh + + if [[ "${{ matrix.backend }}" == "torchao" ]]; then + BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install --no-build-isolation third-party/ao + fi + + pip install -U "huggingface_hub[cli]<1.0" + + bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} --test_with_runner ${{ matrix.backend == 'torchao' && '--use_torchao_kernels' || '' }} + + test-multimodal-macos: + if: ${{ !github.event.pull_request.head.repo.fork }} + name: test-multimodal-macos + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + permissions: + id-token: write + contents: read + secrets: inherit + strategy: + fail-fast: false + matrix: + model: ["gemma3-4b"] # llava gives segfault so not covering. + with: + secrets-env: EXECUTORCH_HF_TOKEN + runner: macos-15-xlarge + python-version: '3.11' + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + echo "::group::Set up ExecuTorch" + bash .ci/scripts/setup-conda.sh + eval "$(conda shell.bash hook)" + + # Install requirements + ${CONDA_RUN} python install_executorch.py + echo "::endgroup::" + + echo "::group::Set up Huggingface" + ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" accelerate + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) + ${CONDA_RUN} pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} + ${CONDA_RUN} pip list + echo "::endgroup::" + + echo "::group::Test ${{ matrix.model }}" + ${CONDA_RUN} python .ci/scripts/test_huggingface_optimum_model.py --model ${{ matrix.model }} --quantize --recipe xnnpack + echo "::endgroup::" test-qnn-model: name: test-qnn-model @@ -800,11 +859,26 @@ jobs: echo "Recipe: $RECIPE" echo "Quantize: $QUANTIZE" - echo "::group::Set up ExecuTorch" # The generic Linux job chooses to use base env, not the one setup by the image CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" - PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake + + echo "::group::Setup ExecuTorch" + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool "cmake" + echo "::endgroup::" + + echo "::group::Setup Huggingface" + pip install -U "huggingface_hub[cli]<1.0" accelerate + huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) + pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} + echo "::endgroup::" + + echo "::group::Test MODEL: $MODEL RECIPE: $RECIPE QUANTIZE: $QUANTIZE" + export OUTPUT_DIR="$(pwd)/${MODEL}_${RECIPE}_${QUANTIZE}" + python .ci/scripts/test_huggingface_optimum_model.py --model "$MODEL" --recipe "$RECIPE" $QUANTIZE --model_dir "$OUTPUT_DIR" + echo "::endgroup::" + # Build executor_runner with ETdump enabled PYTHON_EXECUTABLE=python cmake -DPYTHON_EXECUTABLE=python \ -DCMAKE_INSTALL_PREFIX=cmake-out \ @@ -813,6 +887,7 @@ jobs: -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_XNNPACK=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ @@ -822,25 +897,6 @@ jobs: -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ -Bcmake-out . cmake --build cmake-out -j16 --target install --config Release - echo "::endgroup::" - - echo "::group::Set up Hugging Face" - pip install -U "huggingface_hub[cli]" - huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN - OPTIMUM_ET_COMMIT=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) - git clone https://github.com/huggingface/optimum-executorch - pushd optimum-executorch - # There is no release yet, for CI stability, always test from the same commit on main - git checkout $OPTIMUM_ET_COMMIT - python install_dev.py --skip_override_torch - popd - pip list - echo "::endgroup::" - - echo "::group::Run tests" - export OUTPUT_DIR="$(pwd)/${MODEL}_${RECIPE}_${QUANTIZE}" - python .ci/scripts/test_huggingface_optimum_model.py --model ${MODEL} --recipe ${RECIPE} ${QUANTIZE} --model_dir ${OUTPUT_DIR} - echo "::endgroup::" echo "::group::Generate artifacts for performance profiling" ./cmake-out/executor_runner \ @@ -907,16 +963,11 @@ jobs: ${CONDA_RUN} python install_executorch.py echo "::endgroup::" - echo "::group::Set up Hugging Face" - pip install -U "huggingface_hub[cli]" - huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN - OPTIMUM_ET_COMMIT=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) - git clone https://github.com/huggingface/optimum-executorch - pushd optimum-executorch - # There is no release yet, for CI stability, always test from the same commit on main - git checkout $OPTIMUM_ET_COMMIT - ${CONDA_RUN} python install_dev.py --skip_override_torch - popd + echo "::group::Set up Huggingface" + ${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0" accelerate + ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) + ${CONDA_RUN} pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} ${CONDA_RUN} pip list echo "::endgroup::" @@ -962,62 +1013,77 @@ jobs: # Test llama2 PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh -model stories110M -build_tool "${BUILD_TOOL}" -mode "${MODE}" -dtype "${DTYPE}" -pt2e_quantize "${PT2E_QUANTIZE}" - unittest-release: - uses: ./.github/workflows/_unittest.yml + # this is for filtering out the qnn changes such that qnn jobs only triggered when the specific files are changed + changes: + runs-on: ubuntu-latest + outputs: + qnn: ${{ steps.filter.outputs.qnn }} + steps: + - uses: actions/checkout@v4 + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: | + qnn: + - 'backends/qualcomm/**' + - 'examples/qualcomm/**' + - 'examples/models/llama/**' + + test-static-llama-qnn-eval-linux: + needs: changes # has dependency on changes jobs defined above + if: needs.changes.outputs.qnn == 'true' + name: test-static-llama-qnn-eval-linux + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: id-token: write contents: read - with: - build-mode: Release - build-tool: cmake - docker-image: ci-image:executorch-ubuntu-22.04-clang12 - - test-mcu-models: - name: test-mcu-models - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main strategy: - matrix: - include: - - build-tool: cmake fail-fast: false - permissions: - id-token: write - contents: read + matrix: + config: + - name: "baseline" + flags: "" + threshold: 62.0 with: runner: linux.2xlarge - docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk + docker-image: ci-image:executorch-ubuntu-22.04-qnn-sdk submodules: 'recursive' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - timeout: 90 + timeout: 180 script: | - BUILD_TOOL=${{ matrix.build-tool }} - # The generic Linux job chooses to use base env, not the one setup by the image CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" + BUILD_TOOL="cmake" + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh + PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh + # Setup executorch + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool "${BUILD_TOOL}" + # Setup install_requirements for llama + PYTHON_EXECUTABLE=python bash examples/models/llama/install_requirements.sh - # Try to mirror these as closely as possible - source .ci/scripts/utils.sh - install_executorch "--use-pt-pinned-commit" - - .ci/scripts/setup-arm-baremetal-tools.sh - source examples/arm/ethos-u-scratch/setup_path.sh - - # Run selective Build - chmod +x examples/selective_build/test_selective_build.sh - examples/selective_build/test_selective_build.sh "${BUILD_TOOL}" + echo ">>> Running config: ${{ matrix.config.name }}" + PYTHON_EXECUTABLE=python bash .ci/scripts/test_qnn_static_llama_eval.sh \ + --flags "${{ matrix.config.flags }}" \ + --threshold "${{ matrix.config.threshold }}" - # Run MCU models - chmod +x examples/arm/run_mcu_models_fvp.sh - examples/arm/run_mcu_models_fvp.sh --target=cortex-m55 + unittest-release: + uses: ./.github/workflows/_unittest.yml + permissions: + id-token: write + contents: read + with: + build-mode: Release + build-tool: cmake + docker-image: ci-image:executorch-ubuntu-22.04-clang12 test-models-windows: uses: pytorch/test-infra/.github/workflows/windows_job.yml@main strategy: fail-fast: false matrix: - model: [linear, add, add_mul, ic3, ic4, mv2, mv3, resnet18, resnet50, vit, w2l, mobilebert, emformer_join, emformer_transcribe] - backend: [portable, xnnpack-f32, xnnpack-q8] + model: [mv3, resnet50, vit, mobilebert, emformer_transcribe] + backend: [portable, xnnpack-q8] with: submodules: 'recursive' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -1034,3 +1100,33 @@ jobs: .ci/scripts/test_model.ps1 -modelName ${{ matrix.model }} -backend ${{ matrix.backend }} }" + + test-mcu-cortex-m-backend: + name: test-mcu-cortex-m-backend + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + runner: linux.2xlarge.memory + docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + source .ci/scripts/utils.sh + install_executorch "--use-pt-pinned-commit" + + # Install arm dependencies + .ci/scripts/setup-arm-baremetal-tools.sh + source examples/arm/arm-scratch/setup_path.sh + + # To build cortex-m test runner + backends/cortex_m/test/build_test_runner.sh + + # To run cortex_m tests + pytest --config-file=backends/arm/test/pytest.ini backends/cortex_m/test diff --git a/.github/workflows/update-viablestrict.yml b/.github/workflows/update-viablestrict.yml index e639c497549..5e3b5399bfc 100644 --- a/.github/workflows/update-viablestrict.yml +++ b/.github/workflows/update-viablestrict.yml @@ -20,7 +20,7 @@ jobs: with: repository: pytorch/executorch stable-branch: viable/strict - requires: '[\"pull\", \"lint\", \"trunk\", \"Build documentation\", \"^Apple$\"]' + requires: '[\"pull\", \"lint\", \"trunk\", \"Build documentation\", \"^Apple$\", \"docker-builds\"]' secret-bot-token: ${{ secrets.UPDATEBOT_TOKEN }} clickhouse-url: ${{ secrets.CLICKHOUSE_URL }} clickhouse-username: ${{ secrets.CLICKHOUSE_VIABLESTRICT_USERNAME }} diff --git a/.github/workflows/windows-msvc.yml b/.github/workflows/windows-msvc.yml new file mode 100644 index 00000000000..26312b050a4 --- /dev/null +++ b/.github/workflows/windows-msvc.yml @@ -0,0 +1,35 @@ +name: Windows MSVC Build + +on: + push: + branches: + - main + - release/* + tags: + - ciflow/trunk/* + pull_request: + paths: + - .ci/docker/ci_commit_pins/pytorch.txt + - .ci/scripts/** + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +jobs: + build-windows-msvc: + name: build-windows-msvc + uses: pytorch/test-infra/.github/workflows/windows_job.yml@main + with: + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 60 + script: | + conda init powershell + powershell -Command "& { + Set-PSDebug -Trace 1 + \$ErrorActionPreference = 'Stop' + \$PSNativeCommandUseErrorActionPreference = \$true + .ci/scripts/setup-windows-msvc.ps1 + }" diff --git a/.gitignore b/.gitignore index 511fb324ba2..7f2a42a72ec 100644 --- a/.gitignore +++ b/.gitignore @@ -16,8 +16,10 @@ cmake-android-out/ cmake-ios-out/ cmake-out* cmake-out-android/ +build-android/ +build-x86/ dist/ -ethos-u-scratch/ +arm-scratch/ executorch.egg-info pip-out/ build-profiling/ @@ -60,7 +62,6 @@ xcuserdata/ /include/ /share/ /version.py -*.csv *_etdump # Android diff --git a/.lintrunner.toml b/.lintrunner.toml index 0b6a6eb8908..396b7fde5ac 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -206,6 +206,7 @@ exclude_patterns = [ '**/*.png', '**/*.webp', '**/*.jpeg', + '**/*.mp3', '**/*.mp4', '**/*.pte', '**/*.pth', @@ -216,6 +217,9 @@ exclude_patterns = [ '**/*.jpg', '**/*.jar', '**/*.gif', + 'extension/llm/tokenizers', + 'extension/llm/tokenizers/**', + 'examples/cuda', # File contains @generated 'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h', 'extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h', @@ -363,7 +367,7 @@ exclude_patterns = [ '**/third-party/**', 'scripts/check_binary_dependencies.py', 'profiler/test/test_profiler_e2e.py', - 'backends/arm/test/**', + 'backends/arm/test/ops/*.py', ] command = [ 'python', @@ -445,3 +449,24 @@ command = [ "--", "@{{PATHSFILE}}", ] + +[[linter]] +code = 'ETVKNODEBUG' +include_patterns = [ + "backends/vulkan/**/*.glsl", +] +command = [ + 'python', + '-m', + 'lintrunner_adapters', + 'run', + 'grep_linter', + '--pattern=((DEBUG_MODE)|(GL_EXT_debug_printf))', + '--linter-name=ETVKNODEBUG', + '--error-name=Using DEBUG_MODE or GL_EXT_debug_printf in Vulkan shader', + """--error-description=\ + #define DEBUG_MODE or #extension GL_EXT_debug_printf should only be used during development! + """, + '--', + '@{{PATHSFILE}}', +] diff --git a/.mypy.ini b/.mypy.ini index cd14cbac7ea..0ce444e8a79 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -24,11 +24,14 @@ files = test, util -mypy_path = executorch +mypy_path = executorch,src [mypy-executorch.backends.*] follow_untyped_imports = True +[mypy-backends.arm.*] +disallow_untyped_decorators = False + [mypy-executorch.codegen.*] follow_untyped_imports = True @@ -74,6 +77,12 @@ ignore_missing_imports = True [mypy-pytorch_sphinx_theme] ignore_missing_imports = True +[mypy-pytorch_sphinx_theme2] +ignore_missing_imports = True + +[mypy-executorch.version] +ignore_missing_imports = True + [mypy-ruamel] ignore_missing_imports = True @@ -83,6 +92,12 @@ ignore_missing_imports = True [mypy-tosa_tools.*] ignore_missing_imports = True +[mypy-tosa_serializer] +ignore_missing_imports = True + +[mypy-tosa_serializer.*] +ignore_missing_imports = True + [mypy-setuptools.*] ignore_missing_imports = True diff --git a/CMakeLists.txt b/CMakeLists.txt index fc427d517a9..30cee4afe53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,28 +99,6 @@ announce_configured_options(CCACHE_PROGRAM) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -# Setup RPATH. See -# https://gitlab.kitware.com/cmake/community/-/wikis/doc/cmake/RPATH-handling -# Use separate rpaths during build and install phases -set(CMAKE_SKIP_BUILD_RPATH OFF) -# Don't use the install-rpath during the build phase -set(CMAKE_BUILD_WITH_INSTALL_RPATH ON) -# Automatically add all linked folders that are NOT in the build directory to -# the rpath (per library?) -# -# TODO: Doesn't work for us right now because we are not installing .so's into -# the correct locations. For example we have libcustom_ops_aot_lib.so depending -# on _portable_lib.so, which was eventually put under -# /executorch/extension/pybindings/ but this rpath is not -# automatically added because at build time it seems `portable_lib` is being -# built under the same directory, so no extra rpath is being added. To properly -# fix this we need to install `portable_lib` into the correct path. -set(CMAKE_INSTALL_RPATH_USE_LINK_PATH ON) -# ------------------------------ OPTIONS ------------------------------------- -# WARNING: Please don't add example specific options in this CMakeLists.txt. -# Instead please use `find_package(executorch REQUIRED)` in the example -# directory and add a new executable in the example `CMakeLists.txt`. - if(NOT EXECUTORCH_ENABLE_LOGGING) # Avoid pulling in the logging strings, which can be large. Note that this # will set the compiler flag for all targets in this directory, and for all @@ -141,6 +119,10 @@ if(EXECUTORCH_ENABLE_EVENT_TRACER) add_definitions(-DET_EVENT_TRACER_ENABLED) endif() +if(EXECUTORCH_ENABLE_BUNDLE_IO) + add_definitions(-DET_BUNDLE_IO_ENABLED) +endif() + # -ffunction-sections -fdata-sections: breaks function and data into sections so # they can be properly gc'd. -s: strip symbol. if(WIN32) @@ -226,7 +208,7 @@ if(EXECUTORCH_BUILD_CPUINFO) install( TARGETS cpuinfo EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} INCLUDES DESTINATION ${_common_include_directories} ) @@ -266,10 +248,22 @@ if(EXECUTORCH_BUILD_PTHREADPOOL) executorch_move_interface_include_directories_to_build_time_only( pthreadpool_interface ) + + if(APPLE) + # Use hidden visibility for pthreadpool on Apple platforms to avoid issues + # with pthreadpool symbols from libtorch_cpu taking precedence over the ones + # from the pthreadpool library statically linked in _portable_lib. The + # pthreadpool public APIs are marked as weak by default on some Apple + # platforms, so setting to hidden visibility works around this by not + # putting the symbol in the indirection table. See + # https://github.com/pytorch/executorch/issues/14321 for more details. + target_compile_options(pthreadpool PRIVATE -fvisibility=hidden) + endif() + install( TARGETS pthreadpool pthreadpool_interface fxdiv EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} INCLUDES DESTINATION ${_common_include_directories} ) @@ -284,7 +278,16 @@ if(EXECUTORCH_BUILD_TESTS) endif() # TODO(dbort): Fix these warnings and remove this flag. -set(_common_compile_options -Wno-deprecated-declarations -fPIC) +list(APPEND _common_compile_options $<$:/wd4996> + $<$>:-Wno-deprecated-declarations> +) +# Set default CMAKE_POSITION_INDEPENDENT_CODE behavior if ON or not set +# (default) (and not for MSVC compiler) +if(NOT DEFINED CMAKE_POSITION_INDEPENDENT_CODE + OR CMAKE_POSITION_INDEPENDENT_CODE +) + list(APPEND _common_compile_options $<$>:-fPIC>) +endif() # Let files say "include ". # TODO(#6475): This requires/assumes that the repo lives in a directory named @@ -587,6 +590,25 @@ endif() if(EXECUTORCH_BUILD_CORTEX_M) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m) + list(APPEND _executorch_backends coretex_m_backend) +endif() + +# Build common AOTI functionality if needed by CUDA or Metal backends +if(EXECUTORCH_BUILD_CUDA OR EXECUTORCH_BUILD_METAL) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/aoti) +endif() + +if(EXECUTORCH_BUILD_CUDA) + # Build CUDA-specific AOTI functionality + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cuda) + # Add aoti_cuda_backend to backends - it transitively includes aoti_cuda_shims + # and cuda_platform + list(APPEND _executorch_backends aoti_cuda_backend) +endif() + +if(EXECUTORCH_BUILD_METAL) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/apple/metal) + list(APPEND _executorch_backends metal_backend) endif() if(EXECUTORCH_BUILD_EXTENSION_APPLE) @@ -630,6 +652,11 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE) list(APPEND _executorch_extensions extension_module_static) endif() +if(EXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/named_data_map) + list(APPEND _executorch_extensions extension_named_data_map) +endif() + if(EXECUTORCH_BUILD_EXTENSION_LLM) if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) set(SUPPORT_REGEX_LOOKAHEAD ON) @@ -650,15 +677,6 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM) list(APPEND _executorch_extensions tokenizers) endif() -if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/runner) - list(APPEND _executorch_extensions extension_llm_runner) -endif() - -if(EXECUTORCH_BUILD_EXTENSION_LLM_APPLE) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/apple) -endif() - if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util) install( @@ -717,7 +735,7 @@ if(EXECUTORCH_BUILD_KERNELS_TORCHAO) install( TARGETS torchao_ops_executorch torchao_kernels_aarch64 EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} INCLUDES DESTINATION ${_common_include_directories} ) @@ -728,7 +746,7 @@ if(EXECUTORCH_BUILD_KERNELS_TORCHAO) install( TARGETS kleidiai EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} INCLUDES DESTINATION ${_common_include_directories} ) @@ -738,9 +756,6 @@ endif() if(EXECUTORCH_BUILD_PYBIND) - # Add codegen tools subdirectory for selective_build pybind module - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/codegen/tools) - if(NOT EXECUTORCH_BUILD_EXTENSION_DATA_LOADER) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/data_loader) endif() @@ -749,6 +764,9 @@ if(EXECUTORCH_BUILD_PYBIND) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/devtools) endif() + # Add codegen tools subdirectory for selective_build pybind module + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/codegen/tools) + # Create bundled_module target only for pybindings when bundled_program exists # This target has hard dependencies on devtools generated headers if(TARGET bundled_program) @@ -769,7 +787,10 @@ if(EXECUTORCH_BUILD_PYBIND) bundled_module PUBLIC ${_common_include_directories} ) target_compile_options( - bundled_module PUBLIC -Wno-deprecated-declarations -fPIC + bundled_module + PUBLIC $<$:/wd4996> + $<$>:-Wno-deprecated-declarations + -fPIC> ) endif() @@ -790,6 +811,9 @@ if(EXECUTORCH_BUILD_PYBIND) torch ) + # RPATH for _portable_lib.so + set(_portable_lib_rpath "$ORIGIN/../../../torch/lib") + if(EXECUTORCH_BUILD_EXTENSION_MODULE) # Always use static linking for pybindings to avoid runtime symbol # resolution issues @@ -824,6 +848,7 @@ if(EXECUTORCH_BUILD_PYBIND) if(EXECUTORCH_BUILD_QNN) list(APPEND _dep_libs qnn_executorch_backend) + string(APPEND _portable_lib_rpath ":$ORIGIN/../../backends/qualcomm") endif() if(EXECUTORCH_BUILD_ENN) @@ -841,8 +866,14 @@ if(EXECUTORCH_BUILD_PYBIND) endif() # compile options for pybind - set(_pybind_compile_options -Wno-deprecated-declarations -fPIC -frtti - -fexceptions + set(_pybind_compile_options + $<$:/EHsc + /GR + /wd4996> + $<$>:-Wno-deprecated-declarations + -fPIC + -frtti + -fexceptions> ) # util lib @@ -869,6 +900,23 @@ if(EXECUTORCH_BUILD_PYBIND) target_compile_options(portable_lib PUBLIC ${_pybind_compile_options}) target_link_libraries(portable_lib PRIVATE ${_dep_libs}) + # Set RPATH to find PyTorch and backend libraries relative to the installation + # location. This goes from executorch/extension/pybindings up to + # site-packages, then to torch/lib. If QNN is enabled, also add + # backends/qualcomm/. Don't do this to APPLE, as it will error out on the + # following error: + # + if(APPLE) + # Skip setting @loader_path for APPLE, since it causes error like ld: + # duplicate LC_RPATH '@loader_path' in '/torch/lib/ + # libtorch_cpu.dylib' + else() + set_target_properties( + portable_lib PROPERTIES BUILD_RPATH "${_portable_lib_rpath}" + INSTALL_RPATH "${_portable_lib_rpath}" + ) + endif() + install( TARGETS portable_lib EXPORT ExecuTorchTargets @@ -889,6 +937,20 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING) list(APPEND _executorch_extensions extension_training) endif() +if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/runner) + list(APPEND _executorch_extensions extension_llm_runner) +endif() + +if(EXECUTORCH_BUILD_EXTENSION_ASR_RUNNER) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/asr/runner) + list(APPEND _executorch_extensions extension_asr_runner) +endif() + +if(EXECUTORCH_BUILD_EXTENSION_LLM_APPLE) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/apple) +endif() + if(EXECUTORCH_BUILD_KERNELS_LLM) # TODO: move all custom kernels to ${CMAKE_CURRENT_SOURCE_DIR}/kernels/custom add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/custom_ops) @@ -984,7 +1046,7 @@ if(NOT EXECUTORCH_SELECT_OPS_YAML STREQUAL "" install( TARGETS executorch_selected_kernels EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} ) else() # No selective build - link the full library. @@ -1006,6 +1068,10 @@ if(EXECUTORCH_BUILD_EXECUTOR_RUNNER) extension_runner_util gflags executorch_backends ) + if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR) + list(APPEND _executor_runner_libs extension_flat_tensor) + endif() + if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED) list(APPEND _executor_runner_libs optimized_native_cpu_ops_lib) elseif(EXECUTORCH_BUILD_CADENCE) @@ -1027,6 +1093,10 @@ if(EXECUTORCH_BUILD_EXECUTOR_RUNNER) list(APPEND _executor_runner_libs etdump flatccrt) endif() + if(EXECUTORCH_ENABLE_BUNDLE_IO) + list(APPEND _executor_runner_libs bundled_program) + endif() + add_executable(executor_runner ${_executor_runner__srcs}) if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options_gc_sections(executor_runner) diff --git a/CMakePresets.json b/CMakePresets.json index bcf3bbc8d83..2b1512ac121 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -63,7 +63,8 @@ "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/third-party/ios-cmake/ios.toolchain.cmake", "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/ios.cmake", "PLATFORM": "OS64", - "DEPLOYMENT_TARGET": "17.0" + "DEPLOYMENT_TARGET": "17.0", + "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" }, "condition": { "lhs": "${hostSystemName}", @@ -80,7 +81,8 @@ "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/third-party/ios-cmake/ios.toolchain.cmake", "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/ios.cmake", "PLATFORM": "SIMULATORARM64", - "DEPLOYMENT_TARGET": "17.0" + "DEPLOYMENT_TARGET": "17.0", + "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" }, "condition": { "lhs": "${hostSystemName}", @@ -117,38 +119,118 @@ } }, { - "name": "llm", - "displayName": "Build LLM libraries", - "inherits": ["common"], - "cacheVariables": { - "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/llm.cmake", - "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" - }, - "condition": { - "type": "inList", - "string": "${hostSystemName}", - "list": ["Darwin", "Linux", "Windows"] - } + "name": "llm", + "displayName": "Build LLM libraries", + "inherits": [ + "common" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/llm.cmake", + "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Darwin", "Linux", "Windows"] + } }, { - "name": "profiling", - "displayName": "Build ExecuTorch with Profiling Enabled", - "inherits": [ - "common" - ], - "cacheVariables": { - "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/profiling.cmake", - "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" - }, - "condition": { - "type": "inList", - "string": "${hostSystemName}", - "list": [ - "Darwin", - "Linux", - "Windows" - ] - } + "name": "llm-release", + "displayName": "LLM release build", + "inherits": [ + "llm" + ], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/cmake-out" + } + }, + { + "name": "llm-release-cuda", + "displayName": "LLM release build with CUDA", + "inherits": [ + "llm-release" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Linux", "Windows"] + } + }, + { + "name": "llm-release-metal", + "displayName": "LLM release build with Metal", + "inherits": [ + "llm-release" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_METAL": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } + }, + { + "name": "llm-debug", + "displayName": "LLM debug build", + "inherits": [ + "llm" + ], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/cmake-out" + } + }, + { + "name": "llm-debug-cuda", + "displayName": "LLM debug build with CUDA", + "inherits": [ + "llm-debug" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Linux" + } + }, + { + "name": "llm-debug-metal", + "displayName": "LLM debug build with Metal", + "inherits": [ + "llm-debug" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_METAL": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } + }, + { + "name": "profiling", + "displayName": "Build ExecuTorch with Profiling Enabled", + "inherits": [ + "common" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/profiling.cmake", + "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Darwin", "Linux", "Windows"] + } }, { "name": "windows", @@ -175,13 +257,155 @@ } }, { - "name": "arm-baremetal", - "displayName": "Build ExecuTorch for Arm baremetal", - "inherits": ["common"], - "cacheVariables": { - "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/arm_baremetal.cmake", - "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake" - } + "name": "arm-baremetal", + "displayName": "Build ExecuTorch for Arm baremetal", + "inherits": ["common"], + "cacheVariables": { + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/arm_baremetal.cmake", + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake" + } + } + ], + "buildPresets": [ + { + "name": "llm-release-install", + "displayName": "Build and install LLM extension release artifacts", + "configurePreset": "llm-release", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "llm-release-cuda-install", + "displayName": "Build and install LLM extension release artifacts (CUDA)", + "configurePreset": "llm-release-cuda", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "llm-release-metal-install", + "displayName": "Build and install LLM extension release artifacts (Metal)", + "configurePreset": "llm-release-metal", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "llm-debug-install", + "displayName": "Build and install LLM extension debug artifacts", + "configurePreset": "llm-debug", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "llm-debug-cuda-install", + "displayName": "Build and install LLM extension debug artifacts (CUDA)", + "configurePreset": "llm-debug-cuda", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "llm-debug-metal-install", + "displayName": "Build and install LLM extension debug artifacts (Metal)", + "configurePreset": "llm-debug-metal", + "targets": [ + "install" + ], + "jobs": 0 + } + ], + "workflowPresets": [ + { + "name": "llm-release", + "displayName": "Configure, build and install ExecuTorch LLM extension with default CPU backend", + "steps": [ + { + "type": "configure", + "name": "llm-release" + }, + { + "type": "build", + "name": "llm-release-install" + } + ] + }, + { + "name": "llm-release-cuda", + "displayName": "Configure, build and install ExecuTorch LLM extension with CUDA enabled", + "steps": [ + { + "type": "configure", + "name": "llm-release-cuda" + }, + { + "type": "build", + "name": "llm-release-cuda-install" + } + ] + }, + { + "name": "llm-release-metal", + "displayName": "Configure, build and install ExecuTorch LLM extension with Metal enabled", + "steps": [ + { + "type": "configure", + "name": "llm-release-metal" + }, + { + "type": "build", + "name": "llm-release-metal-install" + } + ] + }, + { + "name": "llm-debug", + "displayName": "Configure, build and install ExecuTorch LLM extension with default CPU backend (Debug)", + "steps": [ + { + "type": "configure", + "name": "llm-debug" + }, + { + "type": "build", + "name": "llm-debug-install" + } + ] + }, + { + "name": "llm-debug-cuda", + "displayName": "Configure, build and install ExecuTorch LLM extension with CUDA enabled (Debug)", + "steps": [ + { + "type": "configure", + "name": "llm-debug-cuda" + }, + { + "type": "build", + "name": "llm-debug-cuda-install" + } + ] + }, + { + "name": "llm-debug-metal", + "displayName": "Configure, build and install ExecuTorch LLM extension with Metal enabled (Debug)", + "steps": [ + { + "type": "configure", + "name": "llm-debug-metal" + }, + { + "type": "build", + "name": "llm-debug-metal-install" + } + ] } ] } diff --git a/CODEOWNERS b/CODEOWNERS index 10baed9ede4..55108026d4e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -14,6 +14,7 @@ /backends/transforms @kimishpatel /backends/vulkan @SS-JIA /backends/xnnpack @digantdesai @mcr229 +/backends/nxp @robert-kalmar /devtools @Gasoonjia @@ -33,6 +34,7 @@ /examples/qualcomm @cccclai /examples/selective_build @lucylq @larryliu0820 @JacobSzwejbka /examples/xnnpack @digantdesai @mcr229 +/examples/nxp @robert-kalmar /exir/backend @cccclai @kimishpatel @JacobSzwejbka /exir @JacobSzwejbka @larryliu0820 @@ -47,31 +49,31 @@ /extension/export_util @kimishpatel /extension/flat_tensor @lucylq /extension/gguf_util @larryliu0820 -/extension/kernel_util @kimishpatel @manuelcandales @swolchok -/extension/llm @jackzhxng @larryliu0820 @swolchok @mergennachin -/extension/memory_allocator @JacobSzwejbka @swolchok +/extension/kernel_util @kimishpatel @manuelcandales +/extension/llm @jackzhxng @larryliu0820 @mergennachin +/extension/memory_allocator @JacobSzwejbka /extension/module @shoumikhin -/extension/parallel @kimishpatel @swolchok +/extension/parallel @kimishpatel /extension/pybindings @JacobSzwejbka @larryliu0820 -/extension/pytree @JacobSzwejbka @swolchok -/extension/runner_util @swolchok +/extension/pytree @JacobSzwejbka +/extension/runner_util /extension/tensor @shoumikhin -/extension/testing_util @swolchok -/extension/threadpool @kimishpatel @swolchok +/extension/testing_util +/extension/threadpool @kimishpatel /extension/training @JacobSzwejbka -/kernels @manuelcandales @swolchok +/kernels @manuelcandales /profiler @Gasoonjia -/runtime @JacobSzwejbka @lucylq @swolchok +/runtime @JacobSzwejbka @lucylq /runtime/backend @cccclai /schema @JacobSzwejbka @lucylq -/scripts @GregoryComer @swolchok +/scripts @GregoryComer -/shim @larryliu0820 @GregoryComer @swolchok +/shim @larryliu0820 @GregoryComer /third-party @GregoryComer diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2f4de863dad..4645ff86725 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,8 +24,8 @@ For Apple, please refer to the [iOS documentation](docs/source/using-executorch- executorch ├── backends - Backend delegate implementations for various hardware targets. Each backend uses partitioner to split the graph into subgraphs that can be executed on specific hardware, quantizer to optimize model precision, and runtime components to execute the graph on target hardware. For details refer to the backend documentation and the Export and Lowering tutorial for more information. │ ├── apple - Apple-specific backends. -│ │ ├── coreml - CoreML backend for Apple devices. See doc. -│ │ └── mps - Metal Performance Shaders backend for Apple devices. See doc. +│ │ ├── coreml - CoreML backend for Apple devices. See doc. +│ │ └── mps - Metal Performance Shaders backend for Apple devices. See doc. │ ├── arm - ARM architecture backends. See doc. │ ├── cadence - Cadence-specific backends. See doc. │ ├── example - Example backend implementations. @@ -33,8 +33,8 @@ executorch │ ├── openvino - OpenVINO backend for Intel hardware. │ ├── qualcomm - Qualcomm-specific backends. See doc. │ ├── transforms - Transformations for backend optimization. -│ ├── vulkan - Vulkan backend for cross-platform GPU support. See doc. -│ └── xnnpack - XNNPACK backend for optimized neural network operations. See doc. +│ ├── vulkan - Vulkan backend for cross-platform GPU support. See doc. +│ └── xnnpack - XNNPACK backend for optimized neural network operations. See doc. ├── codegen - Tooling to autogenerate bindings between kernels and the runtime. ├── configurations - Configuration files. ├── devtools - Model profiling, debugging, and inspection. Please refer to the tools documentation for more information. @@ -199,8 +199,7 @@ We use [`lintrunner`](https://pypi.org/project/lintrunner/) to help make sure th code follows our standards. Set it up with: ``` -pip install lintrunner==0.12.7 -pip install lintrunner-adapters==0.12.4 +./install_requirements.sh # (automatically run by install_executorch.sh) lintrunner init ``` diff --git a/LICENSE b/LICENSE index f20b198d808..c16e59652bb 100644 --- a/LICENSE +++ b/LICENSE @@ -9,6 +9,7 @@ Copyright (c) 2023 Apple Inc. Copyright (c) 2024 MediaTek Inc. Copyright 2023 NXP Copyright (c) 2025 Samsung Electronics Co. LTD +Copyright (c) Intel Corporation Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/Makefile b/Makefile new file mode 100644 index 00000000000..13fc941e135 --- /dev/null +++ b/Makefile @@ -0,0 +1,199 @@ +# ============================================================================== +# ExecuTorch Targets Makefile +# ============================================================================== +# +# This Makefile provides convenient targets for building ExecuTorch model runners +# with different backend configurations (CPU, CUDA, Metal), as well as other +# binary targets. +# +# WHAT THIS BUILDS: +# ----------------- +# Each target builds: +# 1. ExecuTorch core libraries with the specified backend (CPU, CUDA, or Metal) +# 2. The model-specific runner executable in cmake-out/examples/models// +# +# SUPPORTED MODELS: +# ----------------- +# - voxtral: Multimodal voice + text model (CPU, CUDA, Metal) +# - whisper: Speech recognition model (CPU, CUDA, Metal) +# - llama: Text generation model (CPU) +# - llava: Vision + language model (CPU) +# - gemma3: Text generation model (CPU, CUDA) +# +# USAGE: +# ------ +# make - # Build a specific model with a backend +# make help # Show all available targets +# make clean # Remove all build artifacts +# +# Examples: +# make voxtral-cuda # Build Voxtral with CUDA backend +# make llama-cpu # Build Llama with CPU backend +# make whisper-metal # Build Whisper with Metal backend (macOS) +# +# HOW TO ADD A NEW MODEL: +# ----------------------- +# To add a new model (e.g., "mymodel"), follow these steps: +# +# 1. Create a CMakePresets.json in examples/models/mymodel/: +# - Define configurePresets for each backend (base, cpu, cuda, metal) +# - Define buildPresets with the target name from CMakeLists.txt +# - Define workflowPresets that combine configure + build steps +# - See examples/models/voxtral/CMakePresets.json for multi-backend reference +# - Or see examples/models/llama/CMakePresets.json for simple single-preset reference +# +# 2. Add targets to this Makefile: +# a) Add to .PHONY declaration: mymodel-cuda mymodel-cpu mymodel-metal +# b) Add help text in the help target +# c) Add target implementations following this pattern: +# +# mymodel-cuda: +# @echo "==> Building and installing ExecuTorch with CUDA..." +# cmake --workflow --preset llm-release-cuda +# @echo "==> Building MyModel runner with CUDA..." +# cd examples/models/mymodel && cmake --workflow --preset mymodel-cuda +# @echo "" +# @echo "✓ Build complete!" +# @echo " Binary: cmake-out/examples/models/mymodel/mymodel_runner" +# +# mymodel-cpu: +# @echo "==> Building and installing ExecuTorch..." +# cmake --workflow --preset llm-release +# @echo "==> Building MyModel runner (CPU)..." +# cd examples/models/mymodel && cmake --workflow --preset mymodel-cpu +# @echo "" +# @echo "✓ Build complete!" +# @echo " Binary: cmake-out/examples/models/mymodel/mymodel_runner" +# +# mymodel-metal: +# @echo "==> Building and installing ExecuTorch with Metal..." +# cmake --workflow --preset llm-release-metal +# @echo "==> Building MyModel runner with Metal..." +# cd examples/models/mymodel && cmake --workflow --preset mymodel-metal +# @echo "" +# @echo "✓ Build complete!" +# @echo " Binary: cmake-out/examples/models/mymodel/mymodel_runner" +# +# 3. Test your new targets: +# make mymodel-cpu # or mymodel-cuda, mymodel-metal +# +# NOTES: +# ------ +# - CUDA backend is only available on Linux systems +# - Metal backend is only available on macOS (Darwin) systems +# - Some models may not support all backends (check model documentation) +# - Binary outputs are located in cmake-out/examples/models// +# - The preset names in CMakePresets.json must match the names used in Makefile +# +# ============================================================================== + +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal whisper-cuda whisper-cpu whisper-metal llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help + +help: + @echo "This Makefile adds targets to build runners for various models on various backends. Run using `make `. Available targets:" + @echo " voxtral-cuda - Build Voxtral runner with CUDA backend" + @echo " voxtral-cpu - Build Voxtral runner with CPU backend" + @echo " voxtral-metal - Build Voxtral runner with Metal backend (macOS only)" + @echo " whisper-cuda - Build Whisper runner with CUDA backend" + @echo " whisper-cpu - Build Whisper runner with CPU backend" + @echo " whisper-metal - Build Whisper runner with Metal backend (macOS only)" + @echo " llama-cpu - Build Llama runner with CPU backend" + @echo " llava-cpu - Build Llava runner with CPU backend" + @echo " gemma3-cuda - Build Gemma3 runner with CUDA backend" + @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" + @echo " clean - Clean build artifacts" + +voxtral-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building Voxtral runner with CUDA..." + cd examples/models/voxtral && cmake --workflow --preset voxtral-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral/voxtral_runner" + +voxtral-cpu: + @echo "==> Building and installing ExecuTorch..." + cmake --workflow --preset llm-release + @echo "==> Building Voxtral runner (CPU)..." + cd examples/models/voxtral && cmake --workflow --preset voxtral-cpu + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral/voxtral_runner" + +voxtral-metal: + @echo "==> Building and installing ExecuTorch with Metal..." + cmake --workflow --preset llm-release-metal + @echo "==> Building Voxtral runner with Metal..." + cd examples/models/voxtral && cmake --workflow --preset voxtral-metal + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral/voxtral_runner" + +whisper-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building Whisper runner with CUDA..." + cd examples/models/whisper && cmake --workflow --preset whisper-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/whisper/whisper_runner" + +whisper-cpu: + @echo "==> Building and installing ExecuTorch..." + cmake --workflow --preset llm-release + @echo "==> Building Whisper runner (CPU)..." + cd examples/models/whisper && cmake --workflow --preset whisper-cpu + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/whisper/whisper_runner" + +whisper-metal: + @echo "==> Building and installing ExecuTorch with Metal..." + cmake --workflow --preset llm-release-metal + @echo "==> Building Whisper runner with Metal..." + cd examples/models/whisper && cmake --workflow --preset whisper-metal + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/whisper/whisper_runner" + +llama-cpu: + @echo "==> Building and installing ExecuTorch..." + cmake --workflow --preset llm-release + @echo "==> Building Llama runner (CPU)..." + cd examples/models/llama && cmake --workflow --preset llama-release + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/llama/llama_main" + +llava-cpu: + @echo "==> Building and installing ExecuTorch..." + cmake --workflow --preset llm-release + @echo "==> Building Llava runner (CPU)..." + cd examples/models/llava && cmake --workflow --preset llava + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/llava/llava_main" + +gemma3-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building Gemma3 runner with CUDA..." + cd examples/models/gemma3 && cmake --workflow --preset gemma3-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/gemma3/gemma3_e2e_runner" + +gemma3-cpu: + @echo "==> Building and installing ExecuTorch..." + cmake --workflow --preset llm-release + @echo "==> Building Gemma3 runner (CPU)..." + cd examples/models/gemma3 && cmake --workflow --preset gemma3-cpu + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/gemma3/gemma3_e2e_runner" + +clean: + rm -rf cmake-out \ + extension/llm/tokenizers/build \ + extension/llm/tokenizers/pytorch_tokenizers.egg-info diff --git a/README-wheel.md b/README-wheel.md index a59af8ea05f..a1e70a2daef 100644 --- a/README-wheel.md +++ b/README-wheel.md @@ -5,14 +5,14 @@ ExecuTorch is to enable wider customization and deployment capabilities of the PyTorch programs. The `executorch` pip package is in beta. -* Supported python versions: 3.10, 3.11, 3.12 +* Supported python versions: 3.10, 3.11, 3.12, 3.13 * Compatible systems: Linux x86_64, macOS aarch64 The prebuilt `executorch.runtime` module included in this package provides a way to run ExecuTorch `.pte` files, with some restrictions: * Only [core ATen operators](docs/source/ir-ops-set-definition.md) are linked into the prebuilt module -* Only the [XNNPACK backend delegate](docs/source/backends-xnnpack.md) is linked into the prebuilt module. -* \[macOS only] [Core ML](docs/source/backends-coreml.md) and [MPS](docs/source/backends-mps.md) backend +* Only the [XNNPACK backend delegate](docs/source/backends/xnnpack/xnnpack-overview.md) is linked into the prebuilt module. +* \[macOS only] [Core ML](docs/source/backends/coreml/coreml-overview.md) and [MPS](docs/source/backends/mps/mps-overview.md) backend are also linked into the prebuilt module. Please visit the [ExecuTorch website](https://pytorch.org/executorch) for @@ -25,6 +25,6 @@ tutorials and documentation. Here are some starting points: * [Exporting to ExecuTorch](https://pytorch.org/executorch/main/tutorials/export-to-executorch-tutorial) * Learn the fundamentals of exporting a PyTorch `nn.Module` to ExecuTorch, and optimizing its performance using quantization and hardware delegation. -* Running etLLM on [iOS](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/apple) and [Android](docs/source/llm/llama-demo-android.md) devices. +* Running etLLM on [iOS](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/apple) and [Android](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/android) devices. * Build and run LLaMA in a demo mobile app, and learn how to integrate models with your own apps. diff --git a/README.md b/README.md index 17327990a1d..072c13b26f7 100644 --- a/README.md +++ b/README.md @@ -1,72 +1,254 @@
- Logo -

ExecuTorch: A powerful on-device AI Framework

+ ExecuTorch logo mark +

ExecuTorch

+

On-device AI inference powered by PyTorch

-
- Contributors - Stargazers - Join our Discord community - Check out the documentation -
+ PyPI - Version + GitHub - Contributors + GitHub - Stars + Discord - Chat with Us + Documentation
-**ExecuTorch** is an end-to-end solution for on-device inference and training. It powers much of Meta's on-device AI experiences across Facebook, Instagram, Meta Quest, Ray-Ban Meta Smart Glasses, WhatsApp, and more. +**ExecuTorch** is PyTorch's unified solution for deploying AI models on-device—from smartphones to microcontrollers—built for privacy, performance, and portability. It powers Meta's on-device AI across **Instagram, WhatsApp, Quest 3, Ray-Ban Meta Smart Glasses**, and [more](https://docs.pytorch.org/executorch/main/success-stories.html). + +Deploy **LLMs, vision, speech, and multimodal models** with the same PyTorch APIs you already know—accelerating research to production with seamless model export, optimization, and deployment. No manual C++ rewrites. No format conversions. No vendor lock-in. + +
+ 📘 Table of Contents + +- [Why ExecuTorch?](#why-executorch) +- [How It Works](#how-it-works) +- [Quick Start](#quick-start) + - [Installation](#installation) + - [Export and Deploy in 3 Steps](#export-and-deploy-in-3-steps) + - [Run on Device](#run-on-device) + - [LLM Example: Llama](#llm-example-llama) +- [Platform & Hardware Support](#platform--hardware-support) +- [Production Deployments](#production-deployments) +- [Examples & Models](#examples--models) +- [Key Features](#key-features) +- [Documentation](#documentation) +- [Community & Contributing](#community--contributing) +- [License](#license) + +
+ +## Why ExecuTorch? + +- **🔒 Native PyTorch Export** — Direct export from PyTorch. No .onnx, .tflite, or intermediate format conversions. Preserve model semantics. +- **⚡ Production-Proven** — Powers billions of users at [Meta with real-time on-device inference](https://engineering.fb.com/2025/07/28/android/executorch-on-device-ml-meta-family-of-apps/). +- **💾 Tiny Runtime** — 50KB base footprint. Runs on microcontrollers to high-end smartphones. +- **🚀 [12+ Hardware Backends](https://docs.pytorch.org/executorch/main/backends-overview.html)** — Open-source acceleration for Apple, Qualcomm, ARM, MediaTek, Vulkan, and more. +- **🎯 One Export, Multiple Backends** — Switch hardware targets with a single line change. Deploy the same model everywhere. + +## How It Works + +ExecuTorch uses **ahead-of-time (AOT) compilation** to prepare PyTorch models for edge deployment: + +1. **🧩 Export** — Capture your PyTorch model graph with `torch.export()` +2. **⚙️ Compile** — Quantize, optimize, and partition to hardware backends → `.pte` +3. **🚀 Execute** — Load `.pte` on-device via lightweight C++ runtime + +Models use a standardized [Core ATen operator set](https://docs.pytorch.org/executorch/main/compiler-ir-advanced.html#intermediate-representation). [Partitioners](https://docs.pytorch.org/executorch/main/compiler-delegate-and-partitioner.html) delegate subgraphs to specialized hardware (NPU/GPU) with CPU fallback. + +Learn more: [How ExecuTorch Works](https://docs.pytorch.org/executorch/main/intro-how-it-works.html) • [Architecture Guide](https://docs.pytorch.org/executorch/main/getting-started-architecture.html) + +## Quick Start + +### Installation + +```bash +pip install executorch +``` + +For platform-specific setup (Android, iOS, embedded systems), see the [Quick Start](https://docs.pytorch.org/executorch/main/quick-start-section.html) documentation for additional info. + +### Export and Deploy in 3 Steps + +```python +import torch +from executorch.exir import to_edge_transform_and_lower +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner + +# 1. Export your PyTorch model +model = MyModel().eval() +example_inputs = (torch.randn(1, 3, 224, 224),) +exported_program = torch.export.export(model, example_inputs) + +# 2. Optimize for target hardware (switch backends with one line) +program = to_edge_transform_and_lower( + exported_program, + partitioner=[XnnpackPartitioner()] # CPU | CoreMLPartitioner() for iOS | QnnPartitioner() for Qualcomm +).to_executorch() + +# 3. Save for deployment +with open("model.pte", "wb") as f: + f.write(program.buffer) + +# Test locally via ExecuTorch runtime's pybind API (optional) +from executorch.runtime import Runtime +runtime = Runtime.get() +method = runtime.load_program("model.pte").load_method("forward") +outputs = method.execute([torch.randn(1, 3, 224, 224)]) +``` + +### Run on Device + +**[C++](https://docs.pytorch.org/executorch/main/using-executorch-cpp.html)** +```cpp +#include +#include + +Module module("model.pte"); +auto tensor = make_tensor_ptr({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); +auto outputs = module.forward(tensor); +``` + +**[Swift (iOS)](https://docs.pytorch.org/executorch/main/ios-section.html)** +```swift +import ExecuTorch + +let module = Module(filePath: "model.pte") +let input = Tensor([1.0, 2.0, 3.0, 4.0], shape: [2, 2]) +let outputs = try module.forward(input) +``` + +**[Kotlin (Android)](https://docs.pytorch.org/executorch/main/android-section.html)** +```kotlin +val module = Module.load("model.pte") +val inputTensor = Tensor.fromBlob(floatArrayOf(1.0f, 2.0f, 3.0f, 4.0f), longArrayOf(2, 2)) +val outputs = module.forward(EValue.from(inputTensor)) +``` -It supports a wide range of models including LLMs (Large Language Models), CV (Computer Vision), ASR (Automatic Speech Recognition), and TTS (Text to Speech). +### LLM Example: Llama -Platform Support: -- Operating Systems: - - iOS - - MacOS (ARM64) - - Android - - Linux - - Microcontrollers +Export Llama models using the [`export_llm`](https://docs.pytorch.org/executorch/main/llm/export-llm.html) script or [Optimum-ExecuTorch](https://github.com/huggingface/optimum-executorch): -- Hardware Acceleration: - - Apple - - Arm - - Cadence - - MediaTek - - NXP - - OpenVINO - - Qualcomm - - Vulkan - - XNNPACK +```bash +# Using export_llm +python -m executorch.extension.llm.export.export_llm --model llama3_2 --output llama.pte -Key value propositions of ExecuTorch are: +# Using Optimum-ExecuTorch +optimum-cli export executorch \ + --model meta-llama/Llama-3.2-1B \ + --task text-generation \ + --recipe xnnpack \ + --output_dir llama_model +``` -- **Portability:** Compatibility with a wide variety of computing platforms, - from high-end mobile phones to highly constrained embedded systems and - microcontrollers. -- **Productivity:** Enabling developers to use the same toolchains and Developer - Tools from PyTorch model authoring and conversion, to debugging and deployment - to a wide variety of platforms. -- **Performance:** Providing end users with a seamless and high-performance - experience due to a lightweight runtime and utilizing full hardware - capabilities such as CPUs, NPUs, and DSPs. +Run on-device with the LLM runner API: -## Getting Started -To get started you can: +**[C++](https://docs.pytorch.org/executorch/main/llm/run-with-c-plus-plus.html)** +```cpp +#include -- Visit the [Step by Step Tutorial](https://pytorch.org/executorch/stable/getting-started.html) to get things running locally and deploy a model to a device -- Use this [Colab Notebook](https://colab.research.google.com/drive/1qpxrXC3YdJQzly3mRg-4ayYiOjC6rue3?usp=sharing) to start playing around right away -- Jump straight into LLM use cases by following specific instructions for popular open-source models such as [Llama](examples/models/llama/README.md), [Qwen 3](examples/models/qwen3/README.md), [Phi-4-mini](examples/models/phi_4_mini/README.md), [Llava](examples/models/llava/README.md), [Voxtral](examples/models/voxtral/README.md), and [LFM2](examples/models/lfm2/README.md). +auto runner = create_llama_runner("llama.pte", "tiktoken.bin"); +executorch::extension::llm::GenerationConfig config{ + .seq_len = 128, .temperature = 0.8f}; +runner->generate("Hello, how are you?", config); +``` -## Feedback and Engagement +**[Swift (iOS)](https://docs.pytorch.org/executorch/main/llm/run-on-ios.html)** +```swift +import ExecuTorchLLM -We welcome any feedback, suggestions, and bug reports from the community to help -us improve our technology. Check out the [Discussion Board](https://github.com/pytorch/executorch/discussions) or chat real time with us on [Discord](https://discord.gg/Dh43CKSAdc) +let runner = TextRunner(modelPath: "llama.pte", tokenizerPath: "tiktoken.bin") +try runner.generate("Hello, how are you?", Config { + $0.sequenceLength = 128 +}) { token in + print(token, terminator: "") +} +``` -## Contributing +**Kotlin (Android)** — [API Docs](https://docs.pytorch.org/executorch/main/javadoc/org/pytorch/executorch/extension/llm/package-summary.html) • [Demo App](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/android/LlamaDemo) +```kotlin +val llmModule = LlmModule("llama.pte", "tiktoken.bin", 0.8f) +llmModule.load() +llmModule.generate("Hello, how are you?", 128, object : LlmCallback { + override fun onResult(result: String) { print(result) } + override fun onStats(stats: String) { } +}) +``` -We welcome contributions. To get started review the [guidelines](CONTRIBUTING.md) and chat with us on [Discord](https://discord.gg/Dh43CKSAdc) +For multimodal models (vision, audio), use the [MultiModal runner API](extension/llm/runner) which extends the LLM runner to handle image and audio inputs alongside text. See [Llava](examples/models/llava/README.md) and [Voxtral](examples/models/voxtral/README.md) examples. +See [examples/models/llama](examples/models/llama/README.md) for complete workflow including quantization, mobile deployment, and advanced options. -## Directory Structure +**Next Steps:** +- 📖 [Step-by-step tutorial](https://docs.pytorch.org/executorch/main/getting-started.html) — Complete walkthrough for your first model +- ⚡ [Colab notebook](https://colab.research.google.com/drive/1qpxrXC3YdJQzly3mRg-4ayYiOjC6rue3?usp=sharing) — Try ExecuTorch instantly in your browser +- 🤖 [Deploy Llama models](examples/models/llama/README.md) — LLM workflow with quantization and mobile demos -Please refer to the [Codebase structure](CONTRIBUTING.md#codebase-structure) section of the [Contributing Guidelines](CONTRIBUTING.md) for more details. +## Platform & Hardware Support + +| **Platform** | **Supported Backends** | +|------------------|----------------------------------------------------------| +| Android | XNNPACK, Vulkan, Qualcomm, MediaTek, Samsung Exynos | +| iOS | XNNPACK, MPS, CoreML (Neural Engine) | +| Linux / Windows | XNNPACK, OpenVINO, CUDA *(experimental)* | +| macOS | XNNPACK, MPS, Metal *(experimental)* | +| Embedded / MCU | XNNPACK, ARM Ethos-U, NXP, Cadence DSP | + +See [Backend Documentation](https://docs.pytorch.org/executorch/main/backends-overview.html) for detailed hardware requirements and optimization guides. + +## Production Deployments + +ExecuTorch powers on-device AI at scale across Meta's family of apps, VR/AR devices, and partner deployments. [View success stories →](https://docs.pytorch.org/executorch/main/success-stories.html) + +## Examples & Models + +**LLMs:** [Llama 3.2/3.1/3](examples/models/llama/README.md), [Qwen 3](examples/models/qwen3/README.md), [Phi-4-mini](examples/models/phi_4_mini/README.md), [LiquidAI LFM2](examples/models/lfm2/README.md) + +**Multimodal:** [Llava](examples/models/llava/README.md) (vision-language), [Voxtral](examples/models/voxtral/README.md) (audio-language), [Gemma](examples/models/gemma3) (vision-language) + +**Vision/Speech:** [MobileNetV2](https://github.com/meta-pytorch/executorch-examples/tree/main/mv2), [DeepLabV3](https://github.com/meta-pytorch/executorch-examples/tree/main/dl3), [Whisper](https://github.com/meta-pytorch/executorch-examples/tree/main/whisper/android/WhisperApp) + +**Resources:** [`examples/`](examples/) directory • [executorch-examples](https://github.com/meta-pytorch/executorch-examples) out-of-tree demos • [Optimum-ExecuTorch](https://github.com/huggingface/optimum-executorch) for HuggingFace models + +## Key Features + +ExecuTorch provides advanced capabilities for production deployment: + +- **Quantization** — Built-in support via [torchao](https://docs.pytorch.org/ao) for 8-bit, 4-bit, and dynamic quantization +- **Memory Planning** — Optimize memory usage with ahead-of-time allocation strategies +- **Developer Tools** — ETDump profiler, ETRecord inspector, and model debugger +- **Selective Build** — Strip unused operators to minimize binary size +- **Custom Operators** — Extend with domain-specific kernels +- **Dynamic Shapes** — Support variable input sizes with bounded ranges + +See [Advanced Topics](https://docs.pytorch.org/executorch/main/advanced-topics-section.html) for quantization techniques, custom backends, and compiler passes. + +## Documentation + +- [**Documentation Home**](https://docs.pytorch.org/executorch/main/index.html) — Complete guides and tutorials +- [**API Reference**](https://docs.pytorch.org/executorch/main/api-section.html) — Python, C++, Java/Kotlin APIs +- [**Backend Integration**](https://docs.pytorch.org/executorch/main/backend-delegates-integration.html) — Build custom hardware backends +- [**Troubleshooting**](https://docs.pytorch.org/executorch/main/support-section.html) — Common issues and solutions + +## Community & Contributing + +We welcome contributions from the community! + +- 💬 [**GitHub Discussions**](https://github.com/pytorch/executorch/discussions) — Ask questions and share ideas +- 🎮 [**Discord**](https://discord.gg/Dh43CKSAdc) — Chat with the team and community +- 🐛 [**Issues**](https://github.com/pytorch/executorch/issues) — Report bugs or request features +- 🤝 [**Contributing Guide**](CONTRIBUTING.md) — Guidelines and codebase structure ## License -ExecuTorch is BSD licensed, as found in the LICENSE file. + +ExecuTorch is BSD licensed, as found in the [LICENSE](LICENSE) file. + +

+ +--- + +
+

Part of the PyTorch ecosystem

+

+ GitHub • + Documentation +

+
diff --git a/backends/aoti/CMakeLists.txt b/backends/aoti/CMakeLists.txt new file mode 100644 index 00000000000..d5582dfe7c7 --- /dev/null +++ b/backends/aoti/CMakeLists.txt @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Build AOTI backend for runtime. +# +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +# Use ExecuTorch's standard way to find PyTorch libraries for AOTI +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +find_package_torch() + +# Common AOTI functionality - combines all AOTI common components +set(_aoti_common_sources common_shims.cpp) +add_library(aoti_common STATIC ${_aoti_common_sources}) +target_include_directories( + aoti_common + PUBLIC $ $ + $ +) +target_compile_options( + aoti_common + PUBLIC $<$:/EHsc /GR> + $<$>:-fexceptions -frtti -fPIC> +) +target_compile_definitions( + aoti_common PRIVATE $<$:EXPORT_AOTI_FUNCTIONS> +) +# Ensure symbols are exported properly +if(APPLE) + target_link_options(aoti_common PUBLIC -Wl,-export_dynamic) +else() + target_link_options( + aoti_common PUBLIC $<$>:-Wl,--export-dynamic> + ) +endif() + +# Link against ExecuTorch libraries and standard libraries +target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS}) +executorch_target_link_options_shared_lib(aoti_common) + +install( + TARGETS aoti_common + EXPORT ExecuTorchTargets + DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/backends/aoti/README.md b/backends/aoti/README.md new file mode 100644 index 00000000000..74b45a35e5d --- /dev/null +++ b/backends/aoti/README.md @@ -0,0 +1,28 @@ +# AOTI Common Library + +This directory contains **common library components** for AOTI (Ahead-of-Time Inference) driven backends in ExecutorTorch, **not a standalone backend**. + +## Purpose + +The code in this directory provides shared functionality and utilities that are used by actual AOTI-driven backends such as: + +- **CUDA backend** - Uses AOTI for GPU acceleration +- Other AOTI-powered backends + +## Components + +- **`common_shims.cpp/h`** - Common shim functions that bridge ExecuTorch tensor operations with AOTI requirements +- **`aoti_model_container.cpp/h`** - Model container functionality for AOTI models +- **`utils.h`** - Utility functions and type definitions +- **`tests/`** - Unit tests for the common functionality + +## Usage + +This library is intended to be used as a dependency by actual AOTI backend implementations. It is not a backend that can be used directly for model execution. + +For example backend implementations that use this common library, see: +- `executorch/backends/cuda/` - CUDA AOTI backend + +## Building + +The common library components are built as part of the AOTI backend build process. See the `TARGETS` file for build configurations. diff --git a/backends/aoti/TARGETS b/backends/aoti/TARGETS new file mode 100644 index 00000000000..77871de4469 --- /dev/null +++ b/backends/aoti/TARGETS @@ -0,0 +1,3 @@ +load("targets.bzl", "define_common_targets") + +define_common_targets() diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py new file mode 100644 index 00000000000..c2c587da9fe --- /dev/null +++ b/backends/aoti/aoti_backend.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import os +import typing +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Set + +import torch +from executorch.backends.aoti.passes.replace_view_copy_with_view import ( + ReplaceViewCopyWithViewPass, +) +from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import ExportedProgram, PreprocessResult +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch.export.passes import move_to_device_pass + + +class COMPILE_SPEC_KEYS(Enum): + METHOD_NAME = "method_name" + + +@experimental( + "This API and all of aoti-driven backend related functionality are experimental." +) +class AotiBackend(ABC): + """ + Base mixin class for AOTInductor-based backends. + + This class provides common functionality for compiling models using AOTInductor + with different device targets (CUDA, Metal, etc.). + + This is a mixin class, not an actual backend object, for aoti-driven backends. + Concrete backends (e.g., CudaBackend, MetalBackend) should inherit from both + BackendDetails and AotiBackend to get the full functionality. + """ + + @classmethod + @abstractmethod + def get_device_name(cls) -> str: + """Return the device name for this backend (e.g., 'cuda', 'metal').""" + pass + + @classmethod + @abstractmethod + def get_supported_fallback_kernels(cls) -> Dict[str, Any]: + """Return the set of supported fallback kernels for this backend.""" + pass + + @classmethod + @abstractmethod + def get_decomposition_table(cls) -> Dict[Any, Any]: + """Return the decomposition table for this backend.""" + pass + + @classmethod + @abstractmethod + def get_aoti_compile_options( + cls, compile_specs: List[CompileSpec] + ) -> Dict[str, typing.Any]: + """Return the AOTInductor compilation options for this backend.""" + pass + + @classmethod + @abstractmethod + def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]: + """Return the list of custom passes to apply after ReplaceViewCopyWithViewPass and before decomposition.""" + pass + + @classmethod + def get_extra_aoti_compile_context_manager(cls): + """Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager.""" + return contextlib.nullcontext() + + @classmethod + @contextlib.contextmanager + def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]): + """ + Context manager to collect unsupported fallback kernels during compilation. + Monitors both extern kernel calls and runtime lookup. + """ + supported_kernels = cls.get_supported_fallback_kernels() + + original_generate_c_shim_extern_kernel_call = ( + CppWrapperCpu.generate_c_shim_extern_kernel_call + ) + original_generate_fallback_kernel_with_runtime_lookup_aot = ( + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot + ) + + def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( + self, kernel: str, *args: Any, **kwargs: Any + ) -> None: + if kernel not in supported_kernels: + missing_fallback_kernels.add(kernel) + + return original_generate_c_shim_extern_kernel_call( + self, kernel, *args, **kwargs + ) + + def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( + self, op_overload: Any, *args: Any, **kwargs: Any + ) -> None: + kernel_name = getattr(op_overload, "_name", str(op_overload)) + if kernel_name not in supported_kernels: + missing_fallback_kernels.add(kernel_name) + + return original_generate_fallback_kernel_with_runtime_lookup_aot( + self, op_overload, *args, **kwargs + ) + + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels + ) + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels + + try: + yield + finally: + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + original_generate_c_shim_extern_kernel_call + ) + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( + original_generate_fallback_kernel_with_runtime_lookup_aot + ) + + @classmethod + def preprocess( + cls, + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """ + Preprocess the edge program and compile it using AOTInductor. + Weights are always separated from the SO file. + """ + device_name = cls.get_device_name() + decomposition_table = cls.get_decomposition_table() + options = cls.get_aoti_compile_options(compile_specs) + + # Move the edge_program to the target device + device_edge_program = move_to_device_pass( + edge_program, device_name if device_name != "metal" else "mps" + ) + + # Replace view_copy with view + ReplaceViewCopyWithViewPass()(device_edge_program.graph_module) + + # Apply custom backend-specific passes + custom_passes = cls.get_custom_passes(compile_specs) + for custom_pass in custom_passes: + custom_pass(device_edge_program.graph_module) + + # Run decompositions if any + if decomposition_table: + device_edge_program = device_edge_program.run_decompositions( + decomposition_table + ) + + edge_program_module = device_edge_program.module() + + # Grab all input placeholders from the graph + user_input_names = device_edge_program.graph_signature.user_inputs + user_input_placeholders = [] + for node in device_edge_program.graph.nodes: + if node.op == "placeholder" and node.name in user_input_names: + user_input_placeholders.append(node.meta["val"]) + + # Track missing fallback kernels + missing_fallback_kernels: Set[str] = set() + + # Compile with fallback kernel collection + with cls.collect_unsupported_fallback_kernels( + missing_fallback_kernels + ), torch.no_grad(), cls.get_extra_aoti_compile_context_manager(): + paths = torch._inductor.aot_compile( + edge_program_module, tuple(user_input_placeholders), options=options + ) + + if len(missing_fallback_kernels) > 0: + formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) + method_name = cls.method_name_from_compile_specs(compile_specs) + raise RuntimeError( + f"Method {method_name} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" + "Please add them to the AOTI backend." + ) + + # Extract paths - weights are always separated + so_path = None + blob_path = None + + if isinstance(paths, list): + for path in paths: + if path.endswith(".wrapper.so"): + so_path = path + elif path.endswith(".wrapper_weights.blob"): + blob_path = path + else: + so_path = paths + + if so_path is None or blob_path is None: + raise RuntimeError( + f"Could not find required files in compiled paths, got {paths}" + ) + + # Read SO file + with open(so_path, "rb") as f: + so_data = f.read() + + # Read weights blob + with open(blob_path, "rb") as f: + blob_data = f.read() + + # Create named data store + named_data_store = NamedDataStore() + method_name = cls.method_name_from_compile_specs(compile_specs) + + # Add SO and weights blob separately + named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) + weights_blob_data_type = f"aoti_{device_name}_blob" + named_data_store.add_named_data( + method_name + "_weights_blob", blob_data, 1, weights_blob_data_type + ) + + # Clean up the generated files + os.remove(so_path) + os.remove(blob_path) + + return PreprocessResult( + processed_bytes=b"", + debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), + ) + + @classmethod + def generate_method_name_compile_spec( + cls, + method_name: str, + ) -> CompileSpec: + """ + Generate a CompileSpec for the given method name. + """ + return CompileSpec( + COMPILE_SPEC_KEYS.METHOD_NAME.value, + method_name.encode("utf-8"), + ) + + @classmethod + def method_name_from_compile_specs( + cls, + compile_specs: List[CompileSpec], + ) -> str: + """ + Extract the method name from the compile specs. + """ + for spec in compile_specs: + if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: + return spec.value.decode("utf-8") + raise RuntimeError( + f"Could not find method name in compile specs: {compile_specs}" + ) diff --git a/backends/aoti/aoti_delegate_handle.h b/backends/aoti/aoti_delegate_handle.h new file mode 100644 index 00000000000..82ce2521750 --- /dev/null +++ b/backends/aoti/aoti_delegate_handle.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +extern "C" { + +// Type definitions +using AOTITensorHandle = Tensor*; +using AOTIRuntimeError = Error; + +// Forward declarations for AOT Inductor model container +struct AOTInductorModelContainerOpaque; +using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*; +using AOTInductorStreamHandle = void*; +using AOTIProxyExecutorHandle = void*; + +// Function pointer types for AOT Inductor model container operations +using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + const char* device_str, + const char* cubin_dir); + +using AOTInductorModelContainerDeleteFunc = + AOTIRuntimeError (*)(AOTInductorModelContainerHandle container_handle); + +using AOTInductorModelContainerGetNumInputsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_inputs); + +using AOTInductorModelContainerGetNumOutputsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_outputs); + +using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + Tensor** input_handles, // array of input Tensor*; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + Tensor** output_handles, // array for writing output Tensor*; handles + // will be stolen by the caller; the array itself + // is borrowed + size_t n_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +// Retrieves the name of an input tensor by index from the AOTI model container. +using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** input_name); + +// Retrieves the number of constants from the AOTI model container. +using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants); + +// Update the model container with the constant tensors +using AOTInductorModelUpdateConstantsFromBlobFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + const uint8_t* weight_blob_ptr); + +} // extern "C" + +// AOTI Delegate Handle structure +struct AOTIDelegateHandle { + void* so_handle; + std::string so_path; + AOTInductorModelContainerHandle container_handle; + void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header + // dependency + + // Function pointers specific to this handle's shared library + AOTInductorModelContainerCreateWithDeviceFunc create_with_device; + AOTInductorModelContainerDeleteFunc delete_container; + AOTInductorModelContainerGetNumInputsFunc get_num_inputs; + AOTInductorModelContainerGetNumOutputsFunc get_num_outputs; + AOTInductorModelContainerRunFunc run; + AOTInductorModelUpdateConstantsFromBlobFunc update_constants_from_blob; +}; + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/aoti_partitioner.py b/backends/aoti/aoti_partitioner.py new file mode 100644 index 00000000000..aa56d3507e9 --- /dev/null +++ b/backends/aoti/aoti_partitioner.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from executorch.exir._warnings import experimental +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param +from torch.export.exported_program import ExportedProgram + + +@experimental( + "This API and all of cuda backend related functionality are experimental." +) +class AotiPartitioner(Partitioner): + """ + Base partitioner for AOTInductor-driven backend integration. + + This partitioner creates a single partition containing all operators from the input graph. + It skips core ATen decomposition, allowing the backend to handle decomposition using + AOTInductor's backend-specific decomposition table. + + Only operators that cannot be handled by the aoti library will be excluded from + the partition and fall back to ExecuTorch's default or custom handling. + """ + + def __init__(self, backend_name: str, compile_spec: List[CompileSpec]) -> None: + """ + Initialize the AOTI partitioner. + + Args: + backend_name: The name of the backend (e.g., "CudaBackend", "MetalBackend") + compile_spec: List of compilation specifications + """ + self.delegation_spec = DelegationSpec(backend_name, compile_spec) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """ + Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. + """ + + partition_tags: Dict[str, DelegationSpec] = {} + tag = "tag0" + + # Tag torch.cond and other control flow operations + def is_control_flow(node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in [ + torch.ops.higher_order.cond, + torch.ops.higher_order.map_impl, + torch.ops.higher_order.while_loop, + ] + + for node in exported_program.graph.nodes: + if node.op == "call_function": + node.meta["delegation_tag"] = tag + # Tag get_attr nodes that are used by control flow operations + elif node.op == "get_attr": + # Check if any user is a control flow operation + for user in node.users: + if is_control_flow(user): + node.meta["delegation_tag"] = tag + break + + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + tag_mutated_buffer(exported_program) + + # Tag constant placeholders that have no users + # tag_constant_data only tags constants that have users with delegation_tag + # but we need to tag all constants for this partition + for node in exported_program.graph.nodes: + if node.op == "placeholder" and ( + is_param(exported_program, node) + or is_buffer(exported_program, node) + or is_lifted_tensor_constant(exported_program, node) + ): + if "delegation_tag" not in node.meta: + node.meta["delegation_tag"] = tag + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """ + Return a list of operations that should not be decomposed and let the AOT compiler handle them. + Currently we skip ATen decompositon for all ops, and let the backend handle them. + """ + do_not_decompose = set() + + for node in ep.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + do_not_decompose.add(node.target) + return list(do_not_decompose), None diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp new file mode 100644 index 00000000000..abfde86db6d --- /dev/null +++ b/backends/aoti/common_shims.cpp @@ -0,0 +1,268 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +namespace internal { +// Global storage for tensor metadata +AOTI_SHIM_EXPORT std::unordered_map> + tensor_to_sizes; +AOTI_SHIM_EXPORT std::unordered_map> + tensor_to_strides; +} // namespace internal + +extern "C" { + +// Autograd mode functions +int32_t aoti_torch_grad_mode_is_enabled() { + // No autograd ever + return false; +} + +void aoti_torch_grad_mode_set_enabled(bool enabled) { + if (enabled) { + throw std::runtime_error("Cannot enable autograd"); + } +} + +// Tensor attribute operations +AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr) { + *ret_data_ptr = tensor->mutable_data_ptr(); + return Error::Ok; +} + +AOTITorchError aoti_torch_get_storage_offset( + Tensor* tensor, + int64_t* ret_storage_offset) { + // Storage offset is always 0 in ET + *ret_storage_offset = 0; + + return Error::Ok; +} + +AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) { + auto it = internal::tensor_to_strides.find(tensor); + bool needs_update = false; + + if (it == internal::tensor_to_strides.end()) { + needs_update = true; + } else { + // CRITICAL: Multimodal models reuse tensors with different shapes across + // executions (e.g., variable-length audio). We MUST validate cached + // metadata matches current tensor state, or CUDA kernels will receive + // incorrect shapes leading to memory corruption and segfaults. + auto tensor_strides = tensor->strides(); + needs_update = !std::equal( + it->second.begin(), + it->second.end(), + tensor_strides.begin(), + tensor_strides.end()); + } + + if (needs_update) { + std::vector strides(tensor->dim()); + auto tensor_strides = tensor->strides(); + for (int i = 0; i < tensor->dim(); i++) { + strides[i] = tensor_strides[i]; + } + it = + internal::tensor_to_strides.insert_or_assign(tensor, std::move(strides)) + .first; + } + + // For 0D tensors, data() returns nullptr on empty vectors, but we need to + // return a valid pointer + if (it->second.empty()) { + static int64_t empty_strides_placeholder = 0; + *ret_strides = &empty_strides_placeholder; + } else { + *ret_strides = it->second.data(); + } + + return Error::Ok; +} + +AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) { + *ret_dtype = static_cast(tensor->scalar_type()); + + return Error::Ok; +} + +AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) { + auto it = internal::tensor_to_sizes.find(tensor); + bool needs_update = false; + + if (it == internal::tensor_to_sizes.end()) { + needs_update = true; + } else { + // CRITICAL: Multimodal models reuse tensors with different shapes across + // executions (e.g., variable-length audio). We MUST validate cached + // metadata matches current tensor state, or CUDA kernels will receive + // incorrect shapes leading to memory corruption and segfaults. + auto tensor_sizes = tensor->sizes(); + needs_update = !std::equal( + it->second.begin(), + it->second.end(), + tensor_sizes.begin(), + tensor_sizes.end()); + } + + if (needs_update) { + std::vector sizes(tensor->dim()); + auto tensor_sizes = tensor->sizes(); + for (int i = 0; i < tensor->dim(); i++) { + sizes[i] = tensor_sizes[i]; + } + it = internal::tensor_to_sizes.insert_or_assign(tensor, std::move(sizes)) + .first; + } + + // For 0D tensors, data() returns nullptr on empty vectors, but we need to + // return a valid pointer + if (it->second.empty()) { + static int64_t empty_sizes_placeholder = 0; + *ret_sizes = &empty_sizes_placeholder; + } else { + *ret_sizes = it->second.data(); + } + + return Error::Ok; +} + +AOTITorchError aoti_torch_get_device_index( + Tensor* tensor, + int32_t* ret_device_index) { + // Let's assume all tensors AOTI using are on CUDA:0 + *ret_device_index = 0; + return Error::Ok; +} + +AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) { + *ret_dim = static_cast(tensor->dim()); + return Error::Ok; +} + +// Device and layout utility functions +int32_t aoti_torch_device_type_cpu() { + // Let's say cpu is 0 for ET as well + return 0; +} + +int32_t aoti_torch_layout_strided() { + // ET only support strided layout, the return value will always be 0, a.k.a + // at::Layout::Strided; + return 0; +} + +// Dtype constants - these return the PyTorch dtype codes +int32_t aoti_torch_dtype_float32() { + return 6; // PyTorch's float32 dtype code +} + +int32_t aoti_torch_dtype_bfloat16() { + return 15; // PyTorch's bfloat16 dtype code +} + +int32_t aoti_torch_dtype_int8() { + return 1; // PyTorch's int32 dtype code +} + +int32_t aoti_torch_dtype_int16() { + return 2; // PyTorch's int32 dtype code +} + +int32_t aoti_torch_dtype_int32() { + return 3; // PyTorch's int32 dtype code +} + +int32_t aoti_torch_dtype_bool() { + return 11; // PyTorch's bool dtype code +} + +int32_t aoti_torch_dtype_int64() { + return 4; // PyTorch's int64 dtype code +} + +// Dtype utility function needed by Metal backend. +// Returns the size of the dtype in bytes. +size_t aoti_torch_dtype_element_size(int32_t dtype) { + return dtype_to_element_size(dtype); +} + +// Cleanup functions +void cleanup_tensor_metadata() { + internal::tensor_to_sizes.clear(); + internal::tensor_to_strides.clear(); +} + +AOTI_SHIM_EXPORT void aoti_torch_warn( + const char* func, + const char* file, + uint32_t line, + const char* msg) { + ET_LOG(Error, "[%s:%u] %s: %s", file, line, func, msg); +} + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size) { + (void)tensor; + (void)ret_size; + throw std::runtime_error("Not implemented"); + return Error::Internal; +} + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor) { + (void)self; + (void)ret_new_tensor; + throw std::runtime_error("Not implemented"); + return Error::Internal; +} + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor) { + (void)self; + (void)ret_new_tensor; + throw std::runtime_error("Not implemented"); + return Error::Internal; +} + +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( + void* data_ptr, + int64_t ndim, + const int64_t* sizes, + const int64_t* strides, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor) { + (void)data_ptr; + (void)ndim; + (void)sizes; + (void)strides; + (void)storage_offset; + (void)dtype; + (void)device_type; + (void)device_index; + (void)ret_new_tensor; + throw std::runtime_error("Not implemented"); + return Error::Internal; +} + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h new file mode 100644 index 00000000000..675a9864e74 --- /dev/null +++ b/backends/aoti/common_shims.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +// Common using declarations for ExecuTorch types +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +// Global storage for tensor metadata +extern std::unordered_map> tensor_to_sizes; +extern std::unordered_map> tensor_to_strides; + +extern "C" { + +// Common AOTI type aliases +using AOTIRuntimeError = Error; +using AOTITorchError = Error; + +// Attribute-related operations (memory-irrelevant) +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_storage_offset(Tensor* tensor, int64_t* ret_storage_offset); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_device_index(Tensor* tensor, int32_t* ret_device_index); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim); + +// Utility functions for device and layout information +AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu(); +AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64(); + +// Dtype utility function needed by Metal backend +AOTI_SHIM_EXPORT size_t aoti_torch_dtype_element_size(int32_t dtype); + +// Autograd mode functions +AOTI_SHIM_EXPORT int32_t aoti_torch_grad_mode_is_enabled(); +AOTI_SHIM_EXPORT void aoti_torch_grad_mode_set_enabled(bool enabled); + +// Cleanup functions for clearing global state +AOTI_SHIM_EXPORT void cleanup_tensor_metadata(); + +AOTI_SHIM_EXPORT void aoti_torch_warn( + const char* func, + const char* file, + uint32_t line, + const char* msg); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor); + +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( + void* data_ptr, + int64_t ndim, + const int64_t* sizes, + const int64_t* strides, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor); + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/export.h b/backends/aoti/export.h new file mode 100644 index 00000000000..7c945f405b0 --- /dev/null +++ b/backends/aoti/export.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// Define export macro for Windows DLL +// When building the aoti_cuda_backend library, EXPORT_AOTI_FUNCTIONS is defined +// by CMake, which causes this macro to export symbols using +// __declspec(dllexport). When consuming the library, the macro imports symbols +// using +// __declspec(dllimport). On non-Windows platforms, the macro is empty and has +// no effect. +#ifdef _WIN32 +#ifdef EXPORT_AOTI_FUNCTIONS +#define AOTI_SHIM_EXPORT __declspec(dllexport) +#else +#define AOTI_SHIM_EXPORT __declspec(dllimport) +#endif +#else +#define AOTI_SHIM_EXPORT +#endif diff --git a/backends/aoti/passes/TARGETS b/backends/aoti/passes/TARGETS new file mode 100644 index 00000000000..82f3b40dc54 --- /dev/null +++ b/backends/aoti/passes/TARGETS @@ -0,0 +1,17 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "passes", + srcs = [ + "replace_view_copy_with_view.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + ], +) diff --git a/backends/aoti/passes/replace_view_copy_with_view.py b/backends/aoti/passes/replace_view_copy_with_view.py new file mode 100644 index 00000000000..c2be14f96e5 --- /dev/null +++ b/backends/aoti/passes/replace_view_copy_with_view.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This pass replaces view_copy ops with view ops. This is different than +# exir/passes/replace_view_copy_with_view.py and exir/passes/reinplace.py +# because this should only be used in the AOTInductor backend, as it +# has less restrictions on whether the tensor memory is densely packed, + +from typing import Dict, Iterable + +import torch +from executorch.exir.dialects._ops import ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from torch import fx + + +_VIEW_TARGETS: Dict[ + torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload +] = { + torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, + ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor, + torch.ops.aten.select_copy.int: torch.ops.aten.select.int, + ops.edge.aten.select_copy.int: ops.edge.aten.select.int, +} + + +class ReplaceViewCopyWithViewPass(ExportPass): + """Replace non-mutated ``view_copy`` type of ops with ``view`` ops.""" + + def call(self, graph_module: fx.GraphModule) -> PassResult: + graph_changed = False + + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in _VIEW_TARGETS: + continue + + if self._has_blocking_user(node, node.users.keys()): + continue + + node.target = _VIEW_TARGETS[node.target] + graph_changed = True + + if graph_changed: + graph_module.graph.lint() + graph_module.recompile() + + return PassResult(graph_module, graph_changed) + + def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool: + for user in users: + if self._is_mutating_user(node, user) or self._is_view_user(node, user): + return True + return False + + def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat in-place tensor methods conservatively as mutations only when the + # method name ends with ``_`` which is the PyTorch convention for mutation. + return isinstance(user.target, str) and user.target.endswith("_") + + if user.op != "call_function": + return False + + target = user.target + if not hasattr(target, "_schema"): + return False + + schema = target._schema # pyre-ignore[16] + # Positional arguments + for index, arg in enumerate(user.args): + if arg is node and self._argument_mutates(schema, index): + return True + + # Keyword arguments + for name, arg in user.kwargs.items(): + if arg is node and self._argument_mutates(schema, name): + return True + + return False + + def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat tensor methods conservatively and assume they may be view-producing. + return True + + if user.op != "call_function": + return False + + target = user.target + if getattr(target, "is_view", False): + for arg in user.args: + if arg is node: + return True + for arg in user.kwargs.values(): + if arg is node: + return True + + return False + + def _argument_mutates( + self, schema: torch._C.FunctionSchema, key: int | str + ) -> bool: + arguments = schema.arguments + if isinstance(key, int): + if key >= len(arguments): + return False + argument = arguments[key] + else: + argument = next((arg for arg in arguments if arg.name == key), None) + if argument is None: + return False + + alias_info = argument.alias_info + return bool(alias_info and alias_info.is_write) diff --git a/backends/aoti/targets.bzl b/backends/aoti/targets.bzl new file mode 100644 index 00000000000..327bef8cc53 --- /dev/null +++ b/backends/aoti/targets.bzl @@ -0,0 +1,88 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + runtime.python_library( + name = "aoti_partitioner", + srcs = [ + "aoti_partitioner.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/backend:partitioner", + "//executorch/exir/backend:utils", + ], + ) + + runtime.python_library( + name = "aoti_backend", + srcs = [ + "aoti_backend.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/aoti/passes:passes", + "//executorch/exir/_serialize:lib", + "//executorch/exir/backend:backend_details", + "//executorch/exir/backend:compile_spec_schema", + ], + ) + + # AOTI common shims functionality + runtime.cxx_library( + name = "common_shims", + srcs = [ + "common_shims.cpp", + ], + headers = [ + "common_shims.h", + "export.h", + "utils.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + # Constructor needed for backend registration. + compiler_flags = ["-Wno-global-constructors"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + ], + ) + + # AOTI model container functionality + runtime.cxx_library( + name = "delegate_handle", + headers = [ + "aoti_delegate_handle.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + # Constructor needed for backend registration. + compiler_flags = ["-Wno-global-constructors"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + "//executorch/runtime/backend:interface", + "//executorch/runtime/core:core", + ], + ) + + # Common AOTI functionality (combining both common_shims and delegate_handle) + runtime.cxx_library( + name = "aoti_common", + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":common_shims", + ":delegate_handle", + ], + ) diff --git a/backends/aoti/tests/TARGETS b/backends/aoti/tests/TARGETS new file mode 100644 index 00000000000..8daa8abd4d7 --- /dev/null +++ b/backends/aoti/tests/TARGETS @@ -0,0 +1,22 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") + +oncall("executorch") + +cpp_unittest( + name = "test_common_shims", + srcs = [ + "test_common_shims.cpp", + ], + headers = [ + "utils.h", + ], + deps = [ + "//executorch/backends/aoti:common_shims", + "//executorch/extension/tensor:tensor", + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/extension/tensor:tensor", + ], +) diff --git a/backends/aoti/tests/test_common_shims.cpp b/backends/aoti/tests/test_common_shims.cpp new file mode 100644 index 00000000000..0fd1b057f99 --- /dev/null +++ b/backends/aoti/tests/test_common_shims.cpp @@ -0,0 +1,335 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::aoti::test; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for common shims tests +class CommonShimsTest : public ::testing::Test { + protected: + void SetUp() override { + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + } + + void TearDown() override { + // Clean up metadata and free any tensor data + cleanup_tensor_metadata(); + for (auto& tensor : test_tensors_) { + free_tensor_data(tensor.get()); + } + test_tensors_.clear(); + } + + // Helper to create and track test tensors for cleanup + Tensor* create_tracked_tensor(const std::vector& sizes) { + auto tensor = create_test_tensor(sizes); + Tensor* ptr = tensor.get(); + test_tensors_.push_back(tensor); + return ptr; + } + + private: + std::vector> test_tensors_; +}; + +// Test aoti_torch_get_sizes basic functionality +TEST_F(CommonShimsTest, GetSizesBasicFunctionality) { + // Test 1D tensor + auto tensor_1d = create_tracked_tensor({5}); + int64_t* sizes_ptr; + AOTITorchError error = aoti_torch_get_sizes(tensor_1d, &sizes_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(sizes_ptr, nullptr); + EXPECT_EQ(sizes_ptr[0], 5); + + // Test 2D tensor + auto tensor_2d = create_tracked_tensor({3, 4}); + error = aoti_torch_get_sizes(tensor_2d, &sizes_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(sizes_ptr, nullptr); + EXPECT_EQ(sizes_ptr[0], 3); + EXPECT_EQ(sizes_ptr[1], 4); + + // Test 3D tensor + auto tensor_3d = create_tracked_tensor({2, 3, 4}); + error = aoti_torch_get_sizes(tensor_3d, &sizes_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(sizes_ptr, nullptr); + EXPECT_EQ(sizes_ptr[0], 2); + EXPECT_EQ(sizes_ptr[1], 3); + EXPECT_EQ(sizes_ptr[2], 4); +} + +// Test aoti_torch_get_strides basic functionality +TEST_F(CommonShimsTest, GetStridesBasicFunctionality) { + // Test 1D tensor + auto tensor_1d = create_tracked_tensor({5}); + int64_t* strides_ptr; + AOTITorchError error = aoti_torch_get_strides(tensor_1d, &strides_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(strides_ptr, nullptr); + EXPECT_EQ(strides_ptr[0], 1); + + // Test 2D tensor - row major: [3, 4] should have strides [4, 1] + auto tensor_2d = create_tracked_tensor({3, 4}); + error = aoti_torch_get_strides(tensor_2d, &strides_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(strides_ptr, nullptr); + EXPECT_EQ(strides_ptr[0], 4); + EXPECT_EQ(strides_ptr[1], 1); + + // Test 3D tensor - row major: [2, 3, 4] should have strides [12, 4, 1] + auto tensor_3d = create_tracked_tensor({2, 3, 4}); + error = aoti_torch_get_strides(tensor_3d, &strides_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(strides_ptr, nullptr); + EXPECT_EQ(strides_ptr[0], 12); + EXPECT_EQ(strides_ptr[1], 4); + EXPECT_EQ(strides_ptr[2], 1); +} + +// Test caching logic for sizes +TEST_F(CommonShimsTest, SizesCachingLogic) { + auto tensor = create_tracked_tensor({2, 3, 4}); + + // First call should cache the sizes + int64_t* sizes_ptr1; + AOTITorchError error = aoti_torch_get_sizes(tensor, &sizes_ptr1); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(sizes_ptr1, nullptr); + + // Second call should return the same cached pointer + int64_t* sizes_ptr2; + error = aoti_torch_get_sizes(tensor, &sizes_ptr2); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(sizes_ptr1, sizes_ptr2); // Should be the exact same pointer + + // Values should still be correct + EXPECT_EQ(sizes_ptr2[0], 2); + EXPECT_EQ(sizes_ptr2[1], 3); + EXPECT_EQ(sizes_ptr2[2], 4); +} + +// Test caching logic for strides +TEST_F(CommonShimsTest, StridesCachingLogic) { + auto tensor = create_tracked_tensor({2, 3, 4}); + + // First call should cache the strides + int64_t* strides_ptr1; + AOTITorchError error = aoti_torch_get_strides(tensor, &strides_ptr1); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(strides_ptr1, nullptr); + + // Second call should return the same cached pointer + int64_t* strides_ptr2; + error = aoti_torch_get_strides(tensor, &strides_ptr2); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(strides_ptr1, strides_ptr2); // Should be the exact same pointer + + // Values should still be correct + EXPECT_EQ(strides_ptr2[0], 12); + EXPECT_EQ(strides_ptr2[1], 4); + EXPECT_EQ(strides_ptr2[2], 1); +} + +// Test that different tensors have different cached entries +TEST_F(CommonShimsTest, DifferentTensorsCacheSeparately) { + auto tensor1 = create_tracked_tensor({2, 3}); + auto tensor2 = create_tracked_tensor({4, 5}); + + // Get sizes for both tensors + int64_t* sizes1_ptr; + int64_t* sizes2_ptr; + + EXPECT_EQ(aoti_torch_get_sizes(tensor1, &sizes1_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_sizes(tensor2, &sizes2_ptr), Error::Ok); + + // Pointers should be different (different cache entries) + EXPECT_NE(sizes1_ptr, sizes2_ptr); + + // Values should be correct + EXPECT_EQ(sizes1_ptr[0], 2); + EXPECT_EQ(sizes1_ptr[1], 3); + EXPECT_EQ(sizes2_ptr[0], 4); + EXPECT_EQ(sizes2_ptr[1], 5); + + // Test strides as well + int64_t* strides1_ptr; + int64_t* strides2_ptr; + + EXPECT_EQ(aoti_torch_get_strides(tensor1, &strides1_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor2, &strides2_ptr), Error::Ok); + + // Pointers should be different (different cache entries) + EXPECT_NE(strides1_ptr, strides2_ptr); + + // Values should be correct + EXPECT_EQ(strides1_ptr[0], 3); + EXPECT_EQ(strides1_ptr[1], 1); + EXPECT_EQ(strides2_ptr[0], 5); + EXPECT_EQ(strides2_ptr[1], 1); +} + +// Test cache persistence across multiple calls +TEST_F(CommonShimsTest, CachePersistence) { + auto tensor = create_tracked_tensor({3, 4, 5}); + + // Multiple calls to sizes should all return the same pointer + int64_t* sizes_ptr1; + int64_t* sizes_ptr2; + int64_t* sizes_ptr3; + + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr1), Error::Ok); + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr2), Error::Ok); + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr3), Error::Ok); + + EXPECT_EQ(sizes_ptr1, sizes_ptr2); + EXPECT_EQ(sizes_ptr2, sizes_ptr3); + + // Multiple calls to strides should all return the same pointer + int64_t* strides_ptr1; + int64_t* strides_ptr2; + int64_t* strides_ptr3; + + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr1), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr2), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr3), Error::Ok); + + EXPECT_EQ(strides_ptr1, strides_ptr2); + EXPECT_EQ(strides_ptr2, strides_ptr3); +} + +// Test 0D tensor (scalar) +TEST_F(CommonShimsTest, ScalarTensor) { + auto tensor_0d = create_tracked_tensor({}); + + // Test sizes for 0D tensor + int64_t* sizes_ptr; + AOTITorchError error = aoti_torch_get_sizes(tensor_0d, &sizes_ptr); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(sizes_ptr, nullptr); + + // Test strides for 0D tensor + int64_t* strides_ptr; + error = aoti_torch_get_strides(tensor_0d, &strides_ptr); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(strides_ptr, nullptr); + + // Cache should work for 0D tensors too + int64_t* sizes_ptr2; + error = aoti_torch_get_sizes(tensor_0d, &sizes_ptr2); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(sizes_ptr, sizes_ptr2); +} + +// Test large tensor dimensions +TEST_F(CommonShimsTest, LargeTensorDimensions) { + auto tensor = create_tracked_tensor({100, 200, 300, 400}); + + // Test sizes + int64_t* sizes_ptr; + AOTITorchError error = aoti_torch_get_sizes(tensor, &sizes_ptr); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(sizes_ptr, nullptr); + EXPECT_EQ(sizes_ptr[0], 100); + EXPECT_EQ(sizes_ptr[1], 200); + EXPECT_EQ(sizes_ptr[2], 300); + EXPECT_EQ(sizes_ptr[3], 400); + + // Test strides - expected: [24000000, 120000, 400, 1] + int64_t* strides_ptr; + error = aoti_torch_get_strides(tensor, &strides_ptr); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(strides_ptr, nullptr); + EXPECT_EQ(strides_ptr[0], 24000000); + EXPECT_EQ(strides_ptr[1], 120000); + EXPECT_EQ(strides_ptr[2], 400); + EXPECT_EQ(strides_ptr[3], 1); +} + +// Test that cleanup_tensor_metadata clears the cache +TEST_F(CommonShimsTest, CleanupFunctionality) { + auto tensor = create_tracked_tensor({2, 3}); + + // Cache some data + int64_t* sizes_ptr1; + int64_t* strides_ptr1; + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr1), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr1), Error::Ok); + + // Clear the cache + cleanup_tensor_metadata(); + + // Getting sizes/strides again should create new cache entries + // (We can't directly test if the pointers are different since that would be + // implementation-dependent, but we can at least verify the functions still + // work) + int64_t* sizes_ptr2; + int64_t* strides_ptr2; + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr2), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr2), Error::Ok); + + // Values should still be correct + EXPECT_EQ(sizes_ptr2[0], 2); + EXPECT_EQ(sizes_ptr2[1], 3); + EXPECT_EQ(strides_ptr2[0], 3); + EXPECT_EQ(strides_ptr2[1], 1); +} + +// Test mixed operations to ensure caches are independent +TEST_F(CommonShimsTest, IndependentCaches) { + auto tensor = create_tracked_tensor({2, 3, 4}); + + // Get sizes first + int64_t* sizes_ptr1; + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr1), Error::Ok); + + // Get strides + int64_t* strides_ptr1; + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr1), Error::Ok); + + // Get sizes again - should be cached + int64_t* sizes_ptr2; + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr2), Error::Ok); + EXPECT_EQ(sizes_ptr1, sizes_ptr2); + + // Get strides again - should be cached + int64_t* strides_ptr2; + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr2), Error::Ok); + EXPECT_EQ(strides_ptr1, strides_ptr2); + + // Sizes and strides pointers should be different (different caches) + EXPECT_NE(sizes_ptr1, strides_ptr1); +} + +// Test all dtype functions return correct PyTorch dtype codes +TEST_F(CommonShimsTest, AllDtypesReturnCorrectValues) { + EXPECT_EQ(aoti_torch_dtype_float32(), 6); // PyTorch's float32 dtype code + EXPECT_EQ(aoti_torch_dtype_bfloat16(), 15); // PyTorch's bfloat16 dtype code + EXPECT_EQ(aoti_torch_dtype_int8(), 1); // PyTorch's int8 dtype code + EXPECT_EQ(aoti_torch_dtype_int16(), 2); // PyTorch's int16 dtype code + EXPECT_EQ(aoti_torch_dtype_int32(), 3); // PyTorch's int32 dtype code + EXPECT_EQ(aoti_torch_dtype_int64(), 4); // PyTorch's int64 dtype code + EXPECT_EQ(aoti_torch_dtype_bool(), 11); // PyTorch's bool dtype code +} diff --git a/backends/aoti/tests/utils.h b/backends/aoti/tests/utils.h new file mode 100644 index 00000000000..1f26f7e2d51 --- /dev/null +++ b/backends/aoti/tests/utils.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { +namespace test { + +// Use the same type aliases as in common_shims.h +using executorch::runtime::etensor::Tensor; + +/** + * Creates a test tensor with the specified shape and scalar type + */ +inline std::shared_ptr create_test_tensor( + const std::vector& sizes, + exec_aten::ScalarType dtype = exec_aten::ScalarType::Float) { + // Calculate total number of elements + int64_t total_elements = 1; + for (int64_t size : sizes) { + total_elements *= size; + } + + // Calculate strides (row-major layout) + std::vector strides(sizes.size()); + if (sizes.size() > 0) { + strides[sizes.size() - 1] = 1; + for (int i = sizes.size() - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + } + + // Allocate data buffer + size_t dtype_size = exec_aten::elementSize(dtype); + void* data = malloc(total_elements * dtype_size); + + // Convert sizes and strides to the required type + std::vector sizes_converted( + sizes.begin(), sizes.end()); + std::vector strides_converted( + strides.begin(), strides.end()); + + // Create the tensor with the correct argument types and count + auto tensor = executorch::extension::from_blob( + data, sizes_converted, strides_converted, dtype); + + return tensor; +} + +/** + * Helper to clean up tensor data that was allocated with malloc + */ +inline void free_tensor_data(Tensor* tensor) { + if (tensor && tensor->mutable_data_ptr()) { + free(tensor->mutable_data_ptr()); + } +} + +} // namespace test +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/utils.h b/backends/aoti/utils.h new file mode 100644 index 00000000000..8f64bdbe7da --- /dev/null +++ b/backends/aoti/utils.h @@ -0,0 +1,166 @@ + +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +// Common using declarations for ExecuTorch types +using executorch::runtime::Error; + +extern "C" { + +// Common AOTI type aliases +using AOTITorchError = Error; + +// Map int32_t dtype to ExecuTorch ScalarType (robust version of hardcoded +// ScalarType::Float) +inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) { + // Convert based on known PyTorch dtype codes (without CUDA-specific + // dependency) + switch (dtype) { + case 1: // PyTorch's int8 dtype code + return executorch::aten::ScalarType::Char; + case 2: // PyTorch's int16 dtype code + return executorch::aten::ScalarType::Short; + case 3: // PyTorch's int32 dtype code + return executorch::aten::ScalarType::Int; + case 4: // PyTorch's int64 dtype code + return executorch::aten::ScalarType::Long; + case 6: // PyTorch's float32 dtype code + return executorch::aten::ScalarType::Float; + case 11: // PyTorch's bool dtype code + return executorch::aten::ScalarType::Bool; + case 15: // PyTorch's bfloat16 dtype code + return executorch::aten::ScalarType::BFloat16; + // Future support for additional dtypes can be added here + default: + ET_LOG(Error, "Unsupported dtype: %d for ScalarType conversion", dtype); + return executorch::aten::ScalarType::Undefined; + } +} + +// Map int32_t dtype to number of bytes per element (reusing ExecuTorch's +// elementSize function) +inline size_t dtype_to_element_size(int32_t dtype) { + // First convert int32_t dtype to ExecuTorch ScalarType, then use existing + // elementSize function + executorch::aten::ScalarType scalar_type = dtype_to_scalar_type(dtype); + if (scalar_type == executorch::aten::ScalarType::Undefined) { + ET_LOG(Error, "Unsupported dtype: %d for element size calculation", dtype); + return 0; // Return 0 to indicate error + } + + // Reuse ExecuTorch's existing elementSize function from scalar_type_util.h + return executorch::runtime::elementSize(scalar_type); +} + +// Storage offset validation utility function +inline AOTITorchError validate_storage_offset(int64_t storage_offset) { + // Storage offset must always be 0 + if (storage_offset != 0) { + ET_LOG( + Error, + "Storage offset must be 0. Got storage_offset: %ld", + storage_offset); + return Error::InvalidArgument; + } + return Error::Ok; +} + +// Check if tensor is in contiguous memory format (NCHW for 4D tensors) +// Contiguous format means strides decrease from left to right: +// For NCHW: strides = [C*H*W, H*W, W, 1] +inline bool is_tensor_contiguous( + int64_t ndim, + const int64_t* sizes, + const int64_t* strides) { + int64_t expected_stride = 1; + for (int64_t i = ndim - 1; i >= 0; i--) { + if (strides[i] != expected_stride) { + return false; + } + expected_stride *= sizes[i]; + } + return true; +} + +} // extern "C" + +// Utility function to convert sizes pointer to vector +inline std::vector convert_sizes_to_vector( + int64_t ndim, + const int64_t* sizes_ptr) { + std::vector sizes(ndim); + for (int i = 0; i < ndim; i++) { + sizes[i] = static_cast(sizes_ptr[i]); + } + return sizes; +} + +// Utility function to convert strides pointer to vector or calculate from sizes +inline std::vector convert_strides_to_vector( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr) { + std::vector strides(ndim); + + if (strides_ptr != nullptr) { + // Use provided strides. + for (int64_t i = 0; i < ndim; i++) { + strides[i] = static_cast(strides_ptr[i]); + } + } else { + // Calculate strides from sizes. + if (ndim > 0) { + strides[ndim - 1] = static_cast( + 1); // Last dimension has stride 1 + for (int64_t i = ndim - 2; i >= 0; i--) { + if (sizes_ptr[i + 1] == 0) { + strides[i] = strides[i + 1]; // Copy stride when size is 0 + } else { + strides[i] = static_cast( + static_cast(strides[i + 1]) * sizes_ptr[i + 1]); + } + } + } + } + return strides; +} + +// Check if tensor is in contiguous memory format (NCHW for 4D tensors) +// Contiguous format means strides decrease from left to right: +// For NCHW: strides = [C*H*W, H*W, W, 1] +inline bool is_contiguous_tensor( + std::vector& sizes, + std::vector& strides) { + int64_t ndim = static_cast(strides.size()); + int64_t expected_stride = 1; + for (int64_t i = ndim - 1; i >= 0; i--) { + if (strides[i] != expected_stride) { + return false; + } + expected_stride *= sizes[i]; + } + return true; +} + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/apple/coreml/CMakeLists.txt b/backends/apple/coreml/CMakeLists.txt index 9879a05e3dc..17e2d94e336 100644 --- a/backends/apple/coreml/CMakeLists.txt +++ b/backends/apple/coreml/CMakeLists.txt @@ -115,7 +115,7 @@ if(APPLE) endif() target_compile_options(coreml_util PUBLIC -fPIC) -install(TARGETS coreml_util DESTINATION lib) +install(TARGETS coreml_util DESTINATION ${CMAKE_INSTALL_LIBDIR}) install( DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/runtime/util @@ -154,7 +154,7 @@ target_compile_options(coreml_inmemoryfs PUBLIC -fPIC) install( TARGETS coreml_inmemoryfs - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} INCLUDES DESTINATION ${_common_include_directories} ) @@ -251,7 +251,7 @@ if(APPLE) install( TARGETS coremldelegate coreml_util coreml_inmemoryfs EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) diff --git a/backends/apple/coreml/README.md b/backends/apple/coreml/README.md index d063dfc8b71..d72f04da1a1 100644 --- a/backends/apple/coreml/README.md +++ b/backends/apple/coreml/README.md @@ -1,7 +1,7 @@ # ExecuTorch Core ML Delegate This subtree contains the Core ML Delegate implementation for ExecuTorch. -Core ML is an optimized framework for running machine learning models on Apple devices. The delegate is the mechanism for leveraging the Core ML framework to accelerate operators when running on Apple devices. To learn how to use the CoreML delegate, see the [documentation](https://github.com/pytorch/executorch/blob/main/docs/source/backends-coreml.md). +Core ML is an optimized framework for running machine learning models on Apple devices. The delegate is the mechanism for leveraging the Core ML framework to accelerate operators when running on Apple devices. To learn how to use the CoreML delegate, see the [documentation](https://github.com/pytorch/executorch/blob/main/docs/source/backends/coreml/coreml-overview.md). ## Layout - `compiler/` : Lowers a module to Core ML backend. diff --git a/backends/apple/coreml/TARGETS b/backends/apple/coreml/TARGETS index 444e886b4e6..2f3494b6004 100644 --- a/backends/apple/coreml/TARGETS +++ b/backends/apple/coreml/TARGETS @@ -82,6 +82,7 @@ runtime.python_library( "//executorch/exir/backend:partitioner", "//executorch/exir/backend:utils", "//executorch/export:lib", + "//executorch/runtime:runtime", # @manual ], ) diff --git a/backends/apple/coreml/compiler/coreml_preprocess.py b/backends/apple/coreml/compiler/coreml_preprocess.py index d1614f30451..32cd0df67a2 100644 --- a/backends/apple/coreml/compiler/coreml_preprocess.py +++ b/backends/apple/coreml/compiler/coreml_preprocess.py @@ -6,6 +6,7 @@ import logging import shutil +import tempfile import uuid from dataclasses import asdict, dataclass from enum import Enum @@ -42,6 +43,7 @@ class COMPILE_SPEC_KEYS(Enum): MODEL_COMPUTE_PRECISION = "model_compute_precision" OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config" ENUMERATED_SHAPES = "enumerated_shapes" + PASS_PIPELINE = "pass_pipeline" class MODEL_PATHS(Enum): @@ -219,6 +221,33 @@ def op_linear_quantizer_config_from_compile_specs( return None + @staticmethod + def generate_pass_pipeline_compile_spec(pass_names: List[str]) -> CompileSpec: + """ + Creates a compile spec representing the pass pipeline to be used by the CoreML backend + :param pass_names: the list of pass names + """ + str_representation = json.dumps(pass_names) + byte_representation = str_representation.encode("utf-8") + return CompileSpec(COMPILE_SPEC_KEYS.PASS_PIPELINE.value, byte_representation) + + @staticmethod + def pass_pipeline_from_compile_specs( + compile_specs: List[CompileSpec], + ) -> ct.PassPipeline: + """ + Creates a PassPipeline from the list of compile specs, or returns the default if none are provided. + """ + for compile_spec in compile_specs: + if compile_spec.key == COMPILE_SPEC_KEYS.PASS_PIPELINE.value: + pass_names_str = compile_spec.value.decode("utf-8") + pass_names = json.loads(pass_names_str) + return ct.PassPipeline( + pass_names, pipeline_name="executorch_user_pipeline" + ) + + return ct.PassPipeline.DEFAULT + @staticmethod def generate_enumerated_shapes_compile_spec( ep: ExportedProgram, @@ -274,6 +303,7 @@ def generate_compile_specs( compute_precision: ct.precision = ct.precision.FLOAT16, model_type: MODEL_TYPE = MODEL_TYPE.MODEL, op_linear_quantizer_config: Optional[Dict] = None, + pass_names: Optional[List[str]] = None, ) -> List[CompileSpec]: """ Returns the list of compile specs that's used by CoreMLBackend to lower the module. @@ -297,6 +327,10 @@ def generate_compile_specs( op_linear_quantizer_config ) ) + if pass_names is not None: + compile_specs.append( + CoreMLBackend.generate_pass_pipeline_compile_spec(pass_names) + ) return compile_specs @@ -415,7 +449,7 @@ def preprocess_model( mlmodel: ct.models.MLModel, model_type: MODEL_TYPE ) -> PreprocessResult: identifier = "executorch_" + str(uuid.uuid4()) - dir_path: Path = Path("tmp") / identifier + dir_path: Path = Path(tempfile.gettempdir()) / identifier model_dir_path: Path = dir_path / "lowered_module" model_spec: ct.proto.Model_pb2 = mlmodel.get_spec() logger.warning( @@ -502,6 +536,9 @@ def preprocess( enumerated_shapes = CoreMLBackend.enumerated_shapes_from_compile_specs( compile_specs ) + pass_pipeline: ct.PassPipeline = CoreMLBackend.pass_pipeline_from_compile_specs( + compile_specs + ) # If using enumerated shapes, we need to pass the inputs to CoreML's convert() function # explicitly @@ -529,7 +566,7 @@ def preprocess( model=edge_program, source="pytorch", convert_to="mlprogram", - pass_pipeline=ct.PassPipeline.DEFAULT, + pass_pipeline=pass_pipeline, skip_model_load=skip_model_load, compute_precision=model_compute_precision, minimum_deployment_target=minimum_deployment_target, diff --git a/backends/apple/coreml/compiler/torch_ops.py b/backends/apple/coreml/compiler/torch_ops.py index 53ac436fe38..29c7120feb7 100644 --- a/backends/apple/coreml/compiler/torch_ops.py +++ b/backends/apple/coreml/compiler/torch_ops.py @@ -20,8 +20,6 @@ NUM_TO_TORCH_DTYPE, split, to, - transpose, - unbind, ) from coremltools.converters.mil.frontend.torch.torch_op_registry import ( register_torch_op, @@ -30,18 +28,6 @@ from executorch.exir.dim_order_utils import get_memory_format -# https://github.com/apple/coremltools/pull/2556 -@register_torch_op(override=False) -def transpose_copy(context, node): - transpose(context, node) - - -# https://github.com/apple/coremltools/pull/2557 -@register_torch_op(override=False) -def unbind_copy(context, node): - unbind(context, node) - - # https://github.com/apple/coremltools/pull/2563 @register_torch_op(override=False) def split_copy(context, node): @@ -117,7 +103,9 @@ def _clone_dim_order(context, node): # https://github.com/apple/coremltools/pull/2558 @register_torch_op( torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"], - override=False, + # coremltools did not merge the fix into 9.0 (https://github.com/apple/coremltools/pull/2589), + # so we override here + override=True, ) def dequantize_affine(context, node): inputs = _get_inputs(context, node, expected=[7, 8]) diff --git a/backends/apple/coreml/runtime/delegate/ETCoreMLAssetManager.mm b/backends/apple/coreml/runtime/delegate/ETCoreMLAssetManager.mm index 53c3d1cdc69..d98a1d7331b 100644 --- a/backends/apple/coreml/runtime/delegate/ETCoreMLAssetManager.mm +++ b/backends/apple/coreml/runtime/delegate/ETCoreMLAssetManager.mm @@ -170,46 +170,74 @@ bool set_total_assets_size(size_t total_size, return true; } -bool exclude_item_from_backup(NSURL *url, NSError * __autoreleasing *error) { - return [url setResourceValue:@(YES) forKey:NSURLIsExcludedFromBackupKey error:error]; -} -NSURL * _Nullable create_directory_if_needed(NSURL *url, - NSString *name, +NSURL * _Nullable create_directory_if_needed(NSURL *dirURL, NSFileManager *fm, - NSError * __autoreleasing *error) { - NSURL *directory_url = [url URLByAppendingPathComponent:name]; - if (![fm fileExistsAtPath:directory_url.path] && - ![fm createDirectoryAtURL:directory_url withIntermediateDirectories:NO attributes:@{} error:error]) { - return nil; - } + NSError **error) { + NSCParameterAssert(dirURL); + NSCParameterAssert(dirURL.isFileURL); + NSCParameterAssert(fm); - ::exclude_item_from_backup(directory_url, nil); - - return directory_url; + NSString *dirPath = dirURL.path; + + // Fast path: is it already a directory? + BOOL isDir = NO; + if (dirPath && [fm fileExistsAtPath:dirPath isDirectory:&isDir] && isDir) { + return dirURL; + } + + // Try to create the directory and its parents. + NSDictionary *attrs = @{ NSFileProtectionKey : NSFileProtectionCompleteUntilFirstUserAuthentication }; + if (![fm createDirectoryAtURL:dirURL + withIntermediateDirectories:YES + attributes:attrs + error:error]) { + // Lost a race and creation failed because something already exists, check if it's a directory. + isDir = NO; + if (dirPath && [fm fileExistsAtPath:dirPath isDirectory:&isDir] && isDir) { + if (error) { *error = nil; } + } else { + return nil; + } + } + + // Best effort: exclude from backup (ignore failure) + (void)[dirURL setResourceValue:@YES forKey:NSURLIsExcludedFromBackupKey error:nil]; + + return dirURL; } -bool is_directory_empty(NSURL *url, NSFileManager *fm, NSError * __autoreleasing *error) { - BOOL is_directory = NO; - if (![fm fileExistsAtPath:url.path isDirectory:&is_directory] && !is_directory) { +bool is_missing_or_empty_directory(NSURL *dirURL, NSFileManager *fm, NSError * __autoreleasing *error) { + NSString *dirPath = dirURL.path; + BOOL isDir = NO; + BOOL doesFileExist = dirPath && [fm fileExistsAtPath:dirPath isDirectory:&isDir]; + if (!doesFileExist) { return true; } + if (!isDir) { + return false; + } - __block NSError *local_error = nil; - BOOL (^errorHandler)(NSURL *url, NSError *error) = ^BOOL(NSURL *url, NSError *enumeration_error) { - local_error = enumeration_error; - return NO; - }; - - NSDirectoryEnumerator *enumerator = [fm enumeratorAtURL:url + __block NSError *localError = nil; + NSDirectoryEnumerator *enumerator = [fm enumeratorAtURL:dirURL includingPropertiesForKeys:@[] options:NSDirectoryEnumerationProducesRelativePathURLs - errorHandler:errorHandler]; - if (local_error && error) { - *error = local_error; + errorHandler:^BOOL(NSURL *u, NSError *e){ localError = e; return NO; }]; + + // If enumerator failed to create, do not say the directory is empty + if (!enumerator) { + return false; } - return [enumerator nextObject] == nil; + id nextObject = [enumerator nextObject]; + + // Do not treat enumeration errors as empty directory + if (localError) { + if (error) { *error = localError; } + return false; + } + + return nextObject == nil; } NSURL * _Nullable get_asset_url(const Asset& asset) { @@ -255,28 +283,6 @@ BOOL is_asset_alive(NSMapTable *assets_in_use_map, return assets; } -NSURL * _Nullable move_to_directory(NSURL *url, - NSURL *directoryURL, - NSFileManager *fileManager, - NSError * __autoreleasing *error) { - if (!url) { - ETCoreMLLogErrorAndSetNSError(error, ETCoreMLErrorInternalError, "Move operation failed: source URL is nil."); - return nil; - } - - if (!directoryURL) { - ETCoreMLLogErrorAndSetNSError(error, ETCoreMLErrorInternalError, "Move operation failed: destination URL is nil."); - return nil; - } - - NSURL *dstURL = [directoryURL URLByAppendingPathComponent:[NSUUID UUID].UUIDString]; - if (![fileManager moveItemAtURL:url toURL:dstURL error:error]) { - return nil; - } - - return dstURL; -} - } //namespace @interface ETCoreMLAssetManager () { @@ -318,23 +324,33 @@ - (nullable instancetype)initWithDatabase:(const std::shared_ptr&)data } NSFileManager *fileManager = [[NSFileManager alloc] init]; - NSURL *managedAssetsDirectoryURL = ::create_directory_if_needed(assetsDirectoryURL, @"models", fileManager, error); + + NSDictionary *attrs = @{ NSFileProtectionKey : NSFileProtectionCompleteUntilFirstUserAuthentication }; + + NSURL *managedAssetsDirectoryURL = [assetsDirectoryURL URLByAppendingPathComponent:@"models"]; + managedAssetsDirectoryURL = ::create_directory_if_needed(managedAssetsDirectoryURL, fileManager, error); if (!managedAssetsDirectoryURL) { return nil; } + (void)[fileManager setAttributes:attrs ofItemAtPath:managedAssetsDirectoryURL.path error:nil]; // best-effort + - NSURL *managedTrashDirectoryURL = ::create_directory_if_needed(trashDirectoryURL, @"models", fileManager, error); + NSURL *managedTrashDirectoryURL = [trashDirectoryURL URLByAppendingPathComponent:@"models"]; + managedTrashDirectoryURL = ::create_directory_if_needed(managedTrashDirectoryURL, fileManager, error); if (!managedTrashDirectoryURL) { return nil; } + (void)[fileManager setAttributes:attrs ofItemAtPath:managedTrashDirectoryURL.path error:nil]; // best-effort - NSURL *managedStagingDirectoryURL = ::create_directory_if_needed(assetsDirectoryURL, @"staging", fileManager, error); + NSURL *managedStagingDirectoryURL = [assetsDirectoryURL URLByAppendingPathComponent:@"staging"]; + managedStagingDirectoryURL = ::create_directory_if_needed(managedStagingDirectoryURL, fileManager, error); if (!managedStagingDirectoryURL) { return nil; } + (void)[fileManager setAttributes:attrs ofItemAtPath:managedStagingDirectoryURL.path error:nil]; // best-effort // If directory is empty then purge the stores - if (::is_directory_empty(managedAssetsDirectoryURL, fileManager, nil)) { + if (::is_missing_or_empty_directory(managedAssetsDirectoryURL, fileManager, nil)) { assetsMetaStore.impl()->purge(ec); assetsStore.impl()->purge(ec); } @@ -347,7 +363,6 @@ - (nullable instancetype)initWithDatabase:(const std::shared_ptr&)data _trashDirectoryURL = managedTrashDirectoryURL; _estimatedSizeInBytes = sizeInBytes.value(); _maxAssetsSizeInBytes = maxAssetsSizeInBytes; - _fileManager = fileManager; _trashQueue = dispatch_queue_create("com.executorchcoreml.assetmanager.trash", DISPATCH_QUEUE_SERIAL_WITH_AUTORELEASE_POOL); _syncQueue = dispatch_queue_create("com.executorchcoreml.assetmanager.sync", DISPATCH_QUEUE_SERIAL_WITH_AUTORELEASE_POOL); @@ -362,7 +377,35 @@ - (nullable instancetype)initWithDatabaseURL:(NSURL *)databaseURL assetsDirectoryURL:(NSURL *)assetsDirectoryURL trashDirectoryURL:(NSURL *)trashDirectoryURL maxAssetsSizeInBytes:(NSInteger)maxAssetsSizeInBytes - error:(NSError * __autoreleasing *)error { + error:(NSError * __autoreleasing *)error { + + NSURL *databaseDirectoryURL = [databaseURL URLByDeletingLastPathComponent]; + NSFileManager *fm = [[NSFileManager alloc] init]; + if (!::create_directory_if_needed(databaseDirectoryURL, fm, error)) { + return nil; + } + + // Ensure correct file protection + NSMutableArray *maybeDBPaths = [NSMutableArray array]; + NSString *databaseDirectoryPath = databaseDirectoryURL.path; + if (databaseDirectoryPath) { [maybeDBPaths addObject:databaseDirectoryPath]; } + + // Ensure correct file protection on existing database files, if any + // New database files should inherit the protection from the parent directory + NSString *databasePath = databaseURL.path; + if (databasePath) { + [maybeDBPaths addObject:databasePath]; + [maybeDBPaths addObject:[databasePath stringByAppendingString:@"-wal"]]; + [maybeDBPaths addObject:[databasePath stringByAppendingString:@"-shm"]]; + [maybeDBPaths addObject:[databasePath stringByAppendingString:@"-journal"]]; + } + NSDictionary *attrs = @{ NSFileProtectionKey : NSFileProtectionCompleteUntilFirstUserAuthentication }; + for (NSString *p in maybeDBPaths) { + if ([fm fileExistsAtPath:p]) { + (void)[fm setAttributes:attrs ofItemAtPath:p error:nil]; // best-effort + } + } + auto database = make_database(databaseURL, kBusyTimeIntervalInMS, error); if (!database) { return nil; @@ -381,11 +424,27 @@ - (void)withTemporaryDirectory:(void (^)(NSURL *directoryURL))block { if (![self.fileManager fileExistsAtPath:dstURL.path]) { return; } - - move_to_directory(dstURL, self.trashDirectoryURL, self.fileManager, nil); + [self moveItemAtURLToTrash:dstURL error:nil]; [self cleanupTrashDirectory]; } +- (NSURL * _Nullable) moveItemAtURLToTrash:(NSURL *)url + error:(NSError * __autoreleasing *)error { + ::create_directory_if_needed(self.trashDirectoryURL, self.fileManager, error); + NSURL *dstURL = [self.trashDirectoryURL URLByAppendingPathComponent:[NSUUID UUID].UUIDString]; + + if (!url) { + ETCoreMLLogErrorAndSetNSError(error, ETCoreMLErrorInternalError, "Move operation failed: source URL is nil."); + return nil; + } + + if (![self.fileManager moveItemAtURL:url toURL:dstURL error:error]) { + return nil; + } + + return dstURL; +} + - (void)cleanupAssetIfNeeded:(ETCoreMLAsset *)asset { if (!asset || asset.isValid) { return; @@ -394,7 +453,7 @@ - (void)cleanupAssetIfNeeded:(ETCoreMLAsset *)asset { NSString *identifier = asset.identifier; dispatch_async(self.syncQueue, ^{ NSError *cleanupError = nil; - if (![self _removeAssetWithIdentifier:asset.identifier error:&cleanupError]) { + if (![self _removeAssetWithIdentifier:asset.identifier alreadyInsideTransaction:NO error:&cleanupError]) { ETCoreMLLogError(cleanupError, "Failed to remove asset with identifier = %@", identifier); @@ -407,6 +466,7 @@ - (nullable ETCoreMLAsset *)_storeAssetAtURL:(NSURL *)srcURL error:(NSError * __autoreleasing *)error { dispatch_assert_queue(self.syncQueue); NSString *extension = srcURL.lastPathComponent.pathExtension; + ::create_directory_if_needed(self.assetsDirectoryURL, self.fileManager, error); NSURL *dstURL = [self.assetsDirectoryURL URLByAppendingPathComponent:[NSString stringWithFormat:@"%@.%@", identifier, extension]]; auto asset = Asset::make(srcURL, identifier, self.fileManager, error); if (!asset) { @@ -420,7 +480,7 @@ - (nullable ETCoreMLAsset *)_storeAssetAtURL:(NSURL *)srcURL bool status = _assetsStore.impl()->transaction([self, &assetValue, assetSizeInBytes, srcURL, dstURL, &ec, error]() { const std::string& assetIdentifier = assetValue.identifier; // If an asset exists with the same identifier then remove it. - if (![self _removeAssetWithIdentifier:@(assetIdentifier.c_str()) error:error]) { + if (![self _removeAssetWithIdentifier:@(assetIdentifier.c_str()) alreadyInsideTransaction:YES error:error]) { return false; } @@ -437,7 +497,14 @@ - (nullable ETCoreMLAsset *)_storeAssetAtURL:(NSURL *)srcURL } // If a file already exists at `dstURL`, move it to the trash for removal. - move_to_directory(dstURL, self.trashDirectoryURL, self.fileManager, nil); + if ([self.fileManager fileExistsAtPath:dstURL.path]) { + if (![self moveItemAtURLToTrash:dstURL error:error]) { + // Log error and return false + ETCoreMLLogErrorAndSetNSError(error, ETCoreMLErrorInternalError, "moveItemAtURLToTrash failed"); + return false; + } + } + // Move the asset to assets directory. if (![self.fileManager moveItemAtURL:srcURL toURL:dstURL error:error]) { return false; @@ -455,6 +522,7 @@ - (nullable ETCoreMLAsset *)_storeAssetAtURL:(NSURL *)srcURL [self.assetsInUseMap setObject:result forKey:identifier]; } else { [self cleanupAssetIfNeeded:result]; + return nil; } return result; @@ -550,6 +618,7 @@ - (BOOL)hasAssetWithIdentifier:(NSString *)identifier } - (BOOL)_removeAssetWithIdentifier:(NSString *)identifier + alreadyInsideTransaction:(BOOL)alreadyInsideTransaction error:(NSError * __autoreleasing *)error { dispatch_assert_queue(self.syncQueue); // Asset is alive we can't delete it. @@ -573,8 +642,9 @@ - (BOOL)_removeAssetWithIdentifier:(NSString *)identifier const auto& assetValue = asset.value(); size_t assetSizeInBytes = std::min(_estimatedSizeInBytes, static_cast(assetValue.total_size_in_bytes())); - // Update the stores inside a transaction, if anything fails it will automatically rollback to the previous state. - bool status = _assetsStore.impl()->transaction([self, &assetValue, assetSizeInBytes, &ec, error]() { + + + auto transaction = [self, &assetValue, assetSizeInBytes, &ec, error]() { if (!self->_assetsStore.impl()->remove(assetValue.identifier, ec)) { return false; } @@ -585,12 +655,20 @@ - (BOOL)_removeAssetWithIdentifier:(NSString *)identifier NSURL *assetURL = ::get_asset_url(assetValue); if ([self.fileManager fileExistsAtPath:assetURL.path] && - !move_to_directory(assetURL, self.trashDirectoryURL, self.fileManager, error)) { + ![self moveItemAtURLToTrash:assetURL error:error]) { return false; } return true; - }, Database::TransactionBehavior::Immediate, ec); + }; + + // Update the stores inside a transaction, if anything fails it will automatically rollback to the previous state. + bool status = false; + if (alreadyInsideTransaction) { + status = transaction(); + } else { + status = _assetsStore.impl()->transaction(transaction, Database::TransactionBehavior::Immediate, ec); + } // Update the estimated size if the transaction succeeded. _estimatedSizeInBytes -= status ? assetSizeInBytes : 0; @@ -602,7 +680,7 @@ - (BOOL)removeAssetWithIdentifier:(NSString *)identifier error:(NSError * __autoreleasing *)error { __block BOOL result = NO; dispatch_sync(self.syncQueue, ^{ - result = [self _removeAssetWithIdentifier:identifier error:error]; + result = [self _removeAssetWithIdentifier:identifier alreadyInsideTransaction:NO error:error]; }); return result; @@ -680,7 +758,7 @@ - (NSUInteger)_compact:(NSUInteger)sizeInBytes error:(NSError * __autoreleasing for (const auto& asset : assets) { NSError *cleanupError = nil; NSString *identifier = @(asset.identifier.c_str()); - if (![self _removeAssetWithIdentifier:identifier error:&cleanupError] && cleanupError) { + if (![self _removeAssetWithIdentifier:identifier alreadyInsideTransaction:NO error:&cleanupError] && cleanupError) { ETCoreMLLogError(cleanupError, "Failed to remove asset with identifier = %@.", identifier); @@ -742,14 +820,14 @@ - (BOOL)_purge:(NSError * __autoreleasing *)error { } // Move the the whole assets directory to the temp directory. - if (!move_to_directory(self.assetsDirectoryURL, self.trashDirectoryURL, self.fileManager, error)) { + if (![self moveItemAtURLToTrash:self.assetsDirectoryURL error:error]) { return false; } self->_estimatedSizeInBytes = 0; NSError *localError = nil; // Create the assets directory, if we fail here it's okay. - if (![self.fileManager createDirectoryAtURL:self.assetsDirectoryURL withIntermediateDirectories:NO attributes:@{} error:&localError]) { + if (![self.fileManager createDirectoryAtURL:self.assetsDirectoryURL withIntermediateDirectories:YES attributes:@{} error:&localError]) { ETCoreMLLogError(localError, "Failed to create assets directory."); } diff --git a/backends/apple/coreml/runtime/delegate/ETCoreMLModelLoader.h b/backends/apple/coreml/runtime/delegate/ETCoreMLModelLoader.h index 05e96ad59f5..1819710cfda 100644 --- a/backends/apple/coreml/runtime/delegate/ETCoreMLModelLoader.h +++ b/backends/apple/coreml/runtime/delegate/ETCoreMLModelLoader.h @@ -9,6 +9,7 @@ @class ETCoreMLModel; @class ETCoreMLAssetManager; +@class ETCoreMLAsset; namespace executorchcoreml { struct ModelMetadata; @@ -23,6 +24,12 @@ __attribute__((objc_subclassing_restricted)) - (instancetype)init NS_UNAVAILABLE; + ++ (nullable ETCoreMLModel*)loadModelWithCompiledAsset:(ETCoreMLAsset*)compiledAsset + configuration:(MLModelConfiguration*)configuration + metadata:(const executorchcoreml::ModelMetadata&)metadata + error:(NSError* __autoreleasing*)error; + /// Synchronously loads a model given the location of its on-disk representation and configuration. /// /// @param compiledModelURL The location of the model's on-disk representation (.mlmodelc directory). diff --git a/backends/apple/coreml/runtime/delegate/ETCoreMLModelLoader.mm b/backends/apple/coreml/runtime/delegate/ETCoreMLModelLoader.mm index 9e8ae04842e..731b8506f31 100644 --- a/backends/apple/coreml/runtime/delegate/ETCoreMLModelLoader.mm +++ b/backends/apple/coreml/runtime/delegate/ETCoreMLModelLoader.mm @@ -44,6 +44,22 @@ @implementation ETCoreMLModelLoader ++ (nullable ETCoreMLModel *)loadModelWithCompiledAsset:(ETCoreMLAsset *)compiledAsset + configuration:(MLModelConfiguration *)configuration + metadata:(const executorchcoreml::ModelMetadata&)metadata + error:(NSError * __autoreleasing *)error { + NSError *localError = nil; + ETCoreMLModel *model = (compiledAsset != nil) ? get_model_from_asset(compiledAsset, configuration, metadata, &localError) : nil; + if (model) { + return model; + } + if (error) { + *error = localError; + } + return nil; +} + + + (nullable ETCoreMLModel *)loadModelWithContentsOfURL:(NSURL *)compiledModelURL configuration:(MLModelConfiguration *)configuration metadata:(const executorchcoreml::ModelMetadata&)metadata @@ -58,7 +74,13 @@ + (nullable ETCoreMLModel *)loadModelWithContentsOfURL:(NSURL *)compiledModelURL asset = [assetManager storeAssetAtURL:compiledModelURL withIdentifier:identifier error:&localError]; } - ETCoreMLModel *model = (asset != nil) ? get_model_from_asset(asset, configuration, metadata, &localError) : nil; + ETCoreMLModel *model; + if (asset != nil) { + model = [self loadModelWithCompiledAsset:asset configuration:configuration metadata:metadata error:&localError]; + } else { + model = nil; + } + if (model) { return model; } diff --git a/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm b/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm index 524ceaf7e28..d59890ee00f 100644 --- a/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm +++ b/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm @@ -447,15 +447,13 @@ - (nullable NSURL *)compiledModelURLWithIdentifier:(NSString *)identifier // Handle based on the type of the model asset. switch (modelAssetType.value()) { case ModelAssetType::CompiledModel: { - // The model is already compiled; no further action needed. - // Return the existing model URL. + // Model is already compiled. ETCoreMLLogInfo("The model in the pte file is pre-compiled. Skipping compilation."); return modelURL; } case ModelAssetType::Model: { - // The model is not compiled yet. - // Compile the model at the specified URL with a maximum wait time of 5 minutes. + // Compile the model. ETCoreMLLogInfo("The model in the pte file is not pre-compiled. Compiling with a 5 min timeout."); NSURL *compiledModelURL = [ETCoreMLModelCompiler compileModelAtURL:modelURL maxWaitTimeInSeconds:(5 * 60) @@ -474,29 +472,44 @@ - (nullable ETCoreMLAsset *)compiledModelAssetWithMetadata:(const ModelMetadata& __block ETCoreMLAsset *compiledModelAsset = [self assetWithIdentifier:identifier]; if (compiledModelAsset) { ETCoreMLLogInfo("Cache Hit: Successfully retrieved compiled model with identifier=%@ from the models cache.", identifier); - } else { - ETCoreMLLogInfo("Cache Miss: Compiled Model with identifier=%@ was not found in the models cache.", identifier); + return compiledModelAsset; } - + + ETCoreMLLogInfo("Cache Miss: Compiled Model with identifier=%@ was not found in the models cache.", identifier); + __block NSURL *compiledModelURL; [self.assetManager withTemporaryDirectory:^(NSURL * _Nonnull directoryURL) { - if (compiledModelAsset) { - return; - } - // The directory specified by `directoryURL` is unique and will be automatically cleaned up // once the enclosing block completes. - NSURL *compiledModelURL = [self compiledModelURLWithIdentifier:identifier + compiledModelURL = [self compiledModelURLWithIdentifier:identifier modelURL:modelURL inMemoryFS:inMemoryFS dstURL:directoryURL error:error]; if (compiledModelURL) { // Move the compiled model to the asset manager to transfer ownership. - ETCoreMLLogInfo("Storing compiled asset with identifier=%@ in the asset manager.", identifier); + ETCoreMLLogInfo("Successfully got compiled model with identifier=%@. Transferring ownership to assetManager.", identifier); compiledModelAsset = [self.assetManager storeAssetAtURL:compiledModelURL withIdentifier:identifier error:error]; } }]; + if (!compiledModelAsset) { + ETCoreMLLogInfo("Failed to transfer ownership of asset with identifier=%@ to assetManager", identifier); + if (compiledModelURL && [self.fileManager fileExistsAtPath:compiledModelURL.path]) { + // Log what error was since we now attempt backup path, and previous error is overwritten + if (error && *error) { + ETCoreMLLogInfo("error=%@", (*error).localizedDescription); + *error = nil; + } + ETCoreMLLogInfo("Attempting to fall back by loading model without transferring ownership"); + auto backingAsset = Asset::make(compiledModelURL, identifier, self.assetManager.fileManager, error); + if (backingAsset) { + compiledModelAsset = [[ETCoreMLAsset alloc] initWithBackingAsset:backingAsset.value()]; + } + } + } + + // compiledModelAsset can still be nil if our backup path failed + return compiledModelAsset; } @@ -585,10 +598,9 @@ - (nullable ETCoreMLAsset *)modelAssetWithMetadata:(const ModelMetadata&)metadat return nil; } - ETCoreMLModel *model = [ETCoreMLModelLoader loadModelWithContentsOfURL:compiledModelAsset.contentURL + ETCoreMLModel *model = [ETCoreMLModelLoader loadModelWithCompiledAsset:compiledModelAsset configuration:configuration metadata:metadata - assetManager:self.assetManager error:error]; if (!model) { return nil; diff --git a/backends/apple/coreml/runtime/delegate/ETCoreMLStrings.mm b/backends/apple/coreml/runtime/delegate/ETCoreMLStrings.mm index fb66f7b7c03..232e3297b76 100644 --- a/backends/apple/coreml/runtime/delegate/ETCoreMLStrings.mm +++ b/backends/apple/coreml/runtime/delegate/ETCoreMLStrings.mm @@ -101,39 +101,50 @@ + (NSString *)debugSymbolToHandlesKeyName { } + (nullable NSString *)assetsDirectoryPath { - static dispatch_once_t onceToken; - static NSString *result = nil; - dispatch_once(&onceToken, ^{ - NSArray *paths = NSSearchPathForDirectoriesInDomains(NSCachesDirectory, NSUserDomainMask, YES); - if (paths.count > 0) { - result = [paths.lastObject stringByAppendingPathComponent:self.productName]; - } - }); - - return result; + #if defined(EXECUTORCH_COREML_ASSETS_DIRECTORY_PATH) + return @(EXECUTORCH_COREML_ASSETS_DIRECTORY_PATH); + #else + static dispatch_once_t onceToken; + static NSString *result = nil; + dispatch_once(&onceToken, ^{ + NSArray *paths = NSSearchPathForDirectoriesInDomains(NSCachesDirectory, NSUserDomainMask, YES); + if (paths.count > 0) { + result = [paths.lastObject stringByAppendingPathComponent:self.productName]; + } + }); + + return result; + #endif } + (nullable NSString *)trashDirectoryPath { - static dispatch_once_t onceToken; - static NSString *result = nil; - dispatch_once(&onceToken, ^{ - result = [NSTemporaryDirectory() stringByAppendingPathComponent:self.productName]; - }); - - return result; + #if defined(EXECUTORCH_COREML_TRASH_DIRECTORY_PATH) + return @(EXECUTORCH_COREML_TRASH_DIRECTORY_PATH); + #else + static dispatch_once_t onceToken; + static NSString *result = nil; + dispatch_once(&onceToken, ^{ + result = [NSTemporaryDirectory() stringByAppendingPathComponent:self.productName]; + }); + + return result; + #endif } + (nullable NSString *)databaseDirectoryPath { - static dispatch_once_t onceToken; - static NSString *result = nil; - dispatch_once(&onceToken, ^{ - NSArray *paths = NSSearchPathForDirectoriesInDomains(NSApplicationSupportDirectory, NSUserDomainMask, YES); - if (paths.count > 0) { - result = [paths.lastObject stringByAppendingPathComponent:self.productName]; - } - }); - - return result; + #if defined(EXECUTORCH_COREML_DATABASE_DIRECTORY_PATH) + return @(EXECUTORCH_COREML_DATABASE_DIRECTORY_PATH); + #else + static dispatch_once_t onceToken; + static NSString *result = nil; + dispatch_once(&onceToken, ^{ + NSArray *paths = NSSearchPathForDirectoriesInDomains(NSApplicationSupportDirectory, NSUserDomainMask, YES); + if (paths.count > 0) { + result = [paths.lastObject stringByAppendingPathComponent:self.productName]; + } + }); + return result; + #endif } diff --git a/backends/apple/coreml/runtime/delegate/backend_delegate.mm b/backends/apple/coreml/runtime/delegate/backend_delegate.mm index 2cb274f0a89..680c5c63143 100644 --- a/backends/apple/coreml/runtime/delegate/backend_delegate.mm +++ b/backends/apple/coreml/runtime/delegate/backend_delegate.mm @@ -45,40 +45,15 @@ MLComputeUnits get_compute_units(const Buffer& buffer) { return configuration; } -NSURL * _Nullable create_directory_if_needed(NSURL *url, - NSFileManager *fileManager, - NSError * __autoreleasing *error) { - if (![fileManager fileExistsAtPath:url.path] && - ![fileManager createDirectoryAtURL:url withIntermediateDirectories:YES attributes:@{} error:error]) { - return nil; - } - - return url; -} - ETCoreMLAssetManager * _Nullable create_asset_manager(NSString *assets_directory_path, NSString *trash_directory_path, NSString *database_directory_path, NSString *database_name, NSInteger max_assets_size_in_bytes, NSError * __autoreleasing *error) { - NSFileManager *fm = [[NSFileManager alloc] init]; - NSURL *assets_directory_url = [NSURL fileURLWithPath:assets_directory_path]; - if (!create_directory_if_needed(assets_directory_url, fm, error)) { - return nil; - } - NSURL *trash_directory_url = [NSURL fileURLWithPath:trash_directory_path]; - if (!create_directory_if_needed(trash_directory_url, fm, error)) { - return nil; - } - NSURL *database_directory_url = [NSURL fileURLWithPath:database_directory_path]; - if (!create_directory_if_needed(database_directory_url, fm, error)) { - return nil; - } - NSURL *database_url = [database_directory_url URLByAppendingPathComponent:database_name]; ETCoreMLAssetManager *manager = [[ETCoreMLAssetManager alloc] initWithDatabaseURL:database_url assetsDirectoryURL:assets_directory_url diff --git a/backends/apple/coreml/scripts/build_tests.sh b/backends/apple/coreml/scripts/build_tests.sh index 190adf1f65a..0203e5027a2 100755 --- a/backends/apple/coreml/scripts/build_tests.sh +++ b/backends/apple/coreml/scripts/build_tests.sh @@ -30,7 +30,8 @@ rm -rf "$CMAKE_EXECUTORCH_BUILD_DIR_PATH" cmake "$EXECUTORCH_ROOT_PATH" -B"$CMAKE_EXECUTORCH_BUILD_DIR_PATH" \ -DCMAKE_TOOLCHAIN_FILE="$IOS_TOOLCHAIN_PATH" \ --DPLATFORM=MAC_UNIVERSAL \ +-DPLATFORM=MAC_ARM64 \ +-DCMAKE_OSX_ARCHITECTURES=arm64 \ -DDEPLOYMENT_TARGET=13.0 \ -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=OFF \ -DEXECUTORCH_BUILD_XNNPACK=OFF @@ -44,7 +45,8 @@ rm -rf "$CMAKE_PROTOBUF_BUILD_DIR_PATH" cmake "$PROTOBUF_DIR_PATH/cmake" -B"$CMAKE_PROTOBUF_BUILD_DIR_PATH" \ -DCMAKE_TOOLCHAIN_FILE="$IOS_TOOLCHAIN_PATH" \ --DPLATFORM=MAC_UNIVERSAL \ +-DPLATFORM=MAC_ARM64 \ +-DCMAKE_OSX_ARCHITECTURES=arm64 \ -DDEPLOYMENT_TARGET=13.0 \ -Dprotobuf_BUILD_TESTS=OFF \ -Dprotobuf_BUILD_EXAMPLES=OFF \ @@ -55,7 +57,8 @@ cmake --build "$CMAKE_PROTOBUF_BUILD_DIR_PATH" -j9 -t libprotobuf-lite # Copy required libraries echo "ExecuTorch: Copying libraries" -mkdir "$LIBRARIES_DIR_PATH" +rm -rf $LIBRARIES_DIR_PATH +mkdir -p "$LIBRARIES_DIR_PATH" cp -f "$CMAKE_EXECUTORCH_BUILD_DIR_PATH/libexecutorch.a" "$LIBRARIES_DIR_PATH" cp -f "$CMAKE_EXECUTORCH_BUILD_DIR_PATH/libexecutorch_core.a" "$LIBRARIES_DIR_PATH" cp -f "$CMAKE_PROTOBUF_BUILD_DIR_PATH/libprotobuf-lite.a" "$LIBRARIES_DIR_PATH" diff --git a/backends/apple/coreml/scripts/generate_test_models.sh b/backends/apple/coreml/scripts/generate_test_models.sh index 6a73d697379..bb5de781b5e 100755 --- a/backends/apple/coreml/scripts/generate_test_models.sh +++ b/backends/apple/coreml/scripts/generate_test_models.sh @@ -15,7 +15,9 @@ COREML_DIR_PATH="$EXECUTORCH_ROOT_PATH/backends/apple/coreml" cd "$EXECUTORCH_ROOT_PATH" -mkdir "$COREML_DIR_PATH/runtime/test/models/" +rm -rf "$COREML_DIR_PATH/runtime/test/models/" +mkdir -p "$COREML_DIR_PATH/runtime/test/models/" + #Generate models cd "$EXECUTORCH_ROOT_PATH" diff --git a/backends/apple/coreml/scripts/install_requirements.sh b/backends/apple/coreml/scripts/install_requirements.sh index 5ec1ea6a1de..f57df535d86 100755 --- a/backends/apple/coreml/scripts/install_requirements.sh +++ b/backends/apple/coreml/scripts/install_requirements.sh @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +set -euo pipefail + SCRIPT_DIR_PATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 pwd -P @@ -12,10 +14,16 @@ SCRIPT_DIR_PATH="$( # TODO(jathu): remove the need to fetch coremltools to build deps for coreml_executor_runner. # Keep this version in sync with: pyproject.toml -COREMLTOOLS_VERSION="9.0b1" +COREMLTOOLS_VERSION="9.0" -red=`tput setaf 1` -green=`tput setaf 2` +# Safe colors (no TERM noise in CI) +if command -v tput >/dev/null 2>&1 && [ -t 1 ] && [ -n "${TERM:-}" ]; then + red="$(tput setaf 1)" + green="$(tput setaf 2)" + reset="$(tput sgr0)" +else + red=""; green=""; reset="" +fi EXECUTORCH_ROOT_PATH=$(realpath "$SCRIPT_DIR_PATH/../../../../") COREML_DIR_PATH="$EXECUTORCH_ROOT_PATH/backends/apple/coreml" @@ -25,30 +33,79 @@ PROTOBUF_FILES_DIR_PATH="$COREMLTOOLS_DIR_PATH/build/mlmodel/format/" cd "$EXECUTORCH_ROOT_PATH" rm -rf "$COREML_DIR_PATH/third-party" -mkdir "$COREML_DIR_PATH/third-party" +mkdir -p "$COREML_DIR_PATH/third-party" -echo "${green}ExecuTorch: Cloning coremltools." -git clone --depth 1 --branch "${COREMLTOOLS_VERSION}" "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH -cd $COREMLTOOLS_DIR_PATH +echo "${green}ExecuTorch: Cloning coremltools.${reset}" +git clone --depth 1 --branch "${COREMLTOOLS_VERSION}" "https://github.com/apple/coremltools.git" "$COREMLTOOLS_DIR_PATH" +cd "$COREMLTOOLS_DIR_PATH" STATUS=$? if [ $STATUS -ne 0 ]; then - echo "${red}ExecuTorch: Failed to clone coremltools." + echo "${red}ExecuTorch: Failed to clone coremltools.${reset}" exit 1 fi -echo "${green}ExecuTorch: Installing coremltools dependencies." -pip install -r "$COREMLTOOLS_DIR_PATH/reqs/build.pip" +# --------------------------------------------------------------------- +# Host toolchain / SDK setup JUST for coremltools build +# --------------------------------------------------------------------- +HOST_SDKROOT="${SDKROOT:-}" +HOST_CC="${CC:-}" +HOST_CXX="${CXX:-}" +HOST_CFLAGS="${CFLAGS:-}" +HOST_CXXFLAGS="${CXXFLAGS:-}" + +if [[ "$(uname)" == "Darwin" ]]; then + # Only pick macOS SDK if nothing else is specified + if [[ -z "$HOST_SDKROOT" ]]; then + HOST_SDKROOT="$(xcrun --sdk macosx --show-sdk-path)" + fi + if [[ -z "$HOST_CC" ]]; then + HOST_CC="$(xcrun --find clang)" + fi + if [[ -z "$HOST_CXX" ]]; then + HOST_CXX="$(xcrun --find clang++)" + fi + # Only add -isysroot if caller didn't already set CFLAGS/CXXFLAGS + if [[ -z "$HOST_CFLAGS" && -n "$HOST_SDKROOT" ]]; then + HOST_CFLAGS="-isysroot ${HOST_SDKROOT}" + fi + if [[ -z "$HOST_CXXFLAGS" && -n "$HOST_SDKROOT" ]]; then + HOST_CXXFLAGS="-isysroot ${HOST_SDKROOT}" + fi +fi + +echo "${green}ExecuTorch: Installing coremltools dependencies.${reset}" +SDKROOT="$HOST_SDKROOT" \ +CC="$HOST_CC" \ +CXX="$HOST_CXX" \ +CFLAGS="$HOST_CFLAGS" \ +CXXFLAGS="$HOST_CXXFLAGS" \ +python -m pip install -r "$COREMLTOOLS_DIR_PATH/reqs/build.pip" STATUS=$? if [ $STATUS -ne 0 ]; then - echo "${red}ExecuTorch: Failed to install coremltools dependencies." + echo "${red}ExecuTorch: Failed to install coremltools dependencies.${reset}" exit 1 fi -mkdir "$COREMLTOOLS_DIR_PATH/build" +mkdir -p "$COREMLTOOLS_DIR_PATH/build" + +echo "${green}ExecuTorch: Configuring coremltools CMake build.${reset}" +SDKROOT="$HOST_SDKROOT" \ +CC="$HOST_CC" \ +CXX="$HOST_CXX" \ +CFLAGS="$HOST_CFLAGS" \ +CXXFLAGS="$HOST_CXXFLAGS" \ cmake -S "$COREMLTOOLS_DIR_PATH" -B "$COREMLTOOLS_DIR_PATH/build" + +echo "${green}ExecuTorch: Building mlmodel target.${reset}" +SDKROOT="$HOST_SDKROOT" \ +CC="$HOST_CC" \ +CXX="$HOST_CXX" \ +CFLAGS="$HOST_CFLAGS" \ +CXXFLAGS="$HOST_CXXFLAGS" \ cmake --build "$COREMLTOOLS_DIR_PATH/build" --parallel --target mlmodel -echo "${green}ExecuTorch: Copying protobuf files." +echo "${green}ExecuTorch: Copying protobuf files.${reset}" +rm -rf "$COREML_DIR_PATH/runtime/sdk/format/" mkdir -p "$COREML_DIR_PATH/runtime/sdk/format/" cp -rf "$PROTOBUF_FILES_DIR_PATH" "$COREML_DIR_PATH/runtime/sdk/format/" diff --git a/backends/apple/coreml/test/test_coreml_recipes.py b/backends/apple/coreml/test/test_coreml_recipes.py index 303d8cb78ed..98d240d74b5 100644 --- a/backends/apple/coreml/test/test_coreml_recipes.py +++ b/backends/apple/coreml/test/test_coreml_recipes.py @@ -326,7 +326,7 @@ def forward(self, x): ) self.check_fully_delegated(session) - self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-3) + self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-2) self._compare_eager_unquantized_model_outputs(session, model, example_inputs) def test_int8_weight_only_pt2e(self): diff --git a/backends/apple/metal/CMakeLists.txt b/backends/apple/metal/CMakeLists.txt new file mode 100644 index 00000000000..7bdf142041d --- /dev/null +++ b/backends/apple/metal/CMakeLists.txt @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Build AOTI Metal backend for runtime. +# +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# +cmake_minimum_required(VERSION 3.29) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if(NOT APPLE) + message(FATAL_ERROR "Metal backend requires macOS") +endif() + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +# Use full torch package to get library paths, but only link specific libraries +find_package_torch() + +set(_aoti_metal_sources + runtime/metal_backend.cpp + runtime/shims/memory.cpp + runtime/shims/et_metal.mm + runtime/shims/et_metal_ops.mm + runtime/shims/shim_mps.mm + runtime/shims/tensor_attribute.cpp + runtime/shims/utils.cpp +) + +add_library(metal_backend STATIC ${_aoti_metal_sources}) +target_include_directories( + metal_backend + PUBLIC $ $ + # PyTorch AOTI headers from ExecutorTorch's torch detection + ${TORCH_INCLUDE_DIRS} +) + +# Link Metal framework +find_library(METAL_LIBRARY Metal REQUIRED) +find_library(FOUNDATION_LIBRARY Foundation REQUIRED) +find_library(METALPERFORMANCESHADERS_LIBRARY MetalPerformanceShaders REQUIRED) +find_library( + METALPERFORMANCESHADERSGRAPH_LIBRARY MetalPerformanceShadersGraph REQUIRED +) +target_link_libraries( + metal_backend + PUBLIC ${METAL_LIBRARY} ${FOUNDATION_LIBRARY} + ${METALPERFORMANCESHADERS_LIBRARY} + ${METALPERFORMANCESHADERSGRAPH_LIBRARY} +) + +target_compile_options(metal_backend PUBLIC -fexceptions -frtti -fPIC) + +target_link_options(metal_backend PUBLIC -Wl,-export_dynamic) + +# Find PyTorch's OpenMP library specifically for libtorch-less AOTI +get_torch_base_path(TORCH_BASE_PATH) +find_library( + TORCH_OMP_LIBRARY + NAMES omp libomp + PATHS "${TORCH_BASE_PATH}/lib" + NO_DEFAULT_PATH +) + +if(TORCH_OMP_LIBRARY) + message(STATUS "Found PyTorch OpenMP library: ${TORCH_OMP_LIBRARY}") + # Get the directory containing the OpenMP library for rpath + get_filename_component(TORCH_OMP_LIB_DIR ${TORCH_OMP_LIBRARY} DIRECTORY) + message(STATUS "OpenMP library directory: ${TORCH_OMP_LIB_DIR}") +else() + message( + WARNING "PyTorch OpenMP library not found, may cause runtime linking issues" + ) +endif() + +# Link against appropriate backends and standard libraries +target_link_libraries( + metal_backend PUBLIC aoti_common extension_tensor ${CMAKE_DL_LIBS} + ${TORCH_OMP_LIBRARY} +) + +# Set rpath for OpenMP library to avoid runtime linking issues +if(TORCH_OMP_LIBRARY AND TORCH_OMP_LIB_DIR) + # Add the OpenMP library directory to the rpath + set_target_properties( + metal_backend PROPERTIES BUILD_RPATH "${TORCH_OMP_LIB_DIR}" + INSTALL_RPATH "${TORCH_OMP_LIB_DIR}" + ) + # Also try common OpenMP library locations + target_link_options( + metal_backend PUBLIC -Wl,-rpath,${TORCH_OMP_LIB_DIR} + -Wl,-rpath,/usr/local/opt/libomp/lib + -Wl,-rpath,/opt/homebrew/opt/libomp/lib + ) + message(STATUS "Added rpath for OpenMP library: ${TORCH_OMP_LIB_DIR}") +endif() + +executorch_target_link_options_shared_lib(metal_backend) +install( + TARGETS metal_backend + EXPORT ExecuTorchTargets + DESTINATION lib +) diff --git a/backends/apple/metal/README.md b/backends/apple/metal/README.md new file mode 100644 index 00000000000..0f010ae8920 --- /dev/null +++ b/backends/apple/metal/README.md @@ -0,0 +1,5 @@ +# Metal Backend + +⚠️ **EXPERIMENTAL BACKEND** + +This backend is currently in experimental development and may not be fully functional or stable. Use with caution. diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py new file mode 100644 index 00000000000..1d86cfb8447 --- /dev/null +++ b/backends/apple/metal/metal_backend.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import typing +from typing import Any, Dict, final, List + +from executorch.backends.aoti.aoti_backend import AotiBackend +from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import BackendDetails +from executorch.exir.backend.compile_spec_schema import CompileSpec + + +@final +@experimental( + "This API and all of Metal backend related functionality are experimental." +) +class MetalBackend(AotiBackend, BackendDetails): + """ + MetalBackend is a backend that compiles a model to run on Metal/MPS devices. It uses the AOTInductor compiler to generate + optimized Metal kernels for the model's operators with libtorch-free. The compiled model can be executed on Metal devices + using the Executorch runtime. + """ + + @classmethod + def get_device_name(cls) -> str: + return "metal" + + @classmethod + def get_supported_fallback_kernels(cls) -> Dict[str, Any]: + return { + "aoti_torch_mps_addmm_out": None, + "aoti_torch_mps_convolution": None, + "aoti_torch_mps_mm_out": None, + "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, + } + + @classmethod + def get_decomposition_table(cls) -> Dict[Any, Any]: + return {} + + @classmethod + def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]: + """Return Metal-specific passes (currently none)""" + return [] + + @classmethod + def get_aoti_compile_options( + cls, compile_specs: List[CompileSpec] + ) -> Dict[str, typing.Any]: + """Get AOTI compile options for Metal backend.""" + _ = compile_specs # Unused, but required by interface + return { + # Do not link against the full PyTorch/libtorch library + "aot_inductor.link_libtorch": False, + # Separate weight constants from the .so file + "aot_inductor.package": True, + "aot_inductor.package_constants_in_so": False, + # Store weight constants on disk in a binary blob + "aot_inductor.package_constants_on_disk_format": "binary_blob", + # Enable maximum automatic tuning for optimal performance + "max_autotune": True, + # "aot_inductor.debug_compile": True, + # "aot_inductor.force_mmap_weights": False, + } diff --git a/backends/apple/metal/metal_partitioner.py b/backends/apple/metal/metal_partitioner.py new file mode 100644 index 00000000000..e2672f6b554 --- /dev/null +++ b/backends/apple/metal/metal_partitioner.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import final, List + +from executorch.backends.aoti.aoti_partitioner import AotiPartitioner +from executorch.backends.apple.metal.metal_backend import MetalBackend # usort: skip +from executorch.exir._warnings import experimental +from executorch.exir.backend.compile_spec_schema import CompileSpec + + +@final +@experimental( + "This API and all of Metal backend related functionality are experimental." +) +class MetalPartitioner(AotiPartitioner): + """ + Metal partitioner driven by AOTInductor backend. + """ + + def __init__(self, compile_spec: List[CompileSpec]) -> None: + super().__init__(MetalBackend.__name__, compile_spec) diff --git a/backends/apple/metal/runtime/metal_backend.cpp b/backends/apple/metal/runtime/metal_backend.cpp new file mode 100644 index 00000000000..f79a2a67b6f --- /dev/null +++ b/backends/apple/metal/runtime/metal_backend.cpp @@ -0,0 +1,572 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// Include AOTI common headers (from aoti_common library) +#include +#include + +// Include our Metal-specific shim layer headers +#include +#include +#include +#include +#include + +namespace executorch::backends::metal { + +#define LOAD_SYMBOL(handle, member, name, so_handle) \ + do { \ + handle->member = reinterpret_cast(dlsym(so_handle, #name)); \ + ET_CHECK_OR_RETURN_ERROR( \ + handle->member != nullptr, AccessFailed, "Failed to load " #name); \ + } while (0) + +using namespace std; +using namespace aoti; + +using executorch::aten::ScalarType; +using executorch::runtime::ArrayRef; +using executorch::runtime::Backend; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::NamedDataMap; +using executorch::runtime::Result; +using executorch::runtime::Span; +using executorch::runtime::etensor::Tensor; + +class ET_EXPERIMENTAL MetalBackend final + : public ::executorch::runtime::BackendInterface { + private: + Error load_function_pointers_into_handle( + void* so_handle, + AOTIDelegateHandle* handle) const { + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loading symbols"); + + LOAD_SYMBOL( + handle, + create_with_device, + AOTInductorModelContainerCreateWithDevice, + so_handle); + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerCreateWithDevice"); + + LOAD_SYMBOL( + handle, delete_container, AOTInductorModelContainerDelete, so_handle); + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerDelete"); + + LOAD_SYMBOL( + handle, + get_num_inputs, + AOTInductorModelContainerGetNumInputs, + so_handle); + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerGetNumInputs"); + + LOAD_SYMBOL( + handle, + get_num_outputs, + AOTInductorModelContainerGetNumOutputs, + so_handle); + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerGetNumOutputs"); + + LOAD_SYMBOL(handle, run, AOTInductorModelContainerRun, so_handle); + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerRun"); + + LOAD_SYMBOL( + handle, + update_constants_from_blob, + AOTInductorModelUpdateConstantsFromBlob, + so_handle); + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelUpdateConstantsFromBlob"); + + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - All symbols loaded successfully"); + return Error::Ok; + } + + public: + // Once in program + MetalBackend() { + ET_LOG(Debug, "MetalBackend ctor"); + } + + bool is_available() const override { + return 1; + } + + // Once per loaded binary blob + Result init( + BackendInitContext& context, + FreeableBuffer* processed, // This will be a empty buffer + ArrayRef compile_specs // This will be my empty list + ) const override { + ET_LOG(Info, "MetalBackend::init - Starting initialization"); + + std::string method_name; + for (const CompileSpec& spec : compile_specs) { + if (std::strcmp(spec.key, "method_name") == 0) { + method_name.assign( + static_cast(spec.value.buffer), + spec.value.nbytes); // no nullptr guarantee, so pass size + break; + } + } + + std::string so_blob_key = + method_name.empty() ? "so_blob" : method_name + "_so_blob"; + ET_LOG(Info, "MetalBackend::init - so_blob_key: %s", so_blob_key.c_str()); + + const NamedDataMap* named_data_map = context.get_named_data_map(); + ET_LOG(Info, "MetalBackend::init - Got named data map: %p", named_data_map); + + ET_LOG( + Info, + "MetalBackend::init - Looking for blob key: %s", + so_blob_key.c_str()); + + auto aoti_metal_buffer = named_data_map->get_data(so_blob_key.c_str()); + ET_CHECK_OR_RETURN_ERROR( + aoti_metal_buffer.ok(), + Internal, + "Failed to get data for key %s: 0x%x", + so_blob_key.c_str(), + static_cast(aoti_metal_buffer.error())); + + ET_LOG( + Info, + "MetalBackend::init - Buffer is OK, size: %zu", + aoti_metal_buffer->size()); + + if (aoti_metal_buffer->data() == nullptr) { + ET_LOG(Error, "MetalBackend::init - Buffer data is null"); + return Error::InvalidArgument; + } + + ET_LOG( + Info, + "MetalBackend::init - Buffer data pointer: %p", + aoti_metal_buffer->data()); + + // Generate dynamic temporary file path + filesystem::path temp_dir = filesystem::temp_directory_path(); + filesystem::path so_path = + temp_dir / (so_blob_key + to_string(getpid()) + ".so"); + + // Create a temporary file + ET_LOG( + Info, "MetalBackend::init - Creating temp file: %s", so_path.c_str()); + ofstream outfile(so_path.c_str(), ios::binary); + + // Write the ELF buffer to the temporary file + ET_LOG( + Info, + "Writing %zu bytes to %s", + aoti_metal_buffer->size(), + so_path.c_str()); + + outfile.write( + static_cast(aoti_metal_buffer->data()), + aoti_metal_buffer->size()); + + ET_CHECK_OR_RETURN_ERROR( + outfile, AccessFailed, "Failed to write to file %s", so_path.c_str()); + + // Finish writing the file to disk + outfile.close(); + ET_LOG(Info, "MetalBackend::init - File closed successfully"); + + // Free the buffer immediately after writing to disk + aoti_metal_buffer->Free(); + + // Load the ELF using dlopen + void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL); + ET_CHECK_OR_RETURN_ERROR( + so_handle != nullptr, + AccessFailed, + "Failed to load shared library: %s", + dlerror()); + + processed->Free(); + + // Create handle and load function pointers into it + AOTIDelegateHandle* handle = new AOTIDelegateHandle(); + handle->so_handle = so_handle; + handle->so_path = so_path.string(); + + // Load function pointers specific to this handle's shared library + ET_CHECK_OK_OR_RETURN_ERROR( + load_function_pointers_into_handle(so_handle, handle)); + + AOTInductorModelContainerHandle container_handle = nullptr; + ET_LOG( + Info, + "MetalBackend::init - About to create AOTI container with device='mps'"); + + ET_CHECK_OK_OR_RETURN_ERROR( + handle->create_with_device(&container_handle, 1, "mps", nullptr)); + + ET_LOG(Info, "container_handle = %p", container_handle); + + handle->container_handle = container_handle; + + // Look into named data map for constant data + std::string weights_blob_key = + method_name.empty() ? "weights_blob" : method_name + "_weights_blob"; + auto buffer_res = named_data_map->get_data(weights_blob_key.c_str()); + if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) { + ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str()); + const void* weights_blob = buffer_res->data(); + // Feed the weights blob into the container. Under the hood it's copying + // weights, so we should free the buffer immediately. + ET_CHECK_OK_OR_RETURN_ERROR(handle->update_constants_from_blob( + handle->container_handle, static_cast(weights_blob))); + buffer_res->Free(); + } + + ET_LOG(Info, "MetalBackend::init - Initialization completed successfully"); + return (DelegateHandle*)handle; // Return the handle post-processing + } + + // Once per execution + Error execute( + BackendExecutionContext& context, + DelegateHandle* handle_, + Span args) const override { + ET_LOG(Debug, "MetalBackend execute"); + + AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + + ET_LOG(Debug, "MetalBackend Handle generated"); + + size_t n_inputs; + handle->get_num_inputs(handle->container_handle, &n_inputs); + + size_t n_outputs; + handle->get_num_outputs(handle->container_handle, &n_outputs); + + ET_LOG(Debug, "MetalBackend n_outputs %zd generated", n_outputs); + + ET_CHECK_OR_RETURN_ERROR( + n_inputs + n_outputs == args.size(), + InvalidArgument, + "number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.", + n_inputs, + n_outputs, + args.size()) + + ET_LOG( + Debug, + "number of user input %zd and output %zd generated from AOT Inductor matches ET runner's %zd.", + n_inputs, + n_outputs, + args.size()); + + int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13 + + // NOTE: ExecutorTorch tensors are always on CPU/host memory + // We need to create GPU copies for Metal kernel execution + std::vector gpu_inputs( + n_inputs); // GPU copies for kernel execution + std::vector gpu_outputs( + n_outputs); // GPU tensors for kernel output + + ET_LOG(Debug, "MetalBackend input/output vectors generated"); + + // Process input tensors: ExecutorTorch provides CPU tensors, create GPU + // copies + for (int i = 0; i < n_inputs; i++) { + ET_LOG(Debug, "Processing input %d from args to inputs vector", i); + ET_LOG( + Debug, "is %d input a tensor input? %d", i, int(args[i]->isTensor())); + + // Get tensor dimensions and properties from ExecutorTorch CPU tensor + auto cpu_tensor = &(args[i]->toTensor()); + auto sizes = cpu_tensor->sizes(); + auto scalar_type = cpu_tensor->scalar_type(); + ET_LOG( + Debug, + "MetalBackend input %d scalar_type=%d", + i, + static_cast(scalar_type)); + + // Create GPU tensor with same shape + std::vector sizes_vec(sizes.begin(), sizes.end()); + + AOTITensorHandle gpu_input_handle; + Error create_err = aoti_torch_empty_strided( + sizes_vec.size(), + sizes_vec.data(), + nullptr, // use default strides + static_cast(scalar_type), + mps_device_type, // device_type = mps + 0, // device_index = 0 + &gpu_input_handle); + + if (create_err != Error::Ok) { + ET_LOG(Error, "Failed to create GPU tensor for input %d", i); + return Error::Internal; + } + + // Log the created GPU tensor scalar type + auto gpu_tensor = reinterpret_cast( + gpu_input_handle); + ET_LOG( + Debug, + "MetalBackend created GPU tensor %d scalar_type=%d", + i, + static_cast(gpu_tensor->scalar_type())); + + gpu_inputs[i] = gpu_input_handle; + + // Log the CPU tensor data before copying to GPU + void* cpu_data = cpu_tensor->mutable_data_ptr(); + if (cpu_data && cpu_tensor->numel() > 0) { + float* cpu_float_data = (float*)cpu_data; + ET_LOG( + Debug, + "CPU input %d data before copy: [%.3f, %.3f, %.3f, ...] (numel=%zd)", + i, + cpu_float_data[0], + cpu_float_data[1], + cpu_float_data[2], + cpu_tensor->numel()); + } + + // Copy data from CPU to GPU + Error copy_err = aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0); + if (copy_err != Error::Ok) { + ET_LOG(Error, "Failed to copy input %d from CPU to GPU", i); + return Error::Internal; + } + + // Log the GPU tensor scalar type after copy + auto gpu_tensor_after = + reinterpret_cast( + gpu_inputs[i]); + ET_LOG( + Debug, + "MetalBackend GPU tensor %d scalar_type after copy=%d", + i, + static_cast(gpu_tensor_after->scalar_type())); + + ET_LOG(Debug, "Successfully copied input %d from CPU to GPU", i); + } + + ET_LOG(Debug, "MetalBackend GPU inputs generated"); + + // Process output tensors: create GPU counterparts for ExecutorTorch CPU + // tensors + for (int i = 0; i < n_outputs; i++) { + // Get output tensor dimensions from ExecutorTorch CPU tensor + auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); + auto sizes = cpu_output_tensor->sizes(); + auto scalar_type = cpu_output_tensor->scalar_type(); + ET_LOG( + Debug, + "MetalBackend output %d scalar_type=%d", + i, + static_cast(scalar_type)); + + // Create GPU tensor with same shape for kernel output + std::vector sizes_vec(sizes.begin(), sizes.end()); + + AOTITensorHandle gpu_output_handle; + Error create_err = aoti_torch_empty_strided( + sizes_vec.size(), + sizes_vec.data(), + nullptr, // use default strides + static_cast(scalar_type), + mps_device_type, // device_type = mps + 0, // device_index = 0 + &gpu_output_handle); + + if (create_err != Error::Ok) { + ET_LOG(Error, "Failed to create GPU tensor for output %d", i); + return Error::Internal; + } + + gpu_outputs[i] = gpu_output_handle; + ET_LOG(Debug, "Created GPU output tensor %d", i); + } + + ET_LOG(Debug, "MetalBackend output generated"); + + // Log tensor handles before passing to AOTI container + ET_LOG(Debug, "Passing to AOTInductorModelContainerRun:"); + for (int i = 0; i < n_inputs; i++) { + void* gpu_input_data = gpu_inputs[i]->mutable_data_ptr(); + ET_LOG( + Debug, + " gpu_inputs[%d] = %p, data_ptr = %p", + i, + gpu_inputs[i], + gpu_input_data); + } + for (int i = 0; i < n_outputs; i++) { + void* gpu_output_data = gpu_outputs[i]->mutable_data_ptr(); + ET_LOG( + Debug, + " gpu_outputs[%d] = %p, data_ptr = %p", + i, + gpu_outputs[i], + gpu_output_data); + } + + // Run AOTI container with GPU tensors + AOTIRuntimeError error = handle->run( + handle->container_handle, + gpu_inputs.data(), // Use GPU input tensors + n_inputs, + gpu_outputs.data(), // Use GPU output tensors + n_outputs, + nullptr, // Pass the actual Metal stream! + nullptr); // proxy_executor_handle can remain nullptr + + if (error != Error::Ok) { + ET_LOG( + Error, + "AOTInductorModelContainerRun failed with error code %d", + error); + return Error::Internal; + } + + // Ensure all GPU work is completed before reading results + try { + synchronize_metal_stream(); + } catch (const std::exception& e) { + ET_LOG( + Error, + "Failed to synchronize Metal stream after kernel execution: %s", + e.what()); + return Error::Internal; + } catch (...) { + ET_LOG( + Error, + "Failed to synchronize Metal stream after kernel execution: unknown exception"); + return Error::Internal; + } + + ET_LOG(Debug, "MetalBackend running done and synchronized"); + + // Copy GPU output results back to CPU output tensors + for (int i = 0; i < n_outputs; i++) { + auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); + // For DYNAMIC_BOUND tensors we try to resize + ET_CHECK_OK_OR_RETURN_ERROR( + resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()), + "Error resizing tensor at output index %d", + i); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0), + "Failed to copy GPU output %d back to CPU", + i); + ET_LOG(Debug, "Copied GPU output %d back to CPU", i); + } + + // Clean up GPU tensors that we created (ExecutorTorch tensors are always + // CPU, so all GPU tensors are our copies) + for (int i = 0; i < n_inputs; i++) { + // All GPU input tensors were created by us, delete them + aoti_torch_delete_tensor_object(gpu_inputs[i]); + } + + for (int i = 0; i < n_outputs; i++) { + // All GPU output tensors were created by us, delete them + aoti_torch_delete_tensor_object(gpu_outputs[i]); + } + + ET_LOG(Debug, "MetalBackend execution completed successfully"); + + return Error::Ok; + } + + void destroy(DelegateHandle* handle_) const override { + if (handle_ == nullptr) { + return; + } + AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + + // NOTE: AOTInductorModelContainerDelete does not work correctly with + // multiple .so files. Deleting one container frees shared resources, + // which causes segmentation faults when attempting to delete other + // containers. As a workaround, we skip explicit container deletion + // and defer cleanup to the OS. + // TODO: Find a proper solution for safe container deletion. + // AOTInductorModelContainerDelete(handle->container_handle); + + // Now close the shared library + if (handle->so_handle != nullptr) { + dlclose(handle->so_handle); + } + + // Remove the temporary shared library file + if (!handle->so_path.empty()) { + std::error_code remove_error; + std::filesystem::remove(handle->so_path, remove_error); + ET_CHECK_OR_LOG_ERROR( + !remove_error, + "Failed to remove temporary shared library %s: %s", + handle->so_path.c_str(), + remove_error.message().c_str()); + if (!remove_error) { + ET_LOG( + Info, + "Removed temporary shared library file: %s", + handle->so_path.c_str()); + } + } + + delete handle; + cleanup_memory(); + executorch::backends::aoti::cleanup_tensor_metadata(); + ET_LOG(Debug, "MetalBackend handle %p destroy", handle_); + } +}; + +} // namespace executorch::backends::metal + +namespace executorch::backends { +namespace { +auto cls = metal::MetalBackend(); +executorch::runtime::Backend backend{"MetalBackend", &cls}; +static executorch::runtime::Error success_with_compiler = + register_backend(backend); +} // namespace +} // namespace executorch::backends diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h new file mode 100644 index 00000000000..1c61499b242 --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -0,0 +1,399 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef __OBJC__ +#import +#import +#include +// Forward declarations for MetalPerformanceShadersGraph types +@class MPSGraph; +@class MPSCommandBuffer; +// Metal type definitions for Objective-C compilation +typedef id MTLDevice_t; +typedef id MTLCommandQueue_t; +typedef id MTLCommandBuffer_t; +typedef id MTLComputeCommandEncoder_t; +typedef id MTLComputePipelineState_t; +typedef id MTLFunction_t; +typedef id MTLLibrary_t; +typedef id MTLBuffer_t; +typedef dispatch_queue_t dispatch_queue_t; +typedef MPSGraph* MPSGraph_t; +typedef MPSCommandBuffer* MPSCommandBuffer_t; +typedef NSDictionary* NSDictionary_t; +#else +// Forward declarations for C++ compilation +typedef void* MTLDevice_t; +typedef void* MTLCommandQueue_t; +typedef void* MTLCommandBuffer_t; +typedef void* MTLComputeCommandEncoder_t; +typedef void* MTLComputePipelineState_t; +typedef void* MTLFunction_t; +typedef void* MTLLibrary_t; +typedef void* MTLBuffer_t; +typedef void* dispatch_queue_t; +typedef void* MPSGraph_t; +typedef void* MPSCommandBuffer_t; +typedef void* NSDictionary_t; +#endif + +#include +#include +#include +#include +#include + +namespace executorch::runtime::etensor { +class Tensor; +} + +namespace executorch { +namespace backends { +namespace metal { + +// Forward declarations +class ETMetalKernelFunction; +class ETMetalStream; + +// ======================= +// SyncType - Metal synchronization options +// ======================= +enum class SyncType { + NONE, // no commit to command buffer + COMMIT, // commit and flush the command buffer + COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish + COMMIT_AND_CONTINUE, // commit and continue with a new underlying command + // buffer + COMMIT_ADAPTIVE, // commit adaptively based on available memory +}; + +// ======================= +// ETMetalShaderLibrary - ExecuTorch Metal shader library management +// ======================= + +/** + * @class ETMetalShaderLibrary + * @brief Manages Metal shader library compilation and kernel function + * retrieval. + * + * This class provides a high-level interface for compiling Metal shading + * language source code into a Metal library and creating compute pipeline + * states for kernel functions. It handles the creation and caching of Metal + * compute pipeline states and functions, which should be reused across multiple + * kernel dispatches. + * + * The class automatically compiles the provided shader source code upon + * construction and maintains an internal cache of compute pipeline states for + * different kernel functions to avoid redundant compilation. + * + * Example usage: + * @code + * std::string shaderSource = R"( + * #include + * using namespace metal; + * kernel void my_kernel(device float* data [[buffer(0)]], + * uint tid [[thread_position_in_grid]]) { + * data[tid] = data[tid] * 2.0; + * } + * )"; + * + * ETMetalShaderLibrary library(shaderSource); + * auto kernelFunction = library.getKernelFunction("my_kernel"); + * @endcode + */ +class ETMetalShaderLibrary { + public: + ETMetalShaderLibrary(const std::string& source); + ~ETMetalShaderLibrary(); + + std::shared_ptr getKernelFunction( + const std::string& name); + + private: + void compileLibrary(); + std::pair getLibraryPipelineState( + const std::string& functionName); + + friend class ETMetalKernelFunction; + + std::string shaderSource_; + MTLLibrary_t library_; + std::unordered_map< + std::string, + std::pair> + pipelineStates_; +}; + +// ======================= +// ETMetalKernelFunction - ExecuTorch Metal kernel function execution +// ======================= + +/** + * @class ETMetalKernelFunction + * @brief Represents a Metal compute kernel function ready for execution. + * + * This class encapsulates a Metal compute pipeline state and function, + * providing a high-level interface for setting kernel arguments and dispatching + * compute work to the GPU. It handles the encoding of compute commands and + * manages the interaction with Metal's compute command encoder. + * + * The class supports different dispatch patterns: + * - Single-dimension dispatch for linear workloads + * - Multi-dimensional dispatch for grid-based workloads + * - Custom thread group sizes for performance optimization + * + * Kernel arguments can be set using tensors (which will be mapped to Metal + * buffers) or scalar values. The class handles the encoding of these arguments + * into the compute command encoder. + * + * Example usage: + * @code + * // Get kernel function from library + * auto kernelFunction = library.getKernelFunction("vector_add"); + * + * // Start encoding commands + * kernelFunction->startEncoding(); + * + * // Set tensor arguments + * kernelFunction->setArg(0, inputTensorA); + * kernelFunction->setArg(1, inputTensorB); + * kernelFunction->setArg(2, outputTensor); + * + * // Set scalar argument + * kernelFunction->setArg(3, static_cast(numElements)); + * + * // Dispatch for linear workload + * kernelFunction->dispatchSingle(numElements); + * @endcode + */ +class ETMetalKernelFunction { + public: + ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func); + ~ETMetalKernelFunction(); + + void startEncoding(); + void setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor); + void setArg(unsigned idx, int64_t val); + void setArg(unsigned idx, uint32_t val); + void setArg(unsigned idx, float val); + void setArg(unsigned idx, bool val); + void setArg(unsigned idx, const void* data, size_t size); + + // Helper for Metal uint3 struct + void setArgUint3(unsigned idx, uint32_t x, uint32_t y, uint32_t z); + + void dispatchSingle(uint64_t length); + void dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size); + void dispatchArray(const uint64_t* length, size_t length_size); + void dispatchArrayWithGroupSize( + const uint64_t* length, + size_t length_size, + const uint64_t* group_size, + size_t group_size_size); + + // Dispatch with explicit threadgroup count (not thread count) + void dispatchThreadgroups( + uint64_t gridX, + uint64_t gridY, + uint64_t gridZ, + uint64_t threadsX, + uint64_t threadsY, + uint64_t threadsZ); + + void runCommandBlock(std::function f); + + private: + MTLComputePipelineState_t cps_; + MTLFunction_t func_; + MTLComputeCommandEncoder_t encoder_; +}; + +// ======================= +// ETMetalStream - Metal command buffer and synchronization management +// ======================= + +/** + * @class ETMetalStream + * @brief Manages Metal compute command streams and provides GPU + * synchronization. + * + * This class serves as the central management hub for Metal GPU operations, + * providing a stream-based abstraction similar to CUDA streams. It handles + * command buffer lifecycle, compute command encoder management, and various + * synchronization patterns required for efficient GPU computation. + * + * Key features: + * - Lazy command buffer and encoder creation for optimal resource usage + * - Thread-safe operations using serial dispatch queues + * - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, + * COMMIT_AND_CONTINUE, etc.) + * - Kernel coalescing to batch multiple operations efficiently + * - MPSGraph integration for executing fall back operations (mm, conv, sdpa) + * - Memory operations (copy, fill) with GPU acceleration via blit encoders + * + * The stream follows PyTorch's MPS stream design patterns, providing similar + * semantics for command buffer management and synchronization. + * + * Example usage: + * @code + * // Get current stream (typically the default stream) + * ETMetalStream* stream = getCurrentMetalStream(); + * + * // Execute kernel operations (handled automatically) + * auto kernelFunction = library.getKernelFunction("my_kernel"); + * kernelFunction->startEncoding(); + * kernelFunction->setArg(0, inputTensor); + * kernelFunction->dispatchSingle(numElements); + * + * // Synchronize to ensure completion + * stream->synchronize(SyncType::COMMIT_AND_WAIT); + * + * // Copy between GPU buffers using blit encoder + * stream->copy(srcBuffer, dstBuffer, numBytes, 0, 0, SyncType::COMMIT); + * @endcode + */ +class ETMetalStream { + public: + ETMetalStream(); + ~ETMetalStream(); + + // Get the default stream (singleton) + static ETMetalStream* getDefaultStream(); + + // Device and queue access + MTLDevice_t device() const { + return device_; + } + MTLCommandQueue_t commandQueue() const { + return commandQueue_; + } + dispatch_queue_t queue() const { + return serialQueue_; + } + + // Synchronization methods + void synchronize(SyncType syncType = SyncType::COMMIT_AND_WAIT); + void synchronize(); // Overload for backward compatibility + bool isEmpty() const; + + // Command buffer management with lazy creation + MPSCommandBuffer_t commandBuffer(); + MTLComputeCommandEncoder_t commandEncoder(); + + void endKernelCoalescing(); + + // MPSGraph execution + void executeMPSGraph( + MPSGraph_t mpsGraph, + NSDictionary_t feeds, + NSDictionary_t results, + SyncType syncType = SyncType::COMMIT_ADAPTIVE); + + // Command buffer lifecycle management + void commitCommandBuffer(MTLCommandBuffer_t commandBuffer); + void flush(); + + // Memory operations + void fill( + MTLBuffer_t buffer, + uint8_t value, + size_t length, + size_t offset, + SyncType syncType = SyncType::NONE); + void copy( + MTLBuffer_t srcBuffer, + MTLBuffer_t dstBuffer, + size_t length, + size_t srcOffset, + size_t dstOffset, + SyncType syncType = SyncType::NONE); + + private: + // Private synchronization methods + void commit(); + void commitAndWait(); + void commitAndContinue(); + + private: + // Private members + MTLDevice_t device_; + MTLCommandQueue_t commandQueue_; + MPSCommandBuffer_t commandBuffer_; + MPSCommandBuffer_t prevCommandBuffer_; // For commit-and-continue pattern + MTLComputeCommandEncoder_t commandEncoder_; + dispatch_queue_t serialQueue_; // For thread safety + + // Configuration + bool enableCommitAndContinue_; + + // Singleton instance + static ETMetalStream* defaultStream_; +}; + +// ======================= +// Global storage management functions +// ======================= +void storeFunctionHandle( + ETMetalKernelFunction* raw_function, + std::shared_ptr function_shared_ptr); +void storeLibraryHandle( + ETMetalShaderLibrary* raw_library, + std::unique_ptr library); +bool removeFunctionHandle(ETMetalKernelFunction* raw_function); +bool removeLibraryHandle(ETMetalShaderLibrary* raw_library); + +// ======================= +// Global stream access functions +// ======================= +ETMetalStream* getCurrentMetalStream(); +void setCurrentMetalStream(ETMetalStream* stream); + +// ======================= +// Metal stream synchronization functions (C++ interface with exceptions) +// ======================= +void synchronize_metal_stream(); +void synchronize_metal_stream_with_type(int sync_type); + +// ======================= +// Metal helper functions (C interface) +// ======================= +#ifdef __cplusplus +extern "C" { +#endif + +// Memory management functions for Metal +void* metal_allocate_buffer(long bytes); +void metal_deallocate_buffer(void* ptr); +bool metal_is_device_pointer(void* ptr); +int metal_copy_memory( + void* dst, + const void* src, + size_t nbytes, + bool src_is_device, + bool dst_is_device); +void metal_cleanup_resources(); + +// Helper functions to access Metal objects +MTLDevice_t get_metal_device(); +MTLCommandQueue_t get_metal_command_queue(); + +#ifdef __cplusplus +} + +// C++ only - expose the Metal buffer mapping +#ifdef __OBJC__ +extern std::unordered_map ptr_to_mtl_buffer; +#endif + +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm new file mode 100644 index 00000000000..f7d37c152ce --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -0,0 +1,997 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import +#import +#import +#import +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// ======================= +// Exception-Safe Dispatch Function (similar to PyTorch MPS) +// ======================= + +void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) { + __block std::optional block_exception; + dispatch_sync(queue, ^() { + try { + block(); + } catch (...) { + block_exception = std::current_exception(); + } + }); + if (block_exception) { + std::rethrow_exception(*block_exception); + } +} + +// ======================= +// Global Variables and Storage +// ================ + + +// Global Metal buffer mapping - accessible for MPS shim +std::unordered_map> ptr_to_mtl_buffer; + +// Global storage to keep shared_ptr alive while raw pointers are used +static std::unordered_map> function_storage; +static std::unordered_map> library_storage; + +// Static singleton instance for default stream +ETMetalStream* ETMetalStream::defaultStream_ = nullptr; + +// Thread-local current stream +static thread_local ETMetalStream* currentStream_ = nullptr; + +// ======================= +// Metal Helper Functions (C Interface) +// ======================= + +extern "C" { + +void* metal_allocate_buffer(long bytes) { + ETMetalStream* stream = getCurrentMetalStream(); + id device = stream->device(); + if (!device) { + ET_LOG(Error, "Failed to get Metal device from stream"); + return nullptr; + } + + @autoreleasepool { + id buffer = [device newBufferWithLength:bytes options:MTLResourceStorageModeShared]; + if (!buffer) { + ET_LOG(Error, "Failed to allocate %ld bytes on Metal device", bytes); + return nullptr; + } + + void* ptr = [buffer contents]; + ptr_to_mtl_buffer[ptr] = buffer; + + ET_LOG(Debug, "Allocated %ld bytes on Metal device", bytes); + return ptr; + } +} + +void metal_deallocate_buffer(void* ptr) { + @autoreleasepool { + auto it = ptr_to_mtl_buffer.find(ptr); + if (it != ptr_to_mtl_buffer.end()) { + id buffer = it->second; + [buffer release]; + ptr_to_mtl_buffer.erase(it); + ET_LOG(Debug, "Deallocated Metal buffer for pointer %p", ptr); + ptr = nullptr; + } else { + ET_LOG(Error, "Failed to find Metal buffer for pointer %p", ptr); + } + } +} + +void metal_cleanup_resources() { + if (!ptr_to_mtl_buffer.empty()) { + @autoreleasepool { + for (auto& pair : ptr_to_mtl_buffer) { + pair.second = nil; + } + ptr_to_mtl_buffer.clear(); + } + } +} + +bool metal_is_device_pointer(void* ptr) { + return ptr_to_mtl_buffer.find(ptr) != ptr_to_mtl_buffer.end(); +} + +int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_device, bool dst_is_device) { + if (!src || !dst || nbytes == 0) { + ET_LOG(Error, "Metal copy: Invalid parameters"); + return -1; + } + + @autoreleasepool { + // Case 1: Device-to-device copy - use GPU blit encoder (most efficient) + if (src_is_device && dst_is_device) { + auto src_it = ptr_to_mtl_buffer.find(const_cast(src)); + auto dst_it = ptr_to_mtl_buffer.find(dst); + + if (src_it != ptr_to_mtl_buffer.end() && dst_it != ptr_to_mtl_buffer.end()) { + id srcBuffer = src_it->second; + id dstBuffer = dst_it->second; + + // Calculate offsets relative to buffer base + size_t srcOffset = static_cast(src) - static_cast([srcBuffer contents]); + size_t dstOffset = static_cast(dst) - static_cast([dstBuffer contents]); + + // Use Metal's blit encoder for GPU-accelerated copy + ETMetalStream* stream = getCurrentMetalStream(); + stream->copy(srcBuffer, dstBuffer, nbytes, srcOffset, dstOffset, SyncType::NONE); + + ET_LOG(Debug, "Metal device-to-device copy (GPU blit): %zu bytes", nbytes); + return 0; + } + + ET_LOG(Error, "Metal copy: Device pointers not found in buffer map"); + return -1; + } + + // Case 2: Host-to-device or device-to-host - use memcpy with shared memory + // Since Metal uses shared storage mode, CPU and GPU access the same memory + std::memcpy(dst, src, nbytes); + + // Synchronize only if we need to ensure GPU operations complete before CPU reads + // (device-to-host case where GPU may have written data) + if (src_is_device && !dst_is_device) { + // Ensure any pending GPU writes to source complete before CPU reads + ETMetalStream* stream = getCurrentMetalStream(); + stream->synchronize(SyncType::COMMIT_AND_WAIT); + } + + ET_LOG(Debug, "Metal memory copy (memcpy): %zu bytes, src_device=%d, dst_device=%d", + nbytes, src_is_device, dst_is_device); + } + + return 0; +} + +id get_metal_device() { + // Use stream-based device access + ETMetalStream* stream = getCurrentMetalStream(); + return stream->device(); +} + +id get_metal_command_queue() { + // Use stream-based queue access + ETMetalStream* stream = getCurrentMetalStream(); + return stream->commandQueue(); +} + +} // extern "C" + +// ======================= +// ETMetalShaderLibrary Implementation +// ======================= + +ETMetalShaderLibrary::ETMetalShaderLibrary(const std::string& source) : shaderSource_(source) { + compileLibrary(); +} + +ETMetalShaderLibrary::~ETMetalShaderLibrary() { + @autoreleasepool { + if (library_) { + [library_ release]; + library_ = nil; + } + + for (auto& pair : pipelineStates_) { + [pair.second.first release]; + [pair.second.second release]; + } + pipelineStates_.clear(); + } +} + +void ETMetalShaderLibrary::compileLibrary() { + @autoreleasepool { + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to get Metal device"); + return; + } + + NSString* sourceString = [NSString stringWithUTF8String:shaderSource_.c_str()]; + NSError* error = nil; + + library_ = [device newLibraryWithSource:sourceString options:nil error:&error]; + if (!library_ || error) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to compile shader library: %s", + error ? [[error localizedDescription] UTF8String] : "unknown error"); + return; + } + + [library_ retain]; + ET_LOG(Debug, "ETMetalShaderLibrary: Successfully compiled shader library"); + } +} + +std::pair, id> ETMetalShaderLibrary::getLibraryPipelineState(const std::string& functionName) { + auto it = pipelineStates_.find(functionName); + if (it != pipelineStates_.end()) { + return it->second; + } + + @autoreleasepool { + if (!library_) { + ET_LOG(Error, "ETMetalShaderLibrary: Library not compiled"); + return {nil, nil}; + } + + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to get Metal device"); + return {nil, nil}; + } + + NSString* funcName = [NSString stringWithUTF8String:functionName.c_str()]; + id function = [library_ newFunctionWithName:funcName]; + if (!function) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to get function '%s'", functionName.c_str()); + return {nil, nil}; + } + + NSError* error = nil; + id pipelineState = [device newComputePipelineStateWithFunction:function error:&error]; + if (!pipelineState || error) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to create pipeline state for '%s': %s", + functionName.c_str(), error ? [[error localizedDescription] UTF8String] : "unknown error"); + [function release]; + return {nil, nil}; + } + + [pipelineState retain]; + [function retain]; + pipelineStates_[functionName] = {pipelineState, function}; + + ET_LOG(Debug, "ETMetalShaderLibrary: Created pipeline state for function '%s'", functionName.c_str()); + return {pipelineState, function}; + } +} + +std::shared_ptr ETMetalShaderLibrary::getKernelFunction(const std::string& name) { + auto pipelineStatePair = getLibraryPipelineState(name); + if (!pipelineStatePair.first || !pipelineStatePair.second) { + ET_LOG(Error, "ETMetalShaderLibrary::getKernelFunction: Failed to get pipeline state for '%s'", name.c_str()); + return nullptr; + } + + return std::make_shared(pipelineStatePair.first, pipelineStatePair.second); +} + +// ======================= +// ETMetalKernelFunction Implementation +// ======================= + +ETMetalKernelFunction::ETMetalKernelFunction(id cps, id func) + : cps_(cps), func_(func), encoder_(nil) { + if (cps_) [cps_ retain]; + if (func_) [func_ retain]; +} + +ETMetalKernelFunction::~ETMetalKernelFunction() { + @autoreleasepool { + // Don't release encoder_ here - the stream owns it + // Only clean up our own references + if (cps_) { + [cps_ release]; + cps_ = nil; + } + if (func_) { + [func_ release]; + func_ = nil; + } + + encoder_ = nil; // Clear reference without releasing + } +} + +void ETMetalKernelFunction::startEncoding() { + @autoreleasepool { + // Don't retain/release the encoder - just get reference from stream + ETMetalStream* stream = getCurrentMetalStream(); + encoder_ = stream->commandEncoder(); // Use stream's managed encoder + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction: Failed to get encoder from stream"); + return; + } + + // Don't retain - stream owns the encoder + [encoder_ setComputePipelineState:cps_]; + + ET_LOG(Debug, "ETMetalKernelFunction: Started encoding with stream-managed encoder"); + } +} + +void ETMetalKernelFunction::setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + void* data_ptr = tensor.mutable_data_ptr(); + size_t totalSize = tensor.numel() * tensor.element_size(); + + auto it = ptr_to_mtl_buffer.find(data_ptr); + if (it != ptr_to_mtl_buffer.end()) { + // Use existing Metal buffer + id mtlBuffer = it->second; + [encoder_ setBuffer:mtlBuffer offset:0 atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set Metal buffer at index %u (size: %zu)", idx, totalSize); + } else { + // Handle CPU tensor data + if (totalSize <= 4096) { + // Use setBytes for small data (more efficient) + [encoder_ setBytes:data_ptr length:totalSize atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set CPU tensor via setBytes at index %u (size: %zu)", idx, totalSize); + } else { + // Create temporary buffer for large data (should be rare) + @autoreleasepool { + id device = get_metal_device(); + if (device) { + id tempBuffer = [device newBufferWithBytes:data_ptr + length:totalSize + options:MTLResourceStorageModeShared]; + if (tempBuffer) { + [encoder_ setBuffer:tempBuffer offset:0 atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set large CPU tensor via temporary buffer at index %u (size: %zu)", idx, totalSize); + } else { + ET_LOG(Error, "ETMetalKernelFunction::setArg: Failed to create temporary buffer for index %u", idx); + } + } else { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No Metal device available for index %u", idx); + } + } + } + } +} + +void ETMetalKernelFunction::setArg(unsigned idx, int64_t val) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + [encoder_ setBytes:&val length:sizeof(int64_t) atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set int64_t value %lld at index %u", val, idx); +} + +void ETMetalKernelFunction::setArg(unsigned idx, uint32_t val) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + [encoder_ setBytes:&val length:sizeof(uint32_t) atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set uint32_t value %u at index %u", val, idx); +} + +void ETMetalKernelFunction::setArg(unsigned idx, float val) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + [encoder_ setBytes:&val length:sizeof(float) atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set float value %f at index %u", val, idx); +} + +void ETMetalKernelFunction::setArg(unsigned idx, bool val) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + [encoder_ setBytes:&val length:sizeof(bool) atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set bool value %s at index %u", val ? "true" : "false", idx); +} + +void ETMetalKernelFunction::setArg(unsigned idx, const void* data, size_t size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + [encoder_ setBytes:data length:size atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set bytes at index %u (size: %zu)", idx, size); +} + +void ETMetalKernelFunction::setArgUint3(unsigned idx, uint32_t x, uint32_t y, uint32_t z) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArgUint3: No active encoder"); + return; + } + + // Use SIMD library's uint3 type which matches Metal shader's uint3 layout + simd_uint3 val = {x, y, z}; + [encoder_ setBytes:&val length:sizeof(simd_uint3) atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArgUint3: Set uint3{%u, %u, %u} at index %u", x, y, z, idx); +} + +void ETMetalKernelFunction::dispatchSingle(uint64_t length) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchSingle: No active encoder"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + uint64_t actualGroupSize = std::min(maxThreadsPerGroup, length); + + auto size = MTLSizeMake(length, 1, 1); + auto threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchSingle: Dispatched with length %llu, group size %llu", length, actualGroupSize); + +} + +void ETMetalKernelFunction::dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchSingleWithGroupSize: No active encoder"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + uint64_t actualGroupSize = group_size > 0 ? std::min(group_size, maxThreadsPerGroup) : std::min(maxThreadsPerGroup, length); + + auto size = MTLSizeMake(length, 1, 1); + auto threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchSingleWithGroupSize: Dispatched with length %llu, group size %llu", length, actualGroupSize); + +} + +void ETMetalKernelFunction::dispatchArray(const uint64_t* length, size_t length_size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArray: No active encoder"); + return; + } + + if (!length || length_size == 0) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArray: Invalid length array"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + + MTLSize size, threadGroupSize; + + if (length_size == 1) { + size = MTLSizeMake(length[0], 1, 1); + uint64_t actualGroupSize = std::min(maxThreadsPerGroup, length[0]); + threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + } else if (length_size == 2) { + size = MTLSizeMake(length[0], length[1], 1); + uint64_t groupX = std::min(static_cast(32), length[0]); + uint64_t groupY = maxThreadsPerGroup / groupX; + threadGroupSize = MTLSizeMake(groupX, groupY, 1); + } else { + size = MTLSizeMake(length[0], length[1], length_size > 2 ? length[2] : 1); + uint64_t groupX = std::min(static_cast(8), length[0]); + uint64_t groupY = std::min(static_cast(8), length[1]); + uint64_t groupZ = maxThreadsPerGroup / (groupX * groupY); + threadGroupSize = MTLSizeMake(groupX, groupY, groupZ); + } + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchArray: Dispatched %zuD with size [%lu, %lu, %lu], group [%lu, %lu, %lu]", + length_size, size.width, size.height, size.depth, + threadGroupSize.width, threadGroupSize.height, threadGroupSize.depth); + +} + +void ETMetalKernelFunction::dispatchArrayWithGroupSize(const uint64_t* length, size_t length_size, + const uint64_t* group_size, size_t group_size_size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArrayWithGroupSize: No active encoder"); + return; + } + + if (!length || length_size == 0) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArrayWithGroupSize: Invalid length array"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + + MTLSize size, threadGroupSize; + + if (length_size == 1) { + size = MTLSizeMake(length[0], 1, 1); + uint64_t actualGroupSize = maxThreadsPerGroup; + if (group_size && group_size_size > 0) { + actualGroupSize = std::min(maxThreadsPerGroup, group_size[0]); + } + threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + } else if (length_size == 2) { + size = MTLSizeMake(length[0], length[1], 1); + uint64_t groupX = std::min(static_cast(32), length[0]); + uint64_t groupY = maxThreadsPerGroup / groupX; + if (group_size && group_size_size >= 2) { + groupX = std::min(static_cast(group_size[0]), length[0]); + groupY = std::min(static_cast(group_size[1]), length[1]); + } + threadGroupSize = MTLSizeMake(groupX, groupY, 1); + } else { + size = MTLSizeMake(length[0], length[1], length_size > 2 ? length[2] : 1); + uint64_t groupX = std::min(static_cast(8), length[0]); + uint64_t groupY = std::min(static_cast(8), length[1]); + uint64_t groupZ = maxThreadsPerGroup / (groupX * groupY); + if (group_size && group_size_size >= 3) { + groupX = std::min(static_cast(group_size[0]), length[0]); + groupY = std::min(static_cast(group_size[1]), length[1]); + groupZ = std::min(static_cast(group_size[2]), length_size > 2 ? length[2] : 1); + } + threadGroupSize = MTLSizeMake(groupX, groupY, groupZ); + } + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchArrayWithGroupSize: Dispatched %zuD with size [%lu, %lu, %lu], group [%lu, %lu, %lu]", + length_size, size.width, size.height, size.depth, + threadGroupSize.width, threadGroupSize.height, threadGroupSize.depth); + +} + +void ETMetalKernelFunction::dispatchThreadgroups(uint64_t gridX, uint64_t gridY, uint64_t gridZ, + uint64_t threadsX, uint64_t threadsY, uint64_t threadsZ) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: No active encoder"); + return; + } + + if (!cps_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: No compute pipeline state"); + return; + } + + // Calculate total threads per threadgroup + uint64_t totalThreads = threadsX * threadsY * threadsZ; + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + + // Validate total thread count + if (totalThreads > maxThreadsPerGroup) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: Requested %llu total threads per threadgroup exceeds device maximum of %llu", + (unsigned long long)totalThreads, (unsigned long long)maxThreadsPerGroup); + return; + } + + MTLSize threadgroupsPerGrid = MTLSizeMake(gridX, gridY, gridZ); + MTLSize threadsPerThreadgroup = MTLSizeMake(threadsX, threadsY, threadsZ); + + [encoder_ dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup]; + + ET_LOG(Debug, "ETMetalKernelFunction::dispatchThreadgroups: Dispatched grid [%llu, %llu, %llu] with threadgroup [%llu, %llu, %llu]", + (unsigned long long)gridX, (unsigned long long)gridY, (unsigned long long)gridZ, + (unsigned long long)threadsX, (unsigned long long)threadsY, (unsigned long long)threadsZ); +} + +void ETMetalKernelFunction::runCommandBlock(std::function f) { + // Use dispatch_sync with the stream's serial queue for thread safety and synchronization + // This matches PyTorch's approach: dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ...) + ETMetalStream* stream = getCurrentMetalStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + f(); + } + }); + + ET_LOG(Debug, "ETMetalKernelFunction::runCommandBlock: Executed command block with dispatch_sync"); +} + +// ======================= +// ETMetalStream Implementation +// ======================= + +ETMetalStream::ETMetalStream() + : device_(nil), commandQueue_(nil), commandBuffer_(nil), prevCommandBuffer_(nil), + commandEncoder_(nil), serialQueue_(nullptr), enableCommitAndContinue_(true) { + @autoreleasepool { + // Create device and command queue + device_ = MTLCreateSystemDefaultDevice(); + if (!device_) { + ET_LOG(Error, "ETMetalStream: Failed to create Metal device"); + return; + } + [device_ retain]; + + commandQueue_ = [device_ newCommandQueue]; + if (!commandQueue_) { + ET_LOG(Error, "ETMetalStream: Failed to create Metal command queue"); + return; + } + [commandQueue_ retain]; + + // Create serial queue for thread safety + serialQueue_ = dispatch_queue_create("metal gpu stream", nullptr); + + ET_LOG(Debug, "ETMetalStream: Created stream with device %p, queue %p", device_, commandQueue_); + } +} + +ETMetalStream::~ETMetalStream() { + @autoreleasepool { + // Synchronize before cleanup + synchronize(SyncType::COMMIT_AND_WAIT); + + // Clean up command encoder + if (commandEncoder_) { + [commandEncoder_ release]; + commandEncoder_ = nil; + } + + // Clean up command buffers + if (commandBuffer_) { + [commandBuffer_ release]; + commandBuffer_ = nil; + } + if (prevCommandBuffer_) { + [prevCommandBuffer_ release]; + prevCommandBuffer_ = nil; + } + + // Clean up command queue and device + if (commandQueue_) { + [commandQueue_ release]; + commandQueue_ = nil; + } + if (device_) { + [device_ release]; + device_ = nil; + } + + // Clean up serial queue + if (serialQueue_) { + dispatch_release(serialQueue_); + serialQueue_ = nullptr; + } + + ET_LOG(Debug, "ETMetalStream: Destroyed stream"); + } +} + +ETMetalStream* ETMetalStream::getDefaultStream() { + if (!defaultStream_) { + defaultStream_ = new ETMetalStream(); + } + return defaultStream_; +} + +// Lazy command buffer creation (use MPSCommandBuffer like PyTorch) +MPSCommandBuffer* ETMetalStream::commandBuffer() { + if (!commandBuffer_) { + if (!commandQueue_) { + ET_LOG(Error, "ETMetalStream::commandBuffer: No command queue available"); + return nil; + } + + commandBuffer_ = [MPSCommandBuffer commandBufferFromCommandQueue:commandQueue_]; + if (!commandBuffer_) { + ET_LOG(Error, "ETMetalStream::commandBuffer: Failed to create command buffer"); + return nil; + } + [commandBuffer_ retain]; + + ET_LOG(Debug, "ETMetalStream::commandBuffer: Created lazy command buffer %p", commandBuffer_); + } + + return commandBuffer_; +} + +// Lazy command encoder creation +id ETMetalStream::commandEncoder() { + if (!commandEncoder_) { + MPSCommandBuffer* cmdBuffer = commandBuffer(); + if (!cmdBuffer) { + ET_LOG(Error, "ETMetalStream::commandEncoder: Failed to get command buffer"); + return nil; + } + + commandEncoder_ = [cmdBuffer computeCommandEncoder]; + if (!commandEncoder_) { + ET_LOG(Error, "ETMetalStream::commandEncoder: Failed to create command encoder"); + return nil; + } + [commandEncoder_ retain]; + + ET_LOG(Debug, "ETMetalStream::commandEncoder: Created lazy command encoder %p", commandEncoder_); + } + + return commandEncoder_; +} + +// Synchronization with SyncType - matches PyTorch's approach (no dispatch_sync here) +void ETMetalStream::synchronize(SyncType syncType) { + endKernelCoalescing(); + + switch (syncType) { + case SyncType::NONE: + // Do nothing - no commit + break; + case SyncType::COMMIT: + commit(); + break; + case SyncType::COMMIT_AND_WAIT: + commitAndWait(); + break; + case SyncType::COMMIT_AND_CONTINUE: + if (enableCommitAndContinue_) { + commitAndContinue(); + } else { + ET_LOG(Error, "ETMetalStream::synchronize: CommitAndContinue requested but disabled"); + commit(); + } + break; + case SyncType::COMMIT_ADAPTIVE: + // Simple adaptive policy - could be enhanced with memory pressure detection + // TODO: Could add memory pressure detection like PyTorch does + commit(); + break; + } + + ET_LOG(Debug, "ETMetalStream::synchronize: Completed with SyncType %d", static_cast(syncType)); +} + +// Encoder coalescing management +void ETMetalStream::endKernelCoalescing() { + if (commandEncoder_) { + [commandEncoder_ endEncoding]; + [commandEncoder_ release]; + commandEncoder_ = nil; + ET_LOG(Debug, "ETMetalStream::endKernelCoalescing: Ended encoder coalescing"); + } +} + +// Commit methods +void ETMetalStream::commit() { + if (!commandBuffer_) { + ET_LOG(Error, "ETMetalStream::commit: No command buffer to commit"); + return; + } + + [commandBuffer_ commit]; + ET_LOG(Debug, "ETMetalStream::commit: Committed buffer %p", commandBuffer_); + + [commandBuffer_ release]; + commandBuffer_ = nil; +} + +void ETMetalStream::commitAndWait() { + // Handle previous command buffer first + if (prevCommandBuffer_) { + [prevCommandBuffer_ waitUntilCompleted]; + [prevCommandBuffer_ release]; + prevCommandBuffer_ = nil; + } + + // Handle current command buffer + if (commandBuffer_) { + [commandBuffer_ commit]; + [commandBuffer_ waitUntilCompleted]; + [commandBuffer_ release]; + commandBuffer_ = nil; + } + + ET_LOG(Debug, "ETMetalStream::commitAndWait: Committed and waited for completion"); +} + +void ETMetalStream::commitAndContinue() { + if (!commandBuffer_) { + ET_LOG(Error, "ETMetalStream::commitAndContinue: No command buffer to commit"); + return; + } + + // Commit buffer and allow immediate reuse for better performance + [commandBuffer_ commit]; + ET_LOG(Debug, "ETMetalStream::commitAndContinue: Committed buffer %p with continue", commandBuffer_); + + // The buffer handles synchronization internally for commit-and-continue +} + +void ETMetalStream::flush() { + if (commandBuffer_) { + [commandBuffer_ commit]; + + if (!enableCommitAndContinue_) { + // Keep the command buffer for later waiting if commit-and-continue is disabled + prevCommandBuffer_ = commandBuffer_; + } else { + [commandBuffer_ release]; + } + commandBuffer_ = nil; + + ET_LOG(Debug, "ETMetalStream::flush: Flushed command buffer"); + } +} + +// Memory operations +void ETMetalStream::fill(id buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) { + if (length == 0) { + return; + } + + dispatch_sync(serialQueue_, ^{ + @autoreleasepool { + endKernelCoalescing(); + id blitEncoder = [commandBuffer() blitCommandEncoder]; + + [blitEncoder fillBuffer:buffer range:NSMakeRange(offset, length) value:value]; + [blitEncoder endEncoding]; + synchronize(syncType); + + ET_LOG(Debug, "ETMetalStream::fill: Filled buffer with value %u, length %zu, offset %zu", value, length, offset); + } + }); +} + +void ETMetalStream::copy(id srcBuffer, id dstBuffer, size_t length, + size_t srcOffset, size_t dstOffset, SyncType syncType) { + + if (length == 0) { + return; + } + + // Check that offsets are within buffer bounds before copying + if (!srcBuffer || !dstBuffer) { + ET_LOG(Error, "ETMetalStream::copy: Source or destination buffer is nil"); + return; + } + NSUInteger srcBufferLength = [srcBuffer length]; + NSUInteger dstBufferLength = [dstBuffer length]; + if (srcOffset + length > srcBufferLength) { + ET_LOG(Error, "ETMetalStream::copy: Source offset (%zu) + length (%zu) exceeds source buffer size (%zu)", srcOffset, length, srcBufferLength); + return; + } + if (dstOffset + length > dstBufferLength) { + ET_LOG(Error, "ETMetalStream::copy: Destination offset (%zu) + length (%zu) exceeds destination buffer size (%zu)", dstOffset, length, dstBufferLength); + return; + } + + dispatch_sync(serialQueue_, ^{ + @autoreleasepool { + endKernelCoalescing(); + id blitEncoder = [commandBuffer() blitCommandEncoder]; + + // Handle large copies in chunks + constexpr size_t max_copy_size = 0x80000000; // 2GB + size_t bytes_copied = 0; + size_t bytes_remaining = length; + + while (bytes_remaining > 0) { + NSUInteger bytes_to_copy = std::min(max_copy_size, bytes_remaining); + [blitEncoder copyFromBuffer:srcBuffer + sourceOffset:(NSUInteger)srcOffset + bytes_copied + toBuffer:dstBuffer + destinationOffset:(NSUInteger)dstOffset + bytes_copied + size:bytes_to_copy]; + bytes_copied += bytes_to_copy; + bytes_remaining -= bytes_to_copy; + } + + [blitEncoder endEncoding]; + synchronize(syncType); + + ET_LOG(Debug, "ETMetalStream::copy: Copied %zu bytes from offset %zu to offset %zu", length, srcOffset, dstOffset); + } + }); +} + + +void ETMetalStream::synchronize() { + synchronize(SyncType::COMMIT_AND_WAIT); +} + +bool ETMetalStream::isEmpty() const { + return !commandBuffer_ && !commandEncoder_; +} + +void ETMetalStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType) { + // Use dispatch_sync_with_rethrow exactly like PyTorch does for MPSGraph execution + dispatch_sync_with_rethrow(serialQueue_, ^() { + @autoreleasepool { + endKernelCoalescing(); + + [mpsGraph encodeToCommandBuffer:commandBuffer() + feeds:feeds + targetOperations:nil + resultsDictionary:results + executionDescriptor:nil]; + } + }); +} + +// ======================= +// Global Storage Management Functions +// ======================= + +void storeFunctionHandle(ETMetalKernelFunction* raw_function, std::shared_ptr function_shared_ptr) { + function_storage[raw_function] = function_shared_ptr; +} + +void storeLibraryHandle(ETMetalShaderLibrary* raw_library, std::unique_ptr library) { + library_storage[raw_library] = std::move(library); +} + +bool removeFunctionHandle(ETMetalKernelFunction* raw_function) { + auto it = function_storage.find(raw_function); + if (it != function_storage.end()) { + function_storage.erase(it); + return true; + } + return false; +} + +bool removeLibraryHandle(ETMetalShaderLibrary* raw_library) { + auto it = library_storage.find(raw_library); + if (it != library_storage.end()) { + library_storage.erase(it); + return true; + } + return false; +} + +// ======================= +// Global Stream Access Functions +// ======================= + +ETMetalStream* getCurrentMetalStream() { + if (!currentStream_) { + currentStream_ = ETMetalStream::getDefaultStream(); + } + return currentStream_; +} + +void setCurrentMetalStream(ETMetalStream* stream) { + currentStream_ = stream; +} + +// ======================= +// Metal Stream Synchronization Functions +// ======================= + +void synchronize_metal_stream() { + @autoreleasepool { + // Use the ETMetalStream for proper synchronization + ETMetalStream* stream = getCurrentMetalStream(); + stream->synchronize(SyncType::COMMIT_AND_WAIT); + + ET_LOG(Debug, "synchronize_metal_stream: Stream synchronized with COMMIT_AND_WAIT"); + } +} + +void synchronize_metal_stream_with_type(int sync_type) { + @autoreleasepool { + ETMetalStream* stream = getCurrentMetalStream(); + SyncType syncTypeEnum = static_cast(sync_type); + stream->synchronize(syncTypeEnum); + + ET_LOG(Debug, "synchronize_metal_stream_with_type: Stream synchronized with SyncType %d", sync_type); + } +} + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.h b/backends/apple/metal/runtime/shims/et_metal_ops.h new file mode 100644 index 00000000000..78bdb419ea4 --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal_ops.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace metal { + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * ExecutorTorch implementation of aoti_torch_mps_mm_out. + * Performs simple matrix multiplication: out = self @ mat2 + */ +AOTITorchError aoti_torch_mps_mm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2); + +/** + * ExecutorTorch implementation of aoti_torch_mps_convolution. + * Performs 2D convolution operation - matches PyTorch AOTI signature + */ +AOTITorchError aoti_torch_mps_convolution( + AOTITensorHandle input, + AOTITensorHandle weight, + AOTITensorHandle* bias, + const int64_t* stride, + int64_t stride_len_, + const int64_t* padding, + int64_t padding_len_, + const int64_t* dilation, + int64_t dilation_len_, + int32_t transposed, + const int64_t* output_padding, + int64_t output_padding_len_, + int64_t groups, + AOTITensorHandle* ret0); + +/** + * ExecutorTorch implementation of + * aoti_torch_mps__scaled_dot_product_attention_math_for_mps. Performs scaled + * dot product attention calculation - matches PyTorch AOTI signature + */ +AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( + AOTITensorHandle query, + AOTITensorHandle key, + AOTITensorHandle value, + AOTITensorHandle* attn_mask, + double dropout_p, + int32_t is_causal, + AOTITensorHandle* dropout_mask, + double* scale, + AOTITensorHandle* ret0, + AOTITensorHandle* ret1); + +#ifdef __cplusplus +} // extern "C" +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm new file mode 100644 index 00000000000..da54dafb334 --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -0,0 +1,1557 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import +#import +#import +#import +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +using executorch::runtime::etensor::Tensor; + +// Forward declaration of dispatch_sync_with_rethrow from et_metal.mm +void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()); + +// Declare the global mapping from et_metal.mm +extern std::unordered_map> ptr_to_mtl_buffer; + +// ======================= +// MPSGraph Caching Infrastructure +// ======================= + +namespace { + +// Cache key structure for different operations +struct GraphCacheKey { + std::string op_name; + std::vector shape_params; + int32_t dtype; + bool transpose_flag; + + bool operator==(const GraphCacheKey& other) const { + return op_name == other.op_name && + shape_params == other.shape_params && + dtype == other.dtype && + transpose_flag == other.transpose_flag; + } +}; + +// Hash function for GraphCacheKey +struct GraphCacheKeyHash { + std::size_t operator()(const GraphCacheKey& key) const { + std::size_t hash = std::hash{}(key.op_name); + for (auto val : key.shape_params) { + hash ^= std::hash{}(val) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + } + hash ^= std::hash{}(key.dtype) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + hash ^= std::hash{}(key.transpose_flag) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + return hash; + } +}; + +// Struct to store both the compiled graph and its tensors for reuse +struct CachedGraph { + MPSGraph* graph; + MPSGraphTensor* input1; + MPSGraphTensor* input2; + MPSGraphTensor* input3; // Optional (e.g., bias, mask) + MPSGraphTensor* output; +}; + +// Global cache for compiled MPSGraphs +// These graphs are never released - they're reused across calls +static std::unordered_map graph_cache; + +// Statistics for monitoring cache effectiveness +struct CacheStats { + size_t hits = 0; + size_t misses = 0; + + void logStats() { + if ((hits + misses) % 100 == 0 && (hits + misses) > 0) { + double hit_rate = 100.0 * hits / (hits + misses); + ET_LOG(Debug, "MPSGraph cache stats: %zu hits, %zu misses (%.1f%% hit rate)", + hits, misses, hit_rate); + } + } +}; + +static CacheStats cache_stats; + +// Helper function to get Metal buffer from the global mapping +static id get_mtl_buffer(Tensor* tensor, const char* op_name, const char* tensor_name) { + void* data_ptr = tensor->mutable_data_ptr(); + auto it = ptr_to_mtl_buffer.find(data_ptr); + if (it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "%s: %s tensor not found in Metal buffer mapping", op_name, tensor_name); + throw std::runtime_error(std::string(tensor_name) + " tensor not found in Metal buffer mapping"); + } + return it->second; +} + +// Helper function to allocate a Metal buffer and register it in the global mapping. +static id allocate_mtl_buffer(void** data_ptr, size_t size_bytes) { + AOTITorchError malloc_err = aoti_torch_mps_malloc(data_ptr, size_bytes); + if (malloc_err != Error::Ok) { + ET_LOG(Error, "allocate_and_register_mtl_buffer: Failed to allocate Metal buffer via aoti_torch_mps_malloc"); + throw std::runtime_error("Failed to allocate output Metal buffer"); + } + + auto it = ptr_to_mtl_buffer.find(*data_ptr); + if (it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "allocate_and_register_mtl_buffer: aoti_torch_mps_malloc did not register buffer in map"); + throw std::runtime_error("Failed to look up allocated Metal buffer"); + } + return it->second; +} + +// Helper function to get the Metal shader source for SDPA +static std::string get_sdpa_metal_source() { + return R"( +// Ported from PyTorch's Attention.metal +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/mps/kernels/Attention.metal +// Largely influenced by +// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +// Modified to support floating point masks and transposed middle dimensions (dims 1 & 2) + +#include +#include +#include + +using namespace metal; + +// PyTorch's sdpa_vector kernel (one-pass variant) +template +[[kernel]] void sdpa_vector( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device T* out [[buffer(3)]], + constant uint& gqa_factor [[buffer(4)]], + constant uint& N [[buffer(5)]], + constant uint3& qkv_head_strides [[buffer(6)]], + constant uint3& qkv_seq_strides [[buffer(7)]], + constant float& scale [[buffer(8)]], + const device T* mask [[buffer(9)]], // Changed from bool* to T* for floating point masks + constant uint3& mask_strides [[buffer(10)]], + constant bool& has_mask [[buffer(11)]], + constant uint3& qkv_batch_strides [[buffer(12)]], // NEW: batch strides for Q, K, V + constant uint& num_q_heads [[buffer(13)]], // NEW: number of query heads + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr uint BN = 32; + constexpr uint BD = 32; + constexpr uint qk_per_thread = D / BD; + constexpr uint v_per_thread = V / BD; + const uint q_head_stride = qkv_head_strides.x; + const uint q_seq_stride = qkv_seq_strides.x; + const uint q_batch_stride = qkv_batch_strides.x; + const uint k_head_stride = qkv_head_strides.y; + const uint k_seq_stride = qkv_seq_strides.y; + const uint k_batch_stride = qkv_batch_strides.y; + const uint v_head_stride = qkv_head_strides.z; + const uint v_seq_stride = qkv_seq_strides.z; + const uint v_batch_stride = qkv_batch_strides.z; + const uint mask_head_stride = mask_strides.x; + const uint mask_kv_seq_stride = mask_strides.y; + const uint mask_q_seq_stride = mask_strides.z; + uint inner_k_stride = BN * int(k_seq_stride); + uint inner_v_stride = BN * int(v_seq_stride); + + typedef float U; + + thread U q[qk_per_thread]; + thread U k[qk_per_thread]; + thread U o[v_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int head_idx = tid.x; // Flattened batch*heads index + const int q_seq_idx = tid.y; + + // Decompose flattened head_idx into batch and head indices + const int batch_idx = head_idx / num_q_heads; + const int head_in_batch = head_idx % num_q_heads; + const int kv_head_idx = head_in_batch / gqa_factor; + + const int Q = tpg.y; + const int group_offset = head_idx * Q + q_seq_idx; + const int o_offset = group_offset; + + // Use decomposed indices with separate batch and head strides + queries += batch_idx * q_batch_stride + head_in_batch * q_head_stride + q_seq_idx * q_seq_stride + + simd_lid * qk_per_thread; + keys += batch_idx * k_batch_stride + kv_head_idx * k_head_stride + simd_gid * k_seq_stride + + simd_lid * qk_per_thread; + values += batch_idx * v_batch_stride + kv_head_idx * v_head_stride + simd_gid * v_seq_stride + + simd_lid * v_per_thread; + if (has_mask) { + mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; + } + + out += o_offset * V + simd_gid * v_per_thread; + + // Read the query and 0 the output accumulator + for (uint i = 0; i < qk_per_thread; i++) { + q[i] = scale * static_cast(queries[i]); + } + for (uint i = 0; i < v_per_thread; i++) { + o[i] = 0; + } + + U max_score = -INFINITY; + U sum_exp_score = 0; + + // For each key + for (uint i = simd_gid; i < N; i += BN) { + // Check mask: for floating point masks, values > -1e9 are considered valid (not masked) + // Masked positions typically have -inf or very negative values + const bool is_valid = !has_mask || (static_cast(mask[0]) > -1e9f); + + if (is_valid) { + // Read the key + for (uint j = 0; j < qk_per_thread; j++) { + k[j] = static_cast(keys[j]); + } + + // Compute the i-th score + U score = 0; + for (uint j = 0; j < qk_per_thread; j++) { + score += q[j] * k[j]; + } + score = simd_sum(score); + + // Add mask value to score if mask is present + if (has_mask) { + score += static_cast(mask[0]); + } + + // Update the accumulators + U new_max = max(max_score, score); + U factor = metal::fast::exp(max_score - new_max); + U exp_score = metal::fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (uint j = 0; j < v_per_thread; j++) { + o[j] = o[j] * factor + exp_score * static_cast(values[j]); + } + } + + // Move the pointers to the next kv + keys += inner_k_stride; + values += inner_v_stride; + if (has_mask) { + mask += BN * mask_kv_seq_stride; + } + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = metal::fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + // Now we need to aggregate all the outputs + for (uint i = 0; i < v_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + const U safe_sum = (sum_exp_score == 0 ? 1e-6f : sum_exp_score); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / safe_sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (uint i = 0; i < v_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} + +#define INSTANTIATE_SDPA_VECTOR(DTYPE, QK_DIM, VALUE_DIM) \ + template [[host_name("sdpa_vector_" #DTYPE "_" #QK_DIM \ + "_" #VALUE_DIM)]] kernel void \ + sdpa_vector( \ + const device DTYPE* queries [[buffer(0)]], \ + const device DTYPE* keys [[buffer(1)]], \ + const device DTYPE* values [[buffer(2)]], \ + device DTYPE* out [[buffer(3)]], \ + constant uint& gqa_factor [[buffer(4)]], \ + constant uint& N [[buffer(5)]], \ + constant uint3& qkv_head_strides [[buffer(6)]], \ + constant uint3& qkv_seq_strides [[buffer(7)]], \ + constant float& scale [[buffer(8)]], \ + const device DTYPE* mask [[buffer(9)]], \ + constant uint3& mask_strides [[buffer(10)]], \ + constant bool& has_mask [[buffer(11)]], \ + constant uint3& qkv_batch_strides [[buffer(12)]], \ + constant uint& num_q_heads [[buffer(13)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 tpg [[threadgroups_per_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define INSTANTIATE_SDPA_VECTOR_HEADS(DTYPE) \ + INSTANTIATE_SDPA_VECTOR(DTYPE, 64, 64); \ + INSTANTIATE_SDPA_VECTOR(DTYPE, 96, 96); \ + INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128); + +INSTANTIATE_SDPA_VECTOR_HEADS(float); +INSTANTIATE_SDPA_VECTOR_HEADS(bfloat); +)"; +} + +// Global shader library cache for SDPA +static std::unique_ptr sdpa_shader_library = nullptr; + +static std::once_flag sdpa_shader_library_once_flag; + +static ETMetalShaderLibrary* get_sdpa_shader_library() { + std::call_once(sdpa_shader_library_once_flag, []() { + std::string source = get_sdpa_metal_source(); + sdpa_shader_library = std::make_unique(source); + }); + return sdpa_shader_library.get(); +} + +} // anonymous namespace + +extern "C" { + +AOTITorchError aoti_torch_mps_mm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2) { + ET_LOG(Debug, "aoti_torch_mps_mm_out: Starting with out=%p, self=%p, mat2=%p", + out, self, mat2); + + if (!out || !self || !mat2) { + ET_LOG(Error, "aoti_torch_mps_mm_out: null tensor handles"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto out_tensor = reinterpret_cast(out); + auto self_tensor = reinterpret_cast(self); + auto mat2_tensor = reinterpret_cast(mat2); + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Converted tensor handles to ET tensors"); + + // Validate tensor dimensions + if (self_tensor->dim() != 2 || mat2_tensor->dim() != 2) { + std::string error_msg = "aoti_torch_mps_mm_out: tensors must be 2-D, got " + + std::to_string(self_tensor->dim()) + " and " + + std::to_string(mat2_tensor->dim()); + ET_LOG(Error, "%s", error_msg.c_str()); + throw std::runtime_error(error_msg); + } + + int64_t M = self_tensor->sizes()[0]; // rows of self + int64_t K = self_tensor->sizes()[1]; // cols of self / rows of mat2 + int64_t N = mat2_tensor->sizes()[1]; // cols of mat2 + + // Check matrix multiplication compatibility + if (self_tensor->sizes()[1] != mat2_tensor->sizes()[0]) { + std::string error_msg = "aoti_torch_mps_mm_out: incompatible matrix sizes for mm (" + + std::to_string(M) + "x" + std::to_string(K) + " and " + + std::to_string(mat2_tensor->sizes()[0]) + "x" + std::to_string(N) + ")"; + ET_LOG(Error, "%s", error_msg.c_str()); + throw std::runtime_error(error_msg); + } + + // Log tensor shapes for debugging + ET_LOG(Debug, "aoti_torch_mps_mm_out: self shape: [%d, %d], mat2 shape: [%d, %d], out shape: [%d, %d]", + (int)M, (int)K, (int)mat2_tensor->sizes()[0], (int)N, + out_tensor->dim() > 0 ? (int)out_tensor->sizes()[0] : 0, + out_tensor->dim() > 1 ? (int)out_tensor->sizes()[1] : 0); + + // Check if mat2 is transposed (non-contiguous due to transpose) + // A transposed matrix will have stride(-2) == 1 (column-major instead of row-major) + // For a 2D tensor with shape [K, N]: + // - Contiguous (row-major): strides = [N, 1] + // - Transposed (column-major): strides = [1, K] + bool mat2_is_transposed = false; + int64_t mat2_stride_0 = mat2_tensor->strides()[0]; // stride for dimension 0 + int64_t mat2_stride_1 = mat2_tensor->strides()[1]; // stride for dimension 1 + + // Detect transposed layout: stride(-2) == 1 indicates column-major layout + if (mat2_stride_0 == 1 && mat2_stride_1 != 1) { + mat2_is_transposed = true; + ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 is transposed (strides=[%lld, %lld])", + mat2_stride_0, mat2_stride_1); + } else { + ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 is contiguous (strides=[%lld, %lld])", + mat2_stride_0, mat2_stride_1); + } + + // Use the same dispatch pattern as other MPS operations for consistent synchronization + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps_mm_out: Failed to get current Metal stream"); + return Error::Internal; + } + + // Get Metal device + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_mm_out: Failed to get Metal device"); + throw std::runtime_error("Failed to get Metal device"); + } + + // Get Metal buffers for input and output tensors + id self_buffer = get_mtl_buffer(self_tensor, "aoti_torch_mps_mm_out", "self"); + id mat2_buffer = get_mtl_buffer(mat2_tensor, "aoti_torch_mps_mm_out", "mat2"); + id out_buffer = get_mtl_buffer(out_tensor, "aoti_torch_mps_mm_out", "out"); + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Using existing Metal buffers - self=%p, mat2=%p, out=%p", + self_buffer, mat2_buffer, out_buffer); + + // End any existing kernel coalescing to ensure a clean state for MPS + stream->endKernelCoalescing(); + + // Determine data type and element size + int32_t dtype = static_cast(self_tensor->scalar_type()); + MPSDataType mps_dtype; + size_t element_size; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: self_tensor scalar_type=%d, SupportedDTypes::FLOAT32=%d, SupportedDTypes::BFLOAT16=%d", + dtype, static_cast(SupportedDTypes::FLOAT32), static_cast(SupportedDTypes::BFLOAT16)); + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + mps_dtype = MPSDataTypeFloat32; + element_size = sizeof(float); + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + mps_dtype = MPSDataTypeBFloat16; + element_size = sizeof(uint16_t); // bfloat16 is 16 bits + } else { + ET_LOG(Error, "aoti_torch_mps_mm_out: Unsupported data type: %d", dtype); + throw std::runtime_error("Unsupported data type for matrix multiplication"); + } + + ET_LOG(Debug, "aoti_torch_mps_mm_out: dtype=%d, element_size=%zu", dtype, element_size); + ET_LOG(Debug, "aoti_torch_mps_mm_out: M=%lld, K=%lld, N=%lld", M, K, N); + + // Define tensor shapes for placeholders (needed for both cache hit and miss) + NSArray* selfShape = @[@(M), @(K)]; + + // For mat2, we need to handle both contiguous and transposed cases + // If mat2 is transposed, its physical layout in memory is [N, K] (column-major) + // but logically we need [K, N] for the matrix multiplication + NSArray* mat2PhysicalShape; + if (mat2_is_transposed) { + // Physical shape reflects the actual memory layout (transposed) + mat2PhysicalShape = @[@(N), @(K)]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 physical shape (transposed): [%d,%d]", (int)N, (int)K); + } else { + // Physical shape is the logical shape (contiguous) + mat2PhysicalShape = @[@(K), @(N)]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 physical shape (contiguous): [%d,%d]", (int)K, (int)N); + } + + // Create cache key for this matrix multiplication + GraphCacheKey cache_key; + cache_key.op_name = "mm"; + cache_key.shape_params = {M, K, N}; + cache_key.dtype = dtype; + cache_key.transpose_flag = mat2_is_transposed; + + // Check if we have a cached graph + MPSGraph* mpsGraph = nullptr; + MPSGraphTensor* mmOutput = nil; + MPSGraphTensor* selfPlaceholder = nil; + MPSGraphTensor* mat2Placeholder = nil; + + auto cache_it = graph_cache.find(cache_key); + if (cache_it != graph_cache.end()) { + // Cache hit - reuse compiled graph and tensor references + CachedGraph& cached = cache_it->second; + mpsGraph = cached.graph; + selfPlaceholder = cached.input1; + mat2Placeholder = cached.input2; + mmOutput = cached.output; + + cache_stats.hits++; + cache_stats.logStats(); + ET_LOG(Debug, "aoti_torch_mps_mm_out: Using cached MPSGraph (cache hit, %zu total hits)", cache_stats.hits); + + } else { + // Cache miss - create and compile new graph + mpsGraph = [MPSGraph new]; + cache_stats.misses++; + cache_stats.logStats(); + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created new MPSGraph instance (cache miss, %zu total misses)", cache_stats.misses); + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Creating placeholders with shapes self:[%d,%d] mat2:[%d,%d]", + (int)M, (int)K, + mat2_is_transposed ? (int)N : (int)K, + mat2_is_transposed ? (int)K : (int)N); + + // Create placeholders for input tensors + selfPlaceholder = [mpsGraph placeholderWithShape:selfShape + dataType:mps_dtype + name:@"self"]; + mat2Placeholder = [mpsGraph placeholderWithShape:mat2PhysicalShape + dataType:mps_dtype + name:@"mat2_physical"]; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created input placeholders"); + + // If mat2 is transposed, apply transpose operation in the graph to get the logical shape + MPSGraphTensor* mat2Logical; + if (mat2_is_transposed) { + // Transpose from physical [N, K] to logical [K, N] + // MPSGraph transposeTensor swaps the last two dimensions for 2D tensors + mat2Logical = [mpsGraph transposeTensor:mat2Placeholder + dimension:-2 + withDimension:-1 + name:@"mat2_transposed"]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Applied transpose operation to mat2 in graph"); + } else { + // No transpose needed, use placeholder directly + mat2Logical = mat2Placeholder; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Using mat2 placeholder directly (no transpose needed)"); + } + + // Perform matrix multiplication using MPSGraph with the logical mat2 tensor + mmOutput = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder + secondaryTensor:mat2Logical + name:@"matrix_multiplication"]; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Successfully created matrix multiplication tensor"); + + // Cache the compiled graph and tensor references for reuse + CachedGraph cached_graph; + cached_graph.graph = mpsGraph; + cached_graph.input1 = selfPlaceholder; + cached_graph.input2 = mat2Placeholder; + cached_graph.input3 = nil; + cached_graph.output = mmOutput; + graph_cache[cache_key] = cached_graph; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Cached compiled MPSGraph for future reuse"); + } // End of cache miss/hit block + + // Define output shape + NSArray* outShape = @[@(M), @(N)]; + + // Create feeds dictionary for graph execution + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + + // Create MPSGraphTensorData objects for input tensors + // Use physical shapes to match how data is actually laid out in memory + MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer + shape:selfShape + dataType:mps_dtype]; + MPSGraphTensorData* mat2Data = [[MPSGraphTensorData alloc] initWithMTLBuffer:mat2_buffer + shape:mat2PhysicalShape + dataType:mps_dtype]; + + feeds[selfPlaceholder] = selfData; + feeds[mat2Placeholder] = mat2Data; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created feeds dictionary with physical shapes"); + + // Create results dictionary + MPSGraphTensorData* outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer + shape:outShape + dataType:mps_dtype]; + + NSDictionary* results = @{mmOutput: outputData}; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created results dictionary"); + + // Execute the MPSGraph + ET_LOG(Debug, "aoti_torch_mps_mm_out: Executing MPSGraph"); + + @try { + // Use stream helper to encode and synchronize correctly + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); + } @catch (NSException *exception) { + ET_LOG(Error, "aoti_torch_mps_mm_out: NSException caught during executeMPSGraph: %s - %s", + [[exception name] UTF8String], [[exception reason] UTF8String]); + throw std::runtime_error("MPSGraph execution failed with NSException"); + } + + ET_LOG(Debug, "aoti_torch_mps_mm_out: MPSGraph execution completed successfully"); + + [selfData release]; + [mat2Data release]; + [outputData release]; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Executed successfully"); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_mm_out exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_mm_out: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_convolution( + AOTITensorHandle input, + AOTITensorHandle weight, + AOTITensorHandle* bias, + const int64_t* stride, + int64_t stride_len_, + const int64_t* padding, + int64_t padding_len_, + const int64_t* dilation, + int64_t dilation_len_, + int32_t transposed, + const int64_t* output_padding, + int64_t output_padding_len_, + int64_t groups, + AOTITensorHandle* ret0) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Starting with input=%p, weight=%p, bias=%p, groups=%lld, transposed=%d", + input, weight, bias, groups, transposed); + + if (!input || !weight || !ret0) { + ET_LOG(Error, "aoti_torch_mps_convolution: null required handles (input, weight, or ret0)"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto input_tensor = reinterpret_cast(input); + auto weight_tensor = reinterpret_cast(weight); + + // bias can be null for convolutions without bias + Tensor* bias_tensor = nullptr; + if (bias && *bias) { + bias_tensor = reinterpret_cast(*bias); + ET_LOG(Debug, "aoti_torch_mps_convolution: Has bias tensor"); + } else { + ET_LOG(Debug, "aoti_torch_mps_convolution: No bias tensor"); + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: Converted tensor handles to ET tensors"); + + // Log tensor shapes for debugging + ET_LOG(Debug, "aoti_torch_mps_convolution: input shape: [%d, %d, %d, %d]", + input_tensor->dim() > 0 ? (int)input_tensor->sizes()[0] : 0, + input_tensor->dim() > 1 ? (int)input_tensor->sizes()[1] : 0, + input_tensor->dim() > 2 ? (int)input_tensor->sizes()[2] : 0, + input_tensor->dim() > 3 ? (int)input_tensor->sizes()[3] : 0); + + ET_LOG(Debug, "aoti_torch_mps_convolution: weight shape: [%d, %d, %d, %d]", + weight_tensor->dim() > 0 ? (int)weight_tensor->sizes()[0] : 0, + weight_tensor->dim() > 1 ? (int)weight_tensor->sizes()[1] : 0, + weight_tensor->dim() > 2 ? (int)weight_tensor->sizes()[2] : 0, + weight_tensor->dim() > 3 ? (int)weight_tensor->sizes()[3] : 0); + + // Log convolution parameters + if (stride && stride_len_ >= 2) { + ET_LOG(Debug, "aoti_torch_mps_convolution: stride: [%lld, %lld]", stride[0], stride[1]); + } + if (padding && padding_len_ >= 2) { + ET_LOG(Debug, "aoti_torch_mps_convolution: padding: [%lld, %lld]", padding[0], padding[1]); + } + if (dilation && dilation_len_ >= 2) { + ET_LOG(Debug, "aoti_torch_mps_convolution: dilation: [%lld, %lld]", dilation[0], dilation[1]); + } + if (output_padding && output_padding_len_ >= 2) { + ET_LOG(Debug, "aoti_torch_mps_convolution: output_padding: [%lld, %lld]", output_padding[0], output_padding[1]); + } + + // Support conv1d and conv2d by inspecting weight rank. + // conv1d: weight dims = [C_out, C_in, K] + // conv2d: weight dims = [C_out, C_in, Kh, Kw] + bool is_conv1d = (weight_tensor->dim() == 3); + + // Accept input ranks: + // conv1d: 2D (C,W) or 3D (N,C,W) + // conv2d: 3D (C,H,W) or 4D (N,C,H,W) + bool has_batch_dim = false; + bool is_input_4d = false; + int64_t N = 1, C_in = 0, H_in = 1, W_in = 0; + if (is_conv1d) { + if (input_tensor->dim() == 2) { + // (C, W) + has_batch_dim = false; + C_in = input_tensor->sizes()[0]; + W_in = input_tensor->sizes()[1]; + H_in = 1; + } else if (input_tensor->dim() == 3) { + // (N, C, W) + has_batch_dim = true; + N = input_tensor->sizes()[0]; + C_in = input_tensor->sizes()[1]; + W_in = input_tensor->sizes()[2]; + H_in = 1; + } else { + ET_LOG(Error, "aoti_torch_mps_convolution: conv1d expects 2D or 3D input, got %d", (int)input_tensor->dim()); + return Error::InvalidArgument; + } + } else { + is_input_4d = (input_tensor->dim() == 4); + if (is_input_4d) { + // (N, C, H, W) + has_batch_dim = true; + N = input_tensor->sizes()[0]; + C_in = input_tensor->sizes()[1]; + H_in = input_tensor->sizes()[2]; + W_in = input_tensor->sizes()[3]; + } else if (input_tensor->dim() == 3) { + // (C, H, W) + has_batch_dim = false; + N = 1; + C_in = input_tensor->sizes()[0]; + H_in = input_tensor->sizes()[1]; + W_in = input_tensor->sizes()[2]; + } else { + ET_LOG(Error, "aoti_torch_mps_convolution: conv2d expects 3D or 4D input, got %d", (int)input_tensor->dim()); + return Error::InvalidArgument; + } + } + + // Get weight dimensions + int64_t C_out = weight_tensor->sizes()[0]; // output channels + int64_t kernel_h = is_conv1d ? 1 : weight_tensor->sizes()[2]; // kernel height + int64_t kernel_w = is_conv1d ? weight_tensor->sizes()[2] : weight_tensor->sizes()[3]; // kernel width + + // Calculate output spatial dimensions + int64_t stride_h = is_conv1d ? 1 : (stride && stride_len_ > 0 ? stride[0] : 1); + int64_t stride_w = is_conv1d ? (stride && stride_len_ > 0 ? stride[0] : 1) + : (stride && stride_len_ > 1 ? stride[1] : 1); + int64_t pad_h = is_conv1d ? 0 : (padding && padding_len_ > 0 ? padding[0] : 0); + int64_t pad_w = is_conv1d ? (padding && padding_len_ > 0 ? padding[0] : 0) + : (padding && padding_len_ > 1 ? padding[1] : 0); + int64_t dil_h = is_conv1d ? 1 : (dilation && dilation_len_ > 0 ? dilation[0] : 1); + int64_t dil_w = is_conv1d ? (dilation && dilation_len_ > 0 ? dilation[0] : 1) + : (dilation && dilation_len_ > 1 ? dilation[1] : 1); + + int64_t H_out, W_out; + if (transposed) { + // For transposed convolution, output size calculation is different + int64_t output_pad_h = is_conv1d ? 0 : (output_padding && output_padding_len_ > 0 ? output_padding[0] : 0); + int64_t output_pad_w = is_conv1d ? (output_padding && output_padding_len_ > 0 ? output_padding[0] : 0) + : (output_padding && output_padding_len_ > 1 ? output_padding[1] : 0); + H_out = is_conv1d ? 1 : ((H_in - 1) * stride_h - 2 * pad_h + dil_h * (kernel_h - 1) + output_pad_h + 1); + W_out = (W_in - 1) * stride_w - 2 * pad_w + dil_w * (kernel_w - 1) + output_pad_w + 1; + } else { + // Regular convolution output size calculation + H_out = is_conv1d ? 1 : ((H_in + 2 * pad_h - dil_h * (kernel_h - 1) - 1) / stride_h + 1); + W_out = (W_in + 2 * pad_w - dil_w * (kernel_w - 1) - 1) / stride_w + 1; + } + + if (!is_conv1d && is_input_4d) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Calculated 4D output shape: [%lld, %lld, %lld, %lld]", N, C_out, H_out, W_out); + } else if (!is_conv1d) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Calculated 3D output shape: [%lld, %lld, %lld]", C_out, H_out, W_out); + } else if (is_conv1d && has_batch_dim) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Calculated 3D (1D conv) output shape: [%lld, %lld, %lld]", N, C_out, W_out); + } else { + ET_LOG(Debug, "aoti_torch_mps_convolution: Calculated 2D (1D conv) output shape: [%lld, %lld]", C_out, W_out); + } + + // Validate output dimensions are positive + if (N <= 0 || C_out <= 0 || H_out <= 0 || W_out <= 0) { + ET_LOG(Error, "aoti_torch_mps_convolution: Invalid output dimensions N=%lld, C_out=%lld, H_out=%lld, W_out=%lld", + N, C_out, H_out, W_out); + return Error::InvalidArgument; + } + + // Use the same dispatch pattern as other MPS operations for consistent synchronization + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps_convolution: Failed to get current Metal stream"); + return Error::Internal; + } + + // Get Metal device + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_convolution: Failed to get Metal device"); + throw std::runtime_error("Failed to get Metal device"); + } + + // End any existing kernel coalescing to ensure a clean state for MPS + stream->endKernelCoalescing(); + + // Ensure stream is ready; command buffer handled internally by stream helpers + + // Determine data type and element size + int32_t dtype = static_cast(input_tensor->scalar_type()); + MPSDataType mps_dtype; + size_t element_size; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + mps_dtype = MPSDataTypeFloat32; + element_size = sizeof(float); + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + mps_dtype = MPSDataTypeBFloat16; + element_size = sizeof(uint16_t); // bfloat16 is 16 bits + } else { + ET_LOG(Error, "aoti_torch_mps_convolution: Unsupported data type: %d", dtype); + throw std::runtime_error("Unsupported data type for convolution"); + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: mps_dtype=%d, element_size=%zu", mps_dtype, element_size); + + // Define tensor shapes for placeholders (needed for both cache hit and miss) + NSArray* inputShape = @[@(N), @(C_in), @(H_in), @(W_in)]; + NSArray* weightShape = @[@(C_out), @(C_in), @(kernel_h), @(kernel_w)]; + + // Create cache key for this convolution + GraphCacheKey cache_key; + cache_key.op_name = "conv"; + cache_key.shape_params = {N, C_in, H_in, W_in, C_out, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dil_h, dil_w, groups}; + cache_key.dtype = dtype; + cache_key.transpose_flag = (transposed != 0); + + // Check if we have a cached graph + MPSGraph* mpsGraph = nullptr; + MPSGraphTensor* convOutput = nil; + MPSGraphTensor* finalOutput = nil; + MPSGraphTensor* inputPlaceholder = nil; + MPSGraphTensor* weightPlaceholder = nil; + MPSGraphTensor* biasPlaceholder = nil; + bool has_bias = (bias_tensor != nullptr); + + auto cache_it = graph_cache.find(cache_key); + if (cache_it != graph_cache.end()) { + // Cache hit - reuse compiled graph and tensor references + CachedGraph& cached = cache_it->second; + mpsGraph = cached.graph; + inputPlaceholder = cached.input1; + weightPlaceholder = cached.input2; + biasPlaceholder = cached.input3; // May be nil if no bias + finalOutput = cached.output; + + cache_stats.hits++; + cache_stats.logStats(); + ET_LOG(Debug, "aoti_torch_mps_convolution: Using cached MPSGraph (cache hit, %zu total hits)", cache_stats.hits); + + } else { + // Cache miss - create and compile new graph + mpsGraph = [MPSGraph new]; + cache_stats.misses++; + cache_stats.logStats(); + ET_LOG(Debug, "aoti_torch_mps_convolution: Created new MPSGraph instance (cache miss, %zu total misses)", cache_stats.misses); + + ET_LOG(Debug, "aoti_torch_mps_convolution: Creating placeholders with shapes input:[%d,%d,%d,%d] weight:[%d,%d,%d,%d]", + (int)N, (int)C_in, (int)H_in, (int)W_in, + (int)C_out, (int)C_in, (int)kernel_h, (int)kernel_w); + + // Create placeholders for input tensors + inputPlaceholder = [mpsGraph placeholderWithShape:inputShape + dataType:mps_dtype + name:@"input"]; + weightPlaceholder = [mpsGraph placeholderWithShape:weightShape + dataType:mps_dtype + name:@"weight"]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created input and weight placeholders"); + + // Create convolution descriptor + MPSGraphConvolution2DOpDescriptor* convDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w + strideInY:stride_h + dilationRateInX:dil_w + dilationRateInY:dil_h + groups:groups + paddingLeft:pad_w + paddingRight:pad_w + paddingTop:pad_h + paddingBottom:pad_h + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW + weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created convolution descriptor with stride=[%lld,%lld], padding=[%lld,%lld], dilation=[%lld,%lld], groups=%lld", + stride_w, stride_h, pad_w, pad_h, dil_w, dil_h, groups); + + // Perform convolution using MPSGraph + if (transposed) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Using transposed convolution"); + // For transposed convolution, we need to handle output padding + int64_t output_pad_h = output_padding && output_padding_len_ > 0 ? output_padding[0] : 0; + int64_t output_pad_w = output_padding && output_padding_len_ > 1 ? output_padding[1] : 0; + + // For transposed convolution, we need to adjust the padding calculation + // In transposed convolution, the effective padding is typically negative + // and we use output_padding to control the final output size + int64_t transposed_pad_h = pad_h - output_pad_h; + int64_t transposed_pad_w = pad_w - output_pad_w; + + // Create transposed convolution descriptor with adjusted padding + MPSGraphConvolution2DOpDescriptor* transposedConvDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w + strideInY:stride_h + dilationRateInX:dil_w + dilationRateInY:dil_h + groups:groups + paddingLeft:transposed_pad_w + paddingRight:transposed_pad_w + paddingTop:transposed_pad_h + paddingBottom:transposed_pad_h + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW + weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; + + convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder + weightsTensor:weightPlaceholder + descriptor:transposedConvDesc + name:@"transposed_convolution"]; + } else { + ET_LOG(Debug, "aoti_torch_mps_convolution: Using regular convolution"); + convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder + weightsTensor:weightPlaceholder + descriptor:convDesc + name:@"convolution"]; + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: Successfully created convolution tensor"); + + // Handle bias if provided + if (bias_tensor) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Adding bias to convolution output"); + + // Create bias placeholder + NSArray* biasShape = @[@(C_out)]; + biasPlaceholder = [mpsGraph placeholderWithShape:biasShape + dataType:mps_dtype + name:@"bias"]; + + // Add bias to convolution output + finalOutput = [mpsGraph additionWithPrimaryTensor:convOutput + secondaryTensor:biasPlaceholder + name:@"add_bias"]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias placeholder to graph"); + } else { + finalOutput = convOutput; + } + + // Cache the compiled graph and tensor references for reuse + CachedGraph cached_graph; + cached_graph.graph = mpsGraph; + cached_graph.input1 = inputPlaceholder; + cached_graph.input2 = weightPlaceholder; + cached_graph.input3 = biasPlaceholder; // May be nil if no bias + cached_graph.output = finalOutput; + graph_cache[cache_key] = cached_graph; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Cached compiled MPSGraph for future reuse"); + } // End of cache miss block + + // Create feeds dictionary for graph execution + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + + // Get Metal buffers from tensors + id input_buffer = get_mtl_buffer(input_tensor, "aoti_torch_mps_convolution", "input"); + id weight_buffer = get_mtl_buffer(weight_tensor, "aoti_torch_mps_convolution", "weight"); + + ET_LOG(Debug, "aoti_torch_mps_convolution: Using existing Metal buffers - input=%p, weight=%p", + input_buffer, weight_buffer); + + // Create MPSGraphTensorData objects for input tensors + MPSGraphTensorData* inputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:input_buffer + shape:inputShape + dataType:mps_dtype]; + MPSGraphTensorData* weightData = [[MPSGraphTensorData alloc] initWithMTLBuffer:weight_buffer + shape:weightShape + dataType:mps_dtype]; + + feeds[inputPlaceholder] = inputData; + feeds[weightPlaceholder] = weightData; + + MPSGraphTensorData* biasData = nil; + + // Add bias data to feeds if provided + if (bias_tensor && biasPlaceholder) { + id bias_buffer = get_mtl_buffer(bias_tensor, "aoti_torch_mps_convolution", "bias"); + + NSArray* biasShape = @[@(C_out)]; + biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer + shape:biasShape + dataType:mps_dtype]; + + feeds[biasPlaceholder] = biasData; + ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias tensor to feeds"); + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created feeds dictionary"); + + // Create Metal buffer for output tensor + size_t output_size_bytes = N * C_out * H_out * W_out * element_size; + void* output_contents_ptr = nullptr; + id output_buffer = allocate_mtl_buffer(&output_contents_ptr, output_size_bytes); + + // Create results dictionary (MPSGraph output is 4D) + NSArray* outputShape = @[@(N), @(C_out), @(H_out), @(W_out)]; + MPSGraphTensorData* outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:output_buffer + shape:outputShape + dataType:mps_dtype]; + + NSDictionary* results = @{finalOutput: outputData}; + ET_LOG(Debug, "aoti_torch_mps_convolution: Created results dictionary"); + + // Execute the MPSGraph + ET_LOG(Debug, "aoti_torch_mps_convolution: Executing MPSGraph"); + + @try { + // Use stream helper to encode and synchronize correctly + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); + } @catch (NSException *exception) { + ET_LOG(Error, "aoti_torch_mps_convolution: NSException caught during executeMPSGraph: %s - %s", + [[exception name] UTF8String], [[exception reason] UTF8String]); + throw std::runtime_error("MPSGraph execution failed with NSException"); + } @catch (...) { + ET_LOG(Error, "aoti_torch_mps_convolution: MPSGraph execution failed"); + throw std::runtime_error("MPSGraph execution failed"); + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: MPSGraph execution completed successfully"); + + // Create output tensor handle on device (MPS) that points to GPU buffer + std::vector output_sizes_int64; + std::vector output_strides; + if (!is_conv1d && is_input_4d) { + output_sizes_int64 = {N, C_out, H_out, W_out}; + // Contiguous NCHW strides + output_strides = { + C_out * H_out * W_out, + H_out * W_out, + W_out, + 1 + }; + } else if (!is_conv1d) { + output_sizes_int64 = {C_out, H_out, W_out}; + // Contiguous CHW strides + output_strides = { + H_out * W_out, + W_out, + 1 + }; + } else if (is_conv1d && has_batch_dim) { + output_sizes_int64 = {N, C_out, W_out}; + // Contiguous NCW strides + output_strides = { + C_out * W_out, + W_out, + 1 + }; + } else { + output_sizes_int64 = {C_out, W_out}; + // Contiguous CW strides + output_strides = { + W_out, + 1 + }; + } + + // Use the GPU buffer contents pointer directly for the tensor storage + void* tensor_data = output_contents_ptr; + + AOTITensorHandle output_tensor_handle = nullptr; + + AOTITorchError create_result = aoti_torch_create_tensor_from_blob_v2( + tensor_data, + static_cast(output_sizes_int64.size()), // ndim + output_sizes_int64.data(), + output_strides.data(), + 0, // storage_offset + dtype, // dtype + 13, // device_type (MPS) + 0, // device_index + &output_tensor_handle, + 0, // layout (strided) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + if (create_result != Error::Ok || !output_tensor_handle) { + ET_LOG(Error, "aoti_torch_mps_convolution: Failed to create output tensor, error code: %d", static_cast(create_result)); + aoti_torch_mps_free(tensor_data); // Free the allocated GPU memory on failure + throw std::runtime_error("Failed to create output tensor"); + } + + // Verify the tensor was created with the correct size + auto* et_tensor = reinterpret_cast(output_tensor_handle); + size_t actual_numel = et_tensor->numel(); + size_t expected_numel = static_cast(N * C_out * H_out * W_out); + + if (actual_numel != expected_numel) { + ET_LOG(Error, "aoti_torch_mps_convolution: Tensor size mismatch. Expected %zu, got %zu", expected_numel, actual_numel); + aoti_torch_mps_free(tensor_data); // Free the allocated GPU memory on failure + throw std::runtime_error("Tensor size mismatch"); + } + + // Store the tensor handle - mark that we own the memory since we manually allocated it + *ret0 = output_tensor_handle; + // Mark that we own the memory for these tensors + // Note: memory_to_n_tensor is managed automatically in aoti_torch_create_tensor_from_blob_v2 + // The function sets it to NOT_OWN, but we need to change it to 1 since we allocated it + extern std::unordered_map memory_to_n_tensor; + memory_to_n_tensor[tensor_data] = 1; + + [inputData release]; + [weightData release]; + if (biasData) [biasData release]; + [outputData release]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel); + + ET_LOG(Debug, "aoti_torch_mps_convolution: Executed successfully"); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_convolution exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_convolution: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( + AOTITensorHandle query, + AOTITensorHandle key, + AOTITensorHandle value, + AOTITensorHandle* attn_mask, + double dropout_p, + int32_t is_causal, + AOTITensorHandle* dropout_mask, + double* scale, + AOTITensorHandle* ret0, + AOTITensorHandle* ret1) { + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Starting with Metal kernel implementation"); + + if (!query || !key || !value || !ret0 || !ret1) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: null required tensor handles"); + return Error::InvalidArgument; + } + + if (is_causal) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: is_causal=True not implemented"); + return Error::NotImplemented; + } + if (dropout_p != 0.0) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: dropout_p != 0 not implemented (dropout_p=%f)", dropout_p); + return Error::NotImplemented; + } + if (dropout_mask && *dropout_mask) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: dropout_mask provided not implemented"); + return Error::NotImplemented; + } + + // Use the same dispatch pattern as other MPS operations for consistent synchronization + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get current Metal stream"); + return Error::Internal; + } + + try { + @autoreleasepool { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto* query_tensor = reinterpret_cast(query); + auto* key_tensor = reinterpret_cast(key); + auto* value_tensor = reinterpret_cast(value); + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Converted tensor handles to ET tensors"); + + // Log query tensor shape and strides + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor - dim=%d, shape=[%d, %d, %d, %d], strides=[%d, %d, %d, %d]", + (int)query_tensor->dim(), + query_tensor->dim() > 0 ? query_tensor->sizes()[0] : 0, + query_tensor->dim() > 1 ? query_tensor->sizes()[1] : 0, + query_tensor->dim() > 2 ? query_tensor->sizes()[2] : 0, + query_tensor->dim() > 3 ? query_tensor->sizes()[3] : 0, + query_tensor->dim() > 0 ? query_tensor->strides()[0] : 0, + query_tensor->dim() > 1 ? query_tensor->strides()[1] : 0, + query_tensor->dim() > 2 ? query_tensor->strides()[2] : 0, + query_tensor->dim() > 3 ? query_tensor->strides()[3] : 0); + + // Log key tensor shape and strides + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor - dim=%d, shape=[%d, %d, %d, %d], strides=[%d, %d, %d, %d]", + (int)key_tensor->dim(), + key_tensor->dim() > 0 ? key_tensor->sizes()[0] : 0, + key_tensor->dim() > 1 ? key_tensor->sizes()[1] : 0, + key_tensor->dim() > 2 ? key_tensor->sizes()[2] : 0, + key_tensor->dim() > 3 ? key_tensor->sizes()[3] : 0, + key_tensor->dim() > 0 ? key_tensor->strides()[0] : 0, + key_tensor->dim() > 1 ? key_tensor->strides()[1] : 0, + key_tensor->dim() > 2 ? key_tensor->strides()[2] : 0, + key_tensor->dim() > 3 ? key_tensor->strides()[3] : 0); + + // Log value tensor shape and strides + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor - dim=%d, shape=[%d, %d, %d, %d], strides=[%d, %d, %d, %d]", + (int)value_tensor->dim(), + value_tensor->dim() > 0 ? value_tensor->sizes()[0] : 0, + value_tensor->dim() > 1 ? value_tensor->sizes()[1] : 0, + value_tensor->dim() > 2 ? value_tensor->sizes()[2] : 0, + value_tensor->dim() > 3 ? value_tensor->sizes()[3] : 0, + value_tensor->dim() > 0 ? value_tensor->strides()[0] : 0, + value_tensor->dim() > 1 ? value_tensor->strides()[1] : 0, + value_tensor->dim() > 2 ? value_tensor->strides()[2] : 0, + value_tensor->dim() > 3 ? value_tensor->strides()[3] : 0); + + // Validate tensor dimensions + if (query_tensor->dim() < 3 || key_tensor->dim() < 3 || value_tensor->dim() < 3) { + std::string error_msg = "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: tensors must be at least 3-D, got " + + std::to_string(query_tensor->dim()) + ", " + + std::to_string(key_tensor->dim()) + ", " + + std::to_string(value_tensor->dim()); + ET_LOG(Error, "%s", error_msg.c_str()); + throw std::runtime_error(error_msg); + } + + // Get tensor dimensions (assuming [batch, num_heads, seq_len, head_dim] format) + int64_t batchSize = query_tensor->sizes()[0]; + int64_t num_heads = query_tensor->sizes()[1]; + int64_t qSize = query_tensor->sizes()[2]; + int64_t headSize = query_tensor->sizes()[3]; + int64_t kvSeqLength = key_tensor->sizes()[2]; + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: batchSize=%lld, num_heads=%lld, qSize=%lld, headSize=%lld, kvSeqLength=%lld", + batchSize, num_heads, qSize, headSize, kvSeqLength); + + // Determine data type and element size + int32_t dtype = static_cast(query_tensor->scalar_type()); + size_t element_size; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + element_size = sizeof(float); + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + element_size = sizeof(uint16_t); // bfloat16 is 16 bits + } else { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported data type: %d", dtype); + throw std::runtime_error("Unsupported data type for scaled dot product attention"); + } + + // Check that headSize is not zero to avoid division by zero + if (headSize == 0) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: headSize is zero"); + throw std::runtime_error("headSize must be non-zero for scaled dot product attention"); + } + + // Validate key tensor head dimension to avoid division by zero in gqa_factor calculation + int64_t key_num_heads = key_tensor->sizes()[1]; + if (key_num_heads == 0) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: key tensor head dimension (sizes()[1]) is zero"); + throw std::runtime_error("key tensor must have non-zero head dimension for scaled dot product attention"); + } + + // Calculate scale factor + double scale_factor = scale ? *scale : (1.0 / sqrt(static_cast(headSize))); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scale_factor=%f", scale_factor); + + // Calculate output tensor dimensions + std::vector output_sizes = {batchSize, num_heads, qSize, headSize}; + std::vector attn_sizes = {batchSize, num_heads, qSize, kvSeqLength}; + + // Calculate strides for contiguous tensors + std::vector out_strides = { + num_heads * qSize * headSize, + qSize * headSize, + headSize, + 1 + }; + + std::vector attn_strides = { + num_heads * qSize * kvSeqLength, + qSize * kvSeqLength, + kvSeqLength, + 1 + }; + + // Allocate output Metal buffers via AOTI API to keep GPU residency and reuse + size_t out_size_bytes = batchSize * num_heads * qSize * headSize * element_size; + size_t attn_size_bytes = batchSize * num_heads * qSize * kvSeqLength * element_size; + + void* out_contents_ptr = nullptr; + allocate_mtl_buffer(&out_contents_ptr, out_size_bytes); + + void* attn_contents_ptr = nullptr; + allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes); + + // Use MLX-style Metal kernels instead of MPSGraph + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MLX Metal kernels"); + + // Get shader library + ETMetalShaderLibrary* library = get_sdpa_shader_library(); + if (!library) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get shader library"); + throw std::runtime_error("Failed to get SDPA shader library"); + } + + // Determine kernel name based on dtype and head_dim (PyTorch format) + std::string type_name; + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + type_name = "float"; + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + type_name = "bfloat"; + } else { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported dtype for Metal kernel"); + throw std::runtime_error("Unsupported dtype for Metal SDPA kernel"); + } + + // Select head_dim - must match exactly one of the supported sizes (64, 96, 128) + int64_t head_dim = headSize; + if (head_dim != 64 && head_dim != 96 && head_dim != 128) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, or 128)", head_dim); + throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, or 128"); + } + + std::string kernel_name = "sdpa_vector_" + type_name + "_" + std::to_string(head_dim) + "_" + std::to_string(head_dim); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Using kernel: %s", kernel_name.c_str()); + + // Get kernel function + auto kernel_func = library->getKernelFunction(kernel_name); + if (!kernel_func) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get kernel function: %s", kernel_name.c_str()); + throw std::runtime_error("Failed to get SDPA kernel function"); + } + + // Create output tensor handle first so we can use it in the kernel + AOTITensorHandle out_tensor_handle = nullptr; + AOTITorchError create_out_result = aoti_torch_create_tensor_from_blob_v2( + out_contents_ptr, + 4, // ndim + output_sizes.data(), + out_strides.data(), + 0, // storage_offset + dtype, + 13, // device_type (MPS) + 0, // device_index + &out_tensor_handle, + 0, // layout (strided) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + if (create_out_result != Error::Ok || !out_tensor_handle) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create output tensor"); + aoti_torch_mps_free(out_contents_ptr); + aoti_torch_mps_free(attn_contents_ptr); + throw std::runtime_error("Failed to create output tensor"); + } + + // Mark that we own the memory + extern std::unordered_map memory_to_n_tensor; + memory_to_n_tensor[out_contents_ptr] = 1; + + auto* out_tensor = reinterpret_cast(out_tensor_handle); + + // Prepare kernel arguments (PyTorch format) + uint gqa_factor = static_cast(num_heads / key_tensor->sizes()[1]); + uint N = static_cast(kvSeqLength); + + // Get strides for Q, K, V (all 3 stride levels: batch, head, seq) + uint q_batch_stride = static_cast(query_tensor->strides()[0]); + uint q_head_stride = static_cast(query_tensor->strides()[1]); + uint q_seq_stride = static_cast(query_tensor->strides()[2]); + uint q_dim_stride = static_cast(query_tensor->strides()[3]); + + uint k_batch_stride = static_cast(key_tensor->strides()[0]); + uint k_head_stride = static_cast(key_tensor->sizes()[1] == 1 ? key_tensor->strides()[0] : key_tensor->strides()[1]); + uint k_seq_stride = static_cast(key_tensor->strides()[2]); + uint k_dim_stride = static_cast(key_tensor->strides()[3]); + + uint v_batch_stride = static_cast(value_tensor->strides()[0]); + uint v_head_stride = static_cast(value_tensor->sizes()[1] == 1 ? value_tensor->strides()[0] : value_tensor->strides()[1]); + uint v_seq_stride = static_cast(value_tensor->strides()[2]); + uint v_dim_stride = static_cast(value_tensor->strides()[3]); + + // Log strides for debugging + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Q strides - batch:%u, head:%u, seq:%u, dim:%u", + q_batch_stride, q_head_stride, q_seq_stride, q_dim_stride); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: K strides - batch:%u, head:%u, seq:%u, dim:%u", + k_batch_stride, k_head_stride, k_seq_stride, k_dim_stride); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: V strides - batch:%u, head:%u, seq:%u, dim:%u", + v_batch_stride, v_head_stride, v_seq_stride, v_dim_stride); + + // Check if middle dimensions (1 and 2) are transposed + // For contiguous [batch, num_heads, seq, dim]: stride[1] > stride[2] (head_stride > seq_stride) + // For transposed [batch, seq, num_heads, dim] in memory: stride[1] < stride[2] (head_stride < seq_stride) + bool q_transposed = (q_head_stride < q_seq_stride); + bool k_transposed = (k_head_stride < k_seq_stride); + bool v_transposed = (v_head_stride < v_seq_stride); + + if (q_transposed || k_transposed || v_transposed) { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Transposed middle dimensions detected (dims 1&2 swapped)! Q:%d, K:%d, V:%d", q_transposed, k_transposed, v_transposed); + ET_LOG(Debug, " For transposed layout: head_stride < seq_stride"); + ET_LOG(Debug, " Q: head_stride=%u, seq_stride=%u (transposed=%d)", q_head_stride, q_seq_stride, q_transposed); + ET_LOG(Debug, " K: head_stride=%u, seq_stride=%u (transposed=%d)", k_head_stride, k_seq_stride, k_transposed); + ET_LOG(Debug, " V: head_stride=%u, seq_stride=%u (transposed=%d)", v_head_stride, v_seq_stride, v_transposed); + ET_LOG(Debug, " The updated kernel will handle this by decomposing batch and head indices."); + } + + // Verify innermost dimension has stride=1 (required by current kernel implementation) + if (q_dim_stride != 1 || k_dim_stride != 1 || v_dim_stride != 1) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Non-unit dim stride detected!"); + ET_LOG(Error, " Q dim_stride=%u, K dim_stride=%u, V dim_stride=%u", q_dim_stride, k_dim_stride, v_dim_stride); + ET_LOG(Error, " Current kernel implementation requires innermost dimension to be contiguous (stride=1)"); + throw std::runtime_error("SDPA Metal kernel requires innermost dimension to be contiguous (dim_stride must be 1)"); + } + + bool has_mask_val = (attn_mask && *attn_mask); + + // Calculate mask strides if mask is present + uint mask_head_stride = 0; + uint mask_kv_seq_stride = 0; + uint mask_q_seq_stride = 0; + if (has_mask_val) { + auto* mask_tensor = reinterpret_cast(*attn_mask); + int nd = mask_tensor->dim(); + mask_kv_seq_stride = (nd >= 1 && mask_tensor->sizes()[nd - 1] > 1) ? static_cast(mask_tensor->strides()[nd - 1]) : 0; + mask_q_seq_stride = (nd >= 2 && mask_tensor->sizes()[nd - 2] > 1) ? static_cast(mask_tensor->strides()[nd - 2]) : 0; + mask_head_stride = (nd >= 3 && mask_tensor->sizes()[nd - 3] > 1) ? static_cast(mask_tensor->strides()[nd - 3]) : 0; + } + + // Execute kernel + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Preparing to execute kernel with grid [%llu, %llu, %llu], group [1024, 1, 1]", + (unsigned long long)(batchSize * num_heads), (unsigned long long)qSize, 1ULL); + + kernel_func->runCommandBlock([&]() { + kernel_func->startEncoding(); + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Encoder started, setting arguments"); + + // Set buffer arguments (0-3: Q, K, V, out) + kernel_func->setArg(0, *query_tensor); + kernel_func->setArg(1, *key_tensor); + kernel_func->setArg(2, *value_tensor); + kernel_func->setArg(3, *out_tensor); + + // Set scalar arguments (uint values) + kernel_func->setArg(4, gqa_factor); + kernel_func->setArg(5, N); + + // Set uint3 for qkv_head_strides (buffer 6) + kernel_func->setArgUint3(6, q_head_stride, k_head_stride, v_head_stride); + + // Set uint3 for qkv_seq_strides (buffer 7) + kernel_func->setArgUint3(7, q_seq_stride, k_seq_stride, v_seq_stride); + + // Set scale as float (buffer 8) + kernel_func->setArg(8, static_cast(scale_factor)); + + // Set mask buffer (buffer 9) + if (has_mask_val) { + auto* mask_tensor = reinterpret_cast(*attn_mask); + kernel_func->setArg(9, *mask_tensor); + } else { + // Dummy buffer if no mask (won't be accessed) + kernel_func->setArg(9, *query_tensor); + } + + // Set uint3 for mask_strides (buffer 10) + kernel_func->setArgUint3(10, mask_head_stride, mask_kv_seq_stride, mask_q_seq_stride); + + // Set has_mask as bool (buffer 11) + kernel_func->setArg(11, has_mask_val); + + // Set uint3 for qkv_batch_strides (buffer 12) - NEW + kernel_func->setArgUint3(12, q_batch_stride, k_batch_stride, v_batch_stride); + + // Set num_q_heads (buffer 13) - NEW + kernel_func->setArg(13, static_cast(num_heads)); + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: All arguments set, dispatching"); + + // Dispatch using threadgroups (PyTorch uses grid: [batch*heads, qSize, 1], group: [1024, 1, 1]) + // Note: We need to use dispatchThreadgroups, not dispatchThreads + // Each threadgroup processes one query token across all key-value tokens + kernel_func->dispatchThreadgroups( + batchSize * num_heads, // gridX + qSize, // gridY + 1, // gridZ + 1024, // threadsX + 1, // threadsY + 1); // threadsZ + }); + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Command block completed"); + + AOTITensorHandle attn_tensor_handle = nullptr; + AOTITorchError create_attn_result = aoti_torch_create_tensor_from_blob_v2( + attn_contents_ptr, + 4, // ndim + attn_sizes.data(), + attn_strides.data(), + 0, // storage_offset + dtype, + 13, // device_type (MPS) + 0, // device_index + &attn_tensor_handle, + 0, // layout (strided) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + if (create_attn_result != Error::Ok || !attn_tensor_handle) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create attention weights tensor"); + aoti_torch_mps_free(attn_contents_ptr); + throw std::runtime_error("Failed to create attention weights tensor"); + } + + memory_to_n_tensor[attn_contents_ptr] = 1; + + // Set output tensor handles + *ret0 = out_tensor_handle; + *ret1 = attn_tensor_handle; + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Metal kernel implementation completed successfully"); + + } // @autoreleasepool + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executed successfully"); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: unknown exception"); + return Error::Internal; + } +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp new file mode 100644 index 00000000000..ebb5b7642e1 --- /dev/null +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -0,0 +1,547 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include // Ensure we have int64_t, int32_t definitions +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Import all from aoti namespace +using namespace executorch::backends::aoti; + +// Global storage for tensors and their metadata +std::unordered_set> tensors; + +// Reference counting for memory addresses +// Maps memory address to number of tensors using it +// Special value: NOT_OWN (-1) means tensor never owns the memory +constexpr int32_t NOT_OWN = -1; +std::unordered_map memory_to_n_tensor; + +extern "C" { + +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { + ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2: entered"); + + (void)device_type; + (void)opaque_metadata; + (void)layout; + (void)opaque_metadata_size; + + // Validate input parameters first + ET_CHECK_OR_RETURN_ERROR( + data != nullptr, + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null"); + + ET_CHECK_OR_RETURN_ERROR( + !(sizes_ptr == nullptr && ndim > 0), + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_new_tensor != nullptr, + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null"); + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + // Handle storage offset by adjusting the data pointer + void* adjusted_data = static_cast(data) + + (storage_offset * dtype_to_element_size(dtype)); + + ET_LOG( + Debug, + "aoti_torch_create_tensor_from_blob_v2: original_data=%p, storage_offset=%lld, element_size=%zu, adjusted_data=%p", + data, + storage_offset, + dtype_to_element_size(dtype), + adjusted_data); + + // ETensor sizes + auto sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // ETensor strides + auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Log if the tensor is contiguous + if (is_contiguous_tensor(sizes, strides)) { + ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2: contiguous tensor"); + } else { + ET_LOG( + Debug, "aoti_torch_create_tensor_from_blob_v2: non-contiguous tensor"); + } + + // ETensor creation + // Note: We're NOT copying the data, just wrapping it + auto tensor = executorch::extension::from_blob( + adjusted_data, sizes, strides, dtype_to_scalar_type(dtype)); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, InvalidArgument, "Failed to create tensor from blob"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + *ret_new_tensor = tensor.get(); + + // Check if this memory address is already being tracked + auto memory_it = memory_to_n_tensor.find(adjusted_data); + ET_CHECK_OR_RETURN_ERROR( + memory_it == memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is already being tracked by another tensor", + adjusted_data); + + // Mark this memory as NOT_OWN since tensor created from blob never owns + // memory + memory_to_n_tensor[adjusted_data] = NOT_OWN; + + ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2: successfull"); + return Error::Ok; +} + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor) { + ET_LOG(Debug, "aoti_torch_empty_strided: entered"); + + // This requires us to reserve device memory and put it into a ETensor + void* ptr; + int64_t numel = 1; + for (int i = 0; i < ndim; i++) { + numel *= sizes_ptr[i]; + } + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + size_t element_size = dtype_to_element_size(dtype); + ET_CHECK_OR_RETURN_ERROR( + element_size != 0, + InvalidArgument, + "Invalid element size for dtype: %d", + dtype); + int64_t nbytes = numel * element_size; + + int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13 + if (device_type == mps_device_type) { + ptr = metal_allocate_buffer(nbytes); + if (!ptr) { + ET_LOG(Error, "Failed to allocate %lld bytes on Metal device", nbytes); + return Error::MemoryAllocationFailed; + } + } else if (device_type == 0) { // cpu + // Ensure 16-byte alignment for CPU memory to match device requirements + int result = posix_memalign(&ptr, 16, nbytes); + ET_CHECK_OR_RETURN_ERROR( + result == 0, + MemoryAllocationFailed, + "Failed to allocate aligned CPU memory"); + ET_CHECK_OR_RETURN_ERROR( + ptr != nullptr, + MemoryAllocationFailed, + "Failed to call posix_memalign"); + ET_LOG(Debug, "Allocated %lld bytes on CPU", nbytes); + } else { + ET_CHECK_OR_RETURN_ERROR( + false, + NotImplemented, + "Need to implement empty_strided for non-CUDA non-CPU device type %d", + device_type); + } + + // ETensor sizes + auto sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // ETensor strides + auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Log if the tensor is contiguous + if (is_contiguous_tensor(sizes, strides)) { + ET_LOG(Debug, "aoti_torch_empty_strided: contiguous tensor"); + } else { + ET_LOG(Debug, "aoti_torch_empty_strided: non-contiguous tensor"); + } + + // ETensor creation + // Note: We're NOT copying the data, just wrapping it + executorch::aten::ScalarType scalar_type = dtype_to_scalar_type(dtype); + auto tensor = + executorch::extension::from_blob(ptr, sizes, strides, scalar_type); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + *ret_new_tensor = tensor.get(); + + // This tensor owns the memory it allocated, set reference count to 1 + memory_to_n_tensor[ptr] = 1; + + ET_LOG(Debug, "aoti_torch_empty_strided: successfull"); + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor) { + ET_LOG(Debug, "aoti_torch_delete_tensor_object: entered"); + + // Handle null tensor pointer + if (tensor == nullptr) { + ET_LOG(Debug, "aoti_torch_delete_tensor_object: null tensor"); + return Error::Ok; + } + + // Check if tensor exists in our tracking + bool found_in_tensors = false; + for (auto it = tensors.begin(); it != tensors.end(); ++it) { + if (it->get() == tensor) { + found_in_tensors = true; + break; + } + } + + // If tensor not found in our tracking, it's invalid + ET_CHECK_OR_RETURN_ERROR( + found_in_tensors, InvalidArgument, "Didn't find tensor %p", tensor); + + // Find and delete the tensor + for (auto it = tensors.begin(); it != tensors.end(); ++it) { + if (it->get() == tensor) { + // Get the tensor before erasing + auto tensor_ptr = *it; + void* data_ptr = tensor_ptr->mutable_data_ptr(); + + // Find the reference count for this memory address + auto memory_it = memory_to_n_tensor.find(data_ptr); + if (memory_it != memory_to_n_tensor.end()) { + int32_t ref_count = memory_it->second; + + if (ref_count == NOT_OWN) { + // Tensor never owned the memory, skip freeing + // Just remove tensor from tracking + tensors.erase(it); + ET_LOG( + Debug, + "aoti_torch_delete_tensor_object: tensor doesn't own memory, skipping free"); + return Error::Ok; + } else if (ref_count == 1) { + // Only current tensor using this memory, free it + // Check if it's Metal GPU memory + if (metal_is_device_pointer(data_ptr)) { + metal_deallocate_buffer(data_ptr); + } else { + // This is CPU memory - free immediately + free(data_ptr); + data_ptr = nullptr; + ET_LOG( + Debug, "aoti_torch_delete_tensor_object: freeing CPU memory"); + } + + // Remove from memory tracking + memory_to_n_tensor.erase(memory_it); + } else if (ref_count > 1) { + // Other tensors still using this memory, just decrement count + memory_to_n_tensor[data_ptr] = ref_count - 1; + ET_LOG( + Debug, + "aoti_torch_delete_tensor_object: decremented ref count from %d to %d", + ref_count, + ref_count - 1); + } + } else { + ET_CHECK_OR_RETURN_ERROR( + false, + Internal, + "Internal error: memory not found during deletion"); + } + + // Remove tensor from set (this will call the destructor if it's the last + // reference) + tensors.erase(it); + ET_LOG(Debug, "aoti_torch_delete_tensor_object: successfull"); + return Error::Ok; + } + } + + // This should never be reached since we found it above + ET_CHECK_OR_RETURN_ERROR( + false, Internal, "Internal error: tensor not found after validation"); +} + +AOTITorchError aoti_torch_copy_( + AOTITensorHandle self, + AOTITensorHandle src, + int32_t non_blocking) { + ET_LOG(Debug, "aoti_torch_copy_: entered"); + + (void)non_blocking; + + // Check for null pointers first + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch_copy_ failed: self tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + src != nullptr, + InvalidArgument, + "aoti_torch_copy_ failed: src tensor is null"); + + // Get dtype information and validate compatibility + int32_t self_dtype, src_dtype; + aoti_torch_get_dtype(self, &self_dtype); + aoti_torch_get_dtype(src, &src_dtype); + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(self_dtype)); + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(src_dtype)); + + // Check dtype compatibility - both tensors must have the same dtype + ET_CHECK_OR_RETURN_ERROR( + self_dtype == src_dtype, + InvalidArgument, + "dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_ requires same dtypes", + self_dtype, + src_dtype); + + // Check total number of elements compatibility (PyTorch copy_ behavior) + int64_t self_numel = self->numel(); + int64_t src_numel = src->numel(); + + ET_CHECK_OR_RETURN_ERROR( + self_numel == src_numel, + InvalidArgument, + "numel mismatch. self.numel()=%ld, src.numel()=%ld", + self_numel, + src_numel); + + // Get tensor metadata + int64_t* self_strides; + int64_t* src_strides; + aoti_torch_get_strides(self, &self_strides); + aoti_torch_get_strides(src, &src_strides); + + int64_t* self_sizes; + int64_t* src_sizes; + aoti_torch_get_sizes(self, &self_sizes); + aoti_torch_get_sizes(src, &src_sizes); + + // Determine device locations + bool srcIsDevice = false; + bool dstIsDevice = false; + + // Check if pointers are Metal device pointers + if (!srcIsDevice) { + srcIsDevice = metal_is_device_pointer(const_cast(src->data_ptr())); + } + if (!dstIsDevice) { + dstIsDevice = metal_is_device_pointer(self->mutable_data_ptr()); + } + + // Check if tensors have the same schema (sizes, strides, dtype) for fast path + // TODO: This should be improved to catch cases like (4, 1, 5) -> (4, 5) + bool same_schema = true; + for (int i = 0; i < self->dim(); i++) { + if (self_strides[i] != src_strides[i]) { + same_schema = false; + break; + } + } + + size_t total_bytes = src->nbytes(); + int64_t total_elements = self->numel(); + + if (same_schema) { + int result = metal_copy_memory( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + srcIsDevice, + dstIsDevice); + if (result != 0) { + ET_LOG(Error, "metal_copy_memory failed with status %d", result); + return Error::Internal; + } + } else { + ET_LOG(Error, "Layout conversion not supported"); + return Error::NotImplemented; + } + + ET_LOG(Debug, "aoti_torch_copy_: successfull"); + return Error::Ok; +} + +AOTITorchError aoti_torch__reinterpret_tensor( + AOTITensorHandle self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + AOTITensorHandle* ret_new_tensor) { + ET_LOG(Debug, "aoti_torch__reinterpret_tensor: entered"); + + // Validate input parameters first + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: self tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + !(sizes_ptr == nullptr && ndim > 0), + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: sizes_ptr is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_new_tensor != nullptr, + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: ret_new_tensor is null"); + + // Check if storage_offset is not 0 - return error if not + ET_CHECK_OK_OR_RETURN_ERROR(validate_storage_offset(storage_offset)); + + // Get the device info from the source tensor to perform device_index + // validation + int32_t device_type = 0; + int32_t device_index = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_type(self, &device_type)); + + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_index(self, &device_index)); + + // Ensure device_index is always 0 + ET_CHECK_OR_RETURN_ERROR( + device_index == 0, + InvalidArgument, + "device_index must be 0, got: %d", + device_index); + + // Get the dtype from the source tensor + int32_t dtype = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(self, &dtype)); + + // Validate dtype using SupportedDTypes + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + // Get the original data pointer from the source tensor + void* data_ptr = self->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + // Check if the given memory is in the map, if not return error + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked by reference counting system", + data_ptr); + + // Convert sizes using utility function from utils.h + std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // Convert strides using utility function from utils.h + std::vector strides = + convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create new tensor view that reinterprets the same memory with different + // shape/strides This creates a view, not a copy - the data pointer is shared + std::shared_ptr tensor = executorch::extension::from_blob( + data_ptr, // Reuse the same memory from source tensor + sizes, // New sizes with explicit SizesType + strides, // New strides with explicit StridesType + dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting + ); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, + InvalidArgument, + "Failed to create reinterpreted tensor view"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *ret_new_tensor = tensor.get(); + + // Increment the reference count for this memory address only if it is owned + // by tensor + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; + + ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successfull"); + return Error::Ok; +} + +AOTITorchError aoti_torch_new_tensor_handle( + Tensor* orig_handle, + Tensor** new_handle) { + (void)orig_handle; + (void)new_handle; + throw std::runtime_error("Not implemented"); + return Error::Internal; +} + +// Cleanup function for clearing global state +void cleanup_memory() { + // Use aoti_torch_delete_tensor_object to properly delete each tensor + // Note: We need to collect tensor pointers first since deletion modifies the + // set + std::vector tensor_ptrs; + tensor_ptrs.reserve(tensors.size()); + for (const auto& tensor_shared : tensors) { + tensor_ptrs.push_back(tensor_shared.get()); + } + + // Now delete each tensor - this will modify the global tensors set + for (Tensor* tensor_ptr : tensor_ptrs) { + aoti_torch_delete_tensor_object(tensor_ptr); + } + + // tensors set should now be empty, but ensure it's cleared + tensors.clear(); + + // Clean up Metal resources + metal_cleanup_resources(); + + ET_LOG(Info, "Cleared all tensors and Metal resources"); +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/memory.h b/backends/apple/metal/runtime/shims/memory.h new file mode 100644 index 00000000000..dda0e6bd6c7 --- /dev/null +++ b/backends/apple/metal/runtime/shims/memory.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Global storage declarations +extern std::unordered_map memory_to_n_tensor; +extern std::unordered_set> tensors; + +// Memory-related operations +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor); + +AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor); + +AOTITorchError aoti_torch_copy_( + AOTITensorHandle self, + AOTITensorHandle src, + int32_t non_blocking); + +AOTITorchError aoti_torch__reinterpret_tensor( + AOTITensorHandle self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + AOTITensorHandle* ret_new_tensor); + +AOTITorchError aoti_torch_new_tensor_handle( + Tensor* orig_handle, + Tensor** new_handle); + +void cleanup_memory(); + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/shim_mps.h b/backends/apple/metal/runtime/shims/shim_mps.h new file mode 100644 index 00000000000..94611b016ae --- /dev/null +++ b/backends/apple/metal/runtime/shims/shim_mps.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace metal { + +struct AOTIMetalKernelFunctionOpaque; +using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*; + +struct AOTIMetalShaderLibraryOpaque; +using AOTIMetalShaderLibraryHandle = AOTIMetalShaderLibraryOpaque*; + +#ifdef __cplusplus +extern "C" { +#endif + +// MetalShaderLibrary functions +AOTITorchError aoti_torch_mps_create_shader_library( + const char* metal_shader_source, + AOTIMetalShaderLibraryHandle* library_handle); + +AOTITorchError aoti_torch_mps_delete_shader_library( + AOTIMetalShaderLibraryHandle library_handle); + +AOTITorchError aoti_torch_mps_get_kernel_function( + AOTIMetalShaderLibraryHandle library_handle, + const char* kernel_name, + AOTIMetalKernelFunctionHandle* function_handle); + +// MetalKernelFunction functions +AOTITorchError aoti_torch_mps_start_encoding( + AOTIMetalKernelFunctionHandle func); + +AOTITorchError aoti_torch_mps_set_arg_tensor( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + AOTITensorHandle tensor); + +AOTITorchError aoti_torch_mps_set_arg_int( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + int64_t val); + +// Pure C dispatch functions - single value versions +AOTITorchError aoti_torch_mps_dispatch_single( + AOTIMetalKernelFunctionHandle func, + uint64_t length); + +AOTITorchError aoti_torch_mps_dispatch_single_with_group_size( + AOTIMetalKernelFunctionHandle func, + uint64_t length, + uint64_t group_size); + +// Pure C dispatch functions - array versions +AOTITorchError aoti_torch_mps_dispatch_array( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size); + +AOTITorchError aoti_torch_mps_dispatch_array_with_group_size( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size, + const uint64_t* group_size, + size_t group_size_size); + +// Memory management functions +AOTITorchError aoti_torch_mps_malloc(void** buffer, size_t num_bytes); + +AOTITorchError aoti_torch_mps_free(void* ptr); + +AOTITorchError aoti_torch_mps_memcpy( + void* buffer, + size_t constant_offset, + size_t bytes_read, + size_t data_size, + uint8_t* constants_start); + +AOTITorchError aoti_torch_mps_copy_buffer( + void* src_buffer, + void* dst_buffer, + size_t data_size, + size_t src_offset, + size_t dst_offset); + +// C callback function type for command block execution +typedef void (*aoti_torch_mps_command_block_callback_t)( + AOTIMetalKernelFunctionHandle func, + void* user_data); + +// Shared callback function for std::function trampoline +void aoti_torch_mps_shared_callback( + AOTIMetalKernelFunctionHandle func, + void* user_data); + +// Pure C version using function pointer and user data for trampoline pattern +AOTITorchError aoti_torch_mps_run_command_block( + AOTIMetalKernelFunctionHandle func, + aoti_torch_mps_command_block_callback_t callback, + void* user_data); + +#ifdef __cplusplus +} // extern "C" +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/shim_mps.mm b/backends/apple/metal/runtime/shims/shim_mps.mm new file mode 100644 index 00000000000..337e1c7176a --- /dev/null +++ b/backends/apple/metal/runtime/shims/shim_mps.mm @@ -0,0 +1,554 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import +#import +#import +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Declare the global mapping from et_metal.mm +extern std::unordered_map> ptr_to_mtl_buffer; + +extern "C" { + +// MetalShaderLibrary functions +AOTITorchError aoti_torch_mps_create_shader_library( + const char* metal_shader_source, + AOTIMetalShaderLibraryHandle* library_handle) { + + if (!metal_shader_source || !library_handle) { + ET_LOG(Error, "aoti_torch_mps_create_shader_library: null arguments"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + auto library = std::make_unique(std::string(metal_shader_source)); + auto* raw_library = library.get(); + + // Store the unique_ptr to keep the object alive + storeLibraryHandle(raw_library, std::move(library)); + + // Return raw pointer to match existing API + *library_handle = reinterpret_cast(raw_library); + + ET_LOG(Debug, "aoti_torch_mps_create_shader_library: Created shader library %p", raw_library); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_create_shader_library exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_create_shader_library: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_delete_shader_library( + AOTIMetalShaderLibraryHandle library_handle) { + + if (!library_handle) { + ET_LOG(Error, "aoti_torch_mps_delete_shader_library: null library handle"); + return Error::InvalidArgument; + } + + try { + auto* library = reinterpret_cast(library_handle); + if (removeLibraryHandle(library)) { + ET_LOG(Debug, "aoti_torch_mps_delete_shader_library: Deleted shader library %p", library); + } else { + ET_LOG(Error, "aoti_torch_mps_delete_shader_library: Library not found in storage"); + return Error::InvalidArgument; + } + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_delete_shader_library exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_delete_shader_library: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_get_kernel_function( + AOTIMetalShaderLibraryHandle library_handle, + const char* kernel_name, + AOTIMetalKernelFunctionHandle* function_handle) { + + if (!library_handle || !kernel_name || !function_handle) { + ET_LOG(Error, "aoti_torch_mps_get_kernel_function: null arguments"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + auto* library = reinterpret_cast(library_handle); + auto function_shared_ptr = library->getKernelFunction(std::string(kernel_name)); + if (!function_shared_ptr) { + ET_LOG(Error, "aoti_torch_mps_get_kernel_function: Failed to get kernel function '%s'", kernel_name); + return Error::Internal; + } + + auto* raw_function = function_shared_ptr.get(); + + // Store the shared_ptr to keep the object alive + storeFunctionHandle(raw_function, function_shared_ptr); + + // Return raw pointer to match existing API + *function_handle = reinterpret_cast(raw_function); + + ET_LOG(Debug, "aoti_torch_mps_get_kernel_function: Got kernel function '%s' -> %p", kernel_name, raw_function); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_get_kernel_function exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_get_kernel_function: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_start_encoding( + AOTIMetalKernelFunctionHandle func) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_start_encoding: null function handle"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + auto* function = reinterpret_cast(func); + function->startEncoding(); + + ET_LOG(Debug, "aoti_torch_mps_start_encoding: Started encoding for function %p", function); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_start_encoding exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_start_encoding: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_set_arg_tensor( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + AOTITensorHandle tensor) { + + if (!func || !tensor) { + ET_LOG(Error, "aoti_torch_mps_set_arg_tensor: null function handle or tensor"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + auto* function = reinterpret_cast(func); + auto* et_tensor = reinterpret_cast(tensor); + + function->setArg(idx, *et_tensor); + + ET_LOG(Debug, "aoti_torch_mps_set_arg_tensor: Set tensor argument at index %u", idx); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_set_arg_tensor exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_set_arg_tensor: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_set_arg_int( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + int64_t val) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_set_arg_int: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->setArg(idx, val); + + ET_LOG(Debug, "aoti_torch_mps_set_arg_int: Set int64_t value %lld at index %u", val, idx); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_set_arg_int exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_set_arg_int: unknown exception"); + return Error::Internal; + } +} + +// Pure C dispatch functions - single value versions +AOTITorchError aoti_torch_mps_dispatch_single( + AOTIMetalKernelFunctionHandle func, + uint64_t length) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->dispatchSingle(length); + + ET_LOG(Debug, "aoti_torch_mps_dispatch_single: Dispatched function %p with length %llu", function, length); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_dispatch_single_with_group_size( + AOTIMetalKernelFunctionHandle func, + uint64_t length, + uint64_t group_size) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single_with_group_size: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->dispatchSingleWithGroupSize(length, group_size); + + ET_LOG(Debug, "aoti_torch_mps_dispatch_single_with_group_size: Dispatched function %p with length %llu, group size %llu", function, length, group_size); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single_with_group_size exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single_with_group_size: unknown exception"); + return Error::Internal; + } +} + +// Pure C dispatch functions - array versions +AOTITorchError aoti_torch_mps_dispatch_array( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array: null function handle"); + return Error::InvalidArgument; + } + + if (!length) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size: null length pointer"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->dispatchArray(length, length_size); + + ET_LOG(Debug, "aoti_torch_mps_dispatch_array: Dispatched function %p with %zu dimensions", function, length_size); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_dispatch_array_with_group_size( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size, + const uint64_t* group_size, + size_t group_size_size) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size: null function handle"); + return Error::InvalidArgument; + } + + if (!length) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size: null length pointer"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->dispatchArrayWithGroupSize(length, length_size, group_size, group_size_size); + + ET_LOG(Debug, "aoti_torch_mps_dispatch_array_with_group_size: Dispatched function %p with %zu dimensions", function, length_size); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_malloc(void** buffer, size_t num_bytes) { + if (num_bytes == 0) { + *buffer = nullptr; + return Error::Ok; + } + + if (!buffer) { + ET_LOG(Error, "aoti_torch_mps_malloc: null buffer pointer"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_malloc: Failed to get Metal device"); + return Error::Internal; + } + + id metal_buffer = [device newBufferWithLength:num_bytes + options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared]; + if (!metal_buffer) { + ET_LOG(Error, "aoti_torch_mps_malloc: Failed to allocate Metal buffer of size %zu", num_bytes); + return Error::Internal; + } + + // FIX: Return contents pointer, not buffer object + void* contents_ptr = [metal_buffer contents]; + ptr_to_mtl_buffer[contents_ptr] = metal_buffer; // Map contents to buffer + *buffer = contents_ptr; // Return contents pointer + + ET_LOG(Debug, "aoti_torch_mps_malloc: Allocated Metal buffer %p with contents %p of size %zu", + metal_buffer, contents_ptr, num_bytes); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_malloc exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_malloc: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_free(void* ptr) { + if (!ptr) { + return Error::Ok; // Nothing to free + } + + @autoreleasepool { + try { + // FIX: ptr is now the contents pointer, not the buffer object + // Look up the buffer from the mapping and clean up + auto it = ptr_to_mtl_buffer.find(ptr); + if (it != ptr_to_mtl_buffer.end()) { + id metal_buffer = it->second; + [metal_buffer release]; + ptr_to_mtl_buffer.erase(it); + ET_LOG(Debug, "aoti_torch_mps_free: Freed Metal buffer for contents %p", ptr); + } else { + ET_LOG(Error, "aoti_torch_mps_free: Buffer not found for contents pointer %p", ptr); + return Error::InvalidArgument; + } + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_free exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_free: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_memcpy( + void* buffer, + size_t constant_offset, + size_t bytes_read, + size_t data_size, + uint8_t* constants_start) { + + if (!buffer || !constants_start) { + ET_LOG(Error, "aoti_torch_mps_memcpy: null buffer or constants_start"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // FIX: buffer is now the contents pointer, not the buffer object + auto buffer_pointer = static_cast(buffer); + + memcpy(buffer_pointer + constant_offset, constants_start + bytes_read, data_size); + + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_memcpy: Failed to get Metal device"); + return Error::Internal; + } + id subBuffer = [device newBufferWithBytesNoCopy:buffer_pointer + constant_offset + length:data_size + options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared + deallocator:nil]; + + if (constant_offset != 0) { + ptr_to_mtl_buffer[buffer_pointer + constant_offset] = subBuffer; // Map contents to buffer + } + + ET_LOG(Debug, "aoti_torch_mps_memcpy: Copied %zu bytes from offset %zu to buffer offset %zu", + data_size, bytes_read, constant_offset); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_memcpy exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_memcpy: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_copy_buffer( + void* src_buffer, + void* dst_buffer, + size_t data_size, + size_t src_offset, + size_t dst_offset) { + + if (!src_buffer || !dst_buffer) { + ET_LOG(Error, "aoti_torch_mps_copy_buffer: null buffer"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + auto src_mtl_buffer = (id)src_buffer; + auto dst_mtl_buffer = (id)dst_buffer; + + uint8_t* src_contents = static_cast([src_mtl_buffer contents]); + uint8_t* dst_contents = static_cast([dst_mtl_buffer contents]); + + if (!src_contents || !dst_contents) { + ET_LOG(Error, "aoti_torch_mps_copy_buffer: Failed to get buffer contents"); + return Error::Internal; + } + + memcpy(dst_contents + dst_offset, src_contents + src_offset, data_size); + + ET_LOG(Debug, "aoti_torch_mps_copy_buffer: Copied %zu bytes from src+%zu to dst+%zu", + data_size, src_offset, dst_offset); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_copy_buffer exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_copy_buffer: unknown exception"); + return Error::Internal; + } + } +} + +// Shared callback function for std::function trampoline +void aoti_torch_mps_shared_callback( + AOTIMetalKernelFunctionHandle func, + void* user_data) { + ET_LOG(Debug, "aoti_torch_mps_shared_callback: Called with func=%p, user_data=%p", func, user_data); + + auto* function_wrapper = static_cast*>(user_data); + if (function_wrapper) { + ET_LOG(Debug, "aoti_torch_mps_shared_callback: Calling function wrapper"); + (*function_wrapper)(func); + ET_LOG(Debug, "aoti_torch_mps_shared_callback: Function wrapper completed"); + } else { + ET_LOG(Error, "aoti_torch_mps_shared_callback: null function wrapper"); + } +} + +// Pure C version using function pointer and user data for trampoline pattern +AOTITorchError aoti_torch_mps_run_command_block( + AOTIMetalKernelFunctionHandle func, + aoti_torch_mps_command_block_callback_t callback, + void* user_data) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_run_command_block: null function handle"); + return Error::InvalidArgument; + } + + if (!callback) { + ET_LOG(Error, "aoti_torch_mps_run_command_block: null callback"); + return Error::InvalidArgument; + } + + ET_LOG(Debug, "aoti_torch_mps_run_command_block: Starting command block for function %p, callback %p, user_data %p", + func, callback, user_data); + + try { + auto* function = reinterpret_cast(func); + function->runCommandBlock([callback, func, user_data]() { + ET_LOG(Debug, "aoti_torch_mps_run_command_block: Inside lambda, calling callback"); + callback(func, user_data); + ET_LOG(Debug, "aoti_torch_mps_run_command_block: Callback completed"); + }); + + ET_LOG(Debug, "aoti_torch_mps_run_command_block: Executed command block for function %p", function); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_run_command_block exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_run_command_block: unknown exception"); + return Error::Internal; + } +} + +} // extern "C" + + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/tensor_attribute.cpp b/backends/apple/metal/runtime/shims/tensor_attribute.cpp new file mode 100644 index 00000000000..34e0329fdc9 --- /dev/null +++ b/backends/apple/metal/runtime/shims/tensor_attribute.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Metal-specific device type constant +__attribute__((__visibility__("default"))) int32_t +aoti_torch_device_type_mps() { + return 13; // Consistent with c10/core/DeviceType.h +} + +// Override aoti_torch_get_device_type to return MPS device type +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type) { + *ret_device_type = aoti_torch_device_type_mps(); + return Error::Ok; +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/tensor_attribute.h b/backends/apple/metal/runtime/shims/tensor_attribute.h new file mode 100644 index 00000000000..8d2a3dde361 --- /dev/null +++ b/backends/apple/metal/runtime/shims/tensor_attribute.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Metal-specific device type function +int32_t aoti_torch_device_type_mps(); + +// Override aoti_torch_get_device_type to return MPS device type +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type); + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/types.h b/backends/apple/metal/runtime/shims/types.h new file mode 100644 index 00000000000..07d377d7499 --- /dev/null +++ b/backends/apple/metal/runtime/shims/types.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Common using declarations for ExecutorTorch types +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +extern "C" { + +// Common AOTI type aliases +// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility +using AOTITensorHandle = Tensor*; +using AOTIRuntimeError = Error; +using AOTITorchError = Error; + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp new file mode 100644 index 00000000000..061360a4e28 --- /dev/null +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Helper function to check if a dtype is supported in Metal backend +bool is_dtype_supported_in_et_metal(int32_t dtype) { + switch (dtype) { + case static_cast(SupportedDTypes::INT64): + case static_cast(SupportedDTypes::FLOAT32): + case static_cast(SupportedDTypes::BFLOAT16): + return true; + default: + return false; + } +} + +// Metal-specific dtype validation utility function +AOTITorchError validate_dtype(int32_t dtype) { + if (is_dtype_supported_in_et_metal(dtype)) { + return Error::Ok; + } + + ET_LOG( + Error, + "Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)", + dtype, + static_cast(SupportedDTypes::INT64), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::BFLOAT16)); + return Error::InvalidArgument; +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/utils.h b/backends/apple/metal/runtime/shims/utils.h new file mode 100644 index 00000000000..974832fa365 --- /dev/null +++ b/backends/apple/metal/runtime/shims/utils.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Enum for supported data types in et-metal backend +enum class SupportedDTypes : int32_t { + // UINT8 = 0, // PyTorch's uint8 dtype code + // INT8 = 1, // PyTorch's int8 dtype code + // INT16 = 2, // PyTorch's int16 dtype code + // INT32 = 3, // PyTorch's int32 dtype code + INT64 = 4, // PyTorch's int64 dtype code + // FLOAT16 = 5, // PyTorch's float16 dtype code + FLOAT32 = 6, // PyTorch's float32 dtype code + // FLOAT64 = 7, // PyTorch's float64 dtype code + // BOOL = 11, // PyTorch's bool dtype code + BFLOAT16 = 15 // PyTorch's bfloat16 dtype code +}; + +extern "C" { + +// Helper function to check if a dtype is supported in Metal backend +bool is_dtype_supported_in_et_metal(int32_t dtype); + +// Metal-specific dtype validation utility function +AOTITorchError validate_dtype(int32_t dtype); + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/tests/__init__.py b/backends/apple/metal/tests/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/apple/metal/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/apple/metal/tests/test_metal_backend.py b/backends/apple/metal/tests/test_metal_backend.py new file mode 100644 index 00000000000..5caf7a3adc6 --- /dev/null +++ b/backends/apple/metal/tests/test_metal_backend.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from executorch.backends.apple.metal.metal_backend import ( + COMPILE_SPEC_KEYS, + MetalBackend, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec + + +class TestMetalBackend(unittest.TestCase): + """Test Metal backend utility functions.""" + + def test_generate_method_name_compile_spec(self): + """Test that compile spec is generated correctly with method name.""" + method_name = "forward" + compile_spec = MetalBackend.generate_method_name_compile_spec(method_name) + + # Verify compile spec structure + self.assertIsInstance(compile_spec, CompileSpec) + self.assertEqual(compile_spec.key, COMPILE_SPEC_KEYS.METHOD_NAME.value) + self.assertEqual(compile_spec.value, method_name.encode("utf-8")) + + def test_method_name_from_compile_specs(self): + """Test extracting method name from compile specs.""" + method_name = "forward" + compile_specs = [MetalBackend.generate_method_name_compile_spec(method_name)] + + # Extract method name + extracted_name = MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertEqual(extracted_name, method_name) + + def test_method_name_from_compile_specs_with_multiple_specs(self): + """Test extracting method name when there are multiple compile specs.""" + method_name = "forward" + compile_specs = [ + CompileSpec("other_key", b"other_value"), + MetalBackend.generate_method_name_compile_spec(method_name), + CompileSpec("another_key", b"another_value"), + ] + + # Extract method name + extracted_name = MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertEqual(extracted_name, method_name) + + def test_method_name_from_compile_specs_missing(self): + """Test that RuntimeError is raised when method name is missing.""" + compile_specs = [ + CompileSpec("other_key", b"other_value"), + ] + + # Should raise RuntimeError when method name is not found + with self.assertRaises(RuntimeError) as context: + MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertIn("Could not find method name", str(context.exception)) + + def test_compile_spec_roundtrip(self): + """Test that method name survives encode/decode roundtrip.""" + original_name = "my_custom_method" + + # Generate compile spec + compile_spec = MetalBackend.generate_method_name_compile_spec(original_name) + + # Extract from compile specs list + extracted_name = MetalBackend.method_name_from_compile_specs([compile_spec]) + + self.assertEqual(original_name, extracted_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/apple/metal/tests/test_metal_partitioner.py b/backends/apple/metal/tests/test_metal_partitioner.py new file mode 100644 index 00000000000..1b29410ab6c --- /dev/null +++ b/backends/apple/metal/tests/test_metal_partitioner.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Tuple + +import torch +from executorch.backends.apple.metal.metal_backend import MetalBackend +from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner +from executorch.exir.backend.partitioner import PartitionResult +from torch.export import export + + +class TestMetalPartitioner(unittest.TestCase): + """ + Test Metal partitioner functionality. + + After Metal partitioning, there should be exactly one partitioned graph that contains + all operators from the input graph. This means all operators should be tagged with + the same delegation tag, indicating they will all be executed by the Metal backend. + """ + + def _get_partition_result( + self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...] + ) -> PartitionResult: + """Helper method to get partition result for a given module.""" + # Export the model + exported_program = export(module, inputs, strict=True) + + # Create partitioner with compile specs + compile_specs = [MetalBackend.generate_method_name_compile_spec("forward")] + partitioner = MetalPartitioner(compile_specs) + + # Get partition result + partition_result = partitioner.partition(exported_program) + + # Verify partition result structure + self.assertIsNotNone(partition_result) + self.assertTrue(hasattr(partition_result, "tagged_exported_program")) + self.assertTrue(hasattr(partition_result, "partition_tags")) + + return partition_result + + def _check_fully_partitioned(self, partition_result: PartitionResult) -> bool: + """Check if the graph is fully partitioned (all operators have the same tag).""" + tagged_nodes = [] + untagged_ops = [] + + for node in partition_result.tagged_exported_program.graph.nodes: + if node.op == "call_function": + if hasattr(node, "meta") and "delegation_tag" in node.meta: + tagged_nodes.append(node) + else: + untagged_ops.append(node) + + # Check if we have any tagged nodes + if not tagged_nodes: + return False + + # Check if all tagged nodes have the same tag + first_tag = tagged_nodes[0].meta["delegation_tag"] + all_same_tag = all( + node.meta.get("delegation_tag") == first_tag for node in tagged_nodes + ) + + # Should have no untagged operations for full partitioning + fully_partitioned = len(untagged_ops) == 0 and all_same_tag + + return fully_partitioned + + def test_simple_add_partition(self): + """ + Test that Metal partitioner creates exactly one partition containing all operators. + Simple element-wise addition should result in a single graph with all ops tagged identically. + """ + + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + # Create test inputs + x = torch.randn(2, 3) + y = torch.randn(2, 3) + + # Get partition result + partition_result = self._get_partition_result(AddModule(), (x, y)) + + # Verify it's fully partitioned + self.assertTrue( + self._check_fully_partitioned(partition_result), + "Expected all operations to be in a single partition", + ) + + # Verify exactly one partition tag exists + self.assertEqual( + len(partition_result.partition_tags), + 1, + "Expected exactly one partition tag for fully delegated graph", + ) + + def test_linear_partition(self): + """ + Test Metal partitioner with a linear layer. + All matrix operations should be in a single partition. + """ + + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + # Create test input + x = torch.randn(2, 10) + + # Get partition result + partition_result = self._get_partition_result(LinearModule(), (x,)) + + # Verify it's fully partitioned + self.assertTrue( + self._check_fully_partitioned(partition_result), + "Expected all operations to be in a single partition", + ) + + def test_ops_to_not_decompose(self): + """ + Test that ops_to_not_decompose returns all call_function ops. + Metal backend should handle decomposition via AOTInductor. + """ + + class SimpleModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.relu(x + 1.0) + + # Create test input + x = torch.randn(2, 3) + + # Export the model + exported_program = export(SimpleModule(), (x,), strict=True) + + # Create partitioner + compile_specs = [MetalBackend.generate_method_name_compile_spec("forward")] + partitioner = MetalPartitioner(compile_specs) + + # Get ops to not decompose + ops_to_not_decompose, _ = partitioner.ops_to_not_decompose(exported_program) + + # Verify it returns a list + self.assertIsInstance(ops_to_not_decompose, list) + + # All call_function ops should be in the list + call_function_ops = [ + node.target + for node in exported_program.graph.nodes + if node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + ] + + self.assertEqual( + set(ops_to_not_decompose), + set(call_function_ops), + "ops_to_not_decompose should contain all call_function ops", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/apple/mps/CMakeLists.txt b/backends/apple/mps/CMakeLists.txt index 5a253347b01..99a8afa16ac 100644 --- a/backends/apple/mps/CMakeLists.txt +++ b/backends/apple/mps/CMakeLists.txt @@ -77,7 +77,7 @@ target_compile_options(mpsdelegate PRIVATE "-fno-objc-arc") install( TARGETS mpsdelegate mps_schema EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} INCLUDES DESTINATION ${_common_include_directories} ) diff --git a/backends/apple/mps/operators/shape_ops.py b/backends/apple/mps/operators/shape_ops.py index 76c559018be..18b613670ea 100644 --- a/backends/apple/mps/operators/shape_ops.py +++ b/backends/apple/mps/operators/shape_ops.py @@ -242,11 +242,14 @@ def define_node( output_ids = self.define_tensor_list(node, mps_graph) split_sizes = eval_shape(cast(torch.SymInt, node.args[1])) dim = cast(int, node.args[2]) + orig_dim = dim input_shape = get_shape(get_input_node(node, 0)) + if dim < 0: + dim += len(input_shape) if dim < 0 or dim >= len(input_shape): raise RuntimeError( - f"split_copy: dim {dim} out of range for input tensor with {len(input_shape)} dimensions" + f"split_copy: dim {orig_dim} out of range for input tensor with {len(input_shape)} dimensions" ) mps_node = MPSNode( diff --git a/backends/arm/CMakeLists.txt b/backends/arm/CMakeLists.txt index ede7a96a389..8c79ce857c1 100644 --- a/backends/arm/CMakeLists.txt +++ b/backends/arm/CMakeLists.txt @@ -48,17 +48,44 @@ endif() # VGF backend builds if(EXECUTORCH_BUILD_VGF) - - # include libvgf - set(LIBVGF_PATH - "${EXECUTORCH_ROOT}/examples/arm/ethos-u-scratch/ml-sdk-for-vulkan-manifest/sw/vgf-lib/" - ) - set(VULKAN_THIRD_PARTY_PATH ${EXECUTORCH_ROOT}/backends/vulkan/third-party) set(VULKAN_HEADERS_PATH ${VULKAN_THIRD_PARTY_PATH}/Vulkan-Headers/include) set(VOLK_HEADERS_PATH ${VULKAN_THIRD_PARTY_PATH}/volk) - set(LIBVGF_STATIC "${LIBVGF_PATH}/build/src/libvgf.a") + if(APPLE + OR CMAKE_SYSTEM_PROCESSOR MATCHES "^(arm64|aarch64)$" + OR EXISTS + "${EXECUTORCH_ROOT}/examples/arm/arm-scratch/ml-sdk-for-vulkan-manifest/" + ) + message(STATUS "libvgf sourced from local scratch tree") + + # Legacy layout: libvgf sourced from local scratch tree + set(LIBVGF_PATH + "${EXECUTORCH_ROOT}/examples/arm/arm-scratch/ml-sdk-for-vulkan-manifest/sw/vgf-lib/" + ) + set(LIBVGF_STATIC "${LIBVGF_PATH}/build/src/libvgf.a") + else() + message(STATUS "libvgf installed from pip package") + + set(Python3_FIND_VIRTUALENV FIRST) + if(EXECUTORCH_ROOT AND EXISTS "${EXECUTORCH_ROOT}/env") + set(Python3_EXECUTABLE "${EXECUTORCH_ROOT}/env/bin/python3") + endif() + + find_package(Python3 REQUIRED COMPONENTS Interpreter) + + # Prefer arch-specific site-packages if present, else pure + set(_vgf_site_arch "${Python3_SITEARCH}/vgf_lib/binaries") + set(_vgf_site_pure "${Python3_SITELIB}/vgf_lib/binaries") + if(EXISTS "${_vgf_site_arch}") + set(LIBVGF_PATH "${_vgf_site_arch}") + else() + set(LIBVGF_PATH "${_vgf_site_pure}") + endif() + + set(LIBVGF_STATIC "${LIBVGF_PATH}/lib/libvgf.a") + endif() + set(LIBVGF_INCLUDE "${LIBVGF_PATH}/include/") add_library(vgf STATIC IMPORTED) diff --git a/backends/arm/README.md b/backends/arm/README.md index e495a8e40cb..0abf5e9bf55 100644 --- a/backends/arm/README.md +++ b/backends/arm/README.md @@ -6,7 +6,7 @@ PyTorch models to a TOSA representation. This representation is used to deploy to the following targets: - **Arm® Ethos™-U55/65/85** - Compiled using the Ethos-U Vela compiler. -- **VGF (Vulkan® Graph Format)** – SPIR-V™ representation for Vulkan-capable devices. +- **VGF Format, for ML extensions for Vulkan®** – a format containing SPIR-V™ ML operators for Vulkan-capable devices. The backend provides an ahead-of-time (AOT) flow, that produces a PTE file for your chosen target. The AOT flow supports the following development operating systems: diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index 35b16f819e5..6e81adfed6f 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -6,28 +6,6 @@ # @noautodeps load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") -runtime.python_library( - name = "ethosu_partitioner", - srcs = [ - "ethosu/__init__.py", - "ethosu/backend.py", - "ethosu/partitioner.py" - ], - deps = [ - ":arm_partitioner", - ] -) -runtime.python_library( - name = "vgf_partitioner", - srcs = [ - "vgf/__init__.py", - "vgf/backend.py", - "vgf/partitioner.py" - ], - deps = [ - ":arm_partitioner", - ] -) runtime.python_library( name = "constants", srcs = [ @@ -39,60 +17,70 @@ runtime.python_library( ) runtime.python_library( name = "common", - srcs = [ - "common/__init__.py", - "common/debug.py", - ], + srcs = glob(["common/*.py"]), deps = [ - "fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/serializer:serializer", - "fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/serializer:serializer", + "fbsource//third-party/tosa_tools:serializer", "//caffe2:torch", "//executorch/exir:lib", ], ) + runtime.python_library( - name = "arm_partitioner", + name = "arm_compile_spec", srcs = [ - "tosa/backend.py", - "tosa/partitioner.py", + "common/arm_compile_spec.py", ], deps = [ - ":arm_backend", - ":constants", - "//executorch/backends/arm/debug:schema", - "//executorch/backends/arm/operator_support:operator_support", + "fbsource//third-party/pypi/flatbuffers:flatbuffers", + "fbsource//third-party/pypi/ml-dtypes:ml-dtypes", + "fbsource//third-party/tosa_tools:serializer", + "fbsource//third-party/tosa_tools:tosa", + ":process_node", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/backends/arm/operators:lib", + "//executorch/backends/arm/operators:node_visitor", "//executorch/backends/arm/_passes:passes", - "//executorch/exir:lib", ], ) runtime.python_library( - name = "arm_backend", + name = "ethosu", srcs = [ - "arm_backend.py", + "ethosu/__init__.py", + "ethosu/backend.py", + "ethosu/compile_spec.py", + "ethosu/partitioner.py", ], deps = [ - "fbsource//third-party/pypi/flatbuffers:flatbuffers", - "fbsource//third-party/pypi/ml-dtypes:ml-dtypes", - "fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/serializer:serializer", - "fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/serializer:serializer", - "fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/tosa:tosa", - "fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/tosa:tosa", + ":arm_compile_spec", ":arm_vela", - ":process_node", - "//executorch/backends/arm/operators:lib", - "//executorch/backends/arm/operators:node_visitor", - "//executorch/backends/arm/_passes:passes", + "//executorch/backends/arm/tosa:specification", + "//executorch/backends/arm/tosa:partitioner", + ], +) + +runtime.python_library( + name = "vgf", + srcs = [ + "vgf/__init__.py", + "vgf/backend.py", + "vgf/compile_spec.py", + "vgf/model_converter.py", + "vgf/partitioner.py", + ], + deps = [ + ":arm_compile_spec", + "//executorch/backends/arm/tosa:specification", + "//executorch/backends/arm/tosa:partitioner", ], ) + runtime.python_library( name = "process_node", srcs = ["process_node.py"], deps = [ - "fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/tosa:tosa", - "fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/tosa:tosa", + "fbsource//third-party/tosa_tools:tosa", "//executorch/backends/arm/operators:node_visitor", "//executorch/backends/arm/tosa:mapping", - "//executorch/backends/arm/tosa:quant_utils", "//executorch/backends/arm/tosa:utils", "//executorch/exir:lib", ], @@ -115,3 +103,17 @@ runtime.python_library( "//caffe2:torch", ] ) +runtime.python_library( + name = "_factory", + srcs = [ + "util/_factory.py" + ], + deps = [ + ":ethosu", + ":vgf", + ":arm_compile_spec", + "//executorch/backends/arm/quantizer:lib", + "//executorch/exir/backend:operator_support", + "//executorch/exir/backend:compile_spec_schema", + ] +) diff --git a/backends/arm/_passes/TARGETS b/backends/arm/_passes/TARGETS index bb4e992ada1..a75c63fb86e 100644 --- a/backends/arm/_passes/TARGETS +++ b/backends/arm/_passes/TARGETS @@ -6,7 +6,6 @@ runtime.python_library( deps = [ "//executorch/backends/arm:common", "//executorch/backends/arm:constants", - "//executorch/backends/arm/tosa:quant_utils", "//executorch/backends/arm/tosa:utils", "//executorch/backends/arm/tosa/dialect:lib", "//executorch/backends/transforms:fuse_view_copy", diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index f9e23f73cc5..2904e64a658 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -6,33 +6,34 @@ from . import arm_pass_utils # noqa from .arm_pass import ArmPass # noqa # usort: skip -from .add_bias_pass import AddBiasPass # noqa from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa from .broadcast_args_pass import BroadcastArgsPass # noqa -from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa from .cast_to_int32_pass import CastToInt32Pass # noqa from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa -from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa from .convert_elu_params import ConvertELUParamsPass # noqa from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa from .convert_int64_const_ops_to_int32 import ConvertInt64ConstOpsToInt32Pass # noqa from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa -from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa from .convert_minmax_pass import ConvertMinMaxPass # noqa +from .convert_permute_singleton_to_view_pass import ( # noqa + ConvertPermuteSingletonToViewPass, +) from .convert_split_to_slice import ConvertSplitToSlicePass # noqa from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa -from .convert_to_clamp import ConvertToClampPass # noqa +from .convert_to_clamp_pass import ConvertToClampPass # noqa from .decompose_acosh_pass import DecomposeAcoshPass # noqa from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa +from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa from .decompose_addmm_pass import DecomposeAddmmPass # noqa +from .decompose_any_pass import DecomposeAnyPass # noqa from .decompose_asin_and_acos_pass import DecomposeAsinAndAcosPass # noqa from .decompose_asinh_pass import DecomposeAsinhPass # noqa from .decompose_atan_pass import DecomposeAtanPass # noqa from .decompose_atanh_pass import DecomposeAtanhPass # noqa -from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa +from .decompose_avg_pool2d_pass import DecomposeAvgPool2dPass # noqa from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa from .decompose_cosh_pass import DecomposeCoshPass # noqa from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa @@ -42,21 +43,30 @@ from .decompose_elu_pass import DecomposeEluPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa from .decompose_expm1_pass import DecomposeExpm1Pass # noqa +from .decompose_floor_divide_pass import DecomposeFloorDividePass # noqa from .decompose_gelu_pass import DecomposeGeluPass # noqa from .decompose_glu_pass import DecomposeGluPass # noqa -from .decompose_grouped_conv import DecomposeGroupedConv # noqa +from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa +from .decompose_int16_activation_conv_pass import ( # noqa + DecomposeConvWithInt16ActivationPass, +) +from .decompose_int_pow_pass import DecomposeIntPowPass # noqa from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa -from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa +from .decompose_linalg_vector_norm_pass import DecomposeLinalgVectorNormPass # noqa from .decompose_linear_pass import DecomposeLinearPass # noqa from .decompose_logit_pass import DecomposeLogitPass # noqa -from .decompose_masked_fill import DecomposeMaskedFill # noqa -from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa +from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa +from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa from .decompose_meandim_pass import DecomposeMeanDimPass # noqa from .decompose_ne_pass import DecomposeNotEqualPass # noqa +from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa +from .decompose_remainder_pass import DecomposeRemainderPass # noqa from .decompose_round_pass import DecomposeRoundPass # noqa +from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa from .decompose_select import DecomposeSelectPass # noqa +from .decompose_select_scatter_pass import DecomposeSelectScatterPass # noqa from .decompose_sign_pass import DecomposeSignPass # noqa from .decompose_silu_pass import DecomposeSiluPass # noqa from .decompose_sinh_pass import DecomposeSinhPass # noqa @@ -64,34 +74,53 @@ from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa from .decompose_sqrt_pass import DecomposeSqrtPass # noqa from .decompose_sum_pass import DecomposeSumPass # noqa +from .decompose_tosa_unsupported_clamp_pass import ( # noqa + DecomposeTOSAUnsupportedClampPass, +) from .decompose_var_pass import DecomposeVarPass # noqa from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, - QuantizeOperatorArguments, - RetraceFoldedDtypesPass, + QuantizeClampArgumentsPass, ) -from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa -from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa +from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa +from .fuse_constant_ops_pass import ( # noqa + ComputeConstantOpsAOTPass, + FuseConstantArgsPass, +) +from .fuse_duplicate_users_pass import FuseDuplicateUsersPass # noqa from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa +from .fuse_view_copy_transform_pass import FuseViewCopyTransformPass # noqa from .insert_int32_casts_after_int64_placeholders import ( # noqa InsertInt32CastsAfterInt64PlaceholdersPass, ) -from .insert_rescales_pass import InsertRescalePass # noqa +from .insert_rescales_pass import ( # noqa + InsertControlFlowRescalesPass, + InsertRescaleInt32Pass, + InsertRescalePass, +) from .insert_table_ops import InsertTableOpsPass # noqa from .match_arg_dtype_pass import MatchArgDtypePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa +from .normalize_while_initial_args_pass import NormalizeWhileInitialArgsPass # noqa +from .promote_bool_operands_pass import PromoteBoolOperandsPass # noqa +from .remove_getitem_pass import RemoveGetItemPass # noqa +from .remove_graph_asserts_pass import RemoveGraphAssertsPass # noqa from .remove_noop_pass import RemoveNoopPass # noqa from .replace_scalar_with_tensor_pass import ( # noqa - ReplaceScalarWithTensorArgPassTOSABI, - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, ) +from .rewrite_conv_pass import RewriteConvPass # noqa +from .rewrite_matmul import RewriteMatmulPass # noqa +from .rewrite_upsample import RewriteUpsamplePass # noqa from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa from .size_adjust_input_pass import SizeAdjustInputPass # noqa from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa -from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip +from .replace_inf_and_limit_values_pass import ( # noqa # usort: skip + ReplaceInfAndLimitValuesPass, +) from .arm_pass_manager import ArmPassManager # noqa # usort: skip diff --git a/backends/arm/_passes/_debug_passes.py b/backends/arm/_passes/_debug_passes.py index 7809885d465..caaaec8ea5e 100644 --- a/backends/arm/_passes/_debug_passes.py +++ b/backends/arm/_passes/_debug_passes.py @@ -3,17 +3,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import inspect +import os +from typing import Set, Type + import torch +from executorch.backends.arm._passes import ArmPass from executorch.devtools.visualization.visualization_utils import visualize_graph from executorch.exir import ExportedProgram from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule -class VisualizePass(ExportPass): +class VisualizePass(ArmPass): """ This pass visualizes the graph at the point of insertion in the pass manager """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program @@ -21,3 +29,30 @@ def __init__(self, exported_program: ExportedProgram) -> None: def call(self, graph_module: torch.fx.GraphModule) -> PassResult: visualize_graph(graph_module, self.exported_program) return PassResult(graph_module, False) + + +class PrintGraphModuleCodePass(ArmPass): + """ + This pass prints the graph module's code to stdout for debugging purposes. + + Example output: + + [arm_pass_manager.py:305] + def forward(self, x, y): + x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) + remainder = torch.ops.aten.remainder.Scalar(x, 0.25); x = None + return pytree.tree_unflatten((remainder,), self._out_spec) + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def __init__(self, label: str | None = None): + super().__init__() + caller_frame = inspect.stack()[1] + origin = f"{os.path.basename(caller_frame.filename)}:{caller_frame.lineno}" + self.label = f"[{label}]" if label is not None else f"[{origin}]" + + def call(self, graph_module: GraphModule) -> PassResult: + gm_code = graph_module.code.strip() + print(f"\n{self.label}\n{gm_code}") + return PassResult(graph_module, False) diff --git a/backends/arm/_passes/add_bias_pass.py b/backends/arm/_passes/add_bias_pass.py deleted file mode 100644 index 31c0c0505cb..00000000000 --- a/backends/arm/_passes/add_bias_pass.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor -from executorch.backends.transforms.utils import create_constant_placeholder - -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult -from torch.export.graph_signature import InputKind - - -class AddBiasPass(ArmPass): - """TOSA requires convolution nodes to have a bias input. - This pass adds a bias input to convolution nodes that do not have one. - The bias is set to zero. - """ - - targeted_ops = (exir_ops.edge.aten.convolution.default,) - - def call(self, graph_module): - modified = False - for node in graph_module.graph.nodes: - if node.op != "call_function": - continue - if node.target not in self.targeted_ops: - continue - - if len(node.all_input_nodes) < 3: - modified = True - # bias is missing - weight_node = node.all_input_nodes[1] - output_channels = get_first_fake_tensor(weight_node).shape[0] - # add a node containging zeros - # if quantized, use int32, otherwise use float32 - if ( - "output_qparams" in node.meta - and len(node.meta["output_qparams"]) > 0 - ): - bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32) - else: - bias_data = torch.zeros( - size=(output_channels,), dtype=torch.float32 - ) - - with graph_module.graph.inserting_after(weight_node): - bias_node = create_constant_placeholder( - self.exported_program, - graph=graph_module.graph, - kind=InputKind.PARAMETER, - data=bias_data, - persistent_buffer=True, - name=f"{node.name}_bias", - ) - node.update_arg(2, bias_node) - - if modified: - graph_module = super().call(graph_module).graph_module - return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 8156ca0b89d..c8be7c7c04e 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -3,14 +3,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import itertools import operator -from typing import cast, List +from typing import cast, List, Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + FoldAndAnnotateQParamsPass, +) from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops @@ -20,7 +23,7 @@ from torch.fx.passes.utils.source_matcher_utils import get_source_partitions -class AnnotateDecomposedMatmulPass(ExportPass): +class AnnotateDecomposedMatmulPass(ArmPass): """ torch.matmul and it's equivalent operator @ can be decomposed in many ways, for instance: dq -> matmul -> q can become @@ -29,6 +32,8 @@ class AnnotateDecomposedMatmulPass(ExportPass): matmul-op (can be mm or bmm). """ + _passes_required_after: Set[Type[ExportPass]] = {FoldAndAnnotateQParamsPass} + def _match_partition_to_node( self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node] ) -> torch.fx.Node: @@ -46,7 +51,7 @@ def _match_partition_to_node( raise RuntimeError(f"Cannot find an input node which matches, {node}.") def call(self, graph_module: GraphModule) -> PassResult: - matmul_partitions = get_source_partitions( + matmul_partitions_map = get_source_partitions( graph_module.graph, [ torch.matmul, @@ -55,10 +60,11 @@ def call(self, graph_module: GraphModule) -> PassResult: None, ) matmul_partitions = list( - itertools.chain.from_iterable(matmul_partitions.values()) + itertools.chain.from_iterable(matmul_partitions_map.values()) ) matmul_targets = { exir_ops.edge.aten.bmm.default, + exir_ops.edge.aten.mm.default, } for partition in matmul_partitions: quantized_input = all( @@ -68,7 +74,10 @@ def call(self, graph_module: GraphModule) -> PassResult: node for node in partition.nodes if node.target in matmul_targets ][0] - if quantized_input: + if quantized_input and not all( + input_node.target in DQ_OPS + for input_node in matmul_node.all_input_nodes + ): matmul_args = matmul_node.all_input_nodes for node in matmul_args: # Find the dq-node connected to this mm/bmm arg @@ -80,7 +89,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # Create new dq-node before matmul dq_node = create_node( graph=graph_module.graph, - op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type] + op_target=cast(EdgeOpOverload, input_node.target), ) dq_node.args = (node, *input_node.args[1:]) matmul_node.replace_input_with(node, dq_node) @@ -94,12 +103,14 @@ def call(self, graph_module: GraphModule) -> PassResult: partition_output = list(partition.output_nodes[0].users)[0] quantized_output = partition_output.target in Q_OPS - if quantized_output: + if quantized_output and not all( + user.target in Q_OPS for user in matmul_node.users + ): with graph_module.graph.inserting_after(matmul_node): # Create q-node after matmul q_node = create_node( graph=graph_module.graph, - op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type] + op_target=cast(EdgeOpOverload, partition_output.target), ) matmul_node.replace_all_uses_with(q_node) q_node.args = (matmul_node, *partition_output.args[1:]) diff --git a/backends/arm/_passes/annotate_output_dim_order_pass.py b/backends/arm/_passes/annotate_output_dim_order_pass.py index 08f93383a9c..8dc13326e4a 100644 --- a/backends/arm/_passes/annotate_output_dim_order_pass.py +++ b/backends/arm/_passes/annotate_output_dim_order_pass.py @@ -3,9 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_output_dim_orders -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult class AnnotateOutputDimOrderPass(ArmPass): @@ -14,6 +17,8 @@ class AnnotateOutputDimOrderPass(ArmPass): for verifying that the dim order does not change unexpectedly in later passes. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module): output_node = graph_module.graph.output_node() output_node.meta["original_dim_orders"] = get_output_dim_orders(graph_module) diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index 085267a174e..662cd6e8d97 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -3,21 +3,52 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import traceback -from typing import Optional +from abc import abstractmethod +from typing import Any, List, Optional, Set, Type -import torch from executorch.exir.pass_base import ExportPass, NodeMetadata +from torch.fx import GraphModule +from torch.fx.passes.infra.pass_base import PassResult class ArmPass(ExportPass): """Base class for Arm passes""" - def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None): - super(ArmPass, self).__init__() - self.exported_program = exported_program + def __init__(self) -> None: + super().__init__() + self.submodule_depth = 0 + + @property + @abstractmethod + def _passes_required_after(self) -> Set[Type[ExportPass]]: + """The subclass defines passes that must run after it""" + pass + + @staticmethod + def get_required_passes(pass_) -> List[str]: + """ + Returns the list of passes that must be run after this pass, sorted by name. + """ + if hasattr(pass_, "_passes_required_after"): + return sorted([ArmPass.get_name(p) for p in pass_._passes_required_after]) + else: + return [] + + @staticmethod + def get_name(pass_) -> str: + """ + Returns the name of the pass. + """ + if isinstance(pass_, ExportPass): + return pass_.__class__.__name__ + elif hasattr(pass_, "__name__"): + return pass_.__name__ + else: + raise ValueError( + f"Cannot get name for pass: {pass_}. It must be an instance of ExportPass or have a __name__ attribute." + ) def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): if not updated: @@ -31,3 +62,19 @@ def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False) old_stack_trace = new_meta.get("stack_trace", "") new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}" return super().call_operator(op, args, kwargs, NodeMetadata(new_meta)) + + def call_submodule( + self, graph_module: GraphModule, inputs: tuple[Any, ...] + ) -> PassResult: + self.submodule_depth += 1 + if self.submodule_depth == 1: + result = super().call_submodule(graph_module, inputs) + else: + # When we trace a submodule, we don't want to apply the calling pass. + # Temporarily replace call_operator to avoid this. + _call_operator_fn = self.call_operator + self.call_operator = super().call_operator # type: ignore + result = super().call_submodule(graph_module, inputs) + self.call_operator = _call_operator_fn # type: ignore + self.submodule_depth -= 1 + return result diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f49206da67e..dc418f18d27 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -5,40 +5,42 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +import logging +from collections import defaultdict +from collections.abc import Sequence import executorch.backends.arm.tosa.dialect # noqa: unused from executorch.backends.arm._passes import ( - AddBiasPass, AnnotateDecomposedMatmulPass, AnnotateOutputDimOrderPass, BroadcastArgsPass, - CastBoolToInt8Pass, CastInt64BuffersToInt32Pass, CastToInt32Pass, - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, Conv1dUnsqueezePass, - ConvertAnyDefaultDimDimsPass, ConvertELUParamsPass, ConvertExpandCopyToRepeatPass, ConvertFullLikeToFullPass, ConvertInt64ConstOpsToInt32Pass, ConvertInt64OutputOpsToInt32Pass, - ConvertIntPowToMuls, ConvertMinMaxPass, ConvertMmToBmmPass, + ConvertPermuteSingletonToViewPass, ConvertSplitToSlicePass, ConvertSqueezesToViewPass, ConvertToClampPass, DecomposeAcoshPass, DecomposeAdaptiveAvgPool2dPass, DecomposeAddmmPass, + DecomposeAddSubAlphaPass, + DecomposeAnyPass, DecomposeAsinAndAcosPass, DecomposeAsinhPass, DecomposeAtanhPass, DecomposeAtanPass, - DecomposeAvgPool2d, + DecomposeAvgPool2dPass, DecomposeBatchNormNoStatsPass, + DecomposeConvWithInt16ActivationPass, DecomposeCoshPass, DecomposeCosineSimilarityPass, DecomposeCumsumPass, @@ -47,21 +49,27 @@ DecomposeEluPass, DecomposeEmbeddingPass, DecomposeExpm1Pass, + DecomposeFloorDividePass, DecomposeGeluPass, DecomposeGluPass, - DecomposeGroupedConv, + DecomposeGroupedConvPass, DecomposeGroupNormPass, + DecomposeIntPowPass, DecomposeLayerNormPass, DecomposeLeakyReLUPass, + DecomposeLinalgVectorNormPass, DecomposeLinearPass, - DecomposeLinearVectorNormPass, DecomposeLogitPass, - DecomposeMaskedFill, - DecomposeMaxPool2DPass, + DecomposeMaskedFillPass, + DecomposeMaxPool2dPass, DecomposeMeanDimPass, DecomposeNotEqualPass, + DecomposeQuantNodesPass, + DecomposeRemainderPass, DecomposeRoundPass, + DecomposeScaledDotProductAttentionPass, DecomposeSelectPass, + DecomposeSelectScatterPass, DecomposeSignPass, DecomposeSiluPass, DecomposeSinhPass, @@ -69,251 +77,378 @@ DecomposeSoftmaxUnstablePass, DecomposeSqrtPass, DecomposeSumPass, + DecomposeTOSAUnsupportedClampPass, DecomposeVarPass, DecorateFp32toInt32CastingPass, FoldAndAnnotateQParamsPass, - FuseBatchnorm2DPass, + FuseBatchNorm2dPass, FuseConstantArgsPass, + FuseDuplicateUsersPass, FuseEqualPlaceholdersPass, FuseQuantizedActivationPass, + FuseViewCopyTransformPass, + InsertControlFlowRescalesPass, InsertInt32CastsAfterInt64PlaceholdersPass, + InsertRescaleInt32Pass, InsertRescalePass, InsertTableOpsPass, MatchArgDtypePass, MatchArgRanksPass, - QuantizeOperatorArguments, + NormalizeWhileInitialArgsPass, + PromoteBoolOperandsPass, + QuantizeClampArgumentsPass, + RemoveGetItemPass, + RemoveGraphAssertsPass, RemoveNoopPass, - ReplaceInfValues, - ReplaceScalarWithTensorArgPassTOSABI, - ReplaceScalarWithTensorArgPassTOSAMI, - RetraceFoldedDtypesPass, + ReplaceInfAndLimitValuesPass, + ReplaceScalarWithTensorByProfilePass, + RewriteConvPass, + RewriteMatmulPass, + RewriteUpsamplePass, ScalarsToAttributePass, SizeAdjustInputPass, ToTosaMemoryFormatPass, UnsqueezeBeforeRepeatPass, UnsqueezeScalarPlaceholdersPass, ) - +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.common.pipeline_config import ( + ArmPassPipelineConfig, + FuseDuplicateUsersConfig, + SoftmaxDecompositionConfig, +) from executorch.backends.arm.tosa.specification import ( + tosa_spec_in_set, TosaLoweringContext, TosaSpecification, ) -from executorch.backends.transforms.decompose_sdpa import ( - DecomposeScaledDotProductAttention, -) -from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform -from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram +from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassManager -from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass from torch.fx import GraphModule +from torch.fx.passes.infra.pass_base import PassResult +from torch.nn.modules import Module +logger = logging.getLogger(__name__) -class ArmPassManager(PassManager): - def __init__(self, tosa_spec: TosaSpecification) -> None: - self.tosa_spec = tosa_spec +class ArmPassManager(PassManager): + def __init__(self, compile_spec: ArmCompileSpec) -> None: + self.compile_spec = compile_spec + self.tosa_spec = compile_spec.tosa_spec + self._skip_pass_types: tuple[type, ...] = () super().__init__() + self.configure_skip_passes() + + def configure_skip_passes( + self, + override_config: ArmPassPipelineConfig | None = None, + ) -> tuple[type, ...]: + """ + Configures the pass manager to skip certain passes based on the ArmPassPipelineConfig class + found in the compile spec. + """ + skip_set: set[type] = set() + + config = override_config or self.compile_spec.get_pass_pipeline_config() + logger.debug(f"Skip Config: {config}") + + match config.softmax: + case SoftmaxDecompositionConfig.MASKED: + skip_set.add(DecomposeSoftmaxUnstablePass) + case SoftmaxDecompositionConfig.UNSTABLE: + skip_set.add(DecomposeSoftmaxPass) + skip_set.add(DecomposeMaskedFillPass) + + if config.fuse_duplicate_users is FuseDuplicateUsersConfig.DISABLED: + skip_set.add(FuseDuplicateUsersPass) + + self._skip_pass_types = tuple(skip_set) + skip_names = [skipped_pass.__name__ for skipped_pass in self._skip_pass_types] + logger.debug(f"Passes in skip list: {skip_names}") + + return self._skip_pass_types + + def validate_constraints_mandatory(self): + """ + Validates that necessary passes have run before transforming to backend. + + Note that this differs from the original validate_constraints function, which + only checks the order of passes. + """ + passes_to_run = defaultdict(list) + + for current_pass in self.passes: + current_pass_name = ArmPass.get_name(current_pass) + for required_pass_name in ArmPass.get_required_passes(current_pass): + passes_to_run[required_pass_name].append(current_pass_name) + + passes_to_run.pop(current_pass_name, None) + + if len(passes_to_run) > 0: + error_msg = "The following constraints for passes are not met:\n" + for required_pass, requiring_passes in passes_to_run.items(): + for requiring_pass in requiring_passes: + error_msg += ( + f" - {required_pass} must run after {requiring_pass}\n" + ) + + raise RuntimeError(error_msg) + + def add_passes(self, passes: Sequence[ExportPass | None]): + for p in passes: + if p is not None: + self.add_pass(p) def _transform(self, graph_module: GraphModule): with TosaLoweringContext(self.tosa_spec): return self(graph_module).graph_module - def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + def add_pass(self, pipeline_pass): + if type(pipeline_pass) in self._skip_pass_types: + return + super().add_pass(pipeline_pass) + + def _tosa_pipeline( + self, exported_program: ExportedProgram, graph_module: GraphModule + ) -> GraphModule: + # Preprocessing passes self.add_pass(AnnotateOutputDimOrderPass()) - self.add_pass(FuseQuantizedActivationPass()) - self.add_pass(RemoveGetItemPass()) - self.add_pass(ConvertSplitToSlicePass()) - self.add_pass(ConvertMmToBmmPass()) - self.add_pass(DecomposeLinearVectorNormPass()) - self.add_pass( - DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) + + # Node transformation passes (pre q/dq folding) + self.add_passes( + [ + FuseQuantizedActivationPass(), + ConvertToClampPass(), + DecomposeTOSAUnsupportedClampPass(), + DecomposeGroupNormPass(), + DecomposeLayerNormPass(), + DecomposeVarPass(), + DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec), + AnnotateDecomposedMatmulPass(), + ConvertELUParamsPass(), + NormalizeWhileInitialArgsPass(use_exir_clone=True), + ] ) - self.add_pass(ConvertFullLikeToFullPass()) - self.add_pass(ConvertToClampPass()) - self.add_pass(ConvertMinMaxPass()) - self.add_pass(ConvertAnyDefaultDimDimsPass()) - self.add_pass(MatchArgDtypePass()) - if self.tosa_spec.is_U55_subset: - self.add_pass(CastToInt32Pass()) - - self.add_pass(CastBoolToInt8Pass()) - self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) - self.add_pass(AnnotateDecomposedMatmulPass()) - self.add_pass(QuantizeOperatorArguments()) - self.add_pass(ConvertELUParamsPass()) - self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] - self.add_pass(RetraceFoldedDtypesPass()) - self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) - self.add_pass(MatchArgRanksPass(exported_program)) - if self.tosa_spec.is_U55_subset: - self.add_pass(BroadcastArgsPass()) - self.add_pass(DecomposeLinearPass()) - self.add_pass(DecomposeAdaptiveAvgPool2dPass()) - self.add_pass(DecomposeAvgPool2d()) - self.add_pass(ComputeConstantOpsAOT(exported_program)) - - self.add_pass(DecomposeGroupedConv()) - self.add_pass(ConvertExpandCopyToRepeatPass()) - self.add_pass(UnsqueezeBeforeRepeatPass()) - self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) - self.add_pass(DecomposeSumPass()) - self.add_pass(DecomposeCumsumPass(exported_program)) - self.add_pass(Conv1dUnsqueezePass()) - self.add_pass(DecomposeMaxPool2DPass()) - self.add_pass(SizeAdjustInputPass()) - self.add_pass(DecomposeSelectPass()) - self.add_pass(ConvertSqueezesToViewPass()) - - self.add_pass(FuseViewCopyTransform()) - self.add_pass(FuseConstantArgsPass(exported_program)) - self.add_pass(AddBiasPass(exported_program)) - - self.add_pass(InsertTableOpsPass(exported_program)) - self.add_pass(FuseEqualPlaceholdersPass(exported_program)) - self.add_pass(ToTosaMemoryFormatPass(exported_program)) - self.add_pass(RemoveNoopPass()) - self.add_pass(InsertRescalePass()) - - return self._transform(exported_program.graph_module) - - def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: - self.add_pass(AnnotateOutputDimOrderPass()) - self.add_pass(DecomposeExpm1Pass()) - self.add_pass(DecomposeLogitPass()) - self.add_pass(DecomposeMaskedFill()) - self.add_pass(DecomposeRoundPass()) - self.add_pass(DecomposeAcoshPass()) - self.add_pass(DecomposeAsinhPass()) - self.add_pass(DecomposeCoshPass()) - self.add_pass(DecomposeAsinAndAcosPass()) - self.add_pass(DecomposeSqrtPass()) - self.add_pass(DecomposeAtanPass()) - self.add_pass(DecomposeAtanhPass()) - self.add_pass(DecomposeAddmmPass()) - self.add_pass(DecomposeEluPass()) - self.add_pass(DecomposeExpm1Pass()) - self.add_pass(ConvertIntPowToMuls()) - self.add_pass(CastBoolToInt8Pass()) - self.add_pass(DecomposeSinhPass()) - self.add_pass(DecomposeSignPass()) - self.add_pass(DecomposeDivTensorModePass()) - self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) - self.add_pass(DecomposeEmbeddingPass()) - self.add_pass(FuseQuantizedActivationPass()) - self.add_pass(RemoveGetItemPass()) - self.add_pass(ConvertSplitToSlicePass()) - self.add_pass(FuseBatchnorm2DPass(exported_program)) - self.add_pass(ConvertMmToBmmPass()) - self.add_pass(DecomposeGluPass()) - self.add_pass(DecomposeLinearPass()) - self.add_pass(DecomposeLeakyReLUPass()) - self.add_pass(DecomposeGroupNormPass()) - self.add_pass(DecomposeLayerNormPass()) - self.add_pass(DecomposeBatchNormNoStatsPass()) - self.add_pass(DecomposeVarPass()) - self.add_pass( - DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) + + # Fold Q/DQ nodes, insert INT8/INT32 rescales, decompose quantization nodes. + self.add_passes( + [ + FoldAndAnnotateQParamsPass(exported_program), + FuseDuplicateUsersPass(), + # TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or + # before FoldAndAnnotateQParamsPass but is unable to at the moment. + # Ticket: MLETORCH-1539 + DecomposeLinearPass(), + InsertRescaleInt32Pass(), + InsertControlFlowRescalesPass(), + DecomposeQuantNodesPass(), + ] + ) + + # Node transformation passes (post q/dq folding) + self.add_passes( + [ + ConvertSplitToSlicePass(), + QuantizeClampArgumentsPass(), + RemoveGetItemPass(), + DecomposeBatchNormNoStatsPass(), + DecomposeLogitPass(), + DecomposeMaskedFillPass(), + DecomposeRoundPass(), + DecomposeAcoshPass(), + DecomposeAsinhPass(), + DecomposeCoshPass(), + DecomposeAsinAndAcosPass(), + DecomposeSqrtPass(), + DecomposeAtanPass(), + DecomposeAtanhPass(), + DecomposeAddmmPass(), + DecomposeEluPass(), + DecomposeExpm1Pass(), + DecomposeIntPowPass(), + PromoteBoolOperandsPass(), + DecomposeSinhPass(), + DecomposeSignPass(), + DecomposeFloorDividePass(), + DecomposeGeluPass(), + DecomposeAddSubAlphaPass(), + DecomposeGroupedConvPass(), + Conv1dUnsqueezePass(), + ] ) - self.add_pass(DecomposeNotEqualPass()) - self.add_pass(DecomposeDivPass()) - self.add_pass(DecomposeSoftmaxPass()) - self.add_pass(DecomposeGeluPass()) - self.add_pass(ConvertFullLikeToFullPass()) - self.add_pass(ConvertToClampPass()) - self.add_pass(ConvertMinMaxPass()) - self.add_pass(ConvertAnyDefaultDimDimsPass()) - self.add_pass(MatchArgDtypePass()) - self.add_pass(AnnotateDecomposedMatmulPass()) - self.add_pass(QuantizeOperatorArguments()) - self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] - self.add_pass(RetraceFoldedDtypesPass()) - self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) - self.add_pass(MatchArgRanksPass(exported_program)) - self.add_pass(DecomposeAdaptiveAvgPool2dPass()) - self.add_pass(DecomposeAvgPool2d()) - self.add_pass( - DecorateFp32toInt32CastingPass() - ) # Require that no new fp32->int32 is introduced after this pass - self.add_pass(ComputeConstantOpsAOT(exported_program)) - - self.add_pass(DecomposeGroupedConv()) - self.add_pass(ConvertExpandCopyToRepeatPass()) - self.add_pass(UnsqueezeBeforeRepeatPass()) - self.add_pass(DecomposeSumPass()) - self.add_pass(DecomposeCumsumPass(exported_program)) - self.add_pass(Conv1dUnsqueezePass()) - self.add_pass(DecomposeMaxPool2DPass()) - self.add_pass(SizeAdjustInputPass()) - self.add_pass(DecomposeSelectPass()) - self.add_pass(ConvertSqueezesToViewPass()) - - self.add_pass(FuseViewCopyTransform()) - self.add_pass(FuseConstantArgsPass(exported_program)) - self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) - self.add_pass(AddBiasPass(exported_program)) - self.add_pass(InsertTableOpsPass(exported_program)) - self.add_pass(FuseEqualPlaceholdersPass(exported_program)) - self.add_pass(ToTosaMemoryFormatPass(exported_program)) - self.add_pass(RemoveNoopPass()) - self.add_pass(InsertRescalePass()) - - return self._transform(exported_program.graph_module) - - def transform_to_backend_pipeline(self, exported_program: ExportedProgram): + + # Scalars -> tensors, match tensor dtypes and ranks. + self.add_passes( + [ + ReplaceScalarWithTensorByProfilePass(), + ConvertFullLikeToFullPass(), + MatchArgDtypePass(), + UnsqueezeScalarPlaceholdersPass(exported_program), + # TODO: Move DecomposeNotEqualPass to before or after this block of + # passes. Ticket: MLETORCH-1540 + DecomposeNotEqualPass(), + MatchArgRanksPass(exported_program), + ] + ) + + # Node transformation passes (post scalar-removal) + self.add_passes( + [ + DecomposeRemainderPass(), + DecomposeDivTensorModePass(), + DecomposeEmbeddingPass(), + FuseBatchNorm2dPass(exported_program), + ConvertMmToBmmPass(), + DecomposeGluPass(), + DecomposeLeakyReLUPass(), + DecomposeDivPass(), + DecomposeSoftmaxPass(), + ConvertMinMaxPass(), + DecomposeAnyPass(), + DecomposeAdaptiveAvgPool2dPass(), + DecomposeAvgPool2dPass(), + DecorateFp32toInt32CastingPass(), + ComputeConstantOpsAOTPass(exported_program), + FuseConstantArgsPass(exported_program), + ConvertExpandCopyToRepeatPass(), + UnsqueezeBeforeRepeatPass(), + DecomposeCumsumPass(exported_program), + DecomposeMaxPool2dPass(), + SizeAdjustInputPass(), + DecomposeSelectPass(), + ConvertSqueezesToViewPass(), + CastToInt32Pass(), + BroadcastArgsPass(), + ConvertPermuteSingletonToViewPass(), + FuseViewCopyTransformPass(), + DecomposeConvWithInt16ActivationPass(), + DecomposeSumPass(), + InsertTableOpsPass(exported_program), + ] + ) + + # Aten -> TOSA transformation passes + self.add_passes( + [ + RewriteUpsamplePass(), + RewriteConvPass(exported_program), + RewriteMatmulPass(), + ] + ) + + # Postprocessing/cleanup passes + self.add_passes( + [ + CastInt64BuffersToInt32Pass(exported_program), + FuseEqualPlaceholdersPass(exported_program), + ToTosaMemoryFormatPass(exported_program), + RemoveNoopPass(), + InsertRescalePass(), + ] + ) + + self.validate_constraints_mandatory() + return self._transform(graph_module) + + def transform_to_backend_pipeline( + self, exported_program: ExportedProgram, graph_module: GraphModule + ): """Apply passes before transforming program to backend""" - if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"): - return self._tosa_FP_pipeline(exported_program) - elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"): - return self._tosa_INT_pipeline(exported_program) - else: - raise NotImplementedError( - f"No pass pipeline implemented for {self.tosa_spec=}" + + if not tosa_spec_in_set( + self.tosa_spec, + { + TosaSpecification.create_from_string("TOSA-1.0+FP"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + }, + ): + raise RuntimeError( + f"No pass pipeline found for TOSA specification: {self.tosa_spec}" ) + return self._tosa_pipeline(exported_program, graph_module) + def transform_for_annotation_pipeline(self, graph_module: GraphModule): - self.add_pass( - RemoveGraphAssertsPass() - ) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph - self.add_pass(ConvertInt64ConstOpsToInt32Pass()) - self.add_pass(ConvertInt64OutputOpsToInt32Pass()) - self.add_pass(InsertInt32CastsAfterInt64PlaceholdersPass()) - self.add_pass(DecomposeEmbeddingPass()) - self.add_pass(DecomposeScaledDotProductAttention()) - self.add_pass(DecomposeRoundPass()) - self.add_pass(DecomposeLogitPass()) - self.add_pass(CastBoolToInt8Pass()) - self.add_pass(DecomposeSignPass()) - self.add_pass(DecomposeAddmmPass()) - self.add_pass(DecomposeDivTensorModePass()) - self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) - self.add_pass(ScalarsToAttributePass()) - self.add_pass(DecomposeGroupNormPass()) - self.add_pass(DecomposeLayerNormPass()) - self.add_pass(DecomposeVarPass()) - self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec)) - self.add_pass(DecomposeNotEqualPass()) - self.add_pass(DecomposeCosineSimilarityPass()) - self.add_pass(DecomposeGluPass()) - self.add_pass(DecomposeDivPass()) - self.add_pass(DecomposeLeakyReLUPass()) - self.add_pass(DecomposeLinearVectorNormPass()) - self.add_pass(DecomposeSqrtPass()) - self.add_pass(DecomposeSiluPass()) - self.add_pass(DecomposeAvgPool2d()) - - if self.tosa_spec.is_U55_subset: - # Numerically stable softmax uses amax which is not supported on Ethos-U55 - self.add_pass(DecomposeSoftmaxUnstablePass()) - else: - self.add_pass(DecomposeSoftmaxPass()) - - self.add_pass(ConvertMinMaxPass()) - self.add_pass(ReplaceInfValues()) - self.add_pass(DecomposeSumPass()) - - if not self.tosa_spec.is_U55_subset: - # Uses where which is not supported on Ethos-U55 - self.add_pass(DecomposeMaskedFill()) + # Preprocessing passes + self.add_pass(RemoveGraphAssertsPass()) + + # Transformation passes (pre scalar -> tensor) + self.add_passes( + [ + DecomposeSelectScatterPass(), + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + InsertInt32CastsAfterInt64PlaceholdersPass(), + DecomposeEmbeddingPass(), + DecomposeScaledDotProductAttentionPass(), + DecomposeRoundPass(), + DecomposeLogitPass(), + PromoteBoolOperandsPass(), + DecomposeSignPass(), + DecomposeAddmmPass(), + DecomposeRemainderPass(), + DecomposeFloorDividePass(), + DecomposeDivTensorModePass(), + ] + ) + + # Scalars -> tensors + self.add_passes( + [ + ReplaceScalarWithTensorByProfilePass(), + ScalarsToAttributePass(), + ] + ) + + # Transformation passes (post scalar removal) + self.add_passes( + [ + NormalizeWhileInitialArgsPass(use_exir_clone=False), + DecomposeAddSubAlphaPass(), + DecomposeGroupNormPass(), + DecomposeLayerNormPass(), + DecomposeVarPass(), + DecomposeMeanDimPass(graph_module, self.tosa_spec), + DecomposeNotEqualPass(), + DecomposeCosineSimilarityPass(), + DecomposeGluPass(), + DecomposeDivPass(), + DecomposeLeakyReLUPass(), + DecomposeLinalgVectorNormPass(), + DecomposeSqrtPass(), + DecomposeSiluPass(), + DecomposeAvgPool2dPass(), + DecomposeSoftmaxUnstablePass(), + DecomposeSoftmaxPass(), + ConvertMinMaxPass(), + ] + ) + + # Postprocessing passes + self.add_passes( + [ + ReplaceInfAndLimitValuesPass(), + DecomposeMaskedFillPass(), + ] + ) return self._transform(graph_module) + + def __call__(self, module: Module) -> PassResult: + try: + return super().__call__(module) + except Exception as e: + first_exception = e.__cause__ or e.__context__ or e + import re + + message = e.args[0] + m = re.search(r"An error occurred when running the '([^']+)' pass", message) + if m: + pass_name = m.group(1) + first_exception.args = ( + f"{pass_name}: {first_exception.args[0]}", + *first_exception.args[1:], + ) + raise first_exception diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 71e2030958f..006d4fff953 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import traceback from inspect import isclass @@ -14,8 +13,10 @@ import torch import torch.fx from executorch.backends.arm.common.debug import get_node_debug_info +from executorch.backends.arm.common.type import ensure_type from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch._export.utils import ( get_buffer, @@ -30,11 +31,25 @@ from torch.export.graph_signature import InputKind +def is_submodule_node(node: torch.fx.Node): + if node.op not in ("get_attr", "placeholder"): + return False + try: + node.graph.owning_module.get_submodule(node.target) + except AttributeError: + return False + return True + + def is_get_attr_node(node: torch.fx.Node) -> bool: """ - Returns true if the given node is a get attr node for a tensor of the model + Returns true if the given node is a get attr node for a tensor of the model. """ - return isinstance(node, torch.fx.Node) and node.op == "get_attr" + return ( + isinstance(node, torch.fx.Node) + and node.op == "get_attr" + and not is_submodule_node(node) + ) def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: @@ -82,22 +97,38 @@ def get_param_tensor( elif is_lifted_tensor_constant(exp_prog, node): return get_lifted_tensor_constant(exp_prog, node) elif is_get_attr_node(node): + target_node = ensure_type(str, node.target) # This is a hack to support both lifted and unlifted graph try: - return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type] + return getattr(node.graph.owning_module, target_node) except AttributeError: - return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type] + return getattr(exp_prog.graph_module, target_node) raise RuntimeError(f"unsupported param type, {node.op}.") +def expand_around_channel(param: Sequence[int] | int, spatial_rank: int) -> list[int]: + """ + Expand a scalar or 1-D parameter around the channel dimension into a broadcastable + shape while preserving the channel location. + """ + if isinstance(param, int): + return [param] * spatial_rank + + param_list = list(param) + if len(param_list) == 1 and spatial_rank > 1: + param_list = param_list * spatial_rank + return param_list + + def create_node( graph: torch.fx.Graph, - op_target: OpOverload, + op_target: OpOverload | EdgeOpOverload, args: tuple = (), kwargs: Optional[dict] = None, quantize: bool = False, q_params: Optional[tuple] = None, from_node: Optional[torch.fx.Node] = None, + inherit_qparams: bool = False, ): """ Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node. @@ -116,6 +147,14 @@ def create_node( keys = from_node.meta.keys() for key in keys: new_meta[key] = from_node.meta[key] + if not inherit_qparams: + if "input_qparams" in new_meta: + new_meta["input_qparams"] = {} + if "output_qparams" in new_meta: + new_meta["output_qparams"] = {} + elif inherit_qparams: + raise ValueError("inherit_qparams is only valid when from_node is given") + old_stack_trace = new_meta.get("stack_trace", "") new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}" node.meta = new_meta @@ -200,7 +239,7 @@ def get_node_arg(args: list | dict, key: int | str | type, default_value=None): f"Out of bounds index {key} for getting value in args (of size {len(args)})" ) elif isinstance(key, str): - return args.get(key, default_value) # type: ignore[union-attr] # pyre-ignore[16] + return args.get(key, default_value) # type: ignore[union-attr] elif isclass(key): for arg in args: if isinstance(arg, key): diff --git a/backends/arm/_passes/broadcast_args_pass.py b/backends/arm/_passes/broadcast_args_pass.py index f125ba13ff4..d11fb779280 100644 --- a/backends/arm/_passes/broadcast_args_pass.py +++ b/backends/arm/_passes/broadcast_args_pass.py @@ -3,16 +3,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, ) +from executorch.backends.arm.tosa.specification import get_context_spec from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node @@ -22,6 +25,8 @@ class BroadcastArgsPass(ArmPass): This is done when more than one arg needs broadcasting. """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.sub.Tensor, @@ -30,6 +35,9 @@ class BroadcastArgsPass(ArmPass): } def call(self, graph_module: GraphModule) -> PassResult: + tosa_spec = get_context_spec() + if not tosa_spec.is_U55_subset: + return PassResult(graph_module, False) for node in graph_module.graph.nodes: if node.op != "call_function" or node.target not in self.targeted_ops: continue @@ -55,6 +63,7 @@ def call(self, graph_module: GraphModule) -> PassResult: args=(arg, multiples), kwargs={}, from_node=node, + inherit_qparams=False, ) node.replace_input_with(arg, repeat) diff --git a/backends/arm/_passes/cast_bool_to_int8_pass.py b/backends/arm/_passes/cast_bool_to_int8_pass.py deleted file mode 100644 index 1352671b01e..00000000000 --- a/backends/arm/_passes/cast_bool_to_int8_pass.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool as input -# If input/output is bool lest add a cast/conversion pass before/after to/from int8. - -import torch - -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass - - -class CastBoolToInt8Pass(ExportPass): - """Casts the input to int8 if it is not already and casts back the output to the original input dtype.""" - - targeted_ops = { - exir_ops.edge.aten.bitwise_and.Tensor, - exir_ops.edge.aten.bitwise_or.Tensor, - exir_ops.edge.aten.bitwise_xor.Tensor, - } - - def call_operator(self, op, args, kwargs, meta): - if op not in self.targeted_ops: - return super().call_operator(op, args, kwargs, meta) - - new_args: list = [] - did_cast = False - for arg in args: - if arg.data.dtype == torch.bool: - new_args.append( - super().call_operator( - exir_ops.edge.dim_order_ops._to_dim_order_copy.default, - (arg,), - {"dtype": torch.int8}, - meta, - ) - ) - did_cast = True - else: - new_args.append(arg) - - output = super().call_operator( - op, - tuple(new_args), - {}, - meta, - ) - - if did_cast: - output = super().call_operator( - exir_ops.edge.dim_order_ops._to_dim_order_copy.default, - (output,), - {"dtype": args[0].data.dtype}, - meta, - ) - return output diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py index 8052c8fd2ce..02a9cbeceaf 100644 --- a/backends/arm/_passes/cast_int64_pass.py +++ b/backends/arm/_passes/cast_int64_pass.py @@ -3,24 +3,28 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import logging +from typing import Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.exir.pass_base import ExportPass, PassResult from torch._export.utils import is_buffer +from torch.export import ExportedProgram logger = logging.getLogger(__name__) -class CastInt64BuffersToInt32Pass(ExportPass): +class CastInt64BuffersToInt32Pass(ArmPass): """ Cast int64 buffers to int32 if the int64 data is in int32 range. """ - def __init__(self, exported_program: torch.export.ExportedProgram): - super(CastInt64BuffersToInt32Pass, self).__init__() + _passes_required_after: Set[Type[ExportPass]] = set() + + def __init__(self, exported_program: ExportedProgram): + super().__init__() self.exported_program = exported_program def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node): @@ -37,6 +41,8 @@ def _to_int32(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: if len(node.users) == 0: continue + if "val" not in node.meta: + continue fake_tensor = node.meta["val"] if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): continue diff --git a/backends/arm/_passes/cast_to_int32_pass.py b/backends/arm/_passes/cast_to_int32_pass.py index c4b009e2b88..40f7e347b0f 100644 --- a/backends/arm/_passes/cast_to_int32_pass.py +++ b/backends/arm/_passes/cast_to_int32_pass.py @@ -3,20 +3,33 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch +from executorch.backends.arm._passes.arm_pass import ArmPass + +from executorch.backends.arm.tosa.specification import get_context_spec from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass +from executorch.exir.pass_base import ExportPass, PassResult -class CastToInt32Pass(ExportPass): +class CastToInt32Pass(ArmPass): """Casts the input to int32 if it is not already and casts back the output to the original input dtype.""" + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { exir_ops.edge.aten.bitwise_left_shift.Tensor, exir_ops.edge.aten.bitwise_right_shift.Tensor, } + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + tosa_spec = get_context_spec() + if not tosa_spec.is_U55_subset: + return PassResult(graph_module, False) + return super().call(graph_module) + def call_operator(self, op, args, kwargs, meta): if op not in self.targeted_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index 56f674e9066..f0b1026577b 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -6,11 +6,18 @@ # LICENSE file in the root directory of this source tree. +from typing import Set, Type + +from executorch.backends.arm._passes import ArmPass + +from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass +from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class Conv1dUnsqueezePass(ExportPass): +class Conv1dUnsqueezePass(ArmPass): """ This pass is used to change conv1d ops into conv2d since TOSA only supports 2d and 3d convolution. This is done by modifying the graph to do the @@ -21,6 +28,11 @@ class Conv1dUnsqueezePass(ExportPass): 3) squeeze the output back down to 3d. """ + _passes_required_after: Set[Type[ExportPass]] = { + RewriteConvPass, + SizeAdjustInputPass, + } + def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.convolution.default: return super().call_operator(op, args, kwargs, meta) @@ -28,10 +40,18 @@ def call_operator(self, op, args, kwargs, meta): if len(stride) != 1: return super().call_operator(op, args, kwargs, meta) + x_meta = meta.copy() + x_meta.data["input_qparams"] = {} + x_meta.data["output_qparams"] = {} + x = args[0] x_unsqueezed_shape = list(x.data.shape) + [1] x = super().call_operator( - exir_ops.edge.aten.view_copy.default, (x, x_unsqueezed_shape), {}, meta + exir_ops.edge.aten.view_copy.default, + (x, x_unsqueezed_shape), + {}, + x_meta, + updated=True, ) w_meta = meta.copy() @@ -41,7 +61,11 @@ def call_operator(self, op, args, kwargs, meta): w = args[1] w_unsqueezed_shape = list(w.data.shape) + [1] w = super().call_operator( - exir_ops.edge.aten.view_copy.default, (w, w_unsqueezed_shape), {}, w_meta + exir_ops.edge.aten.view_copy.default, + (w, w_unsqueezed_shape), + {}, + w_meta, + updated=True, ) new_args = ( @@ -56,12 +80,19 @@ def call_operator(self, op, args, kwargs, meta): args[8], ) x = super().call_operator( - exir_ops.edge.aten.convolution.default, new_args, kwargs, meta + exir_ops.edge.aten.convolution.default, new_args, kwargs, meta, updated=True ) + x_squeezed_meta = meta.copy() + x_squeezed_meta.data["input_qparams"] = {} + x_squeezed_meta.data["output_qparams"] = {} x_squeezed_shape = list(x.data.shape)[:-1] x = super().call_operator( - exir_ops.edge.aten.view_copy.default, (x, x_squeezed_shape), {}, meta + exir_ops.edge.aten.view_copy.default, + (x, x_squeezed_shape), + {}, + x_squeezed_meta, + updated=True, ) return x diff --git a/backends/arm/_passes/convert_any_default_dim_dims_pass.py b/backends/arm/_passes/convert_any_default_dim_dims_pass.py deleted file mode 100644 index 7085f17add0..00000000000 --- a/backends/arm/_passes/convert_any_default_dim_dims_pass.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from executorch.exir.dialects._ops import ( # type: ignore[import-not-found] - ops as exir_ops, -) -from executorch.exir.pass_base import ( # type: ignore[import-not-found] - ExportPass, - PassResult, -) - - -class ConvertAnyDefaultDimDimsPass(ExportPass): - """ - Converts any.default, any.dim and any.dims to a sequence of any.dim by unrolling multi-dimensional reduction. - Please refer to KeepDimsFalseToSqueezePass for an explanation of this coversion. - - Example 1 - Original: - any() # x.shape: [dim1, dim2, ..., dimn] - After pass: - any.dim(dim1, keepdim = True) - any.dim(dim2, keepdim = True) - ... - any.dim(dimn, keepdim = True) - squeeze(dim = [dim1, dim2, ...., dimn]) - - Example 2 - Original: - any.dim(dim1, keepdim = False) - After pass: - any.dim(dim1, keepdim = True) - squeeze(dim = [dim1]) - - Example 3 - Original: - any.dims([dim1, dim2], keepdim = False) - After pass: - any.dim(dim1, keepdim = True) - any.dim(dim2, keepdim = True) - squeeze(dim = [dim1, dim2]) - """ - - def call(self, graph_module: torch.fx.GraphModule): - modified = False - for node in graph_module.graph.nodes: - if node.op != "call_function": - continue - if node.target not in [ - exir_ops.edge.aten.any.default, - exir_ops.edge.aten.any.dim, - exir_ops.edge.aten.any.dims, - ]: - continue - - if len(node.args) == 1: - # any.default(input) - input_node = (node.args)[0] - dims = range(len(input_node.meta["val"].shape)) - keepdim = False - elif len(node.args) == 2: - # any.dim/dims(input, dims=dims) - input_node, dims = node.args - keepdim = False - elif len(node.args) == 3: - # any.dim/dims(input, dims=dims, keepdim=keepdim) - input_node, dims, keepdim = node.args - else: - raise RuntimeError( - f"Unexpected arg size {len(node.args)} in {node.name}" - ) - try: - iter(dims) - except: - dims = [dims] # type: ignore[assignment] - else: - dims = list(dims) # type: ignore[assignment] - - # Unroll multi-dimensional reduction and keep-dims arg - with graph_module.graph.inserting_before(node): - for dim in dims: - args = (input_node, dim, True) - input_node = graph_module.graph.create_node( - "call_function", exir_ops.edge.aten.any.dim, args, node.kwargs - ) - - if not keepdim: - args = (input_node, dims) # type: ignore[assignment] - input_node = graph_module.graph.create_node( - "call_function", - exir_ops.edge.aten.squeeze_copy.dims, - args, - ) - - node.replace_all_uses_with(input_node) - modified = True - - if modified: - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/convert_elu_params.py b/backends/arm/_passes/convert_elu_params.py index 7da58ae4bb4..737ea85a156 100644 --- a/backends/arm/_passes/convert_elu_params.py +++ b/backends/arm/_passes/convert_elu_params.py @@ -3,13 +3,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm.constants import DQ_OPS from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -class ConvertELUParamsPass(ExportPass): +class ConvertELUParamsPass(ArmPass): """ Pass to convert the input_scale kwarg of ELU operator from float to int. @@ -18,6 +22,8 @@ class ConvertELUParamsPass(ExportPass): the value of input_scale is, as long as that value is not 1. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): modified_graph = False graph = graph_module.graph @@ -25,8 +31,16 @@ def call(self, graph_module: torch.fx.GraphModule): op="call_function", target=exir_ops.edge.aten.elu.default ) for node in node_list: + input_node = node.all_input_nodes[0] + is_quantized = ( + input_node.op == "call_function" and input_node.target in DQ_OPS + ) + if not is_quantized: + continue with graph.inserting_after(node): - replace_node = create_node(graph, exir_ops.edge.aten.elu.default) + replace_node = create_node( + graph, exir_ops.edge.aten.elu.default, from_node=node + ) old_args = list(node.args) alpha = old_args[1] if len(old_args) > 1 else 1.0 diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index ee509c7ebb5..0cd306086cb 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -3,13 +3,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import logging -from typing import cast +from typing import cast, Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import ( + UnsqueezeBeforeRepeatPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -17,6 +20,7 @@ def calculate_multiples(args): + """Returns expand args converted to repeat args, and whether the expand changes the rank""" input_node_or_tensor = args[0] if isinstance(input_node_or_tensor, torch.fx.node.Node): @@ -42,14 +46,16 @@ def calculate_multiples(args): multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1 for i in range(expanded_rank) ] - return multiples + return multiples, expanded_rank != len(input_shape) -class ConvertExpandCopyToRepeatPass(ExportPass): +class ConvertExpandCopyToRepeatPass(ArmPass): """ Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions. """ + _passes_required_after: Set[Type[ExportPass]] = {UnsqueezeBeforeRepeatPass} + expand_copy = exir_ops.edge.aten.expand_copy.default repeat = exir_ops.edge.aten.repeat.default @@ -57,9 +63,9 @@ def call_operator(self, op, args, kwargs, meta): if op != self.expand_copy: return super().call_operator(op, args, kwargs, meta) - multiples = calculate_multiples(args) + multiples, changes_rank = calculate_multiples(args) - if all((x == 1 for x in multiples)): + if all((x == 1 for x in multiples)) and not changes_rank: # All dimensions/repetitions occur only once. Remove node # altogether since it's in practice just a copy. logger.warning("Found redundant expand node (no-op). Removing it.") diff --git a/backends/arm/_passes/convert_full_like_to_full_pass.py b/backends/arm/_passes/convert_full_like_to_full_pass.py index 234e2ecda82..becb0b7f971 100644 --- a/backends/arm/_passes/convert_full_like_to_full_pass.py +++ b/backends/arm/_passes/convert_full_like_to_full_pass.py @@ -3,11 +3,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class ConvertFullLikeToFullPass(ExportPass): +class ConvertFullLikeToFullPass(ArmPass): """As per the full_like pytorch documentation, `torch.full_like(input, fill_value)` is equivalent to `torch.full(input.size(), @@ -19,6 +26,8 @@ class ConvertFullLikeToFullPass(ExportPass): Skip layout and device since it's not relevant for our backend. """ + _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} + def call_operator(self, op, args, kwargs, meta): if op not in [ exir_ops.edge.aten.full_like.default, diff --git a/backends/arm/_passes/convert_int64_const_ops_to_int32.py b/backends/arm/_passes/convert_int64_const_ops_to_int32.py index 704c89dbd78..85fcf715f07 100644 --- a/backends/arm/_passes/convert_int64_const_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_const_ops_to_int32.py @@ -3,13 +3,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe - import logging +from typing import Set, Type import torch -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.exir.pass_base import ExportPass, PassResult @@ -18,7 +20,7 @@ INT32_MAX = torch.iinfo(torch.int32).max -class ConvertInt64ConstOpsToInt32Pass(ExportPass): +class ConvertInt64ConstOpsToInt32Pass(ArmPass): """ Rewrite constant ops that produce int64 to int32 where safe. @@ -30,6 +32,8 @@ class ConvertInt64ConstOpsToInt32Pass(ExportPass): 5. `torch.tensor` """ + _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} + torch_ops = [ torch.ops.aten.full.default, torch.ops.aten.arange.default, @@ -45,7 +49,10 @@ def call(self, graph_module: torch.fx.GraphModule): if node.op != "call_function": continue - if node.target not in ComputeConstantOpsAOT.targeted_ops + self.torch_ops: + if ( + node.target + not in ComputeConstantOpsAOTPass.targeted_ops + self.torch_ops + ): continue data = node.target(*node.args, **node.kwargs) diff --git a/backends/arm/_passes/convert_int64_output_ops_to_int32.py b/backends/arm/_passes/convert_int64_output_ops_to_int32.py index 788201be6c8..048219198b8 100644 --- a/backends/arm/_passes/convert_int64_output_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_output_ops_to_int32.py @@ -3,12 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe - import logging +from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -class ConvertInt64OutputOpsToInt32Pass(ExportPass): +class ConvertInt64OutputOpsToInt32Pass(ArmPass): """ Rewrites or removes operations that produce int64 outputs, converting them to int32 where possible. @@ -44,6 +44,8 @@ class ConvertInt64OutputOpsToInt32Pass(ExportPass): the int32 range. """ + _passes_required_after: Set[Type[ExportPass]] = set() + aten_cast_ops = ( torch.ops.aten.to.dtype, torch.ops.aten.to.dtype_layout, diff --git a/backends/arm/_passes/convert_int_pow_to_mul.py b/backends/arm/_passes/convert_int_pow_to_mul.py deleted file mode 100644 index f22a2fd0b3c..00000000000 --- a/backends/arm/_passes/convert_int_pow_to_mul.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from executorch.backends.arm._passes import ArmPass -from executorch.exir.dialects._ops import ops as exir_ops - - -class ConvertIntPowToMuls(ArmPass): - """ - Replaces pow with integer exponent with a series of multiplications. - Only handles pow.Tensor_Scalar and not pow.Tensor_Tensor. - Needs to be run before doing scalar to tensor conversion. - """ - - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.pow.Tensor_Scalar: - return super().call_operator(op, args, kwargs, meta) - - x = args[0] - exp = args[1] - - # Handle zero first and return early - if exp == 0: - # return a tensor of ones with the same shape as x - return super().call_operator( - exir_ops.edge.aten.full_like.default, (x, 1), {}, meta, True - ) - - if not isinstance(exp, int): - return super().call_operator(op, args, kwargs, meta) - - # Handle negative exponent - if exp < 0: - x = super().call_operator( - exir_ops.edge.aten.reciprocal.default, (x,), {}, meta, True - ) - exp = -exp - - res = x - - # Consider exponentiation by squaring, if exp turns out to be large. - # Now we just roll out the multiplications. - for _ in range(exp - 1): - res = super().call_operator( - exir_ops.edge.aten.mul.Tensor, (res, x), {}, meta, True - ) - - return res diff --git a/backends/arm/_passes/convert_minmax_pass.py b/backends/arm/_passes/convert_minmax_pass.py index 9f409632c20..66da43c57b4 100644 --- a/backends/arm/_passes/convert_minmax_pass.py +++ b/backends/arm/_passes/convert_minmax_pass.py @@ -3,12 +3,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import cast, Set, Type + import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) +from executorch.backends.arm._passes.convert_squeezes_to_view import ( + ConvertSqueezesToViewPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -class ConvertMinMaxPass(ExportPass): +class ConvertMinMaxPass(ArmPass): """ Converts min/max to amin/amax and unrolls multi-dimensional reduction and keep-dims arg to be TOSA compliant. @@ -29,6 +39,8 @@ class ConvertMinMaxPass(ExportPass): squeeze(dim = [dim1, dim2]) """ + _passes_required_after: Set[Type[ExportPass]] = {ConvertSqueezesToViewPass} + def check_argmax(self, node): """ Raises a RuntimeError if the argmax value returned by the min/max op is used in the graph. @@ -94,35 +106,49 @@ def call(self, graph_module: torch.fx.GraphModule): replace_node, op, squeeze_op = self.get_variables(node) # Unwrap args - if len(node.args) == 2: + if len(node.args) == 1: + # If dims is unspecified, min/max over all dims. + input_node = cast(torch.fx.Node, node.args[0]) + input_shape = get_first_fake_tensor(input_node).shape + dims = range(len(input_shape)) + keepdims = False + elif len(node.args) == 2: input_node, dims = node.args keepdims = False elif len(node.args) == 3: input_node, dims, keepdims = node.args else: - raise RuntimeError(f"Unexpected arg size in {node.name}") + raise RuntimeError( + f"Unexpected arg size {len(node.args)} in {node.name}" + ) try: - iter(dims) - except: - dims = [dims] + iter(dims) # type:ignore[assignment] + except Exception: + dims = [dims] # type:ignore[assignment] else: - dims = list(dims) + dims = list(dims) # type:ignore[assignment] # Unroll multi-dimensional reduction and keep-dims arg with graph_module.graph.inserting_before(node): for dim in dims: args = (input_node, dim, True) - input_node = graph_module.graph.create_node( - "call_function", op, args, node.kwargs + input_node = create_node( + graph=graph_module.graph, + op_target=op, + args=args, + kwargs={}, + from_node=node, ) if not keepdims: - input_node = graph_module.graph.create_node( - "call_function", - squeeze_op, - (input_node, dims), + input_node = create_node( + graph=graph_module.graph, + op_target=squeeze_op, + args=(input_node, dims), + kwargs={}, + from_node=node, ) replace_node.replace_all_uses_with(input_node) diff --git a/backends/arm/_passes/convert_permute_singleton_to_view_pass.py b/backends/arm/_passes/convert_permute_singleton_to_view_pass.py new file mode 100644 index 00000000000..fe4697bc213 --- /dev/null +++ b/backends/arm/_passes/convert_permute_singleton_to_view_pass.py @@ -0,0 +1,64 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Sequence, Set, Tuple, Type + +from executorch.backends.arm._passes.arm_pass import ArmPass + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +from torch._ops import OpOverload + + +_PERMUTE_TARGETS: Tuple[OpOverload, ...] = ( + exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, +) + + +class ConvertPermuteSingletonToViewPass(ArmPass): + """Replace permutations that only move singleton axes with a reshape. + + Examples: + x = rand(1,1,1,4) + y = permute(x, (0,3,1,2)) + + becomes: + x = rand(1,1,1,4) + y = view_copy(x, (1,4,1,1)) + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call_operator(self, op, args, kwargs, meta): + if op not in _PERMUTE_TARGETS: + return super().call_operator(op, args, kwargs, meta) + + input_tensor = args[0].data + permutation = args[1] + if not is_singleton_permutation(input_tensor.shape, permutation): + return super().call_operator(op, args, kwargs, meta) + + output_shape = meta["val"].shape + view_args = (args[0], output_shape) + return super().call_operator( + exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta + ) + + +def is_singleton_permutation(shape: Sequence[int], permutation: Sequence[int]) -> bool: + """ + Treat as a view only when non-singleton axes keep their order; singleton + axes may move freely since they carry no data volume. + """ + rank = len(shape) + normalized_perm = [d % rank for d in permutation] + + non_singleton_axes = [i for i, size in enumerate(shape) if size != 1] + permuted_non_singleton_axes = [axis for axis in normalized_perm if shape[axis] != 1] + + return permuted_non_singleton_axes == non_singleton_axes diff --git a/backends/arm/_passes/convert_split_to_slice.py b/backends/arm/_passes/convert_split_to_slice.py index 67bd9d73e81..5006c3006e8 100644 --- a/backends/arm/_passes/convert_split_to_slice.py +++ b/backends/arm/_passes/convert_split_to_slice.py @@ -3,9 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe + +from typing import Set, Type import torch.fx +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, @@ -14,11 +16,13 @@ from executorch.exir.pass_base import ExportPass, PassResult -class ConvertSplitToSlicePass(ExportPass): +class ConvertSplitToSlicePass(ArmPass): """ Replace a split operation with many slice operations. """ + _passes_required_after: Set[Type[ExportPass]] = set() + split_ops = ( exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.split_copy.Tensor, @@ -42,13 +46,24 @@ def call(self, graph_module: torch.fx.GraphModule): dim = (dim + rank) % rank # Validate that split lengths cover the entire dimension - length_sum = sum(split_lengths) + dim_size = shape[dim] - if length_sum != dim_size: - raise ValueError( - f"Split sizes {split_lengths} sum to {length_sum}, " - f"but dimension {dim} has size {dim_size}" - ) + if isinstance(split_lengths, int): + if split_lengths <= 0: + raise ValueError( + f"Split size must be positive, got {split_lengths}" + ) + full_chunks, remainder = divmod(dim_size, split_lengths) + split_lengths = [split_lengths] * full_chunks + if remainder: + split_lengths.append(remainder) + else: + length_sum = sum(split_lengths) + if length_sum != dim_size: + raise ValueError( + f"Split sizes {split_lengths} sum to {length_sum}, " + f"but dimension {dim} has size {dim_size}" + ) # Convert split argument 'split_lengths' to slice arguments start and end. starts = [0] * len(split_lengths) @@ -70,11 +85,48 @@ def call(self, graph_module: torch.fx.GraphModule): graph, self.slice, (input_node, dim, starts[index], ends[index]), + from_node=node, + ) + slice_node.meta = _copy_user_node_qparams( + split_node, output_node, index ) - slice_node.meta = split_node.meta.copy() - slice_node.meta["val"] = slice_node.meta["val"][index] output_node.replace_all_uses_with(slice_node) graph.eliminate_dead_code() graph_module.recompile() graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True) + + +def _copy_user_node_qparams( + split_node: torch.fx.Node, output_node: torch.fx.Node, index: int +) -> dict: + """ + Construct metadata for the slice node that will replace the split output. + + Note that output quantization parameters are copied from the user nodes + of the split node. The split node itself does not have output quantization + parameters. + + Args: + split_node: The split node being replaced. + output_node: The getitem node that is user of the split node. + index: The index of the output being processed. + Returns: + Updated metadata dictionary for the slice node. + """ + + def _select_index(value): + if isinstance(value, (list, tuple)): + return value[index] + return value + + meta = split_node.meta.copy() + if "val" in meta: + meta["val"] = _select_index(meta["val"]) + if "tensor_meta" in meta: + meta["tensor_meta"] = _select_index(meta["tensor_meta"]) + if "input_qparams" in meta: + meta["input_qparams"] = dict(meta["input_qparams"]) + if "output_qparams" in meta: + meta["output_qparams"] = dict(output_node.meta["output_qparams"]) + return meta diff --git a/backends/arm/_passes/convert_squeezes_to_view.py b/backends/arm/_passes/convert_squeezes_to_view.py index 889dbe74172..9d185a8e08c 100644 --- a/backends/arm/_passes/convert_squeezes_to_view.py +++ b/backends/arm/_passes/convert_squeezes_to_view.py @@ -1,20 +1,26 @@ # Copyright 2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +from typing import Set, Type + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.fuse_view_copy_transform_pass import ( + FuseViewCopyTransformPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class ConvertSqueezesToViewPass(ExportPass): +class ConvertSqueezesToViewPass(ArmPass): """ Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors. """ + _passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass} + def call_operator(self, op, args, kwargs, meta): if op not in [ exir_ops.edge.aten.squeeze_copy.dims, diff --git a/backends/arm/_passes/convert_to_clamp.py b/backends/arm/_passes/convert_to_clamp.py deleted file mode 100644 index 8f2c9b16f9a..00000000000 --- a/backends/arm/_passes/convert_to_clamp.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Tuple - -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass - -edge_operators = { - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.relu.default, -} - - -def get_clamp_params(op, args) -> Tuple[float | None, float | None]: - if op == exir_ops.edge.aten.hardtanh.default: - return args[1], args[2] - elif op == exir_ops.edge.aten.relu.default: - return 0.0, None - else: - raise ValueError(f"Getting clamp parameters for op {op} is not implemented.") - - -class ConvertToClampPass(ExportPass): - def call_operator(self, op, args, kwargs, meta): - if op not in edge_operators: - return super().call_operator(op, args, kwargs, meta) - - return super().call_operator( - exir_ops.edge.aten.clamp.default, - (args[0], *get_clamp_params(op, args)), - {}, - meta, - ) diff --git a/backends/arm/_passes/convert_to_clamp_pass.py b/backends/arm/_passes/convert_to_clamp_pass.py new file mode 100644 index 00000000000..4b28f993acd --- /dev/null +++ b/backends/arm/_passes/convert_to_clamp_pass.py @@ -0,0 +1,45 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Tuple, Type + +from executorch.backends.arm._passes import ArmPass + +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + QuantizeClampArgumentsPass, +) + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +edge_operators = { + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.relu.default, +} + + +def get_clamp_params(op, args) -> Tuple[float | None, float | None]: + if op == exir_ops.edge.aten.hardtanh.default: + return args[1], args[2] + elif op == exir_ops.edge.aten.relu.default: + return 0.0, None + else: + raise ValueError(f"Getting clamp parameters for op {op} is not implemented.") + + +class ConvertToClampPass(ArmPass): + _passes_required_after: Set[Type[ExportPass]] = {QuantizeClampArgumentsPass} + + def call_operator(self, op, args, kwargs, meta): + if op not in edge_operators: + return super().call_operator(op, args, kwargs, meta) + + return super().call_operator( + exir_ops.edge.aten.clamp.default, + (args[0], *get_clamp_params(op, args)), + {}, + meta, + updated=True, + ) diff --git a/backends/arm/_passes/decompose_acosh_pass.py b/backends/arm/_passes/decompose_acosh_pass.py index 1d92dd68c4a..1d29986433b 100644 --- a/backends/arm/_passes/decompose_acosh_pass.py +++ b/backends/arm/_passes/decompose_acosh_pass.py @@ -3,10 +3,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe + +from typing import Set, Type from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass # noqa +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorByProfilePass, +) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case edge_acosh_op = exir_ops.edge.aten.acosh.default @@ -19,11 +28,27 @@ class DecomposeAcoshPass(ArmPass): acosh(x) = log(x + sqrt((x-1)(x+1)) """ + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeSqrtPass, + InsertTableOpsPass, + MatchArgRanksPass, + ReplaceScalarWithTensorByProfilePass, + MatchArgDtypePass, + } + def call_operator(self, op, args, kwargs, meta, updated=False): if op is not edge_acosh_op: return super().call_operator(op, args, kwargs, meta, updated) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta, updated) + log_op, sqrt_op, mul_op, sub_op, add_op, add_op_scalar = ( exir_ops.edge.aten.log.default, exir_ops.edge.aten.sqrt.default, diff --git a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py index abfcc8e3945..5905e8f4496 100644 --- a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py @@ -4,12 +4,17 @@ # LICENSE file in the root directory of this source tree. from math import ceil, floor +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_avg_pool2d_pass import ( + DecomposeAvgPool2dPass, +) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, NodeMetadata edge_ops = (exir_ops.edge.aten._adaptive_avg_pool2d.default,) aten_ops = (torch.ops.aten.adaptive_avg_pool2d.default,) @@ -41,6 +46,8 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass): The output is of size output_size_h x output_size_w for any input. """ + _passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2dPass} + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (edge_ops + aten_ops): return super().call_operator(op, args, kwargs, meta, updated) @@ -55,6 +62,11 @@ def call_operator(self, op, args, kwargs, meta, updated=False): # Vela currently only allows a stride in the interval of [1,3] for AvgPool2d. # To accommodate this, the AvgPool2d op is applied to pooling regions and the results are concatenated. + # Slices and concats does not require quantization parameters + metadata_dict = dict(meta.data) + metadata_dict["input_qparams"] = {} + metadata_dict["output_qparams"] = {} + meta_with_no_qparams = NodeMetadata(metadata_dict) res = [] for out_i in range(output_size_h): row = [] @@ -67,11 +79,15 @@ def call_operator(self, op, args, kwargs, meta, updated=False): # Slice along H x_h = super().call_operator( - slice_op, (x, 2, start_h, end_h), kwargs, meta, True + slice_op, (x, 2, start_h, end_h), kwargs, meta_with_no_qparams, True ) # Slice along W x_hw = super().call_operator( - slice_op, (x_h, 3, start_w, end_w), kwargs, meta, True + slice_op, + (x_h, 3, start_w, end_w), + kwargs, + meta_with_no_qparams, + True, ) # Apply avg pooling with kernel size equal to the pooling region @@ -84,9 +100,13 @@ def call_operator(self, op, args, kwargs, meta, updated=False): row.append(pooled) # Concatenate row results along width (dim=3) - row_tensor = super().call_operator(cat_op, (row, 3), kwargs, meta, True) + row_tensor = super().call_operator( + cat_op, (row, 3), kwargs, meta_with_no_qparams, True + ) res.append(row_tensor) # Concatenate all rows along height (dim=2) - out = super().call_operator(cat_op, (res, 2), kwargs, meta, True) + out = super().call_operator( + cat_op, (res, 2), kwargs, meta_with_no_qparams, True + ) return out diff --git a/backends/arm/_passes/decompose_add_sub_alpha_pass.py b/backends/arm/_passes/decompose_add_sub_alpha_pass.py new file mode 100644 index 00000000000..c0ed1bae09b --- /dev/null +++ b/backends/arm/_passes/decompose_add_sub_alpha_pass.py @@ -0,0 +1,94 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import numbers +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +_ADD_OPS = ( + exir_ops.edge.aten.add.Tensor, + torch.ops.aten.add.Tensor, +) + +_SUB_OPS = ( + exir_ops.edge.aten.sub.Tensor, + torch.ops.aten.sub.Tensor, +) + + +def _get_ops(op): + if op in _ADD_OPS: + if op is exir_ops.edge.aten.add.Tensor: + return ( + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.add.Tensor, + ) + return ( + torch.ops.aten.mul.Tensor, + torch.ops.aten.full.default, + torch.ops.aten.add.Tensor, + ) + if op in _SUB_OPS: + if op is exir_ops.edge.aten.sub.Tensor: + return ( + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.sub.Tensor, + ) + return ( + torch.ops.aten.mul.Tensor, + torch.ops.aten.full.default, + torch.ops.aten.sub.Tensor, + ) + raise RuntimeError(f"Unsupported operator {op}") + + +def _should_decompose(alpha) -> bool: + if isinstance(alpha, numbers.Number): + return alpha != 1 + return False + + +class DecomposeAddSubAlphaPass(ArmPass): + """Rewrite add/sub with alpha into a mul followed by add/sub.""" + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): + if op not in _ADD_OPS + _SUB_OPS: + return super().call_operator(op, args, kwargs, meta, updated) + + alpha = kwargs.get("alpha", 1) + if not _should_decompose(alpha): + return super().call_operator(op, args, kwargs, meta, updated) + + mul_op, full_op, binary_op = _get_ops(op) + lhs, rhs = args + + alpha_full = super().call_operator( + full_op, ((1,), float(alpha)), {}, meta, updated=True + ) + scaled_rhs = super().call_operator( + mul_op, + (rhs, alpha_full), + {}, + meta, + updated=True, + ) + return super().call_operator( + binary_op, + (lhs, scaled_rhs), + {}, + meta, + updated=True, + ) diff --git a/backends/arm/_passes/decompose_addmm_pass.py b/backends/arm/_passes/decompose_addmm_pass.py index b59a8cb02d3..a95c1cc7fec 100644 --- a/backends/arm/_passes/decompose_addmm_pass.py +++ b/backends/arm/_passes/decompose_addmm_pass.py @@ -3,10 +3,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass # noqa from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case @@ -36,6 +42,12 @@ def get_ops(op): class DecomposeAddmmPass(ArmPass): """Decomposes the addmm operator into tensor multiplication and addition.""" + _passes_required_after: Set[Type[ExportPass]] = { + ConvertMmToBmmPass, + MatchArgRanksPass, + MatchArgDtypePass, + } + def call_operator(self, op, args, kwargs, meta): if op not in [edge_addmm, aten_addmm]: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_any_pass.py b/backends/arm/_passes/decompose_any_pass.py new file mode 100644 index 00000000000..a0487e7e139 --- /dev/null +++ b/backends/arm/_passes/decompose_any_pass.py @@ -0,0 +1,114 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.exir.dialects._ops import ( # type: ignore[import-not-found] + ops as exir_ops, +) +from executorch.exir.pass_base import ( # type: ignore[import-not-found] + ExportPass, + PassResult, +) + + +class DecomposeAnyPass(ArmPass): + """ + Converts any.default, any.dim and any.dims to a sequence of any.dim by + unrolling multi-dimensional reductions with keepdim=True. If keepdim=False + was requested, the final shape adjustment is implemented with a + view_copy.default to the reduced shape. + + Example 1 + Original: + any.dim() # x.shape: [dim1, dim2, ..., dimn] + After pass: + any.dim(dim1, keepdim = True) + any.dim(dim2, keepdim = True) + ... + any.dim(dimn, keepdim = True) + view_copy(shape = squeezed_shape) + + Example 2 + Original: + any.dim(dim1, keepdim = False) + After pass: + any.dim(dim1, keepdim = True) + view_copy(shape = squeezed_shape) + + Example 3 + Original: + any.dims([dim1, dim2], keepdim = False) + After pass: + any.dim(dim1, keepdim = True) + any.dim(dim2, keepdim = True) + view_copy(shape = squeezed_shape) + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target not in [ + exir_ops.edge.aten.any.default, + exir_ops.edge.aten.any.dim, + exir_ops.edge.aten.any.dims, + ]: + continue + + if len(node.args) == 1: + # any.default(input) + input_node = (node.args)[0] + dims_to_reduce = range(len(input_node.meta["val"].shape)) + keepdim = False + elif len(node.args) == 2: + # any.dim/dims(input, dims=dims) + input_node, dims_to_reduce = node.args + keepdim = False + elif len(node.args) == 3: + # any.dim/dims(input, dims=dims, keepdim=keepdim) + input_node, dims_to_reduce, keepdim = node.args + else: + raise RuntimeError( + f"Unexpected arg size {len(node.args)} in {node.name}" + ) + try: + iter(dims_to_reduce) + except: + dims_to_reduce = [dims_to_reduce] # type: ignore[assignment] + else: + dims_to_reduce = list(dims_to_reduce) # type: ignore[assignment] + + # Unroll multi-dimensional reduction and keep-dims arg + with graph_module.graph.inserting_before(node): + for dim in dims_to_reduce: + args = (input_node, dim, True) + input_node = graph_module.graph.create_node( + "call_function", exir_ops.edge.aten.any.dim, args, node.kwargs + ) + + if not keepdim: + output_shape = list(get_first_fake_tensor(node).shape) + input_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten.view_copy.default, + (input_node, output_shape), + ) + + node.replace_all_uses_with(input_node) + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index e067f17b0ca..e0da9eb9014 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -3,15 +3,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import logging from math import pi +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( + ConvertFullLikeToFullPass, +) +from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass +from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorByProfilePass, +) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case edge_asin_op = (exir_ops.edge.aten.asin.default,) @@ -54,6 +65,15 @@ class DecomposeAsinAndAcosPass(ArmPass): """ + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeSqrtPass, + DecomposeDivPass, + ConvertFullLikeToFullPass, + MatchArgRanksPass, + MatchArgDtypePass, + ReplaceScalarWithTensorByProfilePass, + } + def _build_polynomial( self, coefficients: list[float], variable: torch.Tensor, meta: dict[str, str] ) -> torch.Tensor: @@ -103,6 +123,15 @@ def _combine_branches( def call_operator(self, op, args, kwargs, meta): if op not in (edge_asin_op + edge_acos_op): return super().call_operator(op, args, kwargs, meta) + + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + logging.info( f"Approximating {op}. This may introduce small numerical errors. For details, see {__file__}." ) diff --git a/backends/arm/_passes/decompose_asinh_pass.py b/backends/arm/_passes/decompose_asinh_pass.py index a0b78c51a77..1131feea9c6 100644 --- a/backends/arm/_passes/decompose_asinh_pass.py +++ b/backends/arm/_passes/decompose_asinh_pass.py @@ -3,11 +3,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +from typing import Set, Type from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorByProfilePass, +) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case edge_asinh_op = (exir_ops.edge.aten.asinh.default,) @@ -20,10 +28,26 @@ class DecomposeAsinhPass(ArmPass): asinh(x) = log(x + sqrt(x^2 + 1)) """ + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeSqrtPass, + InsertTableOpsPass, + MatchArgRanksPass, + ReplaceScalarWithTensorByProfilePass, + MatchArgDtypePass, + } + def call_operator(self, op, args, kwargs, meta): if op not in edge_asinh_op: return super().call_operator(op, args, kwargs, meta) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + log_op, sqrt_op, mul_op, add_op_scalar, add_op = ( exir_ops.edge.aten.log.default, exir_ops.edge.aten.sqrt.default, diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py index 57b9dde5216..a3b4081755a 100644 --- a/backends/arm/_passes/decompose_atan_pass.py +++ b/backends/arm/_passes/decompose_atan_pass.py @@ -5,9 +5,17 @@ import logging from math import pi +from typing import Set, Type from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorByProfilePass, +) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_atan = exir_ops.edge.aten.atan.default # MI case @@ -35,6 +43,13 @@ def _get_atan_ops(op): class DecomposeAtanPass(ArmPass): """Decomposes the atan operator into a rational (Padé) approximation.""" + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + MatchArgRanksPass, + MatchArgDtypePass, + ReplaceScalarWithTensorByProfilePass, + } + def _rational_approximation(self, z, ops, meta): """Creates a (2,1) Padé approximation for atan(x) on [-1, 1].""" @@ -65,6 +80,14 @@ def call_operator(self, op, args, kwargs, meta): if op is not edge_atan: return super().call_operator(op, args, kwargs, meta, updated=False) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + logging.info( f"Approximating atan. This may introduce small numerical errors. For details, see {__file__}." ) diff --git a/backends/arm/_passes/decompose_atanh_pass.py b/backends/arm/_passes/decompose_atanh_pass.py index dfdad41e556..789dafed9ef 100644 --- a/backends/arm/_passes/decompose_atanh_pass.py +++ b/backends/arm/_passes/decompose_atanh_pass.py @@ -3,8 +3,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorByProfilePass, +) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_atanh = exir_ops.edge.aten.atanh.default # MI case @@ -30,10 +39,25 @@ class DecomposeAtanhPass(ArmPass): atanh(x) = 0.5 * log((1 + x) / (1 - x)) """ + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + MatchArgRanksPass, + MatchArgDtypePass, + ReplaceScalarWithTensorByProfilePass, + } + def call_operator(self, op, args, kwargs, meta): if op is not edge_atanh: return super().call_operator(op, args, kwargs, meta, updated=False) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + ops = _get_atanh_ops(op) ( op_mul_tensor, diff --git a/backends/arm/_passes/decompose_avg_pool2d.py b/backends/arm/_passes/decompose_avg_pool2d.py deleted file mode 100644 index 21ed6b518c7..00000000000 --- a/backends/arm/_passes/decompose_avg_pool2d.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -from executorch.backends.arm.operators.operator_validation_utils import ( - adjust_pooling_pad_if_needed, -) -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass - -edge_div_ops = (exir_ops.edge.aten.avg_pool2d.default,) -aten_div_ops = (torch.ops.aten.avg_pool2d.default,) - - -def get_decomposition(op) -> tuple: - if op in edge_div_ops: - return ( - exir_ops.edge.aten.full.default, - exir_ops.edge.aten.cat.default, - exir_ops.edge.aten.avg_pool2d.default, - exir_ops.edge.aten.mul.Tensor, - ) - if op in aten_div_ops: - return ( - torch.ops.aten.full.default, - torch.ops.aten.cat.default, - torch.ops.aten.avg_pool2d.default, - torch.ops.aten.mul.Tensor, - ) - raise RuntimeError(f"Can't get div decomposition for op {op}") - - -class DecomposeAvgPool2d(ExportPass): - """ """ - - def call_operator(self, op, args, kwargs, meta): - if op not in (edge_div_ops + aten_div_ops): - return super().call_operator(op, args, kwargs, meta) - - full_op, cat_op, avgpool_op, mul_op = get_decomposition(op) - - x = args[0] - kernel_h, kernel_w = args[1] - kernel_size = kernel_h * kernel_w - if len(args) > 2 and args[2] is not None: - stride_h, stride_w = args[2] - else: - stride_h, stride_w = kernel_h, kernel_w - pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0) - ceil_mode = args[4] if len(args) > 4 else False - count_include_pad = args[5] if len(args) > 5 else True - divisor_override = args[6] if len(args) > 6 else None - - n, c, h, w = x.data.shape - post_pad_w, post_pad_h = (0, 0) - - # Count_include_pad == False means that we use a different divisor for edge elements - # When divisor_override is set, this will be overriden anyways. - # It is easier to replace a constant divisor, so set count_include_pad == True - if divisor_override is not None: - count_include_pad = True - - # Add width padding manually if count_include_pad - if count_include_pad and pad_w > 0: - pre_pad_shape = [n, c, h, pad_w] - pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta) - - if ceil_mode and divisor_override is None: - post_pad_w = pad_w - else: - post_pad_w = adjust_pooling_pad_if_needed( - w, kernel_w, stride_w, pad_w, ceil_mode - ) - - if post_pad_w > 0: - post_pad_shape = [n, c, h, post_pad_w] - post_pad = super().call_operator( - full_op, (post_pad_shape, 0.0), kwargs, meta - ) - cat_nodes = [pre_pad, x, post_pad] - else: - cat_nodes = [pre_pad, x] - - x = super().call_operator(cat_op, (cat_nodes, 3), kwargs, meta) - new_pad_w = 0 - - # Add height padding manually if count_include_pad - if count_include_pad and pad_h > 0: - pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w] - pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta) - - if ceil_mode and divisor_override is None: - post_pad_h = pad_h - else: - post_pad_h = adjust_pooling_pad_if_needed( - h, kernel_h, stride_h, pad_h, ceil_mode - ) - - if post_pad_h > 0: - post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w] - post_pad = super().call_operator( - full_op, (post_pad_shape, 0.0), kwargs, meta - ) - cat_nodes = [pre_pad, x, post_pad] - else: - cat_nodes = [pre_pad, x] - - x = super().call_operator(cat_op, (cat_nodes, 2), kwargs, meta) - new_pad_h = 0 - - avgpool_args = ( - x, - args[1], - [stride_h, stride_w], - [new_pad_h, new_pad_w], - ceil_mode, - False, - ) - x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta) - - # Multiply by factor (kernel_size / divisor_override) if divisor_override - if divisor_override is not None and divisor_override != kernel_size: - override_multiplier = super().call_operator( - full_op, ([1, 1, 1, 1], kernel_size / divisor_override), kwargs, meta - ) - x = super().call_operator(mul_op, (x, override_multiplier), kwargs, meta) - - return x diff --git a/backends/arm/_passes/decompose_avg_pool2d_pass.py b/backends/arm/_passes/decompose_avg_pool2d_pass.py new file mode 100644 index 00000000000..14b03cf6243 --- /dev/null +++ b/backends/arm/_passes/decompose_avg_pool2d_pass.py @@ -0,0 +1,151 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + adjust_pooling_pad_if_needed, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +edge_div_ops = (exir_ops.edge.aten.avg_pool2d.default,) +aten_div_ops = (torch.ops.aten.avg_pool2d.default,) + + +def get_decomposition(op) -> tuple: + if op in edge_div_ops: + return ( + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.mul.Tensor, + ) + if op in aten_div_ops: + return ( + torch.ops.aten.full.default, + torch.ops.aten.cat.default, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.mul.Tensor, + ) + raise RuntimeError(f"Can't get avg_pool2d decomposition for op {op}") + + +class DecomposeAvgPool2dPass(ArmPass): + _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} + + def call_operator(self, op, args, kwargs, meta): + if op not in (edge_div_ops + aten_div_ops): + return super().call_operator(op, args, kwargs, meta) + + full_op, cat_op, avgpool_op, mul_op = get_decomposition(op) + + x = args[0] + kernel_h, kernel_w = args[1] + kernel_size = kernel_h * kernel_w + if len(args) > 2 and args[2] is not None: + stride_h, stride_w = args[2] + else: + stride_h, stride_w = kernel_h, kernel_w + pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0) + ceil_mode = args[4] if len(args) > 4 else False + count_include_pad = args[5] if len(args) > 5 else True + divisor_override = args[6] if len(args) > 6 else None + + n, c, h, w = x.data.shape + post_pad_w, post_pad_h = (0, 0) + + # Count_include_pad == False means that we use a different divisor for edge elements + # When divisor_override is set, this will be overriden anyways. + # It is easier to replace a constant divisor, so set count_include_pad == True + if divisor_override is not None: + count_include_pad = True + + # Add width padding manually if count_include_pad + if count_include_pad and pad_w > 0: + pre_pad_shape = [n, c, h, pad_w] + pre_pad = super().call_operator( + full_op, (pre_pad_shape, 0.0), kwargs, meta, updated=True + ) + + if ceil_mode and divisor_override is None: + post_pad_w = pad_w + else: + post_pad_w = adjust_pooling_pad_if_needed( + w, kernel_w, stride_w, pad_w, ceil_mode + ) + + if post_pad_w > 0: + post_pad_shape = [n, c, h, post_pad_w] + post_pad = super().call_operator( + full_op, (post_pad_shape, 0.0), kwargs, meta, updated=True + ) + cat_nodes = [pre_pad, x, post_pad] + else: + cat_nodes = [pre_pad, x] + + x = super().call_operator( + cat_op, (cat_nodes, 3), kwargs, meta, updated=True + ) + new_pad_w = 0 + + # Add height padding manually if count_include_pad + if count_include_pad and pad_h > 0: + pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w] + pre_pad = super().call_operator( + full_op, (pre_pad_shape, 0.0), kwargs, meta, updated=True + ) + + if ceil_mode and divisor_override is None: + post_pad_h = pad_h + else: + post_pad_h = adjust_pooling_pad_if_needed( + h, kernel_h, stride_h, pad_h, ceil_mode + ) + + if post_pad_h > 0: + post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w] + post_pad = super().call_operator( + full_op, (post_pad_shape, 0.0), kwargs, meta, updated=True + ) + cat_nodes = [pre_pad, x, post_pad] + else: + cat_nodes = [pre_pad, x] + + x = super().call_operator( + cat_op, (cat_nodes, 2), kwargs, meta, updated=True + ) + new_pad_h = 0 + + avgpool_args = ( + x, + args[1], + [stride_h, stride_w], + [new_pad_h, new_pad_w], + ceil_mode, + False, + ) + x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta, updated=True) + + # Multiply by factor (kernel_size / divisor_override) if divisor_override + if divisor_override is not None and divisor_override != kernel_size: + override_multiplier = super().call_operator( + full_op, + ([1, 1, 1, 1], kernel_size / divisor_override), + kwargs, + meta, + updated=True, + ) + x = super().call_operator( + mul_op, (x, override_multiplier), kwargs, meta, updated=True + ) + + return x diff --git a/backends/arm/_passes/decompose_batch_norm_no_stats.py b/backends/arm/_passes/decompose_batch_norm_no_stats.py index 5fdb8db2d7c..9a486376617 100644 --- a/backends/arm/_passes/decompose_batch_norm_no_stats.py +++ b/backends/arm/_passes/decompose_batch_norm_no_stats.py @@ -3,15 +3,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import operator +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) + +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult class DecomposeBatchNormNoStatsPass(ArmPass): @@ -33,6 +38,11 @@ class DecomposeBatchNormNoStatsPass(ArmPass): Source: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html """ + _passes_required_after: Set[Type[ExportPass]] = { + ComputeConstantOpsAOTPass, + InsertTableOpsPass, + } + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 bn_ops = ( exir_ops.edge.aten._native_batch_norm_legit.no_stats, diff --git a/backends/arm/_passes/decompose_cosh_pass.py b/backends/arm/_passes/decompose_cosh_pass.py index a94cf9ecff0..fe84f2bde9b 100644 --- a/backends/arm/_passes/decompose_cosh_pass.py +++ b/backends/arm/_passes/decompose_cosh_pass.py @@ -3,8 +3,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorByProfilePass, +) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case edge_cosh = exir_ops.edge.aten.cosh.default @@ -19,10 +28,25 @@ class DecomposeCoshPass(ArmPass): """ + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + MatchArgRanksPass, + ReplaceScalarWithTensorByProfilePass, + MatchArgDtypePass, + } + def call_operator(self, op, args, kwargs, meta, updated=False): if op is not edge_cosh: return super().call_operator(op, args, kwargs, meta, updated) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + x = args exp_op, mul_op, neg_op, add_op = ( diff --git a/backends/arm/_passes/decompose_cosine_similarity_pass.py b/backends/arm/_passes/decompose_cosine_similarity_pass.py index 9978e653408..96a95ee2a1c 100644 --- a/backends/arm/_passes/decompose_cosine_similarity_pass.py +++ b/backends/arm/_passes/decompose_cosine_similarity_pass.py @@ -3,13 +3,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( + ConvertFullLikeToFullPass, +) + +from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.pass_base import ExportPass torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,) -class DecomposeCosineSimilarityPass(ExportPass): +class DecomposeCosineSimilarityPass(ArmPass): """ Decomposition of aten.cosine_similarity: @@ -22,6 +32,13 @@ class DecomposeCosineSimilarityPass(ExportPass): out = div(dot, denom) """ + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeDivPass, + DecomposeSumPass, + ConvertFullLikeToFullPass, + InsertTableOpsPass, + } + def call_operator(self, op, args, kwargs, meta): if op not in torch_cosine_similarity: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_cumsum_pass.py b/backends/arm/_passes/decompose_cumsum_pass.py index 155ccd11594..8b7d31c97ac 100644 --- a/backends/arm/_passes/decompose_cumsum_pass.py +++ b/backends/arm/_passes/decompose_cumsum_pass.py @@ -4,15 +4,18 @@ # LICENSE file in the root directory of this source tree. from math import prod +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.quant_args import QuantArgs +from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass from executorch.backends.transforms.utils import create_constant_placeholder +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult from torch.export.graph_signature import InputKind @@ -39,6 +42,12 @@ class DecomposeCumsumPass(ArmPass): And the convolution is applied over dimension H. """ + _passes_required_after: Set[Type[ExportPass]] = {RewriteConvPass} + + def __init__(self, exported_program: ExportedProgram) -> None: + super().__init__() + self.exported_program = exported_program + def call(self, graph_module): graph = graph_module.graph targets = (exir_ops.edge.aten.cumsum.default, torch.ops.aten.cumsum.default) @@ -92,7 +101,13 @@ def call(self, graph_module): with graph.inserting_before(node): # Reshape to 4D with view_args = (input_node, conv_shape) - view_node = create_node(graph, view_op, args=view_args, from_node=node) + view_node = create_node( + graph, + view_op, + args=view_args, + from_node=node, + inherit_qparams=False, + ) conv_args = ( view_node, @@ -105,7 +120,9 @@ def call(self, graph_module): [0], 1, ) - conv_node = create_node(graph, conv_op, args=conv_args, from_node=node) + conv_node = create_node( + graph, conv_op, args=conv_args, from_node=node, inherit_qparams=True + ) # The convolution is inserted after quantization, so we need to set our # own quantization parameters for the weights here. However since the @@ -120,12 +137,20 @@ def call(self, graph_module): slice_args = (conv_node, 2, 0, original_shape[dim]) slice_node = create_node( - graph, slice_op, args=slice_args, from_node=node + graph, + slice_op, + args=slice_args, + from_node=node, + inherit_qparams=False, ) view_original_args = (slice_node, original_shape) view_original_node = create_node( - graph, view_op, args=view_original_args, from_node=node + graph, + view_op, + args=view_original_args, + from_node=node, + inherit_qparams=False, ) # Replace and remove original diff --git a/backends/arm/_passes/decompose_div_pass.py b/backends/arm/_passes/decompose_div_pass.py index 893531dac69..c1878e6ce0c 100644 --- a/backends/arm/_passes/decompose_div_pass.py +++ b/backends/arm/_passes/decompose_div_pass.py @@ -1,12 +1,15 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe + +from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -26,7 +29,7 @@ def get_div_decomposition(op) -> tuple: raise RuntimeError(f"Can't get div decomposition for op {op}") -class DecomposeDivPass(ExportPass): +class DecomposeDivPass(ArmPass): """ This pass decomposes div into a mul and a reciprocal node. @@ -37,6 +40,8 @@ class DecomposeDivPass(ExportPass): y = mul(a,x) """ + _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} + def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_ops + aten_div_ops): return super().call_operator(op, args, kwargs, meta) @@ -45,6 +50,10 @@ def call_operator(self, op, args, kwargs, meta): numerator = args[0] denominator = args[1] - reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta) + reciprocal = super().call_operator( + reciprocal_op, (denominator,), {}, meta, updated=True + ) - return super().call_operator(mul_op, (numerator, reciprocal), {}, meta) + return super().call_operator( + mul_op, (numerator, reciprocal), {}, meta, updated=True + ) diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index 0e6b40afbb2..cb7ffbb33b8 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -3,9 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe + +from typing import Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -39,7 +42,7 @@ def _get_opset(op): raise RuntimeError(f"div.Tensor_mode not supported for op {op}") -class DecomposeDivTensorModePass(ExportPass): +class DecomposeDivTensorModePass(ArmPass): """ Rewrites aten.div.Tensor_mode into @@ -48,6 +51,8 @@ class DecomposeDivTensorModePass(ExportPass): rounding_mode='trunc' -> where(div(a,b) < 0, ceil(div(a,b)), floor(div(a,b))) """ + _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivPass} + def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_mode_ops + aten_div_mode_ops): return super().call_operator(op, args, kwargs, meta) @@ -59,13 +64,13 @@ def call_operator(self, op, args, kwargs, meta): if rounding_mode is None and len(args) > 2: rounding_mode = args[2] - q = super().call_operator(opset["div"], (a, b), {}, meta) + q = super().call_operator(opset["div"], (a, b), {}, meta, updated=True) if rounding_mode is None: return q if rounding_mode == "floor": - return super().call_operator(opset["floor"], (q,), {}, meta) + return super().call_operator(opset["floor"], (q,), {}, meta, updated=True) if rounding_mode == "trunc": zero = super().call_operator( @@ -73,11 +78,14 @@ def call_operator(self, op, args, kwargs, meta): args=((1,) * len(meta["val"].size()), 0.0), kwargs={"dtype": torch.float32}, meta=meta, + updated=True, + ) + lt0 = super().call_operator(opset["lt"], (q, zero), {}, meta, updated=True) + ceilq = super().call_operator(opset["ceil"], (q,), {}, meta, updated=True) + floorq = super().call_operator(opset["floor"], (q,), {}, meta, updated=True) + return super().call_operator( + opset["where"], (lt0, ceilq, floorq), {}, meta, updated=True ) - lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta) - ceilq = self.call_operator(opset["ceil"], (q,), {}, meta) - floorq = self.call_operator(opset["floor"], (q,), {}, meta) - return self.call_operator(opset["where"], (lt0, ceilq, floorq), {}, meta) raise RuntimeError( f"Unsupported rounding_mode for div.Tensor_mode: {rounding_mode!r}" diff --git a/backends/arm/_passes/decompose_elu_pass.py b/backends/arm/_passes/decompose_elu_pass.py index 743f1b46f4d..5428465c619 100644 --- a/backends/arm/_passes/decompose_elu_pass.py +++ b/backends/arm/_passes/decompose_elu_pass.py @@ -3,8 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_elu_ops = (exir_ops.edge.aten.elu.default,) @@ -55,10 +58,20 @@ class DecomposeEluPass(ArmPass): - exir_ops.edge.aten.mul.Scalar """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in edge_elu_ops: return super().call_operator(op, args, kwargs, meta, updated=False) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + ( expm1_op, ge_op, @@ -70,8 +83,17 @@ def call_operator(self, op, args, kwargs, meta): alpha = args[1] if len(args) > 1 else 1.0 if alpha == 0: - relu_op = exir_ops.edge.aten.relu.default - return super().call_operator(relu_op, (input,), {}, meta, updated=True) + relu_op = exir_ops.edge.aten.clamp.default + return super().call_operator( + relu_op, + ( + input, + 0, + ), + {}, + meta, + updated=True, + ) expm1_node = super().call_operator(expm1_op, (input,), {}, meta, updated=True) mul_node = super().call_operator( diff --git a/backends/arm/_passes/decompose_embedding_pass.py b/backends/arm/_passes/decompose_embedding_pass.py index 6de971f402f..e9c8f303cbf 100644 --- a/backends/arm/_passes/decompose_embedding_pass.py +++ b/backends/arm/_passes/decompose_embedding_pass.py @@ -3,23 +3,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe - import logging from math import prod +from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.fuse_view_copy_transform_pass import ( + FuseViewCopyTransformPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from .arm_pass_utils import create_node, get_first_fake_tensor logger = logging.getLogger(__name__) -logger.setLevel(logging.WARNING) -class DecomposeEmbeddingPass(ExportPass): +class DecomposeEmbeddingPass(ArmPass): """ This pass decomposes embedding into index_select. @@ -33,13 +35,15 @@ class DecomposeEmbeddingPass(ExportPass): i = indices is expected to be int32 before this pass """ + _passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass} + aten_ops = (torch.ops.aten.embedding.default,) edge_ops = (exir_ops.edge.aten.embedding.default,) def get_decomposition(self, op): if op in self.aten_ops: return ( - torch.ops.aten.view_copy.default, + torch.ops.aten.reshape.default, torch.ops.aten.index_select.default, ) diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py index 5b1b90495b5..d2eb908e925 100644 --- a/backends/arm/_passes/decompose_expm1_pass.py +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -3,8 +3,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass +from executorch.backends.arm._passes.decompose_int_pow_pass import DecomposeIntPowPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorByProfilePass, +) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_expm1_ops = (exir_ops.edge.aten.expm1.default,) # MI case @@ -68,10 +79,27 @@ class DecomposeExpm1Pass(ArmPass): - exir_ops.edge.aten.logical_and.default """ + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeIntPowPass, + InsertTableOpsPass, + DecomposeDivPass, + ReplaceScalarWithTensorByProfilePass, + MatchArgDtypePass, + MatchArgRanksPass, + } + def call_operator(self, op, args, kwargs, meta): if op not in edge_expm1_ops: return super().call_operator(op, args, kwargs, meta, updated=False) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + ( op_pow, op_div, diff --git a/backends/arm/_passes/decompose_floor_divide_pass.py b/backends/arm/_passes/decompose_floor_divide_pass.py new file mode 100644 index 00000000000..c2754f46a11 --- /dev/null +++ b/backends/arm/_passes/decompose_floor_divide_pass.py @@ -0,0 +1,75 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_div_tensor_mode import ( + DecomposeDivTensorModePass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +edge_floor_divide_ops = (exir_ops.edge.aten.floor_divide.default,) +aten_floor_divide_ops = (torch.ops.aten.floor_divide.default,) + + +def get_floor_divide_decomposition(op) -> tuple: + """ + Returns the decomposition of the given aten.floor_div operation into + its equivalent TOSA-supported operations + + This handles both edge dialect ops and core PyTorch ops. The decomposition strategy + is: + floor_div(x, y) → div_tensor_mode(x, y, rounding_mode="floor") + + Returns: + A tuple (div_op,) corresponding to the appropriate operator overload for the input op. + + Raises: + RuntimeError: If the provided operator is not a supported floor_divide variant. + """ + + if op in edge_floor_divide_ops: + return ( + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.full_like.default, + ) + if op in aten_floor_divide_ops: + return ( + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.full_like.default, + ) + + raise RuntimeError(f"Can't get floor_div decomposition for op {op}") + + +class DecomposeFloorDividePass(ArmPass): + """ + Decomposes aten.floor_divide into aten.div.Tensor_mode with rounding_mode="floor". + """ + + _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass} + + def call_operator(self, op, args, kwargs, meta): + if op not in (edge_floor_divide_ops + aten_floor_divide_ops): + return super().call_operator(op, args, kwargs, meta, updated=False) + + (div_op, full_op) = get_floor_divide_decomposition(op) + + input = args[0] + other = args[1] + + if isinstance(other, int): + other = super().call_operator( + full_op, (input, other), {}, meta, updated=False + ) + + div_node = super().call_operator( + div_op, (input, other), {"rounding_mode": "floor"}, meta, updated=True + ) + + return div_node diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index 6e72175e68b..5bf39370835 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -3,8 +3,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -37,7 +46,7 @@ def _get_gelu_ops(op) -> tuple: raise RuntimeError(f"Can't get GeLU decomposition ops for op {op}") -class DecomposeGeluPass(ExportPass): +class DecomposeGeluPass(ArmPass): """ This pass decomposes the GELU operator into primitive ops. Aiming to adhere closely to the reference implementations built into @@ -77,9 +86,23 @@ class DecomposeGeluPass(ExportPass): %op7 = mul(%op6, %FULL_0_5) """ + _passes_required_after: Set[Type[ExportPass]] = { + ComputeConstantOpsAOTPass, + InsertTableOpsPass, + MatchArgDtypePass, + MatchArgRanksPass, + } + def call_operator(self, op, args, kwargs, meta): if op not in torch_gelu + edge_gelu: return super().call_operator(op, args, kwargs, meta) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) full_op, add_op, mul_op, tanh_op, erf_op = _get_gelu_ops(op) diff --git a/backends/arm/_passes/decompose_glu_pass.py b/backends/arm/_passes/decompose_glu_pass.py index 183dc89cf61..373b31c5995 100644 --- a/backends/arm/_passes/decompose_glu_pass.py +++ b/backends/arm/_passes/decompose_glu_pass.py @@ -3,9 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For FP case @@ -36,6 +40,8 @@ def get_ops(op): class DecomposeGluPass(ArmPass): """Decomposes the GLU operator into hadamard product and sigmoid.""" + _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} + def call_operator(self, op, args, kwargs, meta): if op not in [edge_glu, aten_glu]: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_grouped_conv.py b/backends/arm/_passes/decompose_grouped_conv.py deleted file mode 100644 index ce9fe9c9937..00000000000 --- a/backends/arm/_passes/decompose_grouped_conv.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from copy import copy - -import torch -from executorch.backends.arm._passes.quant_args import QuantArgs -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass - - -class DecomposeGroupedConv(ExportPass): - """ - Splits a grouped convolution which is not supported by TOSA into multiple - convolutions using slice->conv->cat. - - Before pass: - x = conv(input, weight, bias, groups = 2) - - After pass: - input1 = slice(input) - weight1 = slice(weight) - bias1 = slice(bias) - x1 = conv(input1, weight1, bias1) - - input2 = slice(input) - weight2 = slice(weight) - bias2 = slice(bias) - x2 = conv(input2, weight2, bias2) - - x = cat(x1, x2) - """ - - @staticmethod - def _get_decomposition(op): - match op: - case exir_ops.edge.aten.convolution.default: - return ( - exir_ops.edge.aten.slice_copy.Tensor, - exir_ops.edge.aten.convolution.default, - exir_ops.edge.aten.cat.default, - ) - case torch.ops.aten.conv2d.default: - return ( - torch.ops.aten.slice_copy.Tensor, - torch.ops.aten.conv2d.default, - torch.ops.aten.cat.default, - ) - case _: - raise RuntimeError("Invalid op for grouped conv decomposition") - - @staticmethod - def _split_per_channel_qparams(qarg, index, output_slice_size): - if qarg is not None and qarg.per_channel: - start_index = index * output_slice_size - stop_index = (index + 1) * output_slice_size - return QuantArgs( - scale=qarg.scale[start_index:stop_index], - zp=qarg.zp[start_index:stop_index], - qmin=qarg.qmin, - qmax=qarg.qmax, - dtype=qarg.dtype, - axis=qarg.axis, - per_channel=qarg.per_channel, - ) - return qarg - - @staticmethod - def _get_meta_copy(meta, i, output_slice_size): - meta_copy = meta.copy() - if "input_qparams" in meta.data and len(meta.data["input_qparams"]) > 0: - # Handle per-channel quantization by splitting quantization params - # similarly to how activations/weights/biases are split. - new_qparams = meta.data.get("input_qparams").copy() - # Get quantization params of the weights and slice them. - qarg = new_qparams[1] - new_qparams[1] = DecomposeGroupedConv._split_per_channel_qparams( - qarg, index=i, output_slice_size=output_slice_size - ) - - meta_copy.data["input_qparams"] = new_qparams - - return meta_copy - - def call_operator(self, op, args, kwargs, meta): - if op == exir_ops.edge.aten.convolution.default: - groups = args[8] - transposed = args[6] - elif op == torch.ops.aten.conv2d.default: - groups = args[6] - transposed = False - else: - return super().call_operator(op, args, kwargs, meta) - - if groups == 1 or transposed: - return super().call_operator(op, args, kwargs, meta) - - input_node = args[0] - if input_node.data.shape[1] == groups: - # This is a depthwise convolution which is handled elsewhere - return super().call_operator(op, args, kwargs, meta) - - weight_node = args[1] - bias_node = args[2] - - input_slice_size = weight_node.data.shape[1] - output_slice_size = weight_node.data.shape[0] // groups - - no_q_dq_meta = copy(meta) - no_q_dq_meta.data = {} - no_q_dq_meta.data = {} - - slice_op, conv_op, cat_op = DecomposeGroupedConv._get_decomposition(op) - - input_slices = [] - for i in range(groups): - start_index = i * input_slice_size - stop_index = (i + 1) * input_slice_size - slice_args = (input_node, 1, start_index, stop_index) - - input_slices.append( - super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta) - ) - - filter_slices = [] - for i in range(groups): - start_index = i * output_slice_size - stop_index = (i + 1) * output_slice_size - slice_args = (weight_node, 0, start_index, stop_index) - - filter_slices.append( - super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta) - ) - - bias_slices = [] - for i in range(groups): - if bias_node is None: - bias_slices.append(None) - else: - start_index = i * output_slice_size - stop_index = (i + 1) * output_slice_size - slice_args = (bias_node, 0, start_index, stop_index) - - bias_slices.append( - super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta) - ) - - output_slices = [] - for i, (input_slice, filter_slice, bias_slice) in enumerate( - zip(input_slices, filter_slices, bias_slices) - ): - - meta_copy = DecomposeGroupedConv._get_meta_copy(meta, i, output_slice_size) - - if op == exir_ops.edge.aten.convolution.default: - conv_args = (input_slice, filter_slice, bias_slice, *args[3:8], 1) - elif op == torch.ops.aten.conv2d.default: - conv_args = (input_slice, filter_slice, bias_slice, *args[3:6], 1) - else: - raise RuntimeError("Invalid op for grouped conv decomposition") - - output_slices.append( - super().call_operator(conv_op, conv_args, kwargs, meta_copy) - ) - - cat_args = (output_slices, 1) - # propagate original metadata (including quantization params) to the concatenated output - return super().call_operator(cat_op, cat_args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_grouped_conv_pass.py b/backends/arm/_passes/decompose_grouped_conv_pass.py new file mode 100644 index 00000000000..a0765b865fc --- /dev/null +++ b/backends/arm/_passes/decompose_grouped_conv_pass.py @@ -0,0 +1,185 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from copy import copy +from typing import Set, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass +from executorch.backends.arm._passes.quant_args import QuantArgs +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class DecomposeGroupedConvPass(ArmPass): + """ + Splits a grouped convolution which is not supported by TOSA into multiple + convolutions using slice->conv->cat. + + Before pass: + x = conv(input, weight, bias, groups = 2) + + After pass: + input1 = slice(input) + weight1 = slice(weight) + bias1 = slice(bias) + x1 = conv(input1, weight1, bias1) + + input2 = slice(input) + weight2 = slice(weight) + bias2 = slice(bias) + x2 = conv(input2, weight2, bias2) + + x = cat(x1, x2) + """ + + _passes_required_after: Set[Type[ExportPass]] = {Conv1dUnsqueezePass} + + @staticmethod + def _get_decomposition(op): + match op: + case exir_ops.edge.aten.convolution.default: + return ( + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.cat.default, + ) + case torch.ops.aten.conv2d.default: + return ( + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.conv2d.default, + torch.ops.aten.cat.default, + ) + case _: + raise RuntimeError("Invalid op for grouped conv decomposition") + + @staticmethod + def _split_per_channel_qparams(qarg, index, output_slice_size): + if qarg is not None and qarg.per_channel: + start_index = index * output_slice_size + stop_index = (index + 1) * output_slice_size + return QuantArgs( + scale=qarg.scale[start_index:stop_index], + zp=qarg.zp[start_index:stop_index], + qmin=qarg.qmin, + qmax=qarg.qmax, + dtype=qarg.dtype, + axis=qarg.axis, + per_channel=qarg.per_channel, + ) + return qarg + + @staticmethod + def _get_meta_copy(meta, i, output_slice_size): + meta_copy = meta.copy() + if "input_qparams" in meta.data and len(meta.data["input_qparams"]) > 0: + # Handle per-channel quantization by splitting quantization params + # similarly to how activations/weights/biases are split. + new_qparams = meta.data.get("input_qparams").copy() + # Get quantization params of the weights and slice them. + qarg = new_qparams[1] + new_qparams[1] = DecomposeGroupedConvPass._split_per_channel_qparams( + qarg, index=i, output_slice_size=output_slice_size + ) + + meta_copy.data["input_qparams"] = new_qparams + + return meta_copy + + def call_operator(self, op, args, kwargs, meta): + if op == exir_ops.edge.aten.convolution.default: + groups = args[8] + transposed = args[6] + elif op == torch.ops.aten.conv2d.default: + groups = args[6] + transposed = False + else: + return super().call_operator(op, args, kwargs, meta) + + if groups == 1 or transposed: + return super().call_operator(op, args, kwargs, meta) + + input_node = args[0] + if input_node.data.shape[1] == groups: + # This is a depthwise convolution which is handled elsewhere + return super().call_operator(op, args, kwargs, meta) + + weight_node = args[1] + bias_node = args[2] + + input_slice_size = weight_node.data.shape[1] + output_slice_size = weight_node.data.shape[0] // groups + + no_q_dq_meta = copy(meta) + no_q_dq_meta.data = {} + no_q_dq_meta.data = {} + + slice_op, conv_op, cat_op = DecomposeGroupedConvPass._get_decomposition(op) + + input_slices = [] + for i in range(groups): + start_index = i * input_slice_size + stop_index = (i + 1) * input_slice_size + slice_args = (input_node, 1, start_index, stop_index) + + input_slices.append( + super().call_operator( + slice_op, slice_args, kwargs, no_q_dq_meta, updated=True + ) + ) + + filter_slices = [] + for i in range(groups): + start_index = i * output_slice_size + stop_index = (i + 1) * output_slice_size + slice_args = (weight_node, 0, start_index, stop_index) + + filter_slices.append( + super().call_operator( + slice_op, slice_args, kwargs, no_q_dq_meta, updated=True + ) + ) + + bias_slices = [] + for i in range(groups): + if bias_node is None: + bias_slices.append(None) + else: + start_index = i * output_slice_size + stop_index = (i + 1) * output_slice_size + slice_args = (bias_node, 0, start_index, stop_index) + + bias_slices.append( + super().call_operator( + slice_op, slice_args, kwargs, no_q_dq_meta, updated=True + ) + ) + + output_slices = [] + for i, (input_slice, filter_slice, bias_slice) in enumerate( + zip(input_slices, filter_slices, bias_slices) + ): + + meta_copy = DecomposeGroupedConvPass._get_meta_copy( + meta, i, output_slice_size + ) + + if op == exir_ops.edge.aten.convolution.default: + conv_args = (input_slice, filter_slice, bias_slice, *args[3:8], 1) + elif op == torch.ops.aten.conv2d.default: + conv_args = (input_slice, filter_slice, bias_slice, *args[3:6], 1) + else: + raise RuntimeError("Invalid op for grouped conv decomposition") + + output_slices.append( + super().call_operator( + conv_op, conv_args, kwargs, meta_copy, updated=True + ) + ) + + cat_args = (output_slices, 1) + # propagate original metadata (including quantization params) to the concatenated output + return super().call_operator(cat_op, cat_args, kwargs, meta, updated=True) diff --git a/backends/arm/_passes/decompose_groupnorm_pass.py b/backends/arm/_passes/decompose_groupnorm_pass.py index c6cb1b05e40..ecd4ecc23a4 100644 --- a/backends/arm/_passes/decompose_groupnorm_pass.py +++ b/backends/arm/_passes/decompose_groupnorm_pass.py @@ -3,15 +3,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import operator +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass +from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult def get_group_norm_decomposition(op) -> tuple: @@ -35,7 +39,7 @@ def get_group_norm_decomposition(op) -> tuple: torch.ops.aten.add.Tensor, torch.ops.aten.rsqrt.default, torch.ops.aten.mul.Tensor, - torch.ops.aten.view_copy.default, + torch.ops.aten.reshape.default, ) raise RuntimeError(f"Can't get group_norm composition for op {op}") @@ -57,6 +61,13 @@ class DecomposeGroupNormPass(ArmPass): Source: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html """ + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + DecomposeMeanDimPass, + DecomposeVarPass, + SizeAdjustInputPass, + } + def call(self, graph_module: torch.fx.GraphModule): modified = False for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/decompose_int16_activation_conv_pass.py b/backends/arm/_passes/decompose_int16_activation_conv_pass.py new file mode 100644 index 00000000000..0a8c5eea2b2 --- /dev/null +++ b/backends/arm/_passes/decompose_int16_activation_conv_pass.py @@ -0,0 +1,137 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import cast, Sequence, Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.quant_args import QuantArgs + +from executorch.backends.arm.tosa.specification import get_context_spec +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class DecomposeConvWithInt16ActivationPass(ArmPass): + """ + This pass decomposes a convolution with input dtype int16 and bias + into a convolution without bias followed by an addition of the bias. + We also reshape the 1D bias to [1, C, 1, …] so it broadcasts along the channel + dimension. Since the TOSA op requires the bias to be int48 which is hard to represent + in torch. Instead rescale the int48 output to int16 and add the bias in int16. + """ + + def __init__(self) -> None: + super().__init__() + + _passes_required_after: Set[Type[ExportPass]] = set() + + def bias_view_shape( + self, bias: torch.Tensor, activation_rank: int + ) -> Sequence[int]: + # reshape bias to match convolution output rank so addition broadcasts over channels + return [1, bias.shape[0], *([1] * (activation_rank - 2))] + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.convolution.default: + return super().call_operator(op, args, kwargs, meta) + + tosa_spec = get_context_spec() + if not tosa_spec.support_integer(): + return super().call_operator(op, args, kwargs, meta) + + # return if no bias + if args[2] is None: + return super().call_operator(op, args, kwargs, meta) + + activation_tensor = args[0].data + activation_rank = activation_tensor.dim() + + if activation_rank not in (4, 5) or activation_tensor.dtype != torch.int16: + return super().call_operator(op, args, kwargs, meta) + + if not tosa_spec.support_extension("int16"): + raise ValueError( + "int16 activation for convolution requires TOSA int16 extension" + ) + + # convolution with bias and activation is int16 (expected activation rank enforced above) + # The bias is assumed to be quantized with the same quantization parameters as + # the output of the convolution + bias_arg = args[2] + bias_data = bias_arg.data + + no_bias_args = list(args) + no_bias_args[2] = None + # split up to convolution + bias + convolution = super().call_operator(op, tuple(no_bias_args), kwargs, meta) + + # create a copy of the meta without the qparams, to be used with the new nodes + new_meta = meta.copy() + new_meta.data.pop("output_qparams", None) + new_meta.data.pop("input_qparams", None) + + # reshape the tensor to the same rank as the convolution output to add the bias to the channels + channel_bias = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (bias_arg, self.bias_view_shape(bias_data, activation_rank)), + {}, + new_meta, + ) + + output_dtype = meta.data["output_qparams"][0].dtype + + if output_dtype == torch.int16: + # The conv will get the output int48 scaled to int32 in serialization step. + # To be able to add the bias we need to first scale (cast?) the output to int32. + # The resulting i32 sum will then need to be scaled back to the output dtype. + output_qparams = cast(QuantArgs, meta.data["output_qparams"][0]) + conv_output_scale = output_qparams.scale + + bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2]) + per_channel_quant = bias_qparams.per_channel + + if per_channel_quant: + bias_scale = bias_qparams.get_scale_per_channel() + else: + bias_scale = [bias_qparams.get_scale_per_tensor()] + + conv_rescale_factors = [1.0] * len(bias_scale) + final_output_scale = [b / conv_output_scale for b in bias_scale] + + conv_output = super().call_operator( + exir_ops.backend.tosa.RESCALE.default, + (convolution, torch.int32, conv_rescale_factors, 0, 0), + {}, + new_meta, + ) + + add = super().call_operator( + exir_ops.edge.aten.add.Tensor, + (conv_output, channel_bias), + {}, + new_meta, + ) + + res_rescale = super().call_operator( + exir_ops.backend.tosa.RESCALE.default, + ( + add, + output_dtype, + final_output_scale, + 0, + 0, + ), + {}, + new_meta, + ) + + else: + raise NotImplementedError( + f"Decomposition to conv+add only implemented for activation of int16 type, not for {output_dtype}" + ) + + return res_rescale diff --git a/backends/arm/_passes/decompose_int_pow_pass.py b/backends/arm/_passes/decompose_int_pow_pass.py new file mode 100644 index 00000000000..4db5e45c120 --- /dev/null +++ b/backends/arm/_passes/decompose_int_pow_pass.py @@ -0,0 +1,64 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Set, Type + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class DecomposeIntPowPass(ArmPass): + """ + Replaces pow with integer exponent with a series of multiplications. + Only handles pow.Tensor_Scalar and not pow.Tensor_Tensor. + Needs to be run before doing scalar to tensor conversion. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.pow.Tensor_Scalar: + return super().call_operator(op, args, kwargs, meta) + + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + + x = args[0] + exp = args[1] + + # Handle zero first and return early + if exp == 0: + # return a tensor of ones with the same shape as x + return super().call_operator( + exir_ops.edge.aten.full_like.default, (x, 1), {}, meta, True + ) + + if not isinstance(exp, int): + return super().call_operator(op, args, kwargs, meta) + + # Handle negative exponent + if exp < 0: + x = super().call_operator( + exir_ops.edge.aten.reciprocal.default, (x,), {}, meta, True + ) + exp = -exp + + res = x + + # Consider exponentiation by squaring, if exp turns out to be large. + # Now we just roll out the multiplications. + for _ in range(exp - 1): + res = super().call_operator( + exir_ops.edge.aten.mul.Tensor, (res, x), {}, meta, True + ) + + return res diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index e6cbdfb91a0..5f56de92512 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -3,15 +3,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import operator +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass +from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult def get_layer_norm_decomposition(op) -> tuple: @@ -35,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple: torch.ops.aten.add.Tensor, torch.ops.aten.rsqrt.default, torch.ops.aten.mul.Tensor, - torch.ops.aten.view_copy.default, + torch.ops.aten.reshape.default, ) raise RuntimeError(f"Can't get layer_norm composition for op {op}") @@ -56,6 +62,13 @@ class DecomposeLayerNormPass(ArmPass): Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html """ + _passes_required_after: Set[Type[ExportPass]] = { + ComputeConstantOpsAOTPass, + DecomposeMeanDimPass, + DecomposeVarPass, + InsertTableOpsPass, + } + def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: if node.op != "call_function" or node.target not in ( diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index e896cc584be..61cf8d4138b 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -4,11 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe + +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_ops = (exir_ops.edge.aten.leaky_relu.default,) torch_ops = (torch.ops.aten.leaky_relu.default,) @@ -46,6 +48,8 @@ class DecomposeLeakyReLUPass(ArmPass): %op5 = add(%op1,%op4) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (edge_ops + torch_ops): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py index 9f036c0524f..83bbc6669ef 100644 --- a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py +++ b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py @@ -3,11 +3,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_sqrt_pass import DecomposeSqrtPass +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass from executorch.exir.pass_base import ExportPass -class DecomposeLinearVectorNormPass(ExportPass): +class DecomposeLinalgVectorNormPass(ArmPass): """ This pass decomposes aten.linalg_vector_norm.default into more primitive ops. We need to add this pass before quantization for graph annotation. @@ -28,6 +33,11 @@ class DecomposeLinearVectorNormPass(ExportPass): dtype prior, but we dont know this from FX graph. """ + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeSqrtPass, + DecomposeSumPass, + } + torch_linalg_vector_norm = (torch.ops.aten.linalg_vector_norm.default,) def call_operator(self, op, args, kwargs, meta): diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py index 3d154d9b81e..e1a9cfd0bfc 100644 --- a/backends/arm/_passes/decompose_linear_pass.py +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -3,7 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe + +from typing import Set, Type import numpy as np from executorch.backends.arm._passes import ArmPass @@ -11,8 +12,9 @@ create_node, get_first_fake_tensor, ) +from executorch.backends.arm._passes.insert_rescales_pass import InsertRescaleInt32Pass from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult class DecomposeLinearPass(ArmPass): @@ -25,6 +27,8 @@ class DecomposeLinearPass(ArmPass): output = view(conv2d) """ + _passes_required_after: Set[Type[ExportPass]] = {InsertRescaleInt32Pass} + def call(self, graph_module): for node in graph_module.graph.nodes: if node.op != "call_function": @@ -51,6 +55,8 @@ def call(self, graph_module): op_target=exir_ops.edge.aten.view_copy.default, args=(input, input_reshaped_shape), kwargs={}, + from_node=node, + inherit_qparams=False, ) # Reshape weights to 4D with shape (Co, Ci, 1, 1) @@ -59,6 +65,8 @@ def call(self, graph_module): op_target=exir_ops.edge.aten.view_copy.default, args=(weights, weights_reshaped_shape), kwargs={}, + from_node=node, + inherit_qparams=False, ) conv = create_node( @@ -77,6 +85,7 @@ def call(self, graph_module): ), kwargs={}, from_node=node, + inherit_qparams=True, ) with graph_module.graph.inserting_after(conv): @@ -89,14 +98,8 @@ def call(self, graph_module): args=(conv, list(output_shape)), kwargs={}, from_node=node, + inherit_qparams=False, ) - # Quantization parameters are inherited from original linear node, but - # output reshape should use the linear node's output qparams for both input - # and output. - if "input_qparams" in output.meta: - output.meta["input_qparams"] = output.meta.get( - "output_qparams", None - ) node.replace_all_uses_with(output) graph_module.graph.erase_node(node) diff --git a/backends/arm/_passes/decompose_logit_pass.py b/backends/arm/_passes/decompose_logit_pass.py index 40e2b22cb54..69a250b41cb 100644 --- a/backends/arm/_passes/decompose_logit_pass.py +++ b/backends/arm/_passes/decompose_logit_pass.py @@ -3,10 +3,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorByProfilePass, +) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For FP case @@ -60,6 +69,13 @@ class DecomposeLogitPass(ArmPass): log(y * reciprocal((-1) * y + 1)) """ + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + MatchArgRanksPass, + MatchArgDtypePass, + ReplaceScalarWithTensorByProfilePass, + } + def call_operator(self, op, args, kwargs, meta): if op not in [edge_logit, aten_logit]: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_masked_fill.py b/backends/arm/_passes/decompose_masked_fill.py deleted file mode 100644 index fbf3079c92b..00000000000 --- a/backends/arm/_passes/decompose_masked_fill.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - - -import torch - -from executorch.backends.arm._passes import ArmPass -from executorch.exir.dialects._ops import ops as exir_ops - - -edge_ops = (exir_ops.edge.aten.masked_fill.Scalar,) -aten_ops = (torch.ops.aten.masked_fill.Scalar,) - - -def _get_decomposition(op) -> tuple: - if op in edge_ops: - return ( - exir_ops.edge.aten.where.self, - exir_ops.edge.aten.full_like.default, - ) - if op in aten_ops: - return ( - torch.ops.aten.where.self, - torch.ops.aten.full_like.default, - ) - raise RuntimeError(f"Unable to get decomposition for op {op}") - - -class DecomposeMaskedFill(ArmPass): - """ - Masked fill takes in a boolean mask, a tensor and a scalar value. - Fills the tensor with the scalar value according to the boolean mask. - Decomposed to a where and a full_like operator. - """ - - def call_operator(self, op, args, kwargs, meta, updated=False): - if op not in (edge_ops + aten_ops): - return super().call_operator(op, args, kwargs, meta, updated) - - x, mask, scalar = args - - where_op, full_like_op = _get_decomposition(op) - - scalar_tensor = super().call_operator(full_like_op, (x, scalar), {}, meta, True) - - return super().call_operator( - where_op, (mask, scalar_tensor, x), kwargs, meta, True - ) diff --git a/backends/arm/_passes/decompose_masked_fill_pass.py b/backends/arm/_passes/decompose_masked_fill_pass.py new file mode 100644 index 00000000000..49a4bbb9b4b --- /dev/null +++ b/backends/arm/_passes/decompose_masked_fill_pass.py @@ -0,0 +1,58 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Set, Type + +import torch + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( + ConvertFullLikeToFullPass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +edge_ops = (exir_ops.edge.aten.masked_fill.Scalar,) +aten_ops = (torch.ops.aten.masked_fill.Scalar, torch.ops.aten.masked_fill_.Scalar) + + +def _get_decomposition(op) -> tuple: + if op in edge_ops: + return ( + exir_ops.edge.aten.where.self, + exir_ops.edge.aten.full_like.default, + ) + elif op in aten_ops: + return ( + torch.ops.aten.where.self, + torch.ops.aten.full_like.default, + ) + raise RuntimeError(f"Unable to get decomposition for op {op}") + + +class DecomposeMaskedFillPass(ArmPass): + """ + Masked fill takes in a boolean mask, a tensor and a scalar value. + Fills the tensor with the scalar value according to the boolean mask. + Decomposed to a where and a full_like operator. + """ + + _passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass} + + def call_operator(self, op, args, kwargs, meta, updated=False): + if op not in (*aten_ops, *edge_ops): + return super().call_operator(op, args, kwargs, meta, updated) + + x, mask, scalar = args + + where_op, full_like_op = _get_decomposition(op) + + scalar_tensor = super().call_operator(full_like_op, (x, scalar), {}, meta, True) + + return super().call_operator( + where_op, (mask, scalar_tensor, x), kwargs, meta, True + ) diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py deleted file mode 100644 index ff6db260099..00000000000 --- a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import operator - -from executorch.backends.arm._passes import ArmPass -from executorch.exir.dialects._ops import ops as exir_ops - -# We'll decompose only the EXIR edge max_pool2d ops when dilation > 1 -EDGE_MAXPOOL2D = ( - exir_ops.edge.aten.max_pool2d.default, - exir_ops.edge.aten.max_pool2d_with_indices.default, -) - - -class DecomposeMaxPool2DPass(ArmPass): - """ - Decompose dilated max_pool2d (EXIR edge ops) into space-to-batch -> maxpool -> batch-to-space. - """ - - def call_operator(self, op, args, kwargs, meta): - # Only intercept EXIR edge max_pool2d ops - if op not in EDGE_MAXPOOL2D: - return super().call_operator(op, args, kwargs, meta) - - # detect whether indices variant - is_with_indices = op is exir_ops.edge.aten.max_pool2d_with_indices.default - - # Normalize missing trailing args to their defaults - x = args[0] - kernel_size = args[1] - stride = args[2] - padding = args[3] if len(args) >= 4 else 0 - dilation = args[4] if len(args) >= 5 else 1 - ceil_mode = args[5] if len(args) == 6 else False - - # Normalize attributes - pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding - d_h, d_w = (dilation, dilation) if isinstance(dilation, int) else dilation - k_h, k_w = ( - (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size - ) - s_h, s_w = (stride, stride) if isinstance(stride, int) else stride - - # If no dilation: call EXIR edge op - if d_h == 1 and d_w == 1: - minimal_args = [x, kernel_size, stride, padding, dilation, ceil_mode] - return super().call_operator(op, tuple(minimal_args), {}, meta) - - # Compute padded and packed dimensions for dilation > 1 - N, C, H, W = x.data.size() - ph, pw = pad_h, pad_w - ph2, pw2 = pad_h, pad_w - H_pad = H + ph + ph2 - W_pad = W + pw + pw2 - H_pack = (H_pad + d_h - 1) // d_h - W_pack = (W_pad + d_w - 1) // d_w - extra_h = 0 if H_pack < k_h else (s_h - ((H_pack - k_h) % s_h)) % s_h - extra_w = 0 if W_pack < k_w else (s_w - ((W_pack - k_w) % s_w)) % s_w - ph2 += extra_h * d_h - pw2 += extra_w * d_w - - # 1) Pad via EXIR edge pad (preserves dtype) - pad_edge = exir_ops.edge.aten.constant_pad_nd.default - pads = [pw, pw2, ph, ph2, 0, 0, 0, 0] - x_pad = super().call_operator( - pad_edge, - (x, pads, 0), - {}, - meta, - ) - - # 2) Space-to-batch: reshape and permute - x2 = super().call_operator( - exir_ops.edge.aten.view_copy.default, - (x_pad, [N, C, H_pack, d_h, W_pack, d_w]), - {}, - meta, - ) - x2 = super().call_operator( - exir_ops.edge.aten.permute_copy.default, - (x2, [3, 5, 0, 1, 2, 4]), - {}, - meta, - ) - x2 = super().call_operator( - exir_ops.edge.aten.view_copy.default, - (x2, [N * d_h * d_w, C, H_pack, W_pack]), - {}, - meta, - ) - - # 3) Core pooling on packed tensor - pool_edge_op = ( - exir_ops.edge.aten.max_pool2d_with_indices.default - if is_with_indices - else exir_ops.edge.aten.max_pool2d.default - ) - pool_args = (x2, (k_h, k_w), (s_h, s_w), (0, 0), 1, ceil_mode) - pool_out = super().call_operator( - pool_edge_op, - pool_args, - {}, - meta, - ) - - # Unpack pooled result - if is_with_indices: - pooled_proxy = super().call_operator( - operator.getitem, - (pool_out, 0), - {}, - meta, - ) - indices_proxy = super().call_operator( - operator.getitem, - (pool_out, 1), - {}, - meta, - ) - pooled_fake, _ = pool_out.data - else: - pooled_proxy = pool_out - pooled_fake = pool_out.data - indices_proxy = None - - _, C_out, H_out, W_out = pooled_fake.shape - - # 4) Batch-to-space: reshape and permute back - out = super().call_operator( - exir_ops.edge.aten.view_copy.default, - (pooled_proxy, [d_h, d_w, N, C_out, H_out, W_out]), - {}, - meta, - ) - out = super().call_operator( - exir_ops.edge.aten.permute_copy.default, - (out, [2, 3, 4, 0, 5, 1]), - {}, - meta, - ) - # now flatten back into (N, C, H_out*d_h, W_out*d_w) - out = super().call_operator( - exir_ops.edge.aten.view_copy.default, - (out, [N, C_out, H_out * d_h, W_out * d_w]), - {}, - meta, - ) - - # 5) Final crop - S_top = ph // d_h + (1 if ph % d_h else 0) - S_left = pw // d_w + (1 if pw % d_w else 0) - S_top = max(0, min(S_top, H_out * d_h - H)) - S_left = max(0, min(S_left, W_out * d_w - W)) - out = super().call_operator( - exir_ops.edge.aten.slice_copy.Tensor, - (out, 2, S_top, S_top + H), - {}, - meta, - ) - out = super().call_operator( - exir_ops.edge.aten.slice_copy.Tensor, - (out, 3, S_left, S_left + W), - {}, - meta, - ) - - if is_with_indices: - # Reconstruct indices - idx = super().call_operator( - exir_ops.edge.aten.view_copy.default, - (indices_proxy, [d_h, d_w, N, C_out, H_out, W_out]), - {}, - meta, - ) - idx = super().call_operator( - exir_ops.edge.aten.permute_copy.default, - (idx, [2, 3, 4, 0, 5, 1]), - {}, - meta, - ) - idx = super().call_operator( - exir_ops.edge.aten.view_copy.default, - (idx, [N, C_out, H_out * d_h, W_out * d_w]), - {}, - meta, - ) - idx = super().call_operator( - exir_ops.edge.aten.slice_copy.Tensor, - (idx, 2, S_top, S_top + H), - {}, - meta, - ) - idx = super().call_operator( - exir_ops.edge.aten.slice_copy.Tensor, - (idx, 3, S_left, S_left + W), - {}, - meta, - ) - return out, idx - - return out diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py new file mode 100644 index 00000000000..bf3f6afc418 --- /dev/null +++ b/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py @@ -0,0 +1,218 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import operator +from typing import Set, Type + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +# We'll decompose only the EXIR edge max_pool2d ops when dilation > 1 +EDGE_MAXPOOL2D = ( + exir_ops.edge.aten.max_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, +) + + +class DecomposeMaxPool2dPass(ArmPass): + """ + Decompose dilated max_pool2d (EXIR edge ops) into space-to-batch -> maxpool -> batch-to-space. + """ + + _passes_required_after: Set[Type[ExportPass]] = { + SizeAdjustInputPass, + } + + def call_operator(self, op, args, kwargs, meta): + # Only intercept EXIR edge max_pool2d ops + if op not in EDGE_MAXPOOL2D: + return super().call_operator(op, args, kwargs, meta) + + # detect whether indices variant + is_with_indices = op is exir_ops.edge.aten.max_pool2d_with_indices.default + + # Normalize missing trailing args to their defaults + x = args[0] + kernel_size = args[1] + stride = args[2] + padding = args[3] if len(args) >= 4 else 0 + dilation = args[4] if len(args) >= 5 else 1 + ceil_mode = args[5] if len(args) == 6 else False + + # Normalize attributes + pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding + d_h, d_w = (dilation, dilation) if isinstance(dilation, int) else dilation + k_h, k_w = ( + (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + ) + s_h, s_w = (stride, stride) if isinstance(stride, int) else stride + + # If no dilation: call EXIR edge op + if d_h == 1 and d_w == 1: + minimal_args = [x, kernel_size, stride, padding, dilation, ceil_mode] + return super().call_operator(op, tuple(minimal_args), {}, meta) + + # Compute padded and packed dimensions for dilation > 1 + N, C, H, W = x.data.size() + ph, pw = pad_h, pad_w + ph2, pw2 = pad_h, pad_w + H_pad = H + ph + ph2 + W_pad = W + pw + pw2 + H_pack = (H_pad + d_h - 1) // d_h + W_pack = (W_pad + d_w - 1) // d_w + extra_h = 0 if H_pack < k_h else (s_h - ((H_pack - k_h) % s_h)) % s_h + extra_w = 0 if W_pack < k_w else (s_w - ((W_pack - k_w) % s_w)) % s_w + ph2 += extra_h * d_h + pw2 += extra_w * d_w + + meta_with_no_qparams = meta.copy() + meta_with_no_qparams.data["output_qparams"] = {} + meta_with_no_qparams.data["input_qparams"] = {} + meta_with_no_output_qparams = meta.copy() + meta_with_no_output_qparams.data["output_qparams"] = {} + + # 1) Pad via EXIR edge pad (preserves dtype) + pad_edge = exir_ops.edge.aten.constant_pad_nd.default + pads = [pw, pw2, ph, ph2, 0, 0, 0, 0] + x_pad = super().call_operator( + pad_edge, + (x, pads, 0), + {}, + meta_with_no_output_qparams, + ) + + # 2) Space-to-batch: reshape and permute + x2 = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (x_pad, [N, C, H_pack, d_h, W_pack, d_w]), + {}, + meta_with_no_qparams, + ) + x2 = super().call_operator( + exir_ops.edge.aten.permute_copy.default, + (x2, [3, 5, 0, 1, 2, 4]), + {}, + meta_with_no_qparams, + ) + x2 = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (x2, [N * d_h * d_w, C, H_pack, W_pack]), + {}, + meta_with_no_qparams, + ) + + # 3) Core pooling on packed tensor + pool_edge_op = ( + exir_ops.edge.aten.max_pool2d_with_indices.default + if is_with_indices + else exir_ops.edge.aten.max_pool2d.default + ) + pool_args = (x2, (k_h, k_w), (s_h, s_w), (0, 0), 1, ceil_mode) + pool_out = super().call_operator( + pool_edge_op, + pool_args, + {}, + meta, + ) + + # Unpack pooled result + if is_with_indices: + pooled_proxy = super().call_operator( + operator.getitem, + (pool_out, 0), + {}, + meta_with_no_qparams, + ) + indices_proxy = super().call_operator( + operator.getitem, + (pool_out, 1), + {}, + meta_with_no_qparams, + ) + pooled_fake, _ = pool_out.data + else: + pooled_proxy = pool_out + pooled_fake = pool_out.data + indices_proxy = None + + _, C_out, H_out, W_out = pooled_fake.shape + + # 4) Batch-to-space: reshape and permute back + out = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (pooled_proxy, [d_h, d_w, N, C_out, H_out, W_out]), + {}, + meta_with_no_qparams, + ) + out = super().call_operator( + exir_ops.edge.aten.permute_copy.default, + (out, [2, 3, 4, 0, 5, 1]), + {}, + meta_with_no_qparams, + ) + # now flatten back into (N, C, H_out*d_h, W_out*d_w) + out = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (out, [N, C_out, H_out * d_h, W_out * d_w]), + {}, + meta_with_no_qparams, + ) + + # 5) Final crop + S_top = ph // d_h + (1 if ph % d_h else 0) + S_left = pw // d_w + (1 if pw % d_w else 0) + S_top = max(0, min(S_top, H_out * d_h - H)) + S_left = max(0, min(S_left, W_out * d_w - W)) + out = super().call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + (out, 2, S_top, S_top + H), + {}, + meta_with_no_qparams, + ) + out = super().call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + (out, 3, S_left, S_left + W), + {}, + meta_with_no_qparams, + ) + + if is_with_indices: + # Reconstruct indices + idx = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (indices_proxy, [d_h, d_w, N, C_out, H_out, W_out]), + {}, + meta_with_no_qparams, + ) + idx = super().call_operator( + exir_ops.edge.aten.permute_copy.default, + (idx, [2, 3, 4, 0, 5, 1]), + {}, + meta, + ) + idx = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (idx, [N, C_out, H_out * d_h, W_out * d_w]), + {}, + meta_with_no_qparams, + ) + idx = super().call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + (idx, 2, S_top, S_top + H), + {}, + meta_with_no_qparams, + ) + idx = super().call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + (idx, 3, S_left, S_left + W), + {}, + meta_with_no_qparams, + ) + return out, idx + + return out diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index a78514b6af5..9bff06b4dfe 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -5,22 +5,30 @@ from copy import copy from math import prod +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) +from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass +from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass def get_meandim_decomposition(op) -> tuple: - if op == exir_ops.edge.aten.mean.dim: + if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): return ( exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.full.default, exir_ops.edge.aten.mul.Tensor, ) - if op == torch.ops.aten.mean.dim: + if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): return ( torch.ops.aten.sum.dim_IntList, torch.ops.aten.full.default, @@ -30,21 +38,30 @@ def get_meandim_decomposition(op) -> tuple: def get_avgpool(op): - if op == exir_ops.edge.aten.mean.dim: + if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): return exir_ops.edge.aten.avg_pool2d.default - if op == torch.ops.aten.mean.dim: + if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): return torch.ops.aten.avg_pool2d.default raise RuntimeError(f"Can't get meandim decomposition for op {op}") def get_view(op): - if op == exir_ops.edge.aten.mean.dim: + if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): return exir_ops.edge.aten.view_copy.default - if op == torch.ops.aten.mean.dim: - return torch.ops.aten.view_copy.default + if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): + return torch.ops.aten.reshape.default raise RuntimeError(f"Can't get meandim decomposition for op {op}") +def get_quantization(op): + """Returns quant and dequant op of same type (per_channel/ tensor) as op if op is a dequant node, None otherwise.""" + if op in DQ_OPS: + # Input of op can be placeholder, can't use that to get quant node directly. + quant_type_index = DQ_OPS.index(op) + return Q_OPS[quant_type_index], op + return None + + class DecomposeMeanDimPass(ArmPass): """ Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for: @@ -62,6 +79,12 @@ class DecomposeMeanDimPass(ArmPass): x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False """ + _passes_required_after: Set[Type[ExportPass]] = { + ComputeConstantOpsAOTPass, + DecomposeSumPass, + SizeAdjustInputPass, + } + def __init__(self, graph_module, tosa_spec): super().__init__() self._graph_module = graph_module @@ -76,13 +99,20 @@ def __init__(self, graph_module, tosa_spec): ) def call_operator(self, op, args, kwargs, meta): - if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim): + if op not in ( + exir_ops.edge.aten.mean.dim, + torch.ops.aten.mean.dim, + exir_ops.edge.aten.mean.default, + torch.ops.aten.mean.default, + ): return super().call_operator(op, args, kwargs, meta) x = get_node_arg(args, 0) input_shape = list(x.data.shape) output_shape = list(meta["val"].shape) - dims_to_reduce = get_node_arg(args, 1) + dims_to_reduce = get_node_arg(args, 1, range(len(input_shape))) + if dims_to_reduce is None: + dims_to_reduce = range(len(input_shape)) dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce] dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1] @@ -103,6 +133,7 @@ def call_operator(self, op, args, kwargs, meta): dims_to_reduce = [dim - 1 for dim in dims_to_reduce] x = super().call_operator(view_op, (x, new_shape), {}, meta, True) + x = self._maybe_insert_q_dq_after(x, meta) # Reduce (h,w) dims by avg pool if possible x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta) @@ -115,7 +146,7 @@ def call_operator(self, op, args, kwargs, meta): dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce] x = super().call_operator(view_op, (x, temp_shape), {}, meta, True) - + x = self._maybe_insert_q_dq_after(x, meta) # Reduce remaining dims by sum x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype) @@ -138,6 +169,45 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype): full = super().call_operator( full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True ) + if (quant_ops := get_quantization(input_node.node.target)) is not None: + # Insert Q and DQ nodes after full op. + # Since the value of full is known, we can compute quant params such that dq(q_max_value) + q_op, dq_op = quant_ops + qmax = input_node.node.args[4] + full_quant_args = ( + 1 / (N * qmax), # Scale to map qmax to 1/N + 0, # Zero point + *input_node.node.args[3:], + ) + q_args = (full, *full_quant_args) + full = super().call_operator( + q_op, + q_args, + kwargs={}, + meta=meta, + updated=True, + ) + dq_args = (full, *full_quant_args) + full = super().call_operator( + dq_op, dq_args, kwargs={}, meta=meta, updated=True + ) + + # Insert Q and DQ nodes after sum op. + # Scale needs to be adjusted with N, since it was computed on data after the division with N. + sum_quant_args = (input_node.node.args[1] * N, *input_node.node.args[2:]) + q_args = (sum, *sum_quant_args) + sum = super().call_operator( + q_op, + q_args, + kwargs={}, + meta=meta, + updated=True, + ) + dq_args = (sum, *sum_quant_args) + sum = super().call_operator( + dq_op, dq_args, kwargs={}, meta=meta, updated=True + ) + return super().call_operator(mul_op, (sum, full), {}, meta, True) def _reduce_by_average_pool(self, op, input_node, dims, meta): @@ -172,10 +242,38 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta): ) if is_supported: + out = super().call_operator(avgpool_op, args, {}, meta, True) + out = self._maybe_insert_q_dq_after(out, meta) return ( - super().call_operator(avgpool_op, args, {}, meta, True), + out, dims_to_reduce_by_sum, ) else: return input_node, dims + + def _maybe_insert_q_dq_after(self, op, meta): + """If the input node of op is a dequant node, insert a q-dq pair after op with identical quantization parameters.""" + + if len(op.node.all_input_nodes) > 1: + raise ValueError( + f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}" + ) + input_node = op.node.all_input_nodes[0] + if (quant_ops := get_quantization(input_node.target)) is not None: + q_op, dq_op = quant_ops + quant_args = list(input_node.args[1:]) + q_args = (op, *quant_args) + out = super().call_operator( + q_op, + q_args, + kwargs={}, + meta=meta, + updated=True, + ) + dq_args = (out, *quant_args) + return super().call_operator( + dq_op, dq_args, kwargs={}, meta=meta, updated=True + ) + else: + return op diff --git a/backends/arm/_passes/decompose_ne_pass.py b/backends/arm/_passes/decompose_ne_pass.py index 16443d5d2fb..3bd4f4540bb 100644 --- a/backends/arm/_passes/decompose_ne_pass.py +++ b/backends/arm/_passes/decompose_ne_pass.py @@ -3,9 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass edge_ne_ops = (exir_ops.edge.aten.ne.Tensor,) aten_ne_ops = (torch.ops.aten.ne.Tensor, torch.ops.aten.ne_.Tensor) @@ -53,6 +56,8 @@ class DecomposeNotEqualPass(ArmPass): - followed by aten.logical_not.default or its edge equivalent """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (edge_ne_ops + aten_ne_ops): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_quant_nodes.py b/backends/arm/_passes/decompose_quant_nodes.py new file mode 100644 index 00000000000..3cc99e7baca --- /dev/null +++ b/backends/arm/_passes/decompose_quant_nodes.py @@ -0,0 +1,156 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Set, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.decompose_round_pass import DecomposeRoundPass +from executorch.backends.arm.constants import DEQUANT_PER_TENSOR_OP, QUANT_PER_TENSOR_OP +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class DecomposeQuantNodesPass(ArmPass): + """Decomposes quantization nodes into more primitive operations by rewriting the graph + using the two formulas: + + quantized value = clamp(round(fp32_value / scale) + zero point, qmin, qmax) + + fp32_value = (quantized value - zp) * scale + + For quantization nodes, the pass replaces them with: + + 1. Multiplying the input by the inverse of the scale factor. + 2. Rounding the result. + 3. Adding the zero point. + 4. Clamping the result to [qmin, qmax]. + 5. Casting to the target data type. + + For dequantization nodes, the pass replaces them with: + + 1. Casting the input to int32. + 2. Subtracting the zero point. + 3. Casting to float32. + 4. Multiplying by the scale factor. + + """ + + _passes_required_after: Set[Type[ExportPass]] = {DecomposeRoundPass} + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + for node in list(graph_module.graph.nodes): + if node.op != "call_function" or node.target not in ( + QUANT_PER_TENSOR_OP, + DEQUANT_PER_TENSOR_OP, + ): + continue + if node.target == DEQUANT_PER_TENSOR_OP and all( + user.target == QUANT_PER_TENSOR_OP for user in node.users + ): + continue + elif ( + node.target == QUANT_PER_TENSOR_OP + and node.all_input_nodes[0].target == DEQUANT_PER_TENSOR_OP + ): + continue + modified = True + args = node.args + input_rank = args[0].meta["val"].ndim + x, scale, zero_point, qmin, qmax, dtype = args + # Instead of dividing by scale in quantization, we multiply by 1/scale + # when quantizing. + scale = cast(float, scale) + scale = scale if node.target == DEQUANT_PER_TENSOR_OP else 1.0 / scale + with graph_module.graph.inserting_before(node): + scale_const = create_node( + graph_module.graph, + exir_ops.edge.aten.full.default, + args=((1,) * input_rank, scale), + kwargs={"dtype": torch.float32}, + ) + zp_const = create_node( + graph_module.graph, + exir_ops.edge.aten.full.default, + args=((1,) * input_rank, zero_point), + kwargs={ + "dtype": ( + torch.float32 + if node.target == QUANT_PER_TENSOR_OP + else torch.int32 + ) + }, + ) + if node.target == QUANT_PER_TENSOR_OP: + # TODO MLETORCH-1587: Decompose quantization nodes using more integer arithmetic + scaled = create_node( + graph_module.graph, + exir_ops.edge.aten.mul.Tensor, + args=(x, scale_const), + from_node=node, + ) + rounded = create_node( + graph_module.graph, + exir_ops.edge.aten.round.default, + args=(scaled,), + from_node=node, + ) + shifted = create_node( + graph_module.graph, + exir_ops.edge.aten.add.Tensor, + args=(rounded, zp_const), + from_node=node, + ) + clamped = create_node( + graph_module.graph, + exir_ops.edge.aten.clamp.default, + args=(shifted, float(qmin), float(qmax)), + from_node=node, + ) + quantized = create_node( + graph_module.graph, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + args=(clamped,), + kwargs={"dtype": dtype}, + from_node=node, + ) + output = quantized + else: + input_casted_to_zp_dtype = create_node( + graph_module.graph, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + args=(x,), + kwargs={"dtype": torch.int32}, + from_node=node, + ) + shifted = create_node( + graph_module.graph, + exir_ops.edge.aten.sub.Tensor, + args=(input_casted_to_zp_dtype, zp_const), + from_node=node, + ) + casted_to_float = create_node( + graph_module.graph, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + args=(shifted,), + kwargs={"dtype": torch.float32}, + from_node=node, + ) + dequantized = create_node( + graph_module.graph, + exir_ops.edge.aten.mul.Tensor, + args=(casted_to_float, scale_const), + from_node=node, + ) + output = dequantized + node.replace_all_uses_with(output) + graph_module.graph.erase_node(node) + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified=modified) diff --git a/backends/arm/_passes/decompose_remainder_pass.py b/backends/arm/_passes/decompose_remainder_pass.py new file mode 100644 index 00000000000..6c11a7b600e --- /dev/null +++ b/backends/arm/_passes/decompose_remainder_pass.py @@ -0,0 +1,70 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_div_tensor_mode import ( + DecomposeDivTensorModePass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass +from torch._ops import OpOverload + +Op = OpOverload | EdgeOpOverload + +_decomposition_ops: Dict[Op, tuple[Op, Op, Op]] = { + exir_ops.edge.aten.remainder.Scalar: ( + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.mul.Scalar, + exir_ops.edge.aten.sub.Tensor, + ), + torch.ops.aten.remainder.Tensor: ( + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.mul.Tensor, + torch.ops.aten.sub.Tensor, + ), + torch.ops.aten.remainder.Scalar: ( + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.mul.Scalar, + torch.ops.aten.sub.Tensor, + ), + exir_ops.edge.aten.remainder.Tensor: ( + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sub.Tensor, + ), +} + + +class DecomposeRemainderPass(ArmPass): + """ + Decompose the remainder operation into primitive arithmetic: + remainder(x, y) -> x - floor_div(x, y) * y + where floor_div(x, y) == div(x, y, rounding_mode="floor"). + """ + + _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass} + + def call_operator(self, op, args, kwargs, meta, updated=False): + supported_ops = ( + exir_ops.edge.aten.remainder.Scalar, + exir_ops.edge.aten.remainder.Tensor, + torch.ops.aten.remainder.Scalar, + torch.ops.aten.remainder.Tensor, + ) + if op not in supported_ops: + return super().call_operator(op, args, kwargs, meta, updated) + + div_op, mul_op, sub_op = _decomposition_ops[op] + x, y = args[0], args[1] + + floor_div = super().call_operator( + div_op, (x, y), {"rounding_mode": "floor"}, meta, updated=True + ) + product = super().call_operator(mul_op, (floor_div, y), {}, meta, updated=True) + return super().call_operator(sub_op, (x, product), {}, meta, updated=True) diff --git a/backends/arm/_passes/decompose_round_pass.py b/backends/arm/_passes/decompose_round_pass.py index edfa3817064..35d36e80396 100644 --- a/backends/arm/_passes/decompose_round_pass.py +++ b/backends/arm/_passes/decompose_round_pass.py @@ -3,10 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass from torch._ops import OpOverload @@ -56,6 +59,8 @@ class DecomposeRoundPass(ArmPass): %result = where(%is_non_negative, %floor, %ceil) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (exir_ops.edge.aten.round.default, torch.ops.aten.round.default): return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_sdpa_pass.py b/backends/arm/_passes/decompose_sdpa_pass.py new file mode 100644 index 00000000000..566b43d5aa3 --- /dev/null +++ b/backends/arm/_passes/decompose_sdpa_pass.py @@ -0,0 +1,16 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.transforms import decompose_sdpa +from executorch.exir.pass_base import ExportPass + + +class DecomposeScaledDotProductAttentionPass( + ArmPass, decompose_sdpa.DecomposeScaledDotProductAttention +): + _passes_required_after: Set[Type[ExportPass]] = set() diff --git a/backends/arm/_passes/decompose_select.py b/backends/arm/_passes/decompose_select.py index 99c89f474ea..23b100ca41b 100644 --- a/backends/arm/_passes/decompose_select.py +++ b/backends/arm/_passes/decompose_select.py @@ -4,22 +4,29 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe + +from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, ) +from executorch.backends.arm._passes.convert_squeezes_to_view import ( + ConvertSqueezesToViewPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -class DecomposeSelectPass(ExportPass): +class DecomposeSelectPass(ArmPass): """ This pass decomposes select into slice + squeeze to ensure that Aten and TOSA outputs has the same rank (input rank -1) """ + _passes_required_after: Set[Type[ExportPass]] = {ConvertSqueezesToViewPass} + def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: @@ -45,10 +52,18 @@ def call(self, graph_module: torch.fx.GraphModule): with graph_module.graph.inserting_before(node): slice_node = create_node( - graph_module.graph, slice_op, (input_node, dim, index, index + 1) + graph_module.graph, + slice_op, + (input_node, dim, index, index + 1), + from_node=node, + inherit_qparams=False, ) squeeze_node = create_node( - graph_module.graph, squeeze_op, (slice_node, [dim]), from_node=node + graph_module.graph, + squeeze_op, + (slice_node, [dim]), + from_node=node, + inherit_qparams=True, ) node.replace_all_uses_with(squeeze_node) diff --git a/backends/arm/_passes/decompose_select_scatter_pass.py b/backends/arm/_passes/decompose_select_scatter_pass.py new file mode 100644 index 00000000000..f3c7ae5955b --- /dev/null +++ b/backends/arm/_passes/decompose_select_scatter_pass.py @@ -0,0 +1,143 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_int64_const_ops_to_int32 import ( + ConvertInt64ConstOpsToInt32Pass, +) +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorByProfilePass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +edge_scatter_ops = (exir_ops.edge.aten.select_scatter.default,) +aten_scatter_ops = (torch.ops.aten.select_scatter.default,) + + +def get_select_scatter_decomposition(op) -> tuple: + if op in edge_scatter_ops: + return ( + exir_ops.edge.aten.arange.start_step, + exir_ops.edge.aten.eq.Scalar, + exir_ops.edge.aten.where.self, + exir_ops.edge.aten.expand_copy.default, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.view_copy.default, + ) + if op in aten_scatter_ops: + return ( + torch.ops.aten.arange.start_step, + torch.ops.aten.eq.Scalar, + torch.ops.aten.where.self, + torch.ops.aten.expand_copy.default, + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.view_copy.default, + ) + + raise RuntimeError(f"Can't get select_scatter decomposition for op {op}") + + +class DecomposeSelectScatterPass(ArmPass): + """select_scatter is decomposed into other ops during export, however this is only + suppported for the fp profile and for the int profile we need to decompose it here. + + The decomposition is as follows: + - Build a boolean mask the size of x + eq(view(arange(0, dim_size), mask_shape), index) + - Broadcast source to x + expand(unsqueeze(source, dim), shape) + - Route the updated slice while keeping the untouched lanes + where(mask, expanded_source, x) + + This reflects the decomposition for the fp profile implemented in torch._refs + """ + + _passes_required_after: Set[Type[ExportPass]] = { + ReplaceScalarWithTensorByProfilePass, + ConvertInt64ConstOpsToInt32Pass, + } + + def call_operator(self, op, args, kwargs, meta): + if op not in (edge_scatter_ops + aten_scatter_ops): + return super().call_operator(op, args, kwargs, meta, updated=False) + + ( + arange_op, + eq_op, + where_op, + expand_op, + unsqueeze_op, + view_op, + ) = get_select_scatter_decomposition(op) + + input_tensor = args[0] + src_tensor = args[1] + dim = int(args[2]) + index = int(args[3]) + + shape = input_tensor.data.size() + rank = len(shape) + dim = dim % rank if dim < 0 else dim + dim_size = shape[dim] + if index < 0: + index = index + dim_size + + mask_shape = [1] * rank + mask_shape[dim] = -1 + + arange_node = super().call_operator( + arange_op, + (0, dim_size, 1), + {}, + meta, + updated=False, + ) + + view_node = super().call_operator( + view_op, + (arange_node, mask_shape), + {}, + meta, + updated=False, + ) + + mask_node = super().call_operator( + eq_op, + (view_node, index), + {}, + meta, + updated=False, + ) + + unsqueeze_node = super().call_operator( + unsqueeze_op, + (src_tensor, dim), + {}, + meta, + updated=False, + ) + + expand_node = super().call_operator( + expand_op, + (unsqueeze_node, shape), + {}, + meta, + updated=False, + ) + + where_node = super().call_operator( + where_op, + (mask_node, expand_node, input_tensor), + {}, + meta, + updated=True, + ) + + return where_node diff --git a/backends/arm/_passes/decompose_sign_pass.py b/backends/arm/_passes/decompose_sign_pass.py index 1038ff0f3fa..c4cb964316d 100644 --- a/backends/arm/_passes/decompose_sign_pass.py +++ b/backends/arm/_passes/decompose_sign_pass.py @@ -3,10 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case @@ -42,6 +45,8 @@ def get_ops(op): class DecomposeSignPass(ArmPass): """Decomposes the sign operator into a sequence of operations that are supported by the Arm backend.""" + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in (edge_sign, aten_sign): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_silu_pass.py b/backends/arm/_passes/decompose_silu_pass.py index 68ebb3f4515..80c9413acfb 100644 --- a/backends/arm/_passes/decompose_silu_pass.py +++ b/backends/arm/_passes/decompose_silu_pass.py @@ -3,15 +3,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe + +from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.pass_base import ExportPass aten_silu_ops = (torch.ops.aten.silu.default, torch.ops.aten.silu_.default) -class DecomposeSiluPass(ExportPass): +class DecomposeSiluPass(ArmPass): """ This pass decomposes silu into a mul and a sigmoid node. @@ -22,6 +25,8 @@ class DecomposeSiluPass(ExportPass): y = mul(a,x) """ + _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} + def call_operator(self, op, args, kwargs, meta): if op not in (aten_silu_ops): return super().call_operator(op, args, kwargs, meta) @@ -29,6 +34,8 @@ def call_operator(self, op, args, kwargs, meta): mul_op = torch.ops.aten.mul.Tensor original = args[0] - sigmoid = super().call_operator(sigmoid_op, (original,), {}, meta) + sigmoid = super().call_operator(sigmoid_op, (original,), {}, meta, updated=True) - return super().call_operator(mul_op, (original, sigmoid), {}, meta) + return super().call_operator( + mul_op, (original, sigmoid), {}, meta, updated=True + ) diff --git a/backends/arm/_passes/decompose_sinh_pass.py b/backends/arm/_passes/decompose_sinh_pass.py index 7192eb9bf74..731f9a5dbf3 100644 --- a/backends/arm/_passes/decompose_sinh_pass.py +++ b/backends/arm/_passes/decompose_sinh_pass.py @@ -4,8 +4,17 @@ # LICENSE file in the root directory of this source tree. +from typing import Set, Type + from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorByProfilePass, +) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For MI case @@ -24,10 +33,25 @@ class DecomposeSinhPass(ArmPass): and scalar multiplication. """ + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + MatchArgRanksPass, + ReplaceScalarWithTensorByProfilePass, + MatchArgDtypePass, + } + def call_operator(self, op, args, kwargs, meta): if op is not edge_sinh: return super().call_operator(op, args, kwargs, meta) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + x = args sub_op, exp_op, neg_op, mul_op = ( diff --git a/backends/arm/_passes/decompose_softmax_pass.py b/backends/arm/_passes/decompose_softmax_pass.py index a735501f711..ee841e54f26 100644 --- a/backends/arm/_passes/decompose_softmax_pass.py +++ b/backends/arm/_passes/decompose_softmax_pass.py @@ -3,7 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -49,7 +54,7 @@ def _get_logsoftmax_ops(op) -> tuple: raise RuntimeError(f"Can't get logsoftmax decomposition ops for op {op}") -class DecomposeSoftmaxPass(ExportPass): +class DecomposeSoftmaxPass(ArmPass): """ This pass decomposes log_softmax or softmax into more primitive ops. Example: @@ -62,6 +67,11 @@ class DecomposeSoftmaxPass(ExportPass): (in logsoftmax case: %op7 = log(%op6)) """ + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeSumPass, + InsertTableOpsPass, + } + def call_operator(self, op, args, kwargs, meta): if op not in torch_softmax + edge_softmax: return super().call_operator(op, args, kwargs, meta) @@ -70,12 +80,12 @@ def call_operator(self, op, args, kwargs, meta): ) _input = args[0] dim = [args[1]] - op1 = super().call_operator(max_op, (_input, dim, True), {}, meta) - op2 = super().call_operator(sub_op, (_input, op1), {}, meta) - op3 = super().call_operator(exp_op, (op2,), {}, meta) - op4 = super().call_operator(sum_op, (op3, dim, True), {}, meta) - op5 = super().call_operator(reciprocal_op, (op4,), {}, meta) - op6 = super().call_operator(mul_op, (op3, op5), {}, meta) + op1 = super().call_operator(max_op, (_input, dim, True), {}, meta, updated=True) + op2 = super().call_operator(sub_op, (_input, op1), {}, meta, updated=True) + op3 = super().call_operator(exp_op, (op2,), {}, meta, updated=True) + op4 = super().call_operator(sum_op, (op3, dim, True), {}, meta, updated=True) + op5 = super().call_operator(reciprocal_op, (op4,), {}, meta, updated=True) + op6 = super().call_operator(mul_op, (op3, op5), {}, meta, updated=True) if op in log_softmax: - op6 = super().call_operator(log_op, (op6,), {}, meta) + op6 = super().call_operator(log_op, (op6,), {}, meta, updated=True) return op6 diff --git a/backends/arm/_passes/decompose_softmax_unstable_pass.py b/backends/arm/_passes/decompose_softmax_unstable_pass.py index b6f5e11b66b..75cd90e4651 100644 --- a/backends/arm/_passes/decompose_softmax_unstable_pass.py +++ b/backends/arm/_passes/decompose_softmax_unstable_pass.py @@ -3,11 +3,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe + +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass # For BI case torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int) @@ -57,6 +61,11 @@ class DecomposeSoftmaxUnstablePass(ArmPass): (in logsoftmax case: %op5 = log(%op4)) """ + _passes_required_after: Set[Type[ExportPass]] = { + DecomposeSumPass, + InsertTableOpsPass, + } + def call_operator(self, op, args, kwargs, meta): if op not in torch_softmax + edge_softmax: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py index 547d0091e90..6d78c70634f 100644 --- a/backends/arm/_passes/decompose_sqrt_pass.py +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -3,10 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe -from typing import Tuple, Union +from typing import Set, Tuple, Type, Union import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -26,7 +27,8 @@ def get_sqrt_decomposition(op) -> Union[Tuple, torch._ops.OpOverload]: raise RuntimeError(f"Can't get sqrt decomposition for op {op}") -class DecomposeSqrtPass(ExportPass): +class DecomposeSqrtPass(ArmPass): + _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} def call_operator(self, op, args, kwargs, meta): """ @@ -36,6 +38,14 @@ def call_operator(self, op, args, kwargs, meta): if op not in (edge_sqrt_ops + aten_sqrt_ops): return super().call_operator(op, args, kwargs, meta) + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # If quantized, node should be replace by table op + return super().call_operator(op, args, kwargs, meta) + pow_op = get_sqrt_decomposition(op) - return super().call_operator(pow_op, (args[0], 0.5), {}, meta) + return super().call_operator(pow_op, (args[0], 0.5), {}, meta, updated=True) diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index 52b9c10c49f..0e63ef38669 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -3,7 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch +from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -16,12 +19,12 @@ def _get_sum_decomp(op): exir_ops.edge.aten.sum.dim_IntList, ) case torch.ops.aten.sum.dim_IntList: - return (torch.ops.aten.view_copy.default, torch.ops.aten.sum.dim_IntList) + return (torch.ops.aten.reshape.default, torch.ops.aten.sum.dim_IntList) case _: raise RuntimeError("Unvalid op in DecomposeSumPass") -class DecomposeSumPass(ExportPass): +class DecomposeSumPass(ArmPass): """ In Pytorch, the default behaviour of for example Tensor.sum is to squeeze the dimension that is summed (keep_dim = False). However, in TOSA, REDUCE_SUM always @@ -40,6 +43,8 @@ class DecomposeSumPass(ExportPass): view(shape = squeezed_shape) -> squeezed_shape """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in [ exir_ops.edge.aten.sum.dim_IntList, @@ -63,8 +68,8 @@ def call_operator(self, op, args, kwargs, meta): case _: raise ValueError(f"Invalid number of arguments ({len(args)}) provided.") - # If dims is None, sum over all dimensions - if dims is None: + # If dims evaluates to False (None or []), sum over all dimensions + if not dims: shape = input_node.data.size() dims = list(range(len(shape))) @@ -72,13 +77,17 @@ def call_operator(self, op, args, kwargs, meta): for dim in dims: input_node = super().call_operator( - sum_op, (input_node, dim, True), kwargs, meta + sum_op, + (input_node, dim, True), + kwargs, + meta, + updated=True, ) if not keepdims: shape = list(meta["val"].size()) input_node = super().call_operator( - view_op, (input_node, shape), kwargs, meta + view_op, (input_node, shape), {}, meta, updated=True ) return input_node diff --git a/backends/arm/_passes/decompose_tosa_unsupported_clamp_pass.py b/backends/arm/_passes/decompose_tosa_unsupported_clamp_pass.py new file mode 100644 index 00000000000..b467f6795b3 --- /dev/null +++ b/backends/arm/_passes/decompose_tosa_unsupported_clamp_pass.py @@ -0,0 +1,97 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class DecomposeTOSAUnsupportedClampPass(ArmPass): + """Rewrite TOSA unsupported clamp into min/max chain since TOSA lacks int32 clamp support + and only supports scalar min/max values.""" + + _passes_required_after: Set[Type[ExportPass]] = set() + _supported_ops = { + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.clamp.Tensor, + torch.ops.aten.clamp.default, + torch.ops.aten.clamp.Tensor, + } + + def _ensure_tensor( + self, + value, + ref_tensor, + dtype, + rank, + meta, + ): + if value is None: + return None + return super().call_operator( + exir_ops.edge.aten.full.default, + ((1,) * rank, value), + {"dtype": dtype}, + meta, + updated=True, + ) + + def call_operator(self, op, args, kwargs, meta): + val = meta["val"] + + is_scalar_clamp = op in { + exir_ops.edge.aten.clamp.default, + torch.ops.aten.clamp.default, + } + is_tensor_clamp = op in { + exir_ops.edge.aten.clamp.Tensor, + torch.ops.aten.clamp.Tensor, + } + + if op not in self._supported_ops: + return super().call_operator(op, args, kwargs, meta) + + # Only rewrite scalar clamp for int32 + if is_scalar_clamp and val.dtype != torch.int32: + return super().call_operator(op, args, kwargs, meta) + + input_tensor = args[0] + dtype = val.dtype + rank = len(val.shape) + min_arg = args[1] if len(args) > 1 else None + max_arg = args[2] if len(args) > 2 else None + + if is_scalar_clamp: + # Scalar min/max -> make them tensors for min/max ops + min_arg = self._ensure_tensor(min_arg, input_tensor, dtype, rank, meta) + max_arg = self._ensure_tensor(max_arg, input_tensor, dtype, rank, meta) + else: + # Tensor variant: arguments are already tensors; nothing extra to do + if not is_tensor_clamp: + raise RuntimeError( + f"DecomposeTOSAUnsupportedClampPass: unexpected op {op} in tensor clamp branch" + ) + + current = input_tensor + if min_arg is not None: + current = super().call_operator( + exir_ops.edge.aten.maximum.default, + (current, min_arg), + {}, + meta, + updated=True, + ) + if max_arg is not None: + current = super().call_operator( + exir_ops.edge.aten.minimum.default, + (current, max_arg), + {}, + meta, + updated=True, + ) + return current diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index 15872738f3e..bb2e2066a06 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -4,13 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg +from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass def get_var_decomposition(op) -> tuple: @@ -47,6 +53,12 @@ class DecomposeVarPass(ArmPass): y = div(sum, max(0, N-correction)) """ + _passes_required_after: Set[Type[ExportPass]] = { + ComputeConstantOpsAOTPass, + DecomposeMeanDimPass, + DecomposeSumPass, + } + def call_operator(self, op, args, kwargs, meta): if op not in ( exir_ops.edge.aten.var.correction, diff --git a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py index 17a682c0a8e..a6f69a1fcc9 100644 --- a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py +++ b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py @@ -3,13 +3,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass def _get_decorated_ops(op): @@ -40,6 +41,8 @@ class DecorateFp32toInt32CastingPass(ArmPass): output = to_dim_order_copy(decorated_x, dtype=torch.int32) """ + _passes_required_after: Set[Type[ExportPass]] = set() + targets = [ exir_ops.edge.dim_order_ops._to_dim_order_copy.default, ] diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 491b404f0a4..65d4e939524 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -1,34 +1,34 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import copy -from typing import cast, Dict, Set, Tuple +from typing import cast, Optional, Set, Type +import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( get_param_tensor, is_param_node, + set_node_arg, ) +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.quant_args import QuantArgs +from executorch.backends.arm._passes.remove_noop_pass import RemoveNoopPass +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload - -from executorch.exir.pass_base import ( - Argument, - ExportPass, - NodeMetadata, - PassResult, - ProxyValue, -) + +from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node @@ -100,6 +100,15 @@ class FoldAndAnnotateQParamsPass(ArmPass): """ + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + RemoveNoopPass, + } + + def __init__(self, exported_program: Optional[ExportedProgram] = None) -> None: + super().__init__() + self.exported_program = exported_program + def fold_and_annotate_arg( self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int ) -> None: @@ -147,15 +156,114 @@ def fold_and_annotate_arg( if len(n.users) == 0: graph_module.graph.erase_node(n) - def call(self, graph_module: GraphModule) -> PassResult: + def _handle_control_flow_node(self, node: Node, graph_module: GraphModule): + """Fold outmost quant nodes inside submodule. + placeholders => qs => dqs => ... => qs => dqs => output + becomes + placeholders => dqs => ... => qs => output, + With output_qparams meta in the placeholders, and input_qparams meta in the output node. + """ + match node.target: + case torch.ops.higher_order.cond: + submodule_nodes = cast(list[Node], node.args[1:3]) + args = cast(list[Node], node.args[-1]) + case torch.ops.higher_order.while_loop: + submodule_nodes = cast(list[Node], node.args[0:2]) + args = cast(list[Node], node.args[-2]) + case _: + raise ValueError(f"Unhandled target {node.target}") + submodules = ( + graph_module.get_submodule(str(submodule_node.target)) + for submodule_node in submodule_nodes + ) + for submodule in submodules: + submodule = cast(GraphModule, submodule) + output_node = submodule.graph.output_node() + output_node.meta["input_qparams"] = {} + nodes_to_remove = [] + arg_id = 0 + for submodule_node in submodule.graph.nodes: + # Remove initial q nodes and ending dq nodes in the module. + submodule_node = cast(Node, submodule_node) + if ( + submodule_node.target in Q_OPS + and list(submodule_node.all_input_nodes)[0].op == "placeholder" + ): + input_node = cast(Node, submodule_node.args[0]) + input_node.meta["val"] = submodule_node.meta["val"] + quant_args = QuantArgs.from_operator( + submodule_node.target, submodule_node.args + ) + input_node.meta["output_qparams"] = {0: quant_args} + + submodule_node.replace_all_uses_with(input_node) + nodes_to_remove.append(submodule_node) + if submodule_node.target in DQ_OPS: + has_non_output_user = False + for user in copy.copy(submodule_node.users): + if user.op != "output": + has_non_output_user = True + else: + input_node = cast(Node, submodule_node.args[0]) + submodule_node.replace_all_uses_with(input_node) + arg_index = cast(list[Node], output_node.args[0]).index( + input_node + ) + quant_args = QuantArgs.from_operator( + submodule_node.target, submodule_node.args + ) + output_node.meta["input_qparams"][arg_index] = quant_args + + # Remove dq node if it only has the output node as its user. + if not has_non_output_user: + nodes_to_remove.append(submodule_node) + # Placeholders without users won't be retraced with correct dtype, do it manually. + # Control flow node input is matched to placeholder nodes in the submodule by index. + # This means it will break if another pass inserts a placeholder before this pass. + if submodule_node.op == "placeholder": + if len(submodule_node.users) == 0: + submodule_node.meta["val"] = args[arg_id].meta["val"] + arg_id += 1 + if arg_id > len(args): + raise RuntimeError( + "Submodule had more placeholders than calling node had inputs." + " This is probably due to a placeholder being inserted in a pass." + ) + for node_to_remove in nodes_to_remove: + submodule.graph.erase_node(node_to_remove) + return + + @staticmethod + def is_foldable(node: Node) -> bool: + if node.op != "call_function": + return False + # Don't fold chains of quant-ops into each other. + if node.target in (*Q_OPS, *DQ_OPS): + return False + + # Always fold q-dq into constant ops. + if node.target in ( + exir_ops.edge.aten.full_like.default, + *ComputeConstantOpsAOTPass.targeted_ops, + ): + return True + + # We should not fold q-dq nodes into non-quantized nodes. + if not ( + ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {}) + and ArmAnnotationInfo( + node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY] + ).quantized + ): + return False + return True + + def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 # Loop over the graph nodes and find any node in the 'targeted_ops' list. for n in graph_module.graph.nodes: n = cast(Node, n) - if n.op != "call_function": - continue - # Don't fold chains of quant-ops into each other. - if n.target in (*Q_OPS, *DQ_OPS): + if not FoldAndAnnotateQParamsPass.is_foldable(n): continue # Make sure we haven't already set qparams meta information on the node @@ -176,8 +284,8 @@ def call(self, graph_module: GraphModule) -> PassResult: n.meta["input_qparams"] = {} n.meta["output_qparams"] = {} for i, arg in enumerate(n.args): - if isinstance(arg, list): - self.fold_and_annotate_arg(graph_module, n, arg, i) + if isinstance(arg, (list, tuple)): + self.fold_and_annotate_arg(graph_module, n, arg, i) # type: ignore elif isinstance(arg, Node): self.fold_and_annotate_arg(graph_module, n, [arg], i) @@ -196,6 +304,22 @@ def call(self, graph_module: GraphModule) -> PassResult: user.replace_all_uses_with(n) graph_module.graph.erase_node(user) + # Some op(s) contain a "dtype" key in their node kwargs. Set this + # to the type of output qparams. + output_qparams = n.meta["output_qparams"] + if ( + n.target in {exir_ops.edge.aten.sum.dim_IntList} + and len(output_qparams) > 0 + ): + output_dtype = output_qparams[0].dtype + set_node_arg(n, "dtype", output_dtype) + + if n.target in ( + torch.ops.higher_order.cond, + torch.ops.higher_order.while_loop, + ): + self._handle_control_flow_node(n, graph_module) + # retrace the graph to update the fake tensor types graph_module = super().call(graph_module).graph_module @@ -203,13 +327,15 @@ def call(self, graph_module: GraphModule) -> PassResult: return PassResult(graph_module, True) -class QuantizeOperatorArguments(ExportPass): +class QuantizeClampArgumentsPass(ArmPass): """ This pass makes sure that the arguments to clamp.default are quantized correctly. More specifically, this pass: - Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: GraphModule) -> PassResult: modified = False # Loop over the graph nodes and find full.default nodes. @@ -220,12 +346,15 @@ def call(self, graph_module: GraphModule) -> PassResult: }: continue - # Make sure we have a quantized operator - user = list(n.users)[0] - if user.target not in Q_OPS: + try: + output_qparams = get_output_qparams(n) + except ValueError: + continue + if len(output_qparams) == 0: continue - qargs = QuantArgs.from_operator(user.target, user.args) + # Qparams are stored per user index; use the first entry. + qargs = next(iter(output_qparams.values())) if n.target == exir_ops.edge.aten.clamp.default: # Quantize the min and max arguments of clamp, if they are not None @@ -242,40 +371,9 @@ def call(self, graph_module: GraphModule) -> PassResult: modified = True - return PassResult(graph_module, modified) - + if modified: + # Retrace to refresh fake tensor metadata after updating clamp min/max. + graph_module = super().call(graph_module).graph_module + graph_module.recompile() -class RetraceFoldedDtypesPass(ExportPass): - """ - FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced - some operators are retraced to types that cannot be handled by TOSA. One - such example is sum.dim_IntList: - q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ... - After folding it becomes: - q (int8) -> sum (int64) -> ... - This pass changes types of ops in self.targeted_ops, such as sum, so that - the output type of that matches the type of the output_qparams. - """ - - targeted_ops: Set[EdgeOpOverload] = { - exir_ops.edge.aten.sum.dim_IntList, - } - - def call_operator( - self, - op, # pyre-ignore - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in self.targeted_ops: - return super().call_operator(op, args, kwargs, meta) - - node_kwargs = kwargs.copy() - output_qparams = meta["output_qparams"] - if len(output_qparams) == 0: - return super().call_operator(op, args, kwargs, meta) - - output_dtype = output_qparams[0].dtype - node_kwargs["dtype"] = output_dtype - return super().call_operator(op, args, node_kwargs, meta) + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/fuse_batch_norm2d_pass.py b/backends/arm/_passes/fuse_batch_norm2d_pass.py new file mode 100644 index 00000000000..d9ae706f503 --- /dev/null +++ b/backends/arm/_passes/fuse_batch_norm2d_pass.py @@ -0,0 +1,241 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) +from executorch.backends.arm.common.debug import get_node_debug_info +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch._export.utils import get_buffer, get_param +from torch.export.graph_signature import InputKind +from torch.fx import Node +from torch.nn.utils.fusion import fuse_conv_bn_weights + + +class FuseBatchNorm2dPass(ArmPass): + """Fuses the pattern convolution -> batchnorm by updating + the weights and bias of the convolution and removing the batchnorm. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def __init__(self, exported_program: ExportedProgram): + super().__init__() + self.exported_program = exported_program + + def get_bias_name(self, weight_node: Node, bias_node: Node | None) -> str: + if bias_node: + return bias_node.name + "_fused_bn" + elif "weight" in weight_node.name: + return weight_node.name.replace("weight", "bias") + "_fused_bn" + else: + return weight_node.name + "_bias_fused_bn" + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 + modified = False + constant_placeholders_to_delete = set() + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if ( + node.target + != exir_ops.edge.aten._native_batch_norm_legit_no_training.default + ): + continue + + # Get data from batchnorm + input_node = node.all_input_nodes[0] + is_single_user = len(input_node.users) == 1 + bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = node.args[1:5] + if bn_mean_node is None: + raise RuntimeError( + "BatchNorm mean buffer missing for node: " + f"{get_node_debug_info(node, graph_module)}" + ) + if bn_var_node is None: + raise RuntimeError( + "BatchNorm variance buffer missing for node: " + f"{get_node_debug_info(node, graph_module)}" + ) + + epsilon = node.args[-1] + + bn_weight_tensor = ( + get_param(self.exported_program, bn_weight_node) + if bn_weight_node is not None + else None + ) + bn_bias_tensor = ( + get_param(self.exported_program, bn_bias_node) + if bn_bias_node is not None + else None + ) + + bn_mean_tensor = torch.Tensor( + get_buffer(self.exported_program, bn_mean_node) + ) + bn_var_tensor = torch.Tensor(get_buffer(self.exported_program, bn_var_node)) + + if ( + input_node.target != exir_ops.edge.aten.convolution.default + or not is_single_user + ): + # Insert a transparent conv2d before bn to fuse with if none is present. + shape = get_first_fake_tensor(node) + if len(shape.size()) == 3: + input_weight_tensor = torch.ones((1, 1, 1)) + stride = [1] + padding = [0] + dilation = [1] + output_padding = [0] + else: + input_weight_tensor = torch.ones((1, 1, 1, 1)) + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + output_padding = [0, 0] + + with graph_module.graph.inserting_before(bn_weight_node): + input_weight_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=node.name + "_conv_weight", + data=input_weight_tensor, + ) + + input_bias_tensor = input_bias_node = None + + with graph_module.graph.inserting_before(node): + channels = bn_mean_tensor.size(0) + conv_args = ( + input_node, + input_weight_node, + input_bias_node, + stride, + padding, + dilation, + False, # Transposed + output_padding, + channels, + ) + new_input_node = create_node( + graph_module.graph, + exir_ops.edge.aten.convolution.default, + conv_args, + ) + node.replace_input_with(input_node, new_input_node) + input_node = new_input_node + else: + input_weight_node, input_bias_node = input_node.args[1:3] + if not ( + isinstance(input_weight_node, Node) + and input_weight_node.op == "placeholder" + ): + raise RuntimeError( + "Parameter weight of convolution must be a placeholder" + ) + if not ( + (input_bias_node is None) + or ( + isinstance(input_weight_node, Node) + and input_weight_node.op == "placeholder" + ) + ): + raise RuntimeError( + "Parameter bias of convolution must be a placeholder or None" + ) + + input_weight_tensor = torch.Tensor( + get_param(self.exported_program, input_weight_node) + ) + + input_bias_tensor = ( + get_param(self.exported_program, input_bias_node) + if input_bias_node is not None + else None + ) + + # Fuse bn weights/bias with input weights/bias + fused_weight, fused_bias = fuse_conv_bn_weights( + input_weight_tensor, + input_bias_tensor, + bn_mean_tensor, + bn_var_tensor, + epsilon, + bn_weight_tensor, + bn_bias_tensor, + ) + + # Create fused weights and bias to conv and replace conv args + with graph_module.graph.inserting_before(input_weight_node): + fused_conv_weight_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=input_weight_node.name + "_fused_bn", + data=fused_weight, + ) + + if fused_bias is not None: + fused_input_bias_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=self.get_bias_name(input_weight_node, input_bias_node), + data=fused_bias, + ) + else: + fused_input_bias_node = None + + input_node.args = ( + input_node.args[0], + fused_conv_weight_node, + fused_input_bias_node, + *input_node.args[3:], + ) + + # Erasing batch-norm nodes is handled by dead-code elimination. After that we may remove their constant placeholder inputs + for user in node.users: + user.replace_all_uses_with(input_node) + + constant_placeholders_to_delete.update( + [ + bn_weight_node, + bn_bias_node, + bn_mean_node, + bn_var_node, + input_weight_node, + input_bias_node, + ] + ) + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + for constant_placeholder in constant_placeholders_to_delete: + if (constant_placeholder is not None) and ( + len(constant_placeholder.users) == 0 + ): + delete_constant_placeholder( + self.exported_program, constant_placeholder + ) + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module=graph_module, modified=modified) diff --git a/backends/arm/_passes/fuse_batchnorm2d_pass.py b/backends/arm/_passes/fuse_batchnorm2d_pass.py deleted file mode 100644 index 2dbdfa84cec..00000000000 --- a/backends/arm/_passes/fuse_batchnorm2d_pass.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import torch -from executorch.backends.arm._passes.arm_pass_utils import ( - create_node, - get_first_fake_tensor, -) -from executorch.backends.transforms.utils import ( - create_constant_placeholder, - delete_constant_placeholder, -) -from executorch.exir import ExportedProgram -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult -from torch._export.utils import get_buffer, get_param -from torch.export.graph_signature import InputKind -from torch.fx import Node -from torch.nn.utils.fusion import fuse_conv_bn_weights - - -class FuseBatchnorm2DPass(ExportPass): - """Fuses the pattern convolution -> batchnorm by updating - the weights and bias of the convolution and removing the batchnorm. - """ - - def __init__(self, exported_program: ExportedProgram): - self.exported_program = exported_program - super().__init__() - - def get_bias_name(self, weight_node: Node, bias_node: Node | None) -> str: - if bias_node: - return bias_node.name + "_fused_bn" - elif "weight" in weight_node.name: - return weight_node.name.replace("weight", "bias") + "_fused_bn" - else: - return weight_node.name + "_bias_fused_bn" - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 - modified = False - constant_placeholders_to_delete = set() - for node in graph_module.graph.nodes: - if node.op != "call_function": - continue - if ( - node.target - != exir_ops.edge.aten._native_batch_norm_legit_no_training.default - ): - continue - - # Get data from batchnorm - input_node = node.all_input_nodes[0] - is_single_user = len(input_node.users) == 1 - bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = node.args[1:5] - assert bn_mean_node is not None, "Batchnorm mean node cannot be None." - assert bn_var_node is not None, "Batchnorm var node cannot be None." - - epsilon = node.args[-1] - - bn_weight_tensor = ( - get_param(self.exported_program, bn_weight_node) - if bn_weight_node is not None - else None - ) - bn_bias_tensor = ( - get_param(self.exported_program, bn_bias_node) - if bn_bias_node is not None - else None - ) - - bn_mean_tensor = torch.Tensor( - get_buffer(self.exported_program, bn_mean_node) - ) - bn_var_tensor = torch.Tensor(get_buffer(self.exported_program, bn_var_node)) - - if ( - input_node.target != exir_ops.edge.aten.convolution.default - or not is_single_user - ): - # Insert a transparent conv2d before bn to fuse with if none is present. - shape = get_first_fake_tensor(node) - if len(shape.size()) == 3: - input_weight_tensor = torch.ones((1, 1, 1)) - stride = [1] - padding = [0] - dilation = [1] - output_padding = [0] - else: - input_weight_tensor = torch.ones((1, 1, 1, 1)) - stride = [1, 1] - padding = [0, 0] - dilation = [1, 1] - output_padding = [0, 0] - - with graph_module.graph.inserting_before(bn_weight_node): - input_weight_node = create_constant_placeholder( - exp_program=self.exported_program, - graph=graph_module.graph, - kind=InputKind.PARAMETER, - name=node.name + "_conv_weight", - data=input_weight_tensor, - ) - - input_bias_tensor = input_bias_node = None - - with graph_module.graph.inserting_before(node): - channels = bn_mean_tensor.size(0) - conv_args = ( - input_node, - input_weight_node, - input_bias_node, - stride, - padding, - dilation, - False, # Transposed - output_padding, - channels, - ) - new_input_node = create_node( - graph_module.graph, - exir_ops.edge.aten.convolution.default, - conv_args, - ) - node.replace_input_with(input_node, new_input_node) - input_node = new_input_node - else: - input_weight_node, input_bias_node = input_node.args[1:3] - assert ( - isinstance(input_weight_node, Node) - and input_weight_node.op == "placeholder" - ), "Parameter weight of convolution must be a placeholder" - assert (input_bias_node is None) or ( - isinstance(input_weight_node, Node) - and input_weight_node.op == "placeholder" - ), "Parameter bias of convolution must be a placeholder or None" - - input_weight_tensor = torch.Tensor( - get_param(self.exported_program, input_weight_node) - ) - - input_bias_tensor = ( - get_param(self.exported_program, input_bias_node) - if input_bias_node is not None - else None - ) - - # Fuse bn weights/bias with input weights/bias - fused_weight, fused_bias = fuse_conv_bn_weights( - input_weight_tensor, - input_bias_tensor, - bn_mean_tensor, - bn_var_tensor, - epsilon, - bn_weight_tensor, - bn_bias_tensor, - ) - - # Create fused weights and bias to conv and replace conv args - with graph_module.graph.inserting_before(input_weight_node): - fused_conv_weight_node = create_constant_placeholder( - exp_program=self.exported_program, - graph=graph_module.graph, - kind=InputKind.PARAMETER, - name=input_weight_node.name + "_fused_bn", - data=fused_weight, - ) - - if fused_bias is not None: - fused_input_bias_node = create_constant_placeholder( - exp_program=self.exported_program, - graph=graph_module.graph, - kind=InputKind.PARAMETER, - name=self.get_bias_name(input_weight_node, input_bias_node), - data=fused_bias, - ) - else: - fused_input_bias_node = None - - input_node.args = ( - input_node.args[0], - fused_conv_weight_node, - fused_input_bias_node, - *input_node.args[3:], - ) - - # Erasing batch-norm nodes is handled by dead-code elimination. After that we may remove their constant placeholder inputs - for user in node.users: - user.replace_all_uses_with(input_node) - - constant_placeholders_to_delete.update( - [ - bn_weight_node, - bn_bias_node, - bn_mean_node, - bn_var_node, - input_weight_node, - input_bias_node, - ] - ) - modified = True - - if modified: - graph_module.graph.eliminate_dead_code() - for constant_placeholder in constant_placeholders_to_delete: - if (constant_placeholder is not None) and ( - len(constant_placeholder.users) == 0 - ): - delete_constant_placeholder( - self.exported_program, constant_placeholder - ) - - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module=graph_module, modified=modified) diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index f49565e3c38..0ca3dc38f75 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -4,15 +4,20 @@ # LICENSE file in the root directory of this source tree. import logging +from typing import Set, Type import torch._export.utils import torch.fx +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( get_constant_placeholder_kind, get_first_fake_tensor, get_param_tensor, is_persistent_buffer, ) +from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( + FuseEqualPlaceholdersPass, +) from executorch.backends.transforms.utils import ( create_constant_placeholder, delete_constant_placeholder, @@ -25,7 +30,7 @@ logger = logging.getLogger(__name__) -class FuseConstantArgsPass(ExportPass): +class FuseConstantArgsPass(ArmPass): """ Fuses ops with only placeholder parameters into one placeholder parameter node with the op pre-calulcated on its data. @@ -41,6 +46,8 @@ def f(): return x """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program @@ -58,7 +65,8 @@ def resolve_arg(arg): if isinstance(arg, torch.fx.Node) and arg in input_nodes: idx = input_nodes.index(arg) t = get_param_tensor(self.exported_program, arg) - if qparams: + # Check if qparams exist for this arg + if qparams and idx in qparams.keys(): t = qparams[idx].dequantize_value(t) return t if isinstance(arg, tuple): @@ -108,8 +116,10 @@ def call(self, graph_module): if node.op != "call_function": continue if node.target in [ - exir_ops.backend.tosa.TABLE.default, + exir_ops.backend.tosa.MATMUL.default, exir_ops.backend.tosa.RESCALE.default, + exir_ops.backend.tosa.RESIZE.default, + exir_ops.backend.tosa.TABLE.default, exir_ops.backend.tosa.TRANSPOSE.default, ]: continue @@ -154,7 +164,7 @@ def call(self, graph_module): return PassResult(graph_module, True) -class ComputeConstantOpsAOT(ExportPass): +class ComputeConstantOpsAOTPass(ArmPass): """ Evaluates call_functions that produce constant tensor outputs and replaces them with placeholders. @@ -168,6 +178,11 @@ def f(node_name_pre_computed): return node_name_pre_computed """ + _passes_required_after: Set[Type[ExportPass]] = { + FuseEqualPlaceholdersPass, + FuseConstantArgsPass, + } + targeted_ops = [ exir_ops.edge.aten.full.default, exir_ops.edge.aten.arange.start_step, diff --git a/backends/arm/_passes/fuse_duplicate_users_pass.py b/backends/arm/_passes/fuse_duplicate_users_pass.py new file mode 100644 index 00000000000..217d93373f8 --- /dev/null +++ b/backends/arm/_passes/fuse_duplicate_users_pass.py @@ -0,0 +1,165 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import deque +from typing import Any, Deque, Dict, Hashable, List, Set, Tuple, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from torch._ops import OpOverload +from torch.fx import GraphModule, Node +from torch.fx.node import Argument, map_arg + + +class FuseDuplicateUsersPass(ArmPass): + """Fuse identical users of a producer node into a single operation. + + Example: + + y = producer(x) + z0 = torch.add(y, bias) + z1 = torch.add(y, bias) + + becomes a single ``torch.add`` that feeds both consumers. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + + producers: Deque[Node] = deque(node for node in graph.nodes) + + while producers: + producer = producers.popleft() + + if producer.graph is None: + # Node was deleted by a previous rewrite while still queued. + continue + + # Only meaningful if a value is consumed by multiple users. + user_nodes = list(producer.users) + if len(user_nodes) < 2: + continue + + candidate_groups = self._get_candidate_groups(user_nodes) + + signature_to_user: Dict[Tuple[Hashable, ...], Node] = {} + for group in candidate_groups: + for user in group: + signature = self._build_user_signature(user) + if signature is None: + continue + + representative = signature_to_user.get(signature) + if representative is None: + # Check if we already encountered identical node that we can fuse with. + signature_to_user[signature] = user + continue + + if user is representative: + # The queue can enqueue the surviving node again after rewrites. + continue + + user.replace_all_uses_with(representative) + graph.erase_node(user) + modified = True + + # Revisit the current producer and the surviving user so that + # newly formed duplicate chains can be fused in later + # iterations. + producers.append(producer) + producers.append(representative) + + if modified: + graph_module.recompile() + graph_module.graph.lint() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) + + def _get_candidate_groups(self, user_nodes): + users_by_target: Dict[Tuple[str, Hashable], List[Node]] = {} + for user in user_nodes: + if user.graph is None: + # User might already have been removed by a prior rewrite. + continue + + if user.op != "call_function": + continue + + target_key = self._get_target_key(user.target) + target_signature = (user.op, target_key) + users_by_target.setdefault(target_signature, []).append(user) + + candidate_groups = [ + group for group in users_by_target.values() if len(group) > 1 + ] + + return candidate_groups + + def _build_user_signature(self, node: Node) -> Tuple[Hashable, ...] | None: + try: + normalized_args = self._to_hashable( + map_arg(node.args, self._map_leaf_to_key) + ) + normalized_kwargs = self._to_hashable( + {k: map_arg(v, self._map_leaf_to_key) for k, v in node.kwargs.items()} + ) + except TypeError: + return None + + target_key = self._get_target_key(node.target) + + return (node.op, target_key, normalized_args, normalized_kwargs) + + def _map_leaf_to_key(self, node: Node) -> Argument: + return node.name + + def _to_hashable(self, value: Any) -> Hashable: + """Convert arbitrarily nested structures into hashable tuples.""" + + if isinstance(value, (list, tuple)): + return tuple(self._to_hashable(v) for v in value) + if isinstance(value, dict): + normalized_items = [(k, self._to_hashable(v)) for k, v in value.items()] + return tuple(sorted(normalized_items, key=lambda item: repr(item[0]))) + if isinstance(value, set): + hashable_values: List[Hashable] = [self._to_hashable(v) for v in value] + return tuple(sorted(hashable_values, key=repr)) + if isinstance(value, slice): + return ( + "slice", + self._to_hashable(value.start), + self._to_hashable(value.stop), + self._to_hashable(value.step), + ) + if isinstance(value, range): + return ("range", value.start, value.stop, value.step) + if isinstance(value, torch.Size): + return ("size", tuple(value)) + if isinstance(value, torch.dtype): + return ("dtype", str(value)) + if isinstance(value, torch.device): + return ("device", str(value)) + if isinstance(value, torch.memory_format): + return ("memory_format", str(value)) + if isinstance(value, torch.Tensor): + return ( + "tensor", + str(value.dtype), + tuple(value.size()), + value.device.type, + value.requires_grad, + ) + return value + + def _get_target_key(self, target: Any) -> Hashable: + if isinstance(target, (EdgeOpOverload, OpOverload)): + return str(target) + return target diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index 5631e2f32e9..f06ea3d5470 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -5,13 +5,17 @@ import hashlib from collections import defaultdict +from typing import Set, Type import torch + +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( get_constant_placeholder_kind, get_param_tensor, is_param_node, ) +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import ( create_constant_placeholder, delete_constant_placeholder, @@ -20,16 +24,18 @@ from executorch.exir.pass_base import ExportPass, PassResult -class FuseEqualPlaceholdersPass(ExportPass): +class FuseEqualPlaceholdersPass(ArmPass): """ This pass optimizes memory usage by finding constant placeholders pointing to identical tensors and fusing them to one single placeholder with multiple users, using a cache for faster comparison. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: ExportedProgram): - self.exported_program = exported_program super().__init__() + self.exported_program = exported_program def call(self, graph_module: torch.fx.GraphModule) -> PassResult: modified = False @@ -44,12 +50,17 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: continue # Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes # Ensure tensor is on CPU and contiguous + + # ensure we don't merge any special case int48_t tensors with int32_t tensors + # since int48_t tensors needs to be instantiated separately. + is_int48 = node.meta.get(TosaSpecialDtype.meta_key(), None) t_cpu = tensor.detach().cpu().contiguous() data_bytes = t_cpu.numpy().tobytes() key = ( + is_int48, str(t_cpu.dtype), tuple(t_cpu.shape), - hashlib.sha1(data_bytes).hexdigest(), + hashlib.sha1(data_bytes, usedforsecurity=False).hexdigest(), ) hash_buckets[key].append((node, t_cpu)) @@ -73,6 +84,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: common_persistent, ) + # TBD: Find a principled way to merge node.meta across all fused node + # For now, i specifically transfer over the TosaSpecialDtype.meta_key() of the rep_node + if TosaSpecialDtype.meta_key() in rep_node.meta: + common_node.meta[TosaSpecialDtype.meta_key()] = rep_node.meta[ + TosaSpecialDtype.meta_key() + ] + # Replace uses and delete duplicates for node, _ in nodes_tensors: node.replace_all_uses_with(common_node) diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py index 46a7d7f6f98..09e989cd3aa 100644 --- a/backends/arm/_passes/fuse_quantized_activation_pass.py +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -3,17 +3,30 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe + +from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_to_clamp_pass import ConvertToClampPass +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + FoldAndAnnotateQParamsPass, +) from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import Q_OPS +from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import Node -class FuseQuantizedActivationPass(ExportPass): +class FuseQuantizedActivationPass(ArmPass): + _passes_required_after: Set[Type[ExportPass]] = { + ConvertToClampPass, + FoldAndAnnotateQParamsPass, + RemoveGetItemPass, + } + @staticmethod def _is_fuseable_quantized_activation(node: Node): """Fuse activations that have a 0 lower bound and quantized with a qmin zero-point""" diff --git a/backends/arm/_passes/fuse_view_copy_transform_pass.py b/backends/arm/_passes/fuse_view_copy_transform_pass.py new file mode 100644 index 00000000000..cef3b408c24 --- /dev/null +++ b/backends/arm/_passes/fuse_view_copy_transform_pass.py @@ -0,0 +1,14 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform +from executorch.exir.pass_base import ExportPass + + +class FuseViewCopyTransformPass(ArmPass, FuseViewCopyTransform): + _passes_required_after: Set[Type[ExportPass]] = set() diff --git a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py index 4b619af790c..de80d61bfbe 100644 --- a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py +++ b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py @@ -3,13 +3,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe - import logging +from typing import Set, Type + import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.decompose_embedding_pass import ( + DecomposeEmbeddingPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import EdgeOpOverload, ExportPass, PassResult from torch._subclasses.fake_tensor import FakeTensor @@ -18,7 +22,7 @@ logger = logging.getLogger(__name__) -class InsertInt32CastsAfterInt64PlaceholdersPass(ExportPass): +class InsertInt32CastsAfterInt64PlaceholdersPass(ArmPass): """ Insert an int64->int32 cast after each int64 placeholder. @@ -26,10 +30,14 @@ class InsertInt32CastsAfterInt64PlaceholdersPass(ExportPass): the int32 range. """ + _passes_required_after: Set[Type[ExportPass]] = {DecomposeEmbeddingPass} + # Ops that require i64 inputs → positions of args to upcast. # Key: op overload; Value: zero-based indices of positional args that must be i64. I64_INPUT_ARG_POSITIONS = { torch.ops.aten.one_hot.default: (0,), + torch.ops.aten.index_copy_.default: (2,), + torch.ops.aten.index_copy.default: (2,), } def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule): diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 7f75aecf24c..9e69a1e7e53 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -3,10 +3,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math from copy import copy -from typing import cast +from typing import cast, Dict, Optional, Set, Tuple, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg +from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_output_qparams, +) -from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops @@ -14,7 +22,7 @@ from torch.fx import GraphModule, Node -class InsertRescalePass(ExportPass): +class InsertRescalePass(ArmPass): """Finds patterns of dq -> q, and replaces them with backend dialect tosa::RESCALE op. @@ -24,6 +32,8 @@ class InsertRescalePass(ExportPass): in the fake implementation of. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule): dq_args = QuantArgs.from_operator(node.target, node.args) q_args = QuantArgs.from_operator(user.target, user.args) @@ -36,7 +46,7 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule ( node.all_input_nodes[0], q_args.dtype, - new_scale, + [new_scale], dq_args.zp, q_args.zp, ), @@ -63,3 +73,512 @@ def call(self, graph_module: GraphModule) -> PassResult: graph_module = super().call(graph_module).graph_module graph_module.recompile() return PassResult(graph_module, modified) + + +class InsertRescaleInt32Pass(ArmPass): + """Numerous TOSA ops require inputs and outputs to be 32-bit integers in their + quantized implementations. This pass treats such operator nodes by + inserting rescale ops before and after them if needed. Note that extra + logic that handles the scales and zero points are in place here because the + affected TOSA ops have naive implementations that do not account for the + quantization parameters. + """ + + # SUM must be decomposed after this pass to prevent insertion of RESCALE + # nodes between each subsequent SUM node after decomposition. RESCALE nodes + # should only be inserted before and after the SUM node prior to its + # decomposition. + _passes_required_after: Set[Type[ExportPass]] = {DecomposeSumPass} + + included_targets = [ + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.sum.dim_IntList, + ] + + def _int32_qargs(self, s): + """Helper creator function for INT32-based QuantArgs""" + + return QuantArgs( + scale=s, + zp=0, + qmin=torch.iinfo(torch.int32).min, + qmax=torch.iinfo(torch.int32).max, + dtype=torch.int32, + ) + + def _get_inputs_rescaled_qparams( + self, target, input_qparams: Dict[int, QuantArgs] + ) -> Dict[int, QuantArgs]: + """Get the qparams for the INT32 operands to the op ``target`` + + Inputs to the INT32-based operator must be rescaled from INT8 to INT32. + This function computes the ``QuantArgs`` for each of the operands and returns + it as a dict, mapping tensor index to ``QuantArgs``. + """ + + if target in [ + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.maximum.default, + ]: + # For these ops, use the smallest scale among the INT8 operands. + min_scale = min( + [qp.get_scale_per_tensor() for qp in input_qparams.values()] + ) + qparams = {i: self._int32_qargs(min_scale) for i in input_qparams.keys()} + elif target in [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + ]: + keys = list(input_qparams) + if len(keys) < 2: + raise ValueError(f"Expected two input qparams, got: {input_qparams}.") + if input_qparams[keys[0]].dtype != input_qparams[keys[1]].dtype: + raise ValueError( + f"Mismatch in dtype args: {input_qparams[keys[0]].dtype} != {input_qparams[keys[1]].dtype}" + ) + + # We are handling two INT8 or two INT16 numbers. For INT8, if the + # zero point is non-null, the result will be in the range [-255; + # 255], therefore we need 9 bits for the result. We have a 32-bit + # accumulator, so we can divide the scale by (1 << 20) which is + # equivalent to shifting the INT8 operands 20 bits to the left + # before rescaling them both to 2 * max(lhs, rhs). + # + # For INT16, similary logic can be applied, but we instead end up + # with a left shift of 12. + lhs_scale, rhs_scale = ( + qp.get_scale_per_tensor() for qp in input_qparams.values() + ) + max_scale_2x = 2 * max(lhs_scale, rhs_scale) + + # Select shift based on input dtype. + shift_bits = 12 if input_qparams[keys[0]].dtype == torch.int16 else 20 + + scale = max_scale_2x / (1 << shift_bits) + qparams = {i: self._int32_qargs(scale) for i in input_qparams.keys()} + elif target in [ + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sum.dim_IntList, + ]: + # The input scales do not need to be adjusted for these ops; they + # can remain the same. + qparams = { + i: self._int32_qargs(qp.get_scale_per_tensor()) + for i, qp in input_qparams.items() + } + else: + raise ValueError(f"Not a valid target: {target}") + + return qparams + + def _get_output_qparams( + self, target, inputs_qparams: Dict[int, QuantArgs] + ) -> Optional[QuantArgs]: + """Given an op ``target`` and the ``QuantArgs`` for each of its inputs, compute + the scale of the output based on how the operator itself affects it.""" + + if target in [ + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + ]: + # The op has not altered the scale; the output scale is equal to + # the operands' scales. + return self._int32_qargs(inputs_qparams[0].get_scale_per_tensor()) + elif target in [ + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.lt.Tensor, + ]: + # Output is bool for these ops and thus no qparams are present + return None + elif target in [exir_ops.edge.aten.mul.Tensor]: + # Mul will cause the scales to also multiply; refer to the formula + # where we compute the output scale S_2: + # + # (Q_2 - ZP_2) * S_2 == ((Q_0 - ZP_0) * S_0) * ((Q_1 - ZP_1) * S_1) + # + # yields: + # + # (Q_2 - ZP_2) == (Q_0 - ZP_0) * (Q_1 - ZP_1) + # S_2 = S_0 * S_1 + output_scale = math.prod( + (qp.get_scale_per_tensor() for qp in inputs_qparams.values()) + ) + return self._int32_qargs(output_scale) + else: + raise ValueError(f"Not a valid target: {target}") + + def _get_rescale_qparams( + self, target, input_qparams: Dict[int, QuantArgs] + ) -> Tuple[Dict[int, QuantArgs], Optional[QuantArgs]]: + """ + Get the quantization parameters of the INT32 inputs/outputs that will + surround the node after the new RESCALE ops have been inserted. + """ + + inputs_rescaled_qparams = self._get_inputs_rescaled_qparams( + target, input_qparams + ) + output_qparams = self._get_output_qparams(target, inputs_rescaled_qparams) + + return (inputs_rescaled_qparams, output_qparams) + + def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> bool: + qargs = node.meta["input_qparams"] + + args_copy = list(node.args) + seen_args = set() + modified = False + for i in qargs: + qp = qargs[i] + if qp.dtype not in (torch.int8, torch.int16): + continue + + arg_node = args_copy[i] + if arg_node in seen_args: + continue + seen_args.add(arg_node) + + with graph.inserting_after(arg_node): + rescale_node = create_node( + graph, + exir_ops.backend.tosa.RESCALE.default, + ( + arg_node, + torch.int32, + [ + qp.get_scale_per_tensor() + / rescale_qargs[i].get_scale_per_tensor() + ], # [Old scale / new scale] + qp.get_zp_per_tensor(), # Old zero point + rescale_qargs[i].get_zp_per_tensor(), # New zero point + ), + from_node=node, + ) + + node.replace_input_with(arg_node, rescale_node) + modified = True + + return modified + + def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> bool: + if "output_qparams" not in node.meta or len(node.meta["output_qparams"]) == 0: + return False + + qargs = get_output_qparams(node) + assert len(qargs) == 1 + assert rescale_qargs is not None + + qarg = qargs[0] + if qarg.dtype not in (torch.int8, torch.int16): + return False + + users_copy = list(node.users) + + with graph.inserting_after(node): + rescale_node = create_node( + graph, + exir_ops.backend.tosa.RESCALE.default, + ( + node, + qarg.dtype, + [ + rescale_qargs.get_scale_per_tensor() + / qarg.get_scale_per_tensor() + ], # [Old scale / new scale] + rescale_qargs.get_zp_per_tensor(), # Old zero point + qarg.get_zp_per_tensor(), # New zero point + ), + from_node=node, + ) + + for user in users_copy: + user.replace_input_with(node, rescale_node) + + return True + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + + modified = False + for node in list(graph.nodes): + node = cast(Node, node) + + if node.op != "call_function" or node.target not in self.included_targets: + continue + + if "input_qparams" not in node.meta or len(node.meta["input_qparams"]) == 0: + continue + input_qparams = node.meta["input_qparams"] + + inputs_rescale_qargs, output_rescale_qargs = self._get_rescale_qparams( + node.target, input_qparams + ) + + inputs_was_rescaled = self._rescale_inputs( + graph, node, inputs_rescale_qargs + ) + outputs_was_rescaled = False + if inputs_was_rescaled: + outputs_was_rescaled = self._rescale_outputs( + graph, node, output_rescale_qargs + ) + modified = True + + # Update node metadata + + if inputs_was_rescaled: + assert len(inputs_rescale_qargs) == len(node.meta["input_qparams"]) + node.meta["input_qparams"] = inputs_rescale_qargs + + if outputs_was_rescaled: + assert len(node.meta["output_qparams"]) == 1 + node.meta["output_qparams"] = {0: output_rescale_qargs} + + # If the output type is specified in the node, change it such + # that it matches the subsequent rescale node(s) that this node + # now has output edges to. + if "dtype" in node.kwargs: + set_node_arg(node, "dtype", torch.int32) + + if modified: + # Retrace the graph to update the fake tensor types + graph_module = super().call(graph_module).graph_module + graph_module.recompile() + + return PassResult(graph_module, modified) + + +class InsertControlFlowRescalesPass(ArmPass): + """The quantization parameters for tensors going into and coming out of a submodule are not guaranteed to + match the quantization parameters for the corresponding tensors inside the submodule. For example, cond has + different annotation on input and output, while the entire graph inside the submodule could be using shared + annotation. This pass solves this by inserting rescales in the beginning and end of the submodule + that transform the tensor from one set of quantization parameters to another. + + The pass is run by the graph_module containing the control flow operator, but requires that the affected nodes + inside the submodule have been q-dq folded and have input/output_qparams meta. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def _get_input_nodes(self, graph_module: GraphModule): + return [node for node in graph_module.graph.nodes if node.op == "placeholder"] + + def _insert_rescale( + self, + in_qparams: QuantArgs, + out_qparams: QuantArgs, + from_node: Node, + graph_module: GraphModule, + ): + """Insert a rescale into the graph, inheriting meta from `from_node`. + The node is not connected to anything, that is up to the user.""" + + new_scales = [ + in_qparams.get_scale_per_tensor() / out_qparams.get_scale_per_tensor() + ] + + rescale_node = create_node( + graph_module.graph, + exir_ops.backend.tosa.RESCALE.default, + ( + None, + out_qparams.dtype, + new_scales, + in_qparams.get_zp_per_tensor(), # Old zero point + out_qparams.get_zp_per_tensor(), # New zero point + ), + from_node=from_node, + ) + return rescale_node + + def _rescale_submodule_inputs( + self, submodule: GraphModule, input_qparams_map: Dict[int, QuantArgs] + ) -> bool: + """Insert rescales at the inputs of `submodule` to match the qparams outside the submodule. + Matching the correct qparams gets a bit tricky: + Containing module: | submodule: + ops => cond | => placeholders => ... + + The dq->q qparam pair we want to convert to a rescale is: + (input qparams of op, output qparams of placeholder) + And the rescale is inserted after the placeholder. + + Args: + submodule: GraphModule: the GraphModule in which to rescale the inputs. + input_qparams_map: A map of input indexes mapping to QuantArgs. Not guaranteed to contain a mapping + for every submodule input. + Returns: + True if at least one rescale was inserted, False otherwise. + """ + + modified = False + input_nodes = self._get_input_nodes(submodule) + for qargs_index in input_qparams_map: + input_node = input_nodes[qargs_index] + if len(input_node.users) == 0: + continue + if len(out_qparams_map := input_node.meta.get("output_qparams", {})) != 1: + raise ValueError( + f"Expected submodule input {input_node} to have exactly one output qparam, got {out_qparams_map}" + ) + in_qparams = input_qparams_map[qargs_index] + out_qparams = cast(QuantArgs, out_qparams_map[0]) + + # Remove qparam meta to not confuse folding pass. + del input_node.meta["output_qparams"] + if in_qparams == out_qparams: + continue + with submodule.graph.inserting_after(input_node): + modified = True + rescale_node = self._insert_rescale( + in_qparams, out_qparams, input_node, submodule + ) + input_node.replace_all_uses_with(replace_with=rescale_node) + rescale_node.update_arg(0, input_node) + return modified + + def _rescale_submodule_outputs( + self, submodule: GraphModule, output_qparams_map: Dict[int, QuantArgs] + ) -> bool: + """Insert rescales at the outputs of `submodule` to match the qparams outside the submodule. + Matching the correct qparams gets a bit tricky: + Submodule: | Containing module: + output_nodes => output |=> getitems => ... + + The dq->q qparam pair we want to convert to a rescale is: + (input qparam of output_node, output qparam of getitem) + And the rescale is inserted between op and output. Note that the output qparam of op is called input_qargs, + since the it is the input to the dq-q pair. + + Args: + submodule: GraphModule: the GraphModule in which to rescale the outputs. + output_qparams_map: A map of output indexes mapping to QuantArgs. Not guaranteed to contain a mapping + for every submodule output. + Returns: + True if at least one rescale was inserted, False otherwise. + """ + + modified = False + output_node = submodule.graph.output_node() + output_args = list(cast(tuple[Node], output_node.args[0])) + input_qparams_map = cast( + dict[int, QuantArgs], output_node.meta["input_qparams"] + ) + for qargs_index in output_qparams_map: + output_arg_node = output_args[qargs_index] + in_qparams = input_qparams_map[qargs_index] + out_qparams = output_qparams_map[qargs_index] + if in_qparams == out_qparams: + continue + with submodule.graph.inserting_before(output_node): + modified = True + rescale_node = self._insert_rescale( + in_qparams, out_qparams, output_arg_node, submodule + ) + output_args[qargs_index] = rescale_node + rescale_node.update_arg(0, output_arg_node) + output_node.update_arg(0, tuple(output_args)) + # Remove qparam meta to not confuse folding pass. + del output_node.meta["input_qparams"] + return modified + + def _get_input_qparams_map(self, node: Node, idx: int): + input_qparams_meta = cast( + dict[int, QuantArgs], node.meta.get("input_qparams", None) + ) + if input_qparams_meta: + input_qparams = cast(QuantArgs, input_qparams_meta.get(idx, None)) + if not input_qparams: + raise ValueError( + f"Expected entry with key {idx} in input_qparams meta, got {input_qparams_meta}" + ) + num_inputs = len(cast(list, node.args[idx])) + + # Currently, infra only supports one set of qparams for a list of inputs + # Map all inputs to the same qparams. + input_qparams_map = {i: input_qparams for i in range(num_inputs)} + return input_qparams_map + return None + + def _get_output_qparams_map(self, node: Node): + output_qparams_map: dict[int, QuantArgs] = {} + for getitem_node in node.users: + idx = cast(int, getitem_node.args[1]) + qparam = getitem_node.meta.get("output_qparams", None) + if qparam: + output_qparams_map[idx] = cast(QuantArgs, qparam[0]) + return output_qparams_map + + def _rescale_cond_submodules(self, node: Node, graph_module: GraphModule) -> bool: + modified = False + if_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[1].target)) # type: ignore + else_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[2].target)) # type: ignore + input_qparams_map = self._get_input_qparams_map(node, 3) + if input_qparams_map: + modified |= self._rescale_submodule_inputs(if_graph, input_qparams_map) + modified |= self._rescale_submodule_inputs(else_graph, input_qparams_map) + + output_qparams_map = self._get_output_qparams_map(node) + if output_qparams_map: + modified |= self._rescale_submodule_outputs(if_graph, output_qparams_map) + modified |= self._rescale_submodule_outputs(else_graph, output_qparams_map) + return modified + + def _rescale_while_submodules(self, node: Node, graph_module: GraphModule): + modified = False + cond_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[0].target)) # type: ignore + body_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[1].target)) # type: ignore + + input_qparams_map = self._get_input_qparams_map(node, 2) + if input_qparams_map: + modified |= self._rescale_submodule_inputs(cond_graph, input_qparams_map) + modified |= self._rescale_submodule_inputs(body_graph, input_qparams_map) + + output_qparams_map = self._get_output_qparams_map(node) + if output_qparams_map: + modified |= self._rescale_submodule_outputs(body_graph, output_qparams_map) + return modified + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + + for node in list(graph_module.graph.nodes): + node = cast(Node, node) + if node.op != "call_function": + continue + + if node.target == torch.ops.higher_order.cond: + modified = self._rescale_cond_submodules(node, graph_module) + if node.target == torch.ops.higher_order.while_loop: + modified = self._rescale_while_submodules(node, graph_module) + + if modified: + # Retrace the graph to update the fake tensor types + graph_module = super().call(graph_module).graph_module + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index fb5d7de5e12..27de85e5ba9 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -3,12 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from itertools import chain -from typing import Callable, cast, Dict, Iterator, Set +from typing import Callable, cast, Dict, Iterator, Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.transforms.utils import create_constant_placeholder @@ -109,7 +109,7 @@ def included_ops() -> Iterator[EdgeOpOverload]: return chain(TableOps.unary_table_ops, TableOps.special_table_ops) -class InsertTableOpsPass(ExportPass): +class InsertTableOpsPass(ArmPass): """ For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target). @@ -117,6 +117,8 @@ class InsertTableOpsPass(ExportPass): which will be used to produce the table values in operators/op_table.py. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program @@ -233,8 +235,8 @@ def call(self, graph_module: GraphModule) -> PassResult: for node in graph_module.graph.nodes: if node.op != "call_function" or node not in self.table_ops: continue - input_qparams = node.meta["input_qparams"] - output_qparams = node.meta["output_qparams"] + input_qparams = node.meta.get("input_qparams", {}) + output_qparams = node.meta.get("output_qparams", {}) if len(input_qparams) == 0 or len(output_qparams) == 0: # We only want to replace the node if it's quantized continue @@ -283,7 +285,7 @@ def call(self, graph_module: GraphModule) -> PassResult: rescale_node = create_node( graph=graph_module.graph, op_target=exir_ops.backend.tosa.RESCALE.default, - args=(table_op_node, output_qparams[0].dtype, scale, 0, 0), + args=(table_op_node, output_qparams[0].dtype, [scale], 0, 0), ) output_node = rescale_node diff --git a/backends/arm/_passes/match_arg_dtype_pass.py b/backends/arm/_passes/match_arg_dtype_pass.py index e7bf3b2d60e..f0aaa0cf5f9 100644 --- a/backends/arm/_passes/match_arg_dtype_pass.py +++ b/backends/arm/_passes/match_arg_dtype_pass.py @@ -3,7 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Set, Type + import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node, get_node_arg from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -26,7 +29,7 @@ def get_largest_dtype(dtype_1, dtype_2): return dtype_1 if DTYPE_RANK[dtype_1] > DTYPE_RANK[dtype_2] else dtype_2 -class MatchArgDtypePass(ExportPass): +class MatchArgDtypePass(ArmPass): """Pass to match data types of non-condition input tensors. Edge dialect allows different data types for non-condition tensors, while TOSA @@ -38,6 +41,8 @@ class MatchArgDtypePass(ExportPass): """ + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.where.self} def call(self, graph_module: torch.fx.GraphModule): diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index d6cdfacb612..d9f38c951b8 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -5,14 +5,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe -from typing import cast +from typing import cast, Set, Type + +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, ) +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -20,7 +22,7 @@ from torch.fx import GraphModule, Node -class MatchArgRanksPass(ExportPass): +class MatchArgRanksPass(ArmPass): """ For ops in 'targeted_ops', make sure that the inputs share the same rank. New dimensions are inserted from the beginning of the inputs that have a @@ -36,7 +38,9 @@ class MatchArgRanksPass(ExportPass): input2 = shape(1, 3, 1) """ - def __init__(self, exported_program): + _passes_required_after: Set[Type[ExportPass]] = set() + + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program @@ -45,6 +49,7 @@ def __init__(self, exported_program): exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.div.Tensor_mode, exir_ops.edge.aten.bitwise_right_shift.Tensor, exir_ops.edge.aten.bitwise_left_shift.Tensor, exir_ops.edge.aten.eq.Tensor, @@ -53,10 +58,13 @@ def __init__(self, exported_program): exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.pow.Tensor_Tensor, + exir_ops.edge.aten.remainder.Tensor, exir_ops.edge.aten.where.self, exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_xor.Tensor, exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.minimum.default, ] def _match_op_rank(self, graph_module, node, arg, max_rank): diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index 69d8573013e..34634b99712 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -4,21 +4,24 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe + +from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, - insert_q_dq_pair, ) -from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.backends.arm._passes.convert_squeezes_to_view import ( + ConvertSqueezesToViewPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import Node -class ConvertMmToBmmPass(ExportPass): +class ConvertMmToBmmPass(ArmPass): """ This pass converts a MM node to a BMM one and turns input and output tensors from rank 2 to rank 3. The TOSA specification requires rank 3. The graph is @@ -28,6 +31,10 @@ class ConvertMmToBmmPass(ExportPass): 3) Squeeze output tensor to rank 2. """ + _passes_required_after: Set[Type[ExportPass]] = { + ConvertSqueezesToViewPass, + } + def call(self, graph_module: torch.fx.GraphModule): modified_graph = False graph = graph_module.graph @@ -47,7 +54,10 @@ def call(self, graph_module: torch.fx.GraphModule): with graph.inserting_before(node): unsqueeze_before = create_node( - graph, exir_ops.edge.aten.unsqueeze_copy.default, from_node=node + graph, + exir_ops.edge.aten.unsqueeze_copy.default, + from_node=node, + inherit_qparams=False, ) unsqueeze_before.args = ( input_node, # Input is node's original input @@ -55,17 +65,13 @@ def call(self, graph_module: torch.fx.GraphModule): ) node.replace_input_with(input_node, unsqueeze_before) - # If Quantized we must insert unsqueeze --> q --> dq --> node - if input_node.target in DQ_OPS: - q_params = input_node.args[1:] - insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node) - # Replace mm node with bmm with graph.inserting_before(node): bmm_node = create_node( graph, exir_ops.edge.aten.bmm.default, from_node=node, + inherit_qparams=True, ) bmm_node.args = node.args node.replace_all_uses_with(bmm_node) @@ -77,6 +83,7 @@ def call(self, graph_module: torch.fx.GraphModule): graph, exir_ops.edge.aten.squeeze_copy.dims, from_node=node, + inherit_qparams=False, ) squeeze_after.args = ( bmm_node, @@ -88,11 +95,6 @@ def call(self, graph_module: torch.fx.GraphModule): for user in original_users: user.replace_input_with(bmm_node, squeeze_after) - # If quantized, insert mm --> q --> dq --> squeeze - if all(original_user.target in Q_OPS for original_user in original_users): - q_params = original_users[0].args[1:] - insert_q_dq_pair(graph, bmm_node, q_params, from_node=node) - modified_graph = True if modified_graph: diff --git a/backends/arm/_passes/normalize_while_initial_args_pass.py b/backends/arm/_passes/normalize_while_initial_args_pass.py new file mode 100644 index 00000000000..fde8d5dd1ad --- /dev/null +++ b/backends/arm/_passes/normalize_while_initial_args_pass.py @@ -0,0 +1,113 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import cast, Sequence, Set, Type + +import torch + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassResult + + +class NormalizeWhileInitialArgsPass(ArmPass): + """ + Normalize ``torch.ops.higher_order.while_loop`` by moving additional_args to carried_args, + making the number of outputs equal to the number of inputs which is required by the TOSA specification. + Example: + def cond(val): + return val.sum() < 10 + + def body(val): + return (val * 2,) + while_loop(cond, body, (val,), additional_args= (buffer,)) + becomes: + def cond(val, buffer): + return val.sum() < 10 + + def body(val, buffer): + return (val * 2, buffer.clone()) + while_loop(cond, body, (val, buffer), ()) + + The clone is neccessary to avoid issues with aliasing. + """ + + def __init__(self, use_exir_clone: bool) -> None: + super().__init__() + if use_exir_clone: + self.clone_op = exir_ops.edge.aten.alias_copy.default + else: + self.clone_op = torch.ops.aten.clone.default + + _passes_required_after: Set[Type[ExportPass]] = set() + + def _connect_to_output( + self, body_module: GraphModule, placeholders: Sequence[Node] + ) -> list[Node]: + if not placeholders: + return [] + + cloned_placeholders = [] + with body_module.graph.inserting_after(placeholders[-1]): + for placeholder in placeholders: + clone = body_module.graph.create_node( + "call_function", + self.clone_op, + (placeholder,), + ) + cloned_placeholders.append(clone) + clone.meta = placeholder.meta + output_node = body_module.graph.output_node() + output_values = output_node.args[0] + if not isinstance(output_values, tuple): + raise RuntimeError("Output of a while should be a tuple.") + + output_node.update_arg(0, output_values + tuple(cloned_placeholders)) + body_module.recompile() + return list(cloned_placeholders) + + def _normalize_node(self, graph_module: GraphModule, node: Node) -> bool: + additional_inputs = list(cast(Sequence[Node], node.args[3])) + + if not additional_inputs: + return False + + carried_inputs = list(cast(Sequence[Node], node.args[2])) + new_carried = tuple(carried_inputs + additional_inputs) + node.update_arg(2, new_carried) + node.update_arg(3, ()) + + body_module_name = str(cast(Node, node.args[1]).target) + body_module = cast(GraphModule, graph_module.get_submodule(body_module_name)) # type: ignore + placeholders = [n for n in body_module.graph.nodes if n.op == "placeholder"] + num_inputs = len(placeholders) + old_num_inputs = len(carried_inputs) + if num_inputs != len(new_carried): + raise RuntimeError( + f"Length of loop placeholders {placeholders} is not equal length of carried inputs {new_carried}" + ) + + missing_placeholders = placeholders[old_num_inputs:] + self._connect_to_output(body_module, missing_placeholders) + + return True + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target != torch.ops.higher_order.while_loop: + continue + modified |= self._normalize_node(graph_module, node) + + if modified: + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/promote_bool_operands_pass.py b/backends/arm/_passes/promote_bool_operands_pass.py new file mode 100644 index 00000000000..8c45a808cb5 --- /dev/null +++ b/backends/arm/_passes/promote_bool_operands_pass.py @@ -0,0 +1,88 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool inputs. +# When a targeted op receives boolean tensors, we promote them to an integer type before +# invocation and cast the result back to the expected dtype afterwards. + +from typing import Set, Type + +import torch + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class PromoteBoolOperandsPass(ArmPass): + """Promote boolean operands to the appropriate integer dtype for unsupported ops.""" + + _passes_required_after: Set[Type[ExportPass]] = set() + + targeted_ops = { + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_xor.Tensor, + exir_ops.edge.aten.mul.Tensor, + } + + def call_operator(self, op, args, kwargs, meta): + if op not in self.targeted_ops: + return super().call_operator(op, args, kwargs, meta) + + original_dtypes = [arg.data.dtype for arg in args] + if torch.bool not in original_dtypes: + return super().call_operator(op, args, kwargs, meta) + + # select the first non-bool dtype, or None if all bool + promoted_dtype = next((dt for dt in original_dtypes if dt != torch.bool), None) + + # if we don't have a dtype specified by the op, promote to default choice for the op + if promoted_dtype is None: + if op == exir_ops.edge.aten.mul.Tensor: + # mul as int32 + promoted_dtype = torch.int32 + else: + # bitwise ops can be int8 + promoted_dtype = torch.int8 + + target_dtypes = [] + for dt in original_dtypes: + if dt == torch.bool: + target_dtypes.append(promoted_dtype) + else: + target_dtypes.append(dt) + + new_args = [] + for arg, original_dtype, target_dtype in zip( + args, original_dtypes, target_dtypes + ): + if original_dtype == target_dtype: + new_args.append(arg) + else: + new_args.append( + super().call_operator( + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + (arg,), + {"dtype": target_dtype}, + meta, + ) + ) + + output = super().call_operator( + op, + tuple(new_args), + kwargs, + meta, + ) + + if all(dtype == torch.bool for dtype in original_dtypes): + output = super().call_operator( + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + (output,), + {"dtype": torch.bool}, + meta, + ) + return output diff --git a/backends/arm/_passes/remove_getitem_pass.py b/backends/arm/_passes/remove_getitem_pass.py new file mode 100644 index 00000000000..3ce157d3fd8 --- /dev/null +++ b/backends/arm/_passes/remove_getitem_pass.py @@ -0,0 +1,14 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.transforms import remove_getitem_op +from executorch.exir.pass_base import ExportPass + + +class RemoveGetItemPass(ArmPass, remove_getitem_op.RemoveGetItemPass): + _passes_required_after: Set[Type[ExportPass]] = set() diff --git a/backends/arm/_passes/remove_graph_asserts_pass.py b/backends/arm/_passes/remove_graph_asserts_pass.py new file mode 100644 index 00000000000..a462c1182ee --- /dev/null +++ b/backends/arm/_passes/remove_graph_asserts_pass.py @@ -0,0 +1,18 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +from executorch.backends.arm._passes.arm_pass import ArmPass + +from executorch.backends.arm._passes.convert_int64_const_ops_to_int32 import ( + ConvertInt64ConstOpsToInt32Pass, +) +from executorch.exir.pass_base import ExportPass +from executorch.exir.passes import remove_graph_asserts_pass + + +class RemoveGraphAssertsPass(remove_graph_asserts_pass.RemoveGraphAssertsPass, ArmPass): + _passes_required_after: Set[Type[ExportPass]] = {ConvertInt64ConstOpsToInt32Pass} diff --git a/backends/arm/_passes/remove_noop_pass.py b/backends/arm/_passes/remove_noop_pass.py index 623517aac59..8ac808809ef 100644 --- a/backends/arm/_passes/remove_noop_pass.py +++ b/backends/arm/_passes/remove_noop_pass.py @@ -4,9 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import logging +from typing import Set, Type + +from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -14,13 +16,16 @@ logger = logging.getLogger(__name__) -class RemoveNoopPass(ExportPass): +class RemoveNoopPass(ArmPass): """Remove no-ops from graph_module""" + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op not in ( exir_ops.edge.dim_order_ops._clone_dim_order.default, exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.aten.copy.default, ): return super().call_operator(op, args, kwargs, meta) @@ -30,4 +35,6 @@ def call_operator(self, op, args, kwargs, meta): if input_dtype != output_dtype: return super().call_operator(op, args, kwargs, meta) + if op == exir_ops.edge.aten.copy.default: + return args[1] return args[0] diff --git a/backends/arm/_passes/replace_inf_and_limit_values_pass.py b/backends/arm/_passes/replace_inf_and_limit_values_pass.py new file mode 100644 index 00000000000..f8dff5701da --- /dev/null +++ b/backends/arm/_passes/replace_inf_and_limit_values_pass.py @@ -0,0 +1,49 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This pass is based on backends/qualcomm/_passes/replace_inf_values.py +# with some modification to replaced inf values. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.pass_base import ExportPass, PassResult + + +class ReplaceInfAndLimitValuesPass(ArmPass): + """ + Rewrites +inf/-inf and floating-point limit values (e.g., torch.finfo(...).min/max) + to quantization-friendly values (±255 by default), improving quantizer stability + (notably for attention mask paths). + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + for buf_name, tensor in graph_module.named_buffers(): + if tensor.is_floating_point(): + modified = True + # 255 here is mainly for attention_mask in Llama for reasonable quant scale + tensor[tensor == float("inf")] = 255 + tensor[tensor == float("-inf")] = -255 + setattr(graph_module, buf_name, tensor) + + for node in graph_module.graph.nodes: + arg_list = list(node.args) + for index, arg in enumerate(arg_list): + if arg == float("-inf") or arg == torch.finfo(torch.float32).min: + modified = True + arg_list[index] = -255.0 + elif arg == float("inf") or arg == torch.finfo(torch.float32).max: + modified = True + arg_list[index] = +255.0 + node.args = tuple(arg_list) + + if modified: + graph_module.recompile() + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/replace_inf_values_pass.py b/backends/arm/_passes/replace_inf_values_pass.py deleted file mode 100644 index 8c721eda3d8..00000000000 --- a/backends/arm/_passes/replace_inf_values_pass.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# This pass is based on backends/qualcomm/_passes/replace_inf_values.py -# with some modification to replaced inf values. - -import torch -from executorch.exir.pass_base import ExportPass, PassResult - - -class ReplaceInfValues(ExportPass): - """ - Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values. - """ - - def __init__(self): - super(ReplaceInfValues, self).__init__() - - def call(self, graph_module: torch.fx.GraphModule): - modified = False - for buf_name, tensor in graph_module.named_buffers(): - if tensor.is_floating_point(): - modified = True - # 255 here is mainly for attention_mask in Llama for reasonable quant scale - tensor[tensor == float("inf")] = 255 - tensor[tensor == float("-inf")] = -255 - setattr(graph_module, buf_name, tensor) - - for node in graph_module.graph.nodes: - arg_list = list(node.args) - for index, arg in enumerate(arg_list): - if arg == float("-inf"): - modified = True - arg_list[index] = -255 - elif arg == float("inf"): - modified = True - arg_list[index] = +255 - node.args = tuple(arg_list) - - if modified: - graph_module.recompile() - return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index 249eb9ffd41..9f6b672c4fa 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -3,18 +3,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe - -from typing import Dict, Union +from typing import Dict, Set, Type, Union import torch +from executorch.backends.arm._passes.insert_table_ops import TableOps + +from executorch.backends.arm.tosa.specification import get_context_spec from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass + +from .arm_pass import ArmPass # Operators that are included for both TOSA profiles @@ -37,6 +41,7 @@ exir_ops.edge.aten.bitwise_and.Scalar: exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_or.Scalar: exir_ops.edge.aten.bitwise_or.Tensor, exir_ops.edge.aten.bitwise_xor.Scalar: exir_ops.edge.aten.bitwise_xor.Tensor, + exir_ops.edge.aten.remainder.Scalar: exir_ops.edge.aten.remainder.Tensor, torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor, torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor, torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor, @@ -52,21 +57,63 @@ torch.ops.aten.bitwise_and.Scalar: torch.ops.aten.bitwise_and.Tensor, torch.ops.aten.bitwise_or.Scalar: torch.ops.aten.bitwise_or.Tensor, torch.ops.aten.bitwise_xor.Scalar: torch.ops.aten.bitwise_xor.Tensor, + torch.ops.aten.remainder.Scalar: torch.ops.aten.remainder.Tensor, } +_fp_profile_ops: Dict[ + Union[EdgeOpOverload, torch._ops.OpOverload], + Union[EdgeOpOverload, torch._ops.OpOverload], +] = _common_ops | { + exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor, +} -class ReplaceScalarWithTensorArgPassTOSAMI(ReplaceScalarWithTensorArgPass): - scalar_to_tensor_ops = _common_ops | { - exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor, - torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor, - } +_int_profile_ops: Dict[ + Union[EdgeOpOverload, torch._ops.OpOverload], + Union[EdgeOpOverload, torch._ops.OpOverload], +] = _common_ops - def __init__(self): - super().__init__(self.scalar_to_tensor_ops) +_all_ops: Dict[ + Union[EdgeOpOverload, torch._ops.OpOverload], + Union[EdgeOpOverload, torch._ops.OpOverload], +] = ( + _fp_profile_ops | _int_profile_ops +) -class ReplaceScalarWithTensorArgPassTOSABI(ReplaceScalarWithTensorArgPass): - scalar_to_tensor_ops = _common_ops +class ReplaceScalarWithTensorByProfilePass(ReplaceScalarWithTensorArgPass, ArmPass): + """Profile-aware scalar-to-tensor replacement pass for binary ops.""" + + _passes_required_after: Set[Type[ExportPass]] = set() def __init__(self): - super().__init__(self.scalar_to_tensor_ops) + # Initialize base (ReplaceScalarWithTensorArgPass) with the full + # superset which will make the superclass handle ops in _all_ops. + # Actual selection is done per-call in call_operator. + super().__init__(_all_ops) + + def call_operator(self, op, args, kwargs, meta): + tosa_spec = get_context_spec() + + included_ops = {} + if tosa_spec.support_integer(): + included_ops |= _int_profile_ops + if tosa_spec.support_float(): + included_ops |= _fp_profile_ops + + if included_ops == {}: + raise ValueError("Profile must support at least INT or FP") + + if op in TableOps.included_ops(): + # Do not handle quantized table ops; forward unchanged. + input_qparams = meta.data.get("input_qparams", {}) + output_qparams = meta.data.get("input_qparams", {}) + if len(input_qparams) > 0 and len(output_qparams) > 0: + # Do not handle; forward unchanged. + return ExportPass.call_operator(self, op, args, kwargs, meta) + + if op in included_ops: + # Include this op based on the current profile. + return super().call_operator(op, args, kwargs, meta) + else: + # Do not handle; forward unchanged. + return ExportPass.call_operator(self, op, args, kwargs, meta) diff --git a/backends/arm/_passes/rewrite_conv_pass.py b/backends/arm/_passes/rewrite_conv_pass.py new file mode 100644 index 00000000000..7582647eabb --- /dev/null +++ b/backends/arm/_passes/rewrite_conv_pass.py @@ -0,0 +1,339 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass + +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + expand_around_channel, + get_first_fake_tensor, + get_param_tensor, + is_buffer, + is_param, +) +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) +from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype +from executorch.backends.transforms.utils import create_constant_placeholder +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.export.graph_signature import InputKind + + +class RewriteConvPass(ArmPass): + """Rewrites aten.convolution to tosa.CONV2D or tosa.DEPTHWISE_CONV2D.""" + + def __init__(self, exported_program: torch.export.ExportedProgram): + super().__init__() + self.exported_program = exported_program + + _passes_required_after: Set[Type[ExportPass]] = set() + + # torch.nn.Conv2d does not require the result of + # `(input + 2 * pad - dilation * (weight - 1) - 1) / stride` + # to be an integer, but tosa currently strictly require this property. + # This function adjusts the pad value to meet the requirement. + def _adjust_pad_if_needed( + self, input_len: int, input_weight: int, stride: int, pad: int, dilation: int + ) -> int: + """Adjust padding to satisfy TOSA's integer output-size requirement. + + Torch ``Conv2d`` does not require the result of + ``(input + 2 * pad - dilation * (weight - 1) - 1) / stride`` to be an + integer, but TOSA does. This helper reduces the provided padding so + that the expression becomes divisible by ``stride``. + + Args: + input_size (int): Spatial input size along the dimension (H or W). + input_weight (int): Kernel size along the same dimension. + stride (int): Stride along the same dimension. + pad (int): Padding value to adjust (bottom or right after duplication). + dilation (int): Dilation along the same dimension. + + Returns: + int: Adjusted padding value that yields an integer output size. + + Raises: + RuntimeError: If the required adjustment exceeds the provided + padding, which should be handled by the ``SizeAdjustInputPass`` + pass instead. + + """ + mod_remainder = ( + input_len + 2 * pad - dilation * (input_weight - 1) - 1 + ) % stride + + # No need to adjust + if mod_remainder == 0: + return pad + + if mod_remainder > pad: + raise RuntimeError( + "This case should be handled by the SizeAdjustInputPass, is it enabled?" + ) + return pad - mod_remainder + + def _is_depthwise_conv2d(self, node: torch.fx.Node) -> bool: + if ( + node.op != "call_function" + or node.target != exir_ops.edge.aten.convolution.default + ): + return False + input_tensor = get_first_fake_tensor(node.all_input_nodes[0]) + if len(input_tensor.shape) != 4: + return False + groups = node.args[-1] + in_channels = input_tensor.shape[1] + out_channels = get_first_fake_tensor(node).shape[1] + return (in_channels == groups) and (out_channels % in_channels) == 0 + + def _is_conv3d(self, rank, groups) -> bool: + if rank == 5: + # A Conv3D is considered depthwise if Group == InChannels and + # Group * N == OutChannels, where N is a possitive integer. + # Currently we do not support depthwise or grouped conv3d. + # @TODO Add grouped/depthwise conv3d support or reject in partitioner. + if groups != 1: + raise RuntimeError( + "CONV3D with groups != 1 is not supported in the Arm backend." + ) + return True + return False + + def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None: + """Reshape the weights for depthwise convolution such that when serialized to TOSA, + the weights are in the format [H, W, in_channels, m_length] where + m_length is the number of output channels per input channel. + """ + weight_tensor = get_param_tensor(self.exported_program, weight_node) # type: ignore[arg-type] + if weight_tensor is None: + raise RuntimeError( + f"Weight node {weight_node.name} is not a parameter or buffer" + ) + + reshaped_weight_tensor = ( + weight_tensor.permute(HWCM_ORDER) + .reshape( + weight_tensor.shape[2], + weight_tensor.shape[3], + in_channels, + weight_tensor.shape[0] // in_channels, + ) + .permute(NHWC_INVERSE_ORDER) + ) + + if is_buffer(self.exported_program, weight_node): + param_name = self.exported_program.graph_signature.inputs_to_buffers[ + weight_node.name + ] + reshaped_weight_tensor = torch.nn.Buffer(reshaped_weight_tensor) + elif is_param(self.exported_program, weight_node): + param_name = self.exported_program.graph_signature.inputs_to_parameters[ + weight_node.name + ] + reshaped_weight_tensor = torch.nn.Parameter( + reshaped_weight_tensor, requires_grad=False + ) + else: + raise RuntimeError( + f"Weight node {weight_node.name} is neither a parameter nor a buffer" + ) + + self.exported_program.state_dict[param_name] = reshaped_weight_tensor + weight_node.meta["val"] = weight_node.meta["val"].reshape( + weight_tensor.shape[2], + weight_tensor.shape[0] // in_channels, + weight_tensor.shape[3], + in_channels, + ) + + def _add_bias( + self, + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + weight_node: torch.fx.Node, + ) -> torch.fx.Node: + output_channels = get_first_fake_tensor(node).shape[1] + # add a node containging zeros if quantized, use int32, otherwise use float32 + if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0: + bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32) + else: + bias_data = torch.zeros(size=(output_channels,), dtype=torch.float32) + + with graph_module.graph.inserting_after(weight_node): + bias_node = create_constant_placeholder( + self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + data=bias_data, + persistent_buffer=True, + name=f"{node.name}_bias", + ) + if node.all_input_nodes[0].meta["val"].dtype == torch.int16: + bias_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 + node.update_arg(2, bias_node) + return bias_node + + def insert_output_rescale(self, graph_module, node): + input_qparams = get_input_qparams(node) + output_qparams = get_output_qparams(node)[0] + weight_qparams = input_qparams[1] + input_qparams = input_qparams[0] + is_per_channel = weight_qparams.per_channel + if is_per_channel: + weight_scale = weight_qparams.get_scale_per_channel() + else: + weight_scale = [weight_qparams.get_scale_per_tensor()] + input_scale = input_qparams.get_scale_per_tensor() + post_conv2d_scale = [ + (inp * w) / out + for inp, w, out in zip( + itertools.cycle([input_scale]), + weight_scale, + itertools.cycle([output_qparams.get_scale_per_tensor()]), + ) + ] + with graph_module.graph.inserting_after(node): + rescale_node = create_node( + graph=graph_module.graph, + op_target=exir_ops.backend.tosa.RESCALE.default, + args=( + node, + output_qparams.dtype, + post_conv2d_scale, + 0, + output_qparams.get_zp_per_tensor(), + ), + from_node=node, + ) + return rescale_node + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 + modified = False + for node in graph_module.graph.nodes: + if ( + node.op != "call_function" + or node.target != exir_ops.edge.aten.convolution.default + ): + continue + + modified = True + + ( + x, + weight, + bias, + stride, + pad, + dilation, + transposed, + output_pad, + group, + ) = node.args + + input_fake_tensor = get_first_fake_tensor(x) + weight_fake_tensor = get_first_fake_tensor(weight) + input_shape = input_fake_tensor.shape + weight_shape = weight_fake_tensor.shape + spatial_rank = len(input_shape) - 2 + stride_list = expand_around_channel(stride, spatial_rank) + dilation_list = expand_around_channel(dilation, spatial_rank) + pad_list = expand_around_channel(pad, spatial_rank) + + pad_attr: list[int] = [] + for value in pad_list: + pad_attr.extend([value, value]) # duplicate pad before/after per axis + + for axis_index in range(spatial_rank): + pad_index = axis_index * 2 + 1 # adjust trailing pad entry + pad_attr[pad_index] = self._adjust_pad_if_needed( + input_shape[axis_index + 2], + weight_shape[axis_index + 2], + stride_list[axis_index], + pad_attr[pad_index], + dilation_list[axis_index], + ) + + stride = tuple(stride_list) + dilation = tuple(dilation_list) + pad = pad_attr + + has_bias = bias is not None + if not has_bias: + bias = self._add_bias(graph_module, node, weight) + + if self._is_conv3d(len(input_shape), group): + target_op = exir_ops.backend.tosa.CONV3D.default + elif self._is_depthwise_conv2d(node): + target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default + # If there are any TOSA.DEPTHWISE_CONV2D nodes using the weights, we've already reshaped them. + if all(user.target != target_op for user in weight.users): + self._reshape_weights(weight, input_fake_tensor.shape[1]) + weight_fake_tensor = get_first_fake_tensor(weight) + else: + target_op = exir_ops.backend.tosa.CONV2D.default + + conv_args = ( + x, + weight, + bias, + stride, + pad, + dilation, + transposed, + output_pad, + group, + ) + + with graph_module.graph.inserting_after(node): + tosa_op = create_node( + graph=graph_module.graph, + op_target=target_op, + args=conv_args, + from_node=node, + inherit_qparams=True, + ) + bias_fake_tensor = get_first_fake_tensor(bias) if bias else None + tosa_node_fake_tensor = target_op( + input_fake_tensor, + weight_fake_tensor, + bias_fake_tensor, + *conv_args[3:], + ) + + if ( + tosa_node_fake_tensor.dtype == torch.int32 + and input_fake_tensor.dtype == torch.int8 + ): + output_rescale = self.insert_output_rescale(graph_module, tosa_op) + node.replace_all_uses_with(output_rescale) + elif ( + tosa_node_fake_tensor.dtype == torch.int32 + and input_fake_tensor.dtype == torch.int16 + ): + has_bias = len(node.meta["input_qparams"]) > 2 + if not has_bias: + output_rescale = self.insert_output_rescale(graph_module, tosa_op) + node.replace_all_uses_with(output_rescale) + else: + node.replace_all_uses_with(tosa_op) + tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 + else: + node.replace_all_uses_with(tosa_op) + + graph_module.graph.erase_node(node) + + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/rewrite_matmul.py b/backends/arm/_passes/rewrite_matmul.py new file mode 100644 index 00000000000..298cfd17f0c --- /dev/null +++ b/backends/arm/_passes/rewrite_matmul.py @@ -0,0 +1,98 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RewriteMatmulPass(ArmPass): + """Rewrites aten.bmm to tosa.MATMUL and inserts a tosa.RESCALE op if needed.""" + + _passes_required_after: Set[Type[ExportPass]] = set() + + def _insert_output_rescale(self, graph_module, node, tosa_matmul_node, dtype): + input_qparams = get_input_qparams(node) + output_qparams = get_output_qparams(node)[0] + scale = ( + input_qparams[0].get_scale_per_tensor() + * input_qparams[1].get_scale_per_tensor() + ) / output_qparams.get_scale_per_tensor() + + with graph_module.graph.inserting_after(tosa_matmul_node): + # If the input is int8, we need to cast the output to int32 + rescale_node = create_node( + graph_module.graph, + op_target=exir_ops.backend.tosa.RESCALE.default, + from_node=tosa_matmul_node, + ) + tosa_matmul_node.replace_all_uses_with(rescale_node) + rescale_node.args = ( + tosa_matmul_node, + dtype, + [scale], + 0, + output_qparams.get_zp_per_tensor(), + ) + + def call(self, graph_module): + modified = False + for node in graph_module.graph.nodes: + if ( + node.op != "call_function" + or node.target != exir_ops.edge.aten.bmm.default + ): + continue + modified = True + + x1, x2 = node.args + tosa_matmul_target = exir_ops.backend.tosa.MATMUL.default + with graph_module.graph.inserting_before(node): + tosa_matmul_node = create_node( + graph_module.graph, + op_target=tosa_matmul_target, + args=(x1, x2), + kwargs={}, + from_node=node, + inherit_qparams=True, + ) + node.replace_all_uses_with(tosa_matmul_node) + graph_module.graph.erase_node(node) + + x1_fake_tensor = get_first_fake_tensor(x1) + x2_fake_tensor = get_first_fake_tensor(x2) + output_fake_tensor = tosa_matmul_target(x1_fake_tensor, x2_fake_tensor) + node_output_fake_tensor = get_first_fake_tensor(node) + if ( + output_fake_tensor.dtype == torch.int32 + and node_output_fake_tensor.dtype in (torch.int8, torch.int16) + ): + self._insert_output_rescale( + graph_module, + node, + tosa_matmul_node, + dtype=node_output_fake_tensor.dtype, + ) + if x1_fake_tensor.dtype == torch.int16: + tosa_matmul_node.meta[TosaSpecialDtype.meta_key()] = ( + TosaSpecialDtype.INT48 + ) + + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/rewrite_upsample.py b/backends/arm/_passes/rewrite_upsample.py new file mode 100644 index 00000000000..cff241d33cf --- /dev/null +++ b/backends/arm/_passes/rewrite_upsample.py @@ -0,0 +1,93 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype +from executorch.backends.arm.tosa.utils import get_resize_parameters +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RewriteUpsamplePass(ArmPass): + """Rewrite upsample2d nodes to TOSA.RESIZE nodes.""" + + targeted_ops = ( + exir_ops.edge.aten.upsample_nearest2d.vec, + exir_ops.edge.aten.upsample_bilinear2d.vec, + ) + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call(self, graph_module): + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in self.targeted_ops: + continue + modified = True + + if node.target == exir_ops.edge.aten.upsample_bilinear2d.vec: + x, output_size, align_corners, scale_factors = node.args + resize_mode = "bilinear" + else: + x, output_size, scale_factors = node.args + align_corners = False + resize_mode = "nearest" + + with graph_module.graph.inserting_before(node): + tosa_resize_node = create_node( + graph_module.graph, + op_target=exir_ops.backend.tosa.RESIZE.default, + args=(x, output_size, align_corners, scale_factors), + kwargs={"resize_mode": resize_mode}, + from_node=node, + inherit_qparams=True, + ) + node.replace_all_uses_with(tosa_resize_node) + graph_module.graph.erase_node(node) + input_dtype = get_first_fake_tensor(x).dtype + if ( + input_dtype == torch.int8 or input_dtype == torch.int16 + ) and resize_mode == "bilinear": + input_size = get_first_fake_tensor(x).shape + input_size_xy = input_size[2:] + output_size = get_first_fake_tensor(node).shape + output_size_xy = output_size[2:] + scale_n_yx, _, _, _ = get_resize_parameters( + input_size_xy=input_size_xy, + output_size_xy=output_size_xy, + resize_mode=1, + align_corners=align_corners, + ) + output_dtype = get_first_fake_tensor(node).dtype + output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1])) + with graph_module.graph.inserting_after(tosa_resize_node): + rescale_node = create_node( + graph_module.graph, + exir_ops.backend.tosa.RESCALE.default, + ) + tosa_resize_node.replace_all_uses_with(rescale_node) + if input_dtype == torch.int16: + tosa_resize_node.meta[TosaSpecialDtype.meta_key()] = ( + TosaSpecialDtype.INT48 + ) + + rescale_node.args = ( + tosa_resize_node, + output_dtype, + [output_scale], + 0, # zero point + 0, # zero point + ) + + if modified: + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index 89468bff1ff..ddef9c75213 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -4,24 +4,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe -from typing import cast, Union +from typing import cast, Set, Type, Union import torch +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix -class ScalarsToAttributePass(ExportPass): +class ScalarsToAttributePass(ArmPass): """ For ops in 'targeted_ops', convert inputs that are scalar values to attribute Nodes that output the same value. """ + _passes_required_after: Set[Type[ExportPass]] = {MatchArgRanksPass} + targeted_ops = [ torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor, @@ -46,7 +49,7 @@ def call(self, graph_module: GraphModule) -> PassResult: shape = get_first_fake_tensor(arg).shape biggest_rank = max(biggest_rank, len(shape)) - new_args = [] + new_args: list[Node | int] = [] for arg in n.args: if isinstance(arg, Node): new_args.append(arg) @@ -54,7 +57,7 @@ def call(self, graph_module: GraphModule) -> PassResult: if isinstance(arg, int) and not torch.is_floating_point( get_first_fake_tensor(n) ): - new_args.append(arg) # type: ignore[arg-type] + new_args.append(arg) continue prefix = "_tensor_constant_" diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index e87d65c450f..642a2499deb 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -3,12 +3,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe -from typing import cast, TypeAlias +from typing import cast, Sequence, Set, Type, TypeAlias import torch.fx -from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + expand_around_channel, +) +from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -38,19 +42,22 @@ def pooling_remainder(input_size, pad, kernel_size, stride) -> int: return (input_size + 2 * pad - kernel_size) % stride -def get_slices_conv2d(conv_node: torch.fx.Node) -> Slices: +def get_slices_convolution(conv_node: torch.fx.Node) -> Slices: slices = [] input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = conv_node.args weight_shape = cast(torch.fx.Node, weight).meta["val"].shape input_shape = cast(torch.fx.Node, input_node).meta["val"].shape + spatial_rank = len(input_shape) - 2 - for stride, pad, dilation, dim in zip( - cast(list, stride_hw), - cast(list, pad_hw), - cast(list, dilation_hw), - (2, 3), - ): + strides = expand_around_channel(cast(Sequence[int] | int, stride_hw), spatial_rank) + pads = expand_around_channel(cast(Sequence[int] | int, pad_hw), spatial_rank) + dilations = expand_around_channel( + cast(Sequence[int] | int, dilation_hw), spatial_rank + ) + + for axis_index, (stride, pad, dilation) in enumerate(zip(strides, pads, dilations)): + dim = axis_index + 2 remainder = conv_remainder( input_shape[dim], pad, dilation, weight_shape[dim], stride ) @@ -68,19 +75,16 @@ def get_slices_pooling(pooling_node: torch.fx.Node) -> Slices: input_node = pooling_node.args[0] kernel_size = pooling_node.args[1] stride = pooling_node.args[2] - padding = pooling_node.args[3] if len(pooling_node.args) >= 4 else [0, 0] - - # For the loop below, padding must be a list - if isinstance(padding, int): - padding = [padding, padding] + padding = pooling_node.args[3] if len(pooling_node.args) >= 4 else 0 input_shape = cast(torch.fx.Node, input_node).meta["val"].shape - for kernel_length, stride_length, pad_size, dim in zip( - cast(list, kernel_size), - cast(list, stride), - cast(list, padding), - (2, 3), + kernel_sizes = expand_around_channel(cast(Sequence[int] | int, kernel_size), 2) + strides = expand_around_channel(cast(Sequence[int] | int, stride), 2) + pads = expand_around_channel(cast(Sequence[int] | int, padding), 2) + + for dim, (kernel_length, stride_length, pad_size) in enumerate( + zip(kernel_sizes, strides, pads), start=2 ): remainder = pooling_remainder( input_shape[dim], pad_size, kernel_length, stride_length @@ -98,7 +102,7 @@ def get_slices(node: torch.fx.Node) -> Slices: Returns the remainder of input_length; given graph Node. """ if node.target == conv2d_op: - return get_slices_conv2d(node) + return get_slices_convolution(node) elif node.target == max_pooling_op or node.target == avg_pooling_op: return get_slices_pooling(node) else: @@ -112,17 +116,17 @@ def is_valid_operator(node: torch.fx.Node) -> bool: dilation = node.args[4] if len(node.args) >= 5 else 1 ceil_mode = node.args[5] if len(node.args) >= 6 else False - # Dilation should be handled first by DecomposeMaxPool2DPass + # Dilation should be handled first by DecomposeMaxPool2dPass if isinstance(dilation, int): if dilation > 1: raise ValueError( - "Expected max_pool2d with dilation = 1, has DecomposeMaxPool2DPass been run?" + "Expected max_pool2d with dilation = 1, has DecomposeMaxPool2dPass been run?" ) else: dilation = cast(list, dilation) if dilation[0] > 1 or dilation[1] > 1: raise ValueError( - "Expected max_pool2d with dilation = [1, 1], has DecomposeMaxPool2DPass been run?" + "Expected max_pool2d with dilation = [1, 1], has DecomposeMaxPool2dPass been run?" ) # If using ceil mode for rounding, the input does not need adjusting @@ -137,7 +141,7 @@ def is_valid_operator(node: torch.fx.Node) -> bool: return False -class SizeAdjustInputPass(ExportPass): +class SizeAdjustInputPass(ArmPass): """ Adjusts the input size to Conv2D and Pooling operators. PyTorch allows the input and kernel shape to not "match", in which case the remaining @@ -185,6 +189,10 @@ class SizeAdjustInputPass(ExportPass): input. """ + _passes_required_after: Set[Type[ExportPass]] = { + RewriteConvPass, + } + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph = graph_module.graph modified_graph = False @@ -204,7 +212,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: with graph_module.graph.inserting_before(node): last_node = cast(torch.fx.Node, parent_node) for args in slice_args: - slice_node = create_node(graph, slice_op, (last_node,) + args) + slice_node = create_node( + graph, slice_op, (last_node,) + args, from_node=node + ) last_node = slice_node node.replace_input_with(cast(torch.fx.Node, parent_node), last_node) modified_graph = True diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index e4436d638f4..07799a840dc 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -3,19 +3,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe - import logging +from typing import Set, Type import torch -from executorch.backends.arm._passes import AnnotateOutputDimOrderPass +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.annotate_decomposed_matmul import ( + AnnotateDecomposedMatmulPass, +) from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, - get_output_dim_orders, is_param_node, ) +from executorch.backends.arm.constants import NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -30,115 +32,193 @@ def _is_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool: return node.op == "placeholder" and not is_param_node(exported_program, node) -class ToTosaMemoryFormatPass(ExportPass): +class ToTosaMemoryFormatPass(ArmPass): """ Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts backend.tosa.TRANSPOSE when a transition between 3D and 4D/5D tensors happen. The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape. + This pass also makes other values aware of spatial dimensions required by future operators by back propogating info as required. """ - NHWC_order = (0, 2, 3, 1) - NHWC_inverse_order = (0, 3, 1, 2) - HWCM_order = (2, 3, 0, 1) - NNHWC_order = (0, 1, 3, 4, 2) - NNHWC_inverse_order = (0, 1, 4, 2, 3) + _passes_required_after: Set[Type[ExportPass]] = set() def __init__(self, exported_program: ExportedProgram) -> None: - self.exported_program = exported_program super().__init__() + self.exported_program = exported_program @staticmethod - def _is_consumer_node_depthwise_conv2d(node: torch.fx.Node): - consumer_node = list(node.users)[0] - if consumer_node.target == exir_ops.edge.aten.convolution.default: - consumer_node_inputs = consumer_node.all_input_nodes - groups = consumer_node.args[-1] - in_channels = consumer_node_inputs[0].meta["val"].shape[1] - out_channels = consumer_node_inputs[1].meta["val"].shape[0] - if (in_channels == groups) and (out_channels % in_channels) == 0: - return True - - return False - - def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): + def _channels_last_order(rank: int, spatial_rank: int) -> tuple[int, ...]: + """ + Compute the permutation of tensor dimensions corresponding to a + "channels_last"-style memory layout for an arbitrary tensor rank. + + In standard PyTorch convention: + - "channels_first" order is (N, C, H, W) + - "channels_last" order is (N, H, W, C) + This helper generalizes that concept beyond 4D tensors, producing an index + ordering that moves the channel dimension to the end while preserving the + relative order of batch and spatial dimensions. + + Args: + rank (int): Total number of tensor dimensions (e.g. 4 for NCHW). + spatial_rank (int): Number of spatial dimensions (e.g. 2 for HW, 3 for DHW). + Values outside [0, rank - 2] are clamped to that range. + + Returns: + tuple[int, ...]: A permutation of dimension indices that reorders the + tensor into "channels_last" format. For example: + - rank=4, spatial_rank=2 → (0, 2, 3, 1) # NCHW → NHWC + - rank=5, spatial_rank=3 → (0, 2, 3, 4, 1) # NCDHW → NDHWC + - rank=3, spatial_rank=1 → (0, 2, 1) + + Notes: + If `rank <= 2`, the function returns the identity order since there + are no distinct channel/spatial dimensions. + In practice only rank 4+ tensors will reach this function as the dim order should be fixed for those. """ - returns True for w in the following sequence; - w -> depthwise_conv2d -> ... + if rank <= 2: + return tuple(range(rank)) + spatial_rank = max(0, min(spatial_rank, rank - 2)) + channel_axis = rank - (spatial_rank + 1) + batch_axes = list(range(channel_axis)) + spatial_axes = list(range(channel_axis + 1, rank)) + return tuple(batch_axes + spatial_axes + [channel_axis]) + + @staticmethod + def _channels_last_inverse_order(rank: int, spatial_rank: int) -> tuple[int, ...]: """ - if node.op == "placeholder": - # node is an input, weight or bias node - consumer_node = list(node.users)[0] - if self.is_weight_node_for_depthwise_conv2d(consumer_node): - return True - if self._is_consumer_node_depthwise_conv2d(node): - # Check that node is the weight-argument and not input or bias - return consumer_node.args[1] == node + Return the inverse permutation of `_channels_last_order`. - return False + This provides the axis order needed to map a tensor from + "channels_last" layout back to its original layout. + """ + order = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank) + inverse = [0] * rank + for idx, axis in enumerate(order): + inverse[axis] = idx + return tuple(inverse) + + def _initial_spatial_rank(self, node: torch.fx.Node) -> int: + """ + Infer the initial spatial rank based on the current rank, input node spatial + ranks and node target. A spatial dimension includes Height, Width or Depth + fields. In most operators this will only ever be Height and Width, but for 3D + operators such as conv3d this would contain 3 spatial dims. + + Spatial rank is the max of any input node spatial ranks and the number of + trailing spatial dims we need to preserve (rank - 2, capped at 3). This + decides which axes must stay channels-last when inserting transposes. + """ + tensor = get_first_fake_tensor(node).data + # Start by assuming 2D when dealing with rank4+ to account for the base case + # of an increasing amount of batch dimensions. + rank = tensor.dim() + if rank >= 4: + spatial_rank = 2 + elif rank == 3: + spatial_rank = 1 + else: + spatial_rank = 0 + + # Look for supported 3D ops and update spatial rank if relevent. + # Currently only Conv3d is supported. + if node.target == exir_ops.backend.tosa.CONV3D.default: + spatial_rank = 3 + + # Check input spatial ranks to know what the previous node spatial ranks were. + input_ranks = [ + input_node.meta.get("tosa_spatial_rank", 0) + for input_node in node.all_input_nodes + ] + if input_ranks: + spatial_rank = max([spatial_rank, *input_ranks]) + + # The max that spatial rank can be is 3. If the current rank not capable of holding + # the current spatial rank, we clamp the max to Rank - (Channels and a singular batch dimension). + # This ensures we revert back to lower spatial ranks after we are finished processing higher spatial ops. + return min(spatial_rank, max(rank - 2, 0)) @staticmethod - def memory_format_differs(shape): - """Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format""" - if len(shape) >= 5: - C = shape[2] - H = shape[3] - W = shape[4] - elif len(shape) == 4: - C = shape[1] - H = shape[2] - W = shape[3] - elif len(shape) == 3: - C = shape[0] - H = shape[1] - W = shape[2] - if len(shape) <= 2: + def memory_format_differs(shape, spatial_rank): + """ + Determine whether a tensor shape would be laid out differently in + channels-first ((N)NCHW) versus channels-last ((N)NHWC) memory format. + """ + if len(shape) <= 2 or spatial_rank <= 0: return False - - return C > 1 and (H > 1 or W > 1) + channel_idx = len(shape) - (spatial_rank + 1) + channel_idx = max(0, min(channel_idx, len(shape) - 1)) + spatial_dims = shape[channel_idx + 1 :] + if not spatial_dims: + return False + channel_dim = shape[channel_idx] + return channel_dim > 1 and any(dim > 1 for dim in spatial_dims) @staticmethod - def is_channel_reshape(input_shape, output_shape): - """Returns true if the reshape changes the channel dimension""" - if not ( - (len(input_shape) == len(output_shape) and (len(output_shape) in (4, 5))) - or (len(input_shape) == 4 and len(output_shape) == 5) - or (len(input_shape) == 5 and len(output_shape) == 4) - ): + def is_channel_reshape( + input_shape, output_shape, input_spatial_rank, output_spatial_rank + ): + """ + Check whether a reshape touches the logical channel or consolidated + batch dimensions, which would invalidate dim-order annotations. + """ + + valid_ranks = {4, 5, 6} + + if not (len(input_shape) in valid_ranks and len(output_shape) in valid_ranks): return False - C_old = input_shape[-3] - C_new = output_shape[-3] + def channel_index(shape, spatial_rank): + if len(shape) <= 2: + return len(shape) - 1 + idx = len(shape) - (spatial_rank + 1) + return max(0, min(idx, len(shape) - 1)) - N_new = ( - output_shape[0] - if len(output_shape) == 4 - else output_shape[0] * output_shape[1] - ) - N_old = ( - input_shape[0] if len(input_shape) == 4 else input_shape[0] * input_shape[1] - ) + C_old = input_shape[channel_index(input_shape, input_spatial_rank)] + C_new = output_shape[channel_index(output_shape, output_spatial_rank)] + + def get_batch_prod_dim(shape, spatial_rank): + product = 1 + + for dim in shape[: channel_index(shape, spatial_rank)]: + product = product * dim + + return product + + N_old = get_batch_prod_dim(input_shape, input_spatial_rank) + N_new = get_batch_prod_dim(output_shape, output_spatial_rank) return (N_old != N_new) or (C_old != C_new) @staticmethod def insert_input_transpose(node, input_node, graph_module): + """ + Ensure an input tensor is converted to channels-last ordering by + inserting (or folding) a backend `TRANSPOSE` node. + """ if input_node.target == exir_ops.backend.tosa.TRANSPOSE.default: pre_permute_node = input_node.all_input_nodes[0] node.replace_input_with(input_node, pre_permute_node) return + rank = len(get_first_fake_tensor(input_node).size()) + spatial_rank = input_node.meta["tosa_spatial_rank"] + mem_format = ToTosaMemoryFormatPass._channels_last_inverse_order( + rank, spatial_rank + ) + # Guard: mem_format must be a true permutation for the current rank + assert sorted(mem_format) == list( + range(rank) + ), f"bad perm {mem_format} for rank {rank} in insert_input_transpose" + with graph_module.graph.inserting_before(node): permute_node = create_node( graph_module.graph, exir_ops.backend.tosa.TRANSPOSE.default, args=( input_node, - list( - ToTosaMemoryFormatPass.NNHWC_inverse_order - if len(get_first_fake_tensor(input_node).size()) == 5 - else ToTosaMemoryFormatPass.NHWC_inverse_order - ), + list(mem_format), ), from_node=node, ) @@ -147,32 +227,41 @@ def insert_input_transpose(node, input_node, graph_module): permute_node.meta["tosa_dim_order"] = tuple( range(len(input_node.meta["val"].size())) ) + permute_node.meta["tosa_spatial_rank"] = spatial_rank @staticmethod def insert_output_transpose(node, graph_module): + """ + Convert a producer's output to channels-last by appending a backend + `TRANSPOSE` node and rewiring its users. + """ + + rank = len(get_first_fake_tensor(node).size()) + spatial_rank = node.meta["tosa_spatial_rank"] + mem_format = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank) + # Guard: mem_format must be a true permutation for the current rank + assert sorted(mem_format) == list( + range(rank) + ), f"bad perm {mem_format} for rank {rank} in insert_input_transpose" + with graph_module.graph.inserting_after(node): permute_node = create_node( graph_module.graph, exir_ops.backend.tosa.TRANSPOSE.default, args=( node, - list( - ToTosaMemoryFormatPass.NNHWC_order - if len(get_first_fake_tensor(node).size()) == 5 - else ToTosaMemoryFormatPass.NHWC_order - ), + list(mem_format), ), from_node=node, ) - permute_node.meta["tosa_dim_order"] = ( - ToTosaMemoryFormatPass.NNHWC_order - if len(get_first_fake_tensor(node).size()) == 5 - else ToTosaMemoryFormatPass.NHWC_order - ) + rank = len(get_first_fake_tensor(node).size()) + permute_node.meta["tosa_dim_order"] = mem_format + node.meta["tosa_dim_order"] = tuple( range(len(get_first_fake_tensor(node).size())) ) + permute_node.meta["tosa_spatial_rank"] = spatial_rank users = [user for user in node.users if user != permute_node] for user in users: @@ -182,24 +271,33 @@ def insert_output_transpose(node, graph_module): def _insert_view_transpose( input_shape, output_shape, node, input_node, graph_module ): + """ + Insert the necessary input/output transposes around reshapes that cross + the (N)NCHW -> (N)NHWC boundary or that touch channel dimensions. + """ nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) >= 4 nhwc_to_nchw = len(input_shape) >= 4 and len(output_shape) < 4 + + input_sr = input_node.meta["tosa_spatial_rank"] + output_sr = node.meta["tosa_spatial_rank"] + channel_reshape = ToTosaMemoryFormatPass.is_channel_reshape( - output_shape, input_shape + input_shape, + output_shape, + input_sr, + output_sr, ) if ( channel_reshape or nhwc_to_nchw - ) and ToTosaMemoryFormatPass.memory_format_differs(input_shape): - + ) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr): ToTosaMemoryFormatPass.insert_input_transpose( node, input_node, graph_module ) if ( channel_reshape or nchw_to_nhwc - ) and ToTosaMemoryFormatPass.memory_format_differs(output_shape): - + ) and ToTosaMemoryFormatPass.memory_format_differs(output_shape, output_sr): ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module) def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): @@ -208,7 +306,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): This is relevant for the following cases: - view: <4D -> >=4D - view: >=4D -> <4D - Additionally, a 4D/5D->4D/5D view operation acting on the channel dimension currently needs to be performed in (N)NCHW format, leadning to one extra input and output transpose for this case. + Additionally, a 4D/5D->4D/5D view operation acting on the channel dimension currently needs to be performed in (N)NCHW format, leading to one extra input and output transpose for this case. Transposes can be avoided for shapes where there is no difference in actual memory, e.g for - H == W == 1 @@ -218,7 +316,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: # call_function and placeholder allowed due to # index.Tensor being able to come in as both - if node.op not in ["call_function", "placeholder", "output"]: + if node.op != "call_function": continue # Transpose views @@ -240,25 +338,48 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): graph_module, ) - # Transpose inputs - elif _is_input(node, self.exported_program): - input_shape = get_first_fake_tensor(node).size() - if len(input_shape) in (4, 5): - ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module) + output_node = graph_module.graph.output_node() - # Transpose outputs - elif node.op == "output": - output_shape = get_first_fake_tensor(node).size() + # Transpose inputs if they are in (N)NCHW format + inputs = [ + n for n in graph_module.graph.nodes if _is_input(n, self.exported_program) + ] + for input_node in inputs: + input_dim_order = get_first_fake_tensor(input_node).dim_order() + if input_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER): + self.insert_output_transpose(input_node, graph_module) + + # Transpose outputs if they are in (N)NCHW format + outputs = output_node.args[0] + if not isinstance(outputs, (list, tuple)): + raise TypeError( + f"Expected output node args to be a list or tuple, got {type(outputs)}" + ) + output_dim_orders = output_node.meta.get("original_dim_orders") + if output_dim_orders is None: + raise RuntimeError( + f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {ToTosaMemoryFormatPass.__name__}." + ) - if len(output_shape) in (4, 5): - for input_node in node.all_input_nodes: - ToTosaMemoryFormatPass.insert_input_transpose( - node, input_node, graph_module - ) + for output_node_input, output_dim_order in zip( + outputs, output_dim_orders, strict=True + ): + if output_dim_order in ( + NCHW_ORDER, + NNCHW_ORDER, + NNNCHW_ORDER, + ): + self.insert_input_transpose( + output_node, output_node_input, graph_module + ) def remove_dim_order_kwargs( self, graph_module: torch.fx.GraphModule, node: torch.fx.Node ): + """ + Drop any user-specified `dim_order` keyword arguments so the pass remains + the single source of truth for dim-order annotations. + """ if node.op != "call_function": return @@ -273,24 +394,31 @@ def remove_dim_order_kwargs( node.kwargs = kwargs def call(self, graph_module: torch.fx.GraphModule): - for node in graph_module.graph.nodes: - node_data = get_first_fake_tensor(node).data - + """ + Entry point for the pass: annotate spatial ranks, compute dim orders, + insert bridging transposes, and forward to child passes. + """ + nodes = list(graph_module.graph.nodes) + for node in nodes: + if "val" not in node.meta: + continue + node.meta["tosa_spatial_rank"] = self._initial_spatial_rank(node) self.remove_dim_order_kwargs(graph_module, node) - # Inputs and outputs are always in (N)NCHW format + + self._propagate_spatial_ranks(nodes) + + for node in nodes: + if "val" not in node.meta: + continue + node_data = get_first_fake_tensor(node).data + spatial_rank = node.meta["tosa_spatial_rank"] if _is_input(node, self.exported_program) or node.op == "output": - dim_order = tuple(range(node_data.dim())) - elif node_data.dim() == 4: - dim_order = self.NHWC_order - if self.is_weight_node_for_depthwise_conv2d(node): - # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to - # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d). - dim_order = self.HWCM_order - elif node_data.dim() == 5: - dim_order = self.NNHWC_order + dim_order = node_data.dim_order() else: - dim_order = tuple(range(node_data.dim())) # type: ignore[assignment] - + if node_data.dim() >= 4: + dim_order = self._channels_last_order(node_data.dim(), spatial_rank) + else: + dim_order = tuple(range(node_data.dim())) # type: ignore[assignment] node.meta["tosa_dim_order"] = dim_order # Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format. @@ -301,31 +429,26 @@ def call(self, graph_module: torch.fx.GraphModule): return PassResult(graph_module, True) - def requires(self, graph_module) -> None: + def _propagate_spatial_ranks(self, nodes): """ - This is the only pass which handles dim_orders, so verify that the output dim_orders has not changed since the beginning of the lowering pipeline. + Propagate `tosa_spatial_rank` metadata backwards so earlier nodes learn + about upcoming spatial requirements from future ops. """ - - dim_orders = get_output_dim_orders(graph_module) - original_dim_orders = graph_module.graph.output_node().meta.get( - "original_dim_orders" - ) - output_node = graph_module.graph.output_node() - - if original_dim_orders is None: - raise RuntimeError( - f"{AnnotateOutputDimOrderPass.__name__} must be run in the beginning of the pass pipeline to verify that the dim order has not changed unexpectedly during its run." - ) - - if len(dim_orders) != len(original_dim_orders): - raise RuntimeError( - f"The number of outputs has changed since {AnnotateOutputDimOrderPass.__name__} was run." - ) - - for node, dim_order, original_dim_order in zip( - output_node.args[0], dim_orders, original_dim_orders - ): - if dim_order != original_dim_order: - raise RuntimeError( - f"The dim order of output {node.name} has changed from {original_dim_order} to {dim_order} since {AnnotateOutputDimOrderPass.__name__} was run." - ) + changed = True + while changed: + changed = False + for node in reversed(nodes): + if "val" not in node.meta: + continue + tensor = get_first_fake_tensor(node) + limit = max(tensor.dim() - 2, 0) + current = node.meta.get("tosa_spatial_rank") + propagated = current + for user in node.users: + user_rank = user.meta.get("tosa_spatial_rank") + if user_rank is None: + continue + propagated = max(propagated, min(user_rank, limit)) + if propagated != current: + node.meta["tosa_spatial_rank"] = propagated + changed = True diff --git a/backends/arm/_passes/unsqueeze_before_repeat_pass.py b/backends/arm/_passes/unsqueeze_before_repeat_pass.py index 01983baa9ab..ed6aa82aad5 100644 --- a/backends/arm/_passes/unsqueeze_before_repeat_pass.py +++ b/backends/arm/_passes/unsqueeze_before_repeat_pass.py @@ -1,11 +1,13 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +from typing import Set, Type + import torch import torch.fx +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, @@ -14,7 +16,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -class UnsqueezeBeforeRepeatPass(ExportPass): +class UnsqueezeBeforeRepeatPass(ArmPass): """ A TOSA TILE op only supports rank(in) == rank(out). To support Pytorch's repeat which can also add dimensions, @@ -29,6 +31,8 @@ class UnsqueezeBeforeRepeatPass(ExportPass): repeat(multiples) """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule): modified_graph = False for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py index ccae9b503cf..15a105a1fd8 100644 --- a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py +++ b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py @@ -3,22 +3,28 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe + +from typing import Set, Type import torch +from executorch.backends.arm._passes import ArmPass +from executorch.exir import ExportedProgram from executorch.exir.pass_base import ExportPass, PassResult from torch._export.utils import is_buffer, is_param +from torch.export.graph_signature import InputKind -class UnsqueezeScalarPlaceholdersPass(ExportPass): +class UnsqueezeScalarPlaceholdersPass(ArmPass): """ Placeholders that have node.meta["val"].shape = () cause issues later in the lowering. This pass unsqueezes the placeholders to make sure shape is at least (1,). """ - def __init__(self, exported_program): - self.exported_program = exported_program + _passes_required_after: Set[Type[ExportPass]] = set() + + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() + self.exported_program = exported_program def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: @@ -37,17 +43,30 @@ def call(self, graph_module: torch.fx.GraphModule): else: continue - tensor = self.exported_program.state_dict[name] + tensor = self.exported_program.state_dict.get(name) + # If we have a persistent=False buffer with no entry in state_dict + spec = next( + s + for s in self.exported_program.graph_signature.input_specs + if getattr(s.arg, "name", None) == node.name + ) + is_non_persistent_buffer = ( + spec.kind is InputKind.BUFFER and spec.persistent is False + ) + if tensor is None and is_non_persistent_buffer: + fake = node.meta["val"] + tensor = torch.ones_like(fake) + + # If we have a scalar, unsqueeze it if tensor.dim() == 0: - self.exported_program.state_dict[name] = tensor.unsqueeze(0) - node.meta["val"] = node.meta["val"].fake_mode.from_tensor( - tensor.unsqueeze(0), static_shapes=True - ) - else: - node.meta["val"] = node.meta["val"].fake_mode.from_tensor( - tensor, static_shapes=True - ) + tensor = tensor.unsqueeze(0) + + # update or create entry in state_dict, recreate fake + self.exported_program.state_dict[name] = tensor + node.meta["val"] = node.meta["val"].fake_mode.from_tensor( + tensor, static_shapes=True + ) graph_module.recompile() graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py deleted file mode 100644 index 2e71f91dbb6..00000000000 --- a/backends/arm/arm_backend.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -# -# Main implementation of AoT flow to partition and preprocess for Arm target -# backends. Converts via TOSA as an intermediate form supported by AoT and -# JIT compiler flows. -# -from enum import Enum -from typing import List, Optional - -from executorch.backends.arm.tosa import TosaSpecification - -from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] - CompileSpec, -) - - -class ArmCompileSpecBuilder: - class DebugMode(Enum): - JSON = 1 - TOSA = 2 - - def __init__(self): - self.compile_spec: List[CompileSpec] = [] - self.compiler_flags = [] - self.output_format = None - self.path_for_intermediates = None - self.tosa_spec = None - self.tosa_debug_mode = None - - def vgf_compile_spec( - self, - tosa_spec: TosaSpecification = None, # type: ignore[assignment] - compiler_flags: Optional[str] = "", - ) -> "ArmCompileSpecBuilder": - """ - Generate compile spec for VGF compatible targets - - Args: - compiler_flags: Extra compiler flags for converter_backend - """ - self.output_format = "vgf" - self.compiler_flags = [ - compiler_flags, - ] - - if tosa_spec is None: - tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") - - tosa_version = tosa_spec.version # type: ignore[attr-defined] - tosa_profiles = tosa_spec.profiles # type: ignore[attr-defined] - - if tosa_version.major != 1: - raise ValueError( - "Arm backend only supports converter-backend for TOSA version 1. " - f"Invalid TOSA version: {tosa_version}" - ) - - if "FP" not in tosa_profiles and "INT" not in tosa_profiles: - raise ValueError( - "Arm backend only supports converter-backend for FP or INT. " - f"Invalid TOSA profile: {tosa_profiles}" - ) - - if len(tosa_profiles) != 1: - raise ValueError( - "For now Arm backend only supports converter-backend for either FP or INT. " - f"Invalid TOSA profile: {tosa_profiles}" - ) - - self.tosa_spec = tosa_spec - - return self - - def ethosu_compile_spec( - self, - target: str, - system_config: Optional[str] = None, - memory_mode: Optional[str] = None, - extra_flags: Optional[str] = None, - config_ini: Optional[str] = "Arm/vela.ini", - ) -> "ArmCompileSpecBuilder": - """ - Generate compile spec for Ethos-U NPU - - Args: - target: Ethos-U accelerator configuration, e.g. ethos-u55-128 - system_config: System configuration to select from the Vel - configuration file - memory_mode: Memory mode to select from the Vela configuration file - extra_flags: Extra flags for the Vela compiler - config_ini: Vela configuration file(s) in Python ConfigParser .ini - file format - """ - assert ( - self.output_format is None - ), f"Output format already set to f{self.output_format}" - self.output_format = "vela" - self.compiler_flags = [ - f"--accelerator-config={target}", - f"--config={config_ini}", - ] - - # default system config and memory mode - if "ethos-u55" in target: - if system_config is None: - system_config = "Ethos_U55_High_End_Embedded" - if memory_mode is None: - memory_mode = "Shared_Sram" - elif "ethos-u85" in target: - if system_config is None: - system_config = "Ethos_U85_SYS_DRAM_Mid" - if memory_mode is None: - memory_mode = "Sram_Only" - else: - raise RuntimeError(f"Unknown ethos target: {target}") - - if system_config is not None: - self.compiler_flags.append(f"--system-config={system_config}") - if memory_mode is not None: - self.compiler_flags.append(f"--memory-mode={memory_mode}") - if extra_flags is not None: - self.compiler_flags.append(extra_flags) - - # We require raw output and regor, so add these flags if absent. This - # overrides any other output setting. - self.compiler_flags.append("--output-format=raw") - self.compiler_flags.append("--debug-force-regor") - - base_tosa_version = "TOSA-1.0+INT+int16" - if "u55" in target: - # Add the Ethos-U55 extension marker - base_tosa_version += "+u55" - self.tosa_spec = TosaSpecification.create_from_string(base_tosa_version) - - return self - - def tosa_compile_spec( - self, tosa_spec: str | TosaSpecification - ) -> "ArmCompileSpecBuilder": - """ - Generate compile spec for TOSA flatbuffer output - """ - assert ( - self.output_format is None - ), f"Output format already set: {self.output_format}" - self.output_format = "tosa" - if isinstance(tosa_spec, TosaSpecification): - self.tosa_spec = tosa_spec - elif isinstance(tosa_spec, str): - self.tosa_spec = TosaSpecification.create_from_string(tosa_spec) - else: - raise RuntimeError(f"Invalid type for {tosa_spec}!") - return self - - def dump_intermediate_artifacts_to( - self, output_path: str - ) -> "ArmCompileSpecBuilder": - """ - Sets a path for dumping intermediate results during such as tosa and pte. - """ - self.path_for_intermediates = output_path - return self - - def dump_debug_info(self, debug_mode: DebugMode) -> "ArmCompileSpecBuilder": - """ - Dump debugging information into the intermediates path - """ - self.tosa_debug_mode = debug_mode.name - return self - - def build(self) -> List[CompileSpec]: - """ - Generate a list of compile spec objects from the builder - """ - assert self.tosa_spec - - # Always supply a TOSA version - self.compile_spec = [CompileSpec("tosa_spec", str(self.tosa_spec).encode())] - - # Add compile flags, these are backend specific, refer to the backend - # documentation. - self.compile_spec += [ - CompileSpec("compile_flags", " ".join(self.compiler_flags).encode()), - ] - - # encode output format - self.compile_spec.append( - CompileSpec("output_format", self.output_format.encode()) - ) - - if self.path_for_intermediates is not None: - self.compile_spec.append( - CompileSpec("debug_artifact_path", self.path_for_intermediates.encode()) - ) - - if self.tosa_debug_mode is not None: - if not self.path_for_intermediates: - raise ValueError( - "dump_debug_info() must be used in conjunction with dump_intermediate_artifacts_to()" - ) - - self.compile_spec.append( - CompileSpec("dump_debug_info", self.tosa_debug_mode.encode()) - ) - - return self.compile_spec - - -def is_tosa(compile_spec: List[CompileSpec]) -> bool: - has_tosa_output = False - has_tosa_spec = False - for spec in compile_spec: - if spec.key == "output_format": - has_tosa_output = spec.value.decode() == "tosa" - if spec.key == "tosa_spec": - has_tosa_spec = True - - return has_tosa_output and has_tosa_spec - - -def is_ethosu(compile_spec: List[CompileSpec]) -> bool: - for spec in compile_spec: - if spec.key == "output_format": - return spec.value.decode() == "vela" - return False - - -def is_vgf(compile_spec: List[CompileSpec]) -> bool: - for spec in compile_spec: - if spec.key == "output_format": - return spec.value.decode() == "vgf" - return False - - -def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]: - for spec in compile_spec: - if spec.key == "debug_artifact_path": - return spec.value.decode() - return None diff --git a/backends/arm/arm_vela.py b/backends/arm/arm_vela.py index c47a5c58f49..1ecaca3c454 100644 --- a/backends/arm/arm_vela.py +++ b/backends/arm/arm_vela.py @@ -3,7 +3,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import os import struct @@ -25,6 +24,8 @@ # per-io structs to simplify runtime use. def vela_bin_pack_io(prefix, data): vela_input_shapes = data[prefix + "_shape"] + # Vela input/output shape is fixed to 6D + vela_io_shape_dims = 6 ios = struct.pack(" bytes: tosaname = "out.tosa" - tosa_path = os.path.join(tmpdir, tosaname) + tosa_path = os.path.join(dir, tosaname) with open(tosa_path, "wb") as f: f.write(tosa_flatbuffer) # invoke vela - output_dir = os.path.join(tmpdir, "output") + output_dir = os.path.join(dir, "output") args.append(f"--output-dir={output_dir}") args.append(tosa_path) if verbose: @@ -70,9 +79,9 @@ def vela_compile(tosa_flatbuffer: bytes, args: List[str], verbose: bool = False) if any("ethos-u85" in arg for arg in args) or any( "debug-force-regor" in arg for arg in args ): - np_path = os.path.join(tmpdir, "output", "out_vela.npz") + np_path = os.path.join(dir, "output", "out_vela.npz") else: - np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz") + np_path = os.path.join(dir, "output", "out_sg0_vela.npz") blocks = b"" with np.load(np_path, allow_pickle=False) as data: @@ -120,3 +129,9 @@ def vela_compile(tosa_flatbuffer: bytes, args: List[str], verbose: bool = False) blocks = blocks + block return blocks + + if intermediate_path is not None: + return run(intermediate_path) + else: + with tempfile.TemporaryDirectory() as tmpdir: + return run(tmpdir) diff --git a/backends/arm/common/annotation_meta.py b/backends/arm/common/annotation_meta.py new file mode 100644 index 00000000000..12ef80ae70b --- /dev/null +++ b/backends/arm/common/annotation_meta.py @@ -0,0 +1,39 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Mapping, Optional + + +@dataclass(frozen=True, init=False) +class ArmAnnotationInfo(dict): + """ + Dataclass wrapper that behaves like a dict so serialization can treat it as + a plain mapping, while still exposing a typed attribute for convenience. + """ + + quantized: bool + CUSTOM_META_KEY: str = "_arm_annotation_info" + + def __init__( + self, + value: Optional[Mapping[str, Any]] = None, + *, + quantized: Optional[bool] = None, + ) -> None: + if quantized is not None: + resolved = bool(quantized) + + elif isinstance(value, Mapping): + resolved = bool(value.get("quantized", False)) + + else: + raise TypeError( + "ArmAnnotationInfo expects a mapping with a 'quantized' entry or a keyword 'quantized'." + ) + dict.__init__(self, quantized=resolved) + object.__setattr__(self, "quantized", resolved) diff --git a/backends/arm/common/arm_compile_spec.py b/backends/arm/common/arm_compile_spec.py new file mode 100644 index 00000000000..dda2930b306 --- /dev/null +++ b/backends/arm/common/arm_compile_spec.py @@ -0,0 +1,271 @@ +# Copyright 2023-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# +# Main implementation of AoT flow to partition and preprocess for Arm target +# backends. Converts via TOSA as an intermediate form supported by AoT and +# JIT compiler flows. +# + +import json +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum + +from executorch.backends.arm.common.pipeline_config import ArmPassPipelineConfig +from executorch.backends.arm.tosa import TosaSpecification + +from executorch.exir.backend.compile_spec_schema import CompileSpec + + +@dataclass(init=False) +class ArmCompileSpec(ABC): + class DebugMode(Enum): + JSON = 1 + TOSA = 2 + + tosa_spec: TosaSpecification + compiler_flags: list[str] = field(default_factory=list) + path_for_intermediates: str | None = None + tosa_debug_mode: DebugMode | None = None + + _TOSA_SPEC_KEY = "tosa_spec" + _COMPILE_FLAGS_KEY = "compile_flags" + _OUTPUT_FORMAT_KEY = "output_format" + _DEBUG_ARTIFACT_KEY = "debug_artifact_path" + _DEBUG_MODE_KEY = "dump_debug_info" + _OUTPUT_REORDER_KEY = "ouput_reorder_workaround" + _TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config" + + def _set_compile_specs( + self, + tosa_spec: TosaSpecification, + compiler_flags: list[str], + path_for_intermediates: str | None = None, + tosa_debug_mode: DebugMode | None = None, + output_order_workaround: bool = True, + pipeline_config: ArmPassPipelineConfig | None = None, + ): + """Set all values of dataclass directly.""" + self.tosa_spec = tosa_spec + self.compiler_flags = compiler_flags + self.path_for_intermediates = path_for_intermediates + self.tosa_debug_mode = tosa_debug_mode + self.output_order_workaround = output_order_workaround + self._pipeline_config = pipeline_config + + @classmethod + def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 + tosa_spec: TosaSpecification | None = None + output_format: str | None = None + compiler_flags: list[str] | None = None + path_for_intermediates: str | None = None + tosa_debug_mode: ArmCompileSpec.DebugMode | None = None + output_order_workaround: bool = True + pipeline_config: ArmPassPipelineConfig | None = None + unknown_specs: dict[str, str] = {} + for spec in compile_specs: + key = spec.key + val = ( + spec.value.decode() + if isinstance(spec.value, (bytes, bytearray)) + else spec.value + ) + if key == ArmCompileSpec._TOSA_SPEC_KEY: + if tosa_spec is not None: + raise ValueError("More than one tosa_spec entry in compile spec.") + tosa_spec = TosaSpecification.create_from_string(val) + elif key == ArmCompileSpec._COMPILE_FLAGS_KEY: + if compiler_flags is not None: + raise ValueError( + "More than one compiler flags entry in compile spec." + ) + compiler_flags = val.split(" ") + elif key == ArmCompileSpec._OUTPUT_FORMAT_KEY: + if output_format is not None: + raise ValueError( + "More than one output format entry in compile spec." + ) + output_format = val + elif key == ArmCompileSpec._DEBUG_ARTIFACT_KEY: + if path_for_intermediates is not None: + raise ValueError( + "More than one debug artifact path entry in compile spec." + ) + path_for_intermediates = val + elif key == ArmCompileSpec._DEBUG_MODE_KEY: + if tosa_debug_mode is not None: + raise ValueError( + "More than one tosa_debug_mode entry in compile spec." + ) + tosa_debug_mode = ArmCompileSpec.DebugMode[val] + elif key == ArmCompileSpec._OUTPUT_REORDER_KEY: + output_order_workaround = val # type: ignore[assignment] + elif key == ArmCompileSpec._TRANSFORM_PIPELINE_CONFIG_KEY: + if pipeline_config is not None: + raise ValueError( + "More than one transform pipeline entry in compile spec." + ) + pipeline_config = ArmPassPipelineConfig.from_dict(json.loads(val)) + else: + unknown_specs[key] = val + + if tosa_spec is None: + raise ValueError("No tosa_spec in compile spec.") + if output_format is None: + raise ValueError("No output_format in compile spec.") + if output_format != cls.get_output_format(): + raise ValueError( + f"Incorrect output format '{output_format}' for {cls.__name__}, expected '{cls.get_output_format()}'" + ) + if compiler_flags is None: + compiler_flags = [] + + # Create new object from class, but bypass __init__ and use _set_compile_specs instead. + compile_spec = cls.__new__(cls) + compile_spec._set_compile_specs( + tosa_spec=tosa_spec, + compiler_flags=compiler_flags, + path_for_intermediates=path_for_intermediates, + tosa_debug_mode=tosa_debug_mode, + output_order_workaround=output_order_workaround, + pipeline_config=pipeline_config, + ) + cls.from_list_hook(compile_spec, unknown_specs) + compile_spec.validate() + return compile_spec + + @classmethod + def from_list_hook(cls, compile_spec, specs: dict[str, str]): # noqa: B027 + """Allows subclasses to hook into parsing compile spec lists.""" + pass + + @abstractmethod + def validate(self): + """Throws an error if the compile spec is not valid.""" + + def to_list(self): + """Get the ArmCompileSpec in list form.""" + if not self.tosa_spec: + raise ValueError("tosa_spec must be set before calling to_list()") + + # Always supply a TOSA version + compile_spec = [ + CompileSpec(ArmCompileSpec._TOSA_SPEC_KEY, str(self.tosa_spec).encode()) + ] + + # Add compile flags, these are backend specific, refer to the backend + # documentation. + if len(self.compiler_flags) > 0: + compile_spec += [ + CompileSpec( + ArmCompileSpec._COMPILE_FLAGS_KEY, + " ".join(self.compiler_flags).encode(), + ), + ] + + # Add output format to identify kind of compile spec. + compile_spec.append( + CompileSpec( + ArmCompileSpec._OUTPUT_FORMAT_KEY, self.get_output_format().encode() + ) + ) + + if self.path_for_intermediates is not None: + compile_spec.append( + CompileSpec( + ArmCompileSpec._DEBUG_ARTIFACT_KEY, + self.path_for_intermediates.encode(), + ) + ) + + if self.tosa_debug_mode is not None: + if not self.path_for_intermediates: + raise ValueError( + "dump_debug_info() must be used in conjunction with dump_intermediate_artifacts_to()" + ) + + compile_spec.append( + CompileSpec( + ArmCompileSpec._DEBUG_MODE_KEY, self.tosa_debug_mode.name.encode() + ) + ) + + if not self.output_order_workaround: + compile_spec.append( + CompileSpec( + ArmCompileSpec._OUTPUT_REORDER_KEY, + self.output_order_workaround, + ) + ) + + if self._pipeline_config is not None and not self._pipeline_config.is_default(): + compile_spec.append( + CompileSpec( + ArmCompileSpec._TRANSFORM_PIPELINE_CONFIG_KEY, + self._pipeline_config.serialize(), + ) + ) + return compile_spec + + def get_pass_pipeline_config(self) -> ArmPassPipelineConfig: + """ + Returns configuration that controls how the Arm pass pipeline should behave. + Subclasses may override to tweak defaults for specific targets. + """ + if self._pipeline_config is None: + self._pipeline_config = self._create_default_pipeline_config() + return self._pipeline_config + + def set_pass_pipeline_config(self, config: ArmPassPipelineConfig) -> None: + self._pipeline_config = config + + def _create_default_pipeline_config(self) -> ArmPassPipelineConfig: + config = ArmPassPipelineConfig() + if self.tosa_spec.is_U55_subset: + config.disable_masked_softmax() + return config + + def get_intermediate_path(self) -> str | None: + """ + Gets the path used for dumping intermediate results such as tosa and pte. + + Returns: + Path where intermediate results are saved. + """ + return self.path_for_intermediates + + def dump_intermediate_artifacts_to(self, output_path: str | None): + """ + Sets a path for dumping intermediate results during such as tosa and pte. + + Args: + output_path: Path to dump intermediate results to. + """ + self.path_for_intermediates = output_path + return self + + def dump_debug_info(self, debug_mode: DebugMode | None): + """ + Dump debugging information into the intermediates path. + + Args: + debug_mode: The debug mode to use for dumping debug information. + """ + self.tosa_debug_mode = debug_mode + return self + + def set_output_order_workaround(self, output_order_workaround: bool): + self.output_order_workaround = output_order_workaround + return self + + def get_output_order_workaround(self) -> bool: + return self.output_order_workaround + + @classmethod + @abstractmethod + def get_output_format(cls) -> str: + """Returns a constant string that is the output format of the class.""" diff --git a/backends/arm/common/debug.py b/backends/arm/common/debug.py index bca6c06d140..e5c90fe7c3d 100644 --- a/backends/arm/common/debug.py +++ b/backends/arm/common/debug.py @@ -7,8 +7,9 @@ import os from typing import Optional -import serializer.tosa_serializer as ts # type: ignore import torch + +import tosa_serializer as ts from executorch.exir.print_program import inspect_node logger = logging.getLogger(__name__) @@ -50,29 +51,20 @@ def get_node_debug_info( return output -# Output TOSA flatbuffer and test harness file -def debug_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): +# Output TOSA flatbuffer for debugging +def debug_tosa_dump(tosa_graph: bytes, path: str, suffix: str = ""): filename = f"output{suffix}.tosa" logger.info(f"Emitting debug output to: {path=}, {suffix=}") os.makedirs(path, exist_ok=True) - fb = tosa_graph.serialize() - js = tosa_graph.writeJson(filename) - filepath_tosa_fb = os.path.join(path, filename) with open(filepath_tosa_fb, "wb") as f: - f.write(fb) + f.write(tosa_graph) if not os.path.exists(filepath_tosa_fb): raise IOError("Failed to write TOSA flatbuffer") - filepath_desc_json = os.path.join(path, f"desc{suffix}.json") - with open(filepath_desc_json, "w") as f: - f.write(js) - if not os.path.exists(filepath_desc_json): - raise IOError("Failed to write TOSA JSON") - def debug_fail( node, @@ -81,7 +73,7 @@ def debug_fail( path: Optional[str] = None, ): logger.warning("Internal error due to poorly handled node:") - if tosa_graph is not None and path is not None: - debug_tosa_dump(tosa_graph, path) + if tosa_graph is not None and path: + debug_tosa_dump(tosa_graph.serialize(), path) logger.warning(f"Debug output captured in '{path}'.") debug_node(node, graph_module) diff --git a/backends/arm/common/pipeline_config.py b/backends/arm/common/pipeline_config.py new file mode 100644 index 00000000000..bbceb3c0c60 --- /dev/null +++ b/backends/arm/common/pipeline_config.py @@ -0,0 +1,59 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +from dataclasses import dataclass, fields +from enum import auto, Enum +from typing import Any + + +class SoftmaxDecompositionConfig(Enum): + MASKED = auto() + UNSTABLE = auto() + + +class FuseDuplicateUsersConfig(Enum): + ENABLED = auto() + DISABLED = auto() + + +@dataclass +class ArmPassPipelineConfig: + softmax: SoftmaxDecompositionConfig = SoftmaxDecompositionConfig.MASKED + fuse_duplicate_users: FuseDuplicateUsersConfig = FuseDuplicateUsersConfig.ENABLED + + def disable_masked_softmax(self) -> None: + self.softmax = SoftmaxDecompositionConfig.UNSTABLE + + def disable_fuse_duplicate_users(self) -> None: + self.fuse_duplicate_users = FuseDuplicateUsersConfig.DISABLED + + def is_default(self) -> bool: + return ( + self.softmax is SoftmaxDecompositionConfig.MASKED + and self.fuse_duplicate_users is FuseDuplicateUsersConfig.ENABLED + ) + + def to_dict(self) -> dict[str, str]: + return {f.name: getattr(self, f.name).name for f in fields(self)} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ArmPassPipelineConfig": + config = cls() + for f in fields(cls): + raw_value = data.get(f.name) + if raw_value is None: + continue + enum_type = f.type + setattr(config, f.name, enum_type[raw_value]) + return config + + def serialize(self) -> bytes: + """Return a serialized representation of this config.""" + return json.dumps(self.to_dict()).encode() + + def __repr__(self): + fields = ", ".join(f"{name}={value!r}" for name, value in self.__dict__.items()) + return f"({fields})" diff --git a/backends/arm/common/type.py b/backends/arm/common/type.py new file mode 100644 index 00000000000..e53dc1ee769 --- /dev/null +++ b/backends/arm/common/type.py @@ -0,0 +1,28 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Type checking utilities.""" + +from typing import TypeVar + +T = TypeVar("T") + + +def ensure_type(expected_type: type[T], arg: object) -> T: + """Ensure that the argument is of the expected type. + + Args: + expected_type (type[T]): The expected type. + arg (object): The argument to check. + + Returns: + T: The argument, if it is of the expected type. + + """ + if isinstance(arg, expected_type): + return arg + + expected_name = getattr(expected_type, "__name__", str(expected_type)) + actual_name = type(arg).__name__ + raise TypeError(f"Expected value of type {expected_name}, got {actual_name!r}") diff --git a/backends/arm/constants.py b/backends/arm/constants.py index fd8710d3ead..0e562f12e88 100644 --- a/backends/arm/constants.py +++ b/backends/arm/constants.py @@ -29,3 +29,18 @@ DEQUANT_PER_TENSOR_OP_T, ) PER_CHANNEL_QDQ_OPS: Final = (QUANT_PER_CHANNEL_OP, DEQUANT_PER_CHANNEL_OP) + +NHWC_ORDER: Final = (0, 2, 3, 1) +NHWC_INVERSE_ORDER: Final = (0, 3, 1, 2) +NNHWC_ORDER: Final = (0, 1, 3, 4, 2) +NNHWC_INVERSE_ORDER: Final = (0, 1, 4, 2, 3) +NNNHWC_ORDER: Final = (0, 1, 2, 4, 5, 3) +NNNHWC_INVERSE_ORDER: Final = (0, 1, 2, 5, 3, 4) + +NCHW_ORDER: Final = (0, 1, 2, 3) +NNCHW_ORDER: Final = (0, 1, 2, 3, 4) +NNNCHW_ORDER: Final = (0, 1, 2, 3, 4, 5) + +HWCM_ORDER: Final = (2, 3, 0, 1) + +MAX_RANK: Final = 6 diff --git a/backends/arm/debug/TARGETS b/backends/arm/debug/TARGETS index 8ddfd9a285c..a88e3b077cd 100644 --- a/backends/arm/debug/TARGETS +++ b/backends/arm/debug/TARGETS @@ -8,7 +8,7 @@ runtime.python_library( "schema.py", ], deps = [ - "fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/serializer:serializer", + "fbsource//third-party/tosa_tools:serializer", "//caffe2:torch", ], ) diff --git a/backends/arm/debug/schema.py b/backends/arm/debug/schema.py index 82f0fd6bf7e..d4df2285304 100644 --- a/backends/arm/debug/schema.py +++ b/backends/arm/debug/schema.py @@ -10,10 +10,10 @@ from dataclasses import asdict, dataclass from typing import Any, Optional -import serializer.tosa_serializer as ts # type: ignore import torch +import tosa_serializer as ts -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from torch.fx.traceback import NodeSource @@ -112,25 +112,20 @@ def to_dict(self) -> dict[str, Any]: class DebugHook: - def __init__(self, debug_mode: ArmCompileSpecBuilder.DebugMode) -> None: + def __init__(self, debug_mode: ArmCompileSpec.DebugMode) -> None: self._debug_events: list[DebugSchema] = [] - self.__op_id_to_name = {} self.mode = debug_mode - # Build up a mapping from TOSA 1.0 operator IDs to their names - for name, val in vars(ts.Op).items(): - self.__op_id_to_name[val] = name - - def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema: + def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: ts.Op) -> DebugSchema: tosa_debug_info = None # If the debug data is being embedded into the TOSA flatbuffer # do not collect TOSADebugSchema data, it's redundent - if self.mode != ArmCompileSpecBuilder.DebugMode.TOSA: + if self.mode != ArmCompileSpec.DebugMode.TOSA: tosa_debug_info = TosaDebugSchema( node_name=str(tosa_op), - operator_name=self.__op_id_to_name[tosa_op_id], - operator_id=tosa_op_id, + operator_name=str(tosa_op_id), + operator_id=int(tosa_op_id), ) aten_debug_info = ATenDebugSchema.from_node(node) diff --git a/backends/arm/ethosu/__init__.py b/backends/arm/ethosu/__init__.py index f6cc1329dfe..10b14d4a68a 100644 --- a/backends/arm/ethosu/__init__.py +++ b/backends/arm/ethosu/__init__.py @@ -3,12 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # -# pyre-unsafe from .backend import EthosUBackend # noqa: F401 +from .compile_spec import EthosUCompileSpec # noqa: F401 from .partitioner import EthosUPartitioner # noqa: F401 -__all__ = [ - "EthosUBackend", - "EthosUPartitioner", -] +__all__ = ["EthosUBackend", "EthosUPartitioner", "EthosUCompileSpec"] diff --git a/backends/arm/ethosu/backend.py b/backends/arm/ethosu/backend.py index c748cf96e93..bd6da08dc38 100644 --- a/backends/arm/ethosu/backend.py +++ b/backends/arm/ethosu/backend.py @@ -3,18 +3,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe # # Main implementation of AoT flow to partition and preprocess for Arm target # backends. Converts via TOSA as an intermediate form supported by AoT and # JIT compiler flows. # +"""Ahead-of-time Arm Ethos-U backend built on the shared TOSA pipeline.""" import logging from typing import final, List from executorch.backends.arm.arm_vela import vela_compile +from executorch.backends.arm.ethosu.compile_spec import EthosUCompileSpec from executorch.backends.arm.tosa.backend import TOSABackend from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult @@ -27,24 +28,30 @@ @final class EthosUBackend(BackendDetails): - """ - BackendDetails subclass for delegation to Ethos-U. Deduce the TOSA lowering from - the compile spec list by filtering out the compile spec values that are of interest - for the TOSABackend. + """BackendDetails subclass for delegation to Ethos-U. + + Deduce the TOSA lowering from the compile spec list by filtering out the + compile spec values that are of interest for the TOSABackend. + """ @staticmethod def _compile_tosa_flatbuffer( - tosa_flatbuffer: bytes, compile_spec: List[CompileSpec] + tosa_flatbuffer: bytes, compile_spec: EthosUCompileSpec ) -> bytes: + """Compile a TOSA flatbuffer into a target-specific binary stream. + + Args: + tosa_flatbuffer (bytes): Serialized TOSA graph produced by + ``TOSABackend``. + compile_spec (EthosUCompileSpec): Compile specification providing + Vela flags and intermediate paths. + + Returns: + bytes: Target-specific binary stream produced by Vela. + """ - Static helper method to do the compilation of the TOSA flatbuffer - representation to a target specific binary stream. - """ - compile_flags = [] - for spec in compile_spec: - if spec.key == "compile_flags": - compile_flags.append(spec.value.decode()) + compile_flags = compile_spec.compiler_flags if len(compile_flags) == 0: # Not testing for compile_flags correctness here, just that they are @@ -53,21 +60,43 @@ def _compile_tosa_flatbuffer( "compile_flags are required in the CompileSpec list for EthosUBackend" ) + # Vela tooling only supports flatbuffers up to 2 GiB. + max_flatbuffer_size = 2 * 1024 * 1024 * 1024 + flatbuffer_size = len(tosa_flatbuffer) + if flatbuffer_size > max_flatbuffer_size: + raise RuntimeError( + "TOSA flatbuffer is too large for Vela " + f"({flatbuffer_size} bytes > {max_flatbuffer_size} bytes limit)." + ) + # Pass on the TOSA flatbuffer to the vela compiler. binary = vela_compile( tosa_flatbuffer, compile_flags, - verbose=logger.getEffectiveLevel() == logging.INFO, + verbose=logger.getEffectiveLevel() <= logging.INFO, + intermediate_path=compile_spec.get_intermediate_path(), ) return binary @staticmethod def preprocess( edge_program: ExportedProgram, - compile_spec: List[CompileSpec], + compile_specs: List[CompileSpec], ) -> PreprocessResult: + """Lower the exported program and compile it for an Ethos-U target. + + Args: + edge_program (ExportedProgram): Program to lower to Ethos-U. + compile_specs (List[CompileSpec]): Serialized Ethos-U compile specs + supplied by the frontend. + + Returns: + PreprocessResult: Result containing the compiled Ethos-U binary. + + """ logger.info(f"{EthosUBackend.__name__} preprocess") + compile_spec = EthosUCompileSpec.from_list(compile_specs) # deduce TOSA compile_spec from Ethos-U compile spec. We get a new # compile spec list, containing only elements relevant for the # TOSABackend. @@ -77,7 +106,7 @@ def preprocess( # ('All backend implementation are final...'), so use composition instead. # preprocess returns the serialized TOSA flatbuffer in .processed_bytes, # which can be passed on to next compilation step. - tosa_preprocess = TOSABackend.preprocess(edge_program, tosa_compile_spec) + tosa_preprocess = TOSABackend._preprocess(edge_program, tosa_compile_spec) binary = EthosUBackend._compile_tosa_flatbuffer( tosa_preprocess.processed_bytes, compile_spec diff --git a/backends/arm/ethosu/compile_spec.py b/backends/arm/ethosu/compile_spec.py new file mode 100644 index 00000000000..e2c49840f80 --- /dev/null +++ b/backends/arm/ethosu/compile_spec.py @@ -0,0 +1,115 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.common.pipeline_config import ( # noqa: unused + ArmPassPipelineConfig, +) +from executorch.backends.arm.tosa import ( # type: ignore[import-not-found] + TosaSpecification, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec + + +class EthosUCompileSpec(ArmCompileSpec): + """Compile specification for Ethos-U NPU targets.""" + + _TARGET_KEY = "target" + + def __init__( + self, + target: str, + system_config: str | None = None, + memory_mode: str | None = None, + extra_flags: list[str] | None = None, + config_ini: str | None = "Arm/vela.ini", + ): + """Normalise Ethos-U compile configuration and compiler flags. + + Args: + target (str): Ethos-U accelerator configuration (for example, + ``"ethos-u55-128"``). + system_config (str | None): System configuration name from the Vela + config file. Defaults based on ``target`` when omitted. + memory_mode (str | None): Memory mode selection from the Vela config + file. Defaults based on ``target`` when omitted. + extra_flags (list[str] | None): Additional command-line flags for + Vela. + config_ini (str | None): Path to a Vela .ini configuration file. + Defaults to ``"Arm/vela.ini"``. + + """ + self.target = target + # Set vela compiler flags + if config_ini is None: + config_ini = "Arm/vela.ini" + compiler_flags = [] if extra_flags is None else extra_flags + compiler_flags.extend( + [ + f"--accelerator-config={target}", + f"--config={config_ini}", + "--output-format=raw", + "--debug-force-regor", + ] + ) + # default system config and memory mode + target_lower = self.target.lower() + if "ethos-u55" in target_lower: + if system_config is None: + system_config = "Ethos_U55_High_End_Embedded" + if memory_mode is None: + memory_mode = "Shared_Sram" + elif "ethos-u85" in target_lower: + if system_config is None: + system_config = "Ethos_U85_SYS_DRAM_Mid" + if memory_mode is None: + memory_mode = "Sram_Only" + else: + raise RuntimeError(f"Unknown ethos target: {target}") + + compiler_flags.append(f"--system-config={system_config}") + compiler_flags.append(f"--memory-mode={memory_mode}") + + # Set TOSA version. + base_tosa_version = "TOSA-1.0+INT+int16" + if "u55" in target_lower: + # Add the Ethos-U55 extension marker + base_tosa_version += "+u55" + tosa_spec = TosaSpecification.create_from_string(base_tosa_version) + + self._set_compile_specs(tosa_spec, compiler_flags) + self.validate() + + def to_list(self): + """Return compile specs including the encoded Ethos-U target.""" + compile_specs = super().to_list() + compile_specs.append(CompileSpec(self._TARGET_KEY, self.target.encode())) + return compile_specs + + @classmethod + def from_list_hook(cls, compile_spec, specs: dict[str, str]): + """Restore target-specific metadata from serialized compile specs.""" + compile_spec.target = specs.get(cls._TARGET_KEY, None) + + def validate(self): + """Validate the configuration against supported Ethos-U settings.""" + if len(self.compiler_flags) == 0: + raise ValueError( + "compile_flags are required in the CompileSpec list for EthosUBackend" + ) + if "u55" in self.target and not self.tosa_spec.is_U55_subset: + raise ValueError( + f"Target was {self.target} but tosa spec was not u55 subset." + ) + + @classmethod + def get_output_format(cls) -> str: + """Return the artifact format emitted by this compile spec.""" + return "vela" + + def _create_default_pipeline_config(self) -> ArmPassPipelineConfig: + # Any u55 subset passes are treated as tosa specification configs + # As such, they should be added to the base class default. + return super()._create_default_pipeline_config() diff --git a/backends/arm/ethosu/partitioner.py b/backends/arm/ethosu/partitioner.py index d76b29eb1d9..9acc0439171 100644 --- a/backends/arm/ethosu/partitioner.py +++ b/backends/arm/ethosu/partitioner.py @@ -3,30 +3,33 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe -from typing import final, List, Optional, Sequence +from typing import final, Optional, Sequence -from executorch.backends.arm.arm_backend import ( - is_ethosu, -) # usort: skip -from executorch.backends.arm.ethosu import EthosUBackend +from executorch.backends.arm.ethosu import EthosUBackend, EthosUCompileSpec from executorch.backends.arm.tosa.partitioner import TOSAPartitioner -from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import DelegationSpec from torch.fx.passes.operator_support import OperatorSupportBase @final class EthosUPartitioner(TOSAPartitioner): + """ + Partitions subgraphs supported by the Arm Ethos-U backend. + + Args: + compile_spec: List of CompileSpec objects for Ethos-U backend. + additional_checks: Optional sequence of additional operator support checks. + """ + def __init__( self, - compile_spec: List[CompileSpec], + compile_spec: EthosUCompileSpec, additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> None: - if not is_ethosu(compile_spec): - raise RuntimeError("compile spec is not targeting Ethos-U") - # Override the delegation spec for Ethos-U - self.delegation_spec = DelegationSpec(EthosUBackend.__name__, compile_spec) + self.delegation_spec = DelegationSpec( + EthosUBackend.__name__, compile_spec.to_list() + ) self.additional_checks = additional_checks + self.tosa_spec = compile_spec.tosa_spec diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 7b73cddad37..01d936be7ce 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -3,10 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from . import ( # noqa clone_dim_order_support, + control_flow_support, convolution_support, embedding_support, ethos_u55_support, @@ -16,8 +16,8 @@ pool_2d_support, reduce_sum_support, right_shift_support, - sin_cos_support, slice_copy_support, to_dim_order_copy_support, tosa_supported_operators, + where_support, ) diff --git a/backends/arm/operator_support/clone_dim_order_support.py b/backends/arm/operator_support/clone_dim_order_support.py index 1397b74bf38..ae6445c050c 100644 --- a/backends/arm/operator_support/clone_dim_order_support.py +++ b/backends/arm/operator_support/clone_dim_order_support.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for dim-order clone in TOSA. + +This module registers a support check for ``dim_order_ops._clone_dim_order`` +ensuring input/output dtypes match and the value types are FakeTensors. + +""" import logging @@ -19,6 +25,8 @@ @register_tosa_support_check class CloneSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``_clone_dim_order``.""" + targets = [exir_ops.edge.dim_order_ops._clone_dim_order.default] tosa_specs = [ @@ -29,6 +37,12 @@ class CloneSupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: + """Return True if the node is supported by TOSA. + + Verify the operator target, the number and types of inputs/outputs, and + check that input and output dtypes match. + + """ if node.target not in self.targets: self.reporter.report_reject(node, f"Target {node.target} is not supported.") return False diff --git a/backends/arm/operator_support/control_flow_support.py b/backends/arm/operator_support/control_flow_support.py new file mode 100644 index 00000000000..24fa34f3462 --- /dev/null +++ b/backends/arm/operator_support/control_flow_support.py @@ -0,0 +1,162 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import typing +from typing import cast + +import torch +import torch.fx as fx + +from executorch.backends.arm._passes.arm_pass_utils import is_submodule_node +from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.specification import Tosa_1_00 +from executorch.exir import ExportedProgram +from executorch.exir.backend.utils import WhyNoPartitionReporter + +from torch.fx.passes.operator_support import OperatorSupportBase + + +def _fully_partitioned(submodule: fx.GraphModule) -> bool: + partition_tag = None + for submodule_node in submodule.graph.nodes: + if submodule_node.op == "call_function": + # Input Q ops and output DQ ops will be de-tagged even if the submodule is fully supported. + if ( + submodule_node.target in Q_OPS + and list(submodule_node.all_input_nodes)[0].op == "placeholder" + ): + continue + if ( + submodule_node.target in DQ_OPS + and list(submodule_node.users)[0].op == "output" + ): + continue + if "delegation_tag" not in submodule_node.meta: + return False + if partition_tag is None: + partition_tag = submodule_node.meta["delegation_tag"] + elif submodule_node.meta["delegation_tag"] != partition_tag: + return False + return True + + +def _submodules_fully_partitioned( + node: fx.Node, exported_program: ExportedProgram +) -> bool: + """Returns whether the submodule arguments to a cond node were fully partitioned. + Updates "val" meta of the submodules if they are. + """ + match node.target: + case torch.ops.higher_order.cond: + submodule_args = node.args[1:3] + case torch.ops.higher_order.while_loop: + submodule_args = node.args[0:2] + case _: + raise ValueError(f"Unexpected target: {node.target}") + cond_submodules = ( + ( + exported_program.graph_module.get_submodule( + str(cast(torch.fx.Node, submodule_node).target) + ), + cast(torch.fx.Node, submodule_node), + ) + for submodule_node in submodule_args + ) + for submodule, submodule_node in cond_submodules: + submodule = cast(torch.fx.GraphModule, submodule) + + if _fully_partitioned(submodule): + submodule_node.meta["val"] = submodule.graph.output_node().meta["val"] + else: + return False + return True + + +def _tosa_spec_supports_cf(tosa_spec: TosaSpecification) -> bool: + if not isinstance(tosa_spec, Tosa_1_00): + return False + return tosa_spec.support_extension("cf") + + +class ControlFlowSubmoduleSupported(OperatorSupportBase): + """Check whether control flow submodule args should be partitioned. + Applies control-flow extension constraints before allowing delegation.""" + + def __init__( + self, + exported_program: ExportedProgram, + tosa_spec: TosaSpecification, + reporter: WhyNoPartitionReporter, + ): + self.exported_program = exported_program + self.reporter = reporter + self.tosa_spec = tosa_spec + super().__init__() + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + if is_submodule_node(node): + if not _tosa_spec_supports_cf(self.tosa_spec): + self.reporter.report_reject( + node, + f"TOSA spec {self.tosa_spec} does not support control flow extension.", + ) + return False + for user in node.users: + if user.target not in ControlFlowOpSupported._targeted_ops: + self.reporter.report_reject( + node, f"Submodule had unsupported user {user}" + ) + return False + if not _submodules_fully_partitioned(user, self.exported_program): + self.reporter.report_reject( + node, "One submodule was not fully partitioned" + ) + return False + return True + return False + + +class ControlFlowOpSupported(OperatorSupportBase): + """Check whether control flow ops should be partitioned. + Applies control-flow extension constraints before allowing delegation.""" + + _targeted_ops = { + torch.ops.higher_order.cond, + torch.ops.higher_order.while_loop, + } + + def __init__( + self, + exported_program: ExportedProgram, + tosa_spec: TosaSpecification, + reporter: WhyNoPartitionReporter, + ): + self.exported_program = exported_program + self.reporter = reporter + self.tosa_spec = tosa_spec + super().__init__() + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + if node.target in self._targeted_ops: + if not _tosa_spec_supports_cf(self.tosa_spec): + self.reporter.report_reject( + node, + f"TOSA spec {self.tosa_spec} does not support control flow extension.", + ) + return False + + if not _submodules_fully_partitioned(node, self.exported_program): + self.reporter.report_reject( + node, "Submodule was not fully partitioned." + ) + return False + return True + + return False diff --git a/backends/arm/operator_support/convolution_support.py b/backends/arm/operator_support/convolution_support.py index 6e9d3b3528e..f335c5046f5 100644 --- a/backends/arm/operator_support/convolution_support.py +++ b/backends/arm/operator_support/convolution_support.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for ``aten.convolution`` in TOSA. + +Provide general checks and hardware-specific constraints (e.g., U55 subset) for +convolution nodes prior to delegation to the TOSA backend. + +""" from typing import cast @@ -18,6 +24,8 @@ @register_tosa_support_check class ConvolutionSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for convolutions.""" + targets = [exir_ops.edge.aten.convolution.default] tosa_specs = [ @@ -25,8 +33,15 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: + """Return True if the node is supported by TOSA. + Reject transposed convolutions and convolutions with non-zero output + padding. Apply additional hardware-specific constraints for U55. + + """ # Not implemented transposed = cast(bool, node.args[6]) output_padding = cast(list[int], node.args[7]) @@ -46,9 +61,19 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): else: return True - def _is_node_supported_u55(self, node: fx.Node): - """Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)""" + def _is_node_supported_u55(self, node: fx.Node) -> bool: + """Enforce Ethos-U55-specific constraints (Vela 4.2.0). + + Check channel dimensions, kernel sizes, and stride/pad/dilation + combinations permitted on U55. + Args: + node (fx.Node): Convolution node to validate. + + Returns: + bool: True if supported; otherwise, False. + + """ shape_in = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape shape_out = node.meta["val"].shape kernel = cast(fx.Node, node.args[1]).meta["val"].shape @@ -98,13 +123,17 @@ def _is_node_supported_u55(self, node: fx.Node): return True def _stride_condition(self, node: fx.Node) -> bool: - """This condition is somewhat complex but boils down - to not supporting stride > 3, unless we have some special conditions. - This condition is a simplified, relaxed version of the hardware constraint, - since the actual constraint requires information not available - here (without a lot of work). + """Check a simplified stride/padding/dilation constraint. + + Disallow strides greater than 3 unless there is no padding and the + dilation is 1. For 3D convolutions, enforce ``stride_z <= 1``. + + Args: + node (fx.Node): Convolution node to evaluate. + + Returns: + bool: True if the condition is satisfied. - This means that we might accept ops that are not actually supported. """ strides = cast(list[int], node.args[3]) has_padding = any(pad > 0 for pad in cast(list[int], node.args[4])) diff --git a/backends/arm/operator_support/embedding_support.py b/backends/arm/operator_support/embedding_support.py index bf95014e575..3ad17012cbb 100644 --- a/backends/arm/operator_support/embedding_support.py +++ b/backends/arm/operator_support/embedding_support.py @@ -2,7 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for ``aten.embedding`` in TOSA. +Permit embeddings with int32 indices (TOSA lacks int64 support); other dtypes +are rejected by this check. + +""" import torch @@ -17,6 +22,8 @@ @register_tosa_support_check class EmbeddingSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``aten.embedding``.""" + targets = [exir_ops.edge.aten.embedding.default] tosa_specs = [ @@ -27,11 +34,20 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] - # Note aten.embedding.default requires int64 indices and TOSA does not support it. - # Int32 indices here for aten.embedding.default is ok since it will be decomposed into ops that can handle it. - assert ( - len(node.all_input_nodes) == 2 - ), "Number of inputs to aten.embedding is not 2" + """Return True if the node is supported by TOSA. + + PyTorch's ``aten.embedding`` typically takes int64 indices, but for + TOSA we only allow int32 indices. The export path decomposes the op so + that int32 indices are ok. + + """ + if len(node.all_input_nodes) != 2: + self.reporter.report_reject( + node, + (f"Expected exactly two input nodes, got {len(node.all_input_nodes)}"), + ) + return False + indices_val = node.all_input_nodes[1].meta["val"] indices_dtype = indices_val.dtype diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index bf9e29d5cb7..bd43233454f 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -2,8 +2,14 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide Ethos-U55 specific operator support checks. + +Contains dtype validation, explicit unsupported-op filtering, and shape/ +permutation constraints for view and permute operations when targeting the +Ethos-U55 subset of TOSA. + +""" -# pyre-unsafe import typing from typing import cast @@ -12,6 +18,9 @@ import torch.fx as fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.convert_permute_singleton_to_view_pass import ( + is_singleton_permutation, +) from executorch.backends.arm._passes.insert_table_ops import TableOps from executorch.backends.arm.operators.op_permute import transform_permutation_vector from executorch.backends.arm.tosa.utils import tosa_shape @@ -21,6 +30,19 @@ def _try_determine_dtype(node: fx.Node) -> torch.dtype | None: + """Return an inferred dtype for a node when possible. + + Uses fake tensor metadata and nearby quantize/dequantize nodes to infer the + integer dtype used by the operator. Returns ``None`` when the dtype cannot + be determined reliably. + + Args: + node (fx.Node): FX node to inspect. + + Returns: + torch.dtype | None: Inferred dtype or ``None`` if unknown. + + """ dtype = get_first_fake_tensor(node).dtype if not dtype.is_floating_point: return dtype @@ -34,17 +56,34 @@ def _try_determine_dtype(node: fx.Node) -> torch.dtype | None: class EthosU55DtypeSupport(OperatorSupportBase): + """Validate dtypes for U55-supported operators. + + Ensures operators use a supported integer dtype according to U55 + constraints, with specific rules for convolution, matmul, and table ops. + + Attributes: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter. + + Args: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ super().__init__() self.reporter = reporter targeted_ops_i8_i16_i32 = [ exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.expand_copy.default, exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.view.default, exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, ] target_ops_i8 = tuple(TableOps.included_ops()) @@ -52,7 +91,20 @@ def __init__(self, reporter: WhyNoPartitionReporter): def is_node_supported( # noqa: C901 self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + """Return True if the node uses supported dtypes. + + Applies per-operator dtype rules for U55, including specialized input + and weight constraints for convolution and int8-only checks for table + operations and matmul variants. + Args: + submodules (typing.Mapping[str, torch.nn.Module]): Exported modules. + node (fx.Node): FX node to check. + + Returns: + bool: True if supported; otherwise, False. + + """ dtype = _try_determine_dtype(node) if dtype is None: # If we couldn't determine dtype, just return ok. @@ -66,9 +118,9 @@ def is_node_supported( # noqa: C901 return False if node.target in self.target_ops_i8: - if dtype not in (torch.int8,): + if dtype not in (torch.int8, torch.int16): self.reporter.report_reject( - node, f"Unsupported dtype {dtype} (Supports i8)." + node, f"Unsupported dtype {dtype} (Supports i8, i16)." ) return False @@ -112,10 +164,12 @@ def is_node_supported( # noqa: C901 class EthosU55NotSupported(OperatorSupportBase): - """ - Certain operators are not supported on U55. These are listed in `unsupported_ops`. - The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious. - For unimplemented operators, this is the anticipated mapping, and it might be incorrect. + """Reject operators not supported by Ethos-U55. + + The ``unsupported_ops`` list contains aten ops that either map to TOSA + operators the U55 cannot run or remain unimplemented. The mapping comments + capture expected TOSA equivalents when not obvious. + """ unsupported_ops = [ @@ -128,13 +182,15 @@ class EthosU55NotSupported(OperatorSupportBase): exir_ops.edge.aten.bitwise_and.Scalar, exir_ops.edge.aten.bitwise_or.Scalar, exir_ops.edge.aten.bitwise_xor.Scalar, - exir_ops.edge.aten.bitwise_not, + exir_ops.edge.aten.bitwise_not.default, exir_ops.edge.aten.logical_and.default, exir_ops.edge.aten.logical_or.default, exir_ops.edge.aten.logical_xor.default, exir_ops.edge.aten.logical_not.default, exir_ops.edge.aten.amax.default, # REDUCE_MAX exir_ops.edge.aten.amin.default, # REDUCE_MIN + exir_ops.edge.aten.conv3d.default, # CONV3D + exir_ops.edge.aten.conv3d.padding, # CONV3D (deprecated alias) exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.eq.Scalar, exir_ops.edge.aten.ge.Tensor, @@ -165,12 +221,27 @@ class EthosU55NotSupported(OperatorSupportBase): ] def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter. + + Args: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ self.reporter = reporter def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + """Return False for nodes explicitly unsupported on U55. + Args: + submodules (typing.Mapping[str, torch.nn.Module]): Exported modules. + node (fx.Node): FX node to check. + + Returns: + bool: False if ``node.target`` is in ``unsupported_ops``; else True. + + """ if node.target in self.unsupported_ops: self.reporter.report_reject(node, "Op is not supported on U55.") return False @@ -182,12 +253,37 @@ def is_node_supported( class EthosU55ViewCheck(OperatorSupportBase): + """Validate view/select shapes and dtypes for U55. + + Performs lightweight checks on output shape rank and product constraints, + with awareness that transposes may be inserted around view/select during + lowering to channels-last. + + Attributes: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter. + + Args: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ super().__init__() self.reporter = reporter def axes_product(self, nhwc_shape: shape_t) -> int: + """Return the product of all axes in ``nhwc_shape``. + + Args: + nhwc_shape (list[int]): Shape in NHWC order. + + Returns: + int: Product of the axis sizes. + + """ product = 1 for axes in nhwc_shape: product *= axes @@ -197,26 +293,27 @@ def axes_product(self, nhwc_shape: shape_t) -> int: def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - """ - Check whether a given view node is supported on U55. + """Check whether a given view/select node is U55-supported. Currently only checks dtypes and product of axes. - It is not the view operator itself that is not supported on U55. In order for the - view operator to be compatible with the channels-last format of TosaBackend, - transposes may need to be inserted before and after the view op. If that happens - and that transpose operator does not adhere to the limitations then it will - result in the following error: + It is not the view operator itself that is not supported on U55. In + order for the view operator to be compatible with the channels-last + format of TosaBackend, transposes may need to be inserted before and + after the view op. If that happens and that transpose operator does not + adhere to the limitations then it will result in the following error: CPU performance estimation for "Transpose" not implemented. ... CPU operations are not supported for GraphAPI input Args: - node: The FX node representing the view_copy operator. + submodules (typing.Mapping[str, torch.nn.Module]): Exported modules. + node (fx.Node): FX node for ``view_copy`` or ``select``. Returns: - False if the operator is not support and True if it is supported. + bool: False if rejected by constraints; otherwise, True. + """ # Select decomposes into squeeze, which in turn becomes a view. Therefore, # perform the same check on select operators as view operators. @@ -236,18 +333,20 @@ def is_node_supported( shape = input_node.meta["val"].shape rank = len(shape) if not -rank <= dim < rank: - raise IndexError( - f"Dim {dim} is outside of the range for tensor '{node.target}' of " - f"rank {rank}" + self.reporter.report_reject( + node, + (f"Dimension {dim} out of range for rank {rank}."), ) + return False dim = dim % rank size = shape[dim] if not -size <= index < size: - raise IndexError( - f"Index {index} is outside of the range for dim {dim} with size " - f"{size} for tensor {node.target}" + self.reporter.report_reject( + node, + (f"Index {index} out of range for dim {dim} with size {size}."), ) + return False index = index % size # Shape after squeeze. This may get converted into a view which may become @@ -277,14 +376,40 @@ def is_node_supported( class EthosU55TransposeCheck(OperatorSupportBase): + """Validate permute nodes against U55 reshape/transpose limits. + + Applies dtype- and rank-specific constraints to permutations. Tests both + NCHW and NHWC interpretations for rank-3/4 shapes since dim order is unknown + at partition time. + + Attributes: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter. + + Args: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ super().__init__() self.reporter = reporter def _pad_to_rank_4( self, shape: shape_t, permutation: list[int] ) -> tuple[shape_t, shape_t]: + """Pad shape/permutation to rank 4 by prepending ones/indices. + + Args: + shape (list[int]): Original shape. + permutation (list[int]): Original permutation indices. + + Returns: + tuple[list[int], list[int]]: Padded shape and permutation. + + """ diff = 4 - len(shape) padded_shape = [1] * diff + shape for i in range(len(permutation)): @@ -293,6 +418,15 @@ def _pad_to_rank_4( return padded_shape, padded_permutation def axes_product(self, nhwc_shape: shape_t) -> int: + """Return the product of all axes in ``nhwc_shape``. + + Args: + nhwc_shape (list[int]): Shape in NHWC order. + + Returns: + int: Product of the axis sizes. + + """ product = 1 for axes in nhwc_shape: product *= axes @@ -301,12 +435,19 @@ def axes_product(self, nhwc_shape: shape_t) -> int: def _permute_constraint_i8_i16( self, nhwc_shape: list[int], permutation: list[int] ) -> bool: - """Returns True if the constraints are ok.""" + """Return True if permutation meets i8/i16 constraints.""" N, H, W, C = nhwc_shape + + if is_singleton_permutation(nhwc_shape, permutation): + return True + match permutation: case (0, 1, 2, 3): # NHWC -> NHWC return True - case (0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2): # NHWC -> NWHC, NHCW, NCWH + case ( + (0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2) | (0, 2, 3, 1) | (0, 3, 2, 1) + ): + # NHWC -> NWHC, NHCW, NCWH, NCHW, NCHW -> NHWC return N * H <= 65536 and W <= 65536 and C <= 65536 case _: return self.axes_product(nhwc_shape) <= 65536 @@ -314,7 +455,7 @@ def _permute_constraint_i8_i16( def _permute_constraint_i32( self, nhwc_shape: list[int], permutation: list[int] ) -> bool: - """Returns True if the constraints are ok.""" + """Return True if permutation meets i32 constraints.""" N, H, W, C = nhwc_shape match permutation: case (0, 1, 2, 3): # NHWC -> NHWC @@ -327,6 +468,7 @@ def _permute_constraint_i32( return False def _permute_constraint(self, shape, permutation, dtype): + """Return True if permutation meets dtype-specific constraints.""" if dtype in (torch.int8, torch.int16): return self._permute_constraint_i8_i16(shape, permutation) if dtype == torch.int32: @@ -336,7 +478,19 @@ def _permute_constraint(self, shape, permutation, dtype): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + """Return True if a permute node satisfies U55 constraints. + + Tests both NCHW and NHWC interpretations for rank-3/4 shapes, and + applies dtype-specific limits to shapes and permutations. + + Args: + submodules (typing.Mapping[str, torch.nn.Module]): Exported modules. + node (fx.Node): FX node to check. + + Returns: + bool: True if supported; otherwise, False. + """ if not node.target == exir_ops.edge.aten.permute_copy.default: return True @@ -382,3 +536,63 @@ def is_node_supported( return False return True + + +class EthosU55CastCheck(OperatorSupportBase): + """Reject unsupported casts on U55. + + U55 does not support casting from INT32 or any casts involving BOOL. Note that + casting from one dtype to the same dtype is a no-op and is supported. + + + Attributes: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ + + targets = [ + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + ] + + def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter. + + Args: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ + super().__init__() + self.reporter = reporter + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + """Return True if the node satisfies the cast constraints of U55. + + Args: + submodules (typing.Mapping[str, torch.nn.Module]): Exported modules. + node (fx.Node): FX node to check. + + Returns: + bool: True if supported; otherwise, False. + + """ + if node.target not in self.targets: + return True + input_dtype = get_first_fake_tensor(node.all_input_nodes[0]).dtype + output_dtype = get_first_fake_tensor(node).dtype + if input_dtype == output_dtype: + # This is ok as this will not result in a cast + return True + if input_dtype in (torch.bool, torch.int32): + self.reporter.report_reject( + node, f"Casting from {input_dtype} is not supported on U55." + ) + return False + if output_dtype in (torch.bool,): + self.reporter.report_reject( + node, f"Casting to {output_dtype} is not supported on U55." + ) + return False + + return True diff --git a/backends/arm/operator_support/index_select_support.py b/backends/arm/operator_support/index_select_support.py index 79f1d154a14..a83151adab7 100644 --- a/backends/arm/operator_support/index_select_support.py +++ b/backends/arm/operator_support/index_select_support.py @@ -2,7 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for ``aten.index_select`` in TOSA. +Accept int32 indices and restrict supported weight shapes to 2D or 3D with a +unit batch dimension. + +""" import torch import torch.fx as fx from executorch.backends.arm.operator_support.tosa_supported_operators import ( @@ -15,6 +20,8 @@ @register_tosa_support_check class IndexSelectSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``aten.index_select``.""" + targets = [exir_ops.edge.aten.index_select.default] tosa_specs = [ @@ -25,7 +32,12 @@ class IndexSelectSupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] + """Return True if the node is supported by TOSA. + + Require int32 indices and limit weight shapes to 2D or 3D with a leading + dimension of 1. + """ weights_shape = node.all_input_nodes[0].meta["val"].shape indices_val = node.all_input_nodes[1].meta["val"] indices_dtype = indices_val.dtype diff --git a/backends/arm/operator_support/index_tensor_support.py b/backends/arm/operator_support/index_tensor_support.py index 4b226a9c407..5de70c0a2de 100644 --- a/backends/arm/operator_support/index_tensor_support.py +++ b/backends/arm/operator_support/index_tensor_support.py @@ -2,12 +2,19 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide TOSA support checks for ``aten.index.Tensor``. + +Reject unsupported patterns such as high-rank index tensors, front-positioned +slice/ellipsis/None markers, and cases that exceed ``int32`` element limits. + +""" import math import torch import torch.fx as fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.operator_support.tosa_supported_operators import ( register_tosa_support_check, SupportedTOSAOperatorCheck, @@ -18,7 +25,8 @@ @register_tosa_support_check class IndexTensorSupported(SupportedTOSAOperatorCheck): - """ + """Prevent partitioning of unsupported ``index.Tensor`` usages. + This support check is intended to prevent the partitioning of currently unsupported usages of the index.Tensor operator. @@ -95,6 +103,7 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck): t[1:3, torch.arange(5), 2:3, torch.arange(3).reshape(3,1)] are also possible and can result in some unintuitive behaviors where batching and indexing are mixed together. + """ targets = [exir_ops.edge.aten.index.Tensor] @@ -107,20 +116,45 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] + """Return True if ``aten.index.Tensor`` usage fits supported patterns. + + Enforces the following constraints: + - No ``None`` (unsqueeze), slice, or ellipsis before an indexing tensor. + - Indexing tensors have rank <= 3. + - The value tensor element count fits in ``int32``. + + """ indices = node.args[1] for index in indices: # type: ignore[union-attr] # Usage 2 guard if index is None: + self.reporter.report_reject( + node, + ( + "None (from slice/unsqueeze/ellipsis) before an indexing tensor" + " is not supported." + ), + ) return False # Usage 1 guard - fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type] + index = ensure_type(torch.fx.Node, index) + fake_tensor = get_first_fake_tensor(index) if len(fake_tensor.size()) > 3: + self.reporter.report_reject( + node, + ("Indexing tensors of rank >= 4 is not supported."), + ) return False # Usage 3 guard - total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type] + input_node = ensure_type(torch.fx.Node, node.args[0]) + total_vals = math.prod(get_first_fake_tensor(input_node).shape) if total_vals > torch.iinfo(torch.int32).max: + self.reporter.report_reject( + node, + ("Value size exceeds int32 range; would overflow flattened indexing."), + ) return False return True diff --git a/backends/arm/operator_support/minmax_support.py b/backends/arm/operator_support/minmax_support.py index edbf7f61818..8ba5d9335dc 100644 --- a/backends/arm/operator_support/minmax_support.py +++ b/backends/arm/operator_support/minmax_support.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for min/max along a dimension in TOSA. + +Provide support checks ensuring that argmax/argmin indices are not consumed, +restricting to float profiles until index quantization is supported. + +""" import torch.fx as fx from executorch.backends.arm.operator_support.tosa_supported_operators import ( @@ -14,6 +20,8 @@ @register_tosa_support_check class MinMaxSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``aten.max.dim`` and ``aten.min.dim``.""" + targets = [ exir_ops.edge.aten.max.dim, exir_ops.edge.aten.min.dim, @@ -24,7 +32,16 @@ class MinMaxSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: + """Return True if the node is supported by TOSA. + + Allow max/min when the argmax/argmin output is unused or dropped (i.e., + only the value is consumed). Disallow cases where arg indices are + further used. + + """ if node.target in [exir_ops.edge.aten.max.dim, exir_ops.edge.aten.min.dim]: no_argmax = len(node.users) == 1 no_argmax_users = (len(node.users) == 2) and ( @@ -32,6 +49,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): ) if not (no_argmax or no_argmax_users): + self.reporter.report_reject( + node, + ( + "Using the indices output is not supported; only usage of the " + "values output is supported." + ), + ) return False return True diff --git a/backends/arm/operator_support/pool_2d_support.py b/backends/arm/operator_support/pool_2d_support.py index ff453741f1f..c0428e45e03 100644 --- a/backends/arm/operator_support/pool_2d_support.py +++ b/backends/arm/operator_support/pool_2d_support.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide TOSA support checks for 2D pooling. + +Validate ``avg_pool2d`` and ``max_pool2d_with_indices`` against U55 profile +constraints including kernel size, stride, padding, and dimensionality. + +""" from typing import cast @@ -20,16 +26,48 @@ def kernel_check(kernel: tuple[int, int]) -> bool: + """Check if kernel size is within U55 constraints. + + Checks that ``kernel_x * kernel_y`` is in ``[1, 65536]`` and + ``kernel_y`` is in ``[1, 256]`` as required by the U55 profile. + + Args: + kernel (tuple[int, int]): Kernel height and width ``(kh, kw)``. + + Returns: + bool: True if the kernel passes validation. + + """ if not (1 <= kernel[0] * kernel[1] <= 65536): return False return 1 <= kernel[1] <= 256 def stride_check(strides: tuple[int, int]) -> bool: + """Check if strides are within U55 constraints. + + Args: + strides (tuple[int, int]): Vertical and horizontal strides. + + Returns: + bool: True if each stride is in ``[1, 3]``. + + """ return all(1 <= stride <= 3 for stride in strides) def dim_check(shape=torch.Size) -> bool: + """Check if non-batch dims are within U55 constraints. + + Verifies that all dimensions except batch are in ``[1, 65536]``. + + Args: + shape (torch.Size): Input tensor shape. + + Returns: + bool: True if all checked dimensions pass. + + """ check = True for dim in shape[1:]: check &= 1 <= dim <= 65536 @@ -38,6 +76,13 @@ def dim_check(shape=torch.Size) -> bool: @register_tosa_support_check class AvgPool2dSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support checks for ``aten.avg_pool2d``. + + Applies additional constraints when targeting the U55 subset, including + limits on kernel size, stride, padding behavior, and tensor ranks. + + """ + targets = [ exir_ops.edge.aten.avg_pool2d.default, ] @@ -48,6 +93,12 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck): ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + """Return True if ``avg_pool2d`` satisfies U55 constraints. + + Computes the effective TOSA padding (depending on ``count_include_pad`` + and ``divisor_override``) and validates kernel, stride, and shape limits. + + """ if not tosa_spec.is_U55_subset: return True @@ -115,6 +166,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): @register_tosa_support_check class MaxPool2dSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support checks for ``aten.max_pool2d_with_indices``. + + Applies additional constraints when targeting the U55 subset, including + limits on kernel size, stride, and tensor ranks. + + """ + targets = [ exir_ops.edge.aten.max_pool2d_with_indices.default, ] @@ -125,6 +183,9 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck): ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + """Return True if ``max_pool2d_with_indices`` satisfies U55 + constraints. + """ if not tosa_spec.is_U55_subset: return True diff --git a/backends/arm/operator_support/reduce_sum_support.py b/backends/arm/operator_support/reduce_sum_support.py index 4ff8f54ad69..02e9e0db90e 100644 --- a/backends/arm/operator_support/reduce_sum_support.py +++ b/backends/arm/operator_support/reduce_sum_support.py @@ -2,7 +2,11 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for ``aten.sum.dim_IntList`` in TOSA. +Provide shape constraints for U55 subsets; otherwise allow reductions. + +""" from typing import cast import torch.fx as fx @@ -16,6 +20,8 @@ @register_tosa_support_check class SumSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for sum over dimensions.""" + targets = [exir_ops.edge.aten.sum.dim_IntList] tosa_specs = [ @@ -23,14 +29,28 @@ class SumSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: + """Return True if the node is supported by TOSA. + + On U55 subsets, enforce bounds on the reduced dimension and the products + of sizes before/after the reduction axis. On other targets, accept the + operation unconditionally. + + """ if not tosa_spec.is_U55_subset: return True # U55 case, Vela 4.2.0 (25.02 release) input_shape = node.all_input_nodes[0].meta["val"].shape - dim_list = cast(list[int], node.args[1]) - dim_list = [dim % len(input_shape) for dim in dim_list] + + if node.args[1] is None: + # Dim is allowed to be None, which means to sum all dimensions + dim_list = list(range(len(input_shape))) + else: + dim_list = cast(list[int], node.args[1]) + dim_list = [dim % len(input_shape) for dim in dim_list] for dim in dim_list: if not 1 <= input_shape[dim] <= 65536: diff --git a/backends/arm/operator_support/right_shift_support.py b/backends/arm/operator_support/right_shift_support.py index 5d3896e3643..7670edec0a9 100644 --- a/backends/arm/operator_support/right_shift_support.py +++ b/backends/arm/operator_support/right_shift_support.py @@ -2,8 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for bitwise right-shift in TOSA. -# pyre-unsafe +Provide support checks for ``aten.bitwise_right_shift`` and ``__rshift__`` +targets across integer and float TOSA profiles. + +""" import logging @@ -21,6 +25,8 @@ @register_tosa_support_check class RightShiftSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for right-shift operations.""" + targets = [ exir_ops.edge.aten.bitwise_right_shift.Tensor, exir_ops.edge.aten.__rshift__.Scalar, @@ -31,9 +37,16 @@ class RightShiftSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: + """Return True if the node is supported by TOSA. + + Emit a warning on U55 subsets where one-off errors may occur. Otherwise + accept all matching targets. + """ # TODO MLETORCH-525 Remove warning if tosa_spec.is_U55_subset: - logging.warning(f"{node.target} may introduce one-off errors.") + logger.warning(f"{node.target} may introduce one-off errors.") return True diff --git a/backends/arm/operator_support/sin_cos_support.py b/backends/arm/operator_support/sin_cos_support.py deleted file mode 100644 index dcdc20f8e4a..00000000000 --- a/backends/arm/operator_support/sin_cos_support.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - - -import torch.fx as fx -from executorch.backends.arm.operator_support.tosa_supported_operators import ( - register_tosa_support_check, - SupportedTOSAOperatorCheck, -) -from executorch.backends.arm.tosa import TosaSpecification -from executorch.exir.dialects._ops import ops as exir_ops - - -@register_tosa_support_check -class SinCosSupported(SupportedTOSAOperatorCheck): - targets = [ - exir_ops.edge.aten.cos.default, - exir_ops.edge.aten.sin.default, - ] - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): - return True diff --git a/backends/arm/operator_support/slice_copy_support.py b/backends/arm/operator_support/slice_copy_support.py index 14ca505635c..77f3e97eb39 100644 --- a/backends/arm/operator_support/slice_copy_support.py +++ b/backends/arm/operator_support/slice_copy_support.py @@ -2,7 +2,11 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for ``aten.slice_copy`` in TOSA. +Support slicing with unit step only; emit a warning and reject otherwise. + +""" import logging @@ -19,6 +23,8 @@ @register_tosa_support_check class SliceCopySupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``aten.slice_copy``.""" + targets = [exir_ops.edge.aten.slice_copy.Tensor] tosa_specs = [ @@ -26,12 +32,17 @@ class SliceCopySupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc] - if tosa_spec not in self.tosa_specs: - return False + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: # type: ignore[override, misc] + """Return True if the node is supported by TOSA. + + Accept slice_copy when the step is 1 (or unspecified). Warn and reject + non-unit step sizes. + """ args = node.args if len(args) == 5 and (step := args[4]) != 1: - logging.warning(f"{node.target} with step size of {step} not supported.") + logger.warning(f"{node.target} with step size of {step} not supported.") return False return True diff --git a/backends/arm/operator_support/to_dim_order_copy_support.py b/backends/arm/operator_support/to_dim_order_copy_support.py index e21f8a68ad6..48f0c4d8604 100644 --- a/backends/arm/operator_support/to_dim_order_copy_support.py +++ b/backends/arm/operator_support/to_dim_order_copy_support.py @@ -2,8 +2,14 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for ``_to_dim_order_copy`` in TOSA. + +Provide dtype-compatibility checks for casting when converting to a specific +dimension order. Supported input/output dtype pairs depend on the active TOSA +profile (integer and/or float). + +""" -# pyre-unsafe import copy import logging @@ -25,6 +31,16 @@ @register_tosa_support_check class ToCopySupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``_to_dim_order_copy``. + + Attributes: + SUPPORTED_INT_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]): + Allowed output dtypes for each integer input dtype. + SUPPORTED_FP_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]): + Allowed output dtypes for each floating input dtype. + + """ + targets = [ exir_ops.edge.dim_order_ops._to_dim_order_copy.default, ] @@ -36,25 +52,34 @@ class ToCopySupported(SupportedTOSAOperatorCheck): @staticmethod def _merge_supported_types( - # pyre-ignore[11] dtypes1: SupportedTypeDict, dtypes2: SupportedTypeDict, ) -> SupportedTypeDict: + """Return a merged mapping of supported dtype transitions. + + Args: + dtypes1 (dict[torch.dtype, list[torch.dtype]]): Base mapping. + dtypes2 (dict[torch.dtype, list[torch.dtype]]): Mapping to merge in. + + Returns: + dict[torch.dtype, list[torch.dtype]]: Combined mapping. + + """ merged_dtypes = copy.deepcopy( dtypes1 - ) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_TYPES + ) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_PROFILE_DTYPES for k, v in dtypes2.items(): merged_dtypes[k] = merged_dtypes.get(k, []) + v return merged_dtypes - SUPPORTED_INT_TYPES: SupportedTypeDict = { + SUPPORTED_INT_PROFILE_DTYPES: SupportedTypeDict = { torch.bool: [torch.bool, torch.int8, torch.int16, torch.int32], torch.int8: [torch.bool, torch.int8, torch.int16, torch.int32], torch.int16: [torch.bool, torch.int8, torch.int16, torch.int32], torch.int32: [torch.bool, torch.int8, torch.int16, torch.int32], torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32], } - SUPPORTED_FLOAT_TYPES: SupportedTypeDict = { + SUPPORTED_FP_PROFILE_DTYPES: SupportedTypeDict = { torch.int8: [torch.int8, torch.float16, torch.bfloat16, torch.float32], torch.int16: [torch.int16, torch.float16, torch.bfloat16, torch.float32], torch.int32: [torch.int32, torch.float16, torch.bfloat16, torch.float32], @@ -89,24 +114,28 @@ def _merge_supported_types( torch.int32, torch.bfloat16, torch.float16, + torch.float32, ], } - ALL_SUPPORTED_TYPES = _merge_supported_types( - SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES - ) def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: + """Return True if the node is supported by TOSA. + + Check FakeTensor metadata, validate input dtype is supported for the + active profile, and ensure the output dtype is allowed for the given + input dtype. + """ supported_dtypes: SupportedTypeDict = {} if tosa_spec.support_integer(): supported_dtypes = self._merge_supported_types( - self.SUPPORTED_INT_TYPES, supported_dtypes + self.SUPPORTED_INT_PROFILE_DTYPES, supported_dtypes ) if tosa_spec.support_float(): supported_dtypes = self._merge_supported_types( - self.SUPPORTED_FLOAT_TYPES, supported_dtypes + self.SUPPORTED_FP_PROFILE_DTYPES, supported_dtypes ) if len(node.all_input_nodes) != 1: diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index d3207c65dff..f4f72690345 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Define TOSA profile support lists for INT and FP. + +Expose static sets of EXIR operator overloads used by the TOSA partitioner to +seed positive support checks for different profiles. + +""" import operator from typing import Final, Set @@ -12,6 +18,7 @@ # INT profile: ops supported via native TOSA ops, decompositions/transformations, precompute, TableOps, etc. +# Note that ops supported via pre-quantization decompositions are not included here. TOSA_PRO_INT_SupportList: Final[Set] = { exir_ops.edge.aten.abs.default, exir_ops.edge.aten.add.Tensor, @@ -24,6 +31,7 @@ exir_ops.edge.aten.bitwise_and.Scalar, exir_ops.edge.aten.bitwise_or.Scalar, exir_ops.edge.aten.bitwise_xor.Scalar, + exir_ops.edge.aten.cos.default, exir_ops.edge.aten.logical_and.default, exir_ops.edge.aten.logical_or.default, exir_ops.edge.aten.logical_xor.default, @@ -33,14 +41,13 @@ exir_ops.edge.aten.cat.default, exir_ops.edge.aten.ceil.default, exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.clamp.Tensor, exir_ops.edge.aten.cumsum.default, exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.permute_copy.default, exir_ops.edge.aten.hardsigmoid.default, exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardswish.default, - exir_ops.edge.aten.div.Tensor, - exir_ops.edge.aten.div.Tensor_mode, exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.eq.Scalar, exir_ops.edge.aten.erf.default, @@ -49,6 +56,7 @@ exir_ops.edge.aten.log.default, exir_ops.edge.aten.linear.default, exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.split_copy.Tensor, exir_ops.edge.aten.floor.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.full_like.default, @@ -61,16 +69,7 @@ exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.lt.Scalar, exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.ne.Tensor, - exir_ops.edge.aten.ne.Scalar, exir_ops.edge.aten.neg.default, - exir_ops.edge.aten.add.Scalar, - exir_ops.edge.aten.sub.Scalar, - exir_ops.edge.aten.mul.Scalar, - exir_ops.edge.aten.div.Scalar, - exir_ops.edge.aten._native_batch_norm_legit_no_training.default, - exir_ops.edge.aten.native_layer_norm.default, - exir_ops.edge.aten.native_group_norm.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mm.default, @@ -79,25 +78,18 @@ exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.relu.default, - exir_ops.edge.aten.leaky_relu.default, - exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.remainder.Tensor, exir_ops.edge.aten.rsqrt.default, - exir_ops.edge.aten.round.default, - exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.select_copy.int, - exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.tanh.default, exir_ops.edge.aten.upsample_bilinear2d.vec, exir_ops.edge.aten.upsample_nearest2d.vec, - exir_ops.edge.aten.var.correction, - exir_ops.edge.aten.var.dim, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.pow.Tensor_Tensor, - exir_ops.edge.aten.where.self, operator.getitem, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.quantize_per_channel.default, @@ -113,6 +105,7 @@ torch.ops.aten.scalar_tensor.default, exir_ops.edge.aten.gelu.default, exir_ops.edge.aten.alias_copy.default, + exir_ops.edge.aten.sin.default, exir_ops.edge.aten.sinh.default, exir_ops.edge.aten.atan.default, exir_ops.edge.aten.acosh.default, @@ -120,14 +113,13 @@ exir_ops.edge.aten.sign.default, exir_ops.edge.aten.asin.default, exir_ops.edge.aten.atanh.default, - exir_ops.edge.aten.addmm.default, exir_ops.edge.aten.masked_fill.Scalar, exir_ops.edge.aten.asinh.default, exir_ops.edge.aten.cosh.default, - exir_ops.edge.aten.glu.default, - exir_ops.edge.aten.logit.default, exir_ops.edge.aten.acos.default, exir_ops.edge.aten.elu.default, + exir_ops.edge.aten.bitwise_not.default, + exir_ops.edge.aten.copy.default, } @@ -147,6 +139,8 @@ exir_ops.edge.aten.cat.default, exir_ops.edge.aten.ceil.default, exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.clamp.Tensor, + exir_ops.edge.aten.cos.default, exir_ops.edge.aten.cumsum.default, exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.permute_copy.default, @@ -163,6 +157,7 @@ exir_ops.edge.aten.log.default, exir_ops.edge.aten.linear.default, exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.split_copy.Tensor, exir_ops.edge.aten.floor.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.full_like.default, @@ -187,12 +182,15 @@ exir_ops.edge.aten.native_group_norm.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.mean.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.remainder.Scalar, + exir_ops.edge.aten.remainder.Tensor, exir_ops.edge.aten.leaky_relu.default, exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.rsqrt.default, @@ -211,7 +209,6 @@ exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.pow.Tensor_Tensor, - exir_ops.edge.aten.where.self, operator.getitem, exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.amax.default, @@ -223,6 +220,7 @@ torch.ops.aten.scalar_tensor.default, exir_ops.edge.aten.gelu.default, exir_ops.edge.aten.alias_copy.default, + exir_ops.edge.aten.sin.default, exir_ops.edge.aten.sinh.default, exir_ops.edge.aten.atan.default, exir_ops.edge.aten.acosh.default, @@ -238,6 +236,8 @@ exir_ops.edge.aten.logit.default, exir_ops.edge.aten.acos.default, exir_ops.edge.aten.elu.default, + exir_ops.edge.aten.copy.default, + exir_ops.edge.aten.floor_divide.default, } diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index b580fbb9a9a..9240f14da54 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -2,8 +2,13 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide operator-support checks and registries for TOSA delegation. + +Define a base check class, a registry/dispatcher, and several generic checks +used by the TOSA partitioner to decide if FX nodes are eligible for delegation. + +""" -# pyre-unsafe import itertools import operator @@ -13,14 +18,25 @@ import torch import torch.fx as fx -from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.arm_pass_utils import ( + get_first_fake_tensor, + is_submodule_node, +) +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( FuseQuantizedActivationPass, ) from executorch.backends.arm._passes.insert_table_ops import TableOps -from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo +from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS +from executorch.backends.arm.operator_support.control_flow_support import ( + ControlFlowOpSupported, + ControlFlowSubmoduleSupported, +) from executorch.backends.arm.operator_support.ethos_u55_support import ( + EthosU55CastCheck, EthosU55DtypeSupport, EthosU55NotSupported, EthosU55TransposeCheck, @@ -30,7 +46,10 @@ TOSA_PRO_FP_SupportList, TOSA_PRO_INT_SupportList, ) -from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.specification import ( + TosaSpecification, + TosaSpecMapping, +) from executorch.exir import ExportedProgram from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops @@ -42,15 +61,31 @@ class SupportedTOSAOperatorCheck(OperatorSupportBase): - """ - Supported OP for TOSA lowering + """Provide a base operator-support check for TOSA lowering. + + Subclasses should implement :py:meth:`is_node_tosa_supported` and declare + the class attributes below to indicate what they support. + + Attributes: + targets (list[OpOverload]): Operator overloads supported by this + check. + tosa_specs (list[TosaSpecification]): TOSA specs where the check is + applicable. + """ def __init__(self, tosa_spec: TosaSpecification, reporter: WhyNoPartitionReporter): + """Initialize the check with a TOSA spec and reporter. + + Args: + tosa_spec (TosaSpecification): Active TOSA specification. + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ self.tosa_spec = tosa_spec self.reporter = reporter - # Should be populated by subclass implementation + # Class attributes populated by subclasses tosa_specs: list[TosaSpecification] = [] targets: list[str] = [] @@ -58,6 +93,17 @@ def __init__(self, tosa_spec: TosaSpecification, reporter: WhyNoPartitionReporte def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + """Return True if the node matches targets and subclass-specific checks. + + Args: + submodules (typing.Mapping[str, torch.nn.Module]): Exported program + modules. + node (fx.Node): Node to evaluate. + + Returns: + bool: True if both the target and TOSA-specific checks pass. + + """ if node.target not in self.targets: return False return self.is_node_tosa_supported(node, self.tosa_spec) @@ -65,39 +111,132 @@ def is_node_supported( def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: - """ - Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec. + """Check if the node is lowerable under the given TOSA spec. + + Args: + node (fx.Node): FX node to check. + tosa_spec (TosaSpecification): Active TOSA specification. + + Returns: + bool: True if supported; otherwise, False. + """ raise NotImplementedError("SupportedTOSAOperatorCheck must be extended.") # container for all SupportedTosaOperatorCheck classes -_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = { - TosaSpecification.create_from_string("TOSA-1.0+INT"): [], - TosaSpecification.create_from_string("TOSA-1.0+FP"): [], -} +_tosa_spec_support: TosaSpecMapping[Type[SupportedTOSAOperatorCheck]] = ( + TosaSpecMapping() +) def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]): - """ - Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck - to be registered for checking if a torch.fx.Node is lowerable given - a TOSA specification. + """Register an operator-support checker for one or more TOSA specs. + + Decorate subclasses of :py:class:`SupportedTOSAOperatorCheck` so they are + picked up by the factory and partitioner for the specs declared in their + ``tosa_specs`` class attribute. + + Args: + checker (Type[SupportedTOSAOperatorCheck]): Checker class to register. + """ for tosa_spec in checker.tosa_specs: - _tosa_spec_support[tosa_spec].append(checker) + _tosa_spec_support.add(tosa_spec, checker) return checker +def _is_integer_dtype(dtype: torch.dtype) -> bool: + return not dtype.is_floating_point and not dtype.is_complex + + +def _is_quantized_constant(node: torch.fx.Node) -> bool: + if node.target not in ( + exir_ops.edge.aten.full_like.default, + *ComputeConstantOpsAOTPass.targeted_ops, + ): + return False + + users = tuple(node.users) + if users and all(user.target in Q_OPS for user in users): + # The node feeds directly into only quantized ops. + return True + + for user in users: + if user.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default: + dim_order_dtype = get_first_fake_tensor(user).dtype + if not _is_integer_dtype(dim_order_dtype): + return False + else: + return False + + return len(users) > 0 + + +def is_quantized(node: torch.fx.Node) -> bool: + """Checks if the node is quantized. + + A node is considered quantized if any of the following is true: + - Its output dtype is not floating point or complex => integer + - It is an op that produces a constant that in turn feeds only quantized users + - It has been marked as quantized in the ArmAnnotationInfo custom meta. + + Args: + node (torch.fx.Node): The FX node to check. + + Returns: + bool: True if the node is quantized, False otherwise. + """ + + try: + node_dtype = get_first_fake_tensor(node).dtype + # Integer-like dtype implies the node is already quantized as long + # as inputs are not floating-point. + if _is_integer_dtype(node_dtype): + input_nodes = node.all_input_nodes + input_nodes_dtypes = [ + get_first_fake_tensor(input_node).dtype for input_node in input_nodes + ] + if all( + _is_integer_dtype(input_node_dtype) + for input_node_dtype in input_nodes_dtypes + ): + return True + + except TypeError: + # Could not determine dtype, fall back to other checks. + pass + + # Nodes introduced during lowering that exclusively feed quantized users. + if _is_quantized_constant(node): + return True + + # Finally, fall back to the explicit annotation emitted by Arm passes. + custom_meta = node.meta.get("custom", {}) + if ArmAnnotationInfo.CUSTOM_META_KEY in custom_meta: + return custom_meta[ArmAnnotationInfo.CUSTOM_META_KEY]["quantized"] + + return False + + def get_registered_tosa_support_checks( tosa_spec: TosaSpecification, ) -> list[Type[SupportedTOSAOperatorCheck]]: - if tosa_spec not in _tosa_spec_support: + """Get all registered operator-support checkers for a given spec. + + Args: + tosa_spec (TosaSpecification): TOSA spec to query. + + Returns: + list[Type[SupportedTOSAOperatorCheck]]: Registered checker classes. + + """ + checks = _tosa_spec_support.get(tosa_spec) + if not checks: raise RuntimeError( - f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}" + f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support._mapping.keys())}" ) - - return _tosa_spec_support[tosa_spec] + return checks def tosa_support_factory( @@ -106,15 +245,33 @@ def tosa_support_factory( reporter: WhyNoPartitionReporter, additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> OperatorSupportBase: - """Generates an OperatorSupport class depending on the given `tosa_spec`. - Additional checks can be supplied to avoid partitioning additional nodes. + """Create an OperatorSupport composite for a TOSA spec. + + Combine profile-specific positive checks, registered operator checks, and + negative checks into a single :py:class:`OperatorSupportBase` chain. + + Args: + tosa_spec (TosaSpecification): Active TOSA specification. + exported_program (ExportedProgram): Program context for checks. + reporter (WhyNoPartitionReporter): Reporter for rejections. + additional_checks (Optional[Sequence[OperatorSupportBase]]): Extra + negative checks to apply. + + Returns: + OperatorSupportBase: Composite checker for the given spec. + """ # Postive checks: Add nodes to partitioning - positive_checks: list[OperatorSupportBase] = [] + positive_checks: list[OperatorSupportBase] = [ + ControlFlowSubmoduleSupported(exported_program, tosa_spec, reporter), + ControlFlowOpSupported(exported_program, tosa_spec, reporter), + ] - if tosa_spec.support_integer(): + if tosa_spec.support_integer() and tosa_spec.support_float(): + positive_checks.append(TOSAProINTFPSupportList()) + elif tosa_spec.support_integer(): positive_checks.append(TOSAProINTSupportList()) - if tosa_spec.support_float(): + elif tosa_spec.support_float(): positive_checks.append(TOSAProFPSupportList()) # TODO: Refactor to use TOSAProSupportLists + negtive checks positive_checks += [ @@ -126,7 +283,7 @@ def tosa_support_factory( negative_checks: list[OperatorSupportBase] = [ CheckInt64InputsAndOutputs(exported_program, reporter), CheckFloat64Inputs(exported_program, reporter), - RankCheck(reporter, max_rank=5), + RankCheck(reporter, max_rank=MAX_RANK), *[ reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}") for check in (additional_checks if additional_checks else []) @@ -134,13 +291,14 @@ def tosa_support_factory( ] if not tosa_spec.support_float(): - negative_checks.append(NeedsDecompositionCheck(reporter)) + negative_checks.append(CheckArmQuantized(reporter)) negative_checks.append(CheckProperQuantization(reporter)) if tosa_spec.is_U55_subset: negative_checks.append(EthosU55NotSupported(reporter)) negative_checks.append(EthosU55DtypeSupport(reporter)) negative_checks.append(EthosU55TransposeCheck(reporter)) negative_checks.append(EthosU55ViewCheck(reporter)) + negative_checks.append(EthosU55CastCheck(reporter)) return chain( reporter.wrap_check( @@ -152,36 +310,65 @@ def tosa_support_factory( class TOSAProINTSupportList(OperatorSupportBase): - """ - TOSA_PRO_INT_SupportList: - Ops supported in INT profile via native TOSA ops, decomposition/transformation, pre-compute, or TableOps + """Provide the INT profile support list for TOSA. + + TOSA_PRO_INT_SupportList enumerates ops supported in the INT profile via + native TOSA ops, decompositions, pre-compute steps, or TableOps. + + Note: + Ops supported via pre-quantization decompositions are not included + here. + """ def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - + """Return True if the node is in the INT profile support list.""" return node.op == "call_function" and node.target in TOSA_PRO_INT_SupportList class TOSAProFPSupportList(OperatorSupportBase): + """Provide the FP profile support list for TOSA. + + Includes ops supported natively, via decomposition/transformation, and pre- + compute. + """ - TOSA_PRO_FP_SupportList: - Ops supported in FP profile via native TOSA ops, decomposition/transformation, pre-compute + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + """Return True if the node is in the FP profile support list.""" + return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList + + +class TOSAProINTFPSupportList(OperatorSupportBase): + """ + TOSA_PRO_INT_FP_SupportList: + Ops supported in INT+FP profile via native TOSA ops, decomposition/transformation, pre-compute, or TableOp. """ def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + if node.op != "call_function": + return False - return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList + # Select list based on whether the node is quantized. + if is_quantized(node) or node.target in (*Q_OPS, *DQ_OPS): + support_list = TOSA_PRO_INT_SupportList + else: + support_list = TOSA_PRO_FP_SupportList + + return node.target in support_list -class NeedsDecompositionCheck(OperatorSupportBase): +class CheckArmQuantized(OperatorSupportBase): """ - Targeted operators need to be decomposed prior to quantization in order to get a pair of q-dq-nodes surrounding - the operator, and to get optimal quantization parameters for each operator. This check will reject operators - that need to be decomposed. + Check if the node was marked as quantized in the Arm backend. + This is used to ensure that nodes that were quantized in the Arm backend + are only partitioned if they are supported by the TOSA backend. """ def __init__(self, reporter: WhyNoPartitionReporter): @@ -191,48 +378,23 @@ def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - if node.op != "call_function": + if node.target in (*DQ_OPS, *Q_OPS): return True - needs_decomp_dict = { - exir_ops.edge.aten.div.Tensor: None, - exir_ops.edge.aten._native_batch_norm_legit_no_training.default: "BatchNorm2D with track_running_stats==True not immediately following a convolution is not supported for quantized TOSA backends.", - exir_ops.edge.aten.native_layer_norm.default: None, - exir_ops.edge.aten.native_group_norm.default: None, - exir_ops.edge.aten._softmax.default: None, - exir_ops.edge.aten._log_softmax.default: None, - exir_ops.edge.aten.var.correction: None, - exir_ops.edge.aten.var.dim: None, - exir_ops.edge.aten.add.Scalar: None, - exir_ops.edge.aten.sqrt.default: None, - exir_ops.edge.aten.sub.Scalar: None, - exir_ops.edge.aten.mul.Scalar: None, - exir_ops.edge.aten.ne.Tensor: None, - exir_ops.edge.aten.ne.Scalar: None, - exir_ops.edge.aten.div.Scalar: None, - exir_ops.edge.aten.leaky_relu.default: None, - exir_ops.edge.aten.round.default: None, - exir_ops.edge.aten.addmm.default: None, - exir_ops.edge.aten.glu.default: None, - exir_ops.edge.aten.logit.default: None, - } - - if node.target in needs_decomp_dict: - reject_message = needs_decomp_dict[node.target] - if reject_message is None: - reject_message = "Op needs to be decomposed into other ops before quantization to get quantized properly." - - self.reporter.report_reject(node, reject_message) + if not is_quantized(node): + self.reporter.report_reject( + node, "Node was not marked as quantized in the Arm backend." + ) return False - else: - return True + return True class CheckProperQuantization(OperatorSupportBase): - """ - For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize - and dequantize nodes surrounds the node. This is neccessary for table operators and operators that need to rescale - activations. + """Ensure targeted nodes are properly quantized. + + Verify that a pair of quantize/dequantize nodes surrounds targeted ops so + rescaling and table operators behave correctly. + """ targeted_ops = ( @@ -258,17 +420,32 @@ class CheckProperQuantization(OperatorSupportBase): ) def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter.""" self.reporter = reporter def _is_matmul_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ): - """ - Find the matmul source partition containing this node and check that all its inputs and outputs are quantized. + """Check quantization for decomposed matmul partitions. + + Handles an edge case where the quantized pipeline + `dq -> torch.matmul/operator.matmul -> q` decomposes into + `dq -> expand -> view -> aten.mm -> view -> q`. + + Args: + submodules (Mapping[str, torch.nn.Module]): Map of child modules to + inspect for matmul partitions. + node (fx.Node): Node that should belong to a quantized matmul + partition. + + Returns: + bool: True if the matched partition uses quantized inputs and + outputs. + """ for graph_module in submodules.values(): graph_module = typing.cast(fx.GraphModule, graph_module) - matmul_partitions = get_source_partitions( + matmul_partitions_map = get_source_partitions( graph_module.graph, [ torch.matmul, @@ -277,7 +454,7 @@ def _is_matmul_node_supported( None, ) matmul_partitions = list( - itertools.chain.from_iterable(matmul_partitions.values()) + itertools.chain.from_iterable(matmul_partitions_map.values()) ) matched_partition = None for partition in matmul_partitions: @@ -313,6 +490,12 @@ def _is_matmul_node_supported( def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + """Return True if the node passes constant-cast and multi-output checks. + + Ensures decomposition-specific matmul partitions keep quantized inputs + and outputs. + + """ output_quantized = False input_quantized = False if node.target not in self.targeted_ops: @@ -345,7 +528,7 @@ def is_node_supported( input_quantized = input_quantized or all( (input_node.target in DQ_OPS) - or (not get_first_fake_tensor(input_node).dtype.is_floating_point) + or _is_integer_dtype(get_first_fake_tensor(input_node).dtype) for input_node in node.all_input_nodes ) @@ -354,8 +537,10 @@ def is_node_supported( return False all_q_users = all((output_node.target in Q_OPS) for output_node in node.users) - is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point - output_quantized = output_quantized or all_q_users or not is_floating_point + output_dtype = get_first_fake_tensor(node).dtype + output_quantized = ( + output_quantized or all_q_users or _is_integer_dtype(output_dtype) + ) if not output_quantized: self.reporter.report_reject(node, "One or more outputs were not quantized.") @@ -364,21 +549,22 @@ def is_node_supported( class CheckInt64InputsAndOutputs(OperatorSupportBase): - """TOSA does not support int64 tensors so in general, ops with int64 inputs or outputs should not be partitioned. - There are however some exceptions: - - Nodes with int64 output can be partitioned if they are constant, within int32, - and all users cast to something else. In this case, the int64 tensor can safely be cast to int32 AOT. - - Nodes with int64 output can be partitioned if all users are getitem with non-int64 output. - In this case, there are multiple outputs and the int64 ones are not used. - - Nodes with int64 inputs can be partitioned if the inputs are constant placeholders, or constant - ops fulfilling the criteria above. - Note that we don't check placeholders here, they are partitioned based on whether their users are partitioned - or not. + """Reject general int64 tensors while allowing safe exceptions. + + Exceptions are: + - Nodes with contant int64 output within int32 range that are cast away + from int64 by all users. + - Int64 output where all users are getitem nodes with non-int64 outputs. + In this case there are multiple outputs and the int64 output is unused. + - Nodes where all inputs are int64 constant placeholders or constant ops + that fulfill the above exceptions. + """ def __init__( self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter ): + """Initialize the check with program context and reporter.""" self.input_names = [ spec.arg.name for spec in exported_program.graph_signature.input_specs @@ -400,7 +586,9 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool: def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - + """Return True when int64 use is absent or safe per exceptions.""" + if is_submodule_node(node): + return True vals = node.meta["val"] tensor_list = vals if isinstance(vals, (list, tuple)) else [vals] @@ -414,7 +602,7 @@ def is_node_supported( for output_node in node.users ) if ( - node.target in ComputeConstantOpsAOT.targeted_ops + node.target in ComputeConstantOpsAOTPass.targeted_ops and users_output_non_int64 ): if not self.inside_int32_bounds(node): @@ -440,7 +628,11 @@ def is_node_supported( # Ops with int64 inputs are only partitioned if input nodes are constant and will be partitioned. # If it is not partitioned, the partition will get an int64 input and fail. - for input_node in node.all_input_nodes: + for input_node in ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ): tensor_in = get_first_fake_tensor(input_node) if tensor_in.dtype != torch.int64: continue @@ -452,12 +644,10 @@ def is_node_supported( continue # Constant operator if input_node.op == "call_function": - if input_node.target in ComputeConstantOpsAOT.targeted_ops: + if input_node.target in ComputeConstantOpsAOTPass.targeted_ops: # This is not perfect since the input_node can still be rejected by other checks but # this should cover the majority of cases. - if self.is_node_supported( - None, input_node # type: ignore[arg-type] #(we don't use 'submodules') - ): + if self.is_node_supported({}, input_node): continue self.reporter.report_reject( node, f"Non-constant int64 input {input_node.name}" @@ -468,18 +658,30 @@ def is_node_supported( class CheckFloat64Inputs(OperatorSupportBase): + """Reject nodes with float64 inputs. + + Useful as a negative check for specs that do not allow float64. + + """ def __init__( self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter ): + """Initialize the check with program context and reporter.""" self.reporter = reporter super().__init__() def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - - for input_node in node.all_input_nodes: + """Return True if no float64 inputs are present.""" + if is_submodule_node(node): + return True + for input_node in ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ): tensor = get_first_fake_tensor(input_node) if tensor.dtype == torch.float64: self.reporter.report_reject( @@ -491,9 +693,10 @@ def is_node_supported( class RankCheck(OperatorSupportBase): - """Makes sure that nodes with input or output tensors with rank > max_rank are not partitioned""" + """Reject nodes with rank greater than ``max_rank``.""" def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int): + """Initialize the check with a reporter and maximum rank.""" self.reporter = reporter self.max_rank = max_rank super().__init__() @@ -501,7 +704,14 @@ def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - input_nodes = node.all_input_nodes + """Return True if input/output tensor ranks are within the limit.""" + if is_submodule_node(node): + return True + input_nodes = ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ) # check if any input node has an unsupported rank for input_node in input_nodes: input_node_shape = get_first_fake_tensor(input_node).shape diff --git a/backends/arm/operator_support/where_support.py b/backends/arm/operator_support/where_support.py new file mode 100644 index 00000000000..2ec7c30827d --- /dev/null +++ b/backends/arm/operator_support/where_support.py @@ -0,0 +1,77 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +import torch.fx as fx +from executorch.backends.arm.constants import DQ_OPS +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_tosa_support_check +class WhereSupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten.where.self] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: # type: ignore[override, misc] + + if len(node.all_input_nodes) != 3: + self.reporter.report_reject( + node, + ( + "Expected exactly three input nodes, " + f"got {len(node.all_input_nodes)} for {node.target}." + ), + ) + return False + + condition, x, y = node.all_input_nodes + if condition.meta["val"].dtype != torch.bool: + self.reporter.report_reject( + node, + f"Type of condition in {node.target} is not torch.bool", + ) + return False + + x_dtype, y_dtype = x.meta["val"].dtype, y.meta["val"].dtype + if tosa_spec.support_float(): + if x_dtype in (torch.bool, torch.float16, torch.float32) and y_dtype in ( + torch.bool, + torch.float16, + torch.float32, + ): + return True + + if tosa_spec.support_integer(): + if ( + x_dtype in (torch.bool, torch.int8, torch.int16, torch.int32) + or (x_dtype == torch.float32 and x.target in DQ_OPS) + ) and ( + y_dtype in (torch.bool, torch.int8, torch.int16, torch.int32) + or (y_dtype == torch.float32 and y.target in DQ_OPS) + ): + return True + + self.reporter.report_reject( + node, + ( + f"Tensor x dtype {x_dtype} and/or tensor y dtype {y_dtype} is not supported in {node.target} " + f"for tosa specification {tosa_spec}" + ), + ) + + return False diff --git a/backends/arm/operators/TARGETS b/backends/arm/operators/TARGETS index 2c255b3c17a..38eb9e7cad9 100644 --- a/backends/arm/operators/TARGETS +++ b/backends/arm/operators/TARGETS @@ -20,12 +20,10 @@ runtime.python_library( name = "ops", srcs = glob(["op_*.py", "ops_*.py"]), deps = [ - "fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/tosa:tosa", - "fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/tosa:tosa", + "fbsource//third-party/tosa_tools:tosa", ":node_visitor", ":operator_validation_utils", "//executorch/backends/arm/tosa:mapping", - "//executorch/backends/arm/tosa:quant_utils", "//executorch/backends/arm/tosa:utils", "//executorch/backends/arm/_passes:passes", "//executorch/exir:lib", diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index f7a9638254e..15be109d708 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -2,8 +2,13 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Import and register Arm TOSA operator visitors. + +Importing this package loads all visitor modules so their classes can be +registered via decorators and discovered at runtime. + +""" -# pyre-unsafe from . import ( # noqa node_visitor, @@ -13,12 +18,12 @@ op_amin, op_any, op_avg_pool2d, - op_bmm, + op_bitwise_not, op_cat, op_ceil, op_clamp, + op_cond_if, op_constant_pad_nd, - op_conv2d, op_cos, op_eq, op_erf, @@ -41,7 +46,6 @@ op_pow, op_reciprocal, op_repeat, - op_rescale, op_rshift_tensor, op_rsqrt, op_sigmoid, @@ -49,14 +53,19 @@ op_slice, op_sub, op_sum, - op_table, op_tanh, op_to_dim_order_copy, - op_transpose, - op_upsample_bilinear2d, - op_upsample_nearest2d, + op_tosa_conv2d, + op_tosa_conv3d, + op_tosa_depthwise_conv2d, + op_tosa_matmul, + op_tosa_rescale, + op_tosa_resize, + op_tosa_table, + op_tosa_transpose, op_view, op_where, + op_while, ops_binary, ops_identity, ) diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 54a81bdaaff..68120d10ba7 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -3,23 +3,42 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +"""Provide utilities to register and apply TOSA node visitors. + +Use this module to construct and serialize TOSA operators from FX nodes. +- Define the NodeVisitor base class and registry +- Register concrete visitors per TOSA specification + +""" import json + +import logging from typing import Any, Dict, List, Optional import torch +import tosa_serializer as ts -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.debug.schema import DebugHook from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.backends.arm.tosa.specification import ( + TosaSpecification, + TosaSpecMapping, +) from torch.export import ExportedProgram +logger = logging.getLogger(__name__) + class NodeVisitor: - """ - Node Visitor pattern for lowering edge IR to TOSA + """Provide a visitor pattern to lower edge IR to TOSA. + + Attributes: + _exported_program (torch.export.ExportedProgram): Source program being lowered. + tosa_spec (TosaSpecification): Active TOSA specification for lowering. + debug_hook (Optional[DebugHook]): Optional hook for debug metadata. + """ # Add the currently supported node_visitor specs as default. @@ -46,12 +65,29 @@ def _serialize_operator( self, node: torch.fx.Node, tosa_graph: Any, - tosa_op: Any, + tosa_op: ts.Op, inputs: List[str], outputs: List[str], attributes: Optional[Any] = None, ) -> None: - op_location = "" + """Serialize a TOSA operator into the graph. + + When a ``DebugHook`` is active, attach location metadata (in JSON) to + the operator for traceability. + + Args: + node (torch.fx.Node): Source FX node being lowered. + tosa_graph: Target TOSA serializer/graph object. + tosa_op: TOSA operator enum value to emit. + inputs (List[str]): Names of input tensors. + outputs (List[str]): Names of output tensors. + attributes (Optional[Any]): Optional TOSA attribute object. + + Returns: + None: Mutates ``tosa_graph`` in place. + + """ + op_location = None if self.debug_hook: debug_info = self.debug_hook.add( node, @@ -59,7 +95,7 @@ def _serialize_operator( tosa_op_id=tosa_op, ) - if self.debug_hook.mode == ArmCompileSpecBuilder.DebugMode.TOSA: + if self.debug_hook.mode == ArmCompileSpec.DebugMode.TOSA: op_location = json.dumps(debug_info.to_dict()) tosa_graph.addOperator( @@ -77,25 +113,50 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: + """Define a TOSA operator node. + + Args: + node (torch.fx.Node): FX node being lowered. + tosa_graph (serializer.tosa_serializer.TosaSerializer): Target TOSA graph. + inputs (List[TosaArg]): Input tensor arguments. + output (TosaArg): Output tensor descriptor. + + Returns: + None: Mutates ``tosa_graph`` in place. + + Raises: + ValueError: If input count or dtypes are invalid. + + """ raise NotImplementedError("NodeVisitor must be extended.") # container for all node visitors -_node_visitor_dicts: Dict[TosaSpecification, Dict] = { - TosaSpecification.create_from_string("TOSA-1.0+INT"): {}, - TosaSpecification.create_from_string("TOSA-1.0+FP"): {}, -} +_node_visitor_tuples: TosaSpecMapping[tuple] = TosaSpecMapping() def register_node_visitor(visitor): + """Register a concrete ``NodeVisitor`` class for its TOSA specs.""" for tosa_spec in visitor.tosa_specs: - _node_visitor_dicts[tosa_spec][visitor.target] = visitor + # Try to get the tuple to make sure it doesn't exist + visitor_tuple = (visitor.target, visitor) + try: + tuples = _node_visitor_tuples.get(tosa_spec) + except KeyError: + tuples = [] + + if visitor_tuple in tuples: + raise RuntimeError( + f"Visitor for target {visitor.target} already registered for TOSA spec {tosa_spec}" + ) + _node_visitor_tuples.add(tosa_spec, visitor_tuple) return visitor def get_node_visitors(*args) -> Dict[str, NodeVisitor]: - node_visitors = {} - tosa_spec = None + """Return a mapping from target names to visitor instances for a spec.""" + node_visitors: Dict[str, NodeVisitor] = {} + tosa_spec: TosaSpecification | None = None for arg in args: if isinstance(arg, TosaSpecification): tosa_spec = arg @@ -104,7 +165,13 @@ def get_node_visitors(*args) -> Dict[str, NodeVisitor]: if tosa_spec is None: raise RuntimeError("No TOSA specification supplied.") - for target, visitor in _node_visitor_dicts[tosa_spec].items(): + # Use the mapping to get the dict for this spec (handles combined specs) + for node_visitor_tuple in _node_visitor_tuples.get(tosa_spec): + target, visitor = node_visitor_tuple + if target in node_visitors and node_visitors[target].__class__ != visitor: + logger.warning( + f"Target {target} already has visitor class {node_visitors[target].__class__.__name__} registered, overwriting with class: {visitor.__name__}" + ) node_visitors[target] = visitor(*args) return node_visitors diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index 625293d66e0..b5a58136395 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -3,11 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils -import executorch.backends.arm.tosa.utils as tutils +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -18,22 +16,20 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @register_node_visitor -class AbsVisitor_INT(NodeVisitor): +class AbsVisitor(NodeVisitor): target = "aten.abs.default" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def __init__(self, *args): - super().__init__(*args) - def define_node( self, node: Node, @@ -41,93 +37,23 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) - # Handle int8 (quantized) and int32 validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - if inputs[0].dtype == ts.DType.INT8: - rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) # type: ignore[possibly-undefined] - else: - # input[0].dtype == ts.DType.INT32 - # Non quantized input, natively support by TOSA.abs - rescaled_inputs = inputs - - if output.dtype == ts.DType.INT8: - broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) - abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) - else: - # output.dtype == ts.DType.INT32 - abs_output = output - - # Do the INT32 Abs - tosa_graph.addOperator( - ts.TosaOp.Op().ABS, - [ - rescaled_inputs[0].name, - ], - [abs_output.name], - None, + attr = ts.TosaSerializerAttribute() + attr.AbsAttribute() + self._serialize_operator( + node, + tosa_graph, + ts.Op.ABS, + [inputs[0].name], + [output.name], + attr, ) - - if output.dtype == ts.DType.INT8: - # Scale output back to 8 bit - # pyre-ignore - tqutils.insert_rescale_op_to_int8( - tosa_graph, abs_output, scale_back, node, self.tosa_spec - ) # type: ignore[possibly-undefined] - - -@register_node_visitor -class AbsVisitor_FP(AbsVisitor_INT): - # inheriting 'target' from BI class - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 1) - validate_same_dtype(self.target, [*inputs, output], ts) - - if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: - # Call the inherited define_node for handling integers - super().define_node(node, tosa_graph, inputs, output) - else: - # FP32 Abs lowering - - validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec - ) - - # MI lowering - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().ABS, - [inputs[0].name], - [output.name], - None, - ) diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index a8f0c3fe14d..6c1ff2e1449 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -3,12 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils -import executorch.backends.arm.tosa.utils as tutils +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -19,22 +17,20 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @register_node_visitor -class AddVisitor_INT(NodeVisitor): +class AddVisitor(NodeVisitor): target = "aten.add.Tensor" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def __init__(self, *args): - super().__init__(*args) - def define_node( self, node: Node, @@ -42,104 +38,23 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) - valid_dtypes = [] - if self.tosa_spec.support_integer(): - valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]) - if self.tosa_spec.support_float(): - valid_dtypes.extend([ts.DType.INT32]) - validate_valid_dtype( self.target, [*inputs, output], - valid_dtypes, + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - if inputs[0].dtype == ts.DType.INT8: - rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale( - tosa_graph, inputs, node, self.tosa_spec - ) - else: - # input[0].dtype == ts.DType.INT16 or ts.DType.INT32 - # Non quantized input, natively support by TOSA.ADD - rescaled_inputs = inputs - if output.dtype == ts.DType.INT8: - broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) - add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) - else: - # output.dtype == ts.DType.INT16 or ts.DType.INT32 - add_output = output + attr = ts.TosaSerializerAttribute() + attr.AddAttribute() - input1, input2 = rescaled_inputs - - # Do the INT32 Add self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().ADD, - [input1.name, input2.name], - [add_output.name], - None, + ts.Op.ADD, + [inputs[0].name, inputs[1].name], + [output.name], + attr, ) - - if output.dtype == ts.DType.INT8: - # Scale output back to 8 bit - # pyre-ignore - tqutils.insert_rescale_op_to_int8( - tosa_graph, - add_output, - scale_back, - node, - compute_rescale=False, - tosa_spec=self.tosa_spec, - ) # type: ignore[possibly-undefined] - - -@register_node_visitor -class AddVisitor_FP(AddVisitor_INT): - # inheriting 'target' from INT class - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, [*inputs, output], ts) - - if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]: - # Call the inherited define_node for handling integers - super().define_node(node, tosa_graph, inputs, output) - else: - # FP32 Add lowering - validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec - ) - - input1, input2 = inputs - - # FP lowering - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().ADD, - [input1.name, input2.name], - [output.name], - None, - ) diff --git a/backends/arm/operators/op_amax.py b/backends/arm/operators/op_amax.py index 1fb751597af..e4824fb59c2 100644 --- a/backends/arm/operators/op_amax.py +++ b/backends/arm/operators/op_amax.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. from typing import Any, List +import tosa_serializer as ts + from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -34,8 +36,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 3) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( @@ -60,11 +60,12 @@ def define_node( ) attr = ts.TosaSerializerAttribute() - attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=1) + nan_mode = ts.NanPropagationMode.PROPAGATE + attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=nan_mode) self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().REDUCE_MAX, + ts.Op.REDUCE_MAX, [input.name], [output.name], attr, diff --git a/backends/arm/operators/op_amin.py b/backends/arm/operators/op_amin.py index 9ebe78b946d..34d4d37cdeb 100644 --- a/backends/arm/operators/op_amin.py +++ b/backends/arm/operators/op_amin.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. from typing import Any, List +import tosa_serializer as ts + from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -34,8 +36,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 3) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( @@ -60,11 +60,13 @@ def define_node( ) attr = ts.TosaSerializerAttribute() - attr.ReduceMinAttribute(axis=input.dim_order.index(dim), nan_mode=1) + attr.ReduceMinAttribute( + axis=input.dim_order.index(dim), nan_mode=ts.NanPropagationMode.PROPAGATE + ) self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().REDUCE_MIN, + ts.Op.REDUCE_MIN, [input.name], [output.name], attr, diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index 0c47d6b190a..2a850c0cf52 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -3,9 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, cast, List +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( # type: ignore NodeVisitor, register_node_visitor, @@ -33,8 +34,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 3) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( @@ -47,7 +46,7 @@ def define_node( ) # process the negative index keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) if not keep_dim: - raise ValueError("This case should be handled by ConvertAnyDimDimsPass") + raise ValueError("This case should be handled by DecomposeAnyPass") attr = ts.TosaSerializerAttribute() attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim)) @@ -55,7 +54,7 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().REDUCE_ANY, + ts.Op.REDUCE_ANY, [inputs[0].name], [output.name], attr, diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index d28f5f27acf..ec9d42915c1 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -3,11 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List import torch +import tosa_serializer as ts + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -32,6 +33,7 @@ class AvgPool2dVisitor(NodeVisitor): tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] def __init__(self, *args): @@ -48,8 +50,6 @@ def _build_generic_avgpool2d( accumulator_type: Any, ) -> None: - import serializer.tosa_serializer as ts # type: ignore - input_tensor = inputs[0] kernel_size_list = inputs[1].special stride_size_list = inputs[2].special @@ -93,17 +93,14 @@ def _build_generic_avgpool2d( pad=pad_size_list, acc_type=accumulator_type, ) - input_zp_tensor = tosa_graph.addConst( - shape=[1], dtype=output.dtype, vals=[input_zp] - ) - output_zp_tensor = tosa_graph.addConst( - shape=[1], dtype=output.dtype, vals=[output_zp] - ) + dt: ts.DType = output.dtype + input_zp_tensor = tosa_graph.addConst(shape=[1], dtype=dt, vals=[input_zp]) + output_zp_tensor = tosa_graph.addConst(shape=[1], dtype=dt, vals=[output_zp]) self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().AVG_POOL2D, + ts.Op.AVG_POOL2D, [input_tensor.name, input_zp_tensor.name, output_zp_tensor.name], [output.name], attr, @@ -116,65 +113,30 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) - validate_same_dtype(self.target, [inputs[0], output], ts) - validate_valid_dtype( - self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec - ) - - accumulator_type = ts.DType.INT32 - - input_qargs = get_input_qparams(node) - input_zp = input_qargs[0].get_zp_per_tensor() - - output_qargs = get_output_qparams(node) - output_zp = output_qargs[0].get_zp_per_tensor() - - self._build_generic_avgpool2d( - node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type - ) - - -@register_node_visitor -class AvgPool2dVisitor_FP(AvgPool2dVisitor): - target = "aten.avg_pool2d.default" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] + if self.tosa_spec.support_extension("int16"): + supported_dtypes.append(ts.DType.INT16) validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.FP32], + supported_dtypes, output.tosa_spec, ) - if inputs[0].dtype == ts.DType.INT8: - super().define_node(node, tosa_graph, inputs, output) + if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16: + accumulator_type = ts.DType.INT32 + input_qargs = get_input_qparams(node) + input_zp = input_qargs[0].get_zp_per_tensor() - if inputs[0].dtype == ts.DType.FP32: + output_qargs = get_output_qparams(node) + output_zp = output_qargs[0].get_zp_per_tensor() + else: accumulator_type = ts.DType.FP32 - # Initilize zero point to zero. input_zp = 0 output_zp = 0 - self._build_generic_avgpool2d( - node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type - ) + self._build_generic_avgpool2d( + node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type + ) diff --git a/backends/arm/operators/op_bitwise_not.py b/backends/arm/operators/op_bitwise_not.py new file mode 100644 index 00000000000..ac0f758469d --- /dev/null +++ b/backends/arm/operators/op_bitwise_not.py @@ -0,0 +1,62 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification +from torch.fx import Node + + +@register_node_visitor +class BitwiseNotVisitor(NodeVisitor): + target = "aten.bitwise_not.default" + + # bitwise_not is not supported on the FP profile + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, 1) + validate_same_dtype(self.target, [*inputs, output], ts) + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], + output.tosa_spec, + ) + + attr = ts.TosaSerializerAttribute() + attr.BitwiseNotAttribute() + + self._serialize_operator( + node, + tosa_graph, + ts.Op.BITWISE_NOT, + [inputs[0].name], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py deleted file mode 100644 index 382386ffa26..00000000000 --- a/backends/arm/operators/op_bmm.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import Any, List - -import torch - -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, - get_output_qparams, -) -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.quant_utils import build_rescale -from tosa.RoundingMode import RoundingMode # type: ignore - - -@register_node_visitor -class BMMVisitor(NodeVisitor): - target = "aten.bmm.default" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, [*inputs, output], ts) - validate_valid_dtype( - self.target, - [*inputs, output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.FP32], - output.tosa_spec, - ) - - # aten.bmm maps directly to MATMUL - - # For INT8, we need to get the zero points and add an intermediate tensor - # for a later rescale. - - if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) - input0_zp = input_qparams[0].get_zp_per_tensor() - input1_zp = input_qparams[1].get_zp_per_tensor() - bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) - bmm_output_name = bmm_result.name - else: - bmm_output_name = output.name - input0_zp, input1_zp = 0, 0 - - tosa_graph.addConst([1], inputs[0].dtype, [input0_zp], name=f"{node.name}_A_ZP") - tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=f"{node.name}_B_ZP") - - # Add the MATMUL to the TOSA graph. - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().MATMUL, - [ - inputs[0].name, - inputs[1].name, - f"{node.name}_A_ZP", - f"{node.name}_B_ZP", - ], - [bmm_output_name], - ) - - # As INT8 accumulates into INT32, we need to rescale it back to INT8 - if output.dtype == ts.DType.INT8: - output_qparams = get_output_qparams(node)[0] - final_output_scale = ( - input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61] - ) / output_qparams.get_scale_per_tensor() - - build_rescale( - tosa_fb=tosa_graph, - scale=[final_output_scale], - # pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined. - input_node=bmm_result, # type: ignore[possibly-undefined] - output_name=output.name, - output_type=ts.DType.INT8, - input_zp=[0], - output_zp=[output_qparams.get_zp_per_tensor()], - rounding_mode=RoundingMode.SINGLE_ROUND, - ) diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index 65b3e2a9549..71c18530d55 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -3,16 +3,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.operators.operator_validation_utils import ( validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, ) from executorch.backends.arm.tosa.mapping import TosaArg from torch.fx import Node @@ -34,11 +37,19 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - + supported_dtypes = [ts.DType.BOOL, ts.DType.INT8, ts.DType.INT32, ts.DType.FP32] + if self.tosa_spec.support_extension("int16"): + supported_dtypes.append(ts.DType.INT16) validate_num_inputs(self.target, inputs, [1, 2]) + input_tosa_args = [TosaArg(arg, output.tosa_spec) for arg in inputs[0].special] + validate_same_dtype(self.target, [*input_tosa_args, output], ts) + validate_valid_dtype( + self.target, + [*input_tosa_args, output], + supported_dtypes, + output.tosa_spec, + ) - tensors = inputs[0].special dim = 0 if len(inputs) < 2 else inputs[1].number rank = len(output.shape) dim = (dim + rank) % rank @@ -50,8 +61,8 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().CONCAT, - [tensor.name for tensor in tensors], + ts.Op.CONCAT, + [tensor.name for tensor in input_tosa_args], [output.name], attr, ) diff --git a/backends/arm/operators/op_ceil.py b/backends/arm/operators/op_ceil.py index 5cf89710436..27ee81d0abe 100644 --- a/backends/arm/operators/op_ceil.py +++ b/backends/arm/operators/op_ceil.py @@ -7,6 +7,8 @@ import torch.fx +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -38,8 +40,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore # noqa: F401 - validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( @@ -49,6 +49,8 @@ def define_node( output.tosa_spec, ) + attr = ts.TosaSerializerAttribute() + attr.CeilAttribute() self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().CEIL, [inputs[0].name], [output.name] + node, tosa_graph, ts.Op.CEIL, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index b0bf044a213..d90f92f5e4b 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -1,15 +1,14 @@ # Copyright 2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree -# pyre-unsafe from typing import Any, List, Tuple import numpy as np import torch +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -27,20 +26,20 @@ @register_node_visitor -class ClampVisitor_INT(NodeVisitor): +class ClampVisitor(NodeVisitor): target = "aten.clamp.default" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] def __init__(self, *args): super().__init__(*args) def _get_min_max_arguments( - self, node: Node, dtype_min: int | float, dtype_max: int | float + self, node: Node, dtype: torch.dtype ) -> Tuple[int | float, int | float]: - def cast_type(value: Any) -> int | float: if isinstance(value, int): return value @@ -48,6 +47,13 @@ def cast_type(value: Any) -> int | float: # Attempt to cast to float return float(value) + if dtype.is_floating_point: + dtype_min = torch.finfo(dtype).min + dtype_max = torch.finfo(dtype).max + else: + dtype_min = torch.iinfo(dtype).min + dtype_max = torch.iinfo(dtype).max + min_arg = dtype_min max_arg = dtype_max @@ -60,56 +66,17 @@ def cast_type(value: Any) -> int | float: return min_arg, max_arg - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, [2, 3]) - validate_same_dtype(self.target, [inputs[0], output], ts) - validate_valid_dtype( - self.target, [inputs[0], output], [ts.DType.INT8], output.tosa_spec - ) - - # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments - min_int8, max_int8 = self._get_min_max_arguments( - node, - torch.iinfo(torch.int8).min, - torch.iinfo(torch.int8).max, - ) - - attr = ts.TosaSerializerAttribute() - attr.ClampAttribute( - tosa_graph.builder, - np.int8(min_int8).tobytes(), - np.int8(max_int8).tobytes(), - nan_mode=1, - ) - - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().CLAMP, - [inputs[0].name], - [output.name], - attr, - ) - - -@register_node_visitor -class ClampVisitor_FP(ClampVisitor_INT): - # inheriting 'target' from INT class - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def __init__(self, *args): - super().__init__(*args) + def _to_bytes(self, value: int | float, dtype: torch.dtype) -> bytes: + if dtype == torch.float32: + return np.frombuffer(np.float32(value).tobytes(), dtype=np.uint8).tolist() + elif dtype == torch.float16: + return np.frombuffer(np.float16(value).tobytes(), dtype=np.uint8).tolist() + elif dtype == torch.int8: + return np.frombuffer(np.int8(value).tobytes(), dtype=np.uint8).tolist() + elif dtype == torch.int16: + return np.frombuffer(np.int16(value).tobytes(), dtype=np.uint8).tolist() + else: + raise ValueError(f"Unsupported dtype for to_bytes: {dtype}") def define_node( self, @@ -118,35 +85,33 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [2, 3]) validate_same_dtype(self.target, [inputs[0], output], ts) + supported_dtypes = [ts.DType.INT8, ts.DType.FP16, ts.DType.FP32] + if self.tosa_spec.support_extension("int16"): + supported_dtypes.append(ts.DType.INT16) validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.FP16, ts.DType.FP32], + supported_dtypes, output.tosa_spec, ) - min_fp32, max_fp32 = self._get_min_max_arguments( - node, - torch.finfo(torch.float32).min, - torch.finfo(torch.float32).max, - ) + node_input_dtype = node.meta["val"].dtype + # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments + min_val, max_val = self._get_min_max_arguments(node, node_input_dtype) attr = ts.TosaSerializerAttribute() attr.ClampAttribute( - tosa_graph.builder, - np.float32(min_fp32).tobytes(), - np.float32(max_fp32).tobytes(), - nan_mode=1, + self._to_bytes(min_val, node_input_dtype), + self._to_bytes(max_val, node_input_dtype), + nan_mode=ts.NanPropagationMode.PROPAGATE, ) self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().CLAMP, + ts.Op.CLAMP, [inputs[0].name], [output.name], attr, diff --git a/backends/arm/operators/op_cond_if.py b/backends/arm/operators/op_cond_if.py new file mode 100644 index 00000000000..4cf5120de31 --- /dev/null +++ b/backends/arm/operators/op_cond_if.py @@ -0,0 +1,56 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import Any, cast, List + +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( # type: ignore + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_cf_extension, + validate_num_inputs, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg # type: ignore +from torch.fx import Node + + +@register_node_visitor +class CondVisitor(NodeVisitor): + target = "cond" + + tosa_specs = NodeVisitor.tosa_specs + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + validate_num_inputs(self.target, inputs, 4) + validate_valid_dtype(self.target, [inputs[0]], ts.DType.BOOL, self.tosa_spec) + validate_cf_extension(self.target, self.tosa_spec) + + attr = ts.TosaSerializerAttribute() + if_graph, else_graph = (cast(Node, arg).target for arg in node.args[1:3]) + attr.CondIfAttribute(if_graph, else_graph) + + self._serialize_operator( + node, + tosa_graph, + ts.Op.COND_IF, + [ + inputs[0].name, + *(subgraph_input.name for subgraph_input in inputs[-1].special), + ], + output.multiple_output_names, + attr, + ) diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index 562c4c9ea0e..47d11fb5627 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -3,12 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List import torch +import tosa_serializer as ts + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, ) @@ -42,8 +43,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 3) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( @@ -51,6 +50,7 @@ def define_node( [inputs[0], output], [ ts.DType.INT8, + ts.DType.INT16, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL, @@ -63,6 +63,11 @@ def define_node( qargs = input_qparams[0] pad_const_val = qargs.quantize_value(inputs[2].number).item() pad_const_dtype = ts.DType.INT8 + elif inputs[0].dtype == ts.DType.INT16: + input_qparams = get_input_qparams(node) + qargs = input_qparams[0] + pad_const_val = qargs.quantize_value(inputs[2].number).item() + pad_const_dtype = ts.DType.INT16 else: pad_const_val = inputs[2].number pad_const_dtype = inputs[0].dtype @@ -100,10 +105,13 @@ def define_node( shape=[1], dtype=pad_const_dtype, vals=[pad_const_val] ) + attr = ts.TosaSerializerAttribute() + attr.PadAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().PAD, + ts.Op.PAD, [inputs[0].name, padding.name, pad_const.name], [output.name], + attr, ) diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py deleted file mode 100644 index 6bfe0ab21eb..00000000000 --- a/backends/arm/operators/op_conv2d.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -import itertools -from typing import Any, List - -import torch - -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, - get_output_qparams, -) -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, -) -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.quant_utils import build_rescale -from executorch.backends.arm.tosa.utils import tosa_shape - - -@register_node_visitor -class Conv2dVisitor(NodeVisitor): - target = "aten.convolution.default" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def __init__(self, *args): - super().__init__(*args) - - # torch.nn.Conv2d does not require the result of - # `(input + 2 * pad - dilation * (weight - 1) - 1) / stride` - # to be an integer, but tosa currently strictly require this property. - # This function adjusts the pad value to meet the requirement. - def adjust_pad_if_needed( - self, input_size: int, input_weight: int, stride: int, pad: int, dilation: int - ) -> int: - mod_remainder = ( - input_size + 2 * pad - dilation * (input_weight - 1) - 1 - ) % stride - - # No need to adjust - if mod_remainder == 0: - return pad - - if mod_remainder > pad: - raise RuntimeError( - "This case should be handled by the SizeAdjustConv2d pass, is it enabled?" - ) - return pad - mod_remainder - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - from tosa.RoundingMode import RoundingMode # type: ignore - - input, weight, bias, stride, pad, dilation, _, _, group = inputs - validate_num_inputs(self.target, inputs, 9) - - # Get the attributes of convolution. - attr = ts.TosaSerializerAttribute() - pad_attr = [val for val in pad.special for _ in (0, 1)] - stride_attr = stride.special - dilation_attr = dilation.special - - # Adjust the pad value if needed to meet the - # strict convolution output shape calculation. - pad_attr[1] = self.adjust_pad_if_needed( - input.shape[2], - weight.shape[2], - stride_attr[0], - pad_attr[1], - dilation_attr[0], - ) - pad_attr[3] = self.adjust_pad_if_needed( - input.shape[3], - weight.shape[3], - stride_attr[1], - pad_attr[3], - dilation_attr[1], - ) - - input_zp = 0 - if inputs[0].dtype == ts.DType.INT8: - # int8 input requires quantization information - input_qparams = get_input_qparams(node) - input_zp = input_qparams[0].get_zp_per_tensor() - - weight_zp = 0 - if inputs[1].dtype == ts.DType.INT8: - # int8 weights requires quantization information - input_qparams = get_input_qparams(node) - weight_zp = input_qparams[1].zp # type: ignore[assignment] - - # The output type is int32 when input type is int8. - conv2d_output_name = output.name - if output.dtype == ts.DType.INT8: - conv2d_res = tosa_graph.addIntermediate( - tosa_shape(output.shape, output.dim_order), ts.DType.INT32 - ) - conv2d_output_name = conv2d_res.name - acc_type = ( - inputs[0].dtype if inputs[0].dtype == ts.DType.FP32 else ts.DType.INT32 - ) - - tosa_graph.addConst( - [1], output.dtype, [input_zp], name=f"{conv2d_output_name}_input_zp" - ) - tosa_graph.addConst( - [1], - output.dtype, - weight_zp, - name=f"{conv2d_output_name}_weight_zp", - ) - - # Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W) - in_channels = input.shape[1] - out_channels = weight.shape[0] - if (in_channels == group.number) and (out_channels % in_channels) == 0: - """Depthwise convolution case""" - # Reshape torch shape format of weight tensor to tosa required format. - # https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d - m_length = int(out_channels / in_channels) - weight_post_shape = [ - weight.shape[2], - weight.shape[3], - in_channels, - m_length, - ] - - weight_reshaped = tosa_graph.addIntermediate( - weight_post_shape, - weight.dtype, - ) - shape = tosa_graph.addConst( - [len(weight_post_shape)], - ts.DType.SHAPE, - weight_post_shape, - name=weight_reshaped.name + "_shape", - ) - - reshape_attr = ts.TosaSerializerAttribute() - reshape_attr.ReshapeAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().RESHAPE, - [weight.name, shape.name], - [weight_reshaped.name], - reshape_attr, - ) - - attr = ts.TosaSerializerAttribute() - tosa_op = ts.TosaOp.Op().DEPTHWISE_CONV2D - weight_name = weight_reshaped.name - - attr.DepthwiseConv2dAttribute( - pad=pad_attr, - stride=stride_attr, - dilation=dilation_attr, - local_bound=False, - acc_type=acc_type, - ) - else: - """Regular convolution case""" - tosa_op = ts.TosaOp.Op().CONV2D - weight_name = weight.name - - attr.Conv2dAttribute( - pad=pad_attr, - stride=stride_attr, - dilation=dilation_attr, - local_bound=False, - acc_type=acc_type, - ) - - self._serialize_operator( - node, - tosa_graph, - tosa_op, - [ - input.name, - weight_name, - bias.name, - f"{conv2d_output_name}_input_zp", - f"{conv2d_output_name}_weight_zp", - ], - [conv2d_output_name], - attr, - ) - - # For quantized convolution, rescale the output value back to the same - # integer value domain of the next op. Otherwise return float32 output. - if inputs[0].dtype == ts.DType.INT8: - # Get scale_factor from input, weight, and output. - input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61] - per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61] - if per_channel_quant: - weight_scale = input_qparams[1].get_scale_per_channel() - else: - weight_scale = [ - input_qparams[1].get_scale_per_tensor() - ] # pyre-ignore [61] - output_qargs = get_output_qparams(node) - post_conv2d_scale = [ - (inp * w) / out - for inp, w, out in zip( - itertools.cycle([input_scale]), - weight_scale, - itertools.cycle([output_qargs[0].get_scale_per_tensor()]), - ) - ] - build_rescale( - tosa_fb=tosa_graph, - scale=post_conv2d_scale, - input_node=conv2d_res, # type: ignore[possibly-undefined] - output_name=output.name, - output_type=output.dtype, - input_zp=[0], - output_zp=[output_qargs[0].get_zp_per_tensor()], - per_channel=per_channel_quant, - rounding_mode=RoundingMode.SINGLE_ROUND, - ) diff --git a/backends/arm/operators/op_cos.py b/backends/arm/operators/op_cos.py index 0350733190c..e6039730b69 100644 --- a/backends/arm/operators/op_cos.py +++ b/backends/arm/operators/op_cos.py @@ -3,10 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -43,7 +42,8 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - + attr = ts.TosaSerializerAttribute() + attr.CosAttribute() self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().COS, [inputs[0].name], [output.name] + node, tosa_graph, ts.Op.COS, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index 2136fe2e946..bd72c9491ca 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -3,11 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -43,36 +42,23 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, inputs, ts) validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - - # Do the equal comparison + attr = ts.TosaSerializerAttribute() + attr.EqualAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().EQUAL, - [input_nodes[0].name, input_nodes[1].name], + ts.Op.EQUAL, + [inputs[0].name, inputs[1].name], [output.name], - None, + attr, ) diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_erf.py index 7797b61e562..e642a4059fe 100644 --- a/backends/arm/operators/op_erf.py +++ b/backends/arm/operators/op_erf.py @@ -2,10 +2,11 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List import torch.fx + +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -36,8 +37,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( @@ -48,6 +47,8 @@ def define_node( ) # MI lowering + attr = ts.TosaSerializerAttribute() + attr.ErfAttribute() self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().ERF, [inputs[0].name], [output.name] + node, tosa_graph, ts.Op.ERF, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index f5d5aef2213..72e89b6906b 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -3,9 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -37,8 +38,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( @@ -48,6 +47,8 @@ def define_node( output.tosa_spec, ) + attr = ts.TosaSerializerAttribute() + attr.ExpAttribute() self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().EXP, [inputs[0].name], [output.name] + node, tosa_graph, ts.Op.EXP, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_floor.py b/backends/arm/operators/op_floor.py index 77d712096fa..d9f831dfb35 100644 --- a/backends/arm/operators/op_floor.py +++ b/backends/arm/operators/op_floor.py @@ -7,6 +7,8 @@ import torch.fx +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -38,8 +40,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore # noqa: F401 - validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( @@ -49,6 +49,8 @@ def define_node( output.tosa_spec, ) + attr = ts.TosaSerializerAttribute() + attr.FloorAttribute() self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().FLOOR, [inputs[0].name], [output.name] + node, tosa_graph, ts.Op.FLOOR, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index c538e735880..754778487e9 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -3,11 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -43,35 +42,23 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, inputs, ts) validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - + attr = ts.TosaSerializerAttribute() + attr.GreaterEqualAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().GREATER_EQUAL, - [input_nodes[0].name, input_nodes[1].name], + ts.Op.GREATER_EQUAL, + [inputs[0].name, inputs[1].name], [output.name], - None, + attr, ) diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index d407e28c1b6..2a483f735a7 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -3,11 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -43,35 +42,23 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, inputs, ts) validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - + attr = ts.TosaSerializerAttribute() + attr.GreaterAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().GREATER, - [input_nodes[0].name, input_nodes[1].name], + ts.Op.GREATER, + [inputs[0].name, inputs[1].name], [output.name], - None, + attr, ) diff --git a/backends/arm/operators/op_index_select.py b/backends/arm/operators/op_index_select.py index e357416fadb..ba2aa03c7ff 100644 --- a/backends/arm/operators/op_index_select.py +++ b/backends/arm/operators/op_index_select.py @@ -3,19 +3,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils # noqa: F401 +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.utils import build_reshape_tosa_1_0 +from executorch.backends.arm.tosa.utils import build_reshape_tosa from torch.fx import Node @@ -46,13 +50,16 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: + validate_num_inputs(self.target, inputs, 3) + validate_same_dtype(self.target, [inputs[0], output], ts) + validate_valid_dtype( + self.target, + [inputs[0], output], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], + output.tosa_spec, + ) - import serializer.tosa_serializer as ts # type: ignore - - if len(inputs) != 3: - raise ValueError(f"Number of inputs are not 3: {len(inputs)}") - - weights, index, indices = inputs + weights, _, indices = inputs if len(weights.shape) == 2: weights_new_shape = [1, weights.shape[0], weights.shape[1]] @@ -60,7 +67,7 @@ def define_node( weights_new_shape, weights.dtype, ) - build_reshape_tosa_1_0( + build_reshape_tosa( tosa_graph, weights.name, weights_new_shape, weights_reshaped.name ) @@ -82,21 +89,21 @@ def define_node( indices_new_shape, indices.dtype, ) - build_reshape_tosa_1_0( + build_reshape_tosa( tosa_graph, indices.name, indices_new_shape, indices_reshaped.name ) + attr = ts.TosaSerializerAttribute() + attr.GatherAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().GATHER, + ts.Op.GATHER, [weights_reshaped.name, indices_reshaped.name], [output_name], - None, + attr, ) if len(weights.shape) == 2: output_real_shape = [output.shape[0], output.shape[1]] - build_reshape_tosa_1_0( - tosa_graph, output_name, output_real_shape, output.name - ) + build_reshape_tosa(tosa_graph, output_name, output_real_shape, output.name) diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py index 2ef7eac352b..cd0809df95b 100644 --- a/backends/arm/operators/op_index_tensor.py +++ b/backends/arm/operators/op_index_tensor.py @@ -3,7 +3,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import math from typing import Any, List @@ -11,6 +10,7 @@ import executorch.backends.arm.tosa.utils as tutils import numpy as np +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -24,7 +24,6 @@ from torch.fx import Node -@register_node_visitor class CommonIndexTensorVisitor(NodeVisitor): target = "aten.index.Tensor" @@ -127,7 +126,6 @@ def define_node( If the number of total elements in the values tensor exceeds int32 limits then this approach falls apart. """ - import serializer.tosa_serializer as ts validate_same_dtype(self.target, [inputs[0], output]) @@ -166,25 +164,28 @@ def define_node( # channels and thus the stride-shift. data = np.full(index_shape, int(values_strides[i] / C)) mul_const = tosa_graph.addConst(index_shape, index_dtype, data) - tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_{i}_shift") + tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_{i}_shift") + attr = ts.TosaSerializerAttribute() + attr.MulAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().MUL, - [index_name, mul_const.name, f"{node.name}_{i}_shift"], + ts.Op.MUL, + [index_name, mul_const.name, f"{output.name}_{i}_shift"], [stride_shifted_indices.name], + attr, ) reshaped_idxs = tosa_graph.addIntermediate( gather_idx_shape, index_dtype, ) - tutils.build_reshape_tosa_1_0( + tutils.build_reshape_tosa( tosa_graph, stride_shifted_indices.name, gather_idx_shape, reshaped_idxs.name, - shape_name_override=f"{node.name}_{i}_shape", + shape_name_override=f"{output.name}_{i}_shape", ) # Guarantees that the accumulation tensor is properly @@ -196,24 +197,27 @@ def define_node( reshaped_idxs.shape, reshaped_idxs.dtype, ) + attr = ts.TosaSerializerAttribute() + attr.AddAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().ADD, + ts.Op.ADD, [gather_index_name, reshaped_idxs.name], [add_idxs.name], + attr, ) gather_index_name = add_idxs.name gather_vals_shape = [N, K, C] reshaped_input = tosa_graph.addIntermediate(gather_vals_shape, values.dtype) - tutils.build_reshape_tosa_1_0( + tutils.build_reshape_tosa( tosa_graph, values.name, gather_vals_shape, reshaped_input.name, - shape_name_override=f"{node.name}_index_shape", + shape_name_override=f"{output.name}_index_shape", ) gather_out_shape = (N, W, C) @@ -221,21 +225,23 @@ def define_node( gather_out_shape, output.dtype, ) + attr = ts.TosaSerializerAttribute() + attr.GatherAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().GATHER, + ts.Op.GATHER, [reshaped_input.name, gather_index_name], [gather_out.name], - None, + attr, ) output_shape = tutils.tosa_shape(output.shape, output.dim_order) - tutils.build_reshape_tosa_1_0( + tutils.build_reshape_tosa( tosa_graph, gather_out.name, list(output_shape), output.name, - shape_name_override=f"{node.name}_output_shape", + shape_name_override=f"{output.name}_output_shape", ) diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index 403c6c233d3..aa6b52b9982 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -3,11 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -43,35 +42,23 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, inputs, ts) validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - + attr = ts.TosaSerializerAttribute() + attr.GreaterEqualAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().GREATER_EQUAL, - [input_nodes[1].name, input_nodes[0].name], + ts.Op.GREATER_EQUAL, + [inputs[1].name, inputs[0].name], [output.name], - None, + attr, ) diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 051b10af062..565d6d56027 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -3,9 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -37,14 +38,13 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - + attr = ts.TosaSerializerAttribute() + attr.LogAttribute() self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().LOG, [inputs[0].name], [output.name] + node, tosa_graph, ts.Op.LOG, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_logical_not.py b/backends/arm/operators/op_logical_not.py index 640c3b4e44f..695af5f7a26 100644 --- a/backends/arm/operators/op_logical_not.py +++ b/backends/arm/operators/op_logical_not.py @@ -7,6 +7,8 @@ import torch.fx +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -39,8 +41,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore # noqa: F401 - validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( @@ -50,10 +50,13 @@ def define_node( output.tosa_spec, ) + attr = ts.TosaSerializerAttribute() + attr.LogicalNotAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().LOGICAL_NOT, + ts.Op.LOGICAL_NOT, [inputs[0].name], [output.name], + attr, ) diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index f5132dd4feb..4b2b1a1960b 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -3,11 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -43,35 +42,23 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, inputs, ts) validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - + attr = ts.TosaSerializerAttribute() + attr.GreaterAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().GREATER, - [input_nodes[1].name, input_nodes[0].name], + ts.Op.GREATER, + [inputs[1].name, inputs[0].name], [output.name], - None, + attr, ) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 39fcbf5cc64..bee0cc3fb0c 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -3,11 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List import torch +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -41,15 +42,15 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4, 5, 6]) validate_same_dtype(self.target, [inputs[0], output], ts) + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] + if self.tosa_spec.support_extension("int16"): + supported_dtypes.append(ts.DType.INT16) validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.FP32], + supported_dtypes, output.tosa_spec, ) @@ -91,13 +92,16 @@ def define_node( attr = ts.TosaSerializerAttribute() attr.MaxPool2dAttribute( - kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1 + kernel=kernel_size, + stride=stride, + pad=pad_size_list, + nan_mode=ts.NanPropagationMode.PROPAGATE, ) self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().MAX_POOL2D, + ts.Op.MAX_POOL2D, [input_tensor.name], [output.name], attr, diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index 66437f8af1d..d3ab305ea3b 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -3,15 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -22,9 +17,8 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.utils import tosa_shape +from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @@ -47,60 +41,26 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - from tosa.NanPropagationMode import NanPropagationMode # type: ignore - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - max_output = output - if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) - if len(input_qparams) != 2: - raise ValueError( - f"Both inputs need to have quantization information for {node}" - ) - if input_qparams[0] != input_qparams[1]: - raise ValueError( - "Both inputs must have the same quantization parameters for MAX" - ) - - operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - output.shape = tosa_shape(output.shape, output.dim_order) - max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) - else: - operand_inputs = inputs - attr_maximum = ts.TosaSerializerAttribute() - - # Set to PROPOGATE as default - attr_maximum.MaximumAttribute(nan_mode=NanPropagationMode.PROPAGATE) + attr_maximum.MaximumAttribute(nan_mode=ts.NanPropagationMode.PROPAGATE) self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().MAXIMUM, + ts.Op.MAXIMUM, [ - operand_inputs[0].name, - operand_inputs[1].name, + inputs[0].name, + inputs[1].name, ], - [max_output.name], + [output.name], attr_maximum, ) - - if output.dtype == ts.DType.INT8: - # insert RESCALE from int32 back to int8 - tqutils.insert_rescale_op_to_int8( - tosa_graph, max_output, scale_back, node, self.tosa_spec - ) diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index 518366d5463..7f72d158d43 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -3,15 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils +import tosa_serializer as ts -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -23,7 +19,6 @@ ) from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.utils import tosa_shape from torch.fx import Node @@ -47,59 +42,26 @@ def define_node( output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore - from tosa.NanPropagationMode import NanPropagationMode # type: ignore - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - min_output = output - if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) - if len(input_qparams) != 2: - raise ValueError( - f"Both inputs need to have quantization information for {node}" - ) - if input_qparams[0] != input_qparams[1]: - raise ValueError( - "Both inputs must have the same quantization parameters for MIN" - ) - - operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - output.shape = tosa_shape(output.shape, output.dim_order) - min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) - else: - operand_inputs = inputs - attr_minimum = ts.TosaSerializerAttribute() - - # Set to PROPOGATE as default - attr_minimum.MinimumAttribute(nan_mode=NanPropagationMode.PROPAGATE) + attr_minimum.MinimumAttribute(nan_mode=ts.NanPropagationMode.PROPAGATE) self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().MINIMUM, + ts.Op.MINIMUM, [ - operand_inputs[0].name, - operand_inputs[1].name, + inputs[0].name, + inputs[1].name, ], - [min_output.name], + [output.name], attr_minimum, ) - - if output.dtype == ts.DType.INT8: - # insert RESCALE from int32 back to int8 - tqutils.insert_rescale_op_to_int8( - tosa_graph, min_output, scale_back, node, self.tosa_spec - ) diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 9d139c68242..0e10443e523 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -3,17 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils -import executorch.backends.arm.tosa.utils as tutils import torch -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -24,17 +19,17 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification @register_node_visitor -class MulVisitor_INT(NodeVisitor): +class MulVisitor(NodeVisitor): target = "aten.mul.Tensor" tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+FP"), TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), ] def define_node( @@ -44,113 +39,23 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16: - input_A = inputs[0] - input_B = inputs[1] - input_qparams = get_input_qparams(node) - input_A_qargs = input_qparams[0] - input_B_qargs = input_qparams[1] - input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) - input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) - - # Rescale inputs to INT32 with zp=0 - input_A_rescaled = tqutils.build_rescale_to_int32( - tosa_graph, - input_A, - input_A_qargs.get_zp_per_tensor(), - 1.0, - tosa_spec=self.tosa_spec, - ) - input_B_rescaled = tqutils.build_rescale_to_int32( - tosa_graph, - input_B, - input_B_qargs.get_zp_per_tensor(), - 1.0, - tosa_spec=self.tosa_spec, - ) - else: - # input[0].dtype == ts.DType.INT16 or ts.DType.INT32 - # Non quantized input, natively support by TOSA.MUL - input_A_rescaled, input_B_rescaled = inputs[0], inputs[1] - - if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16: - output_shape = tutils.tosa_shape(output.shape, output.dim_order) - mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) - else: - # output.dtype == ts.DType.INT32 (non-quantized) - mul_output = output - - # Do the INT32 Mul - tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().MUL, - [input_A_rescaled.name, input_B_rescaled.name, f"{node.name}_shift"], - [mul_output.name], - ) - - if output.dtype == ts.DType.INT8: - # Scale output back to 8 bit - output_scale = ( - input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] - * input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] - ) - tqutils.insert_rescale_op_to_int8( - tosa_graph, mul_output, output_scale, node, self.tosa_spec - ) - elif output.dtype == ts.DType.INT16: - # Scale output back to 16 bit - output_scale = ( - input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] - * input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] - ) - tqutils.insert_rescale_op_to_int16( - tosa_graph, mul_output, output_scale, node, self.tosa_spec - ) - - -@register_node_visitor -class MulVisitor_FP(MulVisitor_INT): - # inheriting 'target' from INT class - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, [*inputs, output], ts) - - if inputs[0].dtype == ts.DType.INT8: - return super().define_node(node, tosa_graph, inputs, output) - - input1, input2 = inputs - - tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") + tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_shift") + attr = ts.TosaSerializerAttribute() + attr.MulAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().MUL, - [input1.name, input2.name, f"{node.name}_shift"], + ts.Op.MUL, + [inputs[0].name, inputs[1].name, f"{output.name}_shift"], [output.name], + attr, ) diff --git a/backends/arm/operators/op_neg.py b/backends/arm/operators/op_neg.py index 98aeea14bea..e0bb408e155 100644 --- a/backends/arm/operators/op_neg.py +++ b/backends/arm/operators/op_neg.py @@ -3,11 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List import torch.fx +import tosa_serializer as ts + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -53,8 +54,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore - supported_dtypes = [ ts.DType.INT8, ts.DType.INT16, @@ -81,11 +80,13 @@ def define_node( output_zp_tensor = tosa_graph.addConst( (1,), output.dtype, [output_zp], name=output.name + "_output_zp" ) - + attr = ts.TosaSerializerAttribute() + attr.NegateAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().NEGATE, + ts.Op.NEGATE, [inputs[0].name, input_zp_tensor.name, output_zp_tensor.name], [output.name], + attr, ) diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index 92cc2b37479..fea0aea9298 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -3,12 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List import torch +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -110,14 +111,18 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.FP32, + ], output.tosa_spec, ) @@ -138,7 +143,7 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().TRANSPOSE, + ts.Op.TRANSPOSE, [inputs[0].name], [output.name], attr, diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_pow.py index 8e7cffc0770..33cbc290d2c 100644 --- a/backends/arm/operators/op_pow.py +++ b/backends/arm/operators/op_pow.py @@ -3,10 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -39,8 +40,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( @@ -49,15 +48,16 @@ def define_node( [ts.DType.FP16, ts.DType.FP32], output.tosa_spec, ) - + attr = ts.TosaSerializerAttribute() + attr.PowAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().POW, + ts.Op.POW, [ inputs[0].name, inputs[1].name, ], [output.name], - None, + attr, ) diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 5aa45f740c2..108a4fac0fb 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -3,11 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List import torch +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -38,14 +39,13 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - + attr = ts.TosaSerializerAttribute() + attr.ReciprocalAttribute() self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name] + node, tosa_graph, ts.Op.RECIPROCAL, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 5db7ce9347c..e44fede736d 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -3,11 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any import torch + +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -37,14 +38,18 @@ def define_node( inputs: list[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.FP32, + ], output.tosa_spec, ) @@ -57,14 +62,16 @@ def define_node( (len(multiples),), ts.DType.SHAPE, list(tosa_shape(multiples, output.dim_order)), - name=node.name + "_multiples", + name=output.name + "_multiples", ) + attr = ts.TosaSerializerAttribute() + attr.TileAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().TILE, + ts.Op.TILE, [inputs[0].name, multiple_shapes.name], [output.name], - None, + attr, ) diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_rescale.py deleted file mode 100644 index d7be2be737c..00000000000 --- a/backends/arm/operators/op_rescale.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Any, cast, List - -import torch -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, -) - -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.mapping import map_dtype, TosaArg -from executorch.backends.arm.tosa.quant_utils import build_rescale -from torch.fx import Node - - -@register_node_visitor -class RescaleVisitor(NodeVisitor): - target = "tosa.RESCALE.default" - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+INT")] - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts # type: ignore - from tosa.RoundingMode import RoundingMode # type: ignore - - validate_num_inputs(self.target, inputs, 5) - - input_dtype = inputs[0].dtype - output_dtype = cast(torch.dtype, node.args[1]) - scale = cast(float, node.args[2]) - input_zp = cast(int, node.args[3]) - output_zp = cast(int, node.args[4]) - - if input_dtype != map_dtype(torch.int8, self.tosa_spec) and input_zp != 0: - raise ValueError( - f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}" - ) - if output_dtype != torch.int8 and output_zp != 0: - raise ValueError( - f"If output dtype is not int8, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" - ) - - build_rescale( - tosa_graph, - scale=[scale], - input_node=inputs[0], - output_name=output.name, - output_type=output.dtype, - input_zp=[input_zp], - output_zp=[output_zp], - rounding_mode=RoundingMode.SINGLE_ROUND, - per_channel=False, - ) diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_rshift_tensor.py index 2a41d685f5d..0b5717aa403 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_rshift_tensor.py @@ -3,12 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List import torch +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -34,8 +35,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( @@ -56,7 +55,7 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().ARITHMETIC_RIGHT_SHIFT, + ts.Op.ARITHMETIC_RIGHT_SHIFT, [inputs[0].name, inputs[1].name], [output.name], attr, diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 362a30f1cf5..a86eaa40985 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -3,11 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List import torch +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -38,14 +39,13 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - + attr = ts.TosaSerializerAttribute() + attr.RsqrtAttribute() self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().RSQRT, [inputs[0].name], [output.name] + node, tosa_graph, ts.Op.RSQRT, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index 2c4673d6b5f..908544ff00c 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -3,9 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -37,14 +38,13 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - + attr = ts.TosaSerializerAttribute() + attr.SigmoidAttribute() self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().SIGMOID, [inputs[0].name], [output.name] + node, tosa_graph, ts.Op.SIGMOID, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_sin.py b/backends/arm/operators/op_sin.py index 76aee063555..faa249917c3 100644 --- a/backends/arm/operators/op_sin.py +++ b/backends/arm/operators/op_sin.py @@ -3,10 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import List -import serializer.tosa_serializer as ts # type: ignore +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -43,7 +42,8 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - + attr = ts.TosaSerializerAttribute() + attr.SinAttribute() self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().SIN, [inputs[0].name], [output.name] + node, tosa_graph, ts.Op.SIN, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index 12d38060aa6..21c86e5f7c4 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -3,10 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -21,17 +22,34 @@ def _fixup_start(start, shape, dim): - if start.number < 0: - return start.number % shape[dim] - else: - return start.number + # Normalize start index and clamp into [0, shape[dim]]. + # If not a constant, default to 0. + idx = getattr(start, "number", 0) + # Handle negative wrap-around + if idx < 0: + idx = idx % shape[dim] + # Clamp into valid bounds + if idx < 0: + idx = 0 + elif idx > shape[dim]: + idx = shape[dim] + return idx def _fixup_end(end, shape, dim): - if end.number < 0: - return end.number % shape[dim] - else: - return min(end.number, shape[dim]) + # Normalize end index and clamp into [0, shape[dim]]. + max_dim = shape[dim] + # If not a constant, default to the full size + idx = getattr(end, "number", max_dim) + # Handle negative wrap-around + if idx < 0: + idx = idx % max_dim + # Clamp into valid bounds + if idx < 0: + idx = 0 + elif idx > max_dim: + idx = max_dim + return idx @register_node_visitor @@ -50,14 +68,18 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [4, 5]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], + [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.FP32, + ], output.tosa_spec, ) @@ -104,7 +126,7 @@ def define_node( (starts_len,), ts.DType.SHAPE, starts, - node.name + "_start_shape", + output.name + "_start_shape", ) sizes = [size if i == dim else shape[i] for i in input_node.dim_order] @@ -114,14 +136,16 @@ def define_node( sizes_len = 1 sizes = [0] sizes_tensor = tosa_graph.addConst( - (sizes_len,), ts.DType.SHAPE, sizes, node.name + "_sizes_shape" + (sizes_len,), ts.DType.SHAPE, sizes, output.name + "_sizes_shape" ) + attr = ts.TosaSerializerAttribute() + attr.SliceAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().SLICE, + ts.Op.SLICE, [input_node.name, start_tensor.name, sizes_tensor.name], [output.name], - None, + attr, ) diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 9c27fddf68a..039a2f6bd68 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -3,12 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils -import executorch.backends.arm.tosa.utils as tutils +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -19,22 +17,20 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @register_node_visitor -class SubVisitor_INT(NodeVisitor): +class SubVisitor(NodeVisitor): target = "aten.sub.Tensor" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def __init__(self, *args): - super().__init__(*args) - def define_node( self, node: Node, @@ -42,98 +38,26 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - if inputs[0].dtype == ts.DType.INT8: - rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale( - tosa_graph, inputs, node, self.tosa_spec - ) - else: - # input[0].dtype == ts.DType.INT32 - # Non quantized input, natively support by TOSA.SUB - rescaled_inputs = inputs + attr = ts.TosaSerializerAttribute() + attr.SubAttribute() - if output.dtype == ts.DType.INT8: - broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) - sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) - else: - # output.dtype == ts.DType.INT32 - sub_output = output - - # Do the INT32 Sub self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().SUB, + ts.Op.SUB, [ - rescaled_inputs[0].name, - rescaled_inputs[1].name, + inputs[0].name, + inputs[1].name, ], - [sub_output.name], - None, + [output.name], + attr, ) - - if output.dtype == ts.DType.INT8: - # Scale output back to 8 bit - # pyre-ignore - tqutils.insert_rescale_op_to_int8( - tosa_graph, - sub_output, - scale_back, - node, - compute_rescale=False, - tosa_spec=self.tosa_spec, - ) # type: ignore[possibly-undefined] - - -@register_node_visitor -class SubVisitor_FP(SubVisitor_INT): - # inheriting 'target' from INT class - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, [*inputs, output], ts) - - if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: - # Call the inherited define_node for handling integers - super().define_node(node, tosa_graph, inputs, output) - else: - # FP32 Sub lowering - validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec - ) - - # MI lowering - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().SUB, - [inputs[0].name, inputs[1].name], - [output.name], - None, - ) diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index 0bd152a8b8c..e956359736c 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -3,12 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils -import executorch.backends.arm.tosa.utils as tutils +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -17,6 +15,7 @@ from executorch.backends.arm.operators.operator_validation_utils import ( validate_num_inputs, validate_same_dtype, + validate_valid_dtype, ) from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg @@ -24,16 +23,14 @@ @register_node_visitor -class SumVisitor_INT(NodeVisitor): +class SumVisitor(NodeVisitor): target = "aten.sum.dim_IntList" tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+FP"), TosaSpecification.create_from_string("TOSA-1.0+INT"), ] - def __init__(self, *args): - super().__init__(*args) - def define_node( self, node: Node, @@ -41,78 +38,26 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 3) validate_same_dtype(self.target, [inputs[0], output], ts) - - tensor = inputs[0] - input_shape = list(tensor.shape) - dim = int(inputs[1].number % len(input_shape)) - - output_shape = input_shape - output_shape[dim] = 1 # Output shape is input shape with dim reduced - - # Rescale input to 32 bit - rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32( - tosa_graph, [tensor], node, self.tosa_spec - ) - - attr = ts.TosaSerializerAttribute() - attr.ReduceSumAttribute(tensor.dim_order.index(dim)) - - intermediate = tosa_graph.addIntermediate( - tutils.tosa_shape(output_shape, tensor.dim_order), - dtype=ts.DType.INT32, - ) - - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_SUM, - [rescaled_inputs[0].name], - [intermediate.name], - attr, - ) - - tqutils.insert_rescale_op_to_int8( - tosa_graph, intermediate, scale, node, self.tosa_spec + validate_valid_dtype( + self.target, + [inputs[0], output], + [ts.DType.INT32, ts.DType.FP32], + output.tosa_spec, ) - -@register_node_visitor -class SumVisitor_FP(SumVisitor_INT): - # inheriting 'target' from INT class - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 3) - validate_same_dtype(self.target, [inputs[0], output], ts) - tensor = inputs[0] input_shape = list(tensor.shape) dim = int(inputs[1].number % len(input_shape)) - output_shape = input_shape - output_shape[dim] = 1 # Output shape is input shape with dim reduced - attr = ts.TosaSerializerAttribute() attr.ReduceSumAttribute(tensor.dim_order.index(dim)) - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_SUM, + self._serialize_operator( + node, + tosa_graph, + ts.Op.REDUCE_SUM, [tensor.name], [output.name], attr, diff --git a/backends/arm/operators/op_table.py b/backends/arm/operators/op_table.py deleted file mode 100644 index 41b40268f6d..00000000000 --- a/backends/arm/operators/op_table.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Any, List - -import torch -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_valid_dtype, -) - -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.mapping import TosaArg - - -@register_node_visitor -class TableVisitor(NodeVisitor): - target = "tosa.TABLE.default" - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+INT")] - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 2) - validate_valid_dtype( - self.target, inputs, [ts.DType.INT8, ts.DType.INT16], output.tosa_spec - ) - if inputs[0].dtype == ts.DType.INT8: - validate_valid_dtype(self.target, output, ts.DType.INT8, output.tosa_spec) - if inputs[0].dtype == ts.DType.INT16: - validate_valid_dtype(self.target, output, ts.DType.INT32, output.tosa_spec) - - if inputs[1].name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr] - raise RuntimeError( - f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}." - ) - - table = self._exported_program.state_dict[inputs[1].name] # type: ignore[union-attr] - - table_tensor_name = node.name + "_table" - tosa_graph.addConst( - table.shape, - ts.DType.INT8 if inputs[0].dtype == ts.DType.INT8 else ts.DType.INT16, - table.detach().numpy(), - name=table_tensor_name, - ) - - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().TABLE, - [inputs[0].name, table_tensor_name], - [output.name], - None, - ) diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 5837825a6a1..c4603e90118 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -3,9 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -38,14 +39,13 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec ) - + attr = ts.TosaSerializerAttribute() + attr.TanhAttribute() self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().TANH, [inputs[0].name], [output.name] + node, tosa_graph, ts.Op.TANH, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_to_dim_order_copy.py b/backends/arm/operators/op_to_dim_order_copy.py index aa5873f698e..9d3aff83554 100644 --- a/backends/arm/operators/op_to_dim_order_copy.py +++ b/backends/arm/operators/op_to_dim_order_copy.py @@ -3,11 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List import torch +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -40,10 +41,9 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, 1) - + attr = ts.TosaSerializerAttribute() + attr.CastAttribute() self._serialize_operator( - node, tosa_graph, ts.TosaOp.Op().CAST, [inputs[0].name], [output.name] + node, tosa_graph, ts.Op.CAST, [inputs[0].name], [output.name], attr ) diff --git a/backends/arm/operators/op_tosa_conv2d.py b/backends/arm/operators/op_tosa_conv2d.py new file mode 100644 index 00000000000..b97242d8373 --- /dev/null +++ b/backends/arm/operators/op_tosa_conv2d.py @@ -0,0 +1,139 @@ +# Copyright 2023-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import tosa_serializer as ts + +"""Provide a visitor for lowering 2D convolution to TOSA (INT/FP).""" + +from typing import Any, List + +import torch + +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification + + +@register_node_visitor +class Conv2dVisitor(NodeVisitor): + """Provide a visitor that serializes TOSA ``CONV2D``.""" + + target = "tosa.CONV2D.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def _get_tosa_op(self): + return ts.Op.CONV2D + + def _get_attr_func(self, attr): + return attr.Conv2dAttribute + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + """Define the TOSA CONV2D/DEPTHWISE_CONV2D operator.""" + + input, weight, bias, stride, pad, dilation, _, _, group = inputs + validate_num_inputs(self.target, inputs, 9) + + valid_input_dtypes = [] + if self.tosa_spec.support_float(): + valid_input_dtypes.append(ts.DType.FP32) + if self.tosa_spec.support_integer(): + valid_input_dtypes.append(ts.DType.INT8) + + if self.tosa_spec.support_extension("int16"): + valid_input_dtypes.append(ts.DType.INT16) + # Check constraints for int16 activations + if inputs[0].dtype == ts.DType.INT16: + validate_valid_dtype( + self.target, [inputs[1]], [ts.DType.INT8], self.tosa_spec + ) + validate_valid_dtype( + self.target, [inputs[2]], [ts.DType.INT48], self.tosa_spec + ) + + validate_valid_dtype( + self.target, + [inputs[0]], + valid_input_dtypes, + self.tosa_spec, + ) + + # Get the attributes of convolution. + pad_attr = pad.special + stride_attr = stride.special + dilation_attr = dilation.special + + input_zp = 0 + if inputs[0].dtype in (ts.DType.INT8, ts.DType.INT16): + # int8 and int16 input requires quantization information + input_qparams = get_input_qparams(node) + input_zp = input_qparams[0].get_zp_per_tensor() + + weight_zp = 0 + if inputs[1].dtype == ts.DType.INT8: + # int8 weights requires quantization information + input_qparams = get_input_qparams(node) + weight_zp = input_qparams[1].zp # type: ignore[assignment] + + conv2d_output_name = output.name + acc_type = output.dtype + + tosa_graph.addConst( + [1], inputs[0].dtype, [input_zp], name=f"{conv2d_output_name}_input_zp" + ) + tosa_graph.addConst( + [1], + inputs[1].dtype, + weight_zp, + name=f"{conv2d_output_name}_weight_zp", + ) + + tosa_op = self._get_tosa_op() + + attr = ts.TosaSerializerAttribute() + self._get_attr_func(attr)( + pad=pad_attr, + stride=stride_attr, + dilation=dilation_attr, + local_bound=False, + acc_type=acc_type, + ) + + self._serialize_operator( + node, + tosa_graph, + tosa_op, + [ + input.name, + weight.name, + bias.name, + f"{conv2d_output_name}_input_zp", + f"{conv2d_output_name}_weight_zp", + ], + [conv2d_output_name], + attr, + ) diff --git a/backends/arm/operators/op_tosa_conv3d.py b/backends/arm/operators/op_tosa_conv3d.py new file mode 100644 index 00000000000..e0a8d2ef6ac --- /dev/null +++ b/backends/arm/operators/op_tosa_conv3d.py @@ -0,0 +1,24 @@ +# Copyright 2023-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Provide a visitor for lowering 3D convolution to TOSA (INT/FP).""" + +from executorch.backends.arm.operators.node_visitor import register_node_visitor +from executorch.backends.arm.operators.op_tosa_conv2d import Conv2dVisitor + + +@register_node_visitor +class Conv3dVisitor(Conv2dVisitor): + """Provide a visitor that serializes TOSA ``CONV3D``.""" + + target = "tosa.CONV3D.default" + + def _get_tosa_op(self): + import serializer.tosa_serializer as ts # type: ignore + + return ts.Op.CONV3D + + def _get_attr_func(self, attr): + return attr.Conv3dAttribute diff --git a/backends/arm/operators/op_tosa_depthwise_conv2d.py b/backends/arm/operators/op_tosa_depthwise_conv2d.py new file mode 100644 index 00000000000..78e6e4424cb --- /dev/null +++ b/backends/arm/operators/op_tosa_depthwise_conv2d.py @@ -0,0 +1,32 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Provide a visitor for lowering 2D depthwise convolution to TOSA (INT/FP).""" + +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import register_node_visitor +from executorch.backends.arm.operators.op_tosa_conv2d import Conv2dVisitor +from executorch.backends.arm.tosa import TosaSpecification + + +@register_node_visitor +class DepthwiseConv2dVisitor(Conv2dVisitor): + """Provide a visitor that serializes TOSA ``DEPTHWISE_CONV2D``.""" + + target = "tosa.DEPTHWISE_CONV2D.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def _get_tosa_op(self): + return ts.Op.DEPTHWISE_CONV2D + + def _get_attr_func(self, attr): + return attr.DepthwiseConv2dAttribute + + # Inheriting the define_node method from Conv2dVisitor diff --git a/backends/arm/operators/op_tosa_matmul.py b/backends/arm/operators/op_tosa_matmul.py new file mode 100644 index 00000000000..993caff9867 --- /dev/null +++ b/backends/arm/operators/op_tosa_matmul.py @@ -0,0 +1,102 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Provide a visitor for lowering batched matmul (BMM) to TOSA.""" + +from typing import Any, List + +import torch +import tosa_serializer as ts + +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.mapping import TosaArg + + +@register_node_visitor +class MatmulVisitor(NodeVisitor): + """Provide a visitor that serializes TOSA ``MATMUL``.""" + + target = "tosa.MATMUL.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + """Define the TOSA ``MATMUL`` operator.""" + validate_num_inputs(self.target, inputs, 2) + validate_same_dtype(self.target, [*inputs], ts) + supported_input_dtypes = [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32] + if self.tosa_spec.support_extension("int16"): + supported_input_dtypes.append(ts.DType.INT16) + validate_valid_dtype( + self.target, + [*inputs], + supported_input_dtypes, + output.tosa_spec, + ) + supported_output_dtypes = [ts.DType.INT32, ts.DType.FP32] + if self.tosa_spec.support_extension("int16"): + supported_output_dtypes.append(ts.DType.INT48) + validate_valid_dtype( + self.target, + [output], + supported_output_dtypes, + output.tosa_spec, + ) + + # We need to get the zero points and add an intermediate tensor for INT16 case + if inputs[0].dtype in (ts.DType.INT8, ts.DType.INT16): + input_qparams = get_input_qparams(node) + input0_zp = input_qparams[0].get_zp_per_tensor() + input1_zp = input_qparams[1].get_zp_per_tensor() + else: + input0_zp, input1_zp = 0, 0 + + input_A_ZP_name = f"{output.name}_A_ZP" + input_B_ZP_name = f"{output.name}_B_ZP" + tosa_graph.addConst([1], inputs[0].dtype, [input0_zp], name=input_A_ZP_name) + tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=input_B_ZP_name) + + # Add the MATMUL to the TOSA graph. + attr = ts.TosaSerializerAttribute() + attr.MatMulAttribute() + + self._serialize_operator( + node, + tosa_graph, + ts.Op.MATMUL, + [ + inputs[0].name, + inputs[1].name, + input_A_ZP_name, + input_B_ZP_name, + ], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py new file mode 100644 index 00000000000..ae87dcc9c31 --- /dev/null +++ b/backends/arm/operators/op_tosa_rescale.py @@ -0,0 +1,261 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from typing import Any, cast, List, Tuple + +import torch + +import tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) + +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.mapping import map_dtype, TosaArg +from torch.fx import Node + + +def _compute_multiplier_and_shift( + scales: list[float], scaleWidth: int = 32 +) -> Tuple[list[int], list[int]]: + """Derive integer multipliers and shifts from floating-point scales. + + TOSA uses the RESCALE operation to scale between values with differing + precision. The RESCALE operator is defined using an integer multiply, add, + and shift. This utility function is for calculating the multiplier and shift + given a scale. + Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling + + Args: + scales (list[float]): Scale factors to decompose into multiplier and + shift pairs. + scaleWidth (int): Bit-width of the multiplier representation; expects + ``16`` or ``32``. + + Returns: + Tuple[list[int], list[int]]: Parallel lists containing the computed + multipliers and right shifts. + + Raises: + ValueError: If ``scaleWidth`` is not supported. + + """ + if scaleWidth == 16: + offset = 15 + elif scaleWidth == 32: + offset = 31 + else: + raise ValueError( + f"Unsupported scale width: {scaleWidth}, only 16 and 32 are valid values." + ) + + multipliers = [] + shifts = [] + for scale in scales: + mantissa, exponent = math.frexp(scale) + shift = exponent + + const_2_power_15_or_31 = 1 << offset + shifted_mantissa = round(mantissa * const_2_power_15_or_31) + + assert ( + shifted_mantissa <= const_2_power_15_or_31 + ), f"Mantissa {shifted_mantissa} exceeds limit {const_2_power_15_or_31}" + + if shifted_mantissa == const_2_power_15_or_31: + shifted_mantissa = shifted_mantissa // 2 + shift += 1 + + # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits. + shift = offset - shift + + # INT32_MAX, 2^31 - 1 + assert shifted_mantissa <= (const_2_power_15_or_31 - 1), ( + f"Mantissa {shifted_mantissa} exceeds signed max " + f"{const_2_power_15_or_31 - 1}" + ) + + multiplier = shifted_mantissa + + if shift > 62: + multiplier = multiplier >> min(31, shift - 62) + shift = 62 + + assert multiplier >= 0, "Multiplier should be non-negative" + assert shift >= 2 and shift <= 62, "Shift should be in range [2, 62]" + multipliers.append(multiplier) + shifts.append(shift) + return multipliers, shifts + + +def _create_const_ops_for_rescale( + tosa_fb, + scale_32, + input_dtype, + node_name, + multipliers, + shifts, + input_zp, + output_zp, + output_dtype, + ts, +): + """Materialize constant operands required by the TOSA RESCALE op. + + For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp + and output_zp to be const inputs. Create constant operators from the data + already initialized. + + Args: + tosa_fb (Any): Graph builder used to emit TOSA operators and tensors. + scale_32 (bool): Flag indicating whether multipliers use 32-bit width. + input_dtype (ts.DType): Data type of the input tensor. + node_name (str): Base name reused for created constant tensors. + multipliers (list[int]): Precomputed multiplier coefficients. + shifts (list[int]): Precomputed shift coefficients. + input_zp (list[int]): Quantization zero points for the input. + output_zp (list[int]): Quantization zero points for the output. + output_dtype (ts.DType): Data type of the output tensor. + ts (module): Reference to the ``tosa_serializer`` module. + + Returns: + list[str]: Names of the constant tensors added to ``tosa_fb`` in the + order expected by RESCALE. + + """ + + multipliers = tosa_fb.addConst( + (len(multipliers),), + ts.DType.INT32 if scale_32 else ts.DType.INT16, + multipliers, + name=node_name + "_multipliers", + ) + shifts = tosa_fb.addConst( + (len(shifts),), ts.DType.INT8, shifts, name=node_name + "_shifts" + ) + input_zp = tosa_fb.addConst( + [1], input_dtype, input_zp, name=node_name + "_input_zp" + ) + output_zp = tosa_fb.addConst( + [1], output_dtype, output_zp, name=node_name + "_output_zp" + ) + + return [multipliers.name, shifts.name, input_zp.name, output_zp.name] + + +def _build_rescale( + tosa_fb: Any, + scale: list[float], + input_node: Any, + output_name: str, + output_type: Any, + input_zp: list[int], + output_zp: list[int], + rounding_mode: ts.RoundingMode, + per_channel: bool = False, + is_scale32: bool = True, +): + """Insert a TOSA RESCALE operator configured for the quantized path. + + Args: + tosa_fb (Any): Graph builder receiving the RESCALE operator. + scale (list[float]): Scale factors applied during rescaling. + input_node (Any): Input tensor node feeding the operator. + output_name (str): Name assigned to the RESCALE output tensor. + output_type (ts.DType): Data type of the output tensor. + input_zp (list[int]): Quantization zero points for the input tensor. + output_zp (list[int]): Quantization zero points for the output tensor. + rounding_mode (ts.RoundingMode): Rounding policy for the RESCALE op. + per_channel (bool): Whether scales are applied per output channel. + is_scale32 (bool): Declared scale width; ignored when the input type is + ``ts.DType.INT48``. + + """ + scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32 + is_scale32 = False if input_node.dtype == ts.DType.INT48 else True + multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth) + rescale_inputs = _create_const_ops_for_rescale( + tosa_fb, + is_scale32, + input_node.dtype, + output_name, + multipliers, + shifts, + input_zp, + output_zp, + output_type, + ts, + ) + attr_rescale = ts.TosaSerializerAttribute() + attr_rescale.RescaleAttribute( + scale32=is_scale32, + rounding_mode=rounding_mode, + per_channel=per_channel, + input_unsigned=False, + output_unsigned=False, + ) + + tosa_fb.addOperator( + ts.Op.RESCALE, + [input_node.name, *rescale_inputs], + [output_name], + attr_rescale, + ) + + +@register_node_visitor +class RescaleVisitor(NodeVisitor): + target = "tosa.RESCALE.default" + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+INT")] + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, 5) + + input_dtype = inputs[0].dtype + output_dtype = cast(torch.dtype, node.args[1]) + scales = cast(list[float], node.args[2]) + input_zp = cast(int, node.args[3]) + output_zp = cast(int, node.args[4]) + + if ( + input_dtype + not in [ + map_dtype(torch.int8, self.tosa_spec), + map_dtype(torch.int16, self.tosa_spec), + ] + and input_zp != 0 + ): + raise ValueError( + f"If input dtype is not int8 or int16, input_zp must be 0. Got input_dtype {input_dtype=}, {input_zp=}" + ) + if output_dtype not in [torch.int8, torch.int16] and output_zp != 0: + raise ValueError( + f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" + ) + + _build_rescale( + tosa_graph, + scale=scales, + input_node=inputs[0], + output_name=output.name, + output_type=output.dtype, + input_zp=[input_zp], + output_zp=[output_zp], + rounding_mode=ts.RoundingMode.SINGLE_ROUND, + per_channel=len(scales) > 1, + ) diff --git a/backends/arm/operators/op_tosa_resize.py b/backends/arm/operators/op_tosa_resize.py new file mode 100644 index 00000000000..e7e63f155d3 --- /dev/null +++ b/backends/arm/operators/op_tosa_resize.py @@ -0,0 +1,124 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import torch + +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.utils import get_resize_parameters + + +@register_node_visitor +class ResizeVisitor(NodeVisitor): + target = "tosa.RESIZE.default" + + tosa_specs = NodeVisitor.tosa_specs + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, [3, 4]) + supported_input_dtypes = [ts.DType.INT8, ts.DType.FP32] + if self.tosa_spec.support_extension("int16"): + supported_input_dtypes.append(ts.DType.INT16) + validate_valid_dtype( + self.target, + [inputs[0]], + supported_input_dtypes, + output.tosa_spec, + ) + supported_output_dtypes = [ts.DType.FP32] + if node.kwargs.get("resize_mode") == "bilinear": + resize_mode = ts.ResizeMode.BILINEAR + align_corners = bool(node.args[2]) + supported_output_dtypes.append(ts.DType.INT32) + if self.tosa_spec.support_extension("int16"): + supported_output_dtypes.append(ts.DType.INT48) + else: + resize_mode = ts.ResizeMode.NEAREST + align_corners = False + validate_same_dtype(self.target, [inputs[0], output], ts) + supported_output_dtypes.append(ts.DType.INT8) + if self.tosa_spec.support_extension("int16"): + supported_output_dtypes.append(ts.DType.INT16) + validate_valid_dtype( + self.target, [output], supported_output_dtypes, output.tosa_spec + ) + # tosa_shape output is NHWC, take HW + input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[ + 1:3 + ] + output_size_yx = tuple([output.shape[dim] for dim in output.dim_order])[1:3] + + # Align corners shouldn't make a difference for nearest upsampling. We set to False so + # half pixel centers are used for resize parameter logic. + scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters( + input_size_yx, output_size_yx, resize_mode, align_corners=align_corners + ) + + def in_int16_range(x): + return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) + + if not in_int16_range(scale_n_yx): + raise ValueError("scale_n_yx is out of the int16 range") + if not in_int16_range(scale_d_yx): + raise ValueError("scale_d_yx is out of the int16 range") + if not in_int16_range(border_yx): + raise ValueError("border_yx is out of the int16 range") + + scale_n_vals = [int(v) for v in scale_n_yx.tolist()] + scale_d_vals = [int(v) for v in scale_d_yx.tolist()] + scales = [ + scale_n_vals[0], + scale_d_vals[0], + scale_n_vals[1], + scale_d_vals[1], + ] + scales_tensor = tosa_graph.addConst( + [len(scales)], ts.DType.SHAPE, scales, output.name + "_scales" + ) + offset = [int(v) for v in offset_yx.tolist()] + offset_tensor = tosa_graph.addConst( + [len(offset)], ts.DType.SHAPE, offset, output.name + "_offset" + ) + border = [int(v) for v in border_yx.tolist()] + border_tensor = tosa_graph.addConst( + [len(border)], ts.DType.SHAPE, border, output.name + "_border" + ) + attr = ts.TosaSerializerAttribute() + attr.ResizeAttribute(resize_mode) + + self._serialize_operator( + node, + tosa_graph, + ts.Op.RESIZE, + [ + inputs[0].name, + scales_tensor.name, + offset_tensor.name, + border_tensor.name, + ], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_tosa_table.py b/backends/arm/operators/op_tosa_table.py new file mode 100644 index 00000000000..7448898bddc --- /dev/null +++ b/backends/arm/operators/op_tosa_table.py @@ -0,0 +1,72 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, List + +import torch + +import tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_valid_dtype, +) + +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.mapping import TosaArg + + +@register_node_visitor +class TableVisitor(NodeVisitor): + target = "tosa.TABLE.default" + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+INT")] + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, 2) + supported_input_dtypes = [ts.DType.INT8] + supported_output_dtypes = [ts.DType.INT8] + if self.tosa_spec.support_extension("int16"): + supported_input_dtypes.append(ts.DType.INT16) + supported_output_dtypes.append(ts.DType.INT32) + + validate_valid_dtype( + self.target, inputs, supported_input_dtypes, output.tosa_spec + ) + validate_valid_dtype( + self.target, output, supported_output_dtypes, output.tosa_spec + ) + + # The name of the table constant is a bit complex. + # The name of the pytorch buffer will be the target of last node argument. + # However, when it is serialized to TOSA, a submodule suffix might be added. The TOSA buffer name thus + # needs to be taken from the last TosaArg. + pytorch_table_buffer_name = node.args[-1].target # type: ignore[union-attr] + tosa_table_buffer_name = inputs[-1].name + if pytorch_table_buffer_name not in self._exported_program.state_dict.keys(): + raise RuntimeError( + f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}." + ) + + attr = ts.TosaSerializerAttribute() + attr.TableAttribute() + self._serialize_operator( + node, + tosa_graph, + ts.Op.TABLE, + [inputs[0].name, tosa_table_buffer_name], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_tosa_transpose.py b/backends/arm/operators/op_tosa_transpose.py new file mode 100644 index 00000000000..c5aa66a85fd --- /dev/null +++ b/backends/arm/operators/op_tosa_transpose.py @@ -0,0 +1,71 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, List + +import torch + +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg + + +@register_node_visitor +class TransposeVisitor(NodeVisitor): + """ + This node visitor targets the tosa::TRANSPOSE op defined in the + TOSA backend dialect. Used when switching between tosa_dim_orders. + Inserts a TOSA TRANSPOSE. + """ + + target = "tosa.TRANSPOSE.default" + + tosa_specs = NodeVisitor.tosa_specs + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, 2) + validate_same_dtype(self.target, [inputs[0], output], ts) + validate_valid_dtype( + self.target, + [inputs[0], output], + [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.FP16, + ts.DType.FP32, + ], + output.tosa_spec, + ) + + output_rank = len(output.shape) + perms = [dim % output_rank for dim in inputs[1].special] + attr = ts.TosaSerializerAttribute() + attr.TransposeAttribute(perms) + self._serialize_operator( + node, + tosa_graph, + ts.Op.TRANSPOSE, + [inputs[0].name], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_transpose.py b/backends/arm/operators/op_transpose.py deleted file mode 100644 index 48766687a62..00000000000 --- a/backends/arm/operators/op_transpose.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Any, List - -import torch - -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) -from executorch.backends.arm.tosa.mapping import TosaArg - - -@register_node_visitor -class TransposeVisitor(NodeVisitor): - """ - This node visitor targets the tosa::TRANSPOSE op defined in the - TOSA backend dialect. Used when switching between tosa_dim_orders. - Inserts a TOSA TRANSPOSE. - """ - - target = "tosa.TRANSPOSE.default" - - tosa_specs = NodeVisitor.tosa_specs - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, [inputs[0], output], ts) - validate_valid_dtype( - self.target, - [inputs[0], output], - [ - ts.DType.INT8, - ts.DType.INT16, - ts.DType.INT32, - ts.DType.FP32, - ts.DType.BOOL, - ts.DType.FP16, - ], - output.tosa_spec, - ) - - output_rank = len(output.shape) - perms = [dim % output_rank for dim in inputs[1].special] - attr = ts.TosaSerializerAttribute() - attr.TransposeAttribute(perms) - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().TRANSPOSE, - [inputs[0].name], - [output.name], - attr, - ) diff --git a/backends/arm/operators/op_upsample_bilinear2d.py b/backends/arm/operators/op_upsample_bilinear2d.py deleted file mode 100644 index 3cc620727e0..00000000000 --- a/backends/arm/operators/op_upsample_bilinear2d.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import Any, List - -import torch - -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) -from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.quant_utils import build_rescale -from executorch.backends.arm.tosa.utils import get_resize_parameters, tosa_shape - - -@register_node_visitor -class UpsampleBilinear2dVisitor(NodeVisitor): - - target = "aten.upsample_bilinear2d.vec" - tosa_specs = NodeVisitor.tosa_specs - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts - from tosa.ResizeMode import ResizeMode # type: ignore - from tosa.RoundingMode import RoundingMode # type: ignore - - validate_num_inputs(self.target, inputs, 4) - validate_same_dtype(self.target, [inputs[0], output], ts) - validate_valid_dtype( - self.target, - [inputs[0], output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, - ) - - if inputs[0].shape is None or output.shape is None: - raise ValueError("Only static shapes are supported") - - input_dtype = inputs[0].dtype - - # tosa_shape output is NHWC, take HW - input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[ - 1:3 - ] - output_size_yx = tuple([output.shape[dim] for dim in output.dim_order])[1:3] - - # Get align_corners value from the node arguments. - align_corners = bool(node.args[2]) - scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters( - input_size_yx, - output_size_yx, - ResizeMode.NEAREST, - align_corners=align_corners, - ) - - def in_int16_range(x): - return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) - - if not in_int16_range(scale_n_yx): - raise ValueError("scale_n_yx is out of the int16 range") - if not in_int16_range(scale_d_yx): - raise ValueError("scale_d_yx is out of the int16 range") - if not in_int16_range(border_yx): - raise ValueError("border_yx is out of the int16 range") - - scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]] - - attr = ts.TosaSerializerAttribute() - attr.ResizeAttribute(mode=ResizeMode.BILINEAR) - - scales_tensor = tosa_graph.addConst( - [len(scales)], ts.DType.SHAPE, scales, node.name + "_scales" - ) - offset = offset_yx.tolist() - offset_tensor = tosa_graph.addConst( - [len(offset)], ts.DType.SHAPE, offset, node.name + "_offset" - ) - border = border_yx.tolist() - border_tensor = tosa_graph.addConst( - [len(border)], ts.DType.SHAPE, border, node.name + "_border" - ) - if input_dtype == output.dtype == ts.DType.FP32: - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().RESIZE, - [ - inputs[0].name, - scales_tensor.name, - offset_tensor.name, - border_tensor.name, - ], - [output.name], - attr, - ) - return - elif input_dtype == output.dtype == ts.DType.INT8: - intermediate = tosa_graph.addIntermediate( - tosa_shape(output.shape, output.dim_order), ts.DType.INT32 - ) - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().RESIZE, - [ - inputs[0].name, - scales_tensor.name, - offset_tensor.name, - border_tensor.name, - ], - [intermediate.name], - attr, - ) - - final_output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1])) - - build_rescale( - tosa_fb=tosa_graph, - scale=[final_output_scale], - input_node=intermediate, - output_name=output.name, - output_type=ts.DType.INT8, - input_zp=[0], - output_zp=[0], - rounding_mode=RoundingMode.SINGLE_ROUND, - ) - else: - raise ValueError( - "Input/output dtype not in {float32, int8}: {input_dtype=} {output.dtype=}" - ) diff --git a/backends/arm/operators/op_upsample_nearest2d.py b/backends/arm/operators/op_upsample_nearest2d.py deleted file mode 100644 index 3c3ca67c9f5..00000000000 --- a/backends/arm/operators/op_upsample_nearest2d.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import Any, List - -import torch - -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) -from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.utils import get_resize_parameters - -from tosa.ResizeMode import ResizeMode # type: ignore - - -@register_node_visitor -class UpsampleNearest2dVisitor(NodeVisitor): - target = "aten.upsample_nearest2d.vec" - - tosa_specs = NodeVisitor.tosa_specs - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts - - validate_num_inputs(self.target, inputs, 3) - validate_same_dtype(self.target, [inputs[0], output], ts) - validate_valid_dtype( - self.target, - [inputs[0], output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, - ) - - # tosa_shape output is NHWC, take HW - input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[ - 1:3 - ] - output_size_yx = tuple([output.shape[dim] for dim in output.dim_order])[1:3] - - # Align corners shouldn't make a difference for nearest upsampling. We set to False so - # half pixel centers are used for resize parameter logic. - scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters( - input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=False - ) - - def in_int16_range(x): - return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) - - if not in_int16_range(scale_n_yx): - raise ValueError("scale_n_yx is out of the int16 range") - if not in_int16_range(scale_d_yx): - raise ValueError("scale_d_yx is out of the int16 range") - if not in_int16_range(border_yx): - raise ValueError("border_yx is out of the int16 range") - - scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]] - scales_tensor = tosa_graph.addConst( - [len(scales)], ts.DType.SHAPE, scales, node.name + "_scales" - ) - offset = offset_yx.tolist() - offset_tensor = tosa_graph.addConst( - [len(offset)], ts.DType.SHAPE, offset, node.name + "_offset" - ) - border = border_yx.tolist() - border_tensor = tosa_graph.addConst( - [len(border)], ts.DType.SHAPE, border, node.name + "_border" - ) - attr = ts.TosaSerializerAttribute() - attr.ResizeAttribute( - mode=ResizeMode.NEAREST, - ) - - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().RESIZE, - [ - inputs[0].name, - scales_tensor.name, - offset_tensor.name, - border_tensor.name, - ], - [output.name], - attr, - ) diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index 925404da917..a32cb3aac06 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -3,11 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, cast, List import torch +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -37,8 +38,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( @@ -67,7 +66,7 @@ def define_node( shape_len, ts.DType.SHAPE, shape_data, - name=node.name + "_shape", + name=output.name + "_shape", ) attr = ts.TosaSerializerAttribute() @@ -75,7 +74,7 @@ def define_node( self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().RESHAPE, + ts.Op.RESHAPE, [inputs[0].name, shape.name], [output.name], attr, diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py index dbdbbc67944..f0b6538ac27 100644 --- a/backends/arm/operators/op_where.py +++ b/backends/arm/operators/op_where.py @@ -3,7 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Sequence +from typing import Any, List + +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -21,25 +23,34 @@ @register_node_visitor -class WhereVisitor_INT(NodeVisitor): +class WhereVisitor(NodeVisitor): target = "aten.where.self" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] def __init__(self, *args): super().__init__(*args) - def _add_node_to_tosa_graph( + def define_node( self, node: Node, tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, - supported_dtypes: Sequence, ) -> None: - import serializer.tosa_serializer as ts + + supported_dtypes = [ts.DType.BOOL] + if output.tosa_spec.support_integer(): + supported_dtypes += [ + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ] + if output.tosa_spec.support_float(): + supported_dtypes += [ts.DType.FP16, ts.DType.FP32] validate_num_inputs(self.target, inputs, 3) # Not first input, which is condition tensor. @@ -52,62 +63,13 @@ def _add_node_to_tosa_graph( output.tosa_spec, ) + attr = ts.TosaSerializerAttribute() + attr.SelectAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().SELECT, + ts.Op.SELECT, [inputs[0].name, inputs[1].name, inputs[2].name], [output.name], - None, - ) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts - - bi_supported_dtypes = [ - ts.DType.INT8, - ts.DType.INT16, - ts.DType.INT32, - ts.DType.BOOL, - ] - self._add_node_to_tosa_graph( - node, tosa_graph, inputs, output, bi_supported_dtypes - ) - - -@register_node_visitor -class WhereVisitor_FP(WhereVisitor_INT): - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts - - mi_supported_dtypes = [ - ts.DType.FP16, - ts.DType.FP32, - ts.DType.INT8, - ts.DType.INT16, - ts.DType.INT32, - ts.DType.BOOL, - ] - self._add_node_to_tosa_graph( - node, tosa_graph, inputs, output, mi_supported_dtypes + attr, ) diff --git a/backends/arm/operators/op_while.py b/backends/arm/operators/op_while.py new file mode 100644 index 00000000000..b4ac4f4f6f1 --- /dev/null +++ b/backends/arm/operators/op_while.py @@ -0,0 +1,103 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, cast, List + +import tosa_serializer as ts +from executorch.backends.arm._passes.arm_pass_utils import get_output_dim_orders + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_cf_extension, + validate_num_inputs, +) +from executorch.backends.arm.tosa.mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa.utils import tosa_shape + +from torch.fx import Node + + +@register_node_visitor +class WhileLoopVisitor(NodeVisitor): + target = "while_loop" + + tosa_specs = NodeVisitor.tosa_specs + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + validate_num_inputs(self.target, inputs, 4) + validate_cf_extension(self.target, self.tosa_spec) + + carried_inputs = inputs[2].special if hasattr(inputs[2], "special") else None + if carried_inputs is None: + raise ValueError(f"{self.target}: Expected loop input arguments to be set.") + + additional_inputs = inputs[3].special if hasattr(inputs[3], "special") else None + if additional_inputs: + raise ValueError( + "Additional inputs is not supported, use carried inputs instead." + ) + + attr = ts.TosaSerializerAttribute() + cond_graph, body_graph = (str(cast(Node, arg).target) for arg in node.args[:2]) + attr.WhileLoopAttribute(cond_graph, body_graph) + + input_names: list[str] = [] + for loop_input in carried_inputs: + if not isinstance(loop_input, Node): + raise ValueError( + f"{self.target}: Unsupported carried input type {type(loop_input)}." + ) + input_names.append(loop_input.name) + + num_inputs = len(input_names) + num_outputs = len(output.multiple_output_names) + if num_inputs > num_outputs: + # If we have more inputs than outputs, we can just add missing output tensors. + body_module = getattr(node.graph.owning_module, body_graph) + output_dim_orders = get_output_dim_orders(body_module) + body_outputs = body_module.graph.output_node().args[0] + outputs_needing_tensors = body_outputs[num_outputs - num_inputs :] + output_dim_orders = output_dim_orders[num_outputs - num_inputs :] + for ( + output_needing_tensor, + dim_order, + ) in zip(outputs_needing_tensors, output_dim_orders, strict=True): + tensor_name = output_needing_tensor.name + "_dummy" + shape = output_needing_tensor.meta["val"].shape + dtype = map_dtype( + output_needing_tensor.meta["val"].dtype, self.tosa_spec + ) + + tosa_graph.currRegion.currBasicBlock.addTensor( + tensor_name, + tosa_shape(shape, dim_order), + dtype, + ) + output.multiple_output_names.append(tensor_name) + elif num_inputs < num_outputs: + # This is a strange case, if we reach it something bad has happened. + raise ValueError( + f"TOSA specifies that the number of inputs, {input_names}, need to be the " + f"same as the number of outputs, {output.multiple_output_names}." + ) + + self._serialize_operator( + node, + tosa_graph, + ts.Op.WHILE_LOOP, + input_names, + output.multiple_output_names, + attr, + ) diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py index cc8317497b8..20ee10534d0 100644 --- a/backends/arm/operators/operator_validation_utils.py +++ b/backends/arm/operators/operator_validation_utils.py @@ -2,48 +2,44 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide validation helpers for operator inputs and dtypes. + +Use these utilities to validate input counts, ensure dtype consistency, check +allowed dtypes, and compute pooling padding adjustments. + +""" from math import ceil, floor from typing import Any, List, Optional -import serializer.tosa_serializer as ts +from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[int]): - """ - Validates the number of inputs provided to an operation against expected values. - - This function checks whether the length of the input list matches the expected - number(s) of inputs. + """Validate the number of inputs against expected values. - Parameters: - ----------- - op_name : str - The name of the operation for which the inputs are being validated. - Used in the error message to provide context. + This function checks whether the length of the input list matches the + expected number(s) of inputs. - inputs : List[TosaArg] - A list of inputs to be validated, where each input is assumed to be an - instance of `TosaArg`. - - expected : int or List[int] - The expected number of inputs. Can be either an integer or a list of integers. + Args: + op_name (str): The name of the operation for which the inputs are being + validated. Used in the error message to provide context. + inputs (List[TosaArg]): A list of inputs to be validated, where each + input is assumed to be an instance of ``TosaArg``. + expected (int | List[int]): The expected number of inputs. Can be either + an integer or a list of integers. Raises: - ------- - ValueError - If the number of inputs does not match the expected value(s), a `ValueError` is - raised with a message indicating the operation name and the mismatch in expected - versus provided number of inputs. + ValueError: If the number of inputs does not match the expected + value(s); the message indicates the operation name and the mismatch + in expected versus provided counts. Example: - -------- - # Example usage: - from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - ) + from executorch.backends.arm.operators.operator_validation_utils import \ + validate_num_inputs + + validate_num_inputs(self.target, inputs, [3, 4]) - validate_num_inputs(self.target, inputs, [3, 4]) """ if isinstance(expected, int): expected = [expected] @@ -56,39 +52,28 @@ def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[in def validate_same_dtype(op_name: str, tensors: List[Any], ts: Optional[Any] = None): - """ - Validates that all given tensors have the same dtype attribute. + """Validate that all given tensors have the same dtype. - This function checks whether all items in the `tensors` list have the same - `dtype` as the first item. + This function checks whether all items in the ``tensors`` list have the + same ``dtype`` as the first item. - Parameters: - ----------- - op_name : str - The name of the operation for which the dtype validation is being performed. - Used in the error message to provide context. - - tensors : List[Any] - A list of tensors to be validated, each is assumed to have a `dtype` attribute. - - ts: Optional[Any] - TOSA serializer. Not required but only to get clearer error messages. + Args: + op_name (str): The name of the operation for which the dtype validation + is being performed. Used in the error message to provide context. + tensors (List[Any]): A list of tensors to be validated, each assumed to + have a ``dtype`` attribute. + ts (Optional[Any]): TOSA serializer (optional) to improve readability of + dtype names in error messages. Raises: - ------- - ValueError - If the dtype of any item in the list does not match the dtype of the first item, - a `ValueError` is raised with a message indicating the operation name and the - mismatch in dtypes. + ValueError: If the dtype of any item in the list does not match the + dtype of the first item, or if the list is empty. Example: - -------- - # Example usage: - from executorch.backends.arm.operators.operator_validation_utils import ( - validate_same_dtype, - ) + from executorch.backends.arm.operators.operator_validation_utils import \ + validate_same_dtype - validate_same_dtype(self.target, [input1, input2, output]) + validate_same_dtype(self.target, [input1, input2, output]) """ if not tensors: @@ -98,67 +83,54 @@ def validate_same_dtype(op_name: str, tensors: List[Any], ts: Optional[Any] = No # Get dtype of the first tensor to reference for comparison reference_dtype = tensors[0].dtype + reference_dtype_name = str(reference_dtype) for tensor in tensors: - ref_dtype_name = ( - ts.DTypeNames[reference_dtype] if ts is not None else str(reference_dtype) - ) - inconsistent_dtype_name = ( - ts.DTypeNames[tensor.dtype] if ts is not None else str(tensor.dtype) - ) if tensor.dtype != reference_dtype: + inconsistent_dtype_name = str(tensor.dtype) raise ValueError( - f"{op_name}: Expected all tensors to have dtype {ref_dtype_name}, but " - f"found inconsistent dtype {inconsistent_dtype_name}." + f"{op_name}: Expected all tensors to have dtype {reference_dtype_name}, " + f"but found inconsistent dtype {inconsistent_dtype_name}." ) def validate_valid_dtype( op_name: str, tensors: Any | List[Any], valid_dtypes: Any | List[Any], tosa_spec ): - """ - Validates that one or more tensors have dtypes within a set of allowed dtypes. - - This function checks whether the `dtype` attribute of the provided tensor(s) is one - of the valid dtype values. It supports checking a single tensor or a list of - tensors. - - Parameters: - ----------- - op_name : str - The name of the operation performing the validation. - tensors : Any or List[Any] - A tensor or list of tensors (each assumed to have `dtype` and `name` attributes) - whose dtype will be validated. - valid_dtypes : Any or List[Any] - A dtype enum or list of dtype enums representing allowed dtype values. - tosa_spec : Any - A TosaSpecification instance indicating which TOSA version is targeted. This - determines which serializer to use for dtype name resolution. + """Validate that one or more tensors have allowed dtypes. + + This function checks whether the ``dtype`` attribute of the provided + tensor(s) is one of the valid dtype values. It supports checking a single + tensor or a list of tensors. + + Args: + op_name (str): The name of the operation performing the validation. + tensors (Any | List[Any]): A tensor or list of tensors (each assumed to + have ``dtype`` and ``name`` attributes) whose dtype will be + validated. + valid_dtypes (Any | List[Any]): A dtype enum or list of dtype enums + representing allowed dtype values. + tosa_spec (Any): A TosaSpecification instance indicating which TOSA + version is targeted. This determines which serializer to use for + dtype name resolution. Raises: - ------- - ValueError - If no tensors are provided, or if any tensor has a dtype not in `valid_dtypes`. + ValueError: If no tensors are provided, or if any tensor has a dtype not + in ``valid_dtypes``. Example: - -------- - # Example usage: - from executorch.backends.arm.operators.operator_validation_utils import ( - validate_valid_dtype, - ) - - import serializer.tosa_serializer as ts - - validate_valid_dtype( - self.target, - [*inputs, output], - [ts.DType.INT8, ts.DType.INT32], - output.tosa_spec, - ) + from executorch.backends.arm.operators.operator_validation_utils import \ + validate_valid_dtype + import serializer.tosa_serializer as ts + + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.INT8, ts.DType.INT32], + output.tosa_spec, + ) """ - if not tensors: raise ValueError( f"{op_name}: Input tensor list is empty, cannot validate dtypes" @@ -172,46 +144,51 @@ def validate_valid_dtype( for tensor in tensors: if tensor.dtype not in valid_dtypes: + valid_names = [str(dtype) for dtype in valid_dtypes] + got_name = str(tensor.dtype) raise ValueError( f"Expected tensor {tensor.name} in {op_name} to have one of the " - f"following dtypes: {[ts.DTypeNames[i] for i in valid_dtypes]}, " - f"got: {ts.DTypeNames[tensor.dtype]}" + f"following dtypes: {valid_names}, got: {got_name}" ) +def validate_cf_extension(op_name: str, tosa_spec: TosaSpecification) -> None: + """Ensure that the requested control-flow operator is supported by the active TOSA spec.""" + if not isinstance(tosa_spec, Tosa_1_00): + raise ValueError( + f"Got TOSA version {tosa_spec.version}, that does not support extensions." + ) + if not tosa_spec.support_extension("cf"): + raise ValueError( + f"Trying to lower {op_name}, but TOSA specification {tosa_spec} does not " + "support the cf extension." + ) + + def adjust_pooling_pad_if_needed( input_size: int, kernel_size: int, stride: int, pad: int, ceil_mode: bool ) -> int: - """ - The Aten pooling ops has one value 'pad' per dimension to specify padding, but they - do not require input and output sizes to match up perfectly. Instead, the output - size is rounded up or down depending on ceil_mode, and padding at the end of the - input is automatically added or removed. TOSA on the other hand specifies two - padding values, one for pre-padding and one for post-padding, and these must satisfy - - output_size = (input_size + pre_pad + post_pad - kernel_size) / stride + 1 + """Compute the post padding needed for pooling. - This function returns the post_pad value required to satisfy the above condition. + ATen pooling uses a single symmetric ``pad`` per dimension and rounds the + output size up or down depending on ``ceil_mode``. TOSA requires distinct + pre- and post-padding values that satisfy: - Parameters: - ----------- - input_size : int - The size of the input to the operator. + output_size == (input_size + pre_pad + post_pad - kernel_size) / stride + 1 - kernel_size : int - The size of the kernel. + This function returns the required ``post_pad`` given a symmetric ``pad``. - stride : int - The size of the stride. + Args: + input_size (int): Input size. + kernel_size (int): Kernel size. + stride (int): Stride size. + pad (int): Symmetric padding specified by ATen. + ceil_mode (bool): Use ceil when computing output size. - pad : int - The amount of padding. + Returns: + int: Post-padding to satisfy the TOSA formula. - Output: - ------- - An int, giving the post-padding to use for the """ - if ceil_mode: output_size = ceil((input_size - kernel_size + 2 * pad) / stride) + 1 else: diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index 4e8e393732b..3e8cda76b5a 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -3,13 +3,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe -from typing import Any, List +from typing import Any, Callable, List import torch import torch.fx +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -22,7 +23,9 @@ from executorch.backends.arm.tosa.mapping import TosaArg -def binary_operator_factory(bw_target: str, tosa_op): +def binary_operator_factory( + bw_target: str, tosa_op, attr_builder: Callable[[Any], None] +): """Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op.""" class BinaryOperator(NodeVisitor): @@ -36,8 +39,6 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore # noqa: F401 - validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) @@ -64,26 +65,48 @@ def define_node( [ts.DType.BOOL], output.tosa_spec, ) - + attr = ts.TosaSerializerAttribute() + attr_builder(attr) self._serialize_operator( node, tosa_graph, tosa_op, [inputs[0].name, inputs[1].name], [output.name], + attr, ) register_node_visitor(BinaryOperator) -import serializer.tosa_serializer as ts # type: ignore - -binary_operator_factory("aten.bitwise_and.Tensor", ts.TosaOp.Op().BITWISE_AND) -binary_operator_factory("aten.bitwise_xor.Tensor", ts.TosaOp.Op().BITWISE_XOR) -binary_operator_factory("aten.bitwise_or.Tensor", ts.TosaOp.Op().BITWISE_OR) -binary_operator_factory("aten.logical_and.default", ts.TosaOp.Op().LOGICAL_AND) -binary_operator_factory("aten.logical_xor.default", ts.TosaOp.Op().LOGICAL_XOR) -binary_operator_factory("aten.logical_or.default", ts.TosaOp.Op().LOGICAL_OR) binary_operator_factory( - "aten.bitwise_left_shift.Tensor", ts.TosaOp.Op().LOGICAL_LEFT_SHIFT + "aten.bitwise_and.Tensor", + ts.Op.BITWISE_AND, + lambda attr: attr.BitwiseAndAttribute(), +) +binary_operator_factory( + "aten.bitwise_xor.Tensor", + ts.Op.BITWISE_XOR, + lambda attr: attr.BitwiseXorAttribute(), +) +binary_operator_factory( + "aten.bitwise_or.Tensor", ts.Op.BITWISE_OR, lambda attr: attr.BitwiseOrAttribute() +) +binary_operator_factory( + "aten.logical_and.default", + ts.Op.LOGICAL_AND, + lambda attr: attr.LogicalAndAttribute(), +) +binary_operator_factory( + "aten.logical_xor.default", + ts.Op.LOGICAL_XOR, + lambda attr: attr.LogicalXorAttribute(), +) +binary_operator_factory( + "aten.logical_or.default", ts.Op.LOGICAL_OR, lambda attr: attr.LogicalOrAttribute() +) +binary_operator_factory( + "aten.bitwise_left_shift.Tensor", + ts.Op.LOGICAL_LEFT_SHIFT, + lambda attr: attr.LogicalLeftShiftAttribute(), ) diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index 62a307f3012..a7ffd4eacca 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -3,13 +3,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from typing import Any, List import torch import torch.fx +import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -17,6 +18,7 @@ from executorch.backends.arm.operators.operator_validation_utils import ( validate_num_inputs, validate_same_dtype, + validate_valid_dtype, ) from executorch.backends.arm.tosa.mapping import TosaArg @@ -39,22 +41,38 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 1) - validate_same_dtype(self.target, [*inputs, output], ts) + validate_same_dtype(self.target, [inputs[0], output], ts) + supported_dtypes = [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ] + if output.tosa_spec.support_float(): + supported_dtypes += [ts.DType.FP32] + if self.tosa_spec.support_extension("int16"): + supported_dtypes += [ts.DType.INT48] + validate_valid_dtype( + self.target, + [inputs[0], output], + supported_dtypes, + output.tosa_spec, + ) # Simply add an identityOp + attr = ts.TosaSerializerAttribute() + attr.IdentityAttribute() self._serialize_operator( node, tosa_graph, - ts.TosaOp.Op().IDENTITY, + ts.Op.IDENTITY, [inputs[0].name], [output.name], + attr, ) register_node_visitor(IdentityOperatorVisitor) -identity_operator_factory("getitem") identity_operator_factory("aten.alias_copy.default") diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 9ca435c60c5..5a1d563ee0b 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -4,15 +4,15 @@ # LICENSE file in the root directory of this source tree. # -# pyre-unsafe +import operator from typing import Any, cast, Dict import numpy as np -import serializer.tosa_serializer as ts import torch import torch.fx +import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import NodeVisitor -from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype from executorch.backends.arm.tosa.specification import TosaSpecification from executorch.backends.arm.tosa.utils import tosa_shape from torch._export.utils import ( @@ -46,14 +46,18 @@ def process_call_function( f"Failed processing call_function: {node.name}. " "Is the original torch function supported?" ) from e - tosa_graph.currRegion.currBasicBlock.addTensor( - output.name, tosa_shape(output.shape, output.dim_order), output.dtype - ) + + if not output.multiple_output_names: + tosa_graph.currRegion.currBasicBlock.addTensor( + output.name, tosa_shape(output.shape, output.dim_order), output.dtype + ) + + # Get item nodes just add tensors, no node visitor is needed. + if node.target == operator.getitem: + return # Visiting each Node - # pyre-ignore[16]: Undefined attribute. if node.target.__name__ in node_visitors: # type: ignore[union-attr] - # pyre-ignore[16]: Undefined attribute. node_visitors[node.target.__name__].define_node( # type: ignore[union-attr] node, tosa_graph, @@ -70,13 +74,6 @@ def process_inputs( tosa_spec: TosaSpecification, ): """Serialize an input node""" - # inputs need to be in default dim_order (contiguous memory format) - meta = node.meta["val"] - if meta.dim_order() != tuple(range(meta.dim())): - raise RuntimeError( - f"Arm backend only supports contiguous memory format for inputs. " - f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}" - ) try: tosa_arg = TosaArg(node, tosa_spec) except ValueError as e: @@ -92,7 +89,6 @@ def process_inputs( tosa_shape(input_shape, input_dim_order), tosa_arg.dtype, data=None, - placeholderFilename=tosa_arg.name + ".npy", ) tosa_graph.addInputTensor(tensor) @@ -113,16 +109,28 @@ def process_inputs_to_parameters( ) from e parameter_data = get_param(edge_program, node) - assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor" + if not isinstance(parameter_data, torch.Tensor): + raise TypeError( + f"Expected parameter '{node.name}' to be a torch.Tensor, got " + f"{type(parameter_data).__name__}" + ) parameter_values = parameter_data.detach().numpy() if tosa_arg.dtype == torch.float32: - assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float" + if not tosa_spec.support_float(): + raise ValueError(f"{tosa_spec} doesn't support float operations") + + # Handle special case for INT48 tensors + special_type = node.meta.get(TosaSpecialDtype.meta_key(), None) + if isinstance(special_type, TosaSpecialDtype): + tosa_dtype = special_type.get_tosa_dtype() + else: + tosa_dtype = tosa_arg.dtype parameter_values = np.transpose(parameter_values, tosa_arg.dim_order) tosa_graph.addConst( - parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name + parameter_values.shape, tosa_dtype, parameter_values, name=tosa_arg.name ) @@ -142,7 +150,11 @@ def process_inputs_to_buffers( ) from e buffer_data = get_buffer(edge_program, node) - assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor" + if not isinstance(buffer_data, torch.Tensor): + raise TypeError( + f"Expected buffer '{node.name}' to be a torch.Tensor, got " + f"{type(buffer_data).__name__}" + ) buffer_values = buffer_data.detach().numpy() # TODO: fragile code for temporary fix @@ -151,7 +163,7 @@ def process_inputs_to_buffers( buffer_values = np.transpose(buffer_values, tosa_arg.dim_order) tosa_graph.addConst( - buffer_values.shape, tosa_arg.dtype, buffer_values, name=node.name + buffer_values.shape, tosa_arg.dtype, buffer_values, name=tosa_arg.name ) @@ -170,24 +182,41 @@ def process_inputs_to_lifted_tensor_constants( ) from e tensor = get_lifted_tensor_constant(edge_program, node) tensor_data = tensor.detach().numpy() # type: ignore[union-attr] + tensor_values = np.transpose(tensor_data, tosa_arg.dim_order) tosa_graph.addConst( - tensor_data.shape, tosa_arg.dtype, tensor_data, name=tosa_arg.name + tensor_values.shape, tosa_arg.dtype, tensor_values, name=tosa_arg.name ) +def _is_submodule_input( + node: torch.fx.Node, containing_graph_module: torch.fx.GraphModule +) -> bool: + """Determines whether 'node' is an input to a submodule of 'containing_graph_module'.""" + if node.op != "placeholder": + return False + return node.meta.get("is_input", False) + + def process_placeholder( node: torch.fx.Node, tosa_graph: Any, edge_program: ExportedProgram, + containing_graph_module: torch.fx.GraphModule | None, tosa_spec: TosaSpecification, ): """Wrapper for processing and serializing all types of placeholders""" - assert node.name == node.target, "Expect placeholder name and target to match" - assert 0 == len(node.args), "Can't handle default input values" + if node.name != node.target: + raise ValueError( + f"Placeholder name '{node.name}' does not match target '{node.target}'" + ) + if len(node.args) != 0: + raise ValueError(f"Placeholder '{node.name}' must not have default values") if node.name in edge_program.graph_signature.user_inputs: process_inputs(node, tosa_graph, tosa_spec) + elif containing_graph_module and _is_submodule_input(node, containing_graph_module): + process_inputs(node, tosa_graph, tosa_spec) elif is_param(edge_program, node): process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec) elif is_buffer(edge_program, node): @@ -204,11 +233,9 @@ def process_placeholder( raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.") -def process_output( - node: torch.fx.Node, - tosa_graph: Any, -): +def process_output(node: torch.fx.Node, tosa_graph: Any, tosa_spec: TosaSpecification): for output in cast(tuple[torch.fx.Node, ...], node.args[0]): + output_arg = TosaArg(output, tosa_spec) tosa_graph.addOutputTensor( - tosa_graph.currRegion.currBasicBlock.tensors[output.name] + tosa_graph.currRegion.currBasicBlock.tensors[output_arg.name] ) diff --git a/backends/arm/quantizer/__init__.py b/backends/arm/quantizer/__init__.py index 5cb5c834a98..2018b845353 100644 --- a/backends/arm/quantizer/__init__.py +++ b/backends/arm/quantizer/__init__.py @@ -2,11 +2,17 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Expose quantizer APIs and load optional quantized kernels. +Import the public quantizer classes and configuration helpers for Arm +backends. Attempt to load portable and quantized libraries; fall back to a +log message if unavailable. +""" from .quantization_config import QuantizationConfig # noqa # usort: skip from .arm_quantizer import ( # noqa EthosUQuantizer, + get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, TOSAQuantizer, VgfQuantizer, diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index ae7c8255428..a383f44890f 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe # # Quantizer for Arm backend @@ -14,21 +13,18 @@ from __future__ import annotations import functools -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional import torch +from executorch.backends.arm.ethosu import EthosUCompileSpec from executorch.backends.arm.quantizer import QuantizationConfig from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.specification import get_tosa_spec - -from .arm_quantizer_utils import is_annotated, mark_node_as_annotated -from .quantization_annotator import annotate_graph -from executorch.backends.arm.arm_backend import ( - is_ethosu, - is_vgf, -) # usort: skip -from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.backends.arm.common.arm_compile_spec import ( + ArmCompileSpec, +) # isort: skip +from executorch.backends.arm.vgf import VgfCompileSpec +from executorch.exir.graph_module import get_cond_while_submodules from torch.fx import GraphModule, Node from torchao.quantization.pt2e import ( @@ -41,18 +37,28 @@ PerChannelMinMaxObserver, PlaceholderObserver, ) +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) from torchao.quantization.pt2e.quantizer import ( annotate_input_qspec_map, annotate_output_qspec, + get_module_name_filter, QuantizationSpec, Quantizer, ) +from .arm_quantizer_utils import is_annotated, mark_node_as_annotated +from .quantization_annotator import annotate_graph + __all__ = [ "TOSAQuantizer", "EthosUQuantizer", "VgfQuantizer", + "get_symmetric_a16w8_quantization_config", "get_symmetric_quantization_config", ] @@ -66,7 +72,25 @@ def get_symmetric_quantization_config( act_qmax: int = 127, weight_qmin: int = -127, weight_qmax: int = 127, -): +) -> QuantizationConfig: + """Create symmetric quantization config for activations and weights. + + Args: + is_per_channel (bool): Whether to use per-channel quantization for + weights. + is_qat (bool): Whether the configuration targets quantization aware + training. + is_dynamic (bool): Whether to generate dynamic activation observers. + act_qmin (int): Minimum activation quantization value. + act_qmax (int): Maximum activation quantization value. + weight_qmin (int): Minimum weight quantization value. + weight_qmax (int): Maximum weight quantization value. + + Returns: + QuantizationConfig: Quantization settings for activations, weights, and + bias. + + """ extra_args: Dict[str, Any] = {"eps": 2**-12} if is_qat: if is_dynamic: @@ -105,15 +129,27 @@ def get_symmetric_quantization_config( # Determine the right observer/fake-quant constructor if is_qat: if is_per_channel: - weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + weight_observer_or_fake_quant_ctr = FakeQuantize.with_args( + observer=PerChannelMinMaxObserver, + quant_min=weight_qmin, + quant_max=weight_qmax, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + ch_axis=0, + **extra_args, + ) else: # Set plain fake-quant with true min/max - weight_observer_or_fake_quant_ctr = FakeQuantize + weight_observer_or_fake_quant_ctr = FakeQuantize.with_args(**extra_args) else: # PTQ: set min/max observer weight_observer_or_fake_quant_ctr = ( PerChannelMinMaxObserver if is_per_channel else MinMaxObserver ) + weight_observer_or_fake_quant_ctr = weight_observer_or_fake_quant_ctr.with_args( + **extra_args, + ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8, @@ -122,9 +158,7 @@ def get_symmetric_quantization_config( qscheme=weight_qscheme, ch_axis=0, is_dynamic=False, - observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( - **extra_args - ), + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, ) bias_quantization_spec = None @@ -152,24 +186,29 @@ def get_symmetric_a16w8_quantization_config( is_dynamic: bool = False, weight_qmin: int = -127, weight_qmax: int = 127, -): - """ - 16A8W quantization config: 16-bit activations, 8-bit weights. + epsilon: float = 2**-12, +) -> QuantizationConfig: + """16A8W quantization config: 16-bit activations, 8-bit weights. This configuration provides better accuracy than 8A8W while maintaining reasonable memory usage through 8-bit weights. Args: - is_per_channel: Whether to use per-channel quantization for weights - is_qat: Whether this is for Quantization Aware Training - is_dynamic: Whether to use dynamic quantization - weight_qmin: Minimum quantization value for weights - weight_qmax: Maximum quantization value for weights + is_per_channel (bool): Whether to use per-channel quantization for + weights. + is_qat (bool): Whether this is for quantization aware training. + is_dynamic (bool): Whether to use dynamic quantization. + weight_qmin (int): Minimum quantization value for weights. + weight_qmax (int): Maximum quantization value for weights. + epsilon (float): Value used to pad observed [qmin, qmax] before initial + zero-point and scale calculation. Returns: - QuantizationConfig with 16-bit activations and 8-bit weights + QuantizationConfig: Configuration with 16-bit activations and 8-bit + weights. + """ - extra_args: Dict[str, Any] = {"eps": 2**-12} + extra_args: Dict[str, Any] = {"eps": epsilon} # Setup observer/fake-quant for 16-bit activations if is_qat: @@ -191,7 +230,7 @@ def get_symmetric_a16w8_quantization_config( # 16-bit activation quantization spec act_quantization_spec = QuantizationSpec( dtype=torch.int16, - quant_min=torch.iinfo(torch.int16).min, # -32768 + quant_min=torch.iinfo(torch.int16).min + 1, # -32767 quant_max=torch.iinfo(torch.int16).max, # 32767 qscheme=torch.per_tensor_symmetric, is_dynamic=is_dynamic, @@ -226,54 +265,39 @@ def get_symmetric_a16w8_quantization_config( NodeFilterType = Callable[[Node], bool] -"""Type for a Node Filter used by annotators. A Node filter is a function that takes - a Node and returns whether the node should be annotated or not. -""" +"""Type for a Node Filter used by annotators. +A Node filter is a function that takes a Node and returns whether the node +should be annotated or not. -def _get_module_name_filter(module_name: str) -> NodeFilterType: - """Get the module_name_filter function for a given module name, the filter accepts - a node and checks if the node comes from a module that has certain module name - - For example: - node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 - - >> module_name_filter = _get_module_name_filter("blocks.sub") - >> print(module_name_filter(node)) - True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" - """ +""" - name_start = len("L['self'].") - def module_name_filter(n: Node) -> bool: - # node_stack example: { - # 'L__self___sub': ("L['self'].sub", ), - # 'L__self___sub_linear': ("L['self'].sub.linear", ) - # } - # get_attr nodes doesn't have nn_module_stack? - nn_module_stack = n.meta.get("nn_module_stack", {}) - names = [name[name_start:] for name, _ in nn_module_stack.values()] - return module_name in names +def _get_module_type_filter(tp: Callable) -> NodeFilterType: + """Get the module_type_filter function for a given module type. - return module_name_filter + The filter accepts a node and checks if the node comes from a module that + has a certain module type. + Args: + tp (Callable): Module class to match against the graph node metadata. -def _get_module_type_filter(tp: Callable) -> NodeFilterType: - """Get the module_type_filter function for a given module type, the filter accepts - a node and checks if the node comes from a module that has certain module type + Returns: + NodeFilterType: Predicate that returns True for nodes from the module + type. For example: - node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear - + node: linear_op = call_function[...](...) # type Block -> Sub -> Linear - >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule + >> module_type_filter = _get_module_type_filter(Sub) >> print(module_type_filter(node)) - True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) - """ + True # the node is from the submodule `Sub` (same for `Block` and `Linear`) + """ tp_str = tp.__module__ + "." + tp.__qualname__ def module_type_filter(n: Node) -> bool: + """Return True if the node originates from the target module type.""" # node_stack example: { # 'L__self___sub': ("L['self'].sub", ), # 'L__self___sub_linear': ("L['self'].sub.linear", ) @@ -288,39 +312,43 @@ def module_type_filter(n: Node) -> bool: def _get_not_module_type_or_name_filter( tp_list: List[Callable], module_name_list: List[str] ) -> NodeFilterType: + """Create a filter that excludes provided module types and names. + + Args: + tp_list (List[Callable]): Module types to exclude from annotation. + module_name_list (List[str]): Module names to exclude from annotation. + + Returns: + NodeFilterType: Filter that returns True when the node does not match + any provided module type or name. + + """ module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] - module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + module_name_list_filters = [get_module_name_filter(m) for m in module_name_list] def not_module_type_or_name_filter(n: Node) -> bool: + """Return True when the node matches none of the blocked filters.""" return not any(f(n) for f in module_type_filters + module_name_list_filters) return not_module_type_or_name_filter class TOSAQuantizer(Quantizer): + """Manage quantization annotations for TOSA-compatible backends.""" def __init__( - self, compile_spec_or_tosa_spec: Union[TosaSpecification, List[CompileSpec]] + self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec ) -> None: - super().__init__() + self.compile_spec: ArmCompileSpec if isinstance(compile_spec_or_tosa_spec, TosaSpecification): - self.tosa_spec = compile_spec_or_tosa_spec - self.compile_spec = None - elif isinstance(compile_spec_or_tosa_spec, list): + from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec + + self.compile_spec = TosaCompileSpec(compile_spec_or_tosa_spec) + self.tosa_spec = self.compile_spec.tosa_spec + elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec): self.compile_spec = compile_spec_or_tosa_spec - # find entry that is 'tosa_spec' - for cs in compile_spec_or_tosa_spec: - if cs.key == "tosa_spec": - spec_val = ( - cs.value.decode() if isinstance(cs.value, bytes) else cs.value - ) - self.tosa_spec = TosaSpecification.create_from_string(spec_val) - break - else: - raise ValueError( - "compile_spec list did not contain a 'tosa_spec' entry" - ) + self.tosa_spec = self.compile_spec.tosa_spec else: raise TypeError( f"TOSAQuantizer constructor expects " @@ -334,16 +362,30 @@ def __init__( self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {} def set_global(self, quantization_config: QuantizationConfig) -> TOSAQuantizer: - """Set quantization_config for submodules that are not already annotated by name or type filters.""" + """Set quantization_config for submodules not matched by other filters. + + Args: + quantization_config (QuantizationConfig): Configuration to apply to + modules that are not captured by name or type filters. + + """ self.global_config = quantization_config return self def set_module_type( self, module_type: Callable, quantization_config: QuantizationConfig ) -> TOSAQuantizer: - """Set quantization_config for a submodule with type: `module_type`, for example: - quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator - patterns in the submodule with this module type with the given `quantization_config` + """Set quantization_config for submodules with a given module type. + + For example, calling set_module_type(Sub) quantizes supported patterns + in each Sub instance with the provided quantization_config. + + Args: + module_type (Callable): Type whose submodules should use the + provided quantization configuration. + quantization_config (QuantizationConfig): Configuration to apply to + submodules of the given type. + """ self.module_type_config[module_type] = quantization_config return self @@ -351,40 +393,61 @@ def set_module_type( def set_module_name( self, module_name: str, quantization_config: Optional[QuantizationConfig] ) -> TOSAQuantizer: - """Set quantization_config for a submodule with name: `module_name`, for example: - quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator - patterns in the submodule with this module name with the given `quantization_config` + """Set quantization_config for submodules with a given module name. + + For example, calling set_module_name("blocks.sub") quantizes supported + patterns for that submodule with the provided quantization_config. + + Args: + module_name (str): Fully qualified module name to configure. + quantization_config (QuantizationConfig): Configuration applied to + the named submodule. + """ # Validate that quantization_config is provided - if quantization_config is None: - raise ValueError("quantization_config == None is not supported yet") self.module_name_config[module_name] = quantization_config return self - def set_io(self, quantization_config): - """Set quantization_config for input and output nodes.""" + def set_io(self, quantization_config: QuantizationConfig) -> TOSAQuantizer: + """Set quantization_config for input and output nodes. + + Args: + quantization_config (QuantizationConfig): Configuration describing + activation quantization for model inputs and outputs. + + """ self.io_config = quantization_config return self def transform_for_annotation(self, model: GraphModule) -> GraphModule: - """An initial pass for transforming the graph to prepare it for annotation. + """Transform the graph to prepare it for quantization annotation. + Currently transforms scalar values to tensor attributes. - """ + Args: + model (GraphModule): Model whose graph will be transformed. + + Returns: + GraphModule: Transformed model prepared for annotation. + + """ # TODO: Fix the need to lazily import this. from executorch.backends.arm._passes import ArmPassManager - return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type] - graph_module=model - ) + pass_manager = ArmPassManager(self.compile_spec) + return pass_manager.transform_for_annotation_pipeline(graph_module=model) def annotate(self, model: GraphModule) -> GraphModule: - """Performs the quantization annotation on the graph. - Currently only does static quantization annotation. + """Annotate the graph with the configured quantization settings. + + Currently only does static quantization annotation. + Args: - model: The model to annotate statically. + model (GraphModule): Model to annotate statically. + Returns: - The annotated model. + GraphModule: Annotated model ready for export. + """ model = self._annotate_for_static_quantization_config(model) return model @@ -395,14 +458,19 @@ def _annotate_all_static_patterns( quantization_config: Optional[QuantizationConfig], filter_fn: Optional[Callable[[Node], bool]] = None, ) -> GraphModule: - """Loops over all STATIC_OPS and runs the corresponding registered annotator. + """Annotate all static patterns registered for the backend. + Args: - model: The model to annotate statically. - quantization_config: Specifies the QuantizationSpecs for the model's - input activations, output activations, weights and biases. - filter_fn: An optional filter function that takes a node and returns whether the node should be annotated. + model (GraphModule): Model to annotate statically. + quantization_config (Optional[QuantizationConfig]): Quantization + specs for input activations, output activations, weights, and + biases. + filter_fn (Optional[Callable[[Node], bool]]): Optional node filter + specifying which nodes to annotate. + Returns: - The annotated model. + GraphModule: Model populated with quantization annotations. + """ # TODO: implement the support for None to be canceling out previous annotations if quantization_config is None: @@ -414,8 +482,15 @@ def _annotate_all_static_patterns( def _annotate_for_static_quantization_config( self, model: GraphModule ) -> GraphModule: - """Matches the correct QuantizationConfig with the correct module using a filter - when running _annotate_all_static_patterns. + """Match QuantizationConfigs to modules before annotating patterns. + + Args: + model (GraphModule): Model whose modules are being matched to + quantization configs. + + Returns: + GraphModule: Annotated model after applying configured filters. + """ if self.io_config: self._annotate_io(model, self.io_config) @@ -423,7 +498,7 @@ def _annotate_for_static_quantization_config( module_name_list = list(self.module_name_config.keys()) for module_name, config in self.module_name_config.items(): self._annotate_all_static_patterns( - model, config, _get_module_name_filter(module_name) + model, config, get_module_name_filter(module_name) ) tp_list = list(self.module_type_config.keys()) @@ -445,6 +520,14 @@ def _annotate_io( model: GraphModule, quantization_config: QuantizationConfig, ): + """Annotate graph inputs and outputs with the provided configuration. + + Args: + model (GraphModule): GraphModule being annotated. + quantization_config (QuantizationConfig): Activation qspecs to apply + to IO nodes. + + """ for node in model.graph.nodes: if is_annotated(node): continue @@ -455,29 +538,71 @@ def _annotate_io( ) mark_node_as_annotated(node) if node.op == "output": - parent = node.all_input_nodes[0] - annotate_input_qspec_map( - node, parent, quantization_config.get_input_act_qspec() - ) + for parent in node.all_input_nodes: + annotate_input_qspec_map( + node, parent, quantization_config.get_input_act_qspec() + ) mark_node_as_annotated(node) def validate(self, model: GraphModule) -> None: + """TODO: Implement validation of annotated graph for TOSA backend.""" pass + def quantize_with_submodules( + self, + model: GraphModule, + calibration_samples: list[tuple], + is_qat: bool = False, + ): + """Quantizes a GraphModule in a way such that conditional submodules are handled properly. + + Args: + model (GraphModule): The model to quantize. + calibration_samples (list[tuple]): A list of inputs to used to + calibrate the model during quantization. To properly calibrate a + model with submodules, at least one sample per code path is + needed. + is_qat (bool): Whether to do quantization aware training or not. + + Returns: + GraphModule: The quantized model. + + """ + prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e + + prepared = prepare_fn(model, self) + for name, submodule, _ in get_cond_while_submodules(prepared): + prepared.set_submodule(name, prepare_fn(submodule, self), strict=True) + for inp in calibration_samples: + prepared(*inp) + + for name, submodule, _ in get_cond_while_submodules(prepared): + prepared.set_submodule(name, convert_pt2e(submodule), strict=True) + converted = convert_pt2e(prepared) + return converted + class EthosUQuantizer(TOSAQuantizer): - def __init__(self, compile_spec: list[CompileSpec]) -> None: - if not is_ethosu(compile_spec): - raise RuntimeError("compile spec is not targeting Ethos-U") + """Quantizer supported by the Arm Ethos-U backend. - tosa_spec = get_tosa_spec(compile_spec) - super().__init__(tosa_spec) + Args: + compile_spec (EthosUCompileSpec): Backend compile specification for + Ethos-U targets. + + """ + + def __init__(self, compile_spec: EthosUCompileSpec) -> None: + super().__init__(compile_spec) class VgfQuantizer(TOSAQuantizer): - def __init__(self, compile_spec: list[CompileSpec]) -> None: - if not is_vgf(compile_spec): - raise RuntimeError("compile spec is not targeting VGF") + """Quantizer supported by the Arm Vgf backend. + + Args: + compile_spec (VgfCompileSpec): Backend compile specification for Vgf + targets. + + """ - tosa_spec = get_tosa_spec(compile_spec) - super().__init__(tosa_spec) + def __init__(self, compile_spec: VgfCompileSpec) -> None: + super().__init__(compile_spec) diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 838dd44733e..7bd8e00c22b 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -1,18 +1,21 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +"""Provide utilities for quantization annotations. -# -# Utility functions for TOSAQuantizer -# +Use these helpers to check and mark annotation state when working with +``QuantizationAnnotation`` entries in FX node metadata. + +""" from typing import cast +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo + from torch.fx import Node from torchao.quantization.pt2e.quantizer import QuantizationAnnotation @@ -20,7 +23,15 @@ def is_annotated(node: Node) -> bool: - """Given a node return whether the node is annotated.""" + """Return True if the node is annotated. + + Args: + node (Node): FX node to inspect. + + Returns: + bool: True if ``Q_ANNOTATION_KEY`` exists and ``_annotated`` is set. + + """ return ( Q_ANNOTATION_KEY in node.meta and cast(QuantizationAnnotation, node.meta[Q_ANNOTATION_KEY])._annotated @@ -28,7 +39,15 @@ def is_annotated(node: Node) -> bool: def is_output_annotated(node: Node) -> bool: - """Given a node, return whether the output of the node is annotated.""" + """Return True if the node's output is annotated. + + Args: + node (Node): FX node to inspect. + + Returns: + bool: True if annotated and an output qspec is present. + + """ if Q_ANNOTATION_KEY in node.meta: annotation = cast(QuantizationAnnotation, node.meta[Q_ANNOTATION_KEY]) return annotation._annotated and annotation.output_qspec is not None @@ -37,9 +56,21 @@ def is_output_annotated(node: Node) -> bool: def mark_node_as_annotated(node: Node) -> None: - """Marks node as annotated. If needed, an empty QuantizationAnnotation is added - to the quantization_annotation node meta entry. + """Mark a node as annotated. + + Create an empty ``QuantizationAnnotation`` on the node when missing and set + its ``_annotated`` flag to True. + + Args: + node (Node): FX node to update. + """ if Q_ANNOTATION_KEY not in node.meta: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() + annotation_info = ArmAnnotationInfo( + quantized=True, + ) node.meta[Q_ANNOTATION_KEY]._annotated = True + meta_custom = node.meta.get("custom", {}) + meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = dict(annotation_info) + node.meta["custom"] = meta_custom diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index ff1ad50e517..60f739c09ad 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -2,16 +2,22 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide quantization annotation logic for Arm backends. + +This module computes per-node quantization properties and applies input/output +annotations to FX graphs using TorchAO qspecs. + +""" import logging import operator from dataclasses import dataclass -from typing import Callable, List, Optional, Sequence +from typing import Callable, cast, List, Optional, Sequence import torch import torch.fx -import torch.nn.functional as F from executorch.backends.arm.common.debug import get_node_debug_info +from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.quantizer import QuantizationConfig from torch._subclasses import FakeTensor @@ -37,19 +43,38 @@ class _QuantProperty: """Specify how the input/output at 'index' must be quantized.""" index: int - qspec: type[QuantizationSpecBase] | List[type[QuantizationSpecBase]] + qspec: QuantizationSpecBase | List[QuantizationSpecBase] optional: bool = False mark_annotated: bool = False class _OpQuantProperties: + """Collect input/output quantization properties for a node. + + Attributes: + quant_inputs (List[_QuantProperty]): Quantization specs for inputs + indexed by argument positions. + quant_output (Optional[_QuantProperty]): Quantization spec for the + node's output when applicable. + + """ + def __init__(self): self.quant_inputs: List[_QuantProperty] = [] self.quant_output: Optional[_QuantProperty] = None def _as_list(x): - if isinstance(x, list): + """Return ``x`` wrapped as a list if needed. + + Args: + x: Value or list of values. + + Returns: + list: ``x`` if already a list; otherwise ``[x]``. + + """ + if isinstance(x, (list, tuple)): return x else: return [ @@ -65,16 +90,19 @@ def _is_ok_for_quantization( A node can be quantized if: - All inputs that are required for quantization are of type `float32` and are not large scalar values. - - The output of the node itself is of type `float32` and is not a large scalar. + - The output of the node itself is of type `float32` and is not a large + scalar. Args: node (Node): The node being analyzed. - quant_properties (_OpQuantProperties): Contains quantization properties for - the node, including input and output quantization specifications. - gm (torch.fx.GraphModule): The graph module containing the computational graph. + quant_properties (_OpQuantProperties): Contains quantization properties + for the node, including input and output quantization specifications. + gm (torch.fx.GraphModule): The graph module containing the computational + graph. Returns: bool: `True` if the node can be quantized, otherwise `False`. + """ # Check output if quant_properties.quant_output is not None: @@ -126,6 +154,16 @@ def _is_ok_for_quantization( def _get_node_target(module: torch.nn.Module | torch.fx.GraphModule, target_str: str): + """Get an attribute from a module by dotted path. + + Args: + module (torch.nn.Module | torch.fx.GraphModule): Root module. + target_str (str): Dotted attribute path, e.g., ``"sub.weight"``. + + Returns: + Any: Resolved attribute on the module. + + """ targets = target_str.split(".") for target in targets[:-1]: module = module.get_submodule(target) @@ -133,15 +171,24 @@ def _get_node_target(module: torch.nn.Module | torch.fx.GraphModule, target_str: def _is_large_scalar(node: Node, gm: torch.fx.GraphModule): - """Check if input is a large scalar value. So that we can skip quantization for the - node since histc op (in HistogramObserver) only works for values up to certain upper - bound. + """Return True if input is a large scalar value. + + Large scalars are skipped because ``torch.histc`` supports values only up + to a certain upper bound. + """ + HISTC_UPPER_BOUND = 3.4028235e15 if node.op == "get_attr" and isinstance(node.target, str): tensor = _get_node_target(gm, node.target) # torch.histc works until this upper bound - HISTC_UPPER_BOUND = 3.4028235e15 return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND + if node.op == "call_function" and node.target in ( + torch.ops.aten.full.default, + torch.ops.aten.full, + torch.ops.aten.fill_.Scalar, + ): + fill_value = cast(float, node.args[1]) + return abs(fill_value) > HISTC_UPPER_BOUND return False @@ -158,11 +205,12 @@ def _is_non_float_tensor(node: Node) -> bool: bool: `True` if the data type is not float32, otherwise `False`. Note: - - If `node.meta["val"]` is a `list`, the function returns `True` if **any** - element is **not** an instance of `FakeTensor` or does **not** have + - If `node.meta["val"]` is a `list`, the function returns `True` if + any element is not an instance of `FakeTensor` or does not have `torch.float32` as its data type. - - If node.meta["val"] is missing or is not an instance of `FakeTensor`, the - function returns True. + - If node.meta["val"] is missing or is not an instance of `FakeTensor`, + the function returns True. + """ if "val" in node.meta and isinstance(node.meta["val"], Sequence): return any( @@ -178,6 +226,20 @@ def _is_non_float_tensor(node: Node) -> bool: def _annotate_input(node: Node, quant_property: _QuantProperty): + """Annotate a node's input with the given qspec. + + Maps the specified input argument(s) to the provided quantization spec and + optionally marks the input node(s) as annotated. + + Args: + node (Node): Node whose input should be annotated. + quant_property (_QuantProperty): Input index and qspec(s). + + Raises: + RuntimeError: If the node is already annotated. + TypeError: If an input argument is not a ``Node`` instance. + + """ if is_annotated(node): raise RuntimeError( f"Cannot annotate input: node '{node.name}' is already annotated" @@ -203,6 +265,18 @@ def _annotate_input(node: Node, quant_property: _QuantProperty): def _annotate_output(node: Node, quant_property: _QuantProperty): + """Annotate a node's output with the given qspec. + + Args: + node (Node): Node whose output should be annotated. + quant_property (_QuantProperty): Output index and qspec. + + Raises: + RuntimeError: If the node is already annotated. + ValueError: If ``mark_annotated`` is True, ``optional`` is True, or + ``index`` is not zero. + + """ if is_annotated(node): raise RuntimeError( f"Cannot annotate output: node '{node.name}' is already annotated" @@ -222,12 +296,13 @@ def _annotate_output(node: Node, quant_property: _QuantProperty): def _match_pattern( node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None ) -> bool: - """ - Check if there's a chain of node.ancestors? -> node -> node.descendant? that matches the - chain provided in 'pattern'. If 'filter_fn' is provided, check that all the nodes in the - chain pass the filtering. + """Check whether a node chain matches a pattern. + + Verify a chain of ancestors -> node -> descendants matches the provided + ``pattern``. If ``filter_fn`` is provided, require all nodes in the chain + to pass the filter. Each pattern element is a list of disjunctive node + targets. - Each 'pattern' element is composed of a list of disjunctive nodes types. """ if len(pattern) < 1: raise ValueError("No pattern provided") @@ -260,6 +335,14 @@ def _match_pattern( return left_condition and right_condition +_conv_ops = [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv3d.padding, +] + _one_to_one = [ torch.ops.aten.abs.default, torch.ops.aten.ceil.default, @@ -276,10 +359,12 @@ def _match_pattern( torch.ops.aten.sin.default, torch.ops.aten.tanh.default, torch.ops.aten.sum.dim_IntList, + torch.ops.aten.sum.default, torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default, torch.ops.aten.full_like.default, + torch.ops.aten.zeros_like.default, torch.ops.aten.pow.Tensor_Scalar, torch.ops.aten.gelu.default, torch.ops.aten.sinh.default, @@ -317,15 +402,18 @@ def _match_pattern( torch.ops.aten.view.default, torch.ops.aten.view_as.default, torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, torch.ops.aten.select.int, torch.ops.aten.select_copy.int, torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor, torch.ops.aten.split.Tensor, torch.ops.aten.split_with_sizes.default, + torch.ops.aten.split_copy.Tensor, torch.ops.aten.transpose.Dimname, torch.ops.aten.transpose.int, torch.ops.aten.transpose_copy.int, + torch.ops.aten.t_copy.default, torch.ops.aten.tile.default, torch.ops.aten.flip.default, torch.ops.aten.chunk.default, @@ -347,6 +435,7 @@ def _match_pattern( ] _one_to_one_shared_input_or_input_act_qspec = [ + torch.ops.aten.alias.default, torch.ops.aten.clone.default, torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default, @@ -358,48 +447,71 @@ def _match_pattern( torch.ops.aten.permute_copy.default, torch.ops.aten.avg_pool2d.default, torch.ops.aten.max_pool2d.default, - torch.ops.aten.full.default, - torch.ops.aten.full, torch.ops.aten.flatten.using_ints, torch.ops.aten.dropout.default, torch.ops.aten.dropout_.default, torch.ops.aten.adaptive_avg_pool2d.default, torch.ops.aten.alias_copy.default, + torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.pixel_unshuffle.default, ] def get_quant_properties( # noqa: C901 node: Node, gm: torch.fx.GraphModule, quantization_config ) -> _OpQuantProperties | None: + """Compute quantization properties for a node. + + Determine which inputs and/or outputs should be annotated for quantization + based on the node's operator and surrounding pattern. + + Args: + node (Node): Node to analyze. + gm (torch.fx.GraphModule): Owning graph module. + quantization_config: Source for activation/weight/bias qspecs. + + Returns: + _OpQuantProperties | None: Properties to apply, or ``None`` if the + node is unsupported or not suitable for quantization. + + """ input_act_qspec = quantization_config.get_input_act_qspec() weight_qspec = quantization_config.get_weight_qspec() output_act_qspec = quantization_config.get_output_act_qspec() bias_qspec = quantization_config.get_bias_qspec(node) + if output_act_qspec is not None: + # Check if output activation qspec is symmetric. In that case + # we avoid conv + relu fusion for quantization annotation. + is_symmetric = output_act_qspec.qscheme == torch.per_tensor_symmetric + else: + is_symmetric = False quant_properties = _OpQuantProperties() def any_or_hardtanh_min_zero(n: Node): + """Return True for any op or hardtanh with ``min_val == 0``.""" # Check that if the node is a hardtanh, its min_val is zero - return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0 + return ( + n.target + not in (torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default) + or n.args[1] == 0 + ) - if _match_pattern( + if not is_symmetric and _match_pattern( node, [ + _conv_ops, + [torch.ops.aten.batch_norm.default], [ - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - torch.ops.aten.conv2d.padding, + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, ], - [torch.ops.aten.batch_norm.default, F.batch_norm], - [torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default], ], filter_fn=any_or_hardtanh_min_zero, ): - if node.target in ( - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - torch.ops.aten.conv2d.padding, - ): + if node.target in _conv_ops: quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), _QuantProperty(1, weight_qspec, mark_annotated=True), @@ -407,51 +519,48 @@ def any_or_hardtanh_min_zero(n: Node): ] elif node.target in ( torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, ): quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif _match_pattern( node, [ - [ - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - torch.ops.aten.conv2d.padding, - ], - [torch.ops.aten.batch_norm.default, F.batch_norm], + _conv_ops, + [torch.ops.aten.batch_norm.default], ], ): - if node.target in ( - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - torch.ops.aten.conv2d.padding, - ): + if node.target in _conv_ops: quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), _QuantProperty(1, weight_qspec, mark_annotated=True), _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), ] - elif node.target in [torch.ops.aten.batch_norm.default, F.batch_norm]: + elif node.target in [ + torch.ops.aten.batch_norm.default, + ]: quant_properties.quant_output = _QuantProperty(0, output_act_qspec) - elif _match_pattern( + elif not is_symmetric and _match_pattern( node, [ [ - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, + *_conv_ops, torch.ops.aten.linear.default, - torch.ops.aten.conv2d.padding, ], - [torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default], + [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + ], ], any_or_hardtanh_min_zero, ): if node.target in ( - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, + *_conv_ops, torch.ops.aten.linear.default, - torch.ops.aten.conv2d.padding, ): quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), @@ -461,10 +570,8 @@ def any_or_hardtanh_min_zero(n: Node): else: quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in ( - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, + *_conv_ops, torch.ops.aten.linear.default, - torch.ops.aten.conv2d.padding, ): quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), @@ -492,33 +599,42 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.minimum.default, torch.ops.aten.maximum.default, ): - shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type] + lhs_node = ensure_type(Node, node.args[0]) + shared_qspec = SharedQuantizationSpec((lhs_node, node)) quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), _QuantProperty( - 1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec # type: ignore[arg-type] + 1, + input_act_qspec if node.args[0] == node.args[1] else shared_qspec, ), ] - quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) elif node.target in (torch.ops.aten.where.self,): - shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type] + true_node = ensure_type(Node, node.args[1]) + input_qspec = ( + SharedQuantizationSpec(true_node) + if is_output_annotated(true_node) + else input_act_qspec + ) quant_properties.quant_inputs = [ - _QuantProperty(1, shared_qspec), # type: ignore[arg-type] - _QuantProperty(2, shared_qspec), # type: ignore[arg-type] + _QuantProperty(1, input_qspec), + _QuantProperty(2, SharedQuantizationSpec((true_node, node))), ] - quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] + quant_properties.quant_output = _QuantProperty( + 0, + SharedQuantizationSpec((true_node, node)), + ) elif node.target in _one_to_one_shared_input_or_input_act_qspec: - if not isinstance(node.args[0], Node): - return None - + input_node = ensure_type(Node, node.args[0]) input_qspec = ( - SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type] - if is_output_annotated(node.args[0]) # type: ignore + SharedQuantizationSpec(input_node) + if is_output_annotated(input_node) else input_act_qspec ) - quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] # type: ignore[arg-type] + quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] quant_properties.quant_output = _QuantProperty( - 0, SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type] + 0, + SharedQuantizationSpec((input_node, node)), ) elif node.target in ( torch.ops.aten.cat.default, @@ -533,25 +649,34 @@ def any_or_hardtanh_min_zero(n: Node): ) if len(node.args[0]) == 0: raise ValueError("Expected non-empty list for node.args[0]") - - shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) + inputs = [ensure_type(Node, element) for element in node.args[0]] + shared_qspec = SharedQuantizationSpec((inputs[0], node)) quant_properties.quant_inputs = [ _QuantProperty( 0, - [ - input_act_qspec if n == node.args[0][0] else shared_qspec # type: ignore[misc] - for n in node.args[0] - ], + [input_act_qspec if n == inputs[0] else shared_qspec for n in inputs], ) ] - quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) elif node.target in _one_to_one: quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in _one_to_one_shared_input_qspec: + input_node = ensure_type(Node, node.args[0]) quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] quant_properties.quant_output = _QuantProperty( - 0, SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type] + 0, + SharedQuantizationSpec((input_node, node)), + ) + elif node.target in [torch.ops.aten.copy_.default]: + input_node = ensure_type(Node, node.args[1]) + quant_properties.quant_inputs = [ + _QuantProperty(0, input_act_qspec), + _QuantProperty(1, input_act_qspec), + ] + quant_properties.quant_output = _QuantProperty( + 0, + SharedQuantizationSpec((input_node, node)), ) elif node.target in [ torch.ops.aten.eq.Tensor, @@ -560,23 +685,63 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.le.Tensor, torch.ops.aten.lt.Tensor, ]: - shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type] + input_node = ensure_type(Node, node.args[0]) + shared_qspec = SharedQuantizationSpec((input_node, node)) quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), _QuantProperty( - 1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec # type: ignore[arg-type] + 1, + input_act_qspec if node.args[0] == node.args[1] else shared_qspec, ), ] quant_properties.quant_output = None - elif node.target in [torch.ops.aten.scalar_tensor.default]: + elif node.target in [ + torch.ops.aten.full.default, + torch.ops.aten.full, + torch.ops.aten.zeros.default, + torch.ops.aten.ones.default, + torch.ops.aten.fill_.Scalar, + torch.ops.aten.scalar_tensor.default, + ]: quant_properties.quant_inputs = [] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in [operator.getitem]: - if not is_output_annotated(node.args[0]): # type: ignore[attr-defined, arg-type] + input_node = ensure_type(Node, node.args[0]) + if not is_output_annotated(input_node): return None - shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type] - quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type] - quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] + shared_qspec = SharedQuantizationSpec(input_node) + quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) + elif node.target in ( + torch.ops.higher_order.cond, + torch.ops.higher_order.while_loop, + ): + submodule_args_pos = -1 if node.target == torch.ops.higher_order.cond else -2 + submodule_args = node.args[submodule_args_pos] + output_qspec = output_act_qspec + if len(submodule_args) > 0: # type: ignore[arg-type] + # The way the TOSA backend handles quantized inputs, arrays of input tensors (such as the input to a + # conditional graph) need shared quantization. + shared_qspec = SharedQuantizationSpec( + (cast(list[Node], submodule_args)[0], node) + ) + quant_properties.quant_inputs = [ + _QuantProperty( + submodule_args_pos, + [ + input_act_qspec, + *([shared_qspec] * (len(submodule_args) - 1)), # type: ignore[arg-type] + ], + ) + ] + if node.target == torch.ops.higher_order.while_loop: + # The output of the while loop body can either re-enter the body, or exit the while loop. + # Therefore, A and B in the diagram below need to share the same quantization parameters. + # A -> while ( RESCALE -> ... RESCALE -> ) -> B + output_qspec = shared_qspec + + quant_properties.quant_output = _QuantProperty(0, output_qspec) + else: return None @@ -597,6 +762,21 @@ def annotate_graph( # type: ignore[return] quantization_config: QuantizationConfig, filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: + """Annotate supported nodes in a graph with quantization specs. + + Iterate through call_function nodes, computes quantization properties, and + apply input/output annotations. A filter can restrict which nodes are + considered. + + Args: + gm (torch.fx.GraphModule): Graph to annotate. + quantization_config (QuantizationConfig): Default qspecs for nodes. + filter_fn (Optional[Callable[[Node], bool]]): Optional node predicate. + + Returns: + Optional[List[List[Node]]]: Reserved for future use; currently None. + + """ for node in gm.graph.nodes: if node.op != "call_function": continue @@ -625,6 +805,9 @@ def annotate_graph( # type: ignore[return] torch.ops.aten.full_like.default, torch.ops.aten.full.default, torch.ops.aten.full, + torch.ops.aten.fill_.Scalar, torch.ops.aten.scalar_tensor.default, + torch.ops.aten.zeros.default, + torch.ops.aten.ones.default, ]: node.kwargs = {} diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index d5c3aab1060..b2bc4a57329 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -3,8 +3,14 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide quantization configuration helpers for the Arm backend. + +Define a small dataclass to carry activation/weight/bias specs and helper +accessors that validate specs before use. Use this module to build and validate +quantization specs consumed by the annotator. + +""" -# pyre-unsafe from dataclasses import dataclass @@ -19,13 +25,38 @@ @dataclass(eq=True, frozen=True) class QuantizationConfig: + """Provide a container for quantization specs. + + Hold optional specs for input/output activations, weights, and bias, and + expose validated accessors. + + Attributes: + input_activation (QuantizationSpec | None): Spec for input activations. + output_activation (QuantizationSpec | None): Spec for output activations. + weight (QuantizationSpec | None): Spec for weights. + bias (QuantizationSpec | None): Spec for bias values. + + """ + input_activation: QuantizationSpec | None output_activation: QuantizationSpec | None weight: QuantizationSpec | None bias: QuantizationSpec | None def get_input_act_qspec(self) -> QuantizationSpec | None: - """Returns QuantizationSpec 'input_activation' after asserting that input_activation.qscheme is valid.""" + """Get the validated input activation spec. + + Validate that the input activation qscheme is supported before + returning the spec. + + Returns: + QuantizationSpec | None: Input activation spec, or ``None`` when + unset. + + Raises: + ValueError: If the qscheme is not per-tensor affine or symmetric. + + """ if self.input_activation is None: return None # Validate that input_activation uses a supported qscheme @@ -39,7 +70,19 @@ def get_input_act_qspec(self) -> QuantizationSpec | None: return self.input_activation def get_output_act_qspec(self) -> QuantizationSpec | None: - """Returns QuantizationSpec 'output_activation' after asserting that output_activation.qscheme is valid.""" + """Get the validated output activation spec. + + Validate that the output activation qscheme is supported before + returning the spec. + + Returns: + QuantizationSpec | None: Output activation spec, or ``None`` when + unset. + + Raises: + ValueError: If the qscheme is not per-tensor affine or symmetric. + + """ if self.output_activation is None: return None # Validate that output_activation uses a supported qscheme @@ -53,7 +96,18 @@ def get_output_act_qspec(self) -> QuantizationSpec | None: return self.output_activation def get_weight_qspec(self) -> QuantizationSpec | None: - """Returns QuantizationSpec 'weight' after asserting that weight.qscheme is valid.""" + """Get the validated weight spec. + + Validate that the weight qscheme is supported (per-tensor or + per-channel symmetric) before returning the spec. + + Returns: + QuantizationSpec | None: Weight spec, or ``None`` when unset. + + Raises: + ValueError: If the qscheme is not a supported symmetric scheme. + + """ if self.weight is None: return None # Validate that weight uses a supported qscheme @@ -65,11 +119,46 @@ def get_weight_qspec(self) -> QuantizationSpec | None: return self.weight def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None: - """Returns QuantizationSpec 'bias' after asserting that bias.dtype is torch.float.""" + """Get the derived or validated bias spec. + + For conv/linear ops, derive bias qparams from the input/weight observers. + Otherwise, validate a user-provided floating-point bias spec. + + Args: + node (torch.fx.Node): Node whose bias spec is requested. + + Returns: + QuantizationSpec | None: Derived or provided bias spec, or ``None`` + when unset. + + Raises: + ValueError: If deriving qparams sees an unexpected number of + observers/fake-quantizers, or if a provided bias dtype is not + floating-point. + + """ def _derive_qparams_fn( obs_or_fqs: list[ObserverOrFakeQuantize], ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute bias scale/zero-point from activation/weight observers. + + Expect two observers or fake-quantize modules: one for the input + activation and one for the weight. The bias scale is the product of + input and weight scales, and the zero-point is a tensor of zeros. + + Args: + obs_or_fqs (list[ObserverOrFakeQuantize]): Observers/fake-quant + in order ``[act, weight]``. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Bias scale tensor and + integer zero-point tensor. + + Raises: + ValueError: If the list does not contain exactly two items. + + """ # Validate expected number of observers/fake-quantizes if len(obs_or_fqs) != 2: raise ValueError( @@ -88,30 +177,46 @@ def _derive_qparams_fn( torch.ops.aten.conv2d.default, torch.ops.aten.linear.default, torch.ops.aten.conv2d.padding, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv3d.padding, ]: - input_act = node.args[0] - weight = node.args[1] - # If the weights are quantized per_tensor, do the same with bias - qscheme = ( - torch.per_tensor_symmetric - if self.weight is None - else self.weight.qscheme - ) - ch_axis = None - if self.weight is not None: - if qscheme == torch.per_channel_symmetric: - ch_axis = self.weight.ch_axis - - quantization_spec = DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item] - derive_qparams_fn=_derive_qparams_fn, - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max - 1, - qscheme=qscheme, - ch_axis=ch_axis, - ) - return quantization_spec # type: ignore[return-value] + if self.input_activation is None or self.weight is None: + raise ValueError( + "Input activation and weight QuantizationConfig must be specified." + ) + + if (self.input_activation.dtype == self.weight.dtype == torch.int8) or ( + self.input_activation.dtype == torch.int16 + and self.weight.dtype == torch.int8 + ): + input_act = node.args[0] + weight = node.args[1] + + # If the weights are quantized per_tensor, do the same with bias + qscheme = ( + torch.per_tensor_symmetric + if self.weight is None + else self.weight.qscheme + ) + ch_axis = None + if self.weight is not None: + if qscheme == torch.per_channel_symmetric: + ch_axis = self.weight.ch_axis + + quantization_spec = DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item] + derive_qparams_fn=_derive_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min + 1, + quant_max=torch.iinfo(torch.int32).max, + qscheme=qscheme, + ch_axis=ch_axis, + ) + return quantization_spec # type: ignore[return-value] + else: + raise NotImplementedError( + f"Bias quantization of types: i:{self.input_activation.dtype}, w:{self.weight.dtype} not implemented" + ) if self.bias is None: return None diff --git a/backends/arm/requirements-arm-ethos-u.txt b/backends/arm/requirements-arm-ethos-u.txt index 5fad9d2fe94..9076aa08852 100644 --- a/backends/arm/requirements-arm-ethos-u.txt +++ b/backends/arm/requirements-arm-ethos-u.txt @@ -3,4 +3,4 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -ethos-u-vela @ git+https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela@d37febc1715edf0d236c2ff555739a8a9aadcf9a +ethos-u-vela == 4.4.1 \ No newline at end of file diff --git a/backends/arm/requirements-arm-models-test.txt b/backends/arm/requirements-arm-models-test.txt index ac4e1d9bad7..238e9d07c9d 100644 --- a/backends/arm/requirements-arm-models-test.txt +++ b/backends/arm/requirements-arm-models-test.txt @@ -3,4 +3,4 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -diffusers[torch] == 0.33.1 +diffusers[torch] == 0.33.1 \ No newline at end of file diff --git a/backends/arm/requirements-arm-tosa.txt b/backends/arm/requirements-arm-tosa.txt index 4b7a3ec0273..c93e9411647 100644 --- a/backends/arm/requirements-arm-tosa.txt +++ b/backends/arm/requirements-arm-tosa.txt @@ -5,5 +5,5 @@ ml_dtypes == 0.5.1 flatbuffers == 24.3.25 - -tosa-tools @ git+https://git.gitlab.arm.com/tosa/tosa-reference-model.git@v2025.07.0 +tosa-adapter-model-explorer == 0.1.0 +ai-edge-model-explorer >= 0.1.16 diff --git a/backends/arm/requirements-arm-vgf.txt b/backends/arm/requirements-arm-vgf.txt new file mode 100644 index 00000000000..1bf4d78c995 --- /dev/null +++ b/backends/arm/requirements-arm-vgf.txt @@ -0,0 +1,8 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +ai_ml_emulation_layer_for_vulkan == 0.7.0 +ai_ml_sdk_model_converter == 0.7.0 +ai_ml_sdk_vgf_library == 0.7.0 diff --git a/backends/arm/runtime/EthosUBackend.cpp b/backends/arm/runtime/EthosUBackend.cpp index bff5ff69284..c339f4a6164 100644 --- a/backends/arm/runtime/EthosUBackend.cpp +++ b/backends/arm/runtime/EthosUBackend.cpp @@ -249,15 +249,6 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { handles.inputs->io[i].elem_size); return Error::InvalidProgram; } - supported = executorch::runtime::is_contiguous_dim_order( - tensor_in.dim_order().data(), tensor_in.dim()); - if (!supported) { - ET_LOG( - Error, - "Input %d expected contiguous dim_order, but got non-contiguous dim_order", - i); - return Error::InvalidProgram; - } // Select a compatible copy routine including checking for input layouts // which require permutation. @@ -335,7 +326,8 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { ET_LOG(Error, "Ethos-U invocation failed error (%d)", result); return Error::InvalidProgram; } - int tensor_dim = 0, io_dim = 0; + size_t tensor_bytes_total = 0; + size_t io_bytes_total = 0; // Write outputs from scratch into EValue pointers for (int i = 0; i < handles.outputs->count; i++) { int tensor_count = 1, io_count = 1; @@ -347,23 +339,39 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { calculate_dimensions( tensor_out, &handles.outputs->io[i], &tensor_count, &io_count); - // At times the topological order of the outputs may change. - // Lets instead ensure that the sum of dimensions match. - tensor_dim = tensor_dim + tensor_count; - io_dim = io_dim + io_count; + size_t tensor_bytes = tensor_out.nbytes(); + size_t io_bytes = static_cast(io_count) * + static_cast(handles.outputs->io[i].elem_size); + + if (tensor_bytes != io_bytes) { + Error status = copy_with_layout_adjustment( + handles.outputs->io[i], i, output_addr, tensor_out, tensor_bytes); + if (status != Error::Ok) { + return status; + } + io_bytes_total += tensor_bytes; + } else { + EXECUTORCH_PROF_SCOPE( + event_tracer, "+EthosUBackend::execute()handles.output.memcpy()"); - EXECUTORCH_PROF_SCOPE( - event_tracer, "+EthosUBackend::execute()handles.output.memcpy()"); + memcpy( + tensor_out.mutable_data_ptr(), + static_cast(output_addr), + tensor_bytes); + io_bytes_total += io_bytes; + } - memcpy( - tensor_out.mutable_data_ptr(), - static_cast(output_addr), - tensor_out.nbytes()); + // At times the topological order of the outputs may change. + // Lets instead ensure that the sum of output bytes match. + tensor_bytes_total += tensor_bytes; } - if (tensor_dim != io_dim) { + if (tensor_bytes_total != io_bytes_total) { ET_LOG(Error, "Total output tensor sizes do not match"); ET_LOG( - Error, "Program expects size of %d but got %d", tensor_dim, io_dim); + Error, + "Program expects %zu bytes but got %zu", + io_bytes_total, + tensor_bytes_total); return Error::InvalidProgram; } return Error::Ok; @@ -374,6 +382,147 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { } private: + // Copies Vela output into the ExecuTorch tensor, adjusting for padding or + // packed layouts produced by the delegate. + Error copy_with_layout_adjustment( + const VelaIO& output_io, + int output_index, + const char* src, + executorch::aten::Tensor& tensor_out, + size_t tensor_bytes) const { + const int elem_size = output_io.elem_size; + if (elem_size == 0) { + ET_LOG( + Error, "Ethos-U output %d reports zero element size", output_index); + return Error::InvalidProgram; + } + + size_t chunk_count = 1; + for (int dim = 0; dim < shapeDim - 1; ++dim) { + const int vela_dim = output_io.shape[dim]; + chunk_count *= static_cast(vela_dim == 0 ? 1 : vela_dim); + } + const int last_dim = output_io.shape[shapeDim - 1]; + const size_t vela_chunk_elems = + static_cast(last_dim == 0 ? 1 : last_dim); + const size_t vela_chunk_size = + vela_chunk_elems * static_cast(elem_size); + + if (tensor_bytes % chunk_count != 0) { + ET_LOG( + Error, + "Ethos-U output %d tensor bytes %zu not divisible by chunk count %zu", + output_index, + tensor_bytes, + chunk_count); + return Error::InvalidProgram; + } + + const size_t chunk_size = tensor_bytes / chunk_count; + + // If Vela writes fewer bytes than the tensor expects we may need to + // expand 4-bit data to 8-bit. Ethos-U outputs may be + // packed 4-bit values but ExecuTorch tensors are at least 8-bit. + if (vela_chunk_size < chunk_size) { + if (chunk_size % vela_chunk_size != 0) { + ET_LOG( + Error, + "Ethos-U output %d chunk bytes %zu not divisible by vela chunk bytes %zu", + output_index, + chunk_size, + vela_chunk_size); + return Error::InvalidProgram; + } + + const size_t expand_factor = chunk_size / vela_chunk_size; + if (expand_factor == 2 && elem_size == 1 && + tensor_out.scalar_type() == ScalarType::Char) { + return unpack_chunks_4bit_to_int8( + reinterpret_cast(src), + tensor_out.mutable_data_ptr(), + chunk_count, + chunk_size, + vela_chunk_size); + } + + ET_LOG( + Error, + "Ethos-U output %d expansion factor %zu with element size %d not supported", + output_index, + expand_factor, + elem_size); + return Error::InvalidProgram; + } + + return strip_delegate_padding( + src, + tensor_out.mutable_data_ptr(), + chunk_count, + chunk_size, + vela_chunk_size); + } + + Error unpack_chunks_4bit_to_int8( + const uint8_t* src, + int8_t* dest, + size_t chunk_count, + size_t dest_chunk_size, + size_t src_chunk_size) const { + const uint8_t* chunk_src = src; + int8_t* chunk_dest = dest; + for (size_t chunk_idx = 0; chunk_idx < chunk_count; ++chunk_idx) { + unpack_single_chunk_4bit_to_int8(chunk_src, chunk_dest, src_chunk_size); + chunk_src += src_chunk_size; + chunk_dest += dest_chunk_size; + } + return Error::Ok; + } + + void unpack_single_chunk_4bit_to_int8( + const uint8_t* src, + int8_t* dest, + size_t chunk_size) const { + for (size_t byte_idx = 0; byte_idx < chunk_size; ++byte_idx) { + const uint8_t packed = src[byte_idx]; + int8_t low = static_cast(packed & 0x0F); + int8_t high = static_cast((packed >> 4) & 0x0F); + if (low >= 8) { + low -= 16; + } + if (high >= 8) { + high -= 16; + } + dest[2 * byte_idx] = low; + dest[2 * byte_idx + 1] = high; + } + } + + Error strip_delegate_padding( + const char* src, + char* dest, + size_t chunk_count, + size_t dest_chunk_size, + size_t src_chunk_size) const { + if (dest_chunk_size > src_chunk_size) { + ET_LOG( + Error, + "dest chunk size %zu must not exceed src chunk size %zu", + dest_chunk_size, + src_chunk_size); + return Error::InvalidProgram; + } + if (src == nullptr || dest == nullptr) { + ET_LOG(Error, "Ethos-U padded copy received null buffer"); + return Error::InvalidState; + } + for (size_t chunk_idx = 0; chunk_idx < chunk_count; ++chunk_idx) { + memcpy(dest, src, dest_chunk_size); + src += src_chunk_size; + dest += dest_chunk_size; + } + return Error::Ok; + } + void calculate_dimensions( const executorch::aten::Tensor tensor, VelaIO* io, @@ -383,19 +532,43 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { *tensor_count = *tensor_count * tensor.size(i); } - // The VelaIO type has a shape of fixed size 4 - for (int i = 0; i < 4; i++) { + // The VelaIO type has a shape of fixed size 6 + for (int i = 0; i < shapeDim; i++) { *io_count = *io_count * io->shape[i]; } } }; namespace { -auto backend = EthosUBackend(); -Backend backend_id{"EthosUBackend", &backend}; -static auto registered = register_backend(backend_id); +auto EthosUBackend_backend = EthosUBackend(); +Backend EthosUBackend_id{"EthosUBackend", &EthosUBackend_backend}; +static executorch::runtime::Error EthosUBackend_registered = + register_backend(EthosUBackend_id); + +#ifdef __ZEPHYR__ +/** + * This function serves as a linker force-include mechanism to ensure the + * EthosU backend module gets properly linked into the final executable, + * even when it might otherwise be optimized out by the linker due to + * linker options that remove unused code or data for example + * if you link with --gc-sections + * This function can be called from your runner to force the inclusion of + * the EthosU backend module. As a bonus it will return the status of the + * backend registration, so you can also check if the registration was + * successful. + */ + +// Warning: This should not be considered to be an API and might get removed +// without notice in a future release if a better way to solve this is +// implemented. +extern "C" executorch::runtime::Error +executorch_delegate_EthosUBackend_registered() { + return EthosUBackend_registered; +} +#endif + } // namespace } // namespace arm } // namespace backends -} // namespace executorch \ No newline at end of file +} // namespace executorch diff --git a/backends/arm/runtime/VGFSetup.cpp b/backends/arm/runtime/VGFSetup.cpp index abb4c50d8be..fd3a114c190 100644 --- a/backends/arm/runtime/VGFSetup.cpp +++ b/backends/arm/runtime/VGFSetup.cpp @@ -24,6 +24,13 @@ namespace vgf { /* static function to map format to byte count */ static uint32_t get_format_size(VkFormat format); +// SPV_ARM_tensor does not support rank-0 representations according to the spec. +// Use an unsqueezed dimension when the resource table contains an empty +// shape. Tensors are output as rank 0 when copied back from the vgf backend. +namespace { +constexpr int64_t kScalarSentinelDimension = 1; +} + // Debug function to inspect memory properties static string memory_flags_to_string(VkMemoryPropertyFlags flags) { if (flags == 0) @@ -264,7 +271,11 @@ static void debug_print_resources( the_shape.size(), the_stride.size()); for (int j = 0; j < the_shape.size(); j++) { - ET_LOG(Info, " %d: dim %ld", j, the_shape[j]); + ET_LOG( + Info, + " %d: dim %lld", + j, + static_cast(the_shape[j])); } // Allocate a tensor with bound memory break; @@ -387,6 +398,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { // Get tensor shape and strides auto shape = resource_decoder->getTensorShape(i); auto stride = resource_decoder->getTensorStride(i); + const auto shape_size = shape.size(); switch (resource_decoder->getCategory(i)) { case vgflib::ResourceCategory::INPUT: @@ -409,9 +421,9 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { result = allocate_tensor( vk_physical, vk_device, - vgflib::ToVkFormat(resource_decoder->getVkFormat(i)), - static_cast(shape.size()), - shape.begin(), + resource_format, + shape_size == 0 ? 1 : static_cast(shape_size), + shape_size == 0 ? &kScalarSentinelDimension : shape.begin(), static_cast(stride.size()), stride.begin(), &tensor_description, @@ -422,8 +434,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { ET_LOG(Error, "Failed to allocate tensor for VGF resource %d", i); return false; } - size_t e_size = get_format_size( - vgflib::ToVkFormat(resource_decoder->getVkFormat(i))); + size_t e_size = get_format_size(resource_format); if (0 == e_size) { ET_LOG(Error, "failed to get element size of VkFormat"); return false; @@ -449,9 +460,11 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { .sType = VK_STRUCTURE_TYPE_TENSOR_DESCRIPTION_ARM, .pNext = nullptr, .tiling = VK_TENSOR_TILING_LINEAR_ARM, - .format = vgflib::ToVkFormat(resource_decoder->getVkFormat(i)), - .dimensionCount = static_cast(shape.size()), - .pDimensions = shape.begin(), + .format = resource_format, + .dimensionCount = + shape_size == 0 ? 1 : static_cast(shape_size), + .pDimensions = + shape_size == 0 ? &kScalarSentinelDimension : shape.begin(), // Note: stride_data of 0's causes size==0, null means stride==size .pStrides = (0 == stride.size() ? nullptr : stride.begin()), .usage = VK_TENSOR_USAGE_DATA_GRAPH_BIT_ARM, @@ -694,7 +707,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { ); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to create DataGraphPipeline"); - return result; + return false; } // prepare the graph pipeline session @@ -708,7 +721,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { vk_device, &pipeline_session_info, nullptr, &vk_session); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to create DataGraphPipelineSession"); - return result; + return false; } // Allocate command buffer @@ -722,7 +735,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { vk_device, &buffer_allocate_info, &vk_execute_cmd); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to allocate command buffers"); - return result; + return false; } // Allocate intermediates memory based on the pipeline requirements provided @@ -740,7 +753,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { vk_device, &bind_point_requirements_info, &bind_point_count, nullptr); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to get session bind point count"); - return result; + return false; } vector @@ -753,7 +766,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { bind_point_requirements.data()); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to get session bind point requirements"); - return result; + return false; } // Given the bind points, just make individual allocations and bind them @@ -764,18 +777,18 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { ET_LOG( Error, "Expected VK_DATA_GRAPH_PIPELINE_SESSION_BIND_POINT_TYPE_MEMORY_ARM"); - return VK_ERROR_UNKNOWN; + return false; } if (bind_point_requirement.bindPoint != VK_DATA_GRAPH_PIPELINE_SESSION_BIND_POINT_TRANSIENT_ARM) { ET_LOG( Error, "Expected VK_DATA_GRAPH_PIPELINE_SESSION_BIND_POINT_TRANSIENT_ARM"); - return VK_ERROR_UNKNOWN; + return false; } if (bind_point_requirement.numObjects != 1) { ET_LOG(Error, "Expected only one object for the bindpoint"); - return VK_ERROR_UNKNOWN; + return false; } VkDataGraphPipelineSessionMemoryRequirementsInfoARM memory_requirements_info = { @@ -808,7 +821,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { vkAllocateMemory(vk_device, &memory_allocate_info, nullptr, &memory); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to allocate memory for intermediates"); - return result; + return false; } // so we can free this object in destructor intermediates.push_back(memory); @@ -826,7 +839,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { result = vkBindDataGraphPipelineSessionMemoryARM(vk_device, 1, &bind_info); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to bind intermediates memory"); - return result; + return false; } } diff --git a/backends/arm/runtime/VelaBinStream.cpp b/backends/arm/runtime/VelaBinStream.cpp index 180219c75b5..c8d568499c9 100644 --- a/backends/arm/runtime/VelaBinStream.cpp +++ b/backends/arm/runtime/VelaBinStream.cpp @@ -6,7 +6,7 @@ */ /* - * Warning: Do not change this without changing arm_backend.py::vela_compile + * Warning: Do not change this without changing arm_vela.py::vela_compile * as that function emits this format and the two need to align. */ diff --git a/backends/arm/runtime/VelaBinStream.h b/backends/arm/runtime/VelaBinStream.h index 04b8b2ada00..7f6606200b3 100644 --- a/backends/arm/runtime/VelaBinStream.h +++ b/backends/arm/runtime/VelaBinStream.h @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 Arm Limited and/or its affiliates. + * Copyright 2023-2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -7,7 +7,7 @@ /* * Minimal reading function for vela_bin_stream wire format. This is an - * implementation detail of the arm_backend AoT flow and ArmBackendEthosU + * implementation detail of the arm backend AoT flow and ArmBackendEthosU * and subject to change. * This format captures the command stream, I/O and memory layout data to * enable execution of the command stream on Ethos-U hardware. @@ -34,9 +34,11 @@ typedef struct { char data[]; // block.name specific format data } VelaBinBlock; +constexpr int shapeDim = 6; // Number of dimensions in VelaIO + // A Vela input or output descriptor in the binary stream typedef struct { - int shape[4]; // Up to 4D shape of input or output + int shape[shapeDim]; // Shape of input or output int elem_size; // Element sizeof in bytes int offset; // Offset in bytes within SRAM working data int region; // Scratch region this belongs to diff --git a/backends/arm/runtime/targets.bzl b/backends/arm/runtime/targets.bzl index 88ce410d956..b4c17acda34 100644 --- a/backends/arm/runtime/targets.bzl +++ b/backends/arm/runtime/targets.bzl @@ -14,7 +14,7 @@ def define_common_targets(): name = "arm_backend", srcs = ["EthosUBackend.cpp"], headers = [], - compatible_with = ["ovr_config//cpu:arm32-embedded"], + compatible_with = ["ovr_config//cpu:arm32-embedded", "ovr_config//cpu:arm32-embedded-fpu"], # arm_executor_runner.cpp needs to compile with executor as whole # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) link_whole = True, diff --git a/backends/arm/scripts/TOSA_minimal_example.ipynb b/backends/arm/scripts/TOSA_minimal_example.ipynb index 785affc657b..a249f03a873 100644 --- a/backends/arm/scripts/TOSA_minimal_example.ipynb +++ b/backends/arm/scripts/TOSA_minimal_example.ipynb @@ -62,7 +62,7 @@ "model = Add()\n", "model = model.eval()\n", "exported_program = torch.export.export(model, example_inputs)\n", - "graph_module = exported_program.module()\n", + "graph_module = exported_program.graph_module\n", "\n", "_ = graph_module.print_readable()" ] @@ -86,10 +86,7 @@ "metadata": {}, "outputs": [], "source": [ - "from executorch.backends.arm.arm_backend import (\n", - " ArmCompileSpecBuilder,\n", - ")\n", - "from executorch.backends.arm.tosa.specification import TosaSpecification\n", + "from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec\n", "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "from pathlib import Path\n", "\n", @@ -99,11 +96,7 @@ "\n", "# Create a compilation spec describing the target for configuring the quantizer\n", "# Dump intermediate artifacts (in this case TOSA flat buffers) to specified location\n", - "tosa_spec = TosaSpecification.create_from_string(target)\n", - "spec_builder = (ArmCompileSpecBuilder()\n", - " .tosa_compile_spec(tosa_spec)\n", - " .dump_intermediate_artifacts_to(str(cwd_dir / base_name)))\n", - "compile_spec = spec_builder.build()\n", + "compile_spec = TosaCompileSpec(target).dump_intermediate_artifacts_to(str(cwd_dir / base_name))\n", "\n", "_ = graph_module.print_readable()\n", "\n", @@ -130,15 +123,11 @@ "metadata": {}, "outputs": [], "source": [ - "from executorch.backends.arm.arm_backend import (\n", - " ArmCompileSpecBuilder,\n", - " get_tosa_spec,\n", - ")\n", + "from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec\n", "from executorch.backends.arm.quantizer import (\n", " TOSAQuantizer,\n", " get_symmetric_quantization_config,\n", ")\n", - "from executorch.backends.arm.tosa.specification import TosaSpecification\n", "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "from pathlib import Path\n", "\n", @@ -148,14 +137,10 @@ "\n", "# Create a compilation spec describing the target for configuring the quantizer\n", "# Dump intermediate artifacts (in this case TOSA flat buffers) to specified location\n", - "tosa_spec = TosaSpecification.create_from_string(target)\n", - "spec_builder = (ArmCompileSpecBuilder()\n", - " .tosa_compile_spec(tosa_spec)\n", - " .dump_intermediate_artifacts_to(str(cwd_dir / base_name)))\n", - "compile_spec = spec_builder.build()\n", + "compile_spec = TosaCompileSpec(target).dump_intermediate_artifacts_to(str(cwd_dir / base_name))\n", "\n", "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", - "quantizer = TOSAQuantizer(get_tosa_spec(compile_spec))\n", + "quantizer = TOSAQuantizer(compile_spec)\n", "operator_config = get_symmetric_quantization_config()\n", "quantizer.set_global(operator_config)\n", "\n", @@ -216,7 +201,7 @@ " config=ExecutorchBackendConfig(extract_delegate_segments=False)\n", " )\n", "\n", - "executorch_program_manager.exported_program().module().print_readable()\n", + "executorch_program_manager.exported_program().graph_module.print_readable()\n", "\n", "# Save pte file\n", "pte_name = base_name + \".pte\"\n", diff --git a/backends/arm/scripts/build_executor_runner.sh b/backends/arm/scripts/build_executor_runner.sh index a05287ac4bf..4a14044b345 100755 --- a/backends/arm/scripts/build_executor_runner.sh +++ b/backends/arm/scripts/build_executor_runner.sh @@ -10,8 +10,9 @@ script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) et_root_dir=$(cd ${script_dir}/../../.. && pwd) et_root_dir=$(realpath ${et_root_dir}) toolchain=arm-none-eabi-gcc -setup_path_script=${et_root_dir}/examples/arm/ethos-u-scratch/setup_path.sh +setup_path_script=${et_root_dir}/examples/arm/arm-scratch/setup_path.sh _setup_msg="please refer to ${et_root_dir}/examples/arm/setup.sh to properly install necessary tools." +source "${script_dir}/utils.sh" pte_file="" target="ethos-u55-128" @@ -24,7 +25,7 @@ extra_build_flags="" output_folder_set=false output_folder="." et_build_root="${et_root_dir}/arm_test" -ethosu_tools_dir=${et_root_dir}/examples/arm/ethos-u-scratch +ethosu_tools_dir=${et_root_dir}/examples/arm/arm-scratch select_ops_list="" build_bundleio_flags=" -DET_BUNDLE_IO=OFF " @@ -44,7 +45,7 @@ help() { echo " --memory_mode= Vela memory mode, used for setting the Timing Adapter parameters of the Corstone platforms." echo " Valid values are Shared_Sram(for Ethos-U55, Ethos-U65, Ethos-85), Sram_Only(for Ethos-U55, Ethos-U65, Ethos-U85) or Dedicated_Sram(for Ethos-U65, Ethos-U85)." echo " Default: Shared_Sram for the Ethos-U55 and Sram_Only for the Ethos-U85" - echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" + echo " --etdump Adds Devtools etdump support to track timing and output, etdump area will be base64 encoded in the log" echo " --extra_build_flags= Extra flags to pass to cmake like -DET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE=60000 Default: none " echo " --output= Output folder Default: /_.pte" echo " --et_build_root= Build output root folder to use, defaults to ${et_build_root}" @@ -161,7 +162,7 @@ if [ "$bundleio" = true ] ; then fi if [ "$build_with_etdump" = true ] ; then - build_with_etdump_flags=" -DEXECUTORCH_ENABLE_EVENT_TRACER=ON " + build_with_etdump_flags=" -DEXECUTORCH_ENABLE_EVENT_TRACER=ON -DET_DUMP_INTERMEDIATE_OUTPUTS=ON " fi echo "Building with BundleIO/etdump/extra flags: ${build_bundleio_flags} ${build_with_etdump_flags} ${extra_build_flags}" @@ -185,7 +186,9 @@ cmake \ echo "[${BASH_SOURCE[0]}] Configured CMAKE" -cmake --build ${output_folder} -j$(nproc) -- arm_executor_runner +parallel_jobs="$(get_parallel_jobs)" + +cmake --build ${output_folder} -j"${parallel_jobs}" -- arm_executor_runner echo "[${BASH_SOURCE[0]}] Generated ${toolchain} elf file:" find ${output_folder} -name "arm_executor_runner" diff --git a/backends/arm/scripts/build_executor_runner_vkml.sh b/backends/arm/scripts/build_executor_runner_vkml.sh index 1df63acc425..f443032ee6f 100755 --- a/backends/arm/scripts/build_executor_runner_vkml.sh +++ b/backends/arm/scripts/build_executor_runner_vkml.sh @@ -6,39 +6,43 @@ set -eu -script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) et_root_dir=$(cd ${script_dir}/../../.. && pwd) et_root_dir=$(realpath ${et_root_dir}) -setup_path_script=${et_root_dir}/examples/arm/ethos-u-scratch/setup_path.sh +setup_path_script=${et_root_dir}/examples/arm/arm-scratch/setup_path.sh _setup_msg="please refer to ${et_root_dir}/examples/arm/setup.sh to properly install necessary tools." - build_type="Release" build_with_etdump=false extra_build_flags="" output_folder="cmake-out-vkml" +build_with_etdump_flags="-DEXECUTORCH_ENABLE_EVENT_TRACER=OFF" +build_with_bundleio_flags="-DEXECUTORCH_ENABLE_BUNDLE_IO=OFF" + +source "${script_dir}/utils.sh" -build_with_etdump_flags=" -DEXECUTORCH_ENABLE_EVENT_TRACER=OFF " help() { echo "Usage: $(basename $0) [options]" echo "Options:" - echo " --build_type= Build with Release, Debug or RelWithDebInfo, default is ${build_type}" - echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" - echo " --extra_build_flags= Extra flags to pass to cmake. Default: none " - echo " --output= Output folder Default: $(output_folder)" + echo " --build_type= Build with Release, Debug or RelWithDebInfo, default is ${build_type}" + echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" + echo " --extra_build_flags= Extra flags to pass to cmake. Default: none " + echo " --output= Output folder Default: $(output_folder)" + echo " --bundleio Support BundleIO using Devtools with Input/RefOutput included" exit 0 } for arg in "$@"; do case $arg in - -h|--help) help ;; - --build_type=*) build_type="${arg#*=}";; - --etdump) build_with_etdump=true ;; - --extra_build_flags=*) extra_build_flags="${arg#*=}";; - --output=*) output_folder="${arg#*=}";; - --select_ops_list=*) select_ops_list="${arg#*=}";; - *) - ;; + -h|--help) help ;; + --build_type=*) build_type="${arg#*=}";; + --etdump) build_with_etdump=true ;; + --extra_build_flags=*) extra_build_flags="${arg#*=}";; + --output=*) output_folder="${arg#*=}";; + --select_ops_list=*) select_ops_list="${arg#*=}";; + --bundleio) build_with_bundleio_flags="-DEXECUTORCH_ENABLE_BUNDLE_IO=ON" ;; + *) + ;; esac done @@ -50,25 +54,29 @@ done source ${setup_path_script} mkdir -p "${output_folder}" -output_folder=$(realpath ${output_folder}) - -echo "--------------------------------------------------------------------------------" -echo "Build Arm VKML executor runner: '${output_folder}' with extra build flags: ${extra_build_flags}" -echo "--------------------------------------------------------------------------------" +output_folder=$(realpath "${output_folder}") cd ${et_root_dir}/examples/arm/executor_runner if [ "$build_with_etdump" = true ] ; then - build_with_etdump_flags=" -DEXECUTORCH_ENABLE_EVENT_TRACER=ON " + build_with_etdump_flags="-DEXECUTORCH_ENABLE_EVENT_TRACER=ON" fi -echo "Building with extra flags: ${build_with_etdump_flags} ${extra_build_flags}" +echo "-----------------------------------------------------------------------------------------------" +echo "Build Arm VKML executor runner: '${output_folder}' with extra build flags: " +echo "${build_with_etdump_flags} ${build_with_bundleio_flags} ${extra_build_flags}" +echo "-----------------------------------------------------------------------------------------------" + cmake \ + -S "${et_root_dir}" \ + -B "${output_folder}" \ -Wall \ -Werror \ -DCMAKE_BUILD_TYPE=${build_type} \ + -DCMAKE_CXX_FLAGS="${extra_build_flags} ${CMAKE_CXX_FLAGS:-}" \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_XNNPACK=OFF \ @@ -77,13 +85,16 @@ cmake \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON \ -DEXECUTORCH_ENABLE_LOGGING=ON \ - -DPYTHON_EXECUTABLE=$(which python3) \ - ${extra_build_flags} \ - -B ${output_folder} ${et_root_dir} + -DEXECUTORCH_BUILD_DEVTOOLS=ON \ + -DPYTHON_EXECUTABLE="$(which python3)" \ + ${build_with_etdump_flags} \ + ${build_with_bundleio_flags} echo "[${BASH_SOURCE[0]}] Configured CMAKE" -cmake --build ${output_folder} -j$(nproc) +parallel_jobs="$(get_parallel_jobs)" + +cmake --build "${output_folder}" --parallel "${parallel_jobs}" echo "[${BASH_SOURCE[0]}] Built VKML runner: " find ${output_folder} -name "executor_runner" diff --git a/backends/arm/scripts/build_executorch.sh b/backends/arm/scripts/build_executorch.sh index 84c675ddb4a..b0a93c9540b 100755 --- a/backends/arm/scripts/build_executorch.sh +++ b/backends/arm/scripts/build_executorch.sh @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. # Optional parameter: -# --build_type= "Release" | "Debug" | "RelWithDebInfo" +# --build_type= "Release" | "Debug" | "RelWithDebInfo" | "UndefinedSanitizer" | "AddressSanitizer" # --etdump build with devtools-etdump support set -eu @@ -14,9 +14,11 @@ script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) et_root_dir=$(cd ${script_dir}/../../.. && pwd) et_root_dir=$(realpath ${et_root_dir}) toolchain=arm-none-eabi-gcc -setup_path_script=${et_root_dir}/examples/arm/ethos-u-scratch/setup_path.sh +setup_path_script=${et_root_dir}/examples/arm/arm-scratch/setup_path.sh _setup_msg="please refer to ${et_root_dir}/examples/arm/setup.sh to properly install necessary tools." +source "${script_dir}/utils.sh" + et_build_root="${et_root_dir}/arm_test" build_type="Release" build_devtools=OFF @@ -26,7 +28,7 @@ help() { echo "Usage: $(basename $0) [options]" echo "Options:" echo " --et_build_root= Build output root folder to use, defaults to ${et_build_root}" - echo " --build_type= Build with Release, Debug or RelWithDebInfo, default is ${build_type}" + echo " --build_type= Build with Release, Debug, RelWithDebInfo, UndefinedSanitizer or AddressSanitizer, default is ${build_type}" echo " --devtools Build Devtools libs" echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" echo " --toolchain= Toolchain can be specified (e.g. bare metal as arm-none-eabi-gcc or zephyr as arm-zephyr-eabi-gcc Default: ${toolchain}" @@ -76,12 +78,14 @@ cd "${et_root_dir}" # Build cmake -DCMAKE_TOOLCHAIN_FILE=${toolchain_cmake} \ --DCMAKE_BUILD_TYPE=Release \ +-DCMAKE_BUILD_TYPE=${build_type} \ -DEXECUTORCH_BUILD_DEVTOOLS=$build_devtools \ -DEXECUTORCH_BUILD_ARM_ETDUMP=$build_with_etdump \ --preset arm-baremetal -B${et_build_dir} -cmake --build ${et_build_dir} -j$(nproc) --target install --config ${build_type} -- +parallel_jobs="$(get_parallel_jobs)" + +cmake --build ${et_build_dir} -j"${parallel_jobs}" --target install --config ${build_type} -- set +x diff --git a/backends/arm/scripts/fvp_utils.sh b/backends/arm/scripts/fvp_utils.sh index cf0774e8706..c3205074004 100644 --- a/backends/arm/scripts/fvp_utils.sh +++ b/backends/arm/scripts/fvp_utils.sh @@ -23,6 +23,9 @@ if [[ $? -ne 0 ]]; then exit 1 fi +script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source "${script_dir}/utils.sh" + if [[ "${ARCH}" == "x86_64" ]]; then # FVPs corstone300_url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.22_20_Linux64.tgz?rev=018659bd574f4e7b95fa647e7836ccf4&hash=22A79103C6FA5FFA7AFF3BE0447F3FF9" @@ -42,7 +45,8 @@ elif [[ "${ARCH}" == "aarch64" ]] || [[ "${ARCH}" == "arm64" ]]; then corstone320_model_dir="Linux64_armv8l_GCC-9.3" corstone320_md5_checksum="3889f1d80a6d9861ea4aa6f1c88dd0ae" else - echo "[main] Error: only x86-64 & aarch64/arm64 architecture is supported for now!"; exit 1; + log_step "fvp" "Error: only x86-64 & aarch64/arm64 architecture is supported for now!" + exit 1 fi function install_fvp() { @@ -52,7 +56,7 @@ function install_fvp() { for fvp in "${fvps[@]}"; do cd "${root_dir}" if [[ ! -e "FVP_${fvp}.tgz" ]]; then - echo "[${FUNCNAME[0]}] Downloading FVP ${fvp}..." + log_step "fvp" "Downloading FVP ${fvp}" url_variable=${fvp}_url fvp_url=${!url_variable} curl --output "FVP_${fvp}.tgz" "${fvp_url}" @@ -61,7 +65,7 @@ function install_fvp() { verify_md5 ${fvp_md5_checksum} FVP_${fvp}.tgz || exit 1 fi - echo "[${FUNCNAME[0]}] Installing FVP ${fvp}..." + log_step "fvp" "Installing FVP ${fvp}" rm -rf FVP-${fvp} mkdir -p FVP-${fvp} cd FVP-${fvp} @@ -76,7 +80,7 @@ function install_fvp() { ./FVP_Corstone_SSE-320.sh --i-agree-to-the-contained-eula --force --destination ./ --quiet --no-interactive ;; *) - echo "[${FUNCNAME[0]}] Error: Unknown FVP model ${fvp}. Exiting." + log_step "fvp" "Error: Unknown FVP model ${fvp}. Exiting." exit 1 ;; esac @@ -89,12 +93,12 @@ function check_fvp_eula () { if [[ "${eula_acceptance}" -eq 0 ]]; then if [[ ${eula_acceptance_by_variable} != "True" ]]; then - echo "Must pass argument '--i-agree-to-the-contained-eula' to agree to EULA associated with downloading the FVP." - echo "Alternativly set environment variable ARM_FVP_INSTALL_I_AGREE_TO_THE_CONTAINED_EULA=True." - echo "Exiting!" + log_step "fvp" "Must pass '--i-agree-to-the-contained-eula' to download the FVP" + log_step "fvp" "Alternatively set ARM_FVP_INSTALL_I_AGREE_TO_THE_CONTAINED_EULA=True" + log_step "fvp" "Exiting due to missing EULA acceptance" exit 1 else - echo "Arm EULA for FVP agreed to with ARM_FVP_INSTALL_I_AGREE_TO_THE_CONTAINED_EULA=True environment variable" + log_step "fvp" "Arm EULA accepted via ARM_FVP_INSTALL_I_AGREE_TO_THE_CONTAINED_EULA=True" fi fi } @@ -103,12 +107,10 @@ function setup_fvp() { if [[ "${OS}" != "Linux" ]]; then # Check if FVP is callable if command -v FVP_Corstone_SSE-300_Ethos-U55 &> /dev/null; then - echo "[${FUNCNAME[0]}] Info: FVP for MacOS seem to be installed. Continuing..." + log_step "fvp" "Detected pre-installed MacOS FVP binaries; continuing" return 0 # If true exit gracefully and proceed with setup else - echo "[${FUNCNAME[0]}] Warning: FVP only supported with Linux OS, skipping FVP setup..." - echo "[${FUNCNAME[0]}] Warning: For MacOS, using https://github.com/Arm-Examples/FVPs-on-Mac is recommended." - echo "[${FUNCNAME[0]}] Warning: Follow the instructions and make sure the path is set correctly." + log_step "fvp" "Warning: FVP setup only supported on Linux; Mac users should install via https://github.com/Arm-Examples/FVPs-on-Mac and ensure binaries are on PATH" return 1 # Throw error. User need to install FVP according to ^^^ fi fi diff --git a/backends/arm/scripts/install_models_for_test.sh b/backends/arm/scripts/install_models_for_test.sh index 001d733a014..d6a7b9cdec0 100644 --- a/backends/arm/scripts/install_models_for_test.sh +++ b/backends/arm/scripts/install_models_for_test.sh @@ -6,3 +6,16 @@ set -e pip install -r backends/arm/requirements-arm-models-test.txt + +# Install model gym repository +git clone https://github.com/arm/neural-graphics-model-gym.git +cd neural-graphics-model-gym +# Remove model-converter installation from model-gym repository (to prevent overwriting executorch version) +if [[ "$(uname)" == "Darwin" ]]; then + sed -i '' 's/^model-converter = "ng_model_gym.bin.model_converter_launcher:main"/#&/' pyproject.toml +else + sed -i 's/^model-converter = "ng_model_gym.bin.model_converter_launcher:main"/#&/' pyproject.toml +fi +pip install . --no-deps +cd .. +rm -rf neural-graphics-model-gym \ No newline at end of file diff --git a/backends/arm/scripts/mlsdk_utils.sh b/backends/arm/scripts/mlsdk_utils.sh index 3129676ec4b..95aa5cf2a4f 100755 --- a/backends/arm/scripts/mlsdk_utils.sh +++ b/backends/arm/scripts/mlsdk_utils.sh @@ -6,52 +6,124 @@ set -euo pipefail -mlsdk_manifest_url="https://github.com/arm/ai-ml-sdk-manifest.git" +# URL and tag of the MLSDK manifest repository. Can be overridden by environment variables. +# eg. export MLSDK_MANIFEST_URL=...; export MLSDK_MANIFEST_TAG=... +mlsdk_manifest_url="${MLSDK_MANIFEST_URL:-https://github.com/arm/ai-ml-sdk-manifest.git}" +mlsdk_manifest_tag="${MLSDK_MANIFEST_TAG:-refs/tags/v2025.10.0}" script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) source ${script_dir}/utils.sh -usage() { echo "Usage: $0 [-u ]" 1>&2; exit 1; } +function mlsdk_sync_manifest() { + local manifest_dir="$1" -while getopts ":u:" opt; do - case "${opt}" in - u) - mlsdk_manifest_url=${OPTARG} - ;; - *) - usage - ;; - esac -done + mkdir -p "${manifest_dir}" + pushd "${manifest_dir}" || return 1 + local parallel_jobs="$(get_parallel_jobs)" + + if [[ ! -f repo ]]; then + log_step "mlsdk" "Fetching repo tool" + curl https://storage.googleapis.com/git-repo-downloads/repo > repo + chmod u+x repo + fi + + ./repo init \ + --depth=1 \ + --no-repo-verify \ + --manifest-url "${mlsdk_manifest_url}" \ + --manifest-branch "${mlsdk_manifest_tag}" \ + -g model-converter,emulation-layer,vgf-library + + local default_manifest=".repo/manifests/default.xml" + + # TODO: Remove this workaround once 2GB capable mlir translator is available + # in the official MLSDK repository. + if [[ "${OSTYPE:-}" == darwin* ]]; then + sed -i '' 's|revision="refs/tags/v2025.07.1"|revision="c3b324e643b4b4e592de8a9123a58c4179649d8c"|' "${default_manifest}" + else + sed -i 's|revision="refs/tags/v2025.07.1"|revision="c3b324e643b4b4e592de8a9123a58c4179649d8c"|' "${default_manifest}" + fi + ./repo sync --force-sync -j"${parallel_jobs}" + + popd +} function download_ai_mlsdk_manifest() { - local _dada_dir="$1" + local _manifest_dir="$1" - if [[ -z "${_dada_dir}" ]]; then - echo "Error: _dada_dir parameter missing?" + if [[ -z "${_manifest_dir}" ]]; then + log_step "mlsdk" "Error: _manifest_dir parameter missing" return 1 fi if [[ -z "${mlsdk_manifest_url}" ]]; then - echo "Error: mlsdk_manifest_url parameter missing?" + log_step "mlsdk" "Error: mlsdk_manifest_url parameter missing" return 1 fi - if [[ ! -d "${_dada_dir}" ]]; then - mkdir -p "$_dada_dir" - pushd "$_dada_dir" || exit 1 + if [[ ! -d "${_manifest_dir}/sw" ]] || [[ ! -d "${_manifest_dir}/dependencies" ]]; then + log_step "mlsdk" "MLSDK checkout not found at ${_manifest_dir}; performing initial download" + mlsdk_sync_manifest "${_manifest_dir}" + return 0 + fi - curl https://storage.googleapis.com/git-repo-downloads/repo > repo - chmod u+x repo - ./repo init --no-repo-verify --depth=1 --manifest-url ${mlsdk_manifest_url} -g model-converter,emulation-layer,vgf-library - ./repo sync + # If a checkout exists, get the URL and tag of the existing checkout. + local cached_url="" + local cached_tag="" + local repo_config="${_manifest_dir}/.repo/manifests.git/config" + if [[ -f "${repo_config}" ]]; then + cached_url="$(git config --file "${repo_config}" remote.origin.url 2>/dev/null || echo "")" + cached_tag="$(git config --file "${repo_config}" branch.default.merge 2>/dev/null || echo "")" + fi - popd + # If the tag is main or refs/heads/main, always refresh the checkout. + # This allows users to track the latest main branch without needing to manually + # delete the checkout. + local tag_tracks_main=0 + if [[ "${mlsdk_manifest_tag}" == "main" ]] || [[ "${mlsdk_manifest_tag}" == "refs/heads/main" ]]; then + tag_tracks_main=1 + fi + + # If the URL and tag match, and the tag does not track main, reuse the existing checkout. + # Skip fetching updates. + if [[ "${cached_url}" == "${mlsdk_manifest_url}" ]] && [[ "${cached_tag}" == "${mlsdk_manifest_tag}" ]] && [[ "${tag_tracks_main}" -eq 0 ]]; then + log_step "mlsdk" "Reusing cached MLSDK dependencies at ${_manifest_dir}" + return 0 + fi + + # If we reach here, either the URL or tag changed, or the tag tracks main. + # In all cases, refresh the checkout. + if [[ "${tag_tracks_main}" -eq 1 ]]; then + log_step "mlsdk" "Manifest tracks branch ${mlsdk_manifest_tag}; refreshing checkout" + else + log_step "mlsdk" "Manifest changed (url=${cached_url:-} -> ${mlsdk_manifest_url}, tag=${cached_tag:-} -> ${mlsdk_manifest_tag}); refreshing checkout" + fi + + # Clean up any local manifest changes to avoid repo sync errors. + # Since we patched in a local manifest for tosa_gitlab.xml, + # remove any existing local manifests to avoid conflicts. + # TODO: we should remove this at some point in the future, but its not hurting anything for now. + if [[ -d "${_manifest_dir}/.repo/local_manifests" ]]; then + rm -rf "${_manifest_dir}/.repo/local_manifests/" + fi + + # Clean up any local changes in the manifests repository. + if [[ -d "${_manifest_dir}/.repo/manifests.git" ]]; then + git -C "${_manifest_dir}/.repo/manifests.git" reset --hard HEAD >/dev/null 2>&1 || true + git -C "${_manifest_dir}/.repo/manifests.git" clean -fd >/dev/null 2>&1 || true fi + + # Clean up any local changes in the manifests working copy. + if [[ -d "${_manifest_dir}/.repo/manifests" ]]; then + git -C "${_manifest_dir}/.repo/manifests" reset --hard HEAD >/dev/null 2>&1 || true + git -C "${_manifest_dir}/.repo/manifests" clean -fd >/dev/null 2>&1 || true + fi + + mlsdk_sync_manifest "${_manifest_dir}" } -function setup_model_converter() { +function setup_mlsdk() { local work_dir="$1" local manifest_dir="$2" local enable_model_converter="$3" @@ -59,55 +131,44 @@ function setup_model_converter() { local enable_emulation_layer="$5" if [[ -z "$work_dir" ]]; then - echo "Error: work_dir parameter is required." + log_step "mlsdk" "Error: work_dir parameter is required" return 1 fi if [[ -z "$manifest_dir" ]]; then - echo "Error: manifest_dir parameter is required." + log_step "mlsdk" "Error: manifest_dir parameter is required" return 1 fi mkdir -p "$work_dir" pushd "$work_dir" || exit 1 - download_ai_mlsdk_manifest ${manifest_dir} + log_step "mlsdk" "Syncing MLSDK manifest into ${manifest_dir}" + download_ai_mlsdk_manifest "${manifest_dir}" pushd "$manifest_dir" + local parallel_jobs="$(get_parallel_jobs)" # model-converter if [[ "${enable_model_converter}" -eq 1 ]]; then - # TODO: Remove this workaround once MLSDK has full Darwin support - # Do not indent sed command, the whitespace is significant for the patch to work. - if [[ "$(uname)" == "Darwin" ]]; then - sed -i '' '/^ *print(f"Unsupported host platform/ i\ - if system == "Darwin":\ - return True\ -\ -' sw/model-converter/scripts/build.py - fi - python sw/model-converter/scripts/build.py -j$(nproc) + log_step "mlsdk" "Building MLSDK model-converter" + python sw/model-converter/scripts/build.py -j"${parallel_jobs}" + log_step "mlsdk" "MLSDK model-converter build complete" fi # libvgf if [[ "${enable_vgf_lib}" -eq 1 ]]; then - # TODO: Remove this workaround once MLSDK has full Darwin support - # Do not indent sed command, the whitespace is significant for the patch to work. - if [[ "$(uname)" == "Darwin" ]]; then - sed -i '' '/^ *print(f"ERROR: Unsupported host platform/ i\ - if system == "Darwin":\ - return True\ -\ -' sw/vgf-lib/scripts/build.py - fi + log_step "mlsdk" "Building MLSDK VGF library" pushd sw/vgf-lib - python scripts/build.py -j$(nproc) + python scripts/build.py -j"${parallel_jobs}" cmake --install build --prefix deploy + log_step "mlsdk" "MLSDK VGF library build complete" popd fi # emu layer if [[ "${enable_emulation_layer}" -eq 1 ]]; then + log_step "mlsdk" "Building MLSDK Vulkan emulation layer" pushd sw/emulation-layer cmake -B build \ -DGLSLANG_PATH=../../dependencies/glslang \ @@ -116,8 +177,9 @@ function setup_model_converter() { -DSPIRV_TOOLS_PATH=../../dependencies/SPIRV-Tools \ -DVULKAN_HEADERS_PATH=../../dependencies/Vulkan-Headers - cmake --build build + cmake --build build -j"${parallel_jobs}" cmake --install build --prefix deploy + log_step "mlsdk" "MLSDK Vulkan emulation layer build complete" popd fi @@ -126,27 +188,68 @@ function setup_model_converter() { function setup_path_model_converter() { cd "${root_dir}" - model_converter_bin_path="$(cd ${mlsdk_manifest_dir}/sw/model-converter/build && pwd)" - append_env_in_setup_path PATH ${model_converter_bin_path} + model_converter_bin_path="$(cd "${mlsdk_manifest_dir}/sw/model-converter/build" && pwd)" + append_env_in_setup_path PATH "${model_converter_bin_path}" } function setup_path_vgf_lib() { cd "${root_dir}" - model_vgf_path="$(cd ${mlsdk_manifest_dir}/sw/vgf-lib/deploy && pwd)" - append_env_in_setup_path PATH ${model_vgf_path}/bin + model_vgf_path="$(cd "${mlsdk_manifest_dir}/sw/vgf-lib/deploy" && pwd)" + append_env_in_setup_path PATH "${model_vgf_path}/bin" append_env_in_setup_path LD_LIBRARY_PATH "${model_vgf_path}/lib" append_env_in_setup_path DYLD_LIBRARY_PATH "${model_vgf_path}/lib" } function setup_path_emulation_layer() { cd "${root_dir}" - model_emulation_layer_path="$(cd ${mlsdk_manifest_dir}/sw/emulation-layer/ && pwd)" + model_emulation_layer_path="$(cd "${mlsdk_manifest_dir}/sw/emulation-layer/" && pwd)" prepend_env_in_setup_path LD_LIBRARY_PATH "${model_emulation_layer_path}/deploy/lib" prepend_env_in_setup_path DYLD_LIBRARY_PATH "${model_emulation_layer_path}/deploy/lib" + prepend_env_in_setup_path VK_LAYER_PATH "${model_emulation_layer_path}/deploy/share/vulkan/explicit_layer.d" prepend_env_in_setup_path VK_INSTANCE_LAYERS VK_LAYER_ML_Tensor_Emulation prepend_env_in_setup_path VK_INSTANCE_LAYERS VK_LAYER_ML_Graph_Emulation - prepend_env_in_setup_path VK_LAYER_PATH "${model_emulation_layer_path}/deploy/share/vulkan/explicit_layer.d" } -#setup_model_converter() $1 -# `"$manifest_dir"' +function setup_path_emulation_layer_from_pip() { + if ! command -v emulation_layer >/dev/null 2>&1; then + echo "[mlsdk_utils] 'emulation_layer' command not found; skipping pip emulation layer path setup" + return + fi + + local output + if ! output=$(emulation_layer 2>/dev/null); then + echo "[mlsdk_utils] Failed to query emulation_layer environment; skipping" + return + fi + + local exports + exports=$(echo "$output" | grep '^export ' || true) + + local ld_line + ld_line=$(echo "$exports" | grep 'LD_LIBRARY_PATH=' || true) + if [[ -n "${ld_line}" ]]; then + local ld_value=${ld_line#export LD_LIBRARY_PATH=} + ld_value=${ld_value%%:\$LD_LIBRARY_PATH*} + if [[ -n "${ld_value}" ]]; then + prepend_env_in_setup_path LD_LIBRARY_PATH "${ld_value}" + fi + fi + + local vk_add_line + vk_add_line=$(echo "$exports" | grep 'VK_ADD_LAYER_PATH=' || true) + if [[ -n "${vk_add_line}" ]]; then + local vk_add_value=${vk_add_line#export VK_ADD_LAYER_PATH=} + if [[ -n "${vk_add_value}" ]]; then + prepend_env_in_setup_path VK_ADD_LAYER_PATH "${vk_add_value}" + fi + fi + + local vk_instance_line + vk_instance_line=$(echo "$exports" | grep 'VK_INSTANCE_LAYERS=' || true) + if [[ -n "${vk_instance_line}" ]]; then + local vk_instance_value=${vk_instance_line#export VK_INSTANCE_LAYERS=} + if [[ -n "${vk_instance_value}" ]]; then + prepend_env_in_setup_path VK_INSTANCE_LAYERS "${vk_instance_value}" + fi + fi +} diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index c6eaafa597b..f570fd222eb 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -7,6 +7,7 @@ # Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here. CUSTOM_EDGE_OPS = [ "linspace.default", + "cond.default", "eye.default", "expm1.default", "vector_norm.default", @@ -14,23 +15,40 @@ "hardswish.default", "linear.default", "maximum.default", + "mean.default", "multihead_attention.default", "adaptive_avg_pool2d.default", "bitwise_right_shift.Tensor", + "bitwise_right_shift.Scalar", "bitwise_left_shift.Tensor", + "bitwise_left_shift.Scalar", "native_group_norm.default", "silu.default", "sdpa.default", + "sum.default", "unbind.int", "unflatten.int", "_native_batch_norm_legit_no_training.default", "_native_batch_norm_legit.no_stats", "alias_copy.default", + "pixel_shuffle.default", + "pixel_unshuffle.default", + "while_loop.default", + "clamp.Tensor", ] ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS # Add all targets and TOSA profiles we support here. -TARGETS = ["tosa_FP", "tosa_INT", "u55_INT", "u85_INT", "vgf_INT", "vgf_FP"] +TARGETS = [ + "tosa_FP", + "tosa_INT", + "u55_INT", + "u85_INT", + "vgf_INT", + "vgf_FP", + "vgf_quant", + "vgf_no_quant", +] def get_op_name_map(): @@ -94,6 +112,10 @@ def parse_test_name( # Special case for convolution op = op.removesuffix("_1d") op = op.removesuffix("_2d") + op = op.removesuffix("_3d") + + # Remove suffix for 16 bit activation and 8 bit weight test cases + op = op.removesuffix("_16a8w") assert target != "None", f"{test_name} does not contain one of {TARGETS}" assert ( diff --git a/backends/arm/scripts/pre-push b/backends/arm/scripts/pre-push index a4e877fdcfc..17c6dc04ef0 100755 --- a/backends/arm/scripts/pre-push +++ b/backends/arm/scripts/pre-push @@ -33,7 +33,7 @@ VERBS="Add|Fix|Update|Refactor|Improve|Remove|Change|Implement|Create|Modify|"\ "Handle|Ignore|Interpret|Instantiate|Invoke|Limit|Load|Modify|Permit|Print|"\ "Profile|Recalculate|Reconstruct|Redefine|Redesign|Reevaluate|Relocate|Remap|"\ "Render|Reposition|Request|Revert|Sanitize|Specify|Strengthen|Stub|Substitute|"\ -"Tag|Tweak|Unify|Unlock|Unset|Use|Validate|Verify|Rename" +"Tag|Tweak|Unify|Unlock|Unset|Use|Validate|Verify|Rename|Relax" # Remote branch REMOTE=$(git rev-parse --abbrev-ref --symbolic-full-name @{u} 2>/dev/null) @@ -68,6 +68,11 @@ if [ -z "$COMMITS" ]; then fi for COMMIT in ${COMMITS}; do + if [ -n "$REMOTE" ] && git merge-base --is-ancestor "$COMMIT" "$REMOTE"; then + echo -e "${INFO} Skipping commit ${COMMIT} (already on $REMOTE)" + continue + fi + # If commit header contains WIP, everything is ok git rev-list --format=%s --max-count=1 ${COMMIT} | grep -q WIP && \ continue @@ -95,8 +100,14 @@ for COMMIT in ${COMMITS}; do commit_files=$(git diff-tree --no-commit-id --name-only \ --diff-filter=ACMR ${COMMIT} -r) for commited_file in $commit_files; do - head $commited_file | grep -q "$current_year Arm" - if [[ $? != 0 ]]; then + file_header=$(head "$commited_file") + if ! echo "$file_header" | grep -qi "Arm"; then + echo -e "${WARNING} No Arm copyright header in ${commited_file}"\ + " (skipping license year check)" + continue + fi + + if ! echo "$file_header" | grep -q "$current_year Arm"; then echo -e "${ERROR} Header in $commited_file did not contain"\ "'$current_year Arm'" failed_license_check=true diff --git a/backends/arm/scripts/run_fvp.sh b/backends/arm/scripts/run_fvp.sh index 769b2e30282..7dcebb2d859 100755 --- a/backends/arm/scripts/run_fvp.sh +++ b/backends/arm/scripts/run_fvp.sh @@ -13,7 +13,7 @@ set -eu script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) et_root_dir=$(cd ${script_dir}/../../.. && pwd) et_root_dir=$(realpath ${et_root_dir}) -setup_path_script=${et_root_dir}/examples/arm/ethos-u-scratch/setup_path.sh +setup_path_script=${et_root_dir}/examples/arm/arm-scratch/setup_path.sh _setup_msg="please refer to ${et_root_dir}/examples/arm/setup.sh to properly install necessary tools." @@ -21,6 +21,8 @@ elf_file="" data_file="" target="ethos-u55-128" timeout="600" +etrecord_file="" +trace_file="" help() { echo "Usage: $(basename $0) [options]" @@ -29,6 +31,8 @@ help() { echo " --data=@
Place a file in memory at this address, useful to emulate a PTE flashed into memory instead as part of the code." echo " --target= Target to build and run for Default: ${target}" echo " --timeout= Maximum target runtime, used to detect hanging, might need to be higer on large models Default: ${timeout}" + echo " --etrecord= If ETDump is used you can supply a ETRecord file matching the PTE" + echo " --trace_file= File to write PMU trace output to" exit 0 } @@ -39,6 +43,8 @@ for arg in "$@"; do --data=*) data_file="--data ${arg#*=}";; --target=*) target="${arg#*=}";; --timeout=*) timeout="${arg#*=}";; + --etrecord=*) etrecord_file="${arg#*=}";; + --trace_file=*) trace_file="${arg#*=}";; *) ;; esac @@ -83,6 +89,14 @@ fi log_file=$(mktemp) +extra_args_u55=() +extra_args_u85=() + +if [[ -n "${trace_file}" ]]; then + extra_args_u55+=(-C "ethosu.extra_args=--pmu-trace ${trace_file}") + extra_args_u85+=(-C "mps4_board.subsystem.ethosu.extra_args=--pmu-trace ${trace_file}") +fi + if [[ ${target} == *"ethos-u55"* ]]; then ${nobuf} ${fvp_model} \ -C ethosu.num_macs=${num_macs} \ @@ -90,6 +104,7 @@ if [[ ${target} == *"ethos-u55"* ]]; then -C mps3_board.telnetterminal0.start_telnet=0 \ -C mps3_board.uart0.out_file='-' \ -C mps3_board.uart0.shutdown_on_eot=1 \ + "${extra_args_u55[@]}" \ -a "${elf_file}" \ ${data_file} \ --timelimit ${timeout} 2>&1 | sed 's/\r$//' | tee ${log_file} || true # seconds @@ -102,6 +117,7 @@ elif [[ ${target} == *"ethos-u85"* ]]; then -C mps4_board.telnetterminal0.start_telnet=0 \ -C mps4_board.uart0.out_file='-' \ -C mps4_board.uart0.shutdown_on_eot=1 \ + "${extra_args_u85[@]}" \ -a "${elf_file}" \ ${data_file} \ --timelimit ${timeout} 2>&1 | sed 's/\r$//' | tee ${log_file} || true # seconds @@ -115,15 +131,23 @@ echo "Checking for a etdump in log" ! grep "#\[RUN THIS\]" ${log_file} >/dev/null if [ $? != 0 ]; then echo "Found ETDump in log!" + devtools_extra_args="" echo "#!/bin/sh" > etdump_script.sh sed -n '/^#\[RUN THIS\]$/,/^#\[END\]$/p' ${log_file} >> etdump_script.sh # You can run etdump_script.sh if you do # $ chmod a+x etdump_script.sh # $ ./etdump_script.sh # But lets not trust the script as a bad patch would run bad code on your machine - grep ">etdump.bin" etdump_script.sh | cut -d\" -f2- | cut -d\" -f1 >etdump.base64 - base64 -d etdump.base64 >etdump.bin - python3 -m devtools.inspector.inspector_cli --etdump_path etdump.bin --source_time_scale cycles --target_time_scale cycles + grep ">etdump.bin" etdump_script.sh | cut -d\" -f2- | cut -d\" -f1 | base64 -d >etdump.bin + ! grep ">debug_buffer.bin" etdump_script.sh >/dev/null + if [ $? != 0 ]; then + grep ">debug_buffer.bin" etdump_script.sh | cut -d\" -f2- | cut -d\" -f1 | base64 -d >debug_buffer.bin + devtools_extra_args="${devtools_extra_args} --debug_buffer_path debug_buffer.bin" + fi + if [[ ${etrecord_file} != "" ]]; then + devtools_extra_args="${devtools_extra_args} --etrecord_path ${etrecord_file}" + fi + python3 -m devtools.inspector.inspector_cli --etdump_path etdump.bin ${devtools_extra_args} --source_time_scale cycles --target_time_scale cycles fi echo "Checking for problems in log:" diff --git a/backends/arm/scripts/run_vkml.sh b/backends/arm/scripts/run_vkml.sh index 8a64a937638..d65600e7eff 100755 --- a/backends/arm/scripts/run_vkml.sh +++ b/backends/arm/scripts/run_vkml.sh @@ -14,11 +14,12 @@ set -o pipefail script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) et_root_dir=$(cd ${script_dir}/../../.. && pwd) et_root_dir=$(realpath ${et_root_dir}) -setup_path_script=${et_root_dir}/examples/arm/ethos-u-scratch/setup_path.sh +setup_path_script=${et_root_dir}/examples/arm/arm-scratch/setup_path.sh _setup_msg="please refer to ${et_root_dir}/examples/arm/setup.sh to properly install necessary tools." model="" +opt_flags="" build_path="cmake-out-vkml" converter="model-converter" @@ -33,6 +34,7 @@ help() { for arg in "$@"; do case $arg in -h|--help) help ;; + --optional_flags=*) opt_flags="${arg#*=}";; --model=*) model="${arg#*=}";; --build_path=*) build_path="${arg#*=}";; *) @@ -50,16 +52,21 @@ if [[ -z ${model} ]]; then echo "Model name needs to be provided"; exit 1; fi source ${setup_path_script} -# basic checks before we get started -hash ${converter} \ - || { echo "Could not find ${converter} on PATH, ${_setup_msg}"; exit 1; } +if ! command -v "${converter}" >/dev/null 2>&1; then + if command -v model_converter >/dev/null 2>&1; then + converter="model_converter" + fi +fi + +command -v "${converter}" >/dev/null 2>&1 \ + || { echo "Could not find a model converter executable (tried model-converter, model_converter). ${_setup_msg}"; exit 1; } +runner=$(find ${build_path} -name executor_runner -type f) -runner="${build_path}/executor_runner" echo "--------------------------------------------------------------------------------" -echo "Running ${model} with ${runner}" +echo "Running ${model} with ${runner} ${opt_flags}" echo "WARNING: The VK_ML layer driver will not provide accurate performance information" echo "--------------------------------------------------------------------------------" @@ -75,7 +82,7 @@ fi log_file=$(mktemp) -${nobuf} ${runner} -model_path ${model} | tee ${log_file} +${nobuf} ${runner} -model_path ${model} ${opt_flags} | tee ${log_file} echo "[${BASH_SOURCE[0]}] execution complete, $?" # Most of these can happen for bare metal or linx executor_runner runs. diff --git a/backends/arm/scripts/toolchain_utils.sh b/backends/arm/scripts/toolchain_utils.sh index 161bfd29cd1..d9e1cf0f0ca 100644 --- a/backends/arm/scripts/toolchain_utils.sh +++ b/backends/arm/scripts/toolchain_utils.sh @@ -17,6 +17,9 @@ if [[ $? -ne 0 ]]; then exit 1 fi +script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source "${script_dir}/utils.sh" + function gcc_select_toolchain() { if [[ "${ARCH}" == "x86_64" ]] ; then toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/13.3.rel1/binrel/arm-gnu-toolchain-13.3.rel1-x86_64-arm-none-eabi.tar.xz" @@ -34,13 +37,15 @@ function gcc_select_toolchain() { fi else # This should never happen, it should be covered by setup.sh but catch it anyway - echo "[gcc_select_toolchain]: Unsupported arch!"; exit 1 + log_step "toolchain" "Error: Unsupported architecture ${ARCH}" + exit 1 fi } function zephyr_select_toolchain() { if [[ "${OS}" != "Linux" ]] ; then - echo "[zephyr_select_toolchain] Error: Linux is only supported for zephyr!"; exit 1; + log_step "toolchain" "Error: Linux is required for Zephyr toolchain support" + exit 1 fi if [[ "${ARCH}" == "x86_64" ]] ; then @@ -53,7 +58,8 @@ function zephyr_select_toolchain() { toolchain_md5_checksum="ef4ca56786204439a75270ba800cc64b" else # This should never happen, it should be covered by setup.sh but catch it anyway - echo "[zephyr_select_toolchain]: Unsupported arch!"; exit 1 + log_step "toolchain" "Error: Unsupported architecture ${ARCH}" + exit 1 fi } @@ -63,7 +69,7 @@ function select_toolchain() { else gcc_select_toolchain fi - echo "[main] Info selected ${toolchain_dir} for ${ARCH} - ${OS} platform" + log_step "toolchain" "Selected ${toolchain_dir} for ${ARCH}/${OS}" } function setup_toolchain() { @@ -71,12 +77,12 @@ function setup_toolchain() { # setting --target-toolchain to zephyr sets this to arm-zephyr-eabi cd "${root_dir}" if [[ ! -e "${toolchain_dir}.tar.xz" ]]; then - echo "[${FUNCNAME[0]}] Downloading ${toolchain_dir} toolchain ..." + log_step "toolchain" "Downloading ${toolchain_dir} toolchain" curl --output "${toolchain_dir}.tar.xz" -L "${toolchain_url}" verify_md5 ${toolchain_md5_checksum} "${toolchain_dir}.tar.xz" || exit 1 fi - echo "[${FUNCNAME[0]}] Installing ${toolchain_dir} toolchain ..." + log_step "toolchain" "Installing ${toolchain_dir} toolchain" rm -rf "${toolchain_dir}" tar xf "${toolchain_dir}.tar.xz" } diff --git a/backends/arm/scripts/utils.sh b/backends/arm/scripts/utils.sh index d0c1dadbb3e..1b4a9205300 100644 --- a/backends/arm/scripts/utils.sh +++ b/backends/arm/scripts/utils.sh @@ -14,6 +14,33 @@ if [[ $? -ne 0 ]]; then exit 1 fi +# Usage: +# log_step +# eg. +# log_step "step" "information message" +# outputs: +# [setup/step] information message +function log_step() { + local context="${1:-main}" + shift || true + local message="$*" + printf "[Arm Setup/%s] %s\n" "${context}" "${message}" +} + +function get_parallel_jobs() { + if command -v nproc >/dev/null 2>&1; then + nproc + elif command -v sysctl >/dev/null 2>&1 && sysctl hw.logicalcpu >/dev/null 2>&1; then + sysctl -n hw.logicalcpu + elif command -v getconf >/dev/null 2>&1; then + getconf _NPROCESSORS_ONLN + elif [[ -n "${NUMBER_OF_PROCESSORS:-}" ]]; then + echo "${NUMBER_OF_PROCESSORS}" + else + echo 1 + fi +} + function verify_md5() { # Compare the md5 of a file with a provided expected value. diff --git a/backends/arm/scripts/vulkan_utils.sh b/backends/arm/scripts/vulkan_utils.sh index c22d8f26a0c..127ef33741e 100644 --- a/backends/arm/scripts/vulkan_utils.sh +++ b/backends/arm/scripts/vulkan_utils.sh @@ -14,6 +14,9 @@ if [[ $? -ne 0 ]]; then exit 1 fi +script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source "${script_dir}/utils.sh" + vulkan_sdk_version="1.4.321.1" vulkan_sdk_base_dir="vulkan_sdk" @@ -32,21 +35,21 @@ elif [[ "${ARCH}" == "aarch64" ]] || [[ "${ARCH}" == "arm64" ]]; then vulkan_sdk_url="https://github.com/jakoch/vulkan-sdk-arm/releases/download/1.4.321.1/vulkansdk-ubuntu-22.04-arm-1.4.321.1.tar.xz" vulkan_sdk_sha256="c57e318d0940394d3a304034bb7ddabda788b5b0b54638e80e90f7264efe9f84" else - echo "[main] Error: only x86-64 & aarch64/arm64 architecture is supported for now!"; exit 1; + log_step "vulkan" "Error: only x86-64 & aarch64/arm64 architecture is supported for now!" + exit 1 fi function setup_vulkan_sdk() { - if command -v glslc > /dev/null 2>&1; then - echo "[${FUNCNAME[0]}] GLSL already installed, no need to get Vulkan SDK..." - enable_vulkan_sdk=0 + cd "${root_dir}" + + if command -v glslc >/dev/null 2>&1; then + log_step "vulkan" "Detected existing GLSLC; skipping Vulkan SDK download" return fi - cd "${root_dir}" - vulkan_sdk_tar_file="${vulkan_sdk_url##*/}" if [[ ! -e "${vulkan_sdk_tar_file}" ]]; then - echo "[${FUNCNAME[0]}] Downloading Vulkan SDK - ${vulkan_sdk_url}.." + log_step "vulkan" "Downloading Vulkan SDK (${vulkan_sdk_version})" curl -L --output "${vulkan_sdk_tar_file}" "${vulkan_sdk_url}" echo "${vulkan_sdk_sha256} ${vulkan_sdk_tar_file}" | sha256sum -c - rm -fr ${vulkan_sdk_base_dir} @@ -57,9 +60,9 @@ function setup_vulkan_sdk() { vulkan_sdk_bin_path="$(cd ${vulkan_sdk_bin_dir} && pwd)" if ${vulkan_sdk_bin_path}/glslc --version > /dev/null 2>&1; then - echo "[${FUNCNAME[0]}] Vulkan SDK install (GLSL) OK" + log_step "vulkan" "Vulkan SDK validation (glslc) succeeded" else - echo "[${FUNCNAME[0]}] Vulkan SDK install NOK - glslc returned error" + log_step "vulkan" "Error: Vulkan SDK validation failed" ${vulkan_sdk_bin_path}/glslc --version exit 1 fi @@ -67,6 +70,22 @@ function setup_vulkan_sdk() { function setup_path_vulkan() { cd "${root_dir}" + if [[ ! -d "${root_dir}/${vulkan_sdk_bin_dir}" ]]; then + log_step "vulkan" "Vulkan SDK not found; skipping PATH update" + return + fi + + local vulkan_sdk_arch_root="${vulkan_sdk_base_dir}/${vulkan_sdk_version}/${ARCH}" + + if [[ ! -d "${vulkan_sdk_arch_root}" ]]; then + log_step "vulkan" "Vulkan SDK arch path not found; skipping PATH update" + return + fi + + vulkan_sdk_arch_root="$(cd ${vulkan_sdk_arch_root} && pwd)" vulkan_sdk_bin_path="$(cd ${vulkan_sdk_bin_dir} && pwd)" + append_env_in_setup_path PATH ${vulkan_sdk_bin_path} + prepend_env_in_setup_path LD_LIBRARY_PATH "${vulkan_sdk_arch_root}/lib" + prepend_env_in_setup_path VULKAN_SDK "${vulkan_sdk_arch_root}" } diff --git a/backends/arm/test/TARGETS b/backends/arm/test/TARGETS index c27d00590f3..fd7d894fbf0 100644 --- a/backends/arm/test/TARGETS +++ b/backends/arm/test/TARGETS @@ -1,3 +1,8 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load(":targets.bzl", "define_arm_tests") @@ -19,7 +24,11 @@ runtime.python_library( srcs = ["runner_utils.py"], deps = [ ":conftest", - "//executorch/backends/arm:arm_backend", + "//executorch/backends/arm:arm_compile_spec", + "//executorch/backends/arm:ethosu", + "//executorch/backends/arm/tosa:compile_spec", + "//executorch/backends/arm:vgf", + "//executorch/backends/arm/tosa:specification", "//executorch/exir:lib", "//executorch/exir/backend:compile_spec_schema", ] @@ -36,19 +45,38 @@ runtime.python_library( ) runtime.python_library( - name = "arm_tester", - srcs = glob(["tester/*.py"]), + name = "arm_tester_serialize", + srcs = ["tester/serialize.py"], + deps = [ + "//executorch/backends/xnnpack/test/tester:tester", + "//executorch/devtools/backend_debug:delegation_info", + ] +) + +runtime.python_library( + name = "arm_tester_lib", + srcs = glob(["tester/*.py"], exclude = ["tester/serialize.py"]), deps = [ ":common", "//executorch/backends/xnnpack/test/tester:tester", - "//executorch/backends/arm:ethosu_partitioner", + "//executorch/backends/arm:ethosu", "//executorch/backends/arm/quantizer:lib", "//executorch/backends/arm/tosa:mapping", - "//executorch/backends/arm:vgf_partitioner", + "//executorch/backends/arm:vgf", + "//executorch/backends/arm:_factory", "//executorch/devtools/backend_debug:delegation_info", "//executorch/exir/backend:operator_support", "fbsource//third-party/pypi/tabulate:tabulate", ] ) + +runtime.python_library( + name = "arm_tester", + deps = [ + "//executorch/backends/arm/test:arm_tester_lib", + "//executorch/backends/arm/test:arm_tester_serialize", + ] +) + define_arm_tests() diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 608c273b2ef..c2522941215 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -10,10 +10,11 @@ from datetime import datetime from pathlib import Path -from typing import Any, Optional +from typing import Any, Callable, Optional, ParamSpec, TypeVar import pytest -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +from executorch.backends.arm.ethosu import EthosUCompileSpec + from executorch.backends.arm.test.runner_utils import ( arm_executor_runner_exists, corstone300_installed, @@ -22,7 +23,8 @@ vkml_emulation_layer_installed, ) from executorch.backends.arm.tosa import TosaSpecification -from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.vgf import VgfCompileSpec def get_time_formatted_path(path: str, log_prefix: str) -> str: @@ -65,116 +67,34 @@ def maybe_get_tosa_collate_path() -> str | None: def get_tosa_compile_spec( tosa_spec: str | TosaSpecification, custom_path: Optional[str] = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, -) -> list[CompileSpec]: - """ - Default compile spec for TOSA tests. - """ - return get_tosa_compile_spec_unbuilt( - tosa_spec, - custom_path, - tosa_debug_mode, - ).build() - - -def get_tosa_compile_spec_unbuilt( - tosa_spec: str | TosaSpecification, - custom_path: Optional[str], - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], -) -> ArmCompileSpecBuilder: - """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify - the compile spec before calling .build() to finalize it. - """ + tosa_debug_mode: TosaCompileSpec.DebugMode | None = None, +) -> TosaCompileSpec: + """Get the compile spec for default TOSA tests.""" if not custom_path: custom_path = maybe_get_tosa_collate_path() - if custom_path is not None: os.makedirs(custom_path, exist_ok=True) - compile_spec_builder = ( - ArmCompileSpecBuilder() - .tosa_compile_spec(tosa_spec) + compile_spec = ( + TosaCompileSpec(tosa_spec) .dump_intermediate_artifacts_to(custom_path) + .dump_debug_info(tosa_debug_mode) ) - - if tosa_debug_mode is not None: - compile_spec_builder.dump_debug_info(tosa_debug_mode) - - return compile_spec_builder + return compile_spec def get_u55_compile_spec( macs: int = 128, system_config: str = "Ethos_U55_High_End_Embedded", memory_mode: str = "Shared_Sram", - extra_flags: str = "--debug-force-regor --output-format=raw", - custom_path: Optional[str] = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, - config: Optional[str] = "Arm/vela.ini", -) -> list[CompileSpec]: - """ - Compile spec for Ethos-U55. - """ - return get_u55_compile_spec_unbuilt( - macs=macs, - system_config=system_config, - memory_mode=memory_mode, - extra_flags=extra_flags, - custom_path=custom_path, - tosa_debug_mode=tosa_debug_mode, - config=config, - ).build() - - -def get_u85_compile_spec( - macs: int = 128, - system_config: str = "Ethos_U85_SYS_DRAM_Mid", - memory_mode: str = "Shared_Sram", - extra_flags: str = "--output-format=raw", + extra_flags: str = "--debug-force-regor --output-format=raw --arena-cache-size=2097152", custom_path: Optional[str] = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, - config: Optional[str] = "Arm/vela.ini", -) -> list[CompileSpec]: - """ - Compile spec for Ethos-U85. - """ - return get_u85_compile_spec_unbuilt( # type: ignore[attr-defined] - macs=macs, - system_config=system_config, - memory_mode=memory_mode, - extra_flags=extra_flags, - custom_path=custom_path, - tosa_debug_mode=tosa_debug_mode, - config=config, - ).build() - - -def get_vgf_compile_spec( - tosa_spec: str | TosaSpecification, - compiler_flags: Optional[str] = "", - custom_path: Optional[str] = "", - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, -) -> list[CompileSpec]: - """ - Default compile spec for VGF tests. - """ - return get_vgf_compile_spec_unbuilt( - tosa_spec, compiler_flags, custom_path, tosa_debug_mode - ).build() - - -def get_u55_compile_spec_unbuilt( - macs: int, - system_config: str, - memory_mode: str, - extra_flags: str, - custom_path: Optional[str], - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], - config: Optional[str], -) -> ArmCompileSpecBuilder: - """Get the ArmCompileSpecBuilder for the Ethos-U55 tests, to modify - the compile spec before calling .build() to finalize it. - """ + config: Optional[str] = None, + tosa_debug_mode: EthosUCompileSpec.DebugMode | None = None, +) -> EthosUCompileSpec: + """Default compile spec for Ethos-U55 tests.""" + if not custom_path: + custom_path = maybe_get_tosa_collate_path() artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_u55_") if not os.path.exists(artifact_path): os.makedirs(artifact_path, exist_ok=True) @@ -182,89 +102,105 @@ def get_u55_compile_spec_unbuilt( # https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/OPTIONS.md assert macs in [32, 64, 128, 256], "Unsupported MACs value" + if extra_flags is not None: + extra_flags_list = extra_flags.split(" ") + else: + extra_flags_list = [] compile_spec = ( - ArmCompileSpecBuilder() - .ethosu_compile_spec( + EthosUCompileSpec( f"ethos-u55-{macs}", system_config=system_config, memory_mode=memory_mode, - extra_flags=extra_flags, + extra_flags=extra_flags_list, config_ini=config, ) .dump_intermediate_artifacts_to(artifact_path) + .dump_debug_info(tosa_debug_mode) ) - - if tosa_debug_mode is not None: - compile_spec.dump_debug_info(tosa_debug_mode) - return compile_spec -def get_u85_compile_spec_unbuilt( - macs: int, - system_config: str, - memory_mode: str, - extra_flags: str, - custom_path: Optional[str], - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], - config: Optional[str], -) -> list[CompileSpec]: - """Get the ArmCompileSpecBuilder for the Ethos-U85 tests, to modify - the compile spec before calling .build() to finalize it. - """ +def get_u85_compile_spec( + macs: int = 128, + system_config="Ethos_U85_SYS_DRAM_Mid", + memory_mode="Shared_Sram", + extra_flags="--output-format=raw --arena-cache-size=2097152", + custom_path: Optional[str] = None, + config: Optional[str] = None, + tosa_debug_mode: EthosUCompileSpec.DebugMode | None = None, +) -> EthosUCompileSpec: + """Default compile spec for Ethos-U85 tests.""" + + if not custom_path: + custom_path = maybe_get_tosa_collate_path() artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_u85_") if not os.path.exists(artifact_path): os.makedirs(artifact_path, exist_ok=True) assert macs in [128, 256, 512, 1024, 2048], "Unsupported MACs value" + if extra_flags is not None: + extra_flags_list = extra_flags.split(" ") + else: + extra_flags_list = [] + compile_spec = ( - ArmCompileSpecBuilder() - .ethosu_compile_spec( + EthosUCompileSpec( f"ethos-u85-{macs}", system_config=system_config, memory_mode=memory_mode, - extra_flags=extra_flags, + extra_flags=extra_flags_list, config_ini=config, ) .dump_intermediate_artifacts_to(artifact_path) + .dump_debug_info(tosa_debug_mode) ) - - if tosa_debug_mode is not None: - compile_spec.dump_debug_info(tosa_debug_mode) - return compile_spec # type: ignore[return-value] -def get_vgf_compile_spec_unbuilt( +def get_vgf_compile_spec( tosa_spec: str | TosaSpecification, - compiler_flags: Optional[str], - custom_path: Optional[str], - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode], -) -> ArmCompileSpecBuilder: - """Get the ArmCompileSpecBuilder for the default VGF tests, to modify + compiler_flags: Optional[str] = "", + custom_path: Optional[str] = None, + tosa_debug_mode: VgfCompileSpec.DebugMode | None = None, +) -> VgfCompileSpec: + """Get the ArmCompileSpec for the default VGF tests, to modify the compile spec before calling .build() to finalize it. """ + + if not custom_path: + custom_path = maybe_get_tosa_collate_path() + profiles = [] if "FP" in repr(tosa_spec): - artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_vgf_fp_") - elif "INT" in repr(tosa_spec): - artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_vgf_int_") - else: + profiles.append("fp") + if "INT" in repr(tosa_spec): + profiles.append("int") + if len(profiles) == 0: raise ValueError(f"Unsupported vgf compile_spec: {repr(tosa_spec)}") + if custom_path is None: + artifact_path = "arm_vgf_" + for profile in profiles: + artifact_path = artifact_path + f"_{profile}" + artifact_path = tempfile.mkdtemp(artifact_path) + else: + artifact_path = custom_path + if not os.path.exists(artifact_path): os.makedirs(artifact_path, exist_ok=True) - compile_spec_builder = ( - ArmCompileSpecBuilder() - .vgf_compile_spec(tosa_spec, compiler_flags) + if compiler_flags is not None: + compiler_flags_list = compiler_flags.split(" ") + else: + compiler_flags_list = [] + + compile_spec = ( + VgfCompileSpec(tosa_spec, compiler_flags_list) .dump_intermediate_artifacts_to(artifact_path) + .dump_debug_info(tosa_debug_mode) ) - if tosa_debug_mode is not None: - compile_spec_builder.dump_debug_info(tosa_debug_mode) - - return compile_spec_builder + return compile_spec XfailIfNoCorstone300 = pytest.mark.xfail( @@ -285,7 +221,7 @@ def get_vgf_compile_spec_unbuilt( ) """Xfails a test if Corsone320 FVP is not installed, or if the executor runner is not built""" -SkipIfNoModelConverter = pytest.mark.skipif( +SkipIfNoModelConverter = pytest.mark.skipif( # type: ignore[call-arg] condition=not (model_converter_installed()), raises=FileNotFoundError, reason="Did not find model-converter on path", @@ -301,13 +237,19 @@ def get_vgf_compile_spec_unbuilt( xfail_type = str | tuple[str, type[Exception]] +_P = ParamSpec("_P") +_R = TypeVar("_R") +Decorator = Callable[[Callable[_P, _R]], Callable[_P, _R]] + def parametrize( arg_name: str, test_data: dict[str, Any], xfails: dict[str, xfail_type] | None = None, + skips: dict[str, str] | None = None, strict: bool = True, -): + flakies: dict[str, int] | None = None, +) -> Decorator: """ Custom version of pytest.mark.parametrize with some syntatic sugar and added xfail functionality - test_data is expected as a dict of (id, test_data) pairs @@ -317,12 +259,22 @@ def parametrize( """ if xfails is None: xfails = {} + if skips is None: + skips = {} + if flakies is None: + flakies = {} - def decorator_func(func): + def decorator_func(func: Callable[_P, _R]) -> Callable[_P, _R]: """Test data is transformed from a dict of (id, data) pairs to a list of pytest params to work with the native pytests parametrize function""" pytest_testsuite = [] for id, test_parameters in test_data.items(): - if id in xfails: + if id in flakies: + # Mark this parameter as flaky with given reruns + marker = (pytest.mark.flaky(reruns=flakies[id]),) + elif id in skips: + # fail markers do not work with 'buck' based ci, so use skip instead + marker = (pytest.mark.skip(reason=skips[id]),) + elif id in xfails: xfail_info = xfails[id] reason = "" raises = None @@ -335,14 +287,16 @@ def decorator_func(func): "xfail info needs to be str, or tuple[str, type[Exception]]" ) # Set up our fail marker + marker: tuple[pytest.MarkDecorator, ...] # type: ignore[no-redef] marker = ( pytest.mark.xfail(reason=reason, raises=raises, strict=strict), ) else: - marker = () + marker = () # type: ignore[assignment] pytest_param = pytest.param(test_parameters, id=id, marks=marker) pytest_testsuite.append(pytest_param) - return pytest.mark.parametrize(arg_name, pytest_testsuite)(func) + decorator = pytest.mark.parametrize(arg_name, pytest_testsuite) + return decorator(func) return decorator_func diff --git a/backends/arm/test/conftest.py b/backends/arm/test/conftest.py index 6fc9e7e5adc..c33f551b2a6 100644 --- a/backends/arm/test/conftest.py +++ b/backends/arm/test/conftest.py @@ -25,15 +25,11 @@ def pytest_configure(config): if getattr(config.option, "llama_inputs", False) and config.option.llama_inputs: pytest._test_options["llama_inputs"] = config.option.llama_inputs # type: ignore[attr-defined] - pytest._test_options["fast_fvp"] = False # type: ignore[attr-defined] - if getattr(config.option, "fast_fvp", False): - pytest._test_options["fast_fvp"] = config.option.fast_fvp # type: ignore[attr-defined] - pytest._test_options["tosa_version"] = "1.0" # type: ignore[attr-defined] if config.option.arm_run_tosa_version: pytest._test_options["tosa_version"] = config.option.arm_run_tosa_version - logging.basicConfig(level=logging.INFO, stream=sys.stdout) + logging.basicConfig(stream=sys.stdout) def pytest_collection_modifyitems(config, items): @@ -44,12 +40,11 @@ def pytest_addoption(parser): def try_addoption(*args, **kwargs): try: parser.addoption(*args, **kwargs) - except Exception: + except Exception: # nosec B110 - pytest redefines options, safe to ignore pass try_addoption("--arm_quantize_io", action="store_true", help="Deprecated.") try_addoption("--arm_run_corstoneFVP", action="store_true", help="Deprecated.") - try_addoption("--fast_fvp", action="store_true") try_addoption( "--llama_inputs", nargs="+", @@ -90,7 +85,7 @@ def set_random_seed(): if os.environ.get("ARM_TEST_SEED", "RANDOM") == "RANDOM": random.seed() # reset seed, in case any other test has fiddled with it - seed = random.randint(0, 2**32 - 1) + seed = random.randint(0, 2**32 - 1) # nosec B311 - non-crypto seed for tests torch.manual_seed(seed) else: seed_str = os.environ.get("ARM_TEST_SEED", "0") @@ -118,7 +113,7 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: a RuntimeError instead of returning False. """ - if option in pytest._test_options and pytest._test_options[option]: # type: ignore[attr-defined] + if hasattr(pytest, "_test_options") and option in pytest._test_options and pytest._test_options[option]: # type: ignore[attr-defined] return True else: if fail_if_not_enabled: diff --git a/backends/arm/test/misc/test_bn_relu_folding_qat.py b/backends/arm/test/misc/test_bn_relu_folding_qat.py index c88c38e869d..f2452c348f6 100644 --- a/backends/arm/test/misc/test_bn_relu_folding_qat.py +++ b/backends/arm/test/misc/test_bn_relu_folding_qat.py @@ -6,13 +6,13 @@ from typing import Tuple import torch -import torch.nn.functional as F from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_quantization_config, TOSAQuantizer, ) -from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT +from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.xnnpack.test.tester.tester import Quantize from torch import nn @@ -21,43 +21,97 @@ input_t1 = Tuple[torch.Tensor] # Input x -class ConvModule(torch.nn.Module): +class Conv2dModule(torch.nn.Module): input_shape = (1, 28, 28) batch_size = 64 test_data: input_t1 = (torch.randn(batch_size, *input_shape),) - def __init__(self, batch_norm: bool = True) -> None: + def __init__(self, batch_norm: bool = True, inplace: bool = False) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 16, 3, stride=2) self.bn = nn.BatchNorm2d(num_features=16) if batch_norm else nn.Identity() + self.relu = nn.ReLU(inplace=inplace) def forward(self, x: torch.Tensor): x = self.conv(x) x = self.bn(x) - x = F.relu(x) + x = self.relu(x) + + return x + + +class Conv1dModule(torch.nn.Module): + input_shape = (3, 10) + batch_size = 2 + test_data: input_t1 = (torch.randn(batch_size, *input_shape),) + + def __init__(self, batch_norm: bool = True, inplace: bool = False) -> None: + super().__init__() + self.conv = torch.nn.Conv1d(3, 8, 5, padding=2) + self.bn = nn.BatchNorm1d(num_features=8) if batch_norm else nn.Identity() + self.relu = nn.ReLU(inplace=inplace) + + def forward(self, x: torch.Tensor): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) return x models = { # name : (model, is_per_channel) - "conv_bn_relu_per_channel": (ConvModule(batch_norm=True), True), - "conv_relu_per_channel": (ConvModule(batch_norm=False), True), - "conv_bn_relu_per_tensor": (ConvModule(batch_norm=True), False), - "conv_relu_per_tensor": (ConvModule(batch_norm=False), False), + "conv1d_bn_relu_per_channel": (Conv1dModule(batch_norm=True), True), + "conv1d_relu_per_channel": (Conv1dModule(batch_norm=False), True), + "conv1d_bn_relu_per_tensor": (Conv1dModule(batch_norm=True), False), + "conv1d_relu_per_tensor": (Conv1dModule(batch_norm=False), False), + "conv2d_bn_relu_per_channel": (Conv2dModule(batch_norm=True), True), + "conv2d_relu_per_channel": (Conv2dModule(batch_norm=False), True), + "conv2d_bn_relu_per_tensor": (Conv2dModule(batch_norm=True), False), + "conv2d_relu_per_tensor": (Conv2dModule(batch_norm=False), False), + "conv1d_bn_relu_inplace_per_channel": ( + Conv1dModule(batch_norm=True, inplace=True), + True, + ), + "conv1d_relu_inplace_per_channel": ( + Conv1dModule(batch_norm=False, inplace=True), + True, + ), + "conv1d_bn_relu_inplace_per_tensor": ( + Conv1dModule(batch_norm=True, inplace=True), + False, + ), + "conv1d_relu_inplace_per_tensor": ( + Conv1dModule(batch_norm=False, inplace=True), + False, + ), + "conv2d_bn_relu_inplace_per_channel": ( + Conv2dModule(batch_norm=True, inplace=True), + True, + ), + "conv2d_relu_inplace_per_channel": ( + Conv2dModule(batch_norm=False, inplace=True), + True, + ), + "conv2d_bn_relu_inplace_per_tensor": ( + Conv2dModule(batch_norm=True, inplace=True), + False, + ), + "conv2d_relu_inplace_per_tensor": ( + Conv2dModule(batch_norm=False, inplace=True), + False, + ), } -@common.parametrize("test_data", models) +@common.parametrize( + "test_data", + models, +) def test_qat_tosa_INT(test_data): model, per_channel = test_data pipeline = TosaPipelineINT[input_t1](model, model.test_data, [], [], qtol=1) - tosa_version = conftest.get_option("tosa_version") - tosa_profiles = { - "1.0": common.TosaSpecification.create_from_string("TOSA-1.0+INT"), - } - tosa_spec = tosa_profiles[tosa_version] - quantizer = TOSAQuantizer(tosa_spec) + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) pipeline.change_args( "quantize", Quantize( @@ -65,7 +119,6 @@ def test_qat_tosa_INT(test_data): quantization_config=get_symmetric_quantization_config( is_qat=True, is_per_channel=per_channel ), - is_qat=True, ), ) pipeline.run() diff --git a/backends/arm/test/misc/test_call_operator_submodule.py b/backends/arm/test/misc/test_call_operator_submodule.py new file mode 100644 index 00000000000..03201c86f59 --- /dev/null +++ b/backends/arm/test/misc/test_call_operator_submodule.py @@ -0,0 +1,72 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import torch + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from torch.fx import GraphModule +from torch.fx.passes.infra.pass_base import PassResult + + +class _DepthRecordingPass(ArmPass): + _passes_required_after = set() + + def __init__(self, initial_graph_module): + super().__init__() + self.depths: list[int] = [] + self.initial_submodule = initial_graph_module + self.submodule = None + self.num_submodules_called = 0 + + def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): + """Should only be called from the top-level graph module.""" + self.depths.append(self.submodule_depth) + assert self.submodule == self.initial_submodule + return super().call_operator(op, args, kwargs, meta, updated) + + def call_submodule( + self, graph_module: GraphModule, inputs: tuple[Any, ...] + ) -> PassResult: + """Should be called for all three graph_modules: top-level, if, and else.""" + self.submodule = graph_module + self.num_submodules_called += 1 + return super().call_submodule(graph_module, inputs) + + +class _CondModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def _true_branch(arg: torch.Tensor) -> torch.Tensor: + return arg + 1 + + def _false_branch(arg: torch.Tensor) -> torch.Tensor: + return arg - 1 + + predicate = x.sum() > 0 + return torch.cond(predicate, _true_branch, _false_branch, [x]) + + +def test_call_operator_runs_once_for_cond_submodules() -> None: + module = _CondModule() + example_inputs = (torch.randn(2, 3),) + exported = torch.export.export(module, example_inputs) + graph_module = exported.graph_module + + recording_pass = _DepthRecordingPass(graph_module) + pass_manager = ArmPassManager(TosaCompileSpec("TOSA-1.00+FP")) + pass_manager.add_pass(recording_pass) + pass_manager._transform(graph_module) + + assert recording_pass.num_submodules_called == 3 + assert recording_pass.depths, "call_operator was never invoked" + assert ( + max(recording_pass.depths) == 1 + ), "call_operator was invoked with larger than one submodule depth." + assert ( + min(recording_pass.depths) == 1 + ), "call_operator was invoked with zero submodule depth." diff --git a/backends/arm/test/misc/test_compile_spec.py b/backends/arm/test/misc/test_compile_spec.py new file mode 100644 index 00000000000..a1b42cd22b5 --- /dev/null +++ b/backends/arm/test/misc/test_compile_spec.py @@ -0,0 +1,50 @@ +from executorch.backends.arm.ethosu import EthosUCompileSpec +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.vgf import VgfCompileSpec +from pytest import raises + + +def test_ethos_u_compile_spec(): + compile_spec = ( + EthosUCompileSpec("ethos-u55", extra_flags=["--my-flag"]) + .dump_intermediate_artifacts_to("my_path") + .dump_debug_info(EthosUCompileSpec.DebugMode.TOSA) + ) + spec_list = compile_spec.to_list() + + assert EthosUCompileSpec.from_list(spec_list) == compile_spec + assert "--my-flag" in compile_spec.compiler_flags + assert "--output-format=raw" in compile_spec.compiler_flags + with raises(ValueError, match="Incorrect output format"): + VgfCompileSpec.from_list(spec_list) + + spec_list.pop(0) + with raises(ValueError, match="No tosa_spec in compile spec."): + EthosUCompileSpec.from_list(spec_list) + + +def test_vgf_compile_spec(): + compile_spec = ( + VgfCompileSpec(compiler_flags=["--my-flag"]) + .dump_intermediate_artifacts_to("my_path") + .dump_debug_info(None) + ) + compile_spec2 = VgfCompileSpec( + compiler_flags=["--my-flag2"] + ).dump_intermediate_artifacts_to("my_path") + + spec_list = compile_spec.to_list() + + assert VgfCompileSpec.from_list(spec_list) == compile_spec + assert VgfCompileSpec.from_list(spec_list) != compile_spec2 + with raises(ValueError, match="Incorrect output format"): + EthosUCompileSpec.from_list(spec_list) + + +def test_tosa_compile_spec(): + compile_spec = TosaCompileSpec("TOSA-1.0+INT") + spec_list = compile_spec.to_list() + + assert TosaCompileSpec.from_list(spec_list) == compile_spec + with raises(ValueError, match="Incorrect output format"): + VgfCompileSpec.from_list(spec_list) diff --git a/backends/arm/test/misc/test_conv_relu_residual_add.py b/backends/arm/test/misc/test_conv_relu_residual_add.py index fdd6ec972a6..72886fb4b29 100644 --- a/backends/arm/test/misc/test_conv_relu_residual_add.py +++ b/backends/arm/test/misc/test_conv_relu_residual_add.py @@ -76,6 +76,13 @@ def test_tosa_INT(per_channel_quantization): pipeline.run() +# TODO: Xfail until the Ethos-U Vela compiler ships commit +# 642f7517d3a6bd053032e1942822f6e38ccd546f. That patch fixes the bug that +# causes this test to fail. +@pytest.mark.xfail( + reason=("Blocked by Vela commit 642f7517d3a6bd053032e1942822f6e38ccd546f"), + strict=True, +) @pytest.mark.slow @common.XfailIfNoCorstone300 @common.parametrize("per_channel_quantization", quant_test_data) @@ -85,7 +92,6 @@ def test_tosa_u55_INT(per_channel_quantization): model_inputs, [], [], - run_on_fvp=True, use_to_edge_transform_and_lower=True, per_channel_quantization=per_channel_quantization, qtol=0, @@ -102,7 +108,6 @@ def test_tosa_u85_INT(per_channel_quantization): model_inputs, [], [], - run_on_fvp=True, use_to_edge_transform_and_lower=True, per_channel_quantization=per_channel_quantization, qtol=0, diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 3e10a9336f9..1284b956ff7 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -14,14 +14,15 @@ import pytest import torch -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.test import common +from executorch.backends.arm.test.runner_utils import dbg_tosa_fb_to_json from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, TosaPipelineFP, TosaPipelineINT, ) - +from executorch.backends.test.harness.stages import StageType input_t1 = Tuple[torch.Tensor] # Input x @@ -49,8 +50,9 @@ def forward(self, x): def _tosa_FP_pipeline(module: torch.nn.Module, test_data: input_t1, dump_file=None): - - pipeline = TosaPipelineFP[input_t1](module, test_data, [], []) + aten_ops: list[str] = [] + exir_ops: list[str] = [] + pipeline = TosaPipelineFP[input_t1](module, test_data, aten_ops, exir_ops) pipeline.dump_artifact("to_edge_transform_and_lower") pipeline.dump_artifact("to_edge_transform_and_lower", suffix=dump_file) pipeline.pop_stage("run_method_and_compare_outputs") @@ -58,8 +60,9 @@ def _tosa_FP_pipeline(module: torch.nn.Module, test_data: input_t1, dump_file=No def _tosa_INT_pipeline(module: torch.nn.Module, test_data: input_t1, dump_file=None): - - pipeline = TosaPipelineINT[input_t1](module, test_data, [], []) + aten_ops: list[str] = [] + exir_ops: list[str] = [] + pipeline = TosaPipelineINT[input_t1](module, test_data, aten_ops, exir_ops) pipeline.dump_artifact("to_edge_transform_and_lower") pipeline.dump_artifact("to_edge_transform_and_lower", suffix=dump_file) pipeline.pop_stage("run_method_and_compare_outputs") @@ -104,11 +107,13 @@ def test_INT_artifact(test_data: input_t1): @common.parametrize("test_data", Linear.inputs) def test_numerical_diff_print(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1]( + aten_ops: list[str] = [] + exir_ops: list[str] = [] + pipeline = TosaPipelineINT[input_t1]( Linear(), test_data, - [], - [], + aten_ops, + exir_ops, custom_path="diff_print_test", ) pipeline.pop_stage("run_method_and_compare_outputs") @@ -119,7 +124,9 @@ def test_numerical_diff_print(test_data: input_t1): # not present. try: # Tolerate 0 difference => we want to trigger a numerical diff - tester.run_method_and_compare_outputs(atol=0, rtol=0, qtol=0) + tester.run_method_and_compare_outputs( + stage=StageType.INITIAL_MODEL, atol=0, rtol=0, qtol=0 + ) except AssertionError: pass # Implicit pass test else: @@ -128,7 +135,9 @@ def test_numerical_diff_print(test_data: input_t1): @common.parametrize("test_data", Linear.inputs) def test_dump_ops_and_dtypes(test_data: input_t1): - pipeline = TosaPipelineINT[input_t1](Linear(), test_data, [], []) + aten_ops: list[str] = [] + exir_ops: list[str] = [] + pipeline = TosaPipelineINT[input_t1](Linear(), test_data, aten_ops, exir_ops) pipeline.pop_stage("run_method_and_compare_outputs") pipeline.add_stage_after("quantize", pipeline.tester.dump_dtype_distribution) pipeline.add_stage_after("quantize", pipeline.tester.dump_operator_distribution) @@ -146,7 +155,9 @@ def test_dump_ops_and_dtypes(test_data: input_t1): @common.parametrize("test_data", Linear.inputs) def test_dump_ops_and_dtypes_parseable(test_data: input_t1): - pipeline = TosaPipelineINT[input_t1](Linear(), test_data, [], []) + aten_ops: list[str] = [] + exir_ops: list[str] = [] + pipeline = TosaPipelineINT[input_t1](Linear(), test_data, aten_ops, exir_ops) pipeline.pop_stage("run_method_and_compare_outputs") pipeline.add_stage_after("quantize", pipeline.tester.dump_dtype_distribution, False) pipeline.add_stage_after( @@ -174,7 +185,9 @@ def test_collate_tosa_INT_tests(test_data: input_t1): # Set the environment variable to trigger the collation of TOSA tests os.environ["TOSA_TESTCASES_BASE_PATH"] = "test_collate_tosa_tests" # Clear out the directory - pipeline = TosaPipelineINT[input_t1](Linear(), test_data, [], []) + aten_ops: list[str] = [] + exir_ops: list[str] = [] + pipeline = TosaPipelineINT[input_t1](Linear(), test_data, aten_ops, exir_ops) pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() @@ -194,13 +207,15 @@ def test_collate_tosa_INT_tests(test_data: input_t1): @common.parametrize("test_data", Linear.inputs) def test_dump_tosa_debug_json(test_data: input_t1): with tempfile.TemporaryDirectory() as tmpdir: + aten_ops: list[str] = [] + exir_ops: list[str] = [] pipeline = TosaPipelineINT[input_t1]( module=Linear(), test_data=test_data, - aten_op=[], - exir_op=[], + aten_op=aten_ops, + exir_op=exir_ops, custom_path=tmpdir, - tosa_debug_mode=ArmCompileSpecBuilder.DebugMode.JSON, + tosa_debug_mode=ArmCompileSpec.DebugMode.JSON, ) pipeline.pop_stage("run_method_and_compare_outputs") @@ -224,32 +239,58 @@ def test_dump_tosa_debug_json(test_data: input_t1): @common.parametrize("test_data", Linear.inputs) def test_dump_tosa_debug_tosa(test_data: input_t1): - with tempfile.TemporaryDirectory() as tmpdir: - pipeline = TosaPipelineINT[input_t1]( - module=Linear(), - test_data=test_data, - aten_op=[], - exir_op=[], - custom_path=tmpdir, - tosa_debug_mode=ArmCompileSpecBuilder.DebugMode.TOSA, - ) + output_dir = "test_dump_tosa_debug" - pipeline.pop_stage("run_method_and_compare_outputs") - pipeline.run() + aten_ops: list[str] = [] + exir_ops: list[str] = [] + pipeline = TosaPipelineFP[input_t1]( + module=Linear(), + test_data=test_data, + use_to_edge_transform_and_lower=True, + aten_op=aten_ops, + exir_op=exir_ops, + custom_path=output_dir, + tosa_debug_mode=ArmCompileSpec.DebugMode.TOSA, + ) - json_output_path = Path(tmpdir) / "debug.json" + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() - # A JSON file should not be created when TOSA mode used - assert not json_output_path.exists() + output_path = Path(output_dir) + json_output_path = output_path / "debug.json" + + # A JSON file should not be created when TOSA mode used + assert not json_output_path.exists() + + # At least one TOSA file should exist + tosa_files = list(output_path.glob("*.tosa")) + assert len(tosa_files) > 0 + + tosa_file = tosa_files[0] + with tosa_file.open("rb") as f: + tosa_json = dbg_tosa_fb_to_json(f.read()) + + # Check all non-empty JSON strings are valid + ops = tosa_json["regions"][0]["blocks"][0]["operators"] + for op in ops: + if op["location"]["text"]: + try: + json.loads(op["location"]["text"]) + except json.JSONDecodeError: + pytest.fail("Failed to load debug JSON string") + + shutil.rmtree(output_dir, ignore_errors=True) @common.parametrize("test_data", Linear.inputs) -def test_dump_tosa_ops(caplog, test_data: input_t1): - pipeline = TosaPipelineINT[input_t1](Linear(), test_data, [], []) +def test_dump_tosa_ops(capsys, test_data: input_t1): + aten_ops: list[str] = [] + exir_ops: list[str] = [] + pipeline = TosaPipelineINT[input_t1](Linear(), test_data, aten_ops, exir_ops) pipeline.pop_stage("run_method_and_compare_outputs") pipeline.dump_operator_distribution("to_edge_transform_and_lower") pipeline.run() - assert "TOSA operators:" in caplog.text + assert "TOSA operators:" in capsys.readouterr().out class Add(torch.nn.Module): @@ -262,10 +303,14 @@ def forward(self, x): @common.parametrize("test_data", Add.inputs) -def test_fail_dump_tosa_ops(caplog, test_data: input_t1): +@common.XfailIfNoCorstone300 +def test_fail_dump_tosa_ops(capsys, test_data: input_t1): + aten_ops: list[str] = [] + exir_ops: list[str] = [] pipeline = EthosU55PipelineINT[input_t1]( - Add(), test_data, [], [], use_to_edge_transform_and_lower=True, run_on_fvp=False + Add(), test_data, aten_ops, exir_ops, use_to_edge_transform_and_lower=True ) pipeline.dump_operator_distribution("to_edge_transform_and_lower") - pipeline.run() - assert "Can not get operator distribution for Vela command stream." in caplog.text + error_msg = "Can not get operator distribution for Vela command stream." + with pytest.raises(NotImplementedError, match=error_msg): + pipeline.run() diff --git a/backends/arm/test/misc/test_debug_hook.py b/backends/arm/test/misc/test_debug_hook.py index 935f3984403..8aa6e1006e8 100644 --- a/backends/arm/test/misc/test_debug_hook.py +++ b/backends/arm/test/misc/test_debug_hook.py @@ -5,11 +5,14 @@ from dataclasses import dataclass from types import SimpleNamespace +from typing import cast -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.debug.schema import DebugHook, DebugSchema from executorch.backends.arm.test import common +from torch.fx import Node + @dataclass class DebugHookTestCase: @@ -28,7 +31,7 @@ def _get_action_str() -> str: name="convolution", target="aten.convolution.default", graph_id=6052414368, - pass_name="ExportedProgram.module()", + pass_name="ExportedProgram.module()", # nosec B106 - static test string, not a secret action="create", from_node=[], _get_action_string=_get_action_str, @@ -38,7 +41,7 @@ def _get_action_str() -> str: name="convolution", target="aten.convolution.default", graph_id=5705954832, - pass_name="Interpreter_PropagateUnbackedSymInts", + pass_name="Interpreter_PropagateUnbackedSymInts", # nosec B106 - static test string, not a secret action="create", from_node=[from_node_2], _get_action_string=_get_action_str, @@ -66,7 +69,7 @@ def _get_action_str() -> str: name="convolution", target="aten.convolution.default", graph_id=5705954832, - pass_name="Interpreter_PropagateUnbackedSymInts", + pass_name="Interpreter_PropagateUnbackedSymInts", # nosec B106 - static test string, not a secret action="create", from_node=[], _get_action_string=_get_action_str, @@ -95,9 +98,9 @@ def create_mock_node_3(): return fx_node_mock -def _compare_tosa_and_schema(debug_event: DebugSchema, tosa_op): +def _compare_tosa_and_schema(debug_event: DebugSchema, tosa_op: str) -> None: tosa_info = debug_event.tosa_info - + assert tosa_info is not None assert tosa_info.node_name == tosa_op # The mapping between op_ids to operator names could change @@ -158,8 +161,8 @@ def _compare_node_and_schema(debug_event: DebugSchema, mocked_node): @common.parametrize("test_data", TESTCASES) def test_debug_hook_add_json(test_data: DebugHookTestCase): - hook = DebugHook(ArmCompileSpecBuilder.DebugMode.JSON) - hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id) + hook = DebugHook(ArmCompileSpec.DebugMode.JSON) + hook.add(cast(Node, test_data.mock_node), test_data.tosa_op, test_data.op_id) debug_events = hook._debug_events assert len(debug_events) == test_data.expected_events @@ -171,8 +174,8 @@ def test_debug_hook_add_json(test_data: DebugHookTestCase): @common.parametrize("test_data", TESTCASES) def test_debug_hook_add_tosa(test_data: DebugHookTestCase): - hook = DebugHook(ArmCompileSpecBuilder.DebugMode.TOSA) - hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id) + hook = DebugHook(ArmCompileSpec.DebugMode.TOSA) + hook.add(cast(Node, test_data.mock_node), test_data.tosa_op, test_data.op_id) debug_events = hook._debug_events assert len(debug_events) == test_data.expected_events diff --git a/backends/arm/test/misc/test_dim_order.py b/backends/arm/test/misc/test_dim_order.py new file mode 100644 index 00000000000..14e12461652 --- /dev/null +++ b/backends/arm/test/misc/test_dim_order.py @@ -0,0 +1,127 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, +) + + +input_t1 = Tuple[torch.Tensor, ...] # Input x + + +class ChannelsLastInput(torch.nn.Module): + """ + Test a complex case with (channels last, channels first) input, + and (channels first, channels last) output. + """ + + inputs: input_t1 = ( + torch.arange(1, 25, dtype=torch.float32) + .reshape((1, 2, 3, 4)) + .to(memory_format=torch.channels_last), + torch.arange(1, 25, dtype=torch.float32).reshape((1, 2, 3, 4)), + ) + + def forward(self, x, y): + x = x * x + return y, x + + +class ChannelsFirstOutput(torch.nn.Module): + """ + Test coverting to channels_first inside the delegate. + """ + + inputs: input_t1 = ( + torch.arange(1, 25, dtype=torch.float32) + .reshape((1, 2, 3, 4)) + .to(memory_format=torch.channels_last), + ) + + def forward(self, x): + x = x.clone(memory_format=torch.contiguous_format) * x + return x + + +class ChannelsLastOutput(torch.nn.Module): + """ + Test changing of dim_order inside the delegate. + """ + + inputs: input_t1 = (torch.arange(1, 9, dtype=torch.float32).reshape((1, 2, 2, 2)),) + + def forward(self, x): + x = x * x + x = x.clone(memory_format=torch.channels_last) + return x + + +class ChannelsLastInsidePartition(torch.nn.Module): + """ + Test dim_order changes inside the partiton, but no dim_order changes at input/output. + """ + + inputs: input_t1 = (torch.randn((1, 2, 3, 3)),) + + def __init__(self): + super().__init__() + self.conv2d = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=(3, 3)) + + def forward(self, x): + return ( + self.conv2d(x.clone(memory_format=torch.channels_last)).clone( + memory_format=torch.contiguous_format + ) + * 1 + ) + + +test_modules = { + "channels_last_input": ChannelsLastInput, + "channels_first_output": ChannelsFirstOutput, + "channels_last_output": ChannelsLastOutput, + "channels_last_inside_partition": ChannelsLastInsidePartition, +} + + +@common.parametrize("module", test_modules) +def test_dim_order_tosa_FP(module) -> None: + aten_ops: list[str] = [] + pipeline = TosaPipelineFP[input_t1](module(), module.inputs, aten_ops) + pipeline.run() + + +@common.parametrize("module", test_modules) +def test_dim_order_tosa_INT(module) -> None: + aten_ops: list[str] = [] + pipeline = TosaPipelineINT[input_t1]( + module(), module.inputs, aten_ops, symmetric_io_quantization=True + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("module", test_modules) +def test_dim_order_u55_INT(module) -> None: + aten_ops: list[str] = [] + pipeline = EthosU55PipelineINT[input_t1](module(), module.inputs, aten_ops) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("module", test_modules) +def test_dim_order_u85_INT(module) -> None: + aten_ops: list[str] = [] + pipeline = EthosU85PipelineINT[input_t1](module(), module.inputs, aten_ops) + pipeline.run() diff --git a/backends/arm/test/misc/test_dim_order_guards.py b/backends/arm/test/misc/test_dim_order_guards.py deleted file mode 100644 index 80a3c014abc..00000000000 --- a/backends/arm/test/misc/test_dim_order_guards.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Tuple - -import pytest - -import torch -from executorch.backends.arm.test import common - -from executorch.backends.arm.test.tester.test_pipeline import ( - TosaPipelineFP, - TosaPipelineINT, -) - - -input_t1 = Tuple[torch.Tensor] # Input x - - -class Conv2D(torch.nn.Module): - inputs: dict[str, input_t1] = { - "randn": (torch.randn(1, 2, 20, 20).to(memory_format=torch.channels_last),), - } - - def __init__(self): - super().__init__() - self.conv2d = torch.nn.Conv2d(in_channels=2, out_channels=3, kernel_size=(3, 3)) - - def forward(self, x): - return self.conv2d(x) - - -@common.parametrize("test_data", Conv2D.inputs) -def test_tosa_FP_pipeline(test_data: input_t1): - module = Conv2D() - pipeline = TosaPipelineFP[input_t1]( - module, - test_data, - [], - [], - use_to_edge_transform_and_lower=False, - ) - pos = pipeline.find_pos("partition") - pipeline._stages = pipeline._stages[:pos] - pipeline.run() - with pytest.raises(RuntimeError): - pipeline.tester.partition() - - -@common.parametrize("test_data", Conv2D.inputs) -def test_tosa_INT_pipeline(test_data: input_t1): - module = Conv2D() - pipeline = TosaPipelineINT[input_t1]( - module, - test_data, - [], - [], - use_to_edge_transform_and_lower=False, - ) - pos = pipeline.find_pos("partition") - pipeline._stages = pipeline._stages[:pos] - pipeline.run() - with pytest.raises(RuntimeError): - pipeline.tester.partition() diff --git a/backends/arm/test/misc/test_dw_convs_with_shared_weights.py b/backends/arm/test/misc/test_dw_convs_with_shared_weights.py new file mode 100644 index 00000000000..8b3b99cf005 --- /dev/null +++ b/backends/arm/test/misc/test_dw_convs_with_shared_weights.py @@ -0,0 +1,58 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Tuple + +import torch +from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass +from executorch.backends.arm.test.tester.test_pipeline import ( + PassPipeline, + TosaPipelineFP, + TosaPipelineINT, +) + +input_t = Tuple[torch.Tensor] + + +class DWConvsModule(torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + conv = torch.nn.Conv2d(6, 6, kernel_size=(2, 2), groups=6) + relu = torch.nn.ReLU() + self.sequential = torch.nn.ModuleList([conv, relu, conv]) + + def forward(self, x) -> torch.Tensor: + for m in self.sequential: + x = m(x) + return x + + def get_inputs(self) -> input_t: + return (torch.randn(1, 6, 24, 24),) + + +def test_convs_tosa_fp(): + module = DWConvsModule() + pipeline = TosaPipelineFP[input_t]( + module, module.get_inputs(), aten_op=[], exir_op=[] + ) + pipeline.run() + + +def test_convs_tosa_int(): + module = DWConvsModule() + pipeline = TosaPipelineINT[input_t]( + module, module.get_inputs(), aten_op=[], exir_op=[] + ) + pipeline.run() + + +def test_rewrite_conv_pass(): + module = DWConvsModule() + pipeline = PassPipeline( + module, module.get_inputs(), passes_with_exported_program=[RewriteConvPass] + ) + # We can't run TOSA backend dialect operators in eager mode + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() diff --git a/backends/arm/test/misc/test_extract_io_params_tosa.py b/backends/arm/test/misc/test_extract_io_params_tosa.py index da471b0bb74..90104c54899 100644 --- a/backends/arm/test/misc/test_extract_io_params_tosa.py +++ b/backends/arm/test/misc/test_extract_io_params_tosa.py @@ -7,7 +7,6 @@ import pytest import torch -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.quantizer import VgfQuantizer from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_quantization_config, @@ -15,9 +14,9 @@ ) from executorch.backends.arm.test.common import SkipIfNoModelConverter -from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.partitioner import TOSAPartitioner -from executorch.backends.arm.vgf import VgfPartitioner +from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner from executorch.exir import to_edge_transform_and_lower from executorch.exir.passes.quantize_io_pass import extract_io_quant_params from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -29,11 +28,11 @@ def forward(self, x, y): @pytest.mark.parametrize( - "builder_method, quantizer_cls, partitioner_cls", + "compile_spec_cls, quantizer_cls, partitioner_cls", [ - ("tosa_compile_spec", TOSAQuantizer, TOSAPartitioner), + (TosaCompileSpec, TOSAQuantizer, TOSAPartitioner), pytest.param( - "vgf_compile_spec", + VgfCompileSpec, VgfQuantizer, VgfPartitioner, marks=SkipIfNoModelConverter, @@ -41,7 +40,11 @@ def forward(self, x, y): ), ], ) -def test_roundtrip_extracts_io_params(builder_method, quantizer_cls, partitioner_cls): +def test_roundtrip_extracts_io_params( + compile_spec_cls: type[TosaCompileSpec] | type[VgfCompileSpec], + quantizer_cls, + partitioner_cls, +): """ Validates that IO quantization parameters round-trip for both flows. """ @@ -51,10 +54,7 @@ def test_roundtrip_extracts_io_params(builder_method, quantizer_cls, partitioner ) mod = SimpleAdd().eval() - base_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") - compile_spec = getattr(ArmCompileSpecBuilder(), builder_method)( - tosa_spec=base_spec - ).build() + compile_spec = compile_spec_cls("TOSA-1.0+INT") quantizer = quantizer_cls(compile_spec) operator_config = get_symmetric_quantization_config(is_qat=True) diff --git a/backends/arm/test/misc/test_int64.py b/backends/arm/test/misc/test_int64.py index d6d6d6cb39c..46a97fff1df 100644 --- a/backends/arm/test/misc/test_int64.py +++ b/backends/arm/test/misc/test_int64.py @@ -68,10 +68,6 @@ def forward(self, x: torch.Tensor): ConstAdd(torch.int64, 2**40), (torch.rand(10) - 0.5,), ), - "int64_in+float_const": ( - ConstAdd(torch.float32), - (torch.randint(0, 10, (10,)),), - ), "fp32_in+int64_buffer_chain": ( BufferChainAdd(torch.int64), (torch.rand(2, 5, 3) - 0.5,), @@ -94,7 +90,7 @@ def test_int64_tosa_FP(test_data: Tuple): ArmTester( model, inputs, - common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"), + common.get_tosa_compile_spec("TOSA-1.0+FP"), ) .export() .to_edge_transform_and_lower() diff --git a/backends/arm/test/misc/test_lifted_tensor.py b/backends/arm/test/misc/test_lifted_tensor.py index 2e45a36d12a..ee9812b53fd 100644 --- a/backends/arm/test/misc/test_lifted_tensor.py +++ b/backends/arm/test/misc/test_lifted_tensor.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. import operator -from typing import Tuple, Union +from collections.abc import Callable +from typing import Union import torch from executorch.backends.arm.test import common @@ -15,12 +16,22 @@ from executorch.backends.test.harness.stages import StageType -input_t1 = Tuple[torch.Tensor] +LiftedTensorInputs = tuple[torch.Tensor, int] +LiftedTensorCase = tuple[ + Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + LiftedTensorInputs, +] +LiftedScalarTensorInputs = tuple[torch.Tensor, ...] +LiftedScalarTensorCase = tuple[ + Callable[[torch.Tensor, Union[float, int, torch.Tensor]], torch.Tensor], + LiftedScalarTensorInputs, + Union[float, int, torch.Tensor], +] class LiftedTensor(torch.nn.Module): - test_data = { + test_data: dict[str, LiftedTensorCase] = { # test_name: (operator, test_data, length) "add": (operator.add, (torch.randn(2, 2), 2)), "truediv": (operator.truediv, (torch.ones(2, 2), 2)), @@ -39,7 +50,7 @@ def forward(self, x: torch.Tensor, length) -> torch.Tensor: class LiftedScalarTensor(torch.nn.Module): - test_data = { + test_data: dict[str, LiftedScalarTensorCase] = { # test_name: (operator, test_data) "add": (operator.add, (torch.randn(2, 2),), 1.0), "truediv": (operator.truediv, (torch.randn(4, 2),), 1.0), @@ -60,14 +71,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @common.parametrize("test_data", LiftedTensor.test_data) -def test_partition_lifted_tensor_tosa_FP(test_data: input_t1): - op = test_data[0] - data = test_data[1:] +def test_partition_lifted_tensor_tosa_FP(test_data: LiftedTensorCase) -> None: + op, inputs = test_data module = LiftedTensor(op) - pipeline = TosaPipelineFP[input_t1]( + aten_ops: list[str] = [] + pipeline = TosaPipelineFP[LiftedTensorInputs]( module, - *data, - [], + inputs, + aten_ops, exir_op=[], use_to_edge_transform_and_lower=False, ) @@ -81,14 +92,14 @@ def test_partition_lifted_tensor_tosa_FP(test_data: input_t1): @common.parametrize("test_data", LiftedTensor.test_data) -def test_partition_lifted_tensor_tosa_INT(test_data: input_t1): - op = test_data[0] - data = test_data[1:] +def test_partition_lifted_tensor_tosa_INT(test_data: LiftedTensorCase) -> None: + op, inputs = test_data module = LiftedTensor(op) - pipeline = TosaPipelineINT[input_t1]( + aten_ops: list[str] = [] + pipeline = TosaPipelineINT[LiftedTensorInputs]( module, - *data, - [], + inputs, + aten_ops, exir_op=[], use_to_edge_transform_and_lower=False, ) @@ -102,14 +113,16 @@ def test_partition_lifted_tensor_tosa_INT(test_data: input_t1): @common.parametrize("test_data", LiftedScalarTensor.test_data) -def test_partition_lifted_scalar_tensor_tosa_FP(test_data: input_t1): - op = test_data[0] - data = test_data[1:] - module = LiftedScalarTensor(op, data[-1]) - pipeline = TosaPipelineFP[input_t1]( +def test_partition_lifted_scalar_tensor_tosa_FP( + test_data: LiftedScalarTensorCase, +) -> None: + op, tensor_inputs, scalar_arg = test_data + module = LiftedScalarTensor(op, scalar_arg) + aten_ops: list[str] = [] + pipeline = TosaPipelineFP[LiftedScalarTensorInputs]( module, - data[0], - [], + tensor_inputs, + aten_ops, exir_op=[], use_to_edge_transform_and_lower=False, ) @@ -117,14 +130,16 @@ def test_partition_lifted_scalar_tensor_tosa_FP(test_data: input_t1): @common.parametrize("test_data", LiftedScalarTensor.test_data) -def test_partition_lifted_scalar_tensor_tosa_INT(test_data: input_t1): - op = test_data[0] - data = test_data[1:] - module = LiftedScalarTensor(op, data[-1]) - pipeline = TosaPipelineINT[input_t1]( +def test_partition_lifted_scalar_tensor_tosa_INT( + test_data: LiftedScalarTensorCase, +) -> None: + op, tensor_inputs, scalar_arg = test_data + module = LiftedScalarTensor(op, scalar_arg) + aten_ops: list[str] = [] + pipeline = TosaPipelineINT[LiftedScalarTensorInputs]( module, - data[0], - [], + tensor_inputs, + aten_ops, exir_op=[], use_to_edge_transform_and_lower=False, ) diff --git a/backends/arm/test/misc/test_mixed_type_lowering.py b/backends/arm/test/misc/test_mixed_type_lowering.py new file mode 100644 index 00000000000..7c03c8a1960 --- /dev/null +++ b/backends/arm/test/misc/test_mixed_type_lowering.py @@ -0,0 +1,72 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import Counter, defaultdict + +import torch +from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT + + +def combine_op_dicts(*dicts): + merged = defaultdict(Counter) + for d in dicts: + for op, dtypes in d.items(): + merged[op].update(dtypes) + return {op: dict(counts) for op, counts in merged.items()} + + +# TODO Figure out how to handle multiple dq/q nodes properly +# See backends/arm/_passes/decompose_quant_nodes.py for details +dq_tosa_ops = { + "CAST": {"FP32": 1, "INT32": 1}, + "SUB": {"INT32": 1}, # zero-point subtraction + "MUL": {"FP32": 1}, # scale multiplication +} +q_tosa_ops = { + "CAST": {"INT8": 1}, + "MUL": {"FP32": 1}, # scale multiplication + "ADD": {"FP32": 2}, # zero-point addition, rounding + "SUB": {"FP32": 1}, # for rounding + "CLAMP": {"FP32": 1}, # clamp + "GREATER_EQUAL": {"BOOL": 1}, # for rounding + "SELECT": {"FP32": 1}, # for rounding + "CEIL": {"FP32": 1}, # for rounding + "FLOOR": {"FP32": 1}, # for rounding +} +q_dq_tosa_ops = combine_op_dicts(dq_tosa_ops, q_tosa_ops) + + +class AddSigmoidMul(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, y): + return self.sigmoid(x + y) * x + + +def test_mixed_type_lowering(): + model = AddSigmoidMul() + input_data = (torch.randn(1, 16, 16, 16), torch.randn(1, 16, 16, 16)) + + pipeline = TosaPipelineINT[type(input_data)]( + model, input_data, [], [], qtol=1, tosa_extensions=["FP"] + ) + pipeline.quantizer.set_module_type(torch.nn.Sigmoid, None) + expected_tosa_dtype_counts = combine_op_dicts( + { + "SIGMOID": {"FP32": 1}, # SIGMOID should be executed in FP32 + "ADD": {"INT32": 1}, # ADD should be executed in INT32 + "MUL": {"INT32": 1}, # MUL should be executed in INT32 + }, + q_dq_tosa_ops, + ) + + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check_dtype_count, + expected_tosa_dtype_counts, + ) + pipeline.run() diff --git a/backends/arm/test/misc/test_multiple_delegates.py b/backends/arm/test/misc/test_multiple_delegates.py index f716bc45385..4928d3d7437 100644 --- a/backends/arm/test/misc/test_multiple_delegates.py +++ b/backends/arm/test/misc/test_multiple_delegates.py @@ -23,13 +23,17 @@ class MultipleDelegatesModule(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor): z = x + y - s = torch.tan(z) + s = torch.max(z) return s * z @common.parametrize("test_data", MultipleDelegatesModule.inputs) def test_tosa_FP_pipeline(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1](MultipleDelegatesModule(), test_data, [], []) + aten_ops: list[str] = [] + exir_ops: list[str] = [] + pipeline = TosaPipelineFP[input_t1]( + MultipleDelegatesModule(), test_data, aten_ops, exir_ops + ) pipeline.change_args( "check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2} ) @@ -38,8 +42,10 @@ def test_tosa_FP_pipeline(test_data: input_t1): @common.parametrize("test_data", MultipleDelegatesModule.inputs) def test_tosa_INT_pipeline(test_data: input_t1): + aten_ops: list[str] = [] + exir_ops: list[str] = [] pipeline = TosaPipelineINT[input_t1]( - MultipleDelegatesModule(), test_data, [], [], qtol=1 + MultipleDelegatesModule(), test_data, aten_ops, exir_ops, qtol=1 ) pipeline.change_args( "check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2} diff --git a/backends/arm/test/misc/test_multiple_outputs.py b/backends/arm/test/misc/test_multiple_outputs.py index 45398437238..37ca3047e7d 100644 --- a/backends/arm/test/misc/test_multiple_outputs.py +++ b/backends/arm/test/misc/test_multiple_outputs.py @@ -30,14 +30,20 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): @common.parametrize("test_data", MultipleOutputsModule.inputs) def test_tosa_FP_pipeline(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1](MultipleOutputsModule(), test_data, [], []) + aten_ops: list[str] = [] + exir_ops: list[str] = [] + pipeline = TosaPipelineFP[input_t1]( + MultipleOutputsModule(), test_data, aten_ops, exir_ops + ) pipeline.run() @common.parametrize("test_data", MultipleOutputsModule.inputs) def test_tosa_INT_pipeline(test_data: input_t1): + aten_ops: list[str] = [] + exir_ops: list[str] = [] pipeline = TosaPipelineINT[input_t1]( - MultipleOutputsModule(), test_data, [], [], qtol=1 + MultipleOutputsModule(), test_data, aten_ops, exir_ops, qtol=1 ) pipeline.run() @@ -45,8 +51,10 @@ def test_tosa_INT_pipeline(test_data: input_t1): @common.parametrize("test_data", MultipleOutputsModule.inputs) @common.XfailIfNoCorstone300 def test_U55_pipeline(test_data: input_t1): + aten_ops: list[str] = [] + exir_ops: list[str] = [] pipeline = EthosU55PipelineINT[input_t1]( - MultipleOutputsModule(), test_data, [], [], qtol=1 + MultipleOutputsModule(), test_data, aten_ops, exir_ops, qtol=1 ) pipeline.run() @@ -54,7 +62,9 @@ def test_U55_pipeline(test_data: input_t1): @common.parametrize("test_data", MultipleOutputsModule.inputs) @common.XfailIfNoCorstone320 def test_U85_pipeline(test_data: input_t1): + aten_ops: list[str] = [] + exir_ops: list[str] = [] pipeline = EthosU85PipelineINT[input_t1]( - MultipleOutputsModule(), test_data, [], [], qtol=1 + MultipleOutputsModule(), test_data, aten_ops, exir_ops, qtol=1 ) pipeline.run() diff --git a/backends/arm/test/misc/test_outputs_order.py b/backends/arm/test/misc/test_outputs_order.py index 43d35b6d13c..253888537f8 100644 --- a/backends/arm/test/misc/test_outputs_order.py +++ b/backends/arm/test/misc/test_outputs_order.py @@ -3,23 +3,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # -# pyre-unsafe +import importlib import tempfile from pathlib import Path +from typing import Any import pytest import torch -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_quantization_config, TOSAQuantizer, ) +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.partitioner import TOSAPartitioner from executorch.backends.arm.tosa.specification import TosaSpecification from executorch.exir import to_edge_transform_and_lower from torch import nn from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -from tosa import TosaGraph + +_TOSA_GRAPH: Any = importlib.import_module("tosa.TosaGraph") class Network(nn.Module): @@ -58,7 +60,7 @@ def _read_tosa_outputs(tosa_path: Path): # Find output tensor names in order and return shapes buf = tosa_path.read_bytes() buf_arr = bytearray(buf) - graph = TosaGraph.TosaGraph.GetRootAsTosaGraph(buf_arr, 0) + graph = _TOSA_GRAPH.TosaGraph.GetRootAsTosaGraph(buf_arr, 0) region = graph.Regions(0) block = region.Blocks(0) # Build a dict name - tensor‑shape @@ -76,32 +78,33 @@ def _read_tosa_outputs(tosa_path: Path): return shapes +# TODO: MLETORCH-1266 Investigate output order issue @pytest.mark.parametrize("batch_size", [1, 4]) -def test_network_output_order_and_restore(tmp_path, batch_size): +@pytest.mark.parametrize("output_order_workaround", [True, False]) +def test_network_output_order_and_restore(batch_size, output_order_workaround): model = Network(batch_norm=True).eval() # Prepare spec spec = TosaSpecification.create_from_string("TOSA-1.0+INT") - compile_spec = ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec=spec).build() + tosa_compile_spec = TosaCompileSpec(spec).set_output_order_workaround( + output_order_workaround + ) # Setup quantizer - quantizer = TOSAQuantizer(compile_spec) + quantizer = TOSAQuantizer(tosa_compile_spec) quantizer.set_global( get_symmetric_quantization_config(is_qat=True, is_per_channel=False) ) # Trace the model dummy = torch.randn(batch_size, 1, 28, 28) - fx_mod = torch.export.export_for_training(model, (dummy,)).module() + fx_mod = torch.export.export(model, (dummy,)).module() model = prepare_pt2e(fx_mod, quantizer) model(dummy) model = convert_pt2e(model) # Export to aten dialect aten_gm = torch.export.export(model, args=(dummy,), strict=True) - with tempfile.TemporaryDirectory() as tmpdir: + with tempfile.TemporaryDirectory(dir="") as tmpdir: art_dir = Path(tmpdir) part = TOSAPartitioner( - ArmCompileSpecBuilder() - .tosa_compile_spec(spec) - .dump_intermediate_artifacts_to(str(art_dir)) - .build() + tosa_compile_spec.dump_intermediate_artifacts_to(str(art_dir)) ) _ = to_edge_transform_and_lower(aten_gm, partitioner=[part]) # Expect exactly one .tosa file in the artefact dir diff --git a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py index 04ecd57e7b1..0514ad5e280 100644 --- a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py +++ b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py @@ -8,8 +8,6 @@ # such a Softplus that is decompsed into many other ops without # surrounding q/dq nodes. -from typing import Tuple - import torch from executorch.backends.arm.test import common @@ -18,7 +16,7 @@ TosaPipelineINT, ) -input_t1 = Tuple[torch.Tensor] +input_t1 = tuple[torch.Tensor, ...] softplus_aten_op: list[str] = [ "torch.ops.aten.add.Tensor", "torch.ops.aten.softplus.default", @@ -44,7 +42,7 @@ ] -test_data: dict[input_t1] = { +test_data: dict[str, input_t1] = { "3d_rand": (torch.rand(1, 5, 5),), } diff --git a/backends/arm/test/misc/test_pass_pipeline_config.py b/backends/arm/test/misc/test_pass_pipeline_config.py new file mode 100644 index 00000000000..e89a235ae9a --- /dev/null +++ b/backends/arm/test/misc/test_pass_pipeline_config.py @@ -0,0 +1,35 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm._passes import ( + DecomposeSoftmaxUnstablePass, + FuseDuplicateUsersPass, +) +from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager +from executorch.backends.arm.common.pipeline_config import ArmPassPipelineConfig +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.specification import TosaSpecification + + +def test_pipeline_config_override_outside_compile_spec(): + compile_spec = TosaCompileSpec( + TosaSpecification.create_from_string("TOSA-1.00+INT") + ) + default_manager = ArmPassManager(compile_spec) + default_skip_passes = default_manager._skip_pass_types + assert FuseDuplicateUsersPass not in default_skip_passes + assert DecomposeSoftmaxUnstablePass in default_skip_passes + + override_compile_spec = TosaCompileSpec( + TosaSpecification.create_from_string("TOSA-1.00+INT") + ) + override_config = ArmPassPipelineConfig() + override_config.disable_fuse_duplicate_users() + override_compile_spec.set_pass_pipeline_config(override_config) + override_manager = ArmPassManager(override_compile_spec) + skip_passes = override_manager._skip_pass_types + + assert FuseDuplicateUsersPass in skip_passes + assert DecomposeSoftmaxUnstablePass in skip_passes diff --git a/backends/arm/test/misc/test_pass_required_order.py b/backends/arm/test/misc/test_pass_required_order.py new file mode 100644 index 00000000000..694e1997d0f --- /dev/null +++ b/backends/arm/test/misc/test_pass_required_order.py @@ -0,0 +1,97 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re +from typing import List, Set, Type + +import pytest +from executorch.backends.arm._passes.arm_pass_manager import ArmPass, ArmPassManager +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.exir.pass_base import ExportPass + + +class PassC(ArmPass): + _passes_required_after: Set[Type[ExportPass]] = set() + + +class PassB(ArmPass): + _passes_required_after = {PassC} + + +class PassA(ArmPass): + _passes_required_after = {PassB, PassC} + + +class IndependentPass(ArmPass): + _passes_required_after: Set[Type[ExportPass]] = set() + + +def _setup_pass_manager(passes: List[ArmPass] | None = None): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.00+INT") + compile_spec = TosaCompileSpec(tosa_spec) + pass_manager = ArmPassManager(compile_spec) + if passes is not None: + for p in passes: + pass_manager.add_pass(p) + return pass_manager + + +def test_no_passes(): + pass_manager = _setup_pass_manager() + pass_manager.validate_constraints_mandatory() + + +def test_correct_order(): + pass_manager = _setup_pass_manager([PassA(), PassB(), PassC()]) + pass_manager.validate_constraints_mandatory() + + +def test_run_pass_twice(): + pass_manager = _setup_pass_manager([PassA(), PassB(), PassB(), PassC()]) + pass_manager.validate_constraints_mandatory() + + +def test_independent_pass(): + pass_manager = _setup_pass_manager( + [ + IndependentPass(), + PassA(), + IndependentPass(), + PassB(), + IndependentPass(), + PassC(), + IndependentPass(), + ] + ) + pass_manager.validate_constraints_mandatory() + + +def test_duplicated_requiring_pass_put_last(): + error_msg = """The following constraints for passes are not met: + - PassC must run after PassB +""" + pass_manager = _setup_pass_manager([PassA(), PassB(), PassC(), PassB()]) + with pytest.raises(RuntimeError, match=re.escape(error_msg)): + pass_manager.validate_constraints_mandatory() + + +def test_two_passes_wrong_order(): + error_msg = """The following constraints for passes are not met: + - PassC must run after PassB +""" + pass_manager = _setup_pass_manager([PassC(), PassB()]) + with pytest.raises(RuntimeError, match=re.escape(error_msg)): + pass_manager.validate_constraints_mandatory() + + +def test_missing_passes(): + error_msg = """The following constraints for passes are not met: + - PassC must run after PassA + - PassC must run after PassB +""" + pass_manager = _setup_pass_manager([PassA(), PassB()]) + with pytest.raises(RuntimeError, match=re.escape(error_msg)): + pass_manager.validate_constraints_mandatory() diff --git a/backends/arm/test/misc/test_qat_training_loop.py b/backends/arm/test/misc/test_qat_training_loop.py new file mode 100644 index 00000000000..291b02bc8ee --- /dev/null +++ b/backends/arm/test/misc/test_qat_training_loop.py @@ -0,0 +1,100 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from executorch.backends.arm.quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) + +from executorch.backends.arm.tosa.specification import TosaSpecification +from torch.export import export +from torchao.quantization.pt2e import ( + move_exported_model_to_eval, + move_exported_model_to_train, +) +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e + +logger = logging.getLogger(__name__) + + +class MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.sequential = torch.nn.Sequential( + torch.nn.Linear(1, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + + def forward(self, x): + return self.sequential(x) + + +def evaluate_model(model, inputs, expected_outputs): + with torch.no_grad(): + test_outputs = model(inputs) + loss = torch.nn.functional.mse_loss(test_outputs, expected_outputs) + logger.info(f"Mean squared error: {loss.item()}") + + +def test_qat_training_loop(): + """Test the QAT training loop with a simple MLP model. + This function creates a simple MLP model, prepares it for QAT, runs a training loop, + and evaluates the quantized model to make sure everything works as expected.""" + + model = MLP() + logger.info("Starting training loop test") + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + for epoch in range(100): + model.train() + optimizer.zero_grad() + inputs = torch.randn(100, 1).clamp(-1, 1) + outputs = model(inputs) + loss = torch.nn.functional.mse_loss(outputs, torch.sin(inputs)) + loss.backward() + optimizer.step() + if epoch % 5 == 0: + logger.info(f"Epoch {epoch}, Loss: {loss.item()}") + logger.info("Training loop test completed successfully") + + logger.info("Evaluating model before QAT") + test_inputs = torch.randn(20, 1).clamp(-1, 1) + test_outputs = torch.sin(test_inputs) + evaluate_model(model, test_inputs, test_outputs) + + exported_model = export(model, (torch.randn(1, 1),), strict=True) + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + quantizer.set_global(get_symmetric_quantization_config(is_qat=True)) + + prepared_model = prepare_qat_pt2e(exported_model.module(), quantizer) + prepared_model = move_exported_model_to_train(prepared_model) + logger.info("QAT model prepared successfully") + + logger.info("Starting QAT training loop") + + for epoch in range(25): + inputs = torch.randn(100, 1).clamp(-1, 1) + optimizer.zero_grad() + outputs = prepared_model(inputs) + loss = torch.nn.functional.mse_loss(outputs, torch.sin(inputs)) + loss.backward() + optimizer.step() + if epoch % 5 == 0: + logger.info(f"QAT Epoch {epoch}, Loss: {loss.item()}") + logger.info("QAT training loop completed successfully") + prepared_model = move_exported_model_to_eval(prepared_model) + + quantized_model = convert_pt2e(prepared_model) + logger.info("QAT model quantized successfully") + + logger.info("Evaluating quantized model") + test_inputs = torch.randn(100, 1).clamp(-1, 1) + test_outputs = torch.sin(test_inputs) + evaluate_model(quantized_model, test_inputs, test_outputs) diff --git a/backends/arm/test/misc/test_quant_custom_meta.py b/backends/arm/test/misc/test_quant_custom_meta.py new file mode 100644 index 00000000000..90948c03829 --- /dev/null +++ b/backends/arm/test/misc/test_quant_custom_meta.py @@ -0,0 +1,105 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT + + +class AddSigmoidMul(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, y): + return self.sigmoid(x + y) * x + + +@pytest.mark.parametrize("fp_extension", [True, False]) +def test_qdq_squeezed_fp_op(fp_extension: bool): + """Test that a float operation surrounded by quantize-dequantize pairs + is correctly handled by the partitioner and the TOSA backend. + Pattern: + q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q + |_____unquantized_____| + """ + aten_op = "torch.ops.aten.add.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + module = AddSigmoidMul() + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + pipeline = TosaPipelineINT( + module=module, + test_data=(x, y), + aten_op=aten_op, + exir_op=exir_op, + tosa_extensions=["FP"] if fp_extension else None, + ) + pipeline.quantizer.set_module_type(torch.nn.Sigmoid, None) # type: ignore + + if not fp_extension: + # In case we don't have the FP extension, the unquantized part of the + # graph should not be delegated to the Arm backend. Modify the op count + # checks to reflect this behavior. + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 2, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + }, + ) + pipeline.run() + + +class MulAddSigmoidConv(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + self.conv = torch.nn.Conv1d(3, 3, 1) + + def forward(self, x, y): + return self.conv(self.sigmoid(x + y * x)) + + +@pytest.mark.parametrize("fp_extension", [True, False]) +def test_quantized_to_float_transition(fp_extension: bool): + """Test that a model executing quantized ops followed by float ops + is correctly handled by the partitioner and the TOSA backend. + Pattern: + q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv + |___unquantized___| + """ + aten_op = "torch.ops.aten.add.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + module = MulAddSigmoidConv() + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + pipeline = TosaPipelineINT( + module=module, + test_data=(x, y), + aten_op=aten_op, + exir_op=exir_op, + tosa_extensions=["FP"] if fp_extension else None, + ) + if not fp_extension: + # In case we don't have the FP extension, the unquantized part of the + # graph should not be delegated to the Arm backend. Modify the op count + # checks to reflect this behavior. + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 1, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + }, + ) + pipeline.quantizer.set_module_type(torch.nn.Sigmoid, None) # type: ignore + pipeline.quantizer.set_module_type(torch.nn.Conv1d, None) # type: ignore + + pipeline.run() diff --git a/backends/arm/test/misc/test_save_exported_model.py b/backends/arm/test/misc/test_save_exported_model.py new file mode 100644 index 00000000000..791294cdb54 --- /dev/null +++ b/backends/arm/test/misc/test_save_exported_model.py @@ -0,0 +1,64 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import torch +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo +from executorch.backends.arm.quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.tosa import TosaSpecification +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +class SimpleModule(torch.nn.Module): + example_inputs = (torch.randn(1, 10),) + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +def test_save_load_exported_int_model(): + module = SimpleModule().eval() + example_inputs = module.example_inputs + exported_module = torch.export.export(module, example_inputs) + + # Set up quantizer + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + quantizer.set_global(get_symmetric_quantization_config()) + # Quantize model + prepared_module = prepare_pt2e(exported_module.module(), quantizer) + prepared_module(*example_inputs) + quantized_module = convert_pt2e(prepared_module) + quantized_exported_module = torch.export.export(quantized_module, example_inputs) + + base_path = "arm_test/misc/" + if not os.path.exists(base_path): + os.makedirs(base_path) + file_path = base_path + "exported_module.pt2" + # Verify that we can save the model + torch.export.save(quantized_exported_module, file_path) + + # Verify that we can load the model back + loaded_model = torch.export.load( + file_path + ) # nosec B614 - loads trusted test artifact + for original_node, loaded_node in zip( + quantized_exported_module.graph.nodes, loaded_model.graph.nodes + ): + # Verify that the custom metadata is preserved after save/load + assert original_node.meta.get("custom", {}) == loaded_node.meta.get( + "custom", {} + ) + if original_node.target == torch.ops.aten.linear.default: + assert ArmAnnotationInfo.CUSTOM_META_KEY in original_node.meta.get( + "custom", {} + ) diff --git a/backends/arm/test/misc/test_tosa_dialect_conv2d.py b/backends/arm/test/misc/test_tosa_dialect_conv2d.py new file mode 100644 index 00000000000..3496ca0d5b6 --- /dev/null +++ b/backends/arm/test/misc/test_tosa_dialect_conv2d.py @@ -0,0 +1,250 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import executorch.backends.arm.tosa.dialect # noqa: unused +import pytest +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def test_conv2d_tosa_INT(): + sample_inputs = [ + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randint(-127, 127, (8, 2, 5, 5), dtype=torch.int8), + torch.randint(-(2**31), 2**31, (8,), dtype=torch.int32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + (1, 8, 20, 20), + torch.int32, + ), + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randint(-127, 127, (4, 2, 5, 5), dtype=torch.int8), + None, + [2, 2], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + (1, 4, 10, 10), + torch.int32, + ), + ] + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + for sample_input, expected_output_shape, expected_output_type in sample_inputs: + output = exir_ops.backend.tosa.CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) + assert ( + output.dtype == expected_output_type + ), f"Expected output dtype {expected_output_type} but got {output.dtype}" + assert ( + tuple(output.shape) == expected_output_shape + ), f"Expected output shape {expected_output_shape} but got {tuple(output.shape)}" + + +def test_conv2d_invalid_tosa_INT(): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + sample_inputs = [ + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((8, 2, 5, 5), dtype=torch.float32), + torch.randn((8,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + TosaValueError, + f"doesn't support {torch.float32} but found input type {torch.float32}", + ), + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randn((8, 2, 5, 5), dtype=torch.float32), + torch.randn((8,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + TosaValueError, + f"only supports {torch.int8} weights for {torch.int8} input but found {torch.float32}", + ), + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randint(-127, 127, (8, 2, 5, 5), dtype=torch.int8), + torch.randn((8,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + TosaValueError, + f"only supports {torch.int32} bias for {torch.int8} input but found {torch.float32}", + ), + ] + + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + for sample_input, expected_error, expected_error_str in sample_inputs: + with pytest.raises(expected_error, match=expected_error_str): + exir_ops.backend.tosa.CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) + + +def test_conv2d_tosa_FP(): + sample_inputs = [ + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((8, 2, 5, 5), dtype=torch.float32), + torch.randn((8,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + (1, 8, 20, 20), + torch.float32, + ), + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((4, 2, 5, 5), dtype=torch.float32), + None, + [2, 2], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + (1, 4, 10, 10), + torch.float32, + ), + ] + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ), FakeTensorMode() as mode: + for sample_input, expected_output_shape, expected_output_type in sample_inputs: + output = exir_ops.backend.tosa.CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) + assert ( + output.dtype == expected_output_type + ), f"Expected output dtype {expected_output_type} but got {output.dtype}" + assert ( + tuple(output.shape) == expected_output_shape + ), f"Expected output shape {expected_output_shape} but got {tuple(output.shape)}" + + +def test_conv2d_invalid_tosa_FP(): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") + sample_inputs = [ + ( + ( + torch.randint(-127, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randn((8, 2, 5, 5), dtype=torch.float32), + torch.randn((8,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + TosaValueError, + f"doesn't support {torch.int8} but found input type {torch.int8}", + ), + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((8, 2, 5, 5), dtype=torch.float16), + torch.randn((8,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + TosaValueError, + f"requires weights {torch.float16} to be of the same type as input {torch.float32}", + ), + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((8, 2, 5, 5), dtype=torch.float32), + torch.randn((8,), dtype=torch.float16), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 4, + ), + TosaValueError, + f"requires bias {torch.float16} to be of the same type as input {torch.float32}", + ), + ] + + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + for sample_input, expected_error, expected_error_str in sample_inputs: + with pytest.raises(expected_error, match=expected_error_str): + exir_ops.backend.tosa.CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) diff --git a/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py b/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py new file mode 100644 index 00000000000..8b50df20830 --- /dev/null +++ b/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py @@ -0,0 +1,257 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import executorch.backends.arm.tosa.dialect # noqa: unused +import pytest +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def test_depthwise_conv2d_tosa_INT(): + sample_inputs = [ + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + # weight shape is [H, m_length, W, in_channels], where m_length = out_channels // in_channels + torch.randint(-127, 127, (5, 2, 5, 8), dtype=torch.int8), + torch.randint(-(2**31), 2**31, (16,), dtype=torch.int32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + (1, 16, 20, 20), + torch.int32, + ), + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + # weight shape is [H, m_length, W, in_channels], where m_length = out_channels // in_channels + torch.randint(-127, 127, (5, 4, 5, 8), dtype=torch.int8), + None, + [2, 2], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + (1, 32, 10, 10), + torch.int32, + ), + ] + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + for sample_input, expected_output_shape, expected_output_type in sample_inputs: + output = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) + assert ( + output.dtype == expected_output_type + ), f"Expected output dtype {expected_output_type} but got {output.dtype}" + assert ( + tuple(output.shape) == expected_output_shape + ), f"Expected output shape {expected_output_shape} but got {tuple(output.shape)}" + + +def test_depthwise_conv2d_invalid_tosa_INT(): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + sample_inputs = [ + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + # weight shape is [H, m_length, W, in_channels], where m_length = out_channels // in_channels + torch.randn((5, 2, 5, 8), dtype=torch.float32), + torch.randn((16,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + TosaValueError, + f"doesn't support {torch.float32} but found input type {torch.float32}", + ), + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + # weight shape is [H, m_length, W, in_channels], where m_length = out_channels // in_channels + torch.randn((5, 2, 5, 8), dtype=torch.float32), + torch.randn((16,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + TosaValueError, + f"only supports {torch.int8} weights for {torch.int8} input but found {torch.float32}", + ), + ( + ( + torch.randint(-128, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randint(-127, 127, (5, 2, 5, 8), dtype=torch.int8), + torch.randn((16,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + TosaValueError, + f"only supports {torch.int32} bias for {torch.int8} input but found {torch.float32}", + ), + ] + + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + for sample_input, expected_error, expected_error_str in sample_inputs: + with pytest.raises(expected_error, match=expected_error_str): + exir_ops.backend.tosa.DEPTHWISE_CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) + + +def test_depthwise_conv2d_tosa_FP(): + sample_inputs = [ + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + # weight shape is [H, m_length, W, in_channels], where m_length = out_channels // in_channels + torch.randn((5, 2, 5, 8), dtype=torch.float32), + torch.randn((16,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + (1, 16, 20, 20), + torch.float32, + ), + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + # weight shape is [H, m_length, W, in_channels], where m_length = out_channels // in_channels + torch.randn((5, 4, 5, 8), dtype=torch.float32), + None, + [2, 2], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + (1, 32, 10, 10), + torch.float32, + ), + ] + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ), FakeTensorMode() as mode: + for sample_input, expected_output_shape, expected_output_type in sample_inputs: + output = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) + assert ( + output.dtype == expected_output_type + ), f"Expected output dtype {expected_output_type} but got {output.dtype}" + assert ( + tuple(output.shape) == expected_output_shape + ), f"Expected output shape {expected_output_shape} but got {tuple(output.shape)}" + + +def test_depthwise_conv2d_invalid_tosa_FP(): + + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") + + sample_inputs = [ + ( + ( + torch.randint(-127, 127, (1, 8, 20, 20), dtype=torch.int8), + torch.randn((5, 2, 5, 8), dtype=torch.float32), + torch.randn((16,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + TosaValueError, + f"doesn't support {torch.int8} but found input type {torch.int8}", + ), + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((5, 2, 5, 8), dtype=torch.float16), + torch.randn((16,), dtype=torch.float32), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + TosaValueError, + f"requires weights {torch.float16} to be of the same type as input {torch.float32}", + ), + ( + ( + torch.randn((1, 8, 20, 20), dtype=torch.float32), + torch.randn((5, 2, 5, 8), dtype=torch.float32), + torch.randn((16,), dtype=torch.float16), + [1, 1], + [2, 2, 2, 2], + [1, 1], + False, + [0, 0], + 8, + ), + TosaValueError, + f"requires bias {torch.float16} to be of the same type as input {torch.float32}", + ), + ] + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + for sample_input, expected_error, expected_error_str in sample_inputs: + with pytest.raises(expected_error, match=expected_error_str): + exir_ops.backend.tosa.CONV2D.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input + ] + ) + ) diff --git a/backends/arm/test/misc/test_tosa_spec.py b/backends/arm/test/misc/test_tosa_spec.py index 968512f54c6..91a5bc19728 100644 --- a/backends/arm/test/misc/test_tosa_spec.py +++ b/backends/arm/test/misc/test_tosa_spec.py @@ -6,12 +6,11 @@ import unittest from executorch.backends.arm.tosa.specification import ( - get_tosa_spec, Tosa_1_00, TosaSpecification, + TosaSpecMapping, ) -from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized # type: ignore[import-untyped] test_valid_strings = [ @@ -43,14 +42,6 @@ "TOSA-1.0.0+BF16+fft+int4+cf+INT", ] -test_compile_specs = [ - ([CompileSpec("tosa_spec", "TOSA-1.0.0+INT".encode())],), -] - -test_compile_specs_no_version = [ - ([CompileSpec("other_key", "some_value".encode())],), -] - class TestTosaSpecification(unittest.TestCase): """Tests the TOSA specification class""" @@ -74,21 +65,105 @@ def test_invalid_version_strings(self, version_string: str): assert tosa_spec is None - @parameterized.expand(test_compile_specs) # type: ignore[misc] - def test_create_from_compilespec(self, compile_specs: list[CompileSpec]): - tosa_spec = get_tosa_spec(compile_specs) - assert isinstance(tosa_spec, TosaSpecification) - - @parameterized.expand(test_compile_specs_no_version) # type: ignore[misc] - def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec]): - tosa_spec = None - with self.assertRaises(ValueError): - tosa_spec = get_tosa_spec(compile_specs) - - assert tosa_spec is None - @parameterized.expand(test_valid_strings) def test_correct_string_representation(self, version_string: str): tosa_spec = TosaSpecification.create_from_string(version_string) assert isinstance(tosa_spec, Tosa_1_00) assert f"{tosa_spec}" == version_string + + +class TestTosaSpecMapping(unittest.TestCase): + """Tests the TosaSpecMapping class""" + + def test_mapping(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + # check that the mapping is correct + vals = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + assert vals == ["A"] + assert len(vals) == 1 + + def test_mapping_multiple(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "B") + # check that the mapping is correct + vals = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + assert vals == ["A", "B"] + assert len(vals) == 2 + + def test_mapping_different_profiles(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "B") + # check that the mapping is correct + vals_int = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + vals_fp = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+FP")) + + assert vals_int == ["A"] + assert vals_fp == ["B"] + assert len(vals_int) == 1 + assert len(vals_fp) == 1 + + def test_mapping_different_profiles_combined_consumer(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "B") + # check that the mapping is correct + combined_vals = mapping.get( + TosaSpecification.create_from_string("TOSA-1.0+INT+FP") + ) + + assert "A" in combined_vals + assert "B" in combined_vals + assert len(combined_vals) == 2 + + def test_mapping_no_spec(self): + mapping = TosaSpecMapping() + with self.assertRaises(KeyError): + mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + def test_mapping_no_values_for_spec(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "A") + with self.assertRaises(KeyError): + mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + def test_spec_with_different_profiles(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "A") + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "B") + # check that the mapping is correct + vals_int = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + vals_fp = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+FP")) + vals_int_fp = mapping.get( + TosaSpecification.create_from_string("TOSA-1.0+INT+FP") + ) + + assert vals_fp == ["A"] + assert vals_int == ["B"] + assert len(vals_int) == 1 + assert len(vals_fp) == 1 + assert len(vals_int_fp) == 2 + + def test_combined_profiles(self): + mapping = TosaSpecMapping() + with self.assertRaises(ValueError): + # Don't allow multiple profiles in a single spec + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT+FP"), "A") + + def test_spec_add_with_extension(self): + mapping = TosaSpecMapping() + with self.assertRaises(ValueError): + mapping.add( + TosaSpecification.create_from_string("TOSA-1.0.0+INT+int16"), "A" + ) + + def test_spec_non_canonical_key(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + + val = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT+u55")) + assert val == ["A"] diff --git a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py index 0e99f3f5bfa..938732fa91a 100644 --- a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py +++ b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py @@ -4,9 +4,8 @@ # LICENSE file in the root directory of this source tree. -import unittest +from typing import Tuple -import pytest import torch from executorch.backends.arm._passes import ( ConvertInt64ConstOpsToInt32Pass, @@ -18,26 +17,41 @@ from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import ( CLIP_text_encoder_config, ) -from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) from transformers import CLIPTextModelWithProjection +input_t = Tuple[torch.Tensor] + -class TestCLIPTextModelWithProjection(unittest.TestCase): +class TestCLIPTextModelWithProjection: """ Test class of CLIPTextModelWithProjection. CLIPTextModelWithProjection is one of the text_encoder used by Stable Diffusion 3.5 Medium """ - # Adjust nbr below as we increase op support. Note: most of the delegates - # calls are directly consecutive to each other in the .pte. The reason - # for that is some assert ops are removed by passes in the - # .to_executorch step, i.e. after Arm partitioner. - ops_after_partitioner = { + # Adjust nbr below as we increase op support. + ops_after_partitioner_FP = { + "executorch_exir_dialects_edge__ops_aten_argmax_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, + "torch.ops.higher_order.executorch_call_delegate": 2, + } + + ops_after_partitioner_INT = { "executorch_exir_dialects_edge__ops_aten_argmax_default": 1, + "executorch_exir_dialects_edge__ops_aten_index_select_default": 1, + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, "torch.ops.higher_order.executorch_call_delegate": 2, } + ops_after_partitioner_vgf_quantize = ops_after_partitioner_FP + ops_after_partitioner_vgf_no_quantize = ops_after_partitioner_FP + def _prepare_inputs( self, batch_size=12, @@ -61,46 +75,94 @@ def prepare_model_and_inputs(self): return text_encoder_model, text_encoder_model_inputs - def test_CLIPTextModelWithProjection_tosa_FP(self): - text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs() - with torch.no_grad(): - ( - ArmTester( - text_encoder_model, - example_inputs=text_encoder_model_inputs, - compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), - transform_passes=[ - ConvertInt64ConstOpsToInt32Pass(), - ConvertInt64OutputOpsToInt32Pass(), - InsertInt32CastsAfterInt64PlaceholdersPass(), - ], - ) - .export() - .to_edge_transform_and_lower() - .dump_operator_distribution() - .check_count(self.ops_after_partitioner) - .to_executorch() - .run_method_and_compare_outputs( - inputs=text_encoder_model_inputs, - ) - ) - - @pytest.mark.xfail(raises=AssertionError, reason="Output difference.") - def test_CLIPTextModelWithProjection_tosa_INT(self): - text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs() - with torch.no_grad(): - ( - ArmTester( - text_encoder_model, - example_inputs=text_encoder_model_inputs, - compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), - ) - .quantize() - .export() - .to_edge_transform_and_lower() - .dump_operator_distribution() - .to_executorch() - .run_method_and_compare_outputs( - inputs=text_encoder_model_inputs, - ) - ) + +def test_CLIPTextModelWithProjection_tosa_FP(): + text_encoder_model, text_encoder_model_inputs = ( + TestCLIPTextModelWithProjection().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = TosaPipelineFP[input_t]( + text_encoder_model, + text_encoder_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + transform_passes=[ + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + InsertInt32CastsAfterInt64PlaceholdersPass(), + ], + ) + pipeline.change_args( + "check_count.exir", TestCLIPTextModelWithProjection.ops_after_partitioner_FP + ) + pipeline.run() + + +def test_CLIPTextModelWithProjection_tosa_INT(): + text_encoder_model, text_encoder_model_inputs = ( + TestCLIPTextModelWithProjection().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = TosaPipelineINT[input_t]( + text_encoder_model, + text_encoder_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + atol=0.8, + ) + pipeline.change_args( + "check_count.exir", + TestCLIPTextModelWithProjection.ops_after_partitioner_INT, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_CLIPTextModelWithProjection_vgf_no_quant(): + text_encoder_model, text_encoder_model_inputs = ( + TestCLIPTextModelWithProjection().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = VgfPipeline[input_t]( + text_encoder_model, + text_encoder_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + atol=4, + transform_passes=[ + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + InsertInt32CastsAfterInt64PlaceholdersPass(), + ], + quantize=False, + ) + pipeline.change_args( + "check_count.exir", + TestCLIPTextModelWithProjection.ops_after_partitioner_vgf_no_quantize, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_CLIPTextModelWithProjection_vgf_quant(): + text_encoder_model, text_encoder_model_inputs = ( + TestCLIPTextModelWithProjection().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = VgfPipeline[input_t]( + text_encoder_model, + text_encoder_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + atol=0.8, + quantize=True, + ) + pipeline.change_args( + "check_count.exir", + TestCLIPTextModelWithProjection.ops_after_partitioner_vgf_quantize, + ) + pipeline.run() diff --git a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py index f9d814d044b..cd76c691c72 100644 --- a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py +++ b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py @@ -4,19 +4,27 @@ # LICENSE file in the root directory of this source tree. -import unittest +from typing import Tuple import torch -from diffusers.models.transformers import SD3Transformer2DModel +from diffusers.models.transformers import ( # type: ignore[import-not-found] + SD3Transformer2DModel, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import ( SD3Transformer2DModel_init_dict, ) -from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +input_t4 = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] -class TestSD3Transformer2DModel(unittest.TestCase): +class TestSD3Transformer2DModel: """ Test class of AutoenSD3Transformer2DModelcoderKL. SD3Transformer2DModel is the transformer model used by Stable Diffusion 3.5 Medium @@ -24,20 +32,23 @@ class TestSD3Transformer2DModel(unittest.TestCase): # Adjust nbr below as we increase op support. ops_after_partitioner_FP = { - "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1, - "executorch_exir_dialects_edge__ops_aten_view_copy_default": 2, "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, "torch.ops.higher_order.executorch_call_delegate": 1, } ops_after_partitioner_INT = { - "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, - "executorch_exir_dialects_edge__ops_aten_view_copy_default": 2, "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, - "torch.ops.higher_order.executorch_call_delegate": 2, + "torch.ops.higher_order.executorch_call_delegate": 3, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, } + ops_after_partitioner_vgf_quantize = { + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, + "torch.ops.higher_order.executorch_call_delegate": 1, + } + ops_after_partitioner_vgf_no_quantize = ops_after_partitioner_FP + def _prepare_inputs( self, batch_size=2, @@ -93,48 +104,90 @@ def forward(self, *args, **kwargs): return sd35_transformer2D_model, sd35_transformer2D_model_inputs - def test_SD3Transformer2DModel_tosa_FP(self): - sd35_transformer2D_model, sd35_transformer2D_model_inputs = ( - self.prepare_model_and_inputs() - ) - with torch.no_grad(): - ( - ArmTester( - sd35_transformer2D_model, - example_inputs=sd35_transformer2D_model_inputs, - compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), - ) - .export() - .to_edge_transform_and_lower() - .check_count(self.ops_after_partitioner_FP) - .to_executorch() - .run_method_and_compare_outputs( - inputs=sd35_transformer2D_model_inputs, - rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT - atol=4.0, - ) - ) - def test_SD3Transformer2DModel_tosa_INT(self): - sd35_transformer2D_model, sd35_transformer2D_model_inputs = ( - self.prepare_model_and_inputs() +def test_SD3Transformer2DModel_tosa_FP(): + sd35_transformer2D_model, sd35_transformer2D_model_inputs = ( + TestSD3Transformer2DModel().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = TosaPipelineFP[input_t4]( + sd35_transformer2D_model, + sd35_transformer2D_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT + atol=4.0, ) - with torch.no_grad(): - ( - ArmTester( - sd35_transformer2D_model, - example_inputs=sd35_transformer2D_model_inputs, - compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), - ) - .quantize() - .export() - .to_edge_transform_and_lower() - .check_count(self.ops_after_partitioner_INT) - .to_executorch() - .run_method_and_compare_outputs( - inputs=sd35_transformer2D_model_inputs, - qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT - rtol=1.0, - atol=4.0, - ) - ) + pipeline.change_args( + "check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_FP + ) + pipeline.run() + + +def test_SD3Transformer2DModel_tosa_INT(): + sd35_transformer2D_model, sd35_transformer2D_model_inputs = ( + TestSD3Transformer2DModel().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = TosaPipelineINT[input_t4]( + sd35_transformer2D_model, + sd35_transformer2D_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT + rtol=1.0, + atol=4.0, + ) + pipeline.change_args( + "check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_INT + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_SD3Transformer2DModel_vgf_no_quant(): + sd35_transformer2D_model, sd35_transformer2D_model_inputs = ( + TestSD3Transformer2DModel().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = VgfPipeline[input_t4]( + sd35_transformer2D_model, + sd35_transformer2D_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT, + atol=4.0, + quantize=False, + ) + pipeline.change_args( + "check_count.exir", + TestSD3Transformer2DModel.ops_after_partitioner_vgf_no_quantize, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_SD3Transformer2DModel_vgf_quant(): + sd35_transformer2D_model, sd35_transformer2D_model_inputs = ( + TestSD3Transformer2DModel().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = VgfPipeline[input_t4]( + sd35_transformer2D_model, + sd35_transformer2D_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + qtol=1.0, + rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT, + atol=4.0, + quantize=True, + ) + pipeline.change_args( + "check_count.exir", + TestSD3Transformer2DModel.ops_after_partitioner_vgf_quantize, + ) + pipeline.run() diff --git a/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py b/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py index 22a47042eb1..7ab7f86f449 100644 --- a/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py +++ b/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. -import unittest +from typing import Tuple import torch from executorch.backends.arm._passes import ( @@ -17,11 +17,17 @@ from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import ( T5_encoder_config, ) -from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) from transformers import T5EncoderModel +input_t = Tuple[torch.Tensor] + -class TestT5EncoderModel(unittest.TestCase): +class TestT5EncoderModel: """ Test class of T5EncoderModel. T5EncoderModel is one of the text_encoder used by Stable Diffusion 3.5 Medium @@ -38,6 +44,13 @@ class TestT5EncoderModel(unittest.TestCase): "torch.ops.higher_order.executorch_call_delegate": 3, } + ops_after_partitioner_vgf_quantize = { + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, + "torch.ops.higher_order.executorch_call_delegate": 1, + } + + ops_after_partitioner_vgf_no_quantize = ops_after_partitioner_vgf_quantize + def _prepare_inputs( self, batch_size=12, @@ -61,46 +74,88 @@ def prepare_model_and_inputs(self): return t5_encoder_model, t5_encoder_model_inputs - def test_T5EncoderModel_tosa_FP(self): - t5_encoder_model, t5_encoder_model_inputs = self.prepare_model_and_inputs() - with torch.no_grad(): - ( - ArmTester( - t5_encoder_model, - example_inputs=t5_encoder_model_inputs, - compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), - transform_passes=[ - ConvertInt64ConstOpsToInt32Pass(), - ConvertInt64OutputOpsToInt32Pass(), - InsertInt32CastsAfterInt64PlaceholdersPass(), - ], - ) - .export() - .to_edge_transform_and_lower() - .dump_operator_distribution() - .check_count(self.ops_after_partitioner_FP) - .to_executorch() - .run_method_and_compare_outputs( - inputs=t5_encoder_model_inputs, - ) - ) - - def test_T5EncoderModel_tosa_INT(self): - t5_encoder_model, t5_encoder_model_inputs = self.prepare_model_and_inputs() - with torch.no_grad(): - ( - ArmTester( - t5_encoder_model, - example_inputs=t5_encoder_model_inputs, - compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), - ) - .quantize() - .export() - .to_edge_transform_and_lower() - .dump_operator_distribution() - .check_count(self.ops_after_partitioner_INT) - .to_executorch() - .run_method_and_compare_outputs( - inputs=t5_encoder_model_inputs, - ) - ) + +def test_T5EncoderModel_tosa_FP(): + t5_encoder_model, t5_encoder_model_inputs = ( + TestT5EncoderModel().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = TosaPipelineFP[input_t]( + t5_encoder_model, + t5_encoder_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + transform_passes=[ + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + InsertInt32CastsAfterInt64PlaceholdersPass(), + ], + ) + pipeline.change_args( + "check_count.exir", TestT5EncoderModel.ops_after_partitioner_FP + ) + pipeline.run() + + +def test_T5EncoderModel_tosa_INT(): + t5_encoder_model, t5_encoder_model_inputs = ( + TestT5EncoderModel().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = TosaPipelineINT[input_t]( + t5_encoder_model, + t5_encoder_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.change_args( + "check_count.exir", TestT5EncoderModel.ops_after_partitioner_INT + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_T5EncoderModel_vgf_no_quant(): + t5_encoder_model, t5_encoder_model_inputs = ( + TestT5EncoderModel().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = VgfPipeline[input_t]( + t5_encoder_model, + t5_encoder_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + transform_passes=[ + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + InsertInt32CastsAfterInt64PlaceholdersPass(), + ], + quantize=False, + ) + pipeline.change_args( + "check_count.exir", TestT5EncoderModel.ops_after_partitioner_vgf_no_quantize + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_T5EncoderModel_vgf_quant(): + t5_encoder_model, t5_encoder_model_inputs = ( + TestT5EncoderModel().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = VgfPipeline[input_t]( + t5_encoder_model, + t5_encoder_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + quantize=True, + ) + pipeline.change_args( + "check_count.exir", TestT5EncoderModel.ops_after_partitioner_vgf_quantize + ) + pipeline.run() diff --git a/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py b/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py index ab0f4892fb8..cb5f93f55d8 100644 --- a/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py +++ b/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py @@ -4,20 +4,30 @@ # LICENSE file in the root directory of this source tree. -import unittest +from typing import Tuple import torch -from diffusers.models.autoencoders import AutoencoderKL -from diffusers.utils.testing_utils import floats_tensor +from diffusers.models.autoencoders import ( # type: ignore[import-not-found] + AutoencoderKL, +) +from diffusers.utils.testing_utils import ( # type: ignore[import-not-found] + floats_tensor, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import ( AutoencoderKL_config, ) -from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) +input_t = Tuple[torch.Tensor] -class TestAutoencoderKL(unittest.TestCase): + +class TestAutoencoderKL: """ Test class of AutoencoderKL. AutoencoderKL is the encoder/decoder used by Stable Diffusion 3.5 Medium @@ -41,40 +51,68 @@ def forward(self, *args, **kwargs): return auto_encoder_model, auto_encoder_model_inputs - def test_AutoencoderKL_tosa_FP(self): - auto_encoder_model, auto_encoder_model_inputs = self.prepare_model_and_inputs() - with torch.no_grad(): - ( - ArmTester( - auto_encoder_model, - example_inputs=auto_encoder_model_inputs, - compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), - ) - .export() - .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs( - inputs=auto_encoder_model_inputs, - ) - ) - - def test_AutoencoderKL_tosa_INT(self): - auto_encoder_model, auto_encoder_model_inputs = self.prepare_model_and_inputs() - with torch.no_grad(): - ( - ArmTester( - auto_encoder_model, - example_inputs=auto_encoder_model_inputs, - compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), - ) - .quantize() - .export() - .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs( - inputs=auto_encoder_model_inputs, - atol=1.0, # TODO: MLETORCH-990 Reduce tolerance of vae(AutoencoderKL) with INT - ) - ) + +def test_AutoencoderKL_tosa_FP(): + auto_encoder_model, auto_encoder_model_inputs = ( + TestAutoencoderKL().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = TosaPipelineFP[input_t]( + auto_encoder_model, + auto_encoder_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +def test_AutoencoderKL_tosa_INT(): + auto_encoder_model, auto_encoder_model_inputs = ( + TestAutoencoderKL().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = TosaPipelineINT[input_t]( + auto_encoder_model, + auto_encoder_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + atol=1.0, # TODO: MLETORCH-990 Reduce tolerance of vae(AutoencoderKL) with INT + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_AutoencoderKL_vgf_no_quant(): + auto_encoder_model, auto_encoder_model_inputs = ( + TestAutoencoderKL().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = VgfPipeline[input_t]( + auto_encoder_model, + auto_encoder_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + quantize=False, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_AutoencoderKL_vgf_quant(): + auto_encoder_model, auto_encoder_model_inputs = ( + TestAutoencoderKL().prepare_model_and_inputs() + ) + with torch.no_grad(): + pipeline = VgfPipeline[input_t]( + auto_encoder_model, + auto_encoder_model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + atol=1.0, # TODO: MLETORCH-990 Reduce tolerance of vae(AutoencoderKL) with INT + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/models/test_T5ForConditionalGeneration_arm.py b/backends/arm/test/models/test_T5ForConditionalGeneration_arm.py new file mode 100644 index 00000000000..d96eaae32db --- /dev/null +++ b/backends/arm/test/models/test_T5ForConditionalGeneration_arm.py @@ -0,0 +1,171 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import pytest +import torch +from executorch.backends.arm._passes import ( + ConvertInt64ConstOpsToInt32Pass, + ConvertInt64OutputOpsToInt32Pass, + InsertInt32CastsAfterInt64PlaceholdersPass, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +from transformers import AutoTokenizer, T5ForConditionalGeneration + +input_t3 = Tuple[ + torch.LongTensor, torch.LongTensor, torch.LongTensor +] # (input_ids, attention_mask, decoder_input_ids) + + +class TestT5ForConditionalGeneration: + # Adjust nbr below as we increase op support. + ops_after_partitioner_FP = { + "executorch_exir_dialects_edge__ops_aten_where_self": 2, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 5, + "torch.ops.higher_order.executorch_call_delegate": 2, + } + + ops_after_partitioner_INT = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 3, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 10, + "torch.ops.higher_order.executorch_call_delegate": 3, + } + + ops_after_partitioner_vgf_no_quantize = { + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 4, + "torch.ops.higher_order.executorch_call_delegate": 2, + } + + ops_after_partitioner_vgf_quantize = ops_after_partitioner_vgf_no_quantize + + def _prepare_inputs( + self, + prompt: str, + ): + tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + enc = tokenizer(prompt, return_tensors="pt") + input_ids = enc.input_ids # (1, src_len) + attention_mask = enc.attention_mask # (1, src_len) + # T5 uses as BOS / decoder start + bos_id = tokenizer.pad_token_id + decoder_input_ids = torch.tensor([[bos_id]], dtype=torch.long) # (1, 1) + return input_ids, attention_mask, decoder_input_ids + + def prepare_model_and_inputs(self, prompt): + class T5ForConditionalGenerationWrapper(T5ForConditionalGeneration): + def forward(self, input_ids, attention_mask, decoder_input_ids): + out = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + use_cache=False, # simpler, export-friendly + return_dict=True, + ) + return out.logits # Tensor: (B, tgt_len=1, vocab) + + model = T5ForConditionalGenerationWrapper.from_pretrained("google-t5/t5-small") + model.config.use_cache = False + inputs = self._prepare_inputs(prompt) + + return model, inputs + + +@pytest.mark.slow +def test_T5ForConditionalGeneration_tosa_FP(): + prompt = "summarize: studies have shown that owning a dog is good for you" + model, inputs = TestT5ForConditionalGeneration().prepare_model_and_inputs(prompt) + with torch.no_grad(): + pipeline = TosaPipelineFP[input_t3]( + model, + inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + transform_passes=[ + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + InsertInt32CastsAfterInt64PlaceholdersPass(), + ], + ) + pipeline.change_args( + "check_count.exir", TestT5ForConditionalGeneration.ops_after_partitioner_FP + ) + pipeline.run() + + +@pytest.mark.slow +def test_T5ForConditionalGeneration_tosa_INT(): + prompt = "summarize: studies have shown that owning a dog is good for you" + model, inputs = TestT5ForConditionalGeneration().prepare_model_and_inputs(prompt) + with torch.no_grad(): + pipeline = TosaPipelineINT[input_t3]( + model, + inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + atol=20, # TODO: MLETORCH-1703: Reduce the tolerance of quantized T5ForConditionalGeneration + ) + pipeline.change_args( + "check_count.exir", + TestT5ForConditionalGeneration.ops_after_partitioner_INT, + ) + pipeline.run() + + +@pytest.mark.slow +@common.SkipIfNoModelConverter +def test_T5ForConditionalGeneration_vgf_no_quant(): + prompt = "summarize: studies have shown that owning a dog is good for you" + model, inputs = TestT5ForConditionalGeneration().prepare_model_and_inputs(prompt) + with torch.no_grad(): + pipeline = VgfPipeline[input_t3]( + model, + inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + transform_passes=[ + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + InsertInt32CastsAfterInt64PlaceholdersPass(), + ], + quantize=False, + ) + pipeline.change_args( + "check_count.exir", + TestT5ForConditionalGeneration.ops_after_partitioner_vgf_no_quantize, + ) + pipeline.run() + + +@pytest.mark.slow +@common.SkipIfNoModelConverter +def test_T5ForConditionalGeneration_vgf_quant(): + prompt = "summarize: studies have shown that owning a dog is good for you" + model, inputs = TestT5ForConditionalGeneration().prepare_model_and_inputs(prompt) + with torch.no_grad(): + pipeline = VgfPipeline[input_t3]( + model, + inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + atol=20, # TODO: MLETORCH-1703: Reduce the tolerance of quantized T5ForConditionalGeneration + quantize=True, + ) + pipeline.change_args( + "check_count.exir", + TestT5ForConditionalGeneration.ops_after_partitioner_vgf_quantize, + ) + pipeline.run() diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index 3119145aef1..e04d8bd44a5 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -18,7 +18,7 @@ VgfPipeline, ) -from torchaudio.models import Conformer +from torchaudio.models import Conformer # type: ignore[import-untyped] input_t = Tuple[torch.Tensor, torch.IntTensor] # Input x, y @@ -36,6 +36,10 @@ class TestConformer: # .to_executorch step, i.e. after Arm partitioner. aten_ops = ["torch.ops.aten._assert_scalar.default"] + # TODO(MLETORCH-635): reduce tolerance + atol = 0.4 + rtol = 0.4 + dim = 16 num_examples = 10 lengths = torch.randint(1, 100, (num_examples,), dtype=torch.int32) @@ -65,18 +69,11 @@ def test_conformer_tosa_INT(): pipeline = TosaPipelineINT[input_t]( TestConformer.conformer, TestConformer.model_example_inputs, - aten_op=[], # RemoveGraphAssertsPass is added in transform_for_annotation_pipeline to remove the assert ops + aten_op=[], exir_op=[], use_to_edge_transform_and_lower=True, - ) - pipeline.pop_stage("check_count.exir") - pipeline.change_args( - "run_method_and_compare_outputs", - get_test_inputs( - TestConformer.dim, TestConformer.lengths, TestConformer.num_examples - ), - rtol=1.0, - atol=3.0, + atol=TestConformer.atol, + rtol=TestConformer.rtol, ) pipeline.run() @@ -89,76 +86,53 @@ def test_conformer_u55_INT(): pipeline = EthosU55PipelineINT[input_t]( TestConformer.conformer, TestConformer.model_example_inputs, - aten_ops=TestConformer.aten_ops, + aten_ops=[], exir_ops=[], use_to_edge_transform_and_lower=True, - run_on_fvp=True, - ) - pipeline.change_args( - "run_method_and_compare_outputs", - get_test_inputs( - TestConformer.dim, TestConformer.lengths, TestConformer.num_examples - ), - rtol=1.0, - atol=5.0, + atol=TestConformer.atol, + rtol=TestConformer.rtol, ) + pipeline.pop_stage("check_count.exir") pipeline.run() @common.XfailIfNoCorstone320 -@pytest.mark.xfail(reason="All IO needs to have the same data type (MLETORCH-635)") def test_conformer_u85_INT(): pipeline = EthosU85PipelineINT[input_t]( TestConformer.conformer, TestConformer.model_example_inputs, - aten_ops=TestConformer.aten_ops, + aten_ops=[], exir_ops=[], use_to_edge_transform_and_lower=True, - run_on_fvp=True, - ) - pipeline.change_args( - "run_method_and_compare_outputs", - get_test_inputs( - TestConformer.dim, TestConformer.lengths, TestConformer.num_examples - ), - rtol=1.0, - atol=5.0, + atol=TestConformer.atol, + rtol=TestConformer.rtol, ) pipeline.run() @common.SkipIfNoModelConverter -def test_conformer_vgf_INT(): +def test_conformer_vgf_quant(): pipeline = VgfPipeline[input_t]( TestConformer.conformer, TestConformer.model_example_inputs, - aten_op=[], # RemoveGraphAssertsPass is added in transform_for_annotation_pipeline to remove the assert ops + aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, + atol=TestConformer.atol, + rtol=TestConformer.rtol, + quantize=True, ) - pipeline.pop_stage("check_count.exir") - - # TODO: MLETORCH-1167 Create Vulkan backend e2e tests - # pipeline.change_args( - # "run_method_and_compare_outputs", - # get_test_inputs( - # TestConformer.dim, TestConformer.lengths, TestConformer.num_examples - # ), - # rtol=1.0, - # atol=3.0, - # ) pipeline.run() @common.SkipIfNoModelConverter -def test_conformer_vgf_FP(): +def test_conformer_vgf_no_quant(): pipeline = VgfPipeline[input_t]( TestConformer.conformer, TestConformer.model_example_inputs, aten_op=TestConformer.aten_ops, exir_op=[], - tosa_version="TOSA-1.0+FP", use_to_edge_transform_and_lower=True, + quantize=False, ) pipeline.run() diff --git a/backends/arm/test/models/test_deit_tiny_arm.py b/backends/arm/test/models/test_deit_tiny_arm.py index 22685a079bd..c53ab4fa0a9 100644 --- a/backends/arm/test/models/test_deit_tiny_arm.py +++ b/backends/arm/test/models/test_deit_tiny_arm.py @@ -3,30 +3,30 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging - from typing import Tuple -import timm +import timm # type: ignore[import-untyped] import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, TosaPipelineFP, TosaPipelineINT, VgfPipeline, ) -from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from torchvision import transforms - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - +from timm.data import ( # type: ignore[import-untyped] + IMAGENET_INCEPTION_MEAN, + IMAGENET_INCEPTION_STD, +) +from torchvision import transforms # type: ignore[import-untyped] deit_tiny = timm.models.deit.deit_tiny_patch16_224(pretrained=True) + deit_tiny.eval() normalize = transforms.Normalize( @@ -61,29 +61,60 @@ def test_deit_tiny_tosa_INT(): pipeline.run() +def test_deit_tiny_u55_INT(): + pipeline = EthosU55PipelineINT[input_t]( + deit_tiny, + model_inputs, + aten_ops=[], + exir_ops=[], + use_to_edge_transform_and_lower=True, + atol=1.5, + qtol=1, + ) + # Multiple partitions + pipeline.pop_stage("check_count.exir") + # Don't run inference as model is too large for Corstone-300 + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +@common.XfailIfNoCorstone320 +def test_deit_tiny_u85_INT(): + pipeline = EthosU85PipelineINT[input_t]( + deit_tiny, + model_inputs, + aten_ops=[], + exir_ops=[], + use_to_edge_transform_and_lower=True, + atol=1.5, + qtol=1, + ) + pipeline.run() + + @common.SkipIfNoModelConverter -def test_deit_tiny_vgf_INT(): +def test_deit_tiny_vgf_quant(): pipeline = VgfPipeline[input_t]( deit_tiny, model_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, atol=1.5, qtol=1, + quantize=True, ) pipeline.run() @common.SkipIfNoModelConverter -def test_deit_tiny_vgf_FP(): +def test_deit_tiny_vgf_no_quant(): pipeline = VgfPipeline[input_t]( deit_tiny, model_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+FP", use_to_edge_transform_and_lower=True, + quantize=False, ) pipeline.run() diff --git a/backends/arm/test/models/test_dl3_arm.py b/backends/arm/test/models/test_dl3_arm.py index 2000ac34794..8e10001d755 100644 --- a/backends/arm/test/models/test_dl3_arm.py +++ b/backends/arm/test/models/test_dl3_arm.py @@ -66,7 +66,6 @@ def test_dl3_u55_INT(): TestDl3.model_example_inputs, aten_ops=[], exir_ops=[], - run_on_fvp=True, ) pipeline.change_args( "run_method_and_compare_outputs", rtol=1.0, atol=1.0 @@ -82,7 +81,6 @@ def test_dl3_u85_INT(): TestDl3.model_example_inputs, aten_ops=[], exir_ops=[], - run_on_fvp=True, ) pipeline.change_args( "run_method_and_compare_outputs", rtol=1.0, atol=1.0 @@ -91,34 +89,30 @@ def test_dl3_u85_INT(): @common.SkipIfNoModelConverter -def test_dl3_vgf_INT(): +def test_dl3_vgf_quant(): pipeline = VgfPipeline[input_t]( TestDl3.dl3, TestDl3.model_example_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, + run_on_vulkan_runtime=True, + quantize=True, ) - # TODO: MLETORCH-1167 Create Vulkan backend e2e tests - # pipeline.change_args( - # "run_method_and_compare_outputs", rtol=1.0, atol=1.0 - # ) + pipeline.change_args( + "run_method_and_compare_outputs", rtol=0.1, atol=0.1 + ) # TODO: MLETORCH-1036 decrease tolerance pipeline.run() @common.SkipIfNoModelConverter -def test_dl3_vgf_FP(): +def test_dl3_vgf_no_quant(): pipeline = VgfPipeline[input_t]( TestDl3.dl3, TestDl3.model_example_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+FP", use_to_edge_transform_and_lower=True, + quantize=False, ) - # TODO: MLETORCH-1167 Create Vulkan backend e2e tests - # pipeline.change_args( - # "run_method_and_compare_outputs", rtol=1.0, atol=1.0 - # ) pipeline.run() diff --git a/backends/arm/test/models/test_inception_v3_arm.py b/backends/arm/test/models/test_inception_v3_arm.py index f973521c1fa..0614ca23036 100644 --- a/backends/arm/test/models/test_inception_v3_arm.py +++ b/backends/arm/test/models/test_inception_v3_arm.py @@ -5,11 +5,12 @@ from typing import Tuple -import common import pytest import torch +from executorch.backends.arm.test import common + from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -18,7 +19,7 @@ VgfPipeline, ) -from torchvision import models, transforms +from torchvision import models, transforms # type: ignore[import-untyped] ic3 = models.inception_v3(weights=models.Inception_V3_Weights) ic3 = ic3.eval() @@ -66,7 +67,6 @@ def test_ic3_u55_BI(): model_inputs, aten_ops=[], exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, atol=0.6, qtol=1, @@ -83,7 +83,6 @@ def test_ic3_u85_BI(): model_inputs, aten_ops=[], exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, atol=0.6, qtol=1, @@ -94,14 +93,14 @@ def test_ic3_u85_BI(): @pytest.mark.slow @pytest.mark.skip(reason="Takes too long to run on CI") @common.SkipIfNoModelConverter -def test_ic3_vgf_FP(): +def test_ic3_vgf_no_quant(): pipeline = VgfPipeline[input_t]( ic3, model_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+FP", use_to_edge_transform_and_lower=True, + quantize=False, ) pipeline.run() @@ -109,13 +108,13 @@ def test_ic3_vgf_FP(): @pytest.mark.slow @pytest.mark.skip(reason="Takes too long to run on CI") @common.SkipIfNoModelConverter -def test_ic3_vgf_INT(): +def test_ic3_vgf_quant(): pipeline = VgfPipeline[input_t]( ic3, model_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py index d47398be3b0..5f2348dee1e 100644 --- a/backends/arm/test/models/test_llama.py +++ b/backends/arm/test/models/test_llama.py @@ -16,6 +16,7 @@ import pytest import torch from executorch.backends.arm._passes import InsertInt32CastsAfterInt64PlaceholdersPass +from executorch.backends.arm.quantizer import get_symmetric_quantization_config from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( @@ -99,6 +100,14 @@ def prepare_model(self): return llama_model, llama_inputs, llama_meta +def _use_partial_quantizer(pipeline): + """Set the pipeline's quantizer to only include Linear layers""" + pipeline.quantizer.set_global(None) + pipeline.quantizer.set_module_type( + torch.nn.Linear, get_symmetric_quantization_config() + ) + + def test_llama_tosa_FP(): llama_model, llama_inputs, llama_meta = TestLlama().prepare_model() @@ -111,9 +120,12 @@ def test_llama_tosa_FP(): llama_inputs, aten_op=[], exir_op=[], + custom_path="llama_tosa_fb", + run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk use_to_edge_transform_and_lower=True, transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()], ) + pipeline.add_stage_after("to_executorch", pipeline.tester.serialize) pipeline.run() @@ -129,13 +141,16 @@ def test_llama_tosa_INT(): llama_inputs, aten_op=[], exir_op=[], + custom_path="llama_tosa_fb_int", + run_on_tosa_ref_model=False, # Just want to write TOSA FB to disk use_to_edge_transform_and_lower=True, ) + pipeline.add_stage_after("to_executorch", pipeline.tester.serialize) pipeline.run() @common.SkipIfNoModelConverter -def test_llama_vgf_FP(): +def test_llama_vgf_no_quant(): llama_model, llama_inputs, llama_meta = TestLlama().prepare_model() if llama_model is None or llama_inputs is None: @@ -147,15 +162,16 @@ def test_llama_vgf_FP(): llama_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+FP", use_to_edge_transform_and_lower=True, transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()], + run_on_vulkan_runtime=True, + quantize=False, ) pipeline.run() @common.SkipIfNoModelConverter -def test_llama_vgf_INT(): +def test_llama_vgf_quant(): llama_model, llama_inputs, llama_meta = TestLlama().prepare_model() if llama_model is None or llama_inputs is None: @@ -167,7 +183,26 @@ def test_llama_vgf_INT(): llama_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, + run_on_vulkan_runtime=True, + quantize=True, + ) + pipeline.run() + + +def test_llama_partial_quant_tosa_INT_FP(): + llama_model, llama_inputs, llama_meta = TestLlama().prepare_model() + + if llama_model is None or llama_inputs is None: + pytest.skip("Missing model and/or input files") + + with torch.no_grad(): + pipeline = TosaPipelineINT[input_t]( + llama_model, + llama_inputs, + aten_op=[], + exir_op=[], + tosa_extensions=["FP"], ) + _use_partial_quantizer(pipeline) pipeline.run() diff --git a/backends/arm/test/models/test_lstm_arm.py b/backends/arm/test/models/test_lstm_arm.py index 1e63472f5f4..e9af67c13ea 100644 --- a/backends/arm/test/models/test_lstm_arm.py +++ b/backends/arm/test/models/test_lstm_arm.py @@ -5,7 +5,11 @@ from typing import Tuple +import pytest import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -51,7 +55,9 @@ def test_lstm_tosa_FP(): exir_op=[], use_to_edge_transform_and_lower=True, ) - pipeline.change_args("run_method_and_compare_outputs", get_test_inputs(), atol=3e-1) + pipeline.change_args( + "run_method_and_compare_outputs", inputs=get_test_inputs(), atol=3e-1 + ) pipeline.run() @@ -64,7 +70,10 @@ def test_lstm_tosa_INT(): use_to_edge_transform_and_lower=True, ) pipeline.change_args( - "run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0 + "run_method_and_compare_outputs", + inputs=get_test_inputs(), + atol=3e-1, + qtol=1.0, ) pipeline.run() @@ -77,10 +86,12 @@ def test_lstm_u55_INT(): aten_ops=[], exir_ops=[], use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.change_args( - "run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0 + "run_method_and_compare_outputs", + inputs=get_test_inputs(), + atol=3e-1, + qtol=1.0, ) pipeline.run() @@ -93,43 +104,91 @@ def test_lstm_u85_INT(): aten_ops=[], exir_ops=[], use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.change_args( - "run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0 + "run_method_and_compare_outputs", + inputs=get_test_inputs(), + atol=3e-1, + qtol=1.0, ) pipeline.run() @common.SkipIfNoModelConverter -def test_lstm_vgf_INT(): +def test_lstm_vgf_quant(): pipeline = VgfPipeline[input_t]( TestLSTM.lstm, TestLSTM.model_example_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, + quantize=True, ) - # TODO: MLETORCH-1167 Create Vulkan backend e2e tests - # pipeline.change_args( - # "run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0 - # ) pipeline.run() @common.SkipIfNoModelConverter -def test_lstm_vgf_FP(): +def test_lstm_vgf_no_quant(): pipeline = VgfPipeline[input_t]( TestLSTM.lstm, TestLSTM.model_example_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+FP", use_to_edge_transform_and_lower=True, + quantize=False, + ) + pipeline.run() + + +def test_lstm_16a8w_tosa_INT(): + """Test LSTM model with 16A8W quantization (16-bit activations, 8-bit weights)""" + + pipeline = TosaPipelineINT[input_t]( + TestLSTM.lstm, + TestLSTM.model_example_inputs, + aten_op=[], + exir_op=[], + per_channel_quantization=False, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], ) - # TODO: MLETORCH-1167 Create Vulkan backend e2e tests - # pipeline.change_args( - # "run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0 - # ) + pipeline.quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=False, epsilon=2**-16) + ) + pipeline.run() + + +@pytest.mark.xfail( + reason="MLETORCH-1452: AssertionError: Output 0 does not match reference output." +) +@common.XfailIfNoCorstone300 +def test_lstm_16a8w_u55_INT(): + pipeline = EthosU55PipelineINT[input_t]( + TestLSTM.lstm, + TestLSTM.model_example_inputs, + aten_ops=[], + exir_ops=[], + use_to_edge_transform_and_lower=True, + ) + + pipeline.quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=False, epsilon=2**-16) + ) + + pipeline.run() + + +@common.XfailIfNoCorstone320 +def test_lstm_16a8w_u85_INT(): + pipeline = EthosU85PipelineINT[input_t]( + TestLSTM.lstm, + TestLSTM.model_example_inputs, + aten_ops=[], + exir_ops=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=False, epsilon=2**-16) + ) + pipeline.run() diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index d4e3bbc8e28..2c5d2cd627d 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -10,6 +10,7 @@ import pytest import torch +from executorch.backends.arm.quantizer import get_symmetric_quantization_config from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -39,6 +40,14 @@ } +def _use_partial_quantizer(pipeline): + """Set the pipeline's quantizer to only include Conv2d and ReLU6""" + quant_cfg = get_symmetric_quantization_config() + pipeline.quantizer.set_global(None) + pipeline.quantizer.set_module_type(torch.nn.Conv2d, quant_cfg) + pipeline.quantizer.set_module_type(torch.nn.ReLU6, quant_cfg) + + def test_mv2_tosa_FP(): pipeline = TosaPipelineFP[input_t]( mv2, model_inputs, aten_op=[], exir_op=[], use_to_edge_transform_and_lower=True @@ -46,6 +55,23 @@ def test_mv2_tosa_FP(): pipeline.run() +def test_mv2_tosa_FP_channels_last(): + input_tensor = model_inputs[0].to(memory_format=torch.channels_last) + pipeline = TosaPipelineFP[input_t]( + mv2, + (input_tensor,), + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + # Changing memory format leads to an unsupported as_strided_copy op being inserted into the graph, + # leading to a graph break. + pipeline.change_args( + "check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2} + ) + pipeline.run() + + @common.parametrize("per_channel_quantization", quant_test_data) def test_mv2_tosa_INT(per_channel_quantization): pipeline = TosaPipelineINT[input_t]( @@ -70,7 +96,6 @@ def test_mv2_u55_INT(per_channel_quantization): model_inputs, aten_ops=[], exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, per_channel_quantization=per_channel_quantization, atol=0.25, @@ -88,7 +113,6 @@ def test_mv2_u85_INT(per_channel_quantization): model_inputs, aten_ops=[], exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, per_channel_quantization=per_channel_quantization, atol=0.25, @@ -99,37 +123,43 @@ def test_mv2_u85_INT(per_channel_quantization): @common.SkipIfNoModelConverter @common.parametrize("per_channel_quantization", quant_test_data) -def test_mv2_vgf_INT(per_channel_quantization): +def test_mv2_vgf_quant(per_channel_quantization): pipeline = VgfPipeline[input_t]( mv2, model_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, per_channel_quantization=per_channel_quantization, atol=0.25, qtol=1, + quantize=True, ) - # TODO: MLETORCH-1167 Create Vulkan backend e2e tests - # pipeline.change_args( - # "run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0 - # ) pipeline.run() @common.SkipIfNoModelConverter -def test_mv2_vgf_FP(): +def test_mv2_vgf_no_quant(): pipeline = VgfPipeline[input_t]( mv2, model_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+FP", use_to_edge_transform_and_lower=True, + quantize=False, + ) + pipeline.run() + + +def test_mv2_partial_quant_tosa_INT_FP(): + pipeline = TosaPipelineINT[input_t]( + mv2, + model_inputs, + aten_op=[], + exir_op=[], + tosa_extensions=["FP"], + use_to_edge_transform_and_lower=True, + atol=0.20, ) - # TODO: MLETORCH-1167 Create Vulkan backend e2e tests - # pipeline.change_args( - # "run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0 - # ) # TODO: MLETORCH-1036 decrease tolerance + _use_partial_quantizer(pipeline) pipeline.run() diff --git a/backends/arm/test/models/test_mobilenet_v3_arm.py b/backends/arm/test/models/test_mobilenet_v3_arm.py index 0dcbd9757ac..d17fc48f0e4 100644 --- a/backends/arm/test/models/test_mobilenet_v3_arm.py +++ b/backends/arm/test/models/test_mobilenet_v3_arm.py @@ -5,11 +5,12 @@ from typing import Tuple -import common import pytest import torch +from executorch.backends.arm.test import common + from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -18,7 +19,7 @@ VgfPipeline, ) -from torchvision import models, transforms +from torchvision import models, transforms # type: ignore[import-untyped] mv3 = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights) mv3 = mv3.eval() @@ -61,7 +62,6 @@ def test_mv3_u55_INT(): model_inputs, aten_ops=[], exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, atol=0.5, qtol=1, @@ -77,7 +77,6 @@ def test_mv3_u85_INT(): model_inputs, aten_ops=[], exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, atol=0.5, qtol=1, @@ -87,28 +86,28 @@ def test_mv3_u85_INT(): @common.SkipIfNoModelConverter @pytest.mark.slow -def test_mv3_vgf_INT(): +def test_mv3_vgf_quant(): pipeline = VgfPipeline[input_t]( mv3, model_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, atol=0.5, qtol=1, + quantize=True, ) pipeline.run() @common.SkipIfNoModelConverter -def test_mv3_vgf_FP(): +def test_mv3_vgf_no_quant(): pipeline = VgfPipeline[input_t]( mv3, model_inputs, aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+FP", use_to_edge_transform_and_lower=True, + quantize=False, ) pipeline.run() diff --git a/backends/arm/test/models/test_nn_functional.py b/backends/arm/test/models/test_nn_functional.py index 4896074b544..7d1ae64b63e 100644 --- a/backends/arm/test/models/test_nn_functional.py +++ b/backends/arm/test/models/test_nn_functional.py @@ -102,7 +102,6 @@ def test_nn_functional_FP(test_data): @parametrize( "test_data", module_tests, - {"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"}, ) def test_nn_functional_INT(test_data): module, inputs = test_data @@ -111,8 +110,10 @@ def test_nn_functional_INT(test_data): ) pipeline.pop_stage("check.aten") pipeline.pop_stage("check_count.exir") - pipeline.pop_stage("check.quant_nodes") - pipeline.pop_stage("check_not.quant_nodes") + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") + if pipeline.has_stage("check_not.quant_nodes"): + pipeline.pop_stage("check_not.quant_nodes") try: pipeline.run() except RuntimeError as e: diff --git a/backends/arm/test/models/test_nn_modules.py b/backends/arm/test/models/test_nn_modules.py index 0daf035a7f1..a1e1f6431d9 100644 --- a/backends/arm/test/models/test_nn_modules.py +++ b/backends/arm/test/models/test_nn_modules.py @@ -17,6 +17,8 @@ - Transformer """ +from typing import Callable + import torch from executorch.backends.arm.test.common import parametrize from executorch.backends.arm.test.tester.test_pipeline import ( @@ -24,25 +26,82 @@ TosaPipelineINT, ) + +def make_module_wrapper( + name: str, module_factory: Callable[[], torch.nn.Module] +) -> torch.nn.Module: + class ModuleWrapper(torch.nn.Module): + def __init__(self): + super().__init__() + self._module = module_factory() + + def forward(self, *args, **kwargs): + return self._module(*args, **kwargs) + + ModuleWrapper.__name__ = name + ModuleWrapper.__qualname__ = name + return ModuleWrapper() + + example_input = torch.rand(1, 6, 16, 16) module_tests = [ - (torch.nn.Embedding(10, 10), (torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]),)), - (torch.nn.LeakyReLU(), (example_input,)), - (torch.nn.BatchNorm1d(16), (torch.rand(6, 16, 16),)), - (torch.nn.AdaptiveAvgPool2d((12, 12)), (example_input,)), - (torch.nn.ConvTranspose2d(6, 3, 2), (example_input,)), - (torch.nn.GRU(10, 20, 2), (torch.randn(5, 3, 10), torch.randn(2, 3, 20))), - (torch.nn.GroupNorm(2, 6), (example_input,)), - (torch.nn.InstanceNorm2d(16), (example_input,)), - (torch.nn.PReLU(), (example_input,)), ( - torch.nn.Transformer( - d_model=64, - nhead=1, - num_encoder_layers=1, - num_decoder_layers=1, - dtype=torch.float32, + make_module_wrapper( + "EmbeddingModule", + lambda: torch.nn.Embedding(10, 10), + ), + (torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]),), + ), + ( + make_module_wrapper("LeakyReLUModule", torch.nn.LeakyReLU), + (example_input,), + ), + ( + make_module_wrapper("BatchNorm1dModule", lambda: torch.nn.BatchNorm1d(16)), + (torch.rand(6, 16, 16),), + ), + ( + make_module_wrapper( + "AdaptiveAvgPool2dModule", + lambda: torch.nn.AdaptiveAvgPool2d((12, 12)), + ), + (example_input,), + ), + ( + make_module_wrapper( + "ConvTranspose2dModule", lambda: torch.nn.ConvTranspose2d(6, 3, 2) + ), + (example_input,), + ), + ( + make_module_wrapper("GRUModule", lambda: torch.nn.GRU(10, 20, 2)), + (torch.randn(5, 3, 10), torch.randn(2, 3, 20)), + ), + ( + make_module_wrapper("GroupNormModule", lambda: torch.nn.GroupNorm(2, 6)), + (example_input,), + ), + ( + make_module_wrapper( + "InstanceNorm2dModule", lambda: torch.nn.InstanceNorm2d(16) + ), + (example_input,), + ), + ( + make_module_wrapper("PReLUModule", torch.nn.PReLU), + (example_input,), + ), + ( + make_module_wrapper( + "TransformerModule", + lambda: torch.nn.Transformer( + d_model=64, + nhead=1, + num_encoder_layers=1, + num_decoder_layers=1, + dtype=torch.float32, + ), ), (torch.rand((10, 32, 64)), torch.rand((20, 32, 64))), ), @@ -78,9 +137,7 @@ def test_nn_Modules_FP(test_data): "test_data", test_parameters, xfails={ - "GRU": "RuntimeError: Node aten_linear_default with op was not decomposed or delegated.", - "PReLU": "RuntimeError: mul(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.", - "Transformer": "AssertionError: Output 0 does not match reference output.", + "TransformerModule": "AssertionError: Output 0 does not match reference output.", }, ) def test_nn_Modules_INT(test_data): @@ -90,8 +147,10 @@ def test_nn_Modules_INT(test_data): ) pipeline.pop_stage("check.aten") pipeline.pop_stage("check_count.exir") - pipeline.pop_stage("check.quant_nodes") - pipeline.pop_stage("check_not.quant_nodes") + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") + if pipeline.has_stage("check_not.quant_nodes"): + pipeline.pop_stage("check_not.quant_nodes") try: pipeline.run() except RuntimeError as e: diff --git a/backends/arm/test/models/test_nss.py b/backends/arm/test/models/test_nss.py new file mode 100644 index 00000000000..e5e381cfe66 --- /dev/null +++ b/backends/arm/test/models/test_nss.py @@ -0,0 +1,146 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import pytest +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +from huggingface_hub import hf_hub_download + +from ng_model_gym.usecases.nss.model.model_blocks import ( # type: ignore[import-not-found,import-untyped] + AutoEncoderV1, +) + +input_t = Tuple[torch.Tensor] # Input x + + +class NSS(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.auto_encoder = AutoEncoderV1() + + +def nss() -> AutoEncoderV1: + """Get an instance of NSS with weights loaded.""" + + weights = hf_hub_download( + repo_id="Arm/neural-super-sampling", + filename="nss_v0.1.0_fp32.pt", + revision="2e9b606acd9fa25071825a12f0764f1c3bef9480", + ) + + nss_model = NSS() + nss_model.load_state_dict( + torch.load(weights, map_location=torch.device("cpu"), weights_only=True), + strict=False, + ) + return nss_model.auto_encoder + + +def example_inputs(): + return (torch.randn((1, 12, 544, 960)),) + + +def test_nss_tosa_FP(): + pipeline = TosaPipelineFP[input_t]( + nss().eval(), + example_inputs(), + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.add_stage_after("export", pipeline.tester.dump_operator_distribution) + pipeline.run() + + +def test_nss_tosa_INT(): + pipeline = TosaPipelineINT[input_t]( + nss().eval(), + example_inputs(), + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@pytest.mark.skip(reason="No support for aten_upsample_nearest2d_vec on U55") +@common.XfailIfNoCorstone300 +def test_nss_u55_INT(): + pipeline = EthosU55PipelineINT[input_t]( + nss().eval(), + example_inputs(), + aten_ops=[], + exir_ops=[], + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@pytest.mark.skip( + reason="Fails at input memory allocation for input shape: [1, 12, 544, 960]" +) +@common.XfailIfNoCorstone320 +def test_nss_u85_INT(): + pipeline = EthosU85PipelineINT[input_t]( + nss().eval(), + example_inputs(), + aten_ops=[], + exir_ops=[], + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@pytest.mark.xfail( + reason="[MLETORCH-1430]: Double types are not supported in buffers in MSL" +) +@common.SkipIfNoModelConverter +def test_nss_vgf_FP(): + pipeline = VgfPipeline[input_t]( + nss().eval(), + example_inputs(), + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + run_on_vulkan_runtime=True, + quantize=False, + # Override tosa version to test FP-only path + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_nss_vgf_INT(): + pipeline = VgfPipeline[input_t]( + nss().eval(), + example_inputs(), + aten_op=[], + exir_op=[], + symmetric_io_quantization=True, + use_to_edge_transform_and_lower=True, + run_on_vulkan_runtime=True, + quantize=True, + # Override tosa version to test INT-only path + tosa_version="TOSA-1.0+INT", + ) + pipeline.run() + + +ModelUnderTest = nss().eval() +ModelInputs = example_inputs() diff --git a/backends/arm/test/models/test_resnet18.py b/backends/arm/test/models/test_resnet18.py index 6e965daeb8b..3a40a3dfd06 100644 --- a/backends/arm/test/models/test_resnet18.py +++ b/backends/arm/test/models/test_resnet18.py @@ -17,13 +17,17 @@ ) from torchvision import transforms # type: ignore[import-untyped] -from torchvision.models import resnet18, ResNet18_Weights +from torchvision.models import ( # type: ignore[import-untyped] + resnet18, + ResNet18_Weights, +) model = resnet18(weights=ResNet18_Weights) model = model.eval() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -model_inputs = (normalize(torch.randn((1, 3, 224, 224))),) +# Using torch.rand * 2 - 1 to generate numbers in the range [-1;1] like an RGB image +model_inputs = (normalize(torch.rand((1, 3, 224, 224)) * 2 - 1),) input_t = Tuple[torch.Tensor] @@ -54,7 +58,7 @@ def test_resnet_tosa_INT(per_channel_quantization): exir_op=[], use_to_edge_transform_and_lower=True, per_channel_quantization=per_channel_quantization, - atol=0.5, + atol=0.25, qtol=1, ) pipeline.run() @@ -69,19 +73,15 @@ def test_resnet_u55_INT(per_channel_quantization): model_inputs, aten_ops=[], exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, per_channel_quantization=per_channel_quantization, - atol=0.5, + atol=0.25, qtol=1, ) pipeline.run() @pytest.mark.slow -@pytest.mark.xfail( - reason="For resnet18 for Ethos-U85, the SRAM memory footprint is very high. The compiler team is investigating." -) @common.XfailIfNoCorstone320 @common.parametrize("per_channel_quantization", quant_test_data) def test_resnet_u85_INT(per_channel_quantization): @@ -90,10 +90,9 @@ def test_resnet_u85_INT(per_channel_quantization): model_inputs, aten_ops=[], exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, per_channel_quantization=per_channel_quantization, - atol=0.5, + atol=0.25, qtol=1, ) pipeline.run() diff --git a/backends/arm/test/models/test_torch_functions.py b/backends/arm/test/models/test_torch_functions.py index 580438f6da8..3632a9dd141 100644 --- a/backends/arm/test/models/test_torch_functions.py +++ b/backends/arm/test/models/test_torch_functions.py @@ -101,7 +101,6 @@ def forward(self, *args): "Requires dynamic output shape.", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", - "norm": "An error occurred when running the 'KeepDimsFalseToSqueezePass' pass after the following passes:", }, ) def test_torch_fns_FP(test_data): @@ -129,9 +128,8 @@ def test_torch_fns_FP(test_data): "Requires dynamic output shape.", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", - "t": "MLETORCH-855: Issue with Quantization folding.", }, - strict=False, + strict=True, ) def test_torch_fns_INT(test_data): module, inputs = test_data diff --git a/backends/arm/test/models/test_w2l_arm.py b/backends/arm/test/models/test_w2l_arm.py index c627cd7f887..91e7732c161 100644 --- a/backends/arm/test/models/test_w2l_arm.py +++ b/backends/arm/test/models/test_w2l_arm.py @@ -20,7 +20,7 @@ VgfPipeline, ) -from torchaudio import models +from torchaudio import models # type: ignore[import-untyped] input_t = Tuple[torch.Tensor] # Input x @@ -91,7 +91,6 @@ def test_w2l_u55_INT(): aten_ops=[], exir_ops=[], use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.run() @@ -106,33 +105,32 @@ def test_w2l_u85_INT(): aten_ops=[], exir_ops=[], use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.run() @common.SkipIfNoModelConverter @pytest.mark.slow -def test_w2l_vgf_INT(): +def test_w2l_vgf_quant(): pipeline = VgfPipeline[input_t]( TestW2L.create_model(), TestW2L.model_example_inputs, aten_op=[], exir_op=TestW2L.all_operators, - tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, + quantize=True, ) pipeline.run() @common.SkipIfNoModelConverter -def test_w2l_vgf_FP(): +def test_w2l_vgf_no_quant(): pipeline = VgfPipeline[input_t]( TestW2L.create_model(), TestW2L.model_example_inputs, aten_op=[], exir_op=TestW2L.all_operators, - tosa_version="TOSA-1.0+FP", use_to_edge_transform_and_lower=True, + quantize=False, ) pipeline.run() diff --git a/backends/arm/test/ops/test_abs.py b/backends/arm/test/ops/test_abs.py index 4ebcf7393c1..9e8ad2e3d03 100644 --- a/backends/arm/test/ops/test_abs.py +++ b/backends/arm/test/ops/test_abs.py @@ -55,7 +55,10 @@ def test_abs_tosa_INT(test_data: torch.Tensor): @common.XfailIfNoCorstone300 def test_abs_u55_INT(test_data: torch.Tensor): pipeline = EthosU55PipelineINT[input_t1]( - Abs(), test_data(), aten_op, exir_op, run_on_fvp=True + Abs(), + test_data(), + aten_op, + exir_op, ) pipeline.run() @@ -64,28 +67,35 @@ def test_abs_u55_INT(test_data: torch.Tensor): @common.XfailIfNoCorstone320 def test_abs_u85_INT(test_data: torch.Tensor): pipeline = EthosU85PipelineINT[input_t1]( - Abs(), test_data(), aten_op, exir_op, run_on_fvp=True + Abs(), + test_data(), + aten_op, + exir_op, ) pipeline.run() @common.parametrize("test_data", Abs.test_parameters) @common.SkipIfNoModelConverter -def test_abs_vgf_FP(test_data: input_t1): +def test_abs_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( - Abs(), test_data(), aten_op, exir_op, tosa_version="TOSA-1.0+FP" + Abs(), + test_data(), + aten_op, + exir_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", Abs.test_parameters) @common.SkipIfNoModelConverter -def test_abs_vgf_INT(test_data: input_t1): +def test_abs_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Abs(), test_data(), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_acos.py b/backends/arm/test/ops/test_acos.py index 28dadcf95be..be91bd71567 100644 --- a/backends/arm/test/ops/test_acos.py +++ b/backends/arm/test/ops/test_acos.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from typing import Tuple -import pytest import torch from executorch.backends.arm.test import common @@ -96,33 +95,27 @@ def test_acos_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_acos_vgf_FP(test_data: Tuple): +def test_acos_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t]( Acos(), (test_data(),), [], [], - tosa_version="TOSA-1.0+FP", run_on_vulkan_runtime=True, + quantize=False, ) - try: - pipeline.run() - except FileNotFoundError as e: - pytest.skip(f"VKML executor_runner not found - not built - skip {e}") + pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_acos_vgf_INT(test_data: Tuple): +def test_acos_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t]( Acos(), (test_data(),), [], [], - tosa_version="TOSA-1.0+INT", run_on_vulkan_runtime=True, + quantize=True, ) - try: - pipeline.run() - except FileNotFoundError as e: - pytest.skip(f"VKML executor_runner not found - not built - skip {e}") + pipeline.run() diff --git a/backends/arm/test/ops/test_acosh.py b/backends/arm/test/ops/test_acosh.py index 25ba2b1a83b..48490a91662 100644 --- a/backends/arm/test/ops/test_acosh.py +++ b/backends/arm/test/ops/test_acosh.py @@ -87,7 +87,6 @@ def test_acosh_u55_INT_xfail(test_data: Tuple): Acosh(), (test_data(),), aten_ops=[], - run_on_fvp=False, ) pipeline.run() @@ -110,30 +109,29 @@ def test_acosh_u85_INT_xfail(test_data: Tuple): Acosh(), (test_data(),), aten_ops=[], - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_acosh_vgf_FP(test_data: Tuple): +def test_acosh_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t]( Acosh(), (test_data(),), aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_acosh_vgf_INT(test_data: Tuple): +def test_acosh_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t]( Acosh(), (test_data(),), aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_adaptive_avg_pool2d.py b/backends/arm/test/ops/test_adaptive_avg_pool2d.py index 4411ce7f746..1043387fdf1 100644 --- a/backends/arm/test/ops/test_adaptive_avg_pool2d.py +++ b/backends/arm/test/ops/test_adaptive_avg_pool2d.py @@ -136,6 +136,20 @@ def test_adaptive_avg_pool2d_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +def test_adaptive_avg_pool2d_tosa_INT_a16w8(test_module): + """Test adaptive_avg_pool2d with int16 I/O quantization for TOSA INT.""" + model, input_tensor = test_module() + pipeline = TosaPipelineINT[input_t]( + model, + input_tensor, + aten_op=[], + exir_op=exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.XfailIfNoCorstone300 def test_adaptive_avg_pool2d_u55_INT(test_module): @@ -150,6 +164,27 @@ def test_adaptive_avg_pool2d_u55_INT(test_module): pipeline.run() +# Remove high_channel_count & output_1x1_from_19 due to 2MB SRAM access on U55 +u55_test_modules = test_modules +for key in ["high_channel_count", "output_1x1_from_19"]: + u55_test_modules.pop(key) + + +@common.parametrize("test_module", u55_test_modules) +@common.XfailIfNoCorstone300 +def test_adaptive_avg_pool2d_16a8w_u55_INT16(test_module): + """Test adaptive_avg_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + pipeline = EthosU55PipelineINT[input_t]( + model, + input_tensor, + aten_ops=[], + exir_ops=exir_op, + a16w8_quantization=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.XfailIfNoCorstone320 def test_adaptive_avg_pool2d_u85_INT(test_module): @@ -164,29 +199,44 @@ def test_adaptive_avg_pool2d_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone320 +def test_adaptive_avg_pool2d_16a8w_u85_INT16(test_module): + """Test adaptive_avg_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + pipeline = EthosU85PipelineINT[input_t]( + model, + input_tensor, + aten_ops=[], + exir_ops=exir_op, + a16w8_quantization=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.SkipIfNoModelConverter -def test_adaptive_avg_pool2d_vgf_FP(test_module): +def test_adaptive_avg_pool2d_vgf_no_quant(test_module): model, input_tensor = test_module() pipeline = VgfPipeline[input_t]( model, input_tensor, [], exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_modules) @common.SkipIfNoModelConverter -def test_adaptive_avg_pool2d_vgf_INT(test_module): +def test_adaptive_avg_pool2d_vgf_quant(test_module): model, input_tensor = test_module() pipeline = VgfPipeline[input_t]( model, input_tensor, [], exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 2eabd302df6..31c20bdb60c 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -7,14 +7,12 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.quantizer import arm_quantizer from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, - TOSAQuantizer, ) -from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -22,9 +20,6 @@ TosaPipelineINT, VgfPipeline, ) -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.specification import get_tosa_spec -from executorch.backends.xnnpack.test.tester import Quantize from torchao.quantization.pt2e import HistogramObserver from torchao.quantization.pt2e.quantizer import QuantizationSpec @@ -79,7 +74,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): class Add3(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor): - return x + y + return torch.add(x, y, alpha=1.5) test_data: list[input_t2] = { "3d_randn_diff_rank": lambda: (torch.randn(1, 4, 5), torch.randn(4, 1)), @@ -103,15 +98,8 @@ def test_add_tensor_tosa_INT(test_data: input_t1): @common.parametrize("test_data", Add.test_data) def test_add_tensor_tosa_INT_i32(test_data: input_t1): pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op) - tosa_version = conftest.get_option("tosa_version") - tosa_profiles = { - "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT"), - } - # Create a quantizer with int8 quantization on the input and output but int32 on everything else. - quantizer = arm_quantizer.TOSAQuantizer( - get_tosa_spec(common.get_tosa_compile_spec(tosa_profiles[tosa_version])) - ) - quantizer.set_io(arm_quantizer.get_symmetric_quantization_config()) + + pipeline.quantizer.set_io(arm_quantizer.get_symmetric_quantization_config()) observer_options = {"eps": 2**-16} observer = HistogramObserver.with_args(**observer_options) input_act_qspec = QuantizationSpec( @@ -128,12 +116,10 @@ def test_add_tensor_tosa_INT_i32(test_data: input_t1): quant_max=2**31 - 1, quant_min=-(2**31), ) - # This quantization_config will be set as global config. quantization_config = arm_quantizer.QuantizationConfig( input_act_qspec, output_act_qspec, None, None ) - quantize_stage = Quantize(quantizer, quantization_config) - pipeline.change_args("quantize", quantize_stage) + pipeline.quantizer.set_global(quantization_config) # Check that we get the additional (dq -> q pipeline.add_stage_after( @@ -146,7 +132,10 @@ def test_add_tensor_tosa_INT_i32(test_data: input_t1): @common.XfailIfNoCorstone300 def test_add_tensor_u55_INT(test_data: input_t1): pipeline = EthosU55PipelineINT[input_t1]( - Add(), test_data(), aten_op, exir_op, run_on_fvp=True + Add(), + test_data(), + aten_op, + exir_op, ) pipeline.run() @@ -155,7 +144,10 @@ def test_add_tensor_u55_INT(test_data: input_t1): @common.XfailIfNoCorstone320 def test_add_tensor_u85_INT(test_data: input_t1): pipeline = EthosU85PipelineINT[input_t1]( - Add(), test_data(), aten_op, exir_op, run_on_fvp=True + Add(), + test_data(), + aten_op, + exir_op, ) pipeline.run() @@ -188,7 +180,10 @@ def test_add_tensor_tosa_INT_2(test_data: input_t2): @common.XfailIfNoCorstone300 def test_add_tensor_u55_INT_2(test_data: input_t2): pipeline = EthosU55PipelineINT[input_t2]( - Add2(), test_data(), aten_op, exir_op, run_on_fvp=True + Add2(), + test_data(), + aten_op, + exir_op, ) pipeline.run() @@ -197,68 +192,43 @@ def test_add_tensor_u55_INT_2(test_data: input_t2): @common.XfailIfNoCorstone320 def test_add_tensor_u85_INT_2(test_data: input_t2): pipeline = EthosU85PipelineINT[input_t2]( - Add2(), test_data(), aten_op, exir_op, run_on_fvp=True + Add2(), + test_data(), + aten_op, + exir_op, ) pipeline.run() @common.parametrize("test_data", Add.test_data) @common.SkipIfNoModelConverter -def test_add_tensor_vgf_FP(test_data: input_t1): +def test_add_tensor_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Add(), test_data(), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", run_on_vulkan_runtime=True, + quantize=False, ) - try: - pipeline.run() - except FileNotFoundError as e: - pytest.skip(f"VKML executor_runner not found - not built - skip {e}") + pipeline.run() @common.parametrize("test_data", Add.test_data) @common.SkipIfNoModelConverter -def test_add_tensor_vgf_INT(test_data: input_t1): +def test_add_tensor_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Add(), test_data(), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", run_on_vulkan_runtime=True, + quantize=True, ) - try: - pipeline.run() - except FileNotFoundError as e: - pytest.skip(f"VKML executor_runner not found - not built - skip {e}") - - -def get_symmetric_a16w8_add_quantizer(per_channel_quantization=False): - tosa_version = conftest.get_option("tosa_version") - tosa_profiles = { - "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), - } - - quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) - quantizer.set_global( - get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) - ) - - return Quantize( - quantizer, - get_symmetric_a16w8_quantization_config( - is_per_channel=per_channel_quantization - ), - ) + pipeline.run() @common.parametrize("test_data", Add.test_data) -@pytest.mark.xfail( - reason="missing int16 add ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13730" -) def test_add_tensor_16a8w_tosa_INT(test_data: input_t1): """Test add operation with 16A8W quantization (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -273,20 +243,14 @@ def test_add_tensor_16a8w_tosa_INT(test_data: input_t1): tosa_extensions=["int16"], ) - pipeline.change_args( - "quantize", - get_symmetric_a16w8_add_quantizer( - per_channel_quantization=per_channel_quantization - ), + pipeline.quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) ) pipeline.run() @common.parametrize("test_data", Add.test_data) @common.XfailIfNoCorstone300 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730" -) def test_add_tensor_16a8w_u55_INT16(test_data: input_t1): """Test add operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -298,23 +262,16 @@ def test_add_tensor_16a8w_u55_INT16(test_data: input_t1): exir_op, per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) - pipeline.change_args( - "quantize", - get_symmetric_a16w8_add_quantizer( - per_channel_quantization=per_channel_quantization - ), + pipeline.quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) ) pipeline.run() @common.parametrize("test_data", Add.test_data) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730" -) def test_add_tensor_16a8w_u85_INT16(test_data: input_t1): """Test add operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -326,13 +283,9 @@ def test_add_tensor_16a8w_u85_INT16(test_data: input_t1): exir_op, per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) - pipeline.change_args( - "quantize", - get_symmetric_a16w8_add_quantizer( - per_channel_quantization=per_channel_quantization - ), + pipeline.quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) ) pipeline.run() diff --git a/backends/arm/test/ops/test_addmm.py b/backends/arm/test/ops/test_addmm.py index 753cb599b2b..10bba4311bf 100644 --- a/backends/arm/test/ops/test_addmm.py +++ b/backends/arm/test/ops/test_addmm.py @@ -167,26 +167,26 @@ def test_addmm_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_addmm_vgf_FP(test_data: input_t1): +def test_addmm_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Addmm(), (*test_data,), aten_op=aten_op, exir_op=exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_addmm_vgf_INT(test_data: input_t1): +def test_addmm_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Addmm(), (*test_data,), aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -211,9 +211,6 @@ def get_symmetric_a16w8_addmm_quantizer(per_channel_quantization=False): @common.parametrize("test_data", test_data_suite) -@pytest.mark.xfail( - reason="missing int16 addmm ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13979" -) def test_addmm_16a8w_tosa_INT(test_data: input_t1): """Test addmm (FC layer) operation with 16A8W quantization (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -253,7 +250,6 @@ def test_addmm_16a8w_u55_INT16(test_data: input_t1): exir_ops=[], per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.change_args( @@ -267,9 +263,6 @@ def test_addmm_16a8w_u55_INT16(test_data: input_t1): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 addmm operations" -) def test_addmm_16a8w_u85_INT16(test_data: input_t1): """Test addmm (FC layer) operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -281,7 +274,6 @@ def test_addmm_16a8w_u85_INT16(test_data: input_t1): exir_ops=[], per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.change_args( diff --git a/backends/arm/test/ops/test_alias_copy.py b/backends/arm/test/ops/test_alias_copy.py index 8b951a4d856..29c68930941 100644 --- a/backends/arm/test/ops/test_alias_copy.py +++ b/backends/arm/test/ops/test_alias_copy.py @@ -90,25 +90,25 @@ def test_alias_u85_INT(test_data: input_t1): @common.parametrize("test_data", AliasCopy.test_data) @common.SkipIfNoModelConverter -def test_alias_vgf_FP(test_data: input_t1): +def test_alias_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( AliasCopy(), test_data(), AliasCopy.aten_op, AliasCopy.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", AliasCopy.test_data) @common.SkipIfNoModelConverter -def test_alias_vgf_INT(test_data: input_t1): +def test_alias_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( AliasCopy(), test_data(), AliasCopy.aten_op, AliasCopy.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_amax.py b/backends/arm/test/ops/test_amax.py index 080dddda92e..48a3932f80f 100644 --- a/backends/arm/test/ops/test_amax.py +++ b/backends/arm/test/ops/test_amax.py @@ -103,7 +103,6 @@ def test_amax_u85_INT(test_data: Amax.input_t): Amax(dim, keep_dims), data, Amax.aten_op, - run_on_fvp=True, ) pipeline.run() @@ -140,53 +139,53 @@ def test_max_dim_tosa_FP_not_delegated(): @common.parametrize("test_data", Amax.test_data) @common.SkipIfNoModelConverter -def test_amax_vgf_FP(test_data: Amax.input_t): +def test_amax_vgf_no_quant(test_data: Amax.input_t): data, dim, keep_dims = test_data() module = Amax(dim, keep_dims) pipeline = VgfPipeline[Amax.input_t]( module, data, Amax.aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Amax.test_data) @common.SkipIfNoModelConverter -def test_amax_vgf_INT(test_data: Amax.input_t): +def test_amax_vgf_quant(test_data: Amax.input_t): data, dim, keep_dims = test_data() module = Amax(dim, keep_dims) pipeline = VgfPipeline[Amax.input_t]( module, data, Amax.aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", Max.test_data) @common.SkipIfNoModelConverter -def test_max_dim_vgf_FP_to_amax(test_data: Max.input_t): +def test_max_dim_to_amax_vgf_no_quant(test_data: Max.input_t): data, dim = test_data() pipeline = VgfPipeline[Max.input_t]( Max(dim), data, "torch.ops.aten.max", - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Max.test_data) @common.SkipIfNoModelConverter -def test_max_dim_vgf_INT_to_amax(test_data: Max.input_t): +def test_max_dim_to_amax_vgf_quant(test_data: Max.input_t): data, dim = test_data() pipeline = VgfPipeline[Max.input_t]( Max(dim), data, "torch.ops.aten.amax", - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_amin.py b/backends/arm/test/ops/test_amin.py index a24da9e1ba0..b237312b412 100644 --- a/backends/arm/test/ops/test_amin.py +++ b/backends/arm/test/ops/test_amin.py @@ -29,12 +29,16 @@ def __init__(self, dim, keep_dims): super().__init__() def forward(self, x): - return torch.amin(x, self.dim, self.keep_dims) + if self.dim is None: + return torch.amin(x, keepdim=self.keep_dims) + else: + return torch.amin(x, self.dim, self.keep_dims) - test_data: Dict[str, input_t] = { + test_data: Dict = { "rank_1_dim_0": lambda: ((torch.rand([10]),), 0, False), "rank_2_dim_1_keep_dims": lambda: ((torch.rand([2, 2]),), (1,), True), "rank_4_all_dim": lambda: ((torch.rand([1, 2, 5, 5]),), (0, 1, 2, 3), False), + "rank_4_no_dim": lambda: ((torch.rand([1, 2, 5, 5]),), None, False), "rank_4_0,3_keep_dims": lambda: ((torch.rand([1, 2, 2, 2]),), (0, 3), True), "rank_4_mult_batches": lambda: ((torch.rand([2, 2, 2, 2]),), (0), True), } @@ -52,7 +56,7 @@ def forward(self, x): x = torch.min(x, self.dim) return x[0] - test_data: Dict[str, input_t] = { + test_data: Dict = { "rank_1_dim_0": lambda: ((torch.rand([10]),), 0), "rank_2_dim_1": lambda: ((torch.rand([2, 2]),), 1), "rank_4_dim_2": lambda: ((torch.rand([2, 2, 2, 2]),), 2), @@ -112,7 +116,6 @@ def test_amin_u85_INT(test_data: Amin.input_t): Amin(dim, keep_dims), data, Amin.aten_op, - run_on_fvp=True, ) pipeline.run() @@ -152,48 +155,51 @@ def test_min_dim_tosa_FP_not_delegated(): @common.parametrize("test_data", Amin.test_data) @common.SkipIfNoModelConverter -def test_amin_vgf_FP(test_data: Amin.input_t): +def test_amin_vgf_no_quant(test_data: Amin.input_t): data, dim, keep_dims = test_data() pipeline = VgfPipeline[Amin.input_t]( - Amin(dim, keep_dims), data, Amin.aten_op, tosa_version="TOSA-1.0+FP" + Amin(dim, keep_dims), + data, + Amin.aten_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", Amin.test_data) @common.SkipIfNoModelConverter -def test_amin_vgf_INT(test_data: Amin.input_t): +def test_amin_vgf_quant(test_data: Amin.input_t): data, dim, keep_dims = test_data() pipeline = VgfPipeline[Amin.input_t]( Amin(dim, keep_dims), data, Amin.aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", Min.test_data) @common.SkipIfNoModelConverter -def test_min_dim_vgf_FP_to_amin(test_data: Min.input_t): +def test_min_dim_to_amin_vgf_no_quant(test_data: Min.input_t): data, dim = test_data() pipeline = VgfPipeline[Min.input_t]( Min(dim), data, "torch.ops.aten.min", - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Min.test_data) @common.SkipIfNoModelConverter -def test_min_dim_vgf_INT_to_amin(test_data: Min.input_t): +def test_min_dim_to_amin_vgf_quant(test_data: Min.input_t): data, dim = test_data() pipeline = VgfPipeline[Min.input_t]( Min(dim), data, "torch.ops.aten.amin", - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_any.py b/backends/arm/test/ops/test_any.py index ae738480048..3cc3432f766 100644 --- a/backends/arm/test/ops/test_any.py +++ b/backends/arm/test/ops/test_any.py @@ -149,8 +149,6 @@ def test_any_tosa_INT(test_data: input_t1): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -177,41 +175,36 @@ def test_any_u85_INT(test_data: input_t1): test_input(), op.aten_op, op.exir_op, - run_on_fvp=True, atol=0, rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", test_data) @common.SkipIfNoModelConverter -def test_any_vgf_FP(test_data: input_t1): +def test_any_vgf_no_quant(test_data: input_t1): op, data_fn = test_data() pipeline = VgfPipeline[input_t1]( op, data_fn(), op.aten_op, op.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data) @common.SkipIfNoModelConverter -def test_any_vgf_INT(test_data: input_t1): +def test_any_vgf_quant(test_data: input_t1): op, data_fn = test_data() pipeline = VgfPipeline[input_t1]( op, data_fn(), op.aten_op, op.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_arange.py b/backends/arm/test/ops/test_arange.py index 33cca542922..90ab437b9e7 100644 --- a/backends/arm/test/ops/test_arange.py +++ b/backends/arm/test/ops/test_arange.py @@ -98,7 +98,6 @@ def test_arange_start_step_tosa_INT(test_data: test_data_t): ArangeAdd.aten_op, ArangeAdd.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -111,7 +110,6 @@ def test_arange_start_step_u55_INT(test_data: test_data_t): input_data(), ArangeAdd.aten_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -124,13 +122,12 @@ def test_arange_start_step_u85_INT(test_data: test_data_t): input_data(), ArangeAdd.aten_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", ArangeAdd.test_data) @common.SkipIfNoModelConverter -def test_arange_start_step_vgf_FP(test_data: test_data_t): +def test_arange_start_step_vgf_no_quant(test_data: test_data_t): input_data, init_data = test_data module = ArangeAdd(*init_data) pipeline = VgfPipeline[input_t]( @@ -138,14 +135,14 @@ def test_arange_start_step_vgf_FP(test_data: test_data_t): input_data(), module.aten_op, module.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", ArangeAdd.test_data) @common.SkipIfNoModelConverter -def test_arange_start_step_vgf_INT(test_data: test_data_t): +def test_arange_start_step_vgf_quant(test_data: test_data_t): input_data, init_data = test_data module = ArangeAdd(*init_data) pipeline = VgfPipeline[input_t]( @@ -153,7 +150,7 @@ def test_arange_start_step_vgf_INT(test_data: test_data_t): input_data(), module.aten_op, module.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -202,28 +199,28 @@ def test_linspace_tosa_INT(test_data: test_data_t): @common.parametrize("test_data", LinspaceAdd.test_data) @common.SkipIfNoModelConverter -def test_linspace_vgf_FP(test_data: test_data_t): +def test_linspace_vgf_no_quant(test_data: test_data_t): input_data, init_data = test_data pipeline = VgfPipeline[input_t]( LinspaceAdd(*init_data), input_data(), LinspaceAdd.aten_op, LinspaceAdd.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", LinspaceAdd.test_data) @common.SkipIfNoModelConverter -def test_linspace_vgf_INT(test_data: test_data_t): +def test_linspace_vgf_quant(test_data: test_data_t): input_data, init_data = test_data pipeline = VgfPipeline[input_t]( LinspaceAdd(*init_data), input_data(), LinspaceAdd.aten_op, LinspaceAdd.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -252,10 +249,10 @@ def test_arange_u85_INT(): @pytest.mark.skip(reason=skip_str) -def test_arange_vgf_FP(): +def test_arange_vgf_no_quant(): pass @pytest.mark.skip(reason=skip_str) -def test_arange_vgf_INT(): +def test_arange_vgf_quant(): pass diff --git a/backends/arm/test/ops/test_asin.py b/backends/arm/test/ops/test_asin.py index 9c37bddbd92..e00e6364d22 100644 --- a/backends/arm/test/ops/test_asin.py +++ b/backends/arm/test/ops/test_asin.py @@ -83,23 +83,23 @@ def test_asin_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_asin_vgf_FP(test_data: Tuple): +def test_asin_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t]( Asin(), (test_data(),), aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_asin_vgf_INT(test_data: Tuple): +def test_asin_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t]( Asin(), (test_data(),), aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_asinh.py b/backends/arm/test/ops/test_asinh.py index 305c822601c..db902a3dc8f 100644 --- a/backends/arm/test/ops/test_asinh.py +++ b/backends/arm/test/ops/test_asinh.py @@ -82,23 +82,23 @@ def test_asinh_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_asinh_vgf_FP(test_data: Tuple): +def test_asinh_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t]( Asinh(), (test_data(),), aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_asinh_vgf_INT(test_data: Tuple): +def test_asinh_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t]( Asinh(), (test_data(),), aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_at.py b/backends/arm/test/ops/test_at.py index b8a20760820..4b8223cc0b7 100644 --- a/backends/arm/test/ops/test_at.py +++ b/backends/arm/test/ops/test_at.py @@ -152,105 +152,105 @@ def test_atmatmul_mixed_pattern2_tosa_INT(test_data: input_t1): @common.parametrize("test_data", AtMatMulSingleInput.test_data_generators) @common.SkipIfNoModelConverter -def test_atmatmul_single_input_vgf_FP(test_data: input_t1): +def test_atmatmul_single_input_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( AtMatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", AtMatMulDoubleInput.test_data_generators) @common.SkipIfNoModelConverter -def test_atmatmul_double_input_vgf_FP(test_data: input_t1): +def test_atmatmul_double_input_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( AtMatMulDoubleInput(), test_data(), aten_op_mm, exir_op_mm, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", AtMatMulMixedPattern1.test_data_generators) @common.SkipIfNoModelConverter -def test_atmatmul_mixed_pattern1_vgf_FP(test_data: input_t1): +def test_atmatmul_mixed_pattern1_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( AtMatMulMixedPattern1(), test_data(), aten_op_mm, exir_op_mm, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", AtMatMulMixedPattern2.test_data_generators) @common.SkipIfNoModelConverter -def test_atmatmul_mixed_pattern2_vgf_FP(test_data: input_t1): +def test_atmatmul_mixed_pattern2_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( AtMatMulMixedPattern2(), test_data(), aten_op_mm, exir_op_mm, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", AtMatMulSingleInput.test_data_generators) @common.SkipIfNoModelConverter -def test_atmatmul_single_input_vgf_INT(test_data: input_t1): +def test_atmatmul_single_input_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( AtMatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", AtMatMulDoubleInput.test_data_generators) @common.SkipIfNoModelConverter -def test_atmatmul_double_input_vgf_INT(test_data: input_t1): +def test_atmatmul_double_input_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( AtMatMulDoubleInput(), test_data(), aten_op_mm, exir_op_mm, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", AtMatMulMixedPattern1.test_data_generators) @common.SkipIfNoModelConverter -def test_atmatmul_mixed_pattern1_vgf_INT(test_data: input_t1): +def test_atmatmul_mixed_pattern1_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( AtMatMulMixedPattern1(), test_data(), aten_op_mm, exir_op_mm, qtol=1, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", AtMatMulMixedPattern2.test_data_generators) @common.SkipIfNoModelConverter -def test_atmatmul_mixed_pattern2_vgf_INT(test_data: input_t1): +def test_atmatmul_mixed_pattern2_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( AtMatMulMixedPattern2(), test_data(), aten_op_mm, exir_op_mm, qtol=1, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_atan.py b/backends/arm/test/ops/test_atan.py index 51114d2800f..4e103dcaa82 100644 --- a/backends/arm/test/ops/test_atan.py +++ b/backends/arm/test/ops/test_atan.py @@ -87,25 +87,25 @@ def test_atan_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_atan_vgf_FP(test_data: Tuple): +def test_atan_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Atan(), (test_data,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_atan_vgf_INT(test_data: Tuple): +def test_atan_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Atan(), (test_data,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_atanh.py b/backends/arm/test/ops/test_atanh.py index 12754a34646..8ac270849a1 100644 --- a/backends/arm/test/ops/test_atanh.py +++ b/backends/arm/test/ops/test_atanh.py @@ -88,25 +88,25 @@ def test_atanh_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_atanh_vgf_FP(test_data: input_t1): +def test_atanh_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Atanh(), (test_data,), aten_op=aten_op, exir_op=exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_atanh_vgf_INT(test_data: input_t1): +def test_atanh_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Atanh(), (test_data,), aten_op=aten_op, exir_op=exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_avg_pool2d.py b/backends/arm/test/ops/test_avg_pool2d.py index be54c76e68b..6998cb97419 100644 --- a/backends/arm/test/ops/test_avg_pool2d.py +++ b/backends/arm/test/ops/test_avg_pool2d.py @@ -23,7 +23,7 @@ VgfPipeline, ) -aten_op = "torch.ops.aten.avg_pool2d.default" +aten_op = "avg_pool2d.default" exir_op = "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default" input_t = Tuple[torch.Tensor] @@ -34,6 +34,15 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) +class BecomesMeanInToEdge(torch.nn.Module): + """This averagepool will be converted to mean when lowering to edge. This causes the decompose_meandim pass to not + trigger until the backend pipeline, which requires extra care. + """ + + def forward(self, x: torch.Tensor): + return torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)) + + test_modules = { "zeros": lambda: (AvgPool2d(4, 2, 0, False), (torch.zeros(1, 16, 50, 32),)), "ones": lambda: (AvgPool2d(4, 2, 0, False, True), (torch.ones(1, 16, 50, 32),)), @@ -110,6 +119,9 @@ def forward(self, *args, **kwargs): AvgPool2d(3, (1, 3), 1, count_include_pad=False), (torch.rand(1, 16, 54, 54),), ), + "becomes_mean_rank3": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 8, 8),)), + "becomes_mean_rank4": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 2, 8, 8),)), + "becomes_mean_rank5": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 2, 8, 8),)), } @@ -141,6 +153,21 @@ def test_avg_pool2d_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +def test_avg_pool2d_tosa_INT_a16w8(test_module): + """Test avg_pool2d operation with int16 I/O quantization for TOSA INT.""" + model, input_tensor = test_module() + pipeline = TosaPipelineINT[input_t]( + model, + input_tensor, + aten_op, + exir_op, + tosa_extensions=["int16"], + run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"), + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.XfailIfNoCorstone300 def test_avg_pool2d_u55_INT(test_module): @@ -151,7 +178,23 @@ def test_avg_pool2d_u55_INT(test_module): input_tensor, aten_op, exir_op, - run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone300 +def test_avg_pool2d_16a8w_u55_INT16(test_module): + """Test avg_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + pipeline = EthosU55PipelineINT[input_t]( + model, + input_tensor, + aten_op, + exir_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, ) pipeline.run() @@ -166,35 +209,51 @@ def test_avg_pool2d_u85_INT(test_module): input_tensor, aten_op, exir_op, - run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone320 +def test_avg_pool2d_16a8w_u85_INT16(test_module): + """Test avg_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + pipeline = EthosU85PipelineINT[input_t]( + model, + input_tensor, + aten_op, + exir_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, ) pipeline.run() @common.parametrize("test_module", test_modules) @common.SkipIfNoModelConverter -def test_avg_pool2d_vgf_FP(test_module): +def test_avg_pool2d_vgf_no_quant(test_module): model, input_tensor = test_module() pipeline = VgfPipeline[input_t]( model, input_tensor, aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_modules) @common.SkipIfNoModelConverter -def test_avg_pool2d_vgf_INT(test_module): +def test_avg_pool2d_vgf_quant(test_module): model, input_tensor = test_module() pipeline = VgfPipeline[input_t]( model, input_tensor, aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_batch_norm.py b/backends/arm/test/ops/test_batch_norm.py index a28180b7b57..b4191da2421 100644 --- a/backends/arm/test/ops/test_batch_norm.py +++ b/backends/arm/test/ops/test_batch_norm.py @@ -102,20 +102,20 @@ def test_native_batch_norm_legit_no_training_tosa_INT_not_delegated(): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_native_batch_norm_legit_no_training_vgf_FP(test_data: Tuple): +def test_native_batch_norm_legit_no_training_vgf_no_quant(test_data: Tuple): inp, model_params = test_data() pipeline = VgfPipeline[input_t1]( BatchNorm2d(*model_params), (inp,), aten_op=BatchNorm2d.aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_native_batch_norm_legit_no_training_vgf_INT(test_data: Tuple): +def test_native_batch_norm_legit_no_training_vgf_quant(test_data: Tuple): # TODO(MLETORCH-100: Quantized stand-alone batch norms) pass @@ -220,7 +220,6 @@ def test_native_batch_norm_legit_no_training_u55_INT_conv(test_data: Tuple): BatchNorm2dConv(*model_params), (test_data,), aten_ops=BatchNorm2dConv.aten_ops[0], # Bn is removed before check - run_on_fvp=True, qtol=1, ) pipeline.run() @@ -234,7 +233,6 @@ def test_native_batch_norm_legit_no_training_u85_INT_conv(test_data: Tuple): BatchNorm2dConv(*model_params), (test_data,), aten_ops=BatchNorm2dConv.aten_ops[0], # Bn is removed before check - run_on_fvp=True, qtol=1, ) pipeline.run() @@ -242,27 +240,27 @@ def test_native_batch_norm_legit_no_training_u85_INT_conv(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_native_batch_norm_legit_no_training_vgf_FP_conv(test_data: Tuple): +def test_native_batch_norm_legit_no_training_conv_vgf_no_quant(test_data: Tuple): test_data, model_params = test_data() pipeline = VgfPipeline[input_t1]( BatchNorm2dConv(*model_params), (test_data,), aten_op=BatchNorm2dConv.aten_ops, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_native_batch_norm_legit_no_training_vgf_INT_conv(test_data: Tuple): +def test_native_batch_norm_legit_no_training_conv_vgf_quant(test_data: Tuple): test_data, model_params = test_data() pipeline = VgfPipeline[input_t1]( BatchNorm2dConv(*model_params), (test_data,), - aten_op=BatchNorm2dConv.aten_ops[0], # Bn is removed before check + aten_op=BatchNorm2dConv.aten_ops[0], qtol=1, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -336,7 +334,6 @@ def test_native_batch_norm_legit_no_stats_u55_INT(test_data: Tuple): BatchNorm2dNoStats(*model_params), (test_data,), aten_op=BatchNorm2dNoStats.aten_ops, - run_on_fvp=True, qtol=1, ) pipeline.run() @@ -353,7 +350,6 @@ def test_native_batch_norm_legit_no_stats_u85_INT(test_data: Tuple): BatchNorm2dNoStats(*model_params), (test_data,), aten_op=BatchNorm2dNoStats.aten_ops, - run_on_fvp=False, qtol=1, ) pipeline.run() @@ -361,13 +357,13 @@ def test_native_batch_norm_legit_no_stats_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_native_batch_norm_legit_no_stats_vgf_FP(test_data: Tuple): +def test_native_batch_norm_legit_no_stats_vgf_no_quant(test_data: Tuple): test_data, model_params = test_data() pipeline = VgfPipeline[input_t1]( BatchNorm2dNoStats(*model_params), (test_data,), aten_op=BatchNorm2dNoStats.aten_ops, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @@ -377,13 +373,13 @@ def test_native_batch_norm_legit_no_stats_vgf_FP(test_data: Tuple): ) @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_native_batch_norm_legit_no_stats_vgf_INT(test_data: Tuple): +def test_native_batch_norm_legit_no_stats_vgf_quant(test_data: Tuple): test_data, model_params = test_data() pipeline = VgfPipeline[input_t1]( BatchNorm2dNoStats(*model_params), (test_data,), aten_op=BatchNorm2dNoStats.aten_ops, qtol=1, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_bitwise.py b/backends/arm/test/ops/test_bitwise.py index 218f2290cab..1565fe6181e 100644 --- a/backends/arm/test/ops/test_bitwise.py +++ b/backends/arm/test/ops/test_bitwise.py @@ -109,8 +109,8 @@ def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor): class AndScalar(BitwiseBinaryScalar): - aten_op = "torch.ops.aten.bitwise_and.Scalar" # Tensor because it gets converted from Scalar -> Tensor in lowering + aten_op = "torch.ops.aten.bitwise_and.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor" exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_bitwise_and_Scalar" @@ -119,8 +119,8 @@ def forward(self, tensor: torch.Tensor, scalar: int): class XorScalar(BitwiseBinaryScalar): - aten_op = "torch.ops.aten.bitwise_xor.Scalar" # Tensor because it gets converted from Scalar -> Tensor in lowering + aten_op = "torch.ops.aten.bitwise_xor.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_xor_Tensor" exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_bitwise_xor_Scalar" @@ -129,8 +129,8 @@ def forward(self, tensor: torch.Tensor, scalar: int): class OrScalar(BitwiseBinaryScalar): - aten_op = "torch.ops.aten.bitwise_or.Scalar" # Tensor because it gets converted from Scalar -> Tensor in lowering + aten_op = "torch.ops.aten.bitwise_or.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_or_Tensor" exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_bitwise_or_Scalar" @@ -174,8 +174,6 @@ def test_bitwise_and_tensor_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -190,8 +188,6 @@ def test_bitwise_and_scalar_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -235,13 +231,10 @@ def test_bitwise_and_scalar_u85_INT(test_data: input_t2): test_data(), AndScalar.aten_op, AndScalar.exir_op, - run_on_fvp=True, atol=0, rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -253,19 +246,16 @@ def test_bitwise_and_tensor_u85_INT(test_data: input_t2): test_data(), And().aten_op, And().exir_op, - run_on_fvp=True, atol=0, rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", And().test_data) @common.SkipIfNoModelConverter -def test_bitwise_and_tensor_vgf_FP(test_data: input_t2): +def test_bitwise_and_tensor_vgf_no_quant(test_data: input_t2): pipeline = OpNotSupportedPipeline[input_t2]( And(), test_data(), @@ -276,7 +266,7 @@ def test_bitwise_and_tensor_vgf_FP(test_data: input_t2): @common.parametrize("test_data", AndScalar().test_data) @common.SkipIfNoModelConverter -def test_bitwise_and_scalar_vgf_FP(test_data: input_t2): +def test_bitwise_and_scalar_vgf_no_quant(test_data: input_t2): pipeline = OpNotSupportedPipeline[input_t2]( AndScalar(), test_data(), @@ -287,7 +277,7 @@ def test_bitwise_and_scalar_vgf_FP(test_data: input_t2): @common.parametrize("test_data", And().test_data) @common.SkipIfNoModelConverter -def test_bitwise_and_tensor_vgf_INT(test_data: input_t2): +def test_bitwise_and_tensor_vgf_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( And(), test_data(), @@ -296,16 +286,14 @@ def test_bitwise_and_tensor_vgf_INT(test_data: input_t2): atol=0, rtol=0, qtol=0, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", AndScalar().test_data) @common.SkipIfNoModelConverter -def test_bitwise_and_scalar_vgf_INT(test_data: input_t2): +def test_bitwise_and_scalar_vgf_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( AndScalar(), test_data(), @@ -314,10 +302,8 @@ def test_bitwise_and_scalar_vgf_INT(test_data: input_t2): atol=0, rtol=0, qtol=0, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -357,8 +343,6 @@ def test_bitwise_xor_tensor_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -373,8 +357,6 @@ def test_bitwise_xor_scalar_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -418,13 +400,10 @@ def test_bitwise_xor_tensor_u85_INT(test_data: input_t2): test_data(), Xor().aten_op, Xor().exir_op, - run_on_fvp=True, atol=0, rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -436,19 +415,16 @@ def test_bitwise_xor_scalar_u85_INT(test_data: input_t2): test_data(), XorScalar.aten_op, XorScalar.exir_op, - run_on_fvp=True, atol=0, rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", Xor().test_data) @common.SkipIfNoModelConverter -def test_bitwise_xor_tensor_vgf_FP(test_data: input_t2): +def test_bitwise_xor_tensor_vgf_no_quant(test_data: input_t2): pipeline = OpNotSupportedPipeline[input_t2]( Xor(), test_data(), @@ -459,7 +435,7 @@ def test_bitwise_xor_tensor_vgf_FP(test_data: input_t2): @common.parametrize("test_data", XorScalar().test_data) @common.SkipIfNoModelConverter -def test_bitwise_xor_scalar_vgf_FP(test_data: input_t2): +def test_bitwise_xor_scalar_vgf_no_quant(test_data: input_t2): pipeline = OpNotSupportedPipeline[input_t2]( XorScalar(), test_data(), @@ -470,7 +446,7 @@ def test_bitwise_xor_scalar_vgf_FP(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) @common.SkipIfNoModelConverter -def test_bitwise_xor_tensor_vgf_INT(test_data: input_t2): +def test_bitwise_xor_tensor_vgf_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( Xor(), test_data(), @@ -479,16 +455,14 @@ def test_bitwise_xor_tensor_vgf_INT(test_data: input_t2): atol=0, rtol=0, qtol=0, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", XorScalar().test_data) @common.SkipIfNoModelConverter -def test_bitwise_xor_scalar_vgf_INT(test_data: input_t2): +def test_bitwise_xor_scalar_vgf_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( XorScalar(), test_data(), @@ -497,10 +471,8 @@ def test_bitwise_xor_scalar_vgf_INT(test_data: input_t2): atol=0, rtol=0, qtol=0, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -540,8 +512,6 @@ def test_bitwise_or_tensor_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -556,8 +526,6 @@ def test_bitwise_or_scalar_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -601,13 +569,10 @@ def test_bitwise_or_tensor_u85_INT(test_data: input_t2): test_data(), Or().aten_op, Or().exir_op, - run_on_fvp=True, atol=0, rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -619,19 +584,16 @@ def test_bitwise_or_scalar_u85_INT(test_data: input_t2): test_data(), OrScalar.aten_op, OrScalar.exir_op, - run_on_fvp=True, atol=0, rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", Or().test_data) @common.SkipIfNoModelConverter -def test_bitwise_or_tensor_vgf_FP(test_data: input_t2): +def test_bitwise_or_tensor_vgf_no_quant(test_data: input_t2): pipeline = OpNotSupportedPipeline[input_t2]( Or(), test_data(), @@ -642,7 +604,7 @@ def test_bitwise_or_tensor_vgf_FP(test_data: input_t2): @common.parametrize("test_data", OrScalar().test_data) @common.SkipIfNoModelConverter -def test_bitwise_or_scalar_vgf_FP(test_data: input_t2): +def test_bitwise_or_scalar_vgf_no_quant(test_data: input_t2): pipeline = OpNotSupportedPipeline[input_t2]( OrScalar(), test_data(), @@ -653,7 +615,7 @@ def test_bitwise_or_scalar_vgf_FP(test_data: input_t2): @common.parametrize("test_data", Or().test_data) @common.SkipIfNoModelConverter -def test_bitwise_or_tensor_vgf_INT(test_data: input_t2): +def test_bitwise_or_tensor_vgf_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( Or(), test_data(), @@ -662,16 +624,14 @@ def test_bitwise_or_tensor_vgf_INT(test_data: input_t2): atol=0, rtol=0, qtol=0, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", OrScalar().test_data) @common.SkipIfNoModelConverter -def test_bitwise_or_scalar_vgf_INT(test_data: input_t2): +def test_bitwise_or_scalar_vgf_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( OrScalar(), test_data(), @@ -680,10 +640,8 @@ def test_bitwise_or_scalar_vgf_INT(test_data: input_t2): atol=0, rtol=0, qtol=0, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_bitwise_not.py b/backends/arm/test/ops/test_bitwise_not.py new file mode 100644 index 00000000000..a483a376fa2 --- /dev/null +++ b/backends/arm/test/ops/test_bitwise_not.py @@ -0,0 +1,114 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU85PipelineINT, + OpNotSupportedPipeline, + TosaPipelineINT, + VgfPipeline, +) + +aten_op = "torch.ops.aten.bitwise_not.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_not_default" + +input_t1 = Tuple[torch.Tensor] + +test_data_suite = { + "zeros": torch.zeros(1, 10, 10, 10, dtype=torch.int32), + "ones": torch.ones(10, 2, 3, dtype=torch.int8), + "pattern1_int8": 0xAA * torch.ones(1, 2, 2, 2, dtype=torch.int8), + "pattern1_int16": 0xAAAA * torch.ones(1, 2, 2, 2, dtype=torch.int16), + "pattern1_int32": 0xAAAAAAAA * torch.ones(1, 2, 2, 2, dtype=torch.int32), + "pattern2_int8": 0xCC * torch.ones(1, 2, 2, 2, dtype=torch.int8), + "pattern2_int16": 0xCCCC * torch.ones(1, 2, 2, 2, dtype=torch.int16), + "pattern2_int32": 0xCCCCCCCC * torch.ones(1, 2, 2, 2, dtype=torch.int32), + "rand_rank2": torch.randint(-128, 127, (10, 10), dtype=torch.int8), + "rand_rank4": torch.randint(-128, 127, (1, 10, 10, 10), dtype=torch.int8), +} + + +class BitwiseNot(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return torch.bitwise_not(x) + + +@common.parametrize("test_data", test_data_suite) +def test_bitwise_not_tosa_FP(test_data: Tuple): + # We don't delegate bitwise_not since it is not supported on the FP profile. + pipeline = OpNotSupportedPipeline[input_t1]( + BitwiseNot(), + (test_data,), + {exir_op: 1}, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_bitwise_not_tosa_INT(test_data: Tuple): + pipeline = TosaPipelineINT[input_t1]( + BitwiseNot(), + (test_data,), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_bitwise_not_u55_INT(test_data: Tuple): + # We don't delegate bitwise_not since it is not supported on U55. + pipeline = OpNotSupportedPipeline[input_t1]( + BitwiseNot(), + (test_data,), + {exir_op: 1}, + quantize=True, + u55_subset=True, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_suite) +def test_bitwise_not_u85_INT(test_data: Tuple): + pipeline = EthosU85PipelineINT[input_t1]( + BitwiseNot(), + (test_data,), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_bitwise_not_vgf_no_quant(test_data: Tuple): + # We don't delegate bitwise_not since it is not supported on the FP profile. + pipeline = OpNotSupportedPipeline[input_t1]( + BitwiseNot(), + (test_data,), + {exir_op: 1}, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_bitwise_not_vgf_quant(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + BitwiseNot(), + (test_data,), + aten_op, + exir_op, + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index 7c0fc1665bb..5a905d1a1a7 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -97,7 +97,6 @@ def test_bmm_u55_INT(test_data: input_t1): test_data(), aten_op_bmm, exir_op_bmm, - run_on_fvp=True, ) pipeline.run() @@ -110,7 +109,6 @@ def test_bmm_u85_INT(test_data: input_t1): test_data(), aten_op_bmm, exir_op_bmm, - run_on_fvp=True, ) pipeline.run() @@ -123,7 +121,6 @@ def test_bmm_u55_INT_single_input(test_data: input_t1): test_data(), aten_op_bmm, exir_op_bmm, - run_on_fvp=True, ) pipeline.run() @@ -136,56 +133,61 @@ def test_bmm_u85_INT_single_input(test_data: input_t1): test_data(), aten_op_bmm, exir_op_bmm, - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", BMM.test_data_generators) @common.SkipIfNoModelConverter -def test_bmm_vgf_FP(test_data: input_t1): +def test_bmm_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( - BMM(), test_data(), aten_op_bmm, exir_op_bmm, tosa_version="TOSA-1.0+FP" + BMM(), + test_data(), + aten_op_bmm, + exir_op_bmm, + quantize=False, ) pipeline.run() -@common.parametrize("test_data", BMMSingleInput.test_data_generators) +@common.parametrize( + "test_data", + BMMSingleInput.test_data_generators, + flakies={"rand_big_1": 3}, +) @common.SkipIfNoModelConverter -def test_bmm_vgf_FP_single_input(test_data: input_t1): +def test_bmm_single_input_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( BMMSingleInput(), test_data(), aten_op_bmm, exir_op_bmm, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", BMM.test_data_generators) @common.SkipIfNoModelConverter -def test_bmm_vgf_INT(test_data: input_t1): +def test_bmm_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( BMM(), test_data(), aten_op_bmm, exir_op_bmm, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", BMMSingleInput.test_data_generators) @common.SkipIfNoModelConverter -def test_bmm_vgf_INT_single_input(test_data: input_t1): +def test_bmm_single_input_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( BMMSingleInput(), test_data(), aten_op_bmm, exir_op_bmm, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - # TODO: MLETORCH-1136 Change args of run_method_and_compare_outputs of the vgf tests - # pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py index 84ecd8641b5..84765dde90b 100644 --- a/backends/arm/test/ops/test_cat.py +++ b/backends/arm/test/ops/test_cat.py @@ -8,13 +8,11 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, - TOSAQuantizer, ) -from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -23,8 +21,6 @@ TosaPipelineINT, VgfPipeline, ) -from executorch.backends.arm.tosa.specification import TosaSpecification -from executorch.backends.xnnpack.test.tester import Quantize input_t1 = Tuple[torch.Tensor] # Input x @@ -120,7 +116,6 @@ def test_cat_u55_INT(test_data: Tuple): test_data(), aten_op, exir_op, - run_on_fvp=True, ) pipeline.run() @@ -133,56 +128,37 @@ def test_cat_u85_INT(test_data: Tuple): test_data(), aten_op, exir_op, - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", Cat.test_parameters) @common.SkipIfNoModelConverter -def test_cat_vgf_FP(test_data: Tuple): +def test_cat_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( - Cat(), test_data(), aten_op, exir_op, tosa_version="TOSA-1.0+FP" + Cat(), + test_data(), + aten_op, + exir_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", Cat.test_parameters) @common.SkipIfNoModelConverter -def test_cat_vgf_INT(test_data: Tuple): +def test_cat_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Cat(), test_data(), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() -def get_symmetric_a16w8_cat_quantizer(per_channel_quantization=False): - tosa_version = conftest.get_option("tosa_version") - tosa_profiles = { - "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), - } - - quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) - quantizer.set_global( - get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) - ) - - return Quantize( - quantizer, - get_symmetric_a16w8_quantization_config( - is_per_channel=per_channel_quantization - ), - ) - - @common.parametrize("test_data", Cat.test_parameters) -@pytest.mark.xfail( - reason="missing int16 cat ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13978" -) def test_cat_16a8w_tosa_INT(test_data: Tuple): """Test cat operation with 16A8W quantization (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -196,21 +172,14 @@ def test_cat_16a8w_tosa_INT(test_data: Tuple): use_to_edge_transform_and_lower=True, tosa_extensions=["int16"], ) - - pipeline.change_args( - "quantize", - get_symmetric_a16w8_cat_quantizer( - per_channel_quantization=per_channel_quantization - ), + pipeline.quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) ) pipeline.run() @common.parametrize("test_data", Cat.test_parameters) @common.XfailIfNoCorstone300 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 cat operations" -) def test_cat_16a8w_u55_INT16(test_data: Tuple): """Test cat operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -222,23 +191,16 @@ def test_cat_16a8w_u55_INT16(test_data: Tuple): exir_op, per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) - - pipeline.change_args( - "quantize", - get_symmetric_a16w8_cat_quantizer( - per_channel_quantization=per_channel_quantization - ), + pipeline.quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) ) + pipeline.run() @common.parametrize("test_data", Cat.test_parameters) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 cat operations" -) def test_cat_16a8w_u85_INT16(test_data: Tuple): """Test cat operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -250,13 +212,8 @@ def test_cat_16a8w_u85_INT16(test_data: Tuple): exir_op, per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) - - pipeline.change_args( - "quantize", - get_symmetric_a16w8_cat_quantizer( - per_channel_quantization=per_channel_quantization - ), + pipeline.quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) ) pipeline.run() diff --git a/backends/arm/test/ops/test_ceil.py b/backends/arm/test/ops/test_ceil.py index 64e9040a974..93b5f9cd009 100644 --- a/backends/arm/test/ops/test_ceil.py +++ b/backends/arm/test/ops/test_ceil.py @@ -78,7 +78,6 @@ def test_ceil_u55_INT(test_data: input_t1): (data,), module.aten_op, module.exir_op, - run_on_fvp=True, ) pipeline.run() @@ -92,28 +91,27 @@ def test_ceil_u85_INT(test_data: input_t1): (data,), module.aten_op, module.exir_op, - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", test_data) @common.SkipIfNoModelConverter -def test_ceil_vgf_FP(test_data: input_t1): +def test_ceil_vgf_no_quant(test_data: input_t1): module, data = test_data() pipeline = VgfPipeline[input_t1]( module, (data,), module.aten_op, module.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data) @common.SkipIfNoModelConverter -def test_ceil_vgf_INT(test_data: input_t1): +def test_ceil_vgf_quant(test_data: input_t1): module, data = test_data() pipeline = VgfPipeline[input_t1]( module, @@ -122,6 +120,6 @@ def test_ceil_vgf_INT(test_data: input_t1): module.exir_op, atol=0.06, rtol=0.01, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_clamp.py b/backends/arm/test/ops/test_clamp.py index ba490ccc0c6..60477c6cbe4 100644 --- a/backends/arm/test/ops/test_clamp.py +++ b/backends/arm/test/ops/test_clamp.py @@ -35,6 +35,25 @@ "rank_4_no_max": lambda: (torch.rand(1, 10, 10, 1) - 3, -3.3, None), } +test_data_suite_int32 = { + "int32_rank2": lambda: (torch.randint(-50, 50, (2, 3), dtype=torch.int32), -10, 10), + "int32_rank3_no_min": lambda: ( + torch.randint(-100, 100, (1, 3, 3), dtype=torch.int32), + None, + 25, + ), + "int32_rank3_no_max": lambda: ( + torch.randint(-100, 100, (1, 3, 3), dtype=torch.int32), + -25, + None, + ), + "int32_rank4_large_range": lambda: ( + torch.randint(-200, 200, (1, 2, 4, 4), dtype=torch.int32), + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + ), +} + class Clamp(torch.nn.Module): def __init__( @@ -53,7 +72,6 @@ def forward(self, x): @common.parametrize("test_data", test_data_suite) def test_clamp_tosa_FP(test_data): - input_tensor, min_val, max_val = test_data() model = Clamp(min_val, max_val) @@ -69,7 +87,20 @@ def test_clamp_tosa_FP(test_data): @common.parametrize("test_data", test_data_suite) def test_clamp_tosa_INT(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + + pipeline = TosaPipelineINT[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + ) + pipeline.run() + +@common.parametrize("test_data", test_data_suite_int32) +def test_clamp_tosa_INT_int32_inputs(test_data): input_tensor, min_val, max_val = test_data() model = Clamp(min_val, max_val) @@ -79,15 +110,28 @@ def test_clamp_tosa_INT(test_data): aten_op, exir_op, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.pop_stage("quantize") + pipeline.run() + +@common.parametrize("test_data", test_data_suite) +def test_clamp_tosa_INT_a16w8(test_data): + """Test clamp operation with int16 I/O quantization for TOSA INT.""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = TosaPipelineINT[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + tosa_extensions=["int16"], + ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 def test_clamp_u55_INT(test_data): - input_tensor, min_val, max_val = test_data() model = Clamp(min_val, max_val) @@ -96,17 +140,31 @@ def test_clamp_u55_INT(test_data): (input_tensor,), aten_op, exir_op, - run_on_fvp=True, ) + pipeline.run() + - pipeline.change_args("run_method_and_compare_outputs", qtol=1) +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_clamp_16a8w_u55_INT(test_data): + """Test clamp operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = EthosU55PipelineINT[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 def test_clamp_u85_INT(test_data): - input_tensor, min_val, max_val = test_data() model = Clamp(min_val, max_val) @@ -115,16 +173,31 @@ def test_clamp_u85_INT(test_data): (input_tensor,), aten_op, exir_op, - run_on_fvp=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_clamp_16a8w_u85_INT(test_data): + """Test clamp operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = EthosU85PipelineINT[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_clamp_vgf_FP(test_data): +def test_clamp_vgf_no_quant(test_data): input_tensor, min_val, max_val = test_data() model = Clamp(min_val, max_val) pipeline = VgfPipeline[input_t]( @@ -132,14 +205,14 @@ def test_clamp_vgf_FP(test_data): (input_tensor,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_clamp_vgf_INT(test_data): +def test_clamp_vgf_quant(test_data): input_tensor, min_val, max_val = test_data() model = Clamp(min_val, max_val) pipeline = VgfPipeline[input_t]( @@ -147,8 +220,255 @@ def test_clamp_vgf_INT(test_data): (input_tensor,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, + ) + pipeline.run() + + +aten_op_tensor = "torch.ops.aten.clamp.Tensor" +exir_op_tensor = "executorch_exir_dialects_edge__ops_aten_clamp_Tensor" + +test_data_suite_tensor_FP = { + # test_name: (test_data, min, max) + "rank_1": lambda: (torch.rand(10) * 2, torch.tensor(-1.0), torch.tensor(1.0)), + "rank_2": lambda: (torch.rand(1, 35), torch.tensor(0.5), torch.tensor(0.8)), + "rank_3": lambda: ( + torch.ones(1, 10, 10), + torch.rand(1, 10, 10) * 0.5, + torch.rand(1, 10, 10) * -0.5, + ), + "rank_4": lambda: ( + torch.rand(1, 10, 10, 1) * 2, + torch.tensor(-0.1), + torch.tensor(2.0), + ), + "rank_4_no_max": lambda: ( + torch.rand(10, 20, 30, 40) - 3, + torch.rand(30, 40) - 3.3, + None, + ), + "rank_4_no_min": lambda: ( + torch.rand(10, 20, 30, 40) * 10, + None, + torch.rand(10, 20, 30, 40) * 5.0, + ), +} + +test_data_suite_tensor_INT32 = { + "int32_rank2": lambda: ( + torch.randint(-50, 50, (2, 3), dtype=torch.int32), + torch.tensor(-10), + torch.tensor(10), + ), + "int32_rank3_no_min_broadcast_1_3": lambda: ( + torch.randint(0, 100, (1, 3, 3), dtype=torch.int32) + 10, + None, + torch.tensor([[3, 5, 7]], dtype=torch.int32), # torch.Size([1, 3]) + ), + "int32_rank3_no_max_broadcast_3_1": lambda: ( + torch.randint(-100, 100, (1, 3, 3), dtype=torch.int32), + torch.tensor([[3], [5], [7]], dtype=torch.int32), # torch.Size([3, 1]) + None, + ), + "int32_rank4_large_range": lambda: ( + torch.randint(-200, 200, (1, 2, 4, 4), dtype=torch.int32), + torch.tensor((torch.iinfo(torch.int32).min)), + torch.tensor((torch.iinfo(torch.int32).max)), + ), + "int32_rank4_broadcast_1_2": lambda: ( + torch.ones(1, 2, 4, 4, dtype=torch.int32) * 100, + torch.randint(-10, 10, (4,), dtype=torch.int32), # torch.Size([4]) + torch.randint( + -10, + 10, + ( + 4, + 4, + ), + dtype=torch.int32, + ), # torch.Size([4, 4]) + ), + "int32_rank4_broadcast_3_4": lambda: ( + torch.ones(1, 2, 4, 4, dtype=torch.int32) * 100, + torch.randint( + -10, + 10, + ( + 1, + 4, + 4, + ), + dtype=torch.int32, + ), # torch.Size([1, 4, 4]) + torch.randint( + -10, + 10, + ( + 1, + 1, + 4, + 4, + ), + dtype=torch.int32, + ), # torch.Size([1, 1, 4, 4]) + ), +} + +test_data_suite_tensor_INT64 = { + "int64_rank_3": lambda: ( + torch.ones(1, 10, 10, dtype=torch.int64), + torch.tensor(-1), + torch.tensor(-1), + ), + "int64_rank_4": lambda: ( + torch.randint(-100, 100, (1, 3, 3)), + torch.tensor(-10), + torch.tensor(20), + ), +} + + +@common.parametrize("test_data", test_data_suite_tensor_FP) +def test_clamp_tensor_tosa_FP(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + + pipeline = TosaPipelineFP[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + ) + + pipeline.run() + + +@common.parametrize( + "test_data", test_data_suite_tensor_INT32 | test_data_suite_tensor_INT64 +) +def test_clamp_tensor_tosa_INT(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + + pipeline = TosaPipelineINT[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + ) + pipeline.run() + + +@common.parametrize( + "test_data", test_data_suite_tensor_INT32 | test_data_suite_tensor_INT64 +) +def test_clamp_tensor_tosa_INT_a16w8(test_data): + """Test clamp operation with int16 I/O quantization for TOSA INT.""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = TosaPipelineINT[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tensor_INT32) +@common.XfailIfNoCorstone300 +def test_clamp_tensor_u55_INT(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + + pipeline = EthosU55PipelineINT[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tensor_INT32) +@common.XfailIfNoCorstone300 +def test_clamp_tensor_16a8w_u55_INT(test_data): + """Test clamp operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = EthosU55PipelineINT[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tensor_INT32) +@common.XfailIfNoCorstone320 +def test_clamp_tensor_u85_INT(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + + pipeline = EthosU85PipelineINT[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tensor_INT32) +@common.XfailIfNoCorstone320 +def test_clamp_tensor_16a8w_u85_INT(test_data): + """Test clamp operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = EthosU85PipelineINT[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tensor_FP) +@common.SkipIfNoModelConverter +def test_clamp_tensor_vgf_no_quant(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = VgfPipeline[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + quantize=False, + ) + pipeline.run() + + +@common.parametrize( + "test_data", test_data_suite_tensor_INT32 | test_data_suite_tensor_INT64 +) +@common.SkipIfNoModelConverter +def test_clamp_tensor_vgf_quant(test_data): + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = VgfPipeline[input_t]( + model, + (input_tensor,), + aten_op_tensor, + exir_op_tensor, + quantize=True, ) - # TODO: MLETORCH-1136 Change args of run_method_and_compare_outputs of the vgf tests - # pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index b240fb1ea07..3aec2ed1950 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -102,7 +102,6 @@ def test_clone_u55_INT(input_data): input_tensor, aten_op, exir_op, - run_on_fvp=True, ) pipeline.run() @@ -118,7 +117,6 @@ def test_clone_u85_INT(input_data): input_tensor, aten_op, exir_op, - run_on_fvp=True, ) pipeline.run() @@ -126,23 +124,27 @@ def test_clone_u85_INT(input_data): @common.parametrize("test_data", delegated_clones) @common.SkipIfNoModelConverter -def test_clone_vgf_FP(test_data): +def test_clone_vgf_no_quant(test_data): module, input_tensor = test_data() pipeline = VgfPipeline[input_t]( - module(), input_tensor, aten_op, exir_op, tosa_version="TOSA-1.0+FP" + module(), + input_tensor, + aten_op, + exir_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", delegated_clones) @common.SkipIfNoModelConverter -def test_clone_vgf_INT(test_data): +def test_clone_vgf_quant(test_data): module, input_tensor = test_data() pipeline = VgfPipeline[input_t]( module(), input_tensor, aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_cond.py b/backends/arm/test/ops/test_cond.py new file mode 100644 index 00000000000..c0a6df69a2d --- /dev/null +++ b/backends/arm/test/ops/test_cond.py @@ -0,0 +1,264 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, +) + +aten_op = "torch.ops.higher_order.cond" +exir_op = "torch.ops.higher_order.cond" + +input_t1 = Tuple[torch.Tensor] +input_t2 = Tuple[torch.Tensor, torch.Tensor] + + +class CondZeroArgsOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch() -> torch.Tensor: + return torch.zeros(10) + + def false_branch() -> torch.Tensor: + return torch.ones(10) + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, []) + + +class CondOneArgOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(arg: torch.Tensor) -> torch.Tensor: + return torch.sin(arg) + + def false_branch(arg: torch.Tensor) -> torch.Tensor: + return torch.cos(arg) + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x]) + + +class CondOneArgBufferOneOutput(torch.nn.Module): + def __init__(self, *args: common.Any, **kwargs: common.Any) -> None: + super().__init__(*args, **kwargs) + self.buffer = torch.rand(1, 1, 2, 2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(arg: torch.Tensor, buffer: torch.Tensor) -> torch.Tensor: + return torch.sin(arg) + buffer + + def false_branch(arg: torch.Tensor, buffer: torch.Tensor) -> torch.Tensor: + return torch.cos(arg) + buffer + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x, self.buffer]) + + +class CondOneArgAndScalarOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(arg: torch.Tensor) -> torch.Tensor: + return arg + 1.0 + + def false_branch(arg: torch.Tensor) -> torch.Tensor: + return arg - 1.0 + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x]) + + +class CondOneArgTwoOutputs(torch.nn.Module): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def true_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return arg + torch.sin(arg), arg - torch.sin(arg) + + def false_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return arg - arg.mean(), arg + arg.mean() + + predicate = x.flatten().sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x]) + + +class CondNestedOneArgOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def inner_true(arg: torch.Tensor) -> torch.Tensor: + return arg + torch.full((1,), (1.0)) + + def inner_false(arg: torch.Tensor) -> torch.Tensor: + return arg - torch.full((1,), (1.0)) + + def outer_true(arg: torch.Tensor) -> torch.Tensor: + inner_predicate = arg.mean() > 0 + return torch.cond(inner_predicate, inner_true, inner_false, [arg]) + + def outer_false(arg: torch.Tensor) -> torch.Tensor: + return arg * torch.full((1,), (1.0)) + + predicate = x.sum() > 0 + return torch.cond(predicate, outer_true, outer_false, [x]) + + +class CondMultipleOneArgOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def first_true(arg: torch.Tensor) -> torch.Tensor: + return arg.sigmoid() + + def first_false(arg: torch.Tensor) -> torch.Tensor: + return arg.relu() + + first_predicate = x.sum() > 0 + intermediate = torch.cond(first_predicate, first_true, first_false, [x]) + + def second_true(arg: torch.Tensor) -> torch.Tensor: + return arg.sin() + + def second_false(arg: torch.Tensor) -> torch.Tensor: + return arg.cos() + + second_predicate = intermediate.mean() > 0 + return torch.cond(second_predicate, second_true, second_false, [intermediate]) + + +class CondTwoArgsOneOutput(torch.nn.Module): + def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor: + def true_branch(arg_l: torch.Tensor, arg_r: torch.Tensor) -> torch.Tensor: + return arg_l + arg_r + + def false_branch(arg_l: torch.Tensor, arg_r: torch.Tensor) -> torch.Tensor: + return arg_l - arg_r + + predicate = (lhs - rhs).sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [lhs, rhs]) + + +class CondTwoArgsTwoOutputs(torch.nn.Module): + def forward( + self, lhs: torch.Tensor, rhs: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + def true_branch( + arg_l: torch.Tensor, arg_r: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + return arg_l + arg_r, arg_l * arg_r + + def false_branch( + arg_l: torch.Tensor, arg_r: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + diff = arg_l - arg_r + return diff, arg_l + diff + + predicate = (lhs * rhs).sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [lhs, rhs]) + + +def _single_input_case( + module_factory: Callable[[], torch.nn.Module] +) -> Callable[[], tuple[torch.nn.Module, input_t1]]: + def _create() -> tuple[torch.nn.Module, input_t1]: + return module_factory(), (torch.randn(1, 1, 2, 2),) + + return _create + + +def _dual_input_case( + module_factory: Callable[[], torch.nn.Module] +) -> Callable[[], tuple[torch.nn.Module, input_t2]]: + def _create() -> tuple[torch.nn.Module, input_t2]: + return module_factory(), (torch.randn(2, 3, 4, 6), torch.randn(2, 3, 4, 6)) + + return _create + + +test_cases: dict[str, Callable[[], tuple[torch.nn.Module, tuple]]] = { + "zero_args_one_output": _single_input_case(CondZeroArgsOneOutput), + "one_arg_one_output": _single_input_case(CondOneArgOneOutput), + "one_arg_const_one_output": _single_input_case(CondOneArgBufferOneOutput), + "one_arg_and_scalar_one_output": _single_input_case(CondOneArgAndScalarOneOutput), + "one_arg_two_outputs": _single_input_case(CondOneArgTwoOutputs), + "two_args_one_output": _dual_input_case(CondTwoArgsOneOutput), + "two_args_two_outputs": _dual_input_case(CondTwoArgsTwoOutputs), + "nested_one_arg_one_output": _single_input_case(CondNestedOneArgOneOutput), + "multiple_one_arg_one_output": _single_input_case(CondMultipleOneArgOneOutput), +} + + +def _make_calibration_samples( + module: torch.nn.Module, example_inputs: tuple +) -> tuple[tuple[torch.Tensor, ...], ...]: + """Return one example input that triggers the if branch, and one that triggers the else branch.""" + + if isinstance(module, CondTwoArgsOneOutput): + # Predicate is sum(lhs-rhs) > 0 + lhs, rhs = example_inputs + if_example_inputs = (lhs, rhs) + else_example_inputs = (rhs, lhs) + elif isinstance(module, CondTwoArgsTwoOutputs): + # Predicate is sum(lhs*rhs) > 0 + lhs, rhs = example_inputs + if_example_inputs = (lhs, rhs) + else_example_inputs = (lhs, -rhs) + else: + # Predicate is sum(x) > 0 + (x,) = example_inputs + if_example_inputs = (x,) + else_example_inputs = (-x,) + + return (if_example_inputs, else_example_inputs) + + +@common.parametrize( + "case", + test_cases, + xfails={ + "one_arg_and_scalar_one_output": "Scalars become get_attr nodes that are not supported.", + "nested_one_arg_one_output": "Not fully delegated.", + }, +) +def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]): + module, example_inputs = case() + pipeline = TosaPipelineFP[tuple]( + module, example_inputs, aten_op, tosa_extensions=["cf"] + ) + + # Make sure no cond ops are left after partitioning. + pipeline.add_stage_after( + "to_edge_transform_and_lower", + ArmTester.check_not, + pipeline.tester, + ["torch.ops.higher_order.cond"], + ) + pipeline.run() + + +@common.parametrize( + "case", + test_cases, + xfails={ + "zero_args_one_output": "Since the submodules have no input, the tracer fails finding a fake tensor mode," + " and traces the graph with real tensors, which tosa.RESCALE can't handle.", + "one_arg_and_scalar_one_output": "Incorrect quantization on the scalar.", + "nested_one_arg_one_output": "Node submodule_0 target submodule_0 references nonexistent attribute submodule_0", + }, +) +def test_cond_tosa_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]): + module, example_inputs = case() + pipeline = TosaPipelineINT[tuple]( + module, example_inputs, aten_op, tosa_extensions=["cf"] + ) + calibration_samples = _make_calibration_samples(module, example_inputs) + quant_stage_pos = pipeline.find_pos("quantize") + quant_stage = pipeline._stages[quant_stage_pos].args[0] + quant_stage.calibration_samples = calibration_samples + + # Make sure no cond ops are left after partitioning. + pipeline.add_stage_after( + "to_edge_transform_and_lower", + ArmTester.check_not, + pipeline.tester, + ["torch.ops.higher_order.cond"], + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_constant_pad_nd.py b/backends/arm/test/ops/test_constant_pad_nd.py index d70249c31d1..05f14e698f7 100644 --- a/backends/arm/test/ops/test_constant_pad_nd.py +++ b/backends/arm/test/ops/test_constant_pad_nd.py @@ -77,29 +77,43 @@ def test_constant_pad_nd_tosa_INT(test_data: Tuple): pipeline.run() +@common.parametrize("test_data", test_data_suite) +def test_constant_pad_nd_tosa_INT_a16w8(test_data: Tuple): + """Test constant_pad_nd op with int16 I/O quantization for TOSA INT.""" + test_data, padding, value = test_data() + pipeline = TosaPipelineINT[input_t1]( + ConstantPadND(padding, value), + (test_data,), + aten_op, + exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_constant_pad_nd_vgf_FP(test_data: Tuple): +def test_constant_pad_nd_vgf_no_quant(test_data: Tuple): inp, padding, value = test_data() pipeline = VgfPipeline[input_t1]( ConstantPadND(padding, value), (inp,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_constant_pad_nd_vgf_INT(test_data: Tuple): +def test_constant_pad_nd_vgf_quant(test_data: Tuple): inp, padding, value = test_data() pipeline = VgfPipeline[input_t1]( ConstantPadND(padding, value), (inp,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_conv1d.py b/backends/arm/test/ops/test_conv1d.py index ac66bc1556b..71351676f7b 100644 --- a/backends/arm/test/ops/test_conv1d.py +++ b/backends/arm/test/ops/test_conv1d.py @@ -307,7 +307,6 @@ def test_convolution_1d_u55_INT(test_data): model.get_inputs(), aten_op, exir_op, - run_on_fvp=True, per_channel_quantization=per_channel_quantization, qtol=1, ) @@ -323,7 +322,6 @@ def test_convolution_1d_u85_INT(test_data): model.get_inputs(), aten_op, exir_op, - run_on_fvp=True, per_channel_quantization=per_channel_quantization, qtol=1, ) @@ -332,27 +330,27 @@ def test_convolution_1d_u85_INT(test_data): @common.parametrize("test_data", test_data_FP) @common.SkipIfNoModelConverter -def test_convolution_1d_vgf_FP(test_data): +def test_convolution_1d_vgf_no_quant(test_data): pipeline = VgfPipeline[input_t]( test_data(), test_data().get_inputs(), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_INT) @common.SkipIfNoModelConverter -def test_convolution_1d_vgf_INT(test_data): +def test_convolution_1d_vgf_quant(test_data): model, per_channel_quantization = test_data() pipeline = VgfPipeline[input_t]( model, model.get_inputs(), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", per_channel_quantization=per_channel_quantization, + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index 0300f7c2049..2b86ea6a5c4 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -117,26 +117,26 @@ def forward(self, x): return x -conv2d_2x2_3x2x40x40_nobias = Conv2d( +conv2d_2x2_3x2x14x14_nobias = Conv2d( in_channels=2, out_channels=3, kernel_size=(2, 2), stride=1, bias=False, padding=0, - width=40, - height=40, - batches=3, + width=14, + height=14, + batches=2, ) -conv2d_3x3_1x3x256x256_st1 = Conv2d( +conv2d_3x3_1x3x24x24_st1 = Conv2d( in_channels=3, out_channels=10, kernel_size=(3, 3), stride=1, padding=0, - width=256, - height=256, + width=24, + height=24, batches=1, ) @@ -151,14 +151,14 @@ def forward(self, x): batches=1, ) -conv2d_1x1_1x2x128x128_st1 = Conv2d( +conv2d_1x1_1x2x16x16_st1 = Conv2d( in_channels=2, out_channels=1, kernel_size=(1, 1), stride=1, padding=0, - width=128, - height=128, + width=16, + height=16, batches=1, ) @@ -173,25 +173,25 @@ def forward(self, x): batches=1, ) -conv2d_5x5_3x2x128x128_st1 = Conv2d( +conv2d_5x5_3x2x24x24_st1 = Conv2d( in_channels=2, out_channels=3, kernel_size=(5, 5), stride=1, padding=0, - width=128, - height=128, - batches=3, + width=24, + height=24, + batches=2, ) -conv2d_3x3_1x3x224x224_st2_pd1 = Conv2d( +conv2d_3x3_1x3x28x28_st2_pd1 = Conv2d( in_channels=3, out_channels=16, kernel_size=(3, 3), stride=2, padding=1, - width=224, - height=224, + width=28, + height=28, batches=1, ) @@ -304,8 +304,8 @@ def forward(self, x): two_conv2d_nobias = Conv2d( nbr_conv=2, - width=256, - height=256, + width=32, + height=32, in_channels=[3, 10], out_channels=[10, 15], kernel_size=[(5, 5), (5, 5)], @@ -317,8 +317,8 @@ def forward(self, x): two_conv2d = Conv2d( nbr_conv=2, - width=256, - height=256, + width=32, + height=32, in_channels=[3, 10], out_channels=[10, 15], kernel_size=[(5, 5), (5, 5)], @@ -359,10 +359,10 @@ def forward(self, x): # Shenanigan to get a nicer output when test fails. With unittest it looks like: # FAIL: test_convolution_2d_tosa_INT_2_3x3_1x3x12x12_st2_pd1 test_data_FP = { - "2x2_3x2x40x40_nobias": lambda: conv2d_2x2_3x2x40x40_nobias, - "3x3_1x3x256x256_st1": lambda: conv2d_3x3_1x3x256x256_st1, + "2x2_3x2x14x14_nobias": lambda: conv2d_2x2_3x2x14x14_nobias, + "3x3_1x3x24x24_st1": lambda: conv2d_3x3_1x3x24x24_st1, "3x3_1x3x12x12_st2_pd1": lambda: conv2d_3x3_1x3x12x12_st2_pd1, - "1x1_1x2x128x128_st1": lambda: conv2d_1x1_1x2x128x128_st1, + "1x1_1x2x16x16_st1": lambda: conv2d_1x1_1x2x16x16_st1, "2x2_1x1x14x13_st2_needs_adjust_pass": lambda: conv2d_2x2_1x1x14x13_st2, "5x5_1x3x14x15_st3_pd1_needs_adjust_pass": lambda: conv2d_5x5_1x3x14x15_st3_pd1, "7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass": lambda: conv2d_7x7_1x3x16x16_st2_pd1_dl2, @@ -373,8 +373,8 @@ def forward(self, x): "3x3_1x3x8x9_st3_pd0_dl1_needs_adjust_pass": lambda: conv2d_3x3_1x3x8x9_st3_pd0_dl1, "3x4_1x3x7x7_st3_pd0_dl1_needs_adjust_pass": lambda: conv2d_3x4_1x3x7x7_st3_pd0_dl1, "4x3_1x3x7x7_st3_pd0_dl1_needs_adjust_pass": lambda: conv2d_4x3_1x3x7x7_st3_pd0_dl1, - "5x5_3x2x128x128_st1": lambda: conv2d_5x5_3x2x128x128_st1, - "3x3_1x3x224x224_st2_pd1": lambda: conv2d_3x3_1x3x224x224_st2_pd1, + "5x5_3x2x24x24_st1": lambda: conv2d_5x5_3x2x24x24_st1, + "3x3_1x3x28x28_st2_pd1": lambda: conv2d_3x3_1x3x28x28_st2_pd1, "two_conv2d_nobias": lambda: two_conv2d_nobias, "two_conv2d": lambda: two_conv2d, "groups": lambda: conv2d_groups, @@ -426,7 +426,6 @@ def test_convolution_2d_u55_INT(test_data): model.get_inputs(), aten_op, exir_op, - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @@ -441,7 +440,6 @@ def test_convolution_u85_INT(test_data): model.get_inputs(), aten_op, exir_op, - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @@ -449,29 +447,29 @@ def test_convolution_u85_INT(test_data): @common.parametrize("test_data", test_data_FP) @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_FP(test_data): +def test_convolution_2d_vgf_no_quant(test_data): model = test_data() pipeline = VgfPipeline[input_t]( model, model.get_inputs(), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_INT) @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_INT(test_data): +def test_convolution_2d_vgf_quant(test_data): model, per_channel_quantization = test_data() pipeline = VgfPipeline[input_t]( model, model.get_inputs(), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", per_channel_quantization=per_channel_quantization, + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_conv3d.py b/backends/arm/test/ops/test_conv3d.py index b26f75daa1a..a0b5b3e4cb2 100644 --- a/backends/arm/test/ops/test_conv3d.py +++ b/backends/arm/test/ops/test_conv3d.py @@ -8,7 +8,11 @@ import pytest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -17,6 +21,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.conv3d.default" exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default" @@ -109,7 +115,11 @@ def __init__( def get_inputs(self): return ( torch.randn( - self.batches, self.in_channels[0], self.height, self.width, self.depth + self.batches, + self.in_channels[0], + self.depth, + self.height, + self.width, ).to(self.dtype), ) @@ -120,26 +130,108 @@ def forward(self, x): return x -conv3d_2x2_3x2x40x40_nobias = Conv3d( +class Conv3dMultiOp(torch.nn.Module): + """ + Mixed Conv3d/Conv2d pipeline used to verify spatial-rank propagation across ops. + + Topology: + conv3d -> reshape -> conv2d -> reshape/permutation -> conv2d -> reshape -> add(5D) + """ + + def __init__(self, dtype=torch.float): + super().__init__() + self.dtype = dtype + self.conv3d = torch.nn.Conv3d( + in_channels=2, + out_channels=4, + kernel_size=(3, 3, 3), + stride=1, + padding=1, + ).to(dtype) + self.conv2d_main = torch.nn.Conv2d( + in_channels=4, + out_channels=4, + kernel_size=3, + stride=1, + padding=1, + ).to(dtype) + self.conv2d_pointwise = torch.nn.Conv2d( + in_channels=4, + out_channels=4, + kernel_size=1, + stride=1, + padding=0, + ).to(dtype) + self.activation = torch.nn.ReLU() + + def get_inputs(self): + return (torch.randn(1, 2, 3, 8, 8).to(self.dtype),) + + def forward(self, x): + x3d = self.conv3d(x) + batches, channels, depth, height, width = x3d.shape + + reshaped = x3d.reshape(batches * depth, channels, height, width) + conv2d_out = self.activation(self.conv2d_main(reshaped)) + + conv2d_out_5d = ( + conv2d_out.reshape(batches, depth, channels, height, width) + .permute(0, 2, 1, 3, 4) + .contiguous() + ) + + reshaped_again = conv2d_out_5d.permute(0, 2, 1, 3, 4).reshape( + batches * depth, channels, height, width + ) + conv2d_pointwise_out = self.conv2d_pointwise(reshaped_again) + conv2d_pointwise_out_5d = ( + conv2d_pointwise_out.reshape(batches, depth, channels, height, width) + .permute(0, 2, 1, 3, 4) + .contiguous() + ) + + return conv2d_pointwise_out_5d + x3d + + +class DepthwiseConv3d(torch.nn.Module): + def __init__(self, dtype=torch.float): + super().__init__() + self.dtype = dtype + self.conv = torch.nn.Conv3d( + in_channels=2, + out_channels=4, + kernel_size=(3, 3, 3), + padding=1, + groups=2, + ).to(dtype) + + def get_inputs(self): + return (torch.randn(1, 2, 3, 8, 8).to(self.dtype),) + + def forward(self, x): + return self.conv(x) + + +conv3d_2x2_3x2x14x14_nobias = Conv3d( in_channels=2, out_channels=3, kernel_size=(2, 2, 2), stride=1, bias=False, padding=0, - width=40, - height=40, - batches=3, + width=14, + height=14, + batches=2, ) -conv3d_3x3_1x3x256x256_st1 = Conv3d( +conv3d_3x3_1x3x24x24_st1 = Conv3d( in_channels=3, out_channels=10, kernel_size=(3, 3, 3), stride=1, padding=0, - width=256, - height=256, + width=24, + height=24, batches=1, ) @@ -154,14 +246,14 @@ def forward(self, x): batches=1, ) -conv3d_1x1_1x2x128x128_st1 = Conv3d( +conv3d_1x1_1x2x16x16_st1 = Conv3d( in_channels=2, out_channels=1, kernel_size=(1, 1, 1), stride=1, padding=0, - width=128, - height=128, + width=16, + height=16, batches=1, ) @@ -176,25 +268,25 @@ def forward(self, x): batches=1, ) -conv3d_5x5_3x2x128x128_st1 = Conv3d( +conv3d_5x5_3x2x24x24_st1 = Conv3d( in_channels=2, out_channels=3, kernel_size=(5, 5, 5), stride=1, padding=0, - width=128, - height=128, - batches=3, + width=24, + height=24, + batches=2, ) -conv3d_3x3_1x3x224x224_st2_pd1 = Conv3d( +conv3d_3x3_1x3x28x28_st2_pd1 = Conv3d( in_channels=3, out_channels=16, kernel_size=(3, 3, 3), stride=2, padding=1, - width=224, - height=224, + width=28, + height=28, batches=1, ) @@ -214,8 +306,8 @@ def forward(self, x): out_channels=3, kernel_size=(7, 7, 7), stride=2, - padding=1, - dilation=2, + padding=3, + dilation=1, width=16, height=16, batches=1, @@ -306,10 +398,10 @@ def forward(self, x): ) test_data_FP = { - "2x2_3x2x40x40_nobias": lambda: conv3d_2x2_3x2x40x40_nobias, - "3x3_1x3x256x256_st1": lambda: conv3d_3x3_1x3x256x256_st1, + "2x2_3x2x14x14_nobias": lambda: conv3d_2x2_3x2x14x14_nobias, + "3x3_1x3x24x24_st1": lambda: conv3d_3x3_1x3x24x24_st1, "3x3_1x3x12x12_st2_pd1": lambda: conv3d_3x3_1x3x12x12_st2_pd1, - "1x1_1x2x128x128_st1": lambda: conv3d_1x1_1x2x128x128_st1, + "1x1_1x2x16x16_st1": lambda: conv3d_1x1_1x2x16x16_st1, "2x2_1x1x14x13_st2_needs_adjust_pass": lambda: conv3d_2x2_1x1x14x13_st2, "5x5_1x3x14x15_st3_pd1_needs_adjust_pass": lambda: conv3d_5x5_1x3x14x15_st3_pd1, "7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass": lambda: conv3d_7x7_1x3x16x16_st2_pd1_dl2, @@ -320,8 +412,8 @@ def forward(self, x): "3x3_1x3x8x9_st3_pd0_dl1_needs_adjust_pass": lambda: conv3d_3x3_1x3x8x9_st3_pd0_dl1, "3x4_1x3x7x7_st3_pd0_dl1_needs_adjust_pass": lambda: conv3d_3x4_1x3x7x7_st3_pd0_dl1, "4x3_1x3x7x7_st3_pd0_dl1_needs_adjust_pass": lambda: conv3d_4x3_1x3x7x7_st3_pd0_dl1, - "5x5_3x2x128x128_st1": lambda: conv3d_5x5_3x2x128x128_st1, - "3x3_1x3x224x224_st2_pd1": lambda: conv3d_3x3_1x3x224x224_st2_pd1, + "5x5_3x2x24x24_st1": lambda: conv3d_5x5_3x2x24x24_st1, + "3x3_1x3x28x28_st2_pd1": lambda: conv3d_3x3_1x3x28x28_st2_pd1, } # Generate a new test set paired with per_channel_quant=True/False. @@ -331,11 +423,36 @@ def forward(self, x): for q in [True, False] } +test_data_INT16 = { + f"{k},16a8w,per_channel_quant={q}": (lambda v=v, q=q: (v(), q)) + for (k, v) in test_data_FP.items() + for q in [True, False] +} + + +def get_symmetric_a16w8_conv3d_quantizer(per_channel_quantization: bool = False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quant_config = get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ) + quantizer.set_global(quant_config) + quantizer.set_module_type(torch.nn.Conv3d, quant_config) + + return Quantize( + quantizer, + quant_config, + ) + + input_t = Tuple[torch.Tensor] @common.parametrize("test_data", test_data_FP) -@pytest.mark.skip # Not implemented, skip until it is. def test_convolution_3d_tosa_FP(test_data): pipeline = TosaPipelineFP[input_t]( test_data(), test_data().get_inputs(), aten_op, exir_op @@ -344,7 +461,6 @@ def test_convolution_3d_tosa_FP(test_data): @common.parametrize("test_data", test_data_INT) -@pytest.mark.skip # Not implemented, skip until it is. def test_convolution_3d_tosa_INT(test_data): model, per_channel_quantization = test_data() pipeline = TosaPipelineINT[input_t]( @@ -358,8 +474,63 @@ def test_convolution_3d_tosa_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_INT16) +def test_convolution_3d_tosa_INT16(test_data): + model, per_channel_quantization = test_data() + pipeline = TosaPipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + qtol=1, + ) + pipeline.change_args( + "quantize", + get_symmetric_a16w8_conv3d_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +def test_convolution_3d_tosa_FP_multi_op(): + """Ensure mixed Conv3d/Conv2d graphs keep correct spatial annotations.""" + model = Conv3dMultiOp() + pipeline = TosaPipelineFP[input_t](model, model.get_inputs(), aten_op, exir_op) + pipeline.run() + + +def test_convolution_3d_tosa_INT_multi_op(): + """Ensure mixed Conv3d/Conv2d graphs keep correct spatial annotations.""" + model = Conv3dMultiOp() + pipeline = TosaPipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + ) + pipeline.run() + + +def test_convolution_3d_tosa_FP_depthwise(): + """Depthwise or Grouped Conv3d should be rejected until grouped support exists.""" + model = DepthwiseConv3d() + pipeline = TosaPipelineFP[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + run_on_tosa_ref_model=False, + ) + with pytest.raises(RuntimeError, match="CONV3D with groups != 1"): + pipeline.run() + + @common.parametrize("test_data", test_data_INT) -@pytest.mark.skip # Not implemented, skip until it is. +@pytest.mark.skip(reason="Ethos-U55 does not support CONV3D yet.") def test_convolution_3d_u55_INT(test_data): model, per_channel_quantization = test_data() pipeline = EthosU55PipelineINT[input_t]( @@ -367,14 +538,13 @@ def test_convolution_3d_u55_INT(test_data): model.get_inputs(), aten_op, exir_op, - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @common.parametrize("test_data", test_data_INT) -@pytest.mark.skip # Not implemented, skip until it is. +@pytest.mark.skip(reason="Ethos-U85 does not support CONV3D yet.") def test_convolution_3d_u85_INT(test_data): model, per_channel_quantization = test_data() pipeline = EthosU85PipelineINT[input_t]( @@ -382,37 +552,62 @@ def test_convolution_3d_u85_INT(test_data): model.get_inputs(), aten_op, exir_op, - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @common.parametrize("test_data", test_data_FP) -@pytest.mark.skip # Not implemented, skip until it is. @common.SkipIfNoModelConverter -def test_convolution_3d_vgf_FP(test_data): +def test_convolution_3d_vgf_no_quant(test_data): pipeline = VgfPipeline[input_t]( test_data(), test_data().get_inputs(), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_INT) -@pytest.mark.skip # Not implemented, skip until it is. @common.SkipIfNoModelConverter -def test_convolution_3d_vgf_INT(test_data): +def test_convolution_3d_vgf_quant(test_data): model, per_channel_quantization = test_data() pipeline = VgfPipeline[input_t]( model, model.get_inputs(), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_convolution_3d_vgf_no_quant_multi_op(): + """Ensure mixed Conv3d/Conv2d graphs keep correct spatial annotations.""" + model = Conv3dMultiOp() + pipeline = VgfPipeline[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + quantize=False, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_convolution_3d_vgf_quant_multi_op(): + """Ensure mixed Conv3d/Conv2d graphs keep correct spatial annotations.""" + model = Conv3dMultiOp() + pipeline = VgfPipeline[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index a7a031468ea..45b1b31f1d3 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -48,7 +48,7 @@ def __init__(self): # 1. 1x1 CONV2d + ReLU6 (Pointwise) self.pointwise_conv2d = torch.nn.Conv2d( in_channels=16, out_channels=96, kernel_size=1, stride=1, groups=1 - ) ## (1, 128, 81, 81) + ) ## Example output shape (1, 96, 33, 33) self.batch_norm2d_16 = torch.nn.BatchNorm2d(96, affine=False) self.relu6 = torch.nn.ReLU6() @@ -60,15 +60,15 @@ def __init__(self): padding=1, stride=1, groups=96, - ) ## (1, 128, H, W) + ) ## Example output shape (1, 96, H, W) # 3. Linear 1x1 Conv2d self.pointwise_conv2d_linear = torch.nn.Conv2d( in_channels=96, out_channels=16, kernel_size=1, stride=1, groups=1 - ) ## (1, 32, 81, 81) + ) ## Example output shape (1, 16, 33, 33) def get_inputs(self) -> Tuple[torch.Tensor]: - return (torch.randn(1, 16, 81, 81),) + return (torch.randn(1, 16, 33, 33),) def forward(self, x): input = x @@ -106,7 +106,7 @@ def __init__(self): self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) def get_inputs(self) -> Tuple[torch.Tensor]: - return (torch.randn(1, 3, 128, 128),) + return (torch.randn(1, 3, 48, 48),) def forward(self, x): x = self.conv2d(x) @@ -145,7 +145,7 @@ def __init__(self, affine: bool): self.relu6 = torch.nn.ReLU6() def get_inputs(self) -> Tuple[torch.Tensor]: - return (torch.randn(1, 3, 256, 256),) + return (torch.randn(1, 3, 64, 64),) def forward(self, x): x = self.conv2d(x) @@ -161,11 +161,11 @@ class ComboConvRelu6(torch.nn.Module): ] test_data_FP = { - "combo_conv_relu_2_x_4d": lambda: (2 * torch.randn(1, 3, 256, 256),), - "combo_conv_relu_0_5_x_4d": lambda: (0.5 * torch.randn(1, 3, 256, 256),), - "combo_conv_relu_4d": lambda: (torch.randn(1, 3, 256, 256),), - "combo_conv_relu_neg_0_5_x_4d": lambda: (-0.5 * torch.randn(1, 3, 256, 256),), - "combo_conv_relu_neg_2_x_4d": lambda: (-2 * torch.randn(1, 3, 256, 256),), + "combo_conv_relu_2_x_4d": lambda: (2 * torch.randn(1, 3, 64, 64),), + "combo_conv_relu_0_5_x_4d": lambda: (0.5 * torch.randn(1, 3, 64, 64),), + "combo_conv_relu_4d": lambda: (torch.randn(1, 3, 64, 64),), + "combo_conv_relu_neg_0_5_x_4d": lambda: (-0.5 * torch.randn(1, 3, 64, 64),), + "combo_conv_relu_neg_2_x_4d": lambda: (-2 * torch.randn(1, 3, 64, 64),), } # Generate a new test set paired with per_channel_quant=True/False. @@ -196,10 +196,10 @@ class ComboConvAvgPool2d(torch.nn.Module): ] test_data_FP = { - "combo_conv_avgpool_20_x_4d": lambda: (20 * torch.randn(1, 3, 64, 32),), - "combo_conv_avgpool_4d": lambda: (torch.randn(1, 3, 100, 200),), - "combo_conv_avgpool_5_x_4d_randn": lambda: (5 * torch.randn(1, 3, 256, 256),), - "combo_conv_avgpool_2_x_4d": lambda: (torch.rand(1, 3, 512, 128),), + "combo_conv_avgpool_20_x_4d": lambda: (20 * torch.randn(1, 3, 48, 24),), + "combo_conv_avgpool_4d": lambda: (torch.randn(1, 3, 60, 120),), + "combo_conv_avgpool_5_x_4d_randn": lambda: (5 * torch.randn(1, 3, 64, 64),), + "combo_conv_avgpool_2_x_4d": lambda: (torch.rand(1, 3, 96, 32),), } # Generate a new test set paired with per_channel_quant=True/False. @@ -258,7 +258,6 @@ def test_convolution_2d_u55_INT_meandim(): model.get_inputs(), aten_ops=[], exir_ops=ComboConv2dMeandim.edge_op_list, - run_on_fvp=True, ) pipeline.run() @@ -271,33 +270,32 @@ def test_convolution_2d_u85_INT_meandim(): model.get_inputs(), aten_ops=[], exir_ops=ComboConv2dMeandim.edge_op_list, - run_on_fvp=True, ) pipeline.run() @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_FP_meandim(): +def test_convolution_2d_meandim_vgf_no_quant(): model = ComboConv2dMeandim() pipeline = VgfPipeline[input_t1]( model, model.get_inputs(), aten_op=[], exir_op=ComboConv2dMeandim.edge_op_list, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_INT_meandim(): +def test_convolution_2d_meandim_vgf_quant(): model = ComboConv2dMeandim() pipeline = VgfPipeline[input_t1]( model, model.get_inputs(), aten_op=[], exir_op=ComboConv2dMeandim.edge_op_list, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -346,7 +344,6 @@ def test_convolution_2d_u55_INT_batchnorm_relu6(test_data): model.get_inputs(), aten_ops=[], exir_ops=[], - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @@ -362,7 +359,6 @@ def test_convolution_2d_u85_INT_batchnorm_relu6(test_data): model.get_inputs(), aten_ops=[], exir_ops=[], - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @@ -370,7 +366,7 @@ def test_convolution_2d_u85_INT_batchnorm_relu6(test_data): @common.parametrize("test_data", ComboConvBatchnormRelu6.test_data_FP) @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_FP_batchnorm_relu6(test_data): +def test_convolution_2d_batchnorm_relu6_vgf_no_quant(test_data): affine = test_data model = ComboConvBatchnormRelu6(affine) pipeline = VgfPipeline[input_t1]( @@ -378,14 +374,14 @@ def test_convolution_2d_vgf_FP_batchnorm_relu6(test_data): model.get_inputs(), aten_op=[], exir_op=ComboConvBatchnormRelu6.edge_op_list, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", ComboConvBatchnormRelu6.test_data_INT) @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_INT_batchnorm_relu6(test_data): +def test_convolution_2d_batchnorm_relu6_vgf_quant(test_data): affine, per_channel_quantization = test_data model = ComboConvBatchnormRelu6(affine) pipeline = VgfPipeline[input_t1]( @@ -393,8 +389,8 @@ def test_convolution_2d_vgf_INT_batchnorm_relu6(test_data): model.get_inputs(), aten_op=[], exir_op=ComboConvBatchnormRelu6.edge_op_list, - tosa_version="TOSA-1.0+INT", per_channel_quantization=per_channel_quantization, + quantize=True, ) pipeline.run() @@ -441,7 +437,6 @@ def test_convolution_2d_u55_INT_relu6(test_data): input, aten_ops=[], exir_ops=ComboConvRelu6.edge_op_list, - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @@ -457,7 +452,6 @@ def test_convolution_2d_u85_INT_relu6(test_data): input, aten_ops=[], exir_ops=ComboConvRelu6.edge_op_list, - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @@ -465,21 +459,21 @@ def test_convolution_2d_u85_INT_relu6(test_data): @common.parametrize("test_data", ComboConvRelu6.test_data_FP) @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_FP_relu6(test_data): +def test_convolution_2d_relu6_vgf_no_quant(test_data): model = ComboConvRelu6() pipeline = VgfPipeline[input_t1]( model, test_data(), aten_op=[], exir_op=ComboConvRelu6.edge_op_list, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", ComboConvRelu6.test_data_INT) @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_INT_relu6(test_data): +def test_convolution_2d_relu6_vgf_quant(test_data): input, per_channel_quantization = test_data() model = ComboConvRelu6() pipeline = VgfPipeline[input_t1]( @@ -487,8 +481,8 @@ def test_convolution_2d_vgf_INT_relu6(test_data): input, aten_op=[], exir_op=ComboConvRelu6.edge_op_list, - tosa_version="TOSA-1.0+INT", per_channel_quantization=per_channel_quantization, + quantize=True, ) pipeline.run() @@ -533,7 +527,6 @@ def test_convolution_2d_u55_INT_block_bottleneck(test_data): model.get_inputs(), aten_ops=[], exir_ops=[], - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @@ -549,28 +542,27 @@ def test_convolution_2d_u85_INT_block_bottleneck(test_data): model.get_inputs(), aten_ops=[], exir_ops=[], - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_FP_block_bottleneck(): +def test_convolution_2d_block_bottleneck_vgf_no_quant(): model = ComboBlockBottleneckResidual() pipeline = VgfPipeline[input_t1]( model, model.get_inputs(), aten_op=[], exir_op=ComboBlockBottleneckResidual.edge_op_list, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", ComboBlockBottleneckResidual.test_data_INT) @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_INT_block_bottleneck(test_data): +def test_convolution_2d_block_bottleneck_vgf_quant(test_data): per_channel_quantization = test_data model = ComboBlockBottleneckResidual() pipeline = VgfPipeline[input_t1]( @@ -578,11 +570,9 @@ def test_convolution_2d_vgf_INT_block_bottleneck(test_data): model.get_inputs(), aten_op=[], exir_op=ComboBlockBottleneckResidual.edge_op_list, - tosa_version="TOSA-1.0+INT", per_channel_quantization=per_channel_quantization, + quantize=True, ) - # TODO: MLETORCH-1136 Change args of run_method_and_compare_outputs of the vgf tests - # pipeline.change_args("run_method_and_compare_outputs", model.get_inputs(), qtol=1) pipeline.run() @@ -628,7 +618,6 @@ def test_convolution_2d_u55_INT_avgpool2d(test_data): input, aten_ops=[], exir_ops=[], - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @@ -644,7 +633,6 @@ def test_convolution_2d_u85_INT_avgpool2d(test_data): input, aten_ops=[], exir_ops=[], - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @@ -652,21 +640,21 @@ def test_convolution_2d_u85_INT_avgpool2d(test_data): @common.parametrize("test_data", ComboConvAvgPool2d.test_data_FP) @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_FP_avgpool2d(test_data): +def test_convolution_2d_avgpool2d_vgf_no_quant(test_data): model = ComboConvAvgPool2d() pipeline = VgfPipeline[input_t1]( model, test_data(), aten_op=[], exir_op=ComboConvAvgPool2d.edge_op_list, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", ComboConvAvgPool2d.test_data_INT) @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_INT_avgpool2d(test_data): +def test_convolution_2d_avgpool2d_vgf_quant(test_data): input, per_channel_quantization = test_data() model = ComboConvAvgPool2d() pipeline = VgfPipeline[input_t1]( @@ -674,7 +662,7 @@ def test_convolution_2d_vgf_INT_avgpool2d(test_data): input, aten_op=[], exir_op=ComboConvAvgPool2d.edge_op_list, - tosa_version="TOSA-1.0+INT", per_channel_quantization=per_channel_quantization, + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_conv_constant_pad_nd.py b/backends/arm/test/ops/test_conv_constant_pad_nd.py index 636c18ef753..aecce1e3a3e 100644 --- a/backends/arm/test/ops/test_conv_constant_pad_nd.py +++ b/backends/arm/test/ops/test_conv_constant_pad_nd.py @@ -119,27 +119,27 @@ def test_constant_pad_nd_tosa_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_constant_pad_nd_vgf_FP(test_data: Tuple): +def test_constant_pad_nd_vgf_no_quant(test_data: Tuple): test_data, padding, value = test_data pipeline = VgfPipeline[input_t1]( ConstantPadND(padding, value), (test_data,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_constant_pad_nd_vgf_INT(test_data: Tuple): +def test_constant_pad_nd_vgf_quant(test_data: Tuple): test_data, padding, value = test_data pipeline = VgfPipeline[input_t1]( ConstantPadND(padding, value), (test_data,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_copy.py b/backends/arm/test/ops/test_copy.py new file mode 100644 index 00000000000..556e952a2ba --- /dev/null +++ b/backends/arm/test/ops/test_copy.py @@ -0,0 +1,171 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + + +class CopyOutput(torch.nn.Module): + def forward(self, x): + y = torch.zeros(x.shape) + return y.copy_(x / x) + x + + +class CopyFirstArg(torch.nn.Module): + def forward(self, x): + y = torch.zeros(x.shape) + return y.copy_(x) + x + + +class CopySecondArg(torch.nn.Module): + def forward(self, x): + y = torch.zeros(x.shape) + return x * y.copy_(x) + + +class CopyBothArgs(torch.nn.Module): + def forward(self, x): + y = torch.zeros(x.shape) + return y.copy_(x) + y.copy_(x) + + +class CopyAfterOtherOp(torch.nn.Module): + def forward(self, x): + y = torch.zeros(x.shape) + x = x * 2 + return y.copy_(x) + x + + +class CopyParallelToOtherOp(torch.nn.Module): + def forward(self, x): + y = torch.zeros(x.shape) + return x * 2 + y.copy_(x) + + +test_suite = { + "copy_output": lambda: ( + CopyOutput, + (torch.rand(1, 2, 3, 4, dtype=torch.float32),), + ), + "copy_first_arg": lambda: ( + CopyFirstArg, + (torch.rand(1, 2, 3, 4, dtype=torch.float32),), + ), + "copy_second_arg": lambda: ( + CopySecondArg, + (torch.rand(1, 2, 3, 4, dtype=torch.float32),), + ), + "copy_both_args": lambda: ( + CopyBothArgs, + (torch.rand(1, 2, 3, 4, dtype=torch.float32),), + ), + "copy_after_other_op": lambda: ( + CopyAfterOtherOp, + (torch.rand(1, 2, 3, 4, dtype=torch.float32),), + ), + "copy_parallel_to_other_op": lambda: ( + CopyParallelToOtherOp, + (torch.rand(1, 2, 3, 4, dtype=torch.float32),), + ), +} + + +aten_op = "torch.ops.aten.copy_.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_copy_default" + +input_t = Tuple[torch.Tensor] + + +@common.parametrize("input_data", test_suite) +def test_copy_tosa_FP(input_data): + module, input_tensor = input_data() + pipeline = TosaPipelineFP[input_t]( + module(), + input_tensor, + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("input_data", test_suite) +def test_copy_tosa_INT(input_data): + module, input_tensor = input_data() + + pipeline = TosaPipelineINT[input_t]( + module(), + input_tensor, + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("input_data", test_suite) +@common.XfailIfNoCorstone300 +def test_copy_u55_INT(input_data): + module, input_tensor = input_data() + + pipeline = EthosU55PipelineINT[input_t]( + module(), + input_tensor, + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("input_data", test_suite) +@common.XfailIfNoCorstone320 +def test_copy_u85_INT(input_data): + module, input_tensor = input_data() + + pipeline = EthosU85PipelineINT[input_t]( + module(), + input_tensor, + aten_op, + exir_op, + ) + + pipeline.run() + + +@common.parametrize("test_data", test_suite) +@common.SkipIfNoModelConverter +def test_copy_vgf_no_quant(test_data): + module, input_tensor = test_data() + pipeline = VgfPipeline[input_t]( + module(), + input_tensor, + aten_op=aten_op, + exir_op=exir_op, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_suite) +@common.SkipIfNoModelConverter +def test_copy_vgf_quant(test_data): + module, input_tensor = test_data() + pipeline = VgfPipeline[input_t]( + module(), + input_tensor, + aten_op, + exir_op, + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_cos.py b/backends/arm/test/ops/test_cos.py index acb950f2a2e..42686115189 100644 --- a/backends/arm/test/ops/test_cos.py +++ b/backends/arm/test/ops/test_cos.py @@ -66,50 +66,50 @@ def test_cos_tosa_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 def test_cos_u55_INT(test_data: Tuple): pipeline = EthosU55PipelineINT[input_t1]( Cos(), (test_data,), aten_op, exir_ops=[], - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 def test_cos_u85_INT(test_data: Tuple): pipeline = EthosU85PipelineINT[input_t1]( Cos(), (test_data,), aten_op, exir_ops=[], - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_cos_vgf_FP(test_data: Tuple): +def test_cos_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Cos(), (test_data,), aten_op, exir_op=[], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_cos_vgf_INT(test_data: Tuple): +def test_cos_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Cos(), (test_data,), aten_op, exir_op=[], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_cosh.py b/backends/arm/test/ops/test_cosh.py index 60920d03f94..a07b3cea2c6 100644 --- a/backends/arm/test/ops/test_cosh.py +++ b/backends/arm/test/ops/test_cosh.py @@ -76,9 +76,6 @@ def test_cosh_u55_INT(test_data: Tuple): @common.parametrize( "test_data", test_data_suite, - xfails={ - "ones_4D": "MLBEDSW-11046 - Incorrect output for TABLE followed by RESHAPE" - }, strict=False, ) def test_cosh_u85_INT(test_data: Tuple): @@ -90,25 +87,25 @@ def test_cosh_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_cosh_vgf_FP(test_data: Tuple): +def test_cosh_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Cosh(), (test_data,), [], [], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_cosh_vgf_INT(test_data: Tuple): +def test_cosh_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Cosh(), (test_data,), [], [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_cumsum.py b/backends/arm/test/ops/test_cumsum.py index ce175fb37c0..09afd572fb8 100644 --- a/backends/arm/test/ops/test_cumsum.py +++ b/backends/arm/test/ops/test_cumsum.py @@ -68,28 +68,28 @@ def test_cumsum_tosa_INT(test_data: input_t1): @common.parametrize("test_data", CumsumModule.test_parameters) @common.SkipIfNoModelConverter -def test_cumsum_vgf_FP(test_data: input_t1): +def test_cumsum_vgf_no_quant(test_data: input_t1): module = CumsumModule() args = test_data() pipeline = VgfPipeline[input_t1]( module, args, aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", CumsumModule.test_parameters) @common.SkipIfNoModelConverter -def test_cumsum_vgf_INT(test_data: input_t1): +def test_cumsum_vgf_quant(test_data: input_t1): module = CumsumModule() args = test_data() pipeline = VgfPipeline[input_t1]( module, args, aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 0f8b34d3d47..0c92896a5f3 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -68,68 +68,68 @@ batches=1, ) -dw_conv1d_3_1x3x256_gp3_st1 = Conv1d( +dw_conv1d_3_1x3x32_gp3_st1 = Conv1d( in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=3, padding=0, - length=256, + length=32, batches=1, ) -dw_conv2d_3x3_1x3x256x256_gp3_st1 = Conv2d( +dw_conv2d_3x3_1x3x24x24_gp3_st1 = Conv2d( in_channels=3, out_channels=3, kernel_size=(3, 3), stride=(1, 1), groups=3, padding=0, - width=256, - height=256, + width=24, + height=24, batches=1, ) -dw_conv2d_3x3_1x4x256x256_gp4_st1 = Conv2d( +dw_conv2d_3x3_1x4x24x24_gp4_st1 = Conv2d( in_channels=4, out_channels=8, kernel_size=(3, 3), stride=(1, 1), groups=4, padding=0, - width=256, - height=256, + width=24, + height=24, batches=1, ) -dw_conv2d_3x3_2x8x198x198_gp8_st3 = Conv2d( +dw_conv2d_3x3_2x8x27x27_gp8_st3 = Conv2d( in_channels=8, out_channels=16, kernel_size=(3, 3), stride=3, groups=8, padding=0, - width=198, - height=198, + width=27, + height=27, batches=2, ) -dw_conv2d_3x3_1x4x256x256_gp4_nobias = Conv2d( +dw_conv2d_3x3_1x4x24x24_gp4_nobias = Conv2d( in_channels=4, out_channels=8, kernel_size=(3, 3), stride=1, groups=4, bias=False, - width=256, - height=256, + width=24, + height=24, batches=1, ) two_dw_conv1d = Conv1d( nbr_conv=2, - length=64, + length=16, in_channels=[4, 8], out_channels=[8, 24], kernel_size=[3, 3], @@ -142,8 +142,8 @@ two_dw_conv2d = Conv2d( nbr_conv=2, - width=64, - height=64, + width=24, + height=24, in_channels=[4, 8], out_channels=[8, 24], kernel_size=[(3, 3), (3, 3)], @@ -157,10 +157,10 @@ # Shenanigan to get a nicer output when test fails. test_data_conv2d_FP = { "2x2_1x6x4x4_gp6_st1": lambda: dw_conv2d_2x2_1x6x4x4_gp6_st1, - "3x3_1x3x256x256_gp3_st1": lambda: dw_conv2d_3x3_1x3x256x256_gp3_st1, - "3x3_1x4x256x256_gp4_nobias": lambda: dw_conv2d_3x3_1x4x256x256_gp4_nobias, - "3x3_1x4x256x256_gp4_st1": lambda: dw_conv2d_3x3_1x4x256x256_gp4_st1, - "3x3_2x8x198x198_gp8_st3": lambda: dw_conv2d_3x3_2x8x198x198_gp8_st3, + "3x3_1x3x24x24_gp3_st1": lambda: dw_conv2d_3x3_1x3x24x24_gp3_st1, + "3x3_1x4x24x24_gp4_nobias": lambda: dw_conv2d_3x3_1x4x24x24_gp4_nobias, + "3x3_1x4x24x24_gp4_st1": lambda: dw_conv2d_3x3_1x4x24x24_gp4_st1, + "3x3_2x8x27x27_gp8_st3": lambda: dw_conv2d_3x3_2x8x27x27_gp8_st3, "two_dw_conv2d": lambda: two_dw_conv2d, } @@ -176,9 +176,9 @@ f"{k},per_channel_quant={q}": (lambda v=v, q=q: (v(), q)) for (k, v) in { "2x2_1x6x4x4_gp6_st1": lambda: dw_conv2d_2x2_1x6x4x4_gp6_st1, - "3x3_1x3x256x256_gp3_st1": lambda: dw_conv2d_3x3_1x3x256x256_gp3_st1, - "3x3_1x4x256x256_gp4_st1": lambda: dw_conv2d_3x3_1x4x256x256_gp4_st1, - "3x3_1x4x256x256_gp4_nobias": lambda: dw_conv2d_3x3_1x4x256x256_gp4_nobias, + "3x3_1x3x24x24_gp3_st1": lambda: dw_conv2d_3x3_1x3x24x24_gp3_st1, + "3x3_1x4x24x24_gp4_st1": lambda: dw_conv2d_3x3_1x4x24x24_gp4_st1, + "3x3_1x4x24x24_gp4_nobias": lambda: dw_conv2d_3x3_1x4x24x24_gp4_nobias, }.items() for q in [True, False] } @@ -186,7 +186,7 @@ test_data_conv1d_FP = { "2_1x6x4_gp6_st1": lambda: dw_conv1d_2_1x6x4_gp6_st1, "two_dw_conv1d": lambda: two_dw_conv1d, - "3_1x3x256_gp3_st1": lambda: dw_conv1d_3_1x3x256_gp3_st1, + "3_1x3x32_gp3_st1": lambda: dw_conv1d_3_1x3x32_gp3_st1, "3_1x3x14_gp3_st1": lambda: dw_conv1d_3_1x3x14_gp3_st1, } @@ -225,28 +225,28 @@ def test_convolution_2d_tosa_INT_depthwise(test_data): @common.parametrize("test_data", test_data_conv1d_FP | test_data_conv2d_FP) @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_FP_depthwise(test_data: torch.nn.Module): +def test_convolution_2d_depthwise_vgf_no_quant(test_data: torch.nn.Module): model = test_data() pipeline = VgfPipeline[input_t]( model, model.get_inputs(), aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_conv1d_INT | test_data_conv2d_INT) @common.SkipIfNoModelConverter -def test_convolution_2d_vgf_INT_depthwise(test_data): +def test_convolution_2d_depthwise_vgf_quant(test_data): model, per_channel_quantization = test_data() pipeline = VgfPipeline[input_t]( model, model.get_inputs(), aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -260,7 +260,6 @@ def test_convolution_2d_u55_INT_depthwise(test_data): model.get_inputs(), aten_ops=[], exir_ops=exir_op, - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @@ -275,7 +274,6 @@ def test_convolution_1d_u55_INT_depthwise(test_data): model.get_inputs(), aten_ops=[], exir_ops=exir_op, - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @@ -290,7 +288,6 @@ def test_convolution_2d_u85_INT_depthwise(test_data): model.get_inputs(), aten_ops=[], exir_ops=exir_op, - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() @@ -305,7 +302,6 @@ def test_convolution_1d_u85_INT_depthwise(test_data): model.get_inputs(), aten_ops=[], exir_ops=exir_op, - run_on_fvp=True, per_channel_quantization=per_channel_quantization, ) pipeline.run() diff --git a/backends/arm/test/ops/test_div.py b/backends/arm/test/ops/test_div.py index 5bacac1c962..9e932b155b0 100644 --- a/backends/arm/test/ops/test_div.py +++ b/backends/arm/test/ops/test_div.py @@ -109,7 +109,6 @@ def test_div_tensor_u55_INT(test_data: Tuple): test_data(), aten_ops=[], exir_ops=[], - run_on_fvp=True, ) pipeline.run() @@ -122,28 +121,31 @@ def test_div_tensor_u85_INT(test_data: Tuple): test_data(), aten_ops=[], exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_div_tensor_vgf_FP(test_data: Tuple): +def test_div_tensor_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( - Div(), test_data(), aten_op, exir_op, tosa_version="TOSA-1.0+FP" + Div(), + test_data(), + aten_op, + exir_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_div_tensor_vgf_INT(test_data: Tuple): +def test_div_tensor_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Div(), test_data(), aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_div_tensor_mode.py b/backends/arm/test/ops/test_div_tensor_mode.py index 909b83bd97f..866805ac31a 100644 --- a/backends/arm/test/ops/test_div_tensor_mode.py +++ b/backends/arm/test/ops/test_div_tensor_mode.py @@ -96,7 +96,6 @@ def test_div_tensor_mode_u55_INT(data): aten_ops=model.aten_ops_int, exir_ops=[], use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.run() @@ -113,14 +112,13 @@ def test_div_tensor_mode_u85_INT(data): aten_ops=model.aten_ops_int, exir_ops=[], use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.run() @common.SkipIfNoModelConverter @common.parametrize("data", test_data) -def test_div_tensor_mode_vgf_INT(data): +def test_div_tensor_mode_vgf_quant(data): mode, inputs = data() model = DivTensorModeFloat(mode) @@ -129,8 +127,8 @@ def test_div_tensor_mode_vgf_INT(data): inputs, aten_op=model.aten_ops_int, exir_op=[], - tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, + quantize=True, ) pipeline.pop_stage("check_count.exir") pipeline.run() @@ -138,7 +136,7 @@ def test_div_tensor_mode_vgf_INT(data): @common.SkipIfNoModelConverter @common.parametrize("data", test_data) -def test_div_tensor_mode_vgf_FP(data): +def test_div_tensor_mode_vgf_no_quant(data): mode, inputs = data() model = DivTensorModeFloat(mode) @@ -147,7 +145,7 @@ def test_div_tensor_mode_vgf_FP(data): inputs, aten_op=model.aten_ops, exir_op=[], - tosa_version="TOSA-1.0+FP", use_to_edge_transform_and_lower=True, + quantize=False, ) pipeline.run() diff --git a/backends/arm/test/ops/test_elu.py b/backends/arm/test/ops/test_elu.py index 884f54c0202..c748f8385dc 100644 --- a/backends/arm/test/ops/test_elu.py +++ b/backends/arm/test/ops/test_elu.py @@ -107,27 +107,27 @@ def test_elu_u85_INT(test_module: input_t1): @common.SkipIfNoModelConverter @common.parametrize("test_module", test_data_suite) -def test_elu_vgf_FP(test_module: input_t1): +def test_elu_vgf_no_quant(test_module: input_t1): alpha, test_data = test_module() pipeline = VgfPipeline[input_t1]( Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.SkipIfNoModelConverter @common.parametrize("test_module", test_data_suite) -def test_elu_vgf_INT(test_module: input_t1): +def test_elu_vgf_quant(test_module: input_t1): alpha, test_data = test_module() pipeline = VgfPipeline[input_t1]( Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_embedding.py b/backends/arm/test/ops/test_embedding.py index 901fbbc0916..b70c2f6545a 100644 --- a/backends/arm/test/ops/test_embedding.py +++ b/backends/arm/test/ops/test_embedding.py @@ -27,10 +27,17 @@ def forward(self, weights: torch.Tensor, indices: torch.Tensor): return torch.embedding(weights, indices) -input_params = Tuple[torch.Tensor, torch.Tensor, torch.dtype] +class ExpandEmbedding(Embedding): + example_inputs = (torch.randn(10, 3), torch.tensor([[1, 2, 3]], dtype=torch.int32)) + + def forward(self, weights: torch.Tensor, indices: torch.Tensor): + return torch.embedding(weights, indices.expand(2, 3)) + + +input_params = Tuple[torch.Tensor, torch.Tensor] -test_input: dict[input_params] = { +test_input: dict[str, input_params] = { "test_1": ( torch.randn(10, 3), torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32), @@ -89,34 +96,49 @@ def test_embedding_tosa_INT(test_input: input_params): pipeline.run() +def test_expand_embedding_tosa_INT(): + op = ExpandEmbedding() + pipeline = TosaPipelineINT( + op, + ExpandEmbedding.example_inputs, + ExpandEmbedding.aten_op, + ExpandEmbedding.exir_op, + use_to_edge_transform_and_lower=True, + ) + pipeline.pop_stage("check.aten") + pipeline.pop_stage("check_count.exir") + + pipeline.run() + + @pytest.mark.skip("reason=MLETORCH-1274 Improve data type checks during partitioning") @common.parametrize("test_input", test_input) @common.SkipIfNoModelConverter -def test_embedding_vgf_FP(test_input: input_params): +def test_embedding_vgf_no_quant(test_input: input_params): op = Embedding() pipeline = VgfPipeline[input_params]( op, test_input, op.aten_op, op.exir_op, - tosa_version="TOSA-1.0+FP", use_to_edge_transform_and_lower=True, transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()], + quantize=False, ) pipeline.run() @common.parametrize("test_input", test_input) @common.SkipIfNoModelConverter -def test_embedding_vgf_INT(test_input: input_params): +def test_embedding_vgf_quant(test_input: input_params): op = Embedding() pipeline = VgfPipeline[input_params]( op, test_input, op.aten_op, op.exir_op, - tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, + quantize=True, ) pipeline.pop_stage("check.aten") pipeline.pop_stage("check_count.exir") diff --git a/backends/arm/test/ops/test_eq.py b/backends/arm/test/ops/test_eq.py index b840869ba48..a726a04fd6e 100644 --- a/backends/arm/test/ops/test_eq.py +++ b/backends/arm/test/ops/test_eq.py @@ -121,6 +121,30 @@ def test_eq_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_eq_tensor_16a8w_tosa_INT(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_eq_scalar_tosa_INT_16a8w(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_eq_scalar_u55_INT_tensor(test_module): @@ -150,14 +174,7 @@ def test_eq_scalar_u55_INT(test_module): pipeline.run() -@common.parametrize( - "test_module", - test_data_tensor, - xfails={ - "eq_tensor_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85", - }, - strict=False, -) +@common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone320 def test_eq_scalar_u85_INT_tensor(test_module): pipeline = EthosU85PipelineINT[input_t]( @@ -165,19 +182,11 @@ def test_eq_scalar_u85_INT_tensor(test_module): test_module().get_inputs(), Equal.aten_op_Tensor, Equal.exir_op, - run_on_fvp=True, ) pipeline.run() -@common.parametrize( - "test_module", - test_data_scalar, - xfails={ - "eq_scalar_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85", - }, - strict=False, -) +@common.parametrize("test_module", test_data_scalar) @common.XfailIfNoCorstone320 def test_eq_scalar_u85_INT(test_module): pipeline = EthosU85PipelineINT[input_t]( @@ -185,50 +194,93 @@ def test_eq_scalar_u85_INT(test_module): test_module().get_inputs(), Equal.aten_op_Tensor, Equal.exir_op, - run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_eq_tensor_16a8w_u85_INT(test_module): + """Test eq operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_eq_scalar_16a8w_u85_INT(test_module): + """Test eq operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, ) pipeline.run() @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter -def test_eq_scalar_vgf_FP_tensor(test_module): +def test_eq_scalar_vgf_no_quant_tensor(test_module): pipeline = VgfPipeline[input_t]( - test_module(), test_module().get_inputs(), Equal.aten_op_Tensor, Equal.exir_op + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_data_scalar) @common.SkipIfNoModelConverter -def test_eq_scalar_vgf_FP(test_module): +def test_eq_scalar_vgf_no_quant(test_module): pipeline = VgfPipeline[input_t]( - test_module(), test_module().get_inputs(), Equal.aten_op_Scalar, Equal.exir_op + test_module(), + test_module().get_inputs(), + Equal.aten_op_Scalar, + Equal.exir_op, + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter -def test_eq_scalar_vgf_INT_tensor(test_module): +def test_eq_scalar_vgf_quant_tensor(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), Equal.aten_op_Tensor, Equal.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_module", test_data_scalar) @common.SkipIfNoModelConverter -def test_eq_scalar_vgf_INT(test_module): +def test_eq_scalar_vgf_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), Equal.aten_op_Tensor, Equal.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_erf.py b/backends/arm/test/ops/test_erf.py index 363b1e2d8c9..6ad9f64b121 100644 --- a/backends/arm/test/ops/test_erf.py +++ b/backends/arm/test/ops/test_erf.py @@ -50,7 +50,10 @@ def test_erf_tosa_INT(test_data: input_t1): @common.XfailIfNoCorstone300 def test_erf_u55_INT(test_data: input_t1): pipeline = EthosU55PipelineINT[input_t1]( - Erf(), test_data(), aten_op, exir_op, run_on_fvp=True + Erf(), + test_data(), + aten_op, + exir_op, ) pipeline.run() @@ -59,28 +62,35 @@ def test_erf_u55_INT(test_data: input_t1): @common.XfailIfNoCorstone320 def test_erf_u85_INT(test_data: input_t1): pipeline = EthosU85PipelineINT[input_t1]( - Erf(), test_data(), aten_op, exir_op, run_on_fvp=True + Erf(), + test_data(), + aten_op, + exir_op, ) pipeline.run() @common.parametrize("test_data", Erf.test_data) @common.SkipIfNoModelConverter -def test_erf_vgf_FP(test_data: input_t1): +def test_erf_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( - Erf(), test_data(), aten_op, exir_op, tosa_version="TOSA-1.0+FP" + Erf(), + test_data(), + aten_op, + exir_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", Erf.test_data) @common.SkipIfNoModelConverter -def test_erf_vgf_INT(test_data: input_t1): +def test_erf_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Erf(), test_data(), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_exp.py b/backends/arm/test/ops/test_exp.py index 6eaacc71d86..71d5f49dc02 100644 --- a/backends/arm/test/ops/test_exp.py +++ b/backends/arm/test/ops/test_exp.py @@ -68,7 +68,6 @@ def test_exp_u55_INT(test_data: Tuple): (test_data(),), aten_op, exir_ops=[], - run_on_fvp=True, ) pipeline.run() @@ -81,32 +80,31 @@ def test_exp_u85_INT(test_data: Tuple): (test_data(),), aten_op, exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_exp_vgf_FP(test_data: Tuple): +def test_exp_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Exp(), (test_data(),), aten_op, exir_op=[], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_exp_vgf_INT(test_data: Tuple): +def test_exp_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Exp(), (test_data(),), aten_op, exir_op=[], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index b5784c9ff93..bddf7b65ad8 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -30,7 +30,7 @@ class Expand(torch.nn.Module): # (input tensor, multiples) - test_parameters = { + base_test_set = { "rand_1d_both": lambda: (torch.rand(1), (2,)), "rand_1d": lambda: (torch.randn(1), (2, 2, 4)), "rand_4d": lambda: (torch.randn(1, 1, 1, 5), (1, 4, -1, -1)), @@ -40,10 +40,14 @@ class Expand(torch.nn.Module): "rand_small_neg": lambda: (torch.rand(1, 1, 2, 2), (4, 3, -1, 2)), } + test_u55_reject_set = { + "randbool_1d": lambda: (torch.randint(0, 2, (1,), dtype=torch.bool), (5,)), + } test_reject_set = { "rand_2d": lambda: (torch.randn(1, 4), (1, -1)), "rand_neg_mul": lambda: (torch.randn(1, 1, 192), (1, -1, -1)), } + test_parameters = base_test_set | test_u55_reject_set def forward(self, x: torch.Tensor, m: Sequence): return x.expand(m) @@ -71,16 +75,22 @@ def test_expand_tosa_INT(test_data: Tuple): pipeline.run() -@common.parametrize("test_data", Expand.test_parameters) +@common.parametrize( + "test_data", + Expand.base_test_set, +) @common.XfailIfNoCorstone300 def test_expand_u55_INT(test_data: Tuple): + inputs = test_data() pipeline = EthosU55PipelineINT[input_t1]( Expand(), - test_data(), + inputs, aten_op, exir_ops=[], - run_on_fvp=True, ) + if inputs[0].dtype == torch.bool: + pipeline.pop_stage("check_count.exir") + pipeline.tester.use_portable_ops = True pipeline.run() @@ -92,40 +102,52 @@ def test_expand_u85_INT(test_data: Tuple): test_data(), aten_op, exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", Expand.test_parameters) @common.SkipIfNoModelConverter -def test_expand_vgf_FP(test_data: Tuple): +def test_expand_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Expand(), test_data(), aten_op, exir_op=[], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Expand.test_parameters) @common.SkipIfNoModelConverter -def test_expand_vgf_INT(test_data: Tuple): +def test_expand_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Expand(), test_data(), aten_op, exir_op=[], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", Expand.test_reject_set) +def test_expand_tosa_INT_not_delegated(test_data: Tuple): + pipeline = OpNotSupportedPipeline[input_t1]( + Expand(), test_data(), {exir_op: 1}, n_expected_delegates=0, quantize=True + ) + pipeline.run() + + +@common.parametrize("test_data", Expand.test_u55_reject_set) def test_expand_u55_INT_not_delegated(test_data: Tuple): pipeline = OpNotSupportedPipeline[input_t1]( - Expand(), test_data(), {exir_op: 1}, n_expected_delegates=0 + Expand(), + test_data(), + {exir_op: 1}, + n_expected_delegates=0, + quantize=True, + u55_subset=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_expm1.py b/backends/arm/test/ops/test_expm1.py index dad95b24f7b..7556d1e45a8 100644 --- a/backends/arm/test/ops/test_expm1.py +++ b/backends/arm/test/ops/test_expm1.py @@ -89,25 +89,25 @@ def test_expm1_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_expm1_vgf_FP(test_data: Tuple): +def test_expm1_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Expm1(), (test_data,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_expm1_vgf_INT(test_data: Tuple): +def test_expm1_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Expm1(), (test_data,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_eye.py b/backends/arm/test/ops/test_eye.py index eef32259c10..c004b7ca455 100644 --- a/backends/arm/test/ops/test_eye.py +++ b/backends/arm/test/ops/test_eye.py @@ -68,7 +68,8 @@ def test_eye_tosa_INT(test_data: test_data_t): input_data(), EyeAdd.aten_op, ) - pipeline.pop_stage("check.quant_nodes") + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -82,7 +83,8 @@ def test_eye_u55_INT(test_data: test_data_t): EyeAdd.aten_op, use_to_edge_transform_and_lower=True, ) - pipeline.pop_stage("check.quant_nodes") + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -95,8 +97,9 @@ def test_eye_u85_INT(test_data: test_data_t): input_data(), EyeAdd.aten_op, use_to_edge_transform_and_lower=True, - ).dump_artifact("to_edge_transform_and_lower") - pipeline.pop_stage("check.quant_nodes") + ) + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -108,13 +111,13 @@ def test_eye_u85_INT(test_data: test_data_t): EyeAdd.test_data, ) @common.SkipIfNoModelConverter -def test_eye_vgf_FP(test_data: test_data_t): +def test_eye_vgf_no_quant(test_data: test_data_t): input_data, init_data = test_data pipeline = VgfPipeline[input_t]( EyeAdd(*init_data), input_data(), EyeAdd.aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @@ -124,15 +127,16 @@ def test_eye_vgf_FP(test_data: test_data_t): EyeAdd.test_data, ) @common.SkipIfNoModelConverter -def test_eye_vgf_INT(test_data: test_data_t): +def test_eye_vgf_quant(test_data: test_data_t): input_data, init_data = test_data pipeline = VgfPipeline[input_t]( EyeAdd(*init_data), input_data(), EyeAdd.aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("check.quant_nodes") + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_fill_scalar.py b/backends/arm/test/ops/test_fill_scalar.py new file mode 100644 index 00000000000..5ca209fbeb0 --- /dev/null +++ b/backends/arm/test/ops/test_fill_scalar.py @@ -0,0 +1,108 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +aten_op = "torch.ops.aten.fill_.Scalar" +exir_op = "executorch_exir_dialects_edge__ops_aten_full_like_default" + +input_t1 = Tuple[torch.Tensor] + +test_data_suite = { + "ones_float": [torch.ones(2, 3), 5.0], + "ones_int": [torch.ones(2, 3), -3], +} + + +class FillScalar(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, y: torch.Tensor, fill_value: int | float): + mask = torch.full_like(y, 0) + mask.fill_(fill_value) + return mask * y + + +@common.parametrize("test_data", test_data_suite) +def test_fill_scalar_tosa_FP(test_data: Tuple): + pipeline = TosaPipelineFP[input_t1]( + FillScalar(), + (*test_data,), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_fill_scalar_tosa_INT(test_data: Tuple): + pipeline = TosaPipelineINT[input_t1]( + FillScalar(), + (*test_data,), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_suite) +def test_fill_scalar_u55_INT(test_data: Tuple): + pipeline = EthosU55PipelineINT[input_t1]( + FillScalar(), + (*test_data,), + aten_ops=[aten_op], + exir_ops=exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_suite) +def test_fill_scalar_u85_INT(test_data: Tuple): + pipeline = EthosU85PipelineINT[input_t1]( + FillScalar(), + (*test_data,), + aten_ops=[aten_op], + exir_ops=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_fill_scalar_vgf_no_quant(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + FillScalar(), + (*test_data,), + aten_op, + exir_op, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_fill_scalar_vgf_quant(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + FillScalar(), + (*test_data,), + aten_op, + exir_op, + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_floor.py b/backends/arm/test/ops/test_floor.py index c66ef1c5d27..d308db178d8 100644 --- a/backends/arm/test/ops/test_floor.py +++ b/backends/arm/test/ops/test_floor.py @@ -78,7 +78,6 @@ def test_floor_u55_INT(test_data: input_t1): (data,), module.aten_op, module.exir_op, - run_on_fvp=True, ) pipeline.run() @@ -92,28 +91,27 @@ def test_floor_u85_INT(test_data: input_t1): (data,), module.aten_op, module.exir_op, - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", test_data) @common.SkipIfNoModelConverter -def test_floor_vgf_FP(test_data: input_t1): +def test_floor_vgf_no_quant(test_data: input_t1): module, data = test_data() pipeline = VgfPipeline[input_t1]( module, (data,), module.aten_op, module.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data) @common.SkipIfNoModelConverter -def test_floor_vgf_INT(test_data: input_t1): +def test_floor_vgf_quant(test_data: input_t1): module, data = test_data() pipeline = VgfPipeline[input_t1]( module, @@ -122,6 +120,6 @@ def test_floor_vgf_INT(test_data: input_t1): module.exir_op, atol=0.06, rtol=0.01, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_floor_div.py b/backends/arm/test/ops/test_floor_div.py new file mode 100644 index 00000000000..d2b4bc46688 --- /dev/null +++ b/backends/arm/test/ops/test_floor_div.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple, Union + +import torch +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +test_data_suite = { + # (test_name, input, other) + "op_floor_div_rank1_ones": lambda: ( + torch.ones(5), + torch.ones(5), + ), + "op_floor_div_rank1_rand": lambda: ( + torch.rand(5) * 5, + torch.rand(5) * 5, + ), + "op_floor_div_rank4_negative_ones": lambda: ( + (-1) * torch.ones(5, 10, 25, 20), + torch.ones(5, 10, 25, 20), + ), + "op_floor_div_rank4_ones_div_negative": lambda: ( + torch.ones(5, 10, 25, 20), + (-1) * torch.ones(5, 10, 25, 20), + ), + "op_floor_div_rank4_randn_mutltiple_broadcasts": lambda: ( + torch.randn(1, 4, 4, 1), + torch.randn(1, 1, 4, 4), + ), + "op_floor_div_rank4_randn_scalar": lambda: ( + torch.randn(1, 4, 4, 1), + 2, + ), + "op_floor_div_rank4_large_rand": lambda: ( + 200 * torch.rand(5, 10, 25, 20), + torch.rand(5, 10, 25, 20), + ), +} + + +class FloorDivide(torch.nn.Module): + aten_op = "torch.ops.aten.floor_divide.default" + aten_ops_int = ["aten.mul.Tensor", "aten.reciprocal.default", "aten.floor.default"] + exir_op = "executorch_exir_dialects_edge__ops_aten_div_Tensor_mode" + exir_ops_int = [ + "executorch_exir_dialects_edge__ops_aten_reciprocal_default", + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_floor_default", + ] + + def forward( + self, + input_: Union[torch.Tensor, torch.types.Number], + other_: Union[torch.Tensor, torch.types.Number], + ): + return torch.floor_divide(input=input_, other=other_) + + +input_t1 = Tuple[torch.Tensor, Union[torch.Tensor, int]] + + +@common.parametrize("test_data", test_data_suite) +def test_floor_divide_tosa_FP(test_data: input_t1): + pipeline = TosaPipelineFP[input_t1]( + FloorDivide(), + test_data(), + FloorDivide.aten_op, + FloorDivide.exir_op, + use_to_edge_transform_and_lower=False, + rtol=0.06, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_floor_divide_tosa_INT(test_data: input_t1): + pipeline = TosaPipelineINT[input_t1]( + FloorDivide(), + test_data(), + aten_op=FloorDivide.aten_ops_int, + exir_op=FloorDivide.exir_ops_int, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_floor_divide_u55_INT(test_data: input_t1): + pipeline = EthosU55PipelineINT[input_t1]( + FloorDivide(), + test_data(), + aten_ops=FloorDivide.aten_ops_int, + exir_ops=[], + run_on_fvp=True, + use_to_edge_transform_and_lower=False, + ) + pipeline.pop_stage("check_not.exir") + pipeline.pop_stage("check_count.exir") + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_floor_divide_u85_INT(test_data: input_t1): + pipeline = EthosU85PipelineINT[input_t1]( + FloorDivide(), + test_data(), + aten_ops=FloorDivide.aten_ops_int, + exir_ops=FloorDivide.exir_ops_int, + run_on_fvp=True, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_floor_divide_vgf_no_quant(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + FloorDivide(), + test_data(), + FloorDivide.aten_op, + FloorDivide.exir_op, + use_to_edge_transform_and_lower=False, + rtol=0.06, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_floor_divide_vgf_quant(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + FloorDivide(), + test_data(), + aten_op=FloorDivide.aten_ops_int, + exir_op=FloorDivide.exir_ops_int, + use_to_edge_transform_and_lower=False, + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_full.py b/backends/arm/test/ops/test_full.py index 9e2c9b4d8be..8ce7fb984bc 100644 --- a/backends/arm/test/ops/test_full.py +++ b/backends/arm/test/ops/test_full.py @@ -117,7 +117,6 @@ def test_full_like_tosa_INT(test_data: Tuple): aten_op=[], exir_op=exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -144,52 +143,52 @@ def test_full_tosa_INT(test_data: Tuple): @common.SkipIfNoModelConverter -def test_full_vgf_FP_only(): +def test_full_only_vgf_no_quant(): pipeline = VgfPipeline[input_t1]( Full(), (), aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.SkipIfNoModelConverter -def test_full_vgf_FP_const(): +def test_full_const_vgf_no_quant(): test_data = (torch.rand((2, 2, 3, 3)) * 10,) pipeline = VgfPipeline[input_t1]( AddConstFull(), test_data, aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", AddVariableFull.test_parameters) @common.SkipIfNoModelConverter -def test_full_vgf_FP(test_data: Tuple): +def test_full_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( AddVariableFull(), test_data, aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", AddVariableFull.test_parameters) @common.SkipIfNoModelConverter -def test_full_vgf_INT(test_data: Tuple): +def test_full_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( AddVariableFull(), test_data, aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -202,7 +201,6 @@ def test_full_u85_INT(test_data: Tuple): test_data, aten_ops=[], exir_ops=exir_op, - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() @@ -216,7 +214,6 @@ def test_full_u55_INT(test_data: Tuple): test_data, aten_ops=[], exir_ops=exir_op, - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_ge.py b/backends/arm/test/ops/test_ge.py index 94f33d28630..a5033c8c977 100644 --- a/backends/arm/test/ops/test_ge.py +++ b/backends/arm/test/ops/test_ge.py @@ -121,6 +121,30 @@ def test_ge_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_ge_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_ge_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_ge_tensor_u55_INT(test_module): @@ -161,7 +185,6 @@ def test_ge_tensor_u85_INT(test_module): test_module().get_inputs(), GreaterEqual.aten_op_tensor, GreaterEqual.exir_op, - run_on_fvp=True, ) pipeline.run() @@ -177,58 +200,93 @@ def test_ge_scalar_u85_INT(test_module): test_module().get_inputs(), GreaterEqual.aten_op_tensor, GreaterEqual.exir_op, - run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_ge_tensor_16a8w_u85_INT16(test_module): + """Test ge operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_ge_scalar_16a8w_u85_INT16(test_module): + """Test ge operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, ) pipeline.run() @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter -def test_ge_tensor_vgf_FP(test_module): +def test_ge_tensor_vgf_no_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), GreaterEqual.aten_op_tensor, GreaterEqual.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter -def test_ge_tensor_vgf_INT(test_module): +def test_ge_tensor_vgf_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), GreaterEqual.aten_op_tensor, GreaterEqual.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_module", test_data_scalar) @common.SkipIfNoModelConverter -def test_ge_scalar_vgf_FP(test_module): +def test_ge_scalar_vgf_no_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), GreaterEqual.aten_op_scalar, GreaterEqual.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_data_scalar) @common.SkipIfNoModelConverter -def test_ge_scalar_vgf_INT(test_module): +def test_ge_scalar_vgf_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), GreaterEqual.aten_op_tensor, GreaterEqual.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_gelu.py b/backends/arm/test/ops/test_gelu.py index 264f6b95e71..0936aa16c3c 100644 --- a/backends/arm/test/ops/test_gelu.py +++ b/backends/arm/test/ops/test_gelu.py @@ -130,27 +130,27 @@ def test_gelu_u85_INT(test_data: input_t1): @common.parametrize("test_data", Gelu.test_data) @common.SkipIfNoModelConverter -def test_gelu_vgf_FP(test_data: input_t1): +def test_gelu_vgf_no_quant(test_data: input_t1): approximate, data = test_data() pipeline = VgfPipeline[input_t1]( Gelu(approximate), (data,), Gelu.aten_op, Gelu.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Gelu.test_data) @common.SkipIfNoModelConverter -def test_gelu_vgf_INT(test_data: input_t1): +def test_gelu_vgf_quant(test_data: input_t1): approximate, data = test_data() pipeline = VgfPipeline[input_t1]( Gelu(approximate), (data,), Gelu.aten_op, Gelu.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_glu.py b/backends/arm/test/ops/test_glu.py index c19fb892c92..c7426c01286 100644 --- a/backends/arm/test/ops/test_glu.py +++ b/backends/arm/test/ops/test_glu.py @@ -103,13 +103,13 @@ def test_glu_u85_INT(test_data: Tuple): test_data_suite, ) @common.SkipIfNoModelConverter -def test_glu_vgf_FP(test_data: input_t1): +def test_glu_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Glu(), (*test_data,), [], [], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @@ -119,12 +119,12 @@ def test_glu_vgf_FP(test_data: input_t1): test_data_suite, ) @common.SkipIfNoModelConverter -def test_glu_vgf_INT(test_data: input_t1): +def test_glu_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Glu(), (*test_data,), [], [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_group_norm.py b/backends/arm/test/ops/test_group_norm.py index 0f314064548..d80b94ce786 100644 --- a/backends/arm/test/ops/test_group_norm.py +++ b/backends/arm/test/ops/test_group_norm.py @@ -118,7 +118,6 @@ def test_native_group_norm_u55_INT(test_data): test_data[1], test_data[0], "torch.ops.aten.sub.Tensor", # 'sub' op arbitrarily chosen to confirm groupnorm was decomposed - run_on_fvp=True, atol=0.1, # TODO: "MLETORCH-925: Fix numerical issue for aten.native_group_norm" ) pipeline.change_args("run_method_and_compare_outputs", atol=1, qtol=1) @@ -142,7 +141,6 @@ def test_native_group_norm_u85_INT(test_data): test_data[1], test_data[0], "torch.ops.aten.sub.Tensor", # 'sub' op arbitrarily chosen to confirm groupnorm was decomposed - run_on_fvp=True, atol=0.1, # TODO: "MLETORCH-925: Fix numerical issue for aten.native_group_norm" ) pipeline.change_args("run_method_and_compare_outputs", atol=1, qtol=1) @@ -161,7 +159,7 @@ def test_native_group_norm_u85_INT(test_data): strict=False, ) @common.SkipIfNoModelConverter -def test_native_group_norm_vgf_FP(test_data): +def test_native_group_norm_vgf_no_quant(test_data): aten_op = "torch.ops.aten.group_norm.default" exir_op = "executorch_exir_dialects_edge__ops_aten_native_group_norm_default" model, inp = test_data @@ -170,7 +168,7 @@ def test_native_group_norm_vgf_FP(test_data): model, aten_op=aten_op, exir_op=exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @@ -187,7 +185,7 @@ def test_native_group_norm_vgf_FP(test_data): strict=False, ) @common.SkipIfNoModelConverter -def test_native_group_norm_vgf_INT(test_data): +def test_native_group_norm_vgf_quant(test_data): aten_op = "torch.ops.aten.sub.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_native_group_norm_default" model, inp = test_data @@ -196,7 +194,7 @@ def test_native_group_norm_vgf_INT(test_data): model, aten_op=aten_op, exir_op=exir_op, - tosa_version="TOSA-1.0+INT", - atol=0.1, # TODO: "MLETORCH-925: Fix numerical issue for aten.native_group_norm" + atol=0.1, + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_gt.py b/backends/arm/test/ops/test_gt.py index 41229397eb5..961ec638b98 100644 --- a/backends/arm/test/ops/test_gt.py +++ b/backends/arm/test/ops/test_gt.py @@ -122,6 +122,30 @@ def test_gt_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_gt_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Greater.aten_op_tensor, + Greater.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_gt_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Greater.aten_op_tensor, + Greater.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_gt_tensor_u55_INT(test_module): @@ -162,7 +186,6 @@ def test_gt_tensor_u85_INT(test_module): test_module().get_inputs(), Greater.aten_op_tensor, Greater.exir_op, - run_on_fvp=True, ) pipeline.run() @@ -178,58 +201,93 @@ def test_gt_scalar_u85_INT(test_module): test_module().get_inputs(), Greater.aten_op_tensor, Greater.exir_op, - run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_gt_tensor_16a8w_u85_INT16(test_module): + """Test gt operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Greater.aten_op_tensor, + Greater.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_gt_scalar_16a8w_u85_INT16(test_module): + """Test gt operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Greater.aten_op_tensor, + Greater.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, ) pipeline.run() @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter -def test_gt_tensor_vgf_FP(test_module): +def test_gt_tensor_vgf_no_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), Greater.aten_op_tensor, Greater.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_data_scalar) @common.SkipIfNoModelConverter -def test_gt_scalar_vgf_FP(test_module): +def test_gt_scalar_vgf_no_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), Greater.aten_op_scalar, Greater.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter -def test_gt_tensor_vgf_INT(test_module): +def test_gt_tensor_vgf_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), Greater.aten_op_tensor, Greater.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_module", test_data_scalar) @common.SkipIfNoModelConverter -def test_gt_scalar_vgf_INT(test_module): +def test_gt_scalar_vgf_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), Greater.aten_op_tensor, Greater.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_hardsigmoid.py b/backends/arm/test/ops/test_hardsigmoid.py index 5f591c15617..eb10e5a79e4 100644 --- a/backends/arm/test/ops/test_hardsigmoid.py +++ b/backends/arm/test/ops/test_hardsigmoid.py @@ -70,7 +70,6 @@ def test_hardsigmoid_u55_INT(test_data: torch.Tensor): (test_data(),), aten_op, exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() @@ -84,7 +83,6 @@ def test_hardsigmoid_u85_INT(test_data: torch.Tensor): (test_data(),), aten_op, exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() @@ -92,21 +90,25 @@ def test_hardsigmoid_u85_INT(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_hardsigmoid_vgf_FP(test_data: torch.Tensor): +def test_hardsigmoid_vgf_no_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( - Hardsigmoid(), (test_data(),), aten_op, exir_op=[], tosa_version="TOSA-1.0+FP" + Hardsigmoid(), + (test_data(),), + aten_op, + exir_op=[], + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_hardsigmoid_vgf_INT(test_data: torch.Tensor): +def test_hardsigmoid_vgf_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( Hardsigmoid(), (test_data(),), aten_op, exir_op=[], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_hardswish.py b/backends/arm/test/ops/test_hardswish.py index 00db0cb296b..68cd249861a 100644 --- a/backends/arm/test/ops/test_hardswish.py +++ b/backends/arm/test/ops/test_hardswish.py @@ -62,7 +62,6 @@ def test_hardswish_u55_INT(test_data): (test_data(),), aten_op, exir_op, - run_on_fvp=True, use_to_edge_transform_and_lower=True, ).run() @@ -75,28 +74,31 @@ def test_hardswish_u85_INT(test_data): (test_data(),), aten_op, exir_op, - run_on_fvp=True, use_to_edge_transform_and_lower=True, ).run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_hardswish_vgf_FP(test_data): +def test_hardswish_vgf_no_quant(test_data): pipeline = VgfPipeline[input_t1]( - Hardswish(), (test_data(),), aten_op, exir_op, tosa_version="TOSA-1.0+FP" + Hardswish(), + (test_data(),), + aten_op, + exir_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_hardswish_vgf_INT(test_data): +def test_hardswish_vgf_quant(test_data): pipeline = VgfPipeline[input_t1]( Hardswish(), (test_data(),), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_hardtanh.py b/backends/arm/test/ops/test_hardtanh.py index 28f7e717351..a13e70d74d0 100644 --- a/backends/arm/test/ops/test_hardtanh.py +++ b/backends/arm/test/ops/test_hardtanh.py @@ -71,7 +71,6 @@ def test_hardtanh_u55_INT(test_data: torch.Tensor): (test_data(),), aten_op, exir_op, - run_on_fvp=True, ) pipeline.run() @@ -84,28 +83,31 @@ def test_hardtanh_u85_INT(test_data: torch.Tensor): (test_data(),), aten_op, exir_op, - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_hardtanh_vgf_FP(test_data: torch.Tensor): +def test_hardtanh_vgf_no_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t]( - HardTanh(), (test_data(),), aten_op, exir_op, tosa_version="TOSA-1.0+FP" + HardTanh(), + (test_data(),), + aten_op, + exir_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_hardtanh_vgf_INT(test_data: torch.Tensor): +def test_hardtanh_vgf_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t]( HardTanh(), (test_data(),), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_index_select.py b/backends/arm/test/ops/test_index_select.py index 95ebaa62a38..239c27a8af6 100644 --- a/backends/arm/test/ops/test_index_select.py +++ b/backends/arm/test/ops/test_index_select.py @@ -137,45 +137,41 @@ def test_index_select_u55_INT_not_delegated(test_data: input_params): @pytest.mark.parametrize("test_data", list(test_data.values())) @common.SkipIfNoModelConverter -def test_index_select_vgf_FP(test_data: input_params): +def test_index_select_vgf_no_quant(test_data: input_params): op, inp = test_data pipeline = VgfPipeline[input_params]( op, inp, op.aten_op, op.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @pytest.mark.parametrize("test_data", list(test_data.values())[:-1]) @common.SkipIfNoModelConverter -def test_index_select_vgf_INT(test_data: input_params): +def test_index_select_vgf_quant(test_data: input_params): op, inp = test_data pipeline = VgfPipeline[input_params]( op, inp, op.aten_op, op.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @pytest.mark.parametrize("test_data", list(test_data.values())[-1:]) @common.SkipIfNoModelConverter -def test_index_select_vgf_INT_rand(test_data: input_params): +def test_index_select_vgf_quant_rand(test_data: input_params): op, inp = test_data pipeline = VgfPipeline[input_params]( op, inp, op.aten_op, op.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - # TODO: MLETORCH-1136 Change args of run_method_and_compare_outputs of the vgf tests - # pipeline.change_args( - # "run_method_and_compare_outputs", inputs=test_input, atol=0.9, rtol=0.2, qtol=1 - # ) pipeline.run() diff --git a/backends/arm/test/ops/test_index_tensor.py b/backends/arm/test/ops/test_index_tensor.py index 557846922b8..bc19634bf30 100644 --- a/backends/arm/test/ops/test_index_tensor.py +++ b/backends/arm/test/ops/test_index_tensor.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. -from enum import IntEnum from typing import Tuple import torch @@ -25,22 +24,12 @@ class IndexTensorTestCommon: # Gathers and reshapes should result in no inaccuracies rtol = 0.0 atol = 0.0 + BEFORE = "BEFORE" + MIDDLE = "MIDDLE" + AFTER = "AFTER" - class OpPlacement(IntEnum): - """ - Simple enum used to indicate where slices or ellipsis should be placed - in tests. - IntEnum so that Dynamo does not complain about unsupported types. - """ - BEFORE = 1 - MIDDLE = 2 - AFTER = 3 - - -input_params_slice = Tuple[ - torch.Tensor, int, int, IndexTensorTestCommon.OpPlacement, Tuple[torch.Tensor] -] +input_params_slice = Tuple[torch.Tensor, int, int, str, Tuple[torch.Tensor]] input_params = Tuple[torch.Tensor, Tuple[torch.Tensor]] @@ -55,12 +44,12 @@ class IndexTensor_Ellipsis(torch.nn.Module): test_data_ellipsis: dict[input_params] = { "test_4d_ellipsis_before": ( torch.rand(size=(25, 5, 13, 7)), - IndexTensorTestCommon.OpPlacement.BEFORE, + IndexTensorTestCommon.BEFORE, (torch.arange(2, dtype=torch.int32),), ), "test_4d_ellipsis_middle": ( torch.rand(size=(25, 5, 13, 7)), - IndexTensorTestCommon.OpPlacement.MIDDLE, + IndexTensorTestCommon.MIDDLE, ( torch.arange(2, dtype=torch.int32), torch.arange(2, dtype=torch.int32), @@ -72,7 +61,7 @@ class IndexTensor_Ellipsis(torch.nn.Module): # partitioning is difficult and unreliable, as such # it is not xfail as the existing logic can handle it. torch.rand(size=(25, 5, 13, 7)), - IndexTensorTestCommon.OpPlacement.AFTER, + IndexTensorTestCommon.AFTER, (torch.arange(2, dtype=torch.int32),), ), } @@ -80,15 +69,15 @@ class IndexTensor_Ellipsis(torch.nn.Module): def forward( self, input_: torch.Tensor, - position: IndexTensorTestCommon.OpPlacement, + position: str, indices: Tuple[None | torch.Tensor], ): match position: - case IndexTensorTestCommon.OpPlacement.BEFORE: + case IndexTensorTestCommon.BEFORE: return input_[..., indices[0]] - case IndexTensorTestCommon.OpPlacement.MIDDLE: + case IndexTensorTestCommon.MIDDLE: return input_[indices[0], ..., indices[1]] - case IndexTensorTestCommon.OpPlacement.AFTER: + case IndexTensorTestCommon.AFTER: return input_[indices[0], ...] return input_[indices] @@ -154,7 +143,7 @@ class IndexTensor_Slice(torch.nn.Module): torch.rand(size=(5, 3, 4, 5)), 0, 2, - IndexTensorTestCommon.OpPlacement.BEFORE, + IndexTensorTestCommon.BEFORE, (torch.arange(2, dtype=torch.int32),), ), "test_3d_slice_before_2d_idx": ( @@ -164,14 +153,14 @@ class IndexTensor_Slice(torch.nn.Module): torch.arange(5 * 3 * 4, dtype=torch.float32).reshape(5, 3, 4), 0, 2, - IndexTensorTestCommon.OpPlacement.BEFORE, + IndexTensorTestCommon.BEFORE, (torch.arange(2, dtype=torch.int32).unsqueeze(0).tile(2, 1),), ), "test_4d_slice_middle": ( torch.arange(5 * 3 * 2, dtype=torch.int32).reshape(5, 3, 2), 0, 2, - IndexTensorTestCommon.OpPlacement.MIDDLE, + IndexTensorTestCommon.MIDDLE, ( torch.arange(2, dtype=torch.int32), torch.arange(2, dtype=torch.int32), @@ -185,7 +174,7 @@ class IndexTensor_Slice(torch.nn.Module): torch.rand(size=(25, 5, 13, 7)), 0, 2, - IndexTensorTestCommon.OpPlacement.AFTER, + IndexTensorTestCommon.AFTER, (torch.arange(2, dtype=torch.int32),), ), } @@ -195,15 +184,15 @@ def forward( input_: torch.Tensor, slice_start: int, slice_end: int, - position: IndexTensorTestCommon.OpPlacement, + position: str, indices: Tuple[None | torch.Tensor], ): match position: - case IndexTensorTestCommon.OpPlacement.BEFORE: + case IndexTensorTestCommon.BEFORE: return input_[slice_start:slice_end, indices[0]] - case IndexTensorTestCommon.OpPlacement.MIDDLE: + case IndexTensorTestCommon.MIDDLE: return input_[indices[0], slice_start:slice_end, indices[1]] - case IndexTensorTestCommon.OpPlacement.AFTER: + case IndexTensorTestCommon.AFTER: return input_[indices[0], slice_start:slice_end] diff --git a/backends/arm/test/ops/test_layer_norm.py b/backends/arm/test/ops/test_layer_norm.py index 2c9b83dc7e7..3b6db9f644c 100644 --- a/backends/arm/test/ops/test_layer_norm.py +++ b/backends/arm/test/ops/test_layer_norm.py @@ -95,7 +95,6 @@ def test_native_layer_norm_u55_INT(test_data): model, test_data, "torch.ops.aten.sub.Tensor", # Just check for sub op included in the layernorm decomposition - run_on_fvp=True, symmetric_io_quantization=True, ) pipeline.run() @@ -109,7 +108,6 @@ def test_native_layer_norm_u85_INT(test_data): model, test_data, "torch.ops.aten.sub.Tensor", # Just check for sub op included in the layernorm decomposition - run_on_fvp=True, symmetric_io_quantization=True, ) pipeline.run() @@ -117,25 +115,72 @@ def test_native_layer_norm_u85_INT(test_data): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_native_layer_norm_vgf_FP(test_data): +def test_native_layer_norm_vgf_no_quant(test_data): test_input, model = test_data() pipeline = VgfPipeline[input_t]( model, test_input, "torch.ops.aten.layer_norm.default", - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_native_layer_norm_vgf_INT(test_data): +def test_native_layer_norm_vgf_quant(test_data): test_input, model = test_data() pipeline = VgfPipeline[input_t]( model, test_input, "torch.ops.aten.sub.Tensor", - tosa_version="TOSA-1.0+INT", + quantize=True, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_native_layer_norm_tosa_INT_a16w8(test_data): + """Test layer_norm with int16 I/O quantization for TOSA INT.""" + test_input, model = test_data() + pipeline = TosaPipelineINT[input_t]( + model, + test_input, + "torch.ops.aten.sub.Tensor", # check for sub op in decomposition + symmetric_io_quantization=True, + tosa_extensions=["int16"], + epsilon=2**16, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_native_layer_norm_16a8w_u55_INT16(test_data): + """Test layer_norm with int16 I/O quantization for U55""" + test_input, model = test_data() + pipeline = EthosU55PipelineINT[input_t]( + model, + test_input, + "torch.ops.aten.sub.Tensor", + symmetric_io_quantization=True, + a16w8_quantization=True, + epsilon=2**16, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_native_layer_norm_16a8w_u85_INT16(test_data): + """Test layer_norm with int16 I/O quantization for U85""" + test_input, model = test_data() + pipeline = EthosU85PipelineINT[input_t]( + model, + test_input, + "torch.ops.aten.sub.Tensor", + symmetric_io_quantization=True, + a16w8_quantization=True, + epsilon=2**16, ) pipeline.run() diff --git a/backends/arm/test/ops/test_le.py b/backends/arm/test/ops/test_le.py index 31422302a2d..3d4cc836038 100644 --- a/backends/arm/test/ops/test_le.py +++ b/backends/arm/test/ops/test_le.py @@ -122,6 +122,30 @@ def test_le_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_le_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessEqual.aten_op_tensor, + LessEqual.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_le_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessEqual.aten_op_tensor, + LessEqual.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_le_tensor_u55_INT_not_delegated(test_module): @@ -163,7 +187,6 @@ def test_le_tensor_u85_INT(test_module): test_module().get_inputs(), LessEqual.aten_op_tensor, LessEqual.exir_op, - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() @@ -180,7 +203,42 @@ def test_le_scalar_u85_INT(test_module): test_module().get_inputs(), LessEqual.aten_op_tensor, LessEqual.exir_op, - run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_le_tensor_16a8w_u85_INT16(test_module): + """Test le operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessEqual.aten_op_tensor, + LessEqual.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_le_scalar_16a8w_u85_INT16(test_module): + """Test le operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessEqual.aten_op_tensor, + LessEqual.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, use_to_edge_transform_and_lower=True, ) pipeline.run() @@ -188,51 +246,51 @@ def test_le_scalar_u85_INT(test_module): @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter -def test_le_tensor_vgf_FP(test_module): +def test_le_tensor_vgf_no_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), LessEqual.aten_op_tensor, LessEqual.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter -def test_le_tensor_vgf_INT(test_module): +def test_le_tensor_vgf_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), LessEqual.aten_op_tensor, LessEqual.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_module", test_data_scalar) @common.SkipIfNoModelConverter -def test_le_scalar_vgf_FP(test_module): +def test_le_scalar_vgf_no_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), LessEqual.aten_op_scalar, LessEqual.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_data_scalar) @common.SkipIfNoModelConverter -def test_le_scalar_vgf_INT(test_module): +def test_le_scalar_vgf_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), LessEqual.aten_op_tensor, LessEqual.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_leaky_relu.py b/backends/arm/test/ops/test_leaky_relu.py index 432c4da7ecc..9be24857563 100644 --- a/backends/arm/test/ops/test_leaky_relu.py +++ b/backends/arm/test/ops/test_leaky_relu.py @@ -73,7 +73,6 @@ def test_leaky_relu_u55_INT(test_data): LeakyReLU(slope), data, [], - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) @@ -88,7 +87,6 @@ def test_leaky_relu_u85_INT(test_data): LeakyReLU(slope), data, [], - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) @@ -97,14 +95,14 @@ def test_leaky_relu_u85_INT(test_data): @common.parametrize("test_data", LeakyReLU.test_data) @common.SkipIfNoModelConverter -def test_leaky_relu_vgf_FP(test_data): +def test_leaky_relu_vgf_no_quant(test_data): data, slope = test_data() pipeline = VgfPipeline[input_t1]( LeakyReLU(slope), data, [], use_to_edge_transform_and_lower=True, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.add_stage_after( "to_edge_transform_and_lower", pipeline.tester.check_not, [aten_op] @@ -114,14 +112,14 @@ def test_leaky_relu_vgf_FP(test_data): @common.parametrize("test_data", LeakyReLU.test_data) @common.SkipIfNoModelConverter -def test_leaky_relu_vgf_INT(test_data): +def test_leaky_relu_vgf_quant(test_data): data, slope = test_data() pipeline = VgfPipeline[input_t1]( LeakyReLU(slope), data, [], use_to_edge_transform_and_lower=True, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) pipeline.run() diff --git a/backends/arm/test/ops/test_linalg_vector_norm.py b/backends/arm/test/ops/test_linalg_vector_norm.py index 1777cffb0a7..2723479869e 100644 --- a/backends/arm/test/ops/test_linalg_vector_norm.py +++ b/backends/arm/test/ops/test_linalg_vector_norm.py @@ -103,7 +103,6 @@ def test_vector_norm_u55_INT_fvp(test_module): input_tensor, aten_op_q_decomposed_q, exir_op_q_decomposed, - run_on_fvp=True, symmetric_io_quantization=True, ) pipeline.pop_stage("check_not.exir") @@ -121,7 +120,6 @@ def test_vector_norm_u85_INT_fvp(test_module): input_tensor, aten_op_q_decomposed_q, exir_op_q_decomposed, - run_on_fvp=True, symmetric_io_quantization=True, ) pipeline.pop_stage("check_not.exir") @@ -130,7 +128,7 @@ def test_vector_norm_u85_INT_fvp(test_module): @common.parametrize("test_module", test_modules) @common.SkipIfNoModelConverter -def test_vector_norm_vgf_FP(test_module): +def test_vector_norm_vgf_no_quant(test_module): model, input_tensor = test_module # FP VGF aten_op = "torch.ops.aten.linalg_vector_norm.default" @@ -140,14 +138,14 @@ def test_vector_norm_vgf_FP(test_module): input_tensor, aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_modules) @common.SkipIfNoModelConverter -def test_vector_norm_vgf_INT(test_module): +def test_vector_norm_vgf_quant(test_module): model, input_tensor = test_module # Should not found this op exir_op = "executorch_exir_dialects_edge__ops_aten_linalg_vector_norm_default" @@ -157,6 +155,6 @@ def test_vector_norm_vgf_INT(test_module): input_tensor, aten_op_q_decomposed_q, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index e5d00c83e9f..7f79f7c586b 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -8,7 +8,6 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, @@ -181,7 +180,6 @@ def test_linear_u55_INT(test_data: torch.Tensor): (test_data,), aten_op, exir_ops=[], - run_on_fvp=True, per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, ).run() @@ -204,7 +202,6 @@ def test_linear_u85_INT(test_data: torch.Tensor): (test_data,), aten_op, exir_ops=[], - run_on_fvp=True, per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, ).run() @@ -212,39 +209,31 @@ def test_linear_u85_INT(test_data: torch.Tensor): @common.parametrize("test_data", test_data_rank1_FP | test_data_rank4_FP) @common.SkipIfNoModelConverter -def test_linear_vgf_FP(test_data: torch.Tensor): +def test_linear_vgf_no_quant(test_data: torch.Tensor): test_data, out_features, has_bias = test_data() in_features = test_data.shape[-1] pipeline = VgfPipeline[input_t1]( - Linear( - in_features=in_features, - out_features=out_features, - bias=has_bias, - ), + Linear(in_features=in_features, out_features=out_features, bias=has_bias), (test_data,), aten_op=aten_op, exir_op=[], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT) @common.SkipIfNoModelConverter -def test_linear_vgf_INT(test_data: torch.Tensor): +def test_linear_vgf_quant(test_data: torch.Tensor): test_data, out_features, has_bias, per_channel_quantization = test_data() in_features = test_data.shape[-1] pipeline = VgfPipeline[input_t1]( - Linear( - in_features=in_features, - out_features=out_features, - bias=has_bias, - ), + Linear(in_features=in_features, out_features=out_features, bias=has_bias), (test_data,), aten_op=aten_op, exir_op=[], - tosa_version="TOSA-1.0+INT", per_channel_quantization=per_channel_quantization, + quantize=True, ) pipeline.run() @@ -276,10 +265,10 @@ def get_symmetric_a16w8_linear_quantizer( ) -@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT) -@pytest.mark.xfail( - reason="missing int16 linear ops support; fails at TOSA reference model run with Invalid TOSA graph" -) +test_data_all_16a8w = test_data_rank1_INT | test_data_rank4_INT + + +@common.parametrize("test_data", test_data_all_16a8w) def test_linear_16a8w_tosa_INT(test_data: torch.Tensor): """Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)""" test_data, out_features, has_bias, per_channel_quantization = test_data() @@ -308,3 +297,81 @@ def test_linear_16a8w_tosa_INT(test_data: torch.Tensor): ) # Run the pipeline pipeline.run() + + +x_fails = {} +x_skips = {} + +for test_name in [ + "model_linear_rank4_zeros", + "model_linear_rank4_negative_ones", + "model_linear_rank4_negative_large_rand", +]: + for set_per_chan in ["True", "False"]: + key = test_name + ",per_channel_quant={}".format(set_per_chan) + reason = ( + "MLETORCH-1452: AssertionError: Output 0 does not match reference output." + ) + x_fails[key] = reason + # TODO: Check why xfail doesn't work for this buck target. In the interim rely on skip + x_skips[key] = reason + + +@common.parametrize("test_data", test_data_all_16a8w, xfails=x_fails, skips=x_skips) +@common.XfailIfNoCorstone300 +def test_linear_16a8w_u55_INT16(test_data: torch.Tensor): + """Test linear operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + test_data, out_features, has_bias, per_channel_quantization = test_data() + in_features = test_data.shape[-1] + + pipeline = EthosU55PipelineINT[input_t1]( + Linear( + in_features=in_features, + out_features=out_features, + bias=has_bias, + ), + (test_data,), + aten_op, + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_linear_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_all_16a8w) +@common.XfailIfNoCorstone320 +def test_linear_16a8w_u85_INT16(test_data: torch.Tensor): + """Test linear operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + test_data, out_features, has_bias, per_channel_quantization = test_data() + in_features = test_data.shape[-1] + + pipeline = EthosU85PipelineINT[input_t1]( + Linear( + in_features=in_features, + out_features=out_features, + bias=has_bias, + ), + (test_data,), + aten_op, + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_linear_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_log.py b/backends/arm/test/ops/test_log.py index 1ed5c57f1ab..3f4bfcdb17f 100644 --- a/backends/arm/test/ops/test_log.py +++ b/backends/arm/test/ops/test_log.py @@ -60,7 +60,6 @@ def test_log_u55_INT(test_data: input_t1): (test_data(),), aten_op, exir_op, - run_on_fvp=True, ).run() @@ -72,31 +71,30 @@ def test_log_u85_INT(test_data: input_t1): (test_data(),), aten_op, exir_op, - run_on_fvp=True, ).run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_log_vgf_FP(test_data: input_t1): +def test_log_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Log(), (test_data(),), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_log_vgf_INT(test_data: input_t1): +def test_log_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Log(), (test_data(),), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_logical.py b/backends/arm/test/ops/test_logical.py index bb7c5773342..a2a82793170 100644 --- a/backends/arm/test/ops/test_logical.py +++ b/backends/arm/test/ops/test_logical.py @@ -86,9 +86,6 @@ def forward(self, tensor: torch.Tensor): ################# -xfails = {"rand_rank4": "MLBEDSW-11031: Output diff on u85 bool transpose."} - - @common.parametrize("test_data", And().test_data) def test_logical_and_tosa_FP(test_data: input_t2): pipeline = TosaPipelineFP[input_t2]( @@ -114,8 +111,6 @@ def test_logical_and_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -132,7 +127,7 @@ def test_logical_and_u55_INT_not_delegated(test_data: input_t2): pipeline.run() -@common.parametrize("test_data", And().test_data, xfails=xfails) +@common.parametrize("test_data", And().test_data) @common.XfailIfNoCorstone320 def test_logical_and_u85_INT(test_data: input_t2): pipeline = EthosU85PipelineINT[input_t2]( @@ -140,41 +135,36 @@ def test_logical_and_u85_INT(test_data: input_t2): test_data(), And().aten_op, And().exir_op, - run_on_fvp=True, atol=0, rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", And().test_data) @common.SkipIfNoModelConverter -def test_logical_and_vgf_FP(test_data: input_t2): +def test_logical_and_vgf_no_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( And(), test_data(), And().aten_op, And().exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", And().test_data) @common.SkipIfNoModelConverter -def test_logical_and_vgf_INT(test_data: input_t2): +def test_logical_and_vgf_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( And(), test_data(), And().aten_op, And().exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -208,8 +198,6 @@ def test_logical_xor_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -226,7 +214,7 @@ def test_logical_xor_u55_INT_not_delegated(test_data: input_t2): pipeline.run() -@common.parametrize("test_data", Xor().test_data, xfails=xfails) +@common.parametrize("test_data", Xor().test_data) @common.XfailIfNoCorstone320 def test_logical_xor_u85_INT(test_data: input_t2): pipeline = EthosU85PipelineINT[input_t2]( @@ -234,41 +222,36 @@ def test_logical_xor_u85_INT(test_data: input_t2): test_data(), Xor().aten_op, Xor().exir_op, - run_on_fvp=True, atol=0, rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", Xor().test_data) @common.SkipIfNoModelConverter -def test_logical_xor_vgf_FP(test_data: input_t2): +def test_logical_xor_vgf_no_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( Xor(), test_data(), Xor().aten_op, Xor().exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Xor().test_data) @common.SkipIfNoModelConverter -def test_logical_xor_vgf_INT(test_data: input_t2): +def test_logical_xor_vgf_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( Xor(), test_data(), Xor().aten_op, Xor().exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -302,8 +285,6 @@ def test_logical_or_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -320,7 +301,7 @@ def test_logical_or_u55_INT_not_delegated(test_data: input_t2): pipeline.run() -@common.parametrize("test_data", Or().test_data, xfails=xfails) +@common.parametrize("test_data", Or().test_data) @common.XfailIfNoCorstone320 def test_logical_or_u85_INT(test_data: input_t2): pipeline = EthosU85PipelineINT[input_t2]( @@ -328,41 +309,36 @@ def test_logical_or_u85_INT(test_data: input_t2): test_data(), Or().aten_op, Or().exir_op, - run_on_fvp=True, atol=0, rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", Or().test_data) @common.SkipIfNoModelConverter -def test_logical_or_vgf_FP(test_data: input_t2): +def test_logical_or_vgf_no_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( Or(), test_data(), Or().aten_op, Or().exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Or().test_data) @common.SkipIfNoModelConverter -def test_logical_or_vgf_INT(test_data: input_t2): +def test_logical_or_vgf_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( Or(), test_data(), Or().aten_op, Or().exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -396,8 +372,6 @@ def test_logical_not_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -414,7 +388,7 @@ def test_logical_not_u55_INT_not_delegated(test_data: input_t2): pipeline.run() -@common.parametrize("test_data", Not().test_data, xfails=xfails) +@common.parametrize("test_data", Not().test_data) @common.XfailIfNoCorstone320 def test_logical_not_u85_INT(test_data: input_t2): pipeline = EthosU85PipelineINT[input_t2]( @@ -422,39 +396,34 @@ def test_logical_not_u85_INT(test_data: input_t2): test_data(), Not().aten_op, Not().exir_op, - run_on_fvp=True, atol=0, rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", Not().test_data) @common.SkipIfNoModelConverter -def test_logical_not_vgf_FP(test_data: input_t2): +def test_logical_not_vgf_no_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( Not(), test_data(), Not().aten_op, Not().exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Not().test_data) @common.SkipIfNoModelConverter -def test_logical_not_vgf_INT(test_data: input_t2): +def test_logical_not_vgf_quant(test_data: input_t2): pipeline = VgfPipeline[input_t2]( Not(), test_data(), Not().aten_op, Not().exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_logit.py b/backends/arm/test/ops/test_logit.py index 8915c151bb9..b628504c716 100644 --- a/backends/arm/test/ops/test_logit.py +++ b/backends/arm/test/ops/test_logit.py @@ -92,13 +92,13 @@ def test_logit_u85_INT(test_data: Tuple): test_data_suite, ) @common.SkipIfNoModelConverter -def test_logit_vgf_FP(test_data: input_t1): +def test_logit_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Logit(), (*test_data,), [], [], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @@ -108,12 +108,12 @@ def test_logit_vgf_FP(test_data: input_t1): test_data_suite, ) @common.SkipIfNoModelConverter -def test_logit_vgf_INT(test_data: input_t1): +def test_logit_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Logit(), (*test_data,), [], [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_logsoftmax.py b/backends/arm/test/ops/test_logsoftmax.py index 791069aa4b0..8d090b660ae 100644 --- a/backends/arm/test/ops/test_logsoftmax.py +++ b/backends/arm/test/ops/test_logsoftmax.py @@ -64,13 +64,7 @@ def test_log_softmax_tosa_INT(test_data): pipeline.run() -@common.parametrize( - "test_data", - LogSoftmax.test_data, - xfails={ - "randn_neg_dim": "MLBEDSW-11032: ILLEGAL_OFM_BASE error: Base addresses must be aligned to brick depth on u55." - }, -) +@common.parametrize("test_data", LogSoftmax.test_data) @common.XfailIfNoCorstone300() def test_log_softmax_u55_INT(test_data): data, dim = test_data() @@ -78,7 +72,6 @@ def test_log_softmax_u55_INT(test_data): LogSoftmax(dim), data, [], - run_on_fvp=True, ) pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) pipeline.change_args("run_method_and_compare_outputs", qtol=1) @@ -93,7 +86,6 @@ def test_log_softmax_u85_INT(test_data): LogSoftmax(dim), data, [], - run_on_fvp=True, ) pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) pipeline.change_args("run_method_and_compare_outputs", qtol=1) @@ -102,10 +94,14 @@ def test_log_softmax_u85_INT(test_data): @common.parametrize("test_data", LogSoftmax.test_data) @common.SkipIfNoModelConverter -def test_log_softmax_vgf_FP(test_data): +def test_log_softmax_vgf_no_quant(test_data): data, dim = test_data() pipeline = VgfPipeline[input_t1]( - LogSoftmax(dim), data, [], [], tosa_version="TOSA-1.0+FP" + LogSoftmax(dim), + data, + [], + [], + quantize=False, ) pipeline.add_stage_after( "to_edge_transform_and_lower", pipeline.tester.check_not, [aten_op] @@ -115,16 +111,14 @@ def test_log_softmax_vgf_FP(test_data): @common.parametrize("test_data", LogSoftmax.test_data) @common.SkipIfNoModelConverter -def test_log_softmax_vgf_INT(test_data): +def test_log_softmax_vgf_quant(test_data): data, dim = test_data() pipeline = VgfPipeline[input_t1]( LogSoftmax(dim), data, [], [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) - # TODO: MLETORCH-1136 Change args of run_method_and_compare_outputs of the vgf tests - # pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() diff --git a/backends/arm/test/ops/test_lshift.py b/backends/arm/test/ops/test_lshift.py index bab364a4528..878b18fc805 100644 --- a/backends/arm/test/ops/test_lshift.py +++ b/backends/arm/test/ops/test_lshift.py @@ -91,7 +91,6 @@ def test_bitwise_left_shift_tensor_tosa_INT_scalar(test_data): LshiftScalar.torch_op_INT, LshiftScalar.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -103,9 +102,7 @@ def test_bitwise_left_shift_tensor_u55_INT_scalar(test_data): test_data, LshiftScalar.torch_op_INT, LshiftScalar.exir_op, - run_on_fvp=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -117,36 +114,33 @@ def test_bitwise_left_shift_tensor_u85_INT_scalar(test_data): test_data, LshiftScalar.torch_op_INT, LshiftScalar.exir_op, - run_on_fvp=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", LshiftScalar.test_data) @common.SkipIfNoModelConverter -def test_bitwise_left_shift_scalar_vgf_FP_scalar(test_data: scalar_input_t): +def test_bitwise_left_shift_scalar_scalar_vgf_no_quant(test_data: scalar_input_t): pipeline = VgfPipeline[scalar_input_t]( LshiftScalar(), test_data, LshiftScalar.torch_op_FP, LshiftScalar.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", LshiftScalar.test_data) @common.SkipIfNoModelConverter -def test_bitwise_left_shift_tensor_vgf_INT_scalar(test_data: scalar_input_t): +def test_bitwise_left_shift_tensor_scalar_vgf_quant(test_data: scalar_input_t): pipeline = VgfPipeline[scalar_input_t]( LshiftScalar(), test_data, LshiftScalar.torch_op_INT, LshiftScalar.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -173,60 +167,54 @@ def test_bitwise_left_shift_tensor_tosa_INT(test_data): LshiftTensor.torch_op, LshiftTensor.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", LshiftTensor.test_data) -@XfailIfNoCorstone300 +@common.XfailIfNoCorstone300 def test_bitwise_left_shift_tensor_u55_INT(test_data): pipeline = EthosU55PipelineINT[scalar_input_t]( LshiftTensor(), test_data, LshiftTensor.torch_op, LshiftTensor.exir_op, - run_on_fvp=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", LshiftTensor.test_data) -@XfailIfNoCorstone320 +@common.XfailIfNoCorstone320 def test_bitwise_left_shift_tensor_u85_INT(test_data): pipeline = EthosU85PipelineINT[scalar_input_t]( LshiftTensor(), test_data, LshiftTensor.torch_op, LshiftTensor.exir_op, - run_on_fvp=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", LshiftTensor.test_data) @common.SkipIfNoModelConverter -def test_bitwise_left_shift_tensor_vgf_FP(test_data: tensor_input_t): +def test_bitwise_left_shift_tensor_vgf_no_quant(test_data: tensor_input_t): pipeline = VgfPipeline[tensor_input_t]( LshiftTensor(), test_data, LshiftTensor.torch_op, LshiftTensor.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", LshiftTensor.test_data) @common.SkipIfNoModelConverter -def test_bitwise_left_shift_tensor_vgf_INT(test_data: tensor_input_t): +def test_bitwise_left_shift_tensor_vgf_quant(test_data: tensor_input_t): pipeline = VgfPipeline[tensor_input_t]( LshiftTensor(), test_data, LshiftTensor.torch_op, LshiftTensor.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_lt.py b/backends/arm/test/ops/test_lt.py index 98d0298b195..e260cd5b75d 100644 --- a/backends/arm/test/ops/test_lt.py +++ b/backends/arm/test/ops/test_lt.py @@ -122,6 +122,30 @@ def test_lt_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_lt_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessThan.aten_op_tensor, + LessThan.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_lt_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessThan.aten_op_tensor, + LessThan.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_lt_tensor_u55_INT_not_delegated(test_module): @@ -162,7 +186,6 @@ def test_lt_tensor_u85_INT(test_module): test_module().get_inputs(), LessThan.aten_op_tensor, LessThan.exir_op, - run_on_fvp=True, ) pipeline.run() @@ -178,58 +201,93 @@ def test_lt_scalar_u85_INT(test_module): test_module().get_inputs(), LessThan.aten_op_tensor, LessThan.exir_op, - run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_lt_tensor_16a8w_u85_INT16(test_module): + """Test lt operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessThan.aten_op_tensor, + LessThan.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_lt_scalar_16a8w_u85_INT16(test_module): + """Test lt operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessThan.aten_op_tensor, + LessThan.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, ) pipeline.run() @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter -def test_lt_tensor_vgf_FP(test_module): +def test_lt_tensor_vgf_no_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), LessThan.aten_op_tensor, LessThan.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_data_scalar) @common.SkipIfNoModelConverter -def test_lt_scalar_vgf_FP(test_module): +def test_lt_scalar_vgf_no_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), LessThan.aten_op_scalar, LessThan.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter -def test_lt_tensor_vgf_INT(test_module): +def test_lt_tensor_vgf_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), LessThan.aten_op_tensor, LessThan.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_module", test_data_scalar) @common.SkipIfNoModelConverter -def test_lt_scalar_vgf_INT(test_module): +def test_lt_scalar_vgf_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), LessThan.aten_op_tensor, LessThan.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_masked_fill.py b/backends/arm/test/ops/test_masked_fill.py index 3aab19925ec..2704fa53257 100644 --- a/backends/arm/test/ops/test_masked_fill.py +++ b/backends/arm/test/ops/test_masked_fill.py @@ -147,19 +147,25 @@ def test_masked_fill_scalar_u85_INT(test_module): @common.parametrize("test_module", test_modules) @common.SkipIfNoModelConverter -def test_masked_fill_scalar_vgf_FP(test_module): +def test_masked_fill_scalar_vgf_no_quant(test_module): module, inputs = test_module() pipeline = VgfPipeline[input_t]( - module, inputs, aten_op=[], tosa_version="TOSA-1.0+FP" + module, + inputs, + aten_op=[], + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_modules) @common.SkipIfNoModelConverter -def test_masked_fill_scalar_vgf_INT(test_module): +def test_masked_fill_scalar_vgf_quant(test_module): module, inputs = test_module() pipeline = VgfPipeline[input_t]( - module, inputs, aten_op=[], tosa_version="TOSA-1.0+INT" + module, + inputs, + aten_op=[], + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_matmul.py b/backends/arm/test/ops/test_matmul.py index d1a21684325..489078a6dfa 100644 --- a/backends/arm/test/ops/test_matmul.py +++ b/backends/arm/test/ops/test_matmul.py @@ -22,6 +22,7 @@ class MatMul(torch.nn.Module): test_data_generators = { + "rand_rand_2d": lambda: (torch.rand(5, 5), torch.rand(5, 2)), "rand_rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)), "rand_rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)), } @@ -32,6 +33,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): class MatMulSingleInput(torch.nn.Module): test_data_generators = { + "rand_2d": lambda: (torch.rand(5, 5),), "rand_3d": lambda: (torch.rand(2, 5, 5),), "rand_4d": lambda: (torch.rand(1, 2, 5, 5),), } @@ -42,6 +44,11 @@ def forward(self, x: torch.Tensor): class MatMulCombo(torch.nn.Module): test_data_generators = { + "rand_rand_rand_2d": lambda: ( + torch.rand(5, 5), + torch.rand(5, 2), + torch.rand(2, 5), + ), "rand_rand_rand_3d": lambda: ( torch.rand(2, 5, 5), torch.rand(2, 5, 2), @@ -122,13 +129,18 @@ def test_matmul_u55_INT(test_data: input_t1): test_data(), aten_op_mm, exir_op_mm, - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() -@common.parametrize("test_data", MatMulSingleInput.test_data_generators) +@common.parametrize( + "test_data", + MatMulSingleInput.test_data_generators, + xfails={ + "rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs" + }, +) @common.XfailIfNoCorstone300 def test_matmul_single_input_u55_INT(test_data: input_t1): pipeline = EthosU55PipelineINT[input_t1]( @@ -136,13 +148,18 @@ def test_matmul_single_input_u55_INT(test_data: input_t1): test_data(), aten_op_mm, exir_op_mm, - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() -@common.parametrize("test_data", MatMulCombo.test_data_generators) +@common.parametrize( + "test_data", + MatMulCombo.test_data_generators, + xfails={ + "rand_rand_rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs" + }, +) @common.XfailIfNoCorstone300 def test_matmul_combo_u55_INT(test_data: input_t1): pipeline = EthosU55PipelineINT[input_t1]( @@ -150,7 +167,6 @@ def test_matmul_combo_u55_INT(test_data: input_t1): test_data(), aten_op_mm, exir_op_mm, - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() @@ -164,13 +180,18 @@ def test_matmul_u85_INT(test_data: input_t1): test_data(), aten_op_mm, exir_op_mm, - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() -@common.parametrize("test_data", MatMulSingleInput.test_data_generators) +@common.parametrize( + "test_data", + MatMulSingleInput.test_data_generators, + xfails={ + "rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs" + }, +) @common.XfailIfNoCorstone320 def test_matmul_single_input_u85_INT(test_data: input_t1): pipeline = EthosU85PipelineINT[input_t1]( @@ -178,13 +199,18 @@ def test_matmul_single_input_u85_INT(test_data: input_t1): test_data(), aten_op_mm, exir_op_mm, - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() -@common.parametrize("test_data", MatMulCombo.test_data_generators) +@common.parametrize( + "test_data", + MatMulCombo.test_data_generators, + xfails={ + "rand_rand_rand_4d": "MLBEDSW-11228: Matmul output diff between 1 input vs 2 identical inputs" + }, +) @common.XfailIfNoCorstone320 def test_matmul_combo_u85_INT(test_data: input_t1): pipeline = EthosU85PipelineINT[input_t1]( @@ -192,7 +218,6 @@ def test_matmul_combo_u85_INT(test_data: input_t1): test_data(), aten_op_mm, exir_op_mm, - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() @@ -200,69 +225,77 @@ def test_matmul_combo_u85_INT(test_data: input_t1): @common.parametrize("test_data", MatMul.test_data_generators) @common.SkipIfNoModelConverter -def test_matmul_vgf_FP(test_data: input_t1): +def test_matmul_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( - MatMul(), test_data(), aten_op_mm, exir_op_mm, tosa_version="TOSA-1.0+FP" + MatMul(), + test_data(), + aten_op_mm, + exir_op_mm, + quantize=False, ) pipeline.run() @common.parametrize("test_data", MatMulSingleInput.test_data_generators) @common.SkipIfNoModelConverter -def test_matmul_single_input_vgf_FP(test_data: input_t1): +def test_matmul_single_input_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( MatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", MatMulCombo.test_data_generators) @common.SkipIfNoModelConverter -def test_matmul_combo_vgf_FP(test_data: input_t1): +def test_matmul_combo_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( - MatMulCombo(), test_data(), aten_op_mm, exir_op_mm, tosa_version="TOSA-1.0+FP" + MatMulCombo(), + test_data(), + aten_op_mm, + exir_op_mm, + quantize=False, ) pipeline.run() @common.parametrize("test_data", MatMul.test_data_generators) @common.SkipIfNoModelConverter -def test_matmul_vgf_INT(test_data: input_t1): +def test_matmul_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( MatMul(), test_data(), aten_op_mm, exir_op_mm, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", MatMulSingleInput.test_data_generators) @common.SkipIfNoModelConverter -def test_matmul_single_input_vgf_INT(test_data: input_t1): +def test_matmul_single_input_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( MatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", MatMulCombo.test_data_generators) @common.SkipIfNoModelConverter -def test_matmul_combo_vgf_INT(test_data: input_t1): +def test_matmul_combo_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( MatMulCombo(), test_data(), aten_op_mm, exir_op_mm, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index 7db56311837..5c780a9bcb1 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -133,6 +133,20 @@ def test_max_pool2d_tosa_INT(test_data: torch.Tensor): pipeline.run() +@common.parametrize("test_data", test_data_suite) +def test_max_pool2d_tosa_INT_a16w8(test_data: torch.Tensor): + """Test max_pool2d operation with int16 I/O quantization for TOSA INT.""" + test_data, model_params = test_data() + pipeline = TosaPipelineINT[input_t1]( + MaxPool2d(*model_params), + (test_data,), + aten_op, + exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 def test_max_pool2d_u55_INT(test_data: torch.Tensor): @@ -142,10 +156,26 @@ def test_max_pool2d_u55_INT(test_data: torch.Tensor): (test_data,), aten_op, exir_ops=[], - run_on_fvp=True, ).run() +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_max_pool2d_16a8w_u55_INT16(test_data: torch.Tensor): + """Test max_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + test_data, model_params = test_data() + pipeline = EthosU55PipelineINT[input_t1]( + MaxPool2d(*model_params), + (test_data,), + aten_op, + exir_ops=[], + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 def test_max_pool2d_u85_INT(test_data: torch.Tensor): @@ -155,10 +185,26 @@ def test_max_pool2d_u85_INT(test_data: torch.Tensor): (test_data,), aten_op, exir_ops=[], - run_on_fvp=True, ).run() +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_max_pool2d_16a8w_u85_INT16(test_data: torch.Tensor): + """Test max_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + test_data, model_params = test_data() + pipeline = EthosU85PipelineINT[input_t1]( + MaxPool2d(*model_params), + (test_data,), + aten_op, + exir_ops=[], + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + reject_data_suite = { "reject_1": lambda: (MaxPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)), "reject_2": lambda: (MaxPool2d((1, 257), 1, 0), torch.rand(1, 16, 5, 300)), @@ -223,35 +269,35 @@ def test_max_pool2d_tosa_INT_dilation(test_data): # VGF tests @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_max_pool2d_vgf_FP(test_data: torch.Tensor): +def test_max_pool2d_vgf_no_quant(test_data: torch.Tensor): test_data, model_params = test_data() pipeline = VgfPipeline[input_t1]( MaxPool2d(*model_params), (test_data,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_max_pool2d_vgf_INT(test_data: torch.Tensor): +def test_max_pool2d_vgf_quant(test_data: torch.Tensor): test_data, model_params = test_data() pipeline = VgfPipeline[input_t1]( MaxPool2d(*model_params), (test_data,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", dilation_test_data) @common.SkipIfNoModelConverter -def test_max_pool2d_vgf_FP_dilation(test_data: torch.Tensor): +def test_max_pool2d_dilation_vgf_no_quant(test_data: torch.Tensor): """ VGF FP pipeline with dilation > 1 (and dilation=1 sanity cases). """ @@ -261,14 +307,14 @@ def test_max_pool2d_vgf_FP_dilation(test_data: torch.Tensor): (test_data,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", dilation_test_data) @common.SkipIfNoModelConverter -def test_max_pool2d_vgf_INT_dilation(test_data: torch.Tensor): +def test_max_pool2d_dilation_vgf_quant(test_data: torch.Tensor): """ VGF INT pipeline with dilation > 1 (and dilation=1 sanity cases). """ @@ -278,6 +324,6 @@ def test_max_pool2d_vgf_INT_dilation(test_data: torch.Tensor): (test_data,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_maximum.py b/backends/arm/test/ops/test_maximum.py index eb0d4b86efc..e213842494f 100644 --- a/backends/arm/test/ops/test_maximum.py +++ b/backends/arm/test/ops/test_maximum.py @@ -61,7 +61,6 @@ def test_maximum_u55_INT(test_data: Tuple): Maximum(), test_data(), aten_op, - run_on_fvp=True, ).run() @@ -72,29 +71,28 @@ def test_maximum_u85_INT(test_data: Tuple): Maximum(), test_data(), aten_op, - run_on_fvp=True, ).run() @common.parametrize("test_data", Maximum.test_parameters) @common.SkipIfNoModelConverter -def test_maximum_vgf_FP(test_data: Tuple): +def test_maximum_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[test_t]( Maximum(), test_data(), aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Maximum.test_parameters) @common.SkipIfNoModelConverter -def test_maximum_vgf_INT(test_data: Tuple): +def test_maximum_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[test_t]( Maximum(), test_data(), aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index 061e8da14f1..5195d955a1a 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -4,6 +4,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Callable import torch from executorch.backends.arm.test import common @@ -66,7 +67,6 @@ def test_adaptive_avg_pool2d_u55_INT(test_data): test_data(), AdaptiveAveragePool2d.aten_op, AdaptiveAveragePool2d.exir_op, - run_on_fvp=True, symmetric_io_quantization=True, ).run() @@ -79,34 +79,33 @@ def test_adaptive_avg_pool2d_u85_INT(test_data): test_data(), AdaptiveAveragePool2d.aten_op, AdaptiveAveragePool2d.exir_op, - run_on_fvp=True, symmetric_io_quantization=True, ).run() @common.parametrize("test_data", AdaptiveAveragePool2d.test_data_suite) @common.SkipIfNoModelConverter -def test_adaptive_avg_pool2d_vgf_FP(test_data): +def test_adaptive_avg_pool2d_vgf_no_quant(test_data): pipeline = VgfPipeline[input_t]( AdaptiveAveragePool2d(), test_data(), AdaptiveAveragePool2d.aten_op, AdaptiveAveragePool2d.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", AdaptiveAveragePool2d.test_data_suite) @common.SkipIfNoModelConverter -def test_adaptive_avg_pool2d_vgf_INT(test_data): +def test_adaptive_avg_pool2d_vgf_quant(test_data): pipeline = VgfPipeline[input_t]( AdaptiveAveragePool2d(), test_data(), AdaptiveAveragePool2d.aten_op, AdaptiveAveragePool2d.exir_op, symmetric_io_quantization=True, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -115,7 +114,7 @@ class MeanDim(torch.nn.Module): test_data_suite: dict[str, tuple] = { "rank_1_keepdim": lambda: ( torch.rand(7), - (0), + 0, True, ), "rank_2_keepdim": lambda: ( @@ -168,6 +167,11 @@ class MeanDim(torch.nn.Module): (0, 1, 2, 3), True, ), + "rand_none_keepdim": lambda: ( + torch.rand(1, 5, 7, 3), + None, + True, + ), "rank_1": lambda: ( torch.rand(7), (-1), @@ -280,20 +284,11 @@ def test_mean_dim_tosa_INT(test_data): (test_data,), [], # Might be sum, avgpool, or both symmetric_io_quantization=True, - custom_path="MEANDIM", ) pipeline.run() -xfails = { - "rank5_01234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)", - "rank5_234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)", - "rank5_12": "Rank 5 graph input currently not supported in EthosUBackend", - "rank5_2": "Rank 5 graph input currently not supported in EthosUBackend", -} - - -@common.parametrize("test_data", MeanDim.test_data_suite, xfails=xfails, strict=False) +@common.parametrize("test_data", MeanDim.test_data_suite) @common.XfailIfNoCorstone300 def test_mean_dim_u55_INT(test_data): test_data, dim, keep_dim = test_data() @@ -301,7 +296,6 @@ def test_mean_dim_u55_INT(test_data): MeanDim(dim, keep_dim), (test_data,), [], # Might be sum, avgpool, or both - run_on_fvp=True, symmetric_io_quantization=True, ) pipeline.add_stage_after( @@ -313,7 +307,7 @@ def test_mean_dim_u55_INT(test_data): pipeline.run() -@common.parametrize("test_data", MeanDim.test_data_suite, xfails=xfails, strict=False) +@common.parametrize("test_data", MeanDim.test_data_suite) @common.XfailIfNoCorstone320 def test_mean_dim_u85_INT(test_data): test_data, dim, keep_dim = test_data() @@ -321,7 +315,6 @@ def test_mean_dim_u85_INT(test_data): MeanDim(dim, keep_dim), (test_data,), [], # Might be sum, avgpool, or both - run_on_fvp=True, symmetric_io_quantization=True, ) pipeline.run() @@ -329,27 +322,67 @@ def test_mean_dim_u85_INT(test_data): @common.parametrize("test_data", MeanDim.test_data_suite) @common.SkipIfNoModelConverter -def test_mean_dim_vgf_FP(test_data): +def test_mean_dim_vgf_no_quant(test_data): test_data_val, dim, keep_dim = test_data() pipeline = VgfPipeline[input_t]( MeanDim(dim, keep_dim), (test_data_val,), MeanDim.torch_op, MeanDim.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", MeanDim.test_data_suite) @common.SkipIfNoModelConverter -def test_mean_dim_vgf_INT(test_data): +def test_mean_dim_vgf_quant(test_data): test_data_val, dim, keep_dim = test_data() pipeline = VgfPipeline[input_t]( MeanDim(dim, keep_dim), (test_data_val,), + [], + symmetric_io_quantization=True, + quantize=True, + ) + pipeline.run() + + +mean_input_t = tuple[torch.Tensor, bool] + + +class MeanDefault(torch.nn.Module): + def forward(self, tensor: torch.Tensor, keepdim: bool): + return tensor.mean() + + test_data_suite: dict[str, Callable[[], mean_input_t]] = { + "rank1": lambda: ( + torch.rand( + 1, + ), + False, + ), + "rank2": lambda: (torch.rand(5, 5), True), + "rank4": lambda: (torch.rand(5, 1, 10, 1), False), + } + + +@common.parametrize("test_data", MeanDefault.test_data_suite) +def test_mean_tosa_FP(test_data): + pipeline = TosaPipelineFP[mean_input_t]( + MeanDefault(), + test_data(), + [], # Might be sum, avgpool, or both + ) + pipeline.run() + + +@common.parametrize("test_data", MeanDefault.test_data_suite) +def test_mean_tosa_INT(test_data): + pipeline = TosaPipelineINT[mean_input_t]( + MeanDefault(), + test_data(), [], # Might be sum, avgpool, or both symmetric_io_quantization=True, - tosa_version="TOSA-1.0+INT", ) pipeline.run() diff --git a/backends/arm/test/ops/test_minimum.py b/backends/arm/test/ops/test_minimum.py index 88ae2c2b8da..ff706f7261e 100644 --- a/backends/arm/test/ops/test_minimum.py +++ b/backends/arm/test/ops/test_minimum.py @@ -61,7 +61,6 @@ def test_minimum_u55_INT(test_data: Tuple): Minimum(), test_data(), aten_op, - run_on_fvp=True, ).run() @@ -72,24 +71,28 @@ def test_minimum_u85_INT(test_data: Tuple): Minimum(), test_data(), aten_op, - run_on_fvp=True, ).run() @common.parametrize("test_data", Minimum.test_parameters) @common.SkipIfNoModelConverter -def test_minimum_vgf_FP(test_data: test_t): - pipeline = VgfPipeline[test_t](Minimum(), test_data(), aten_op) +def test_minimum_vgf_no_quant(test_data: test_t): + pipeline = VgfPipeline[test_t]( + Minimum(), + test_data(), + aten_op, + quantize=False, + ) pipeline.run() @common.parametrize("test_data", Minimum.test_parameters) @common.SkipIfNoModelConverter -def test_minimum_vgf_INT(test_data: test_t): +def test_minimum_vgf_quant(test_data: test_t): pipeline = VgfPipeline[test_t]( Minimum(), test_data(), aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_mm.py b/backends/arm/test/ops/test_mm.py index 1b76baaeff0..6d026888027 100644 --- a/backends/arm/test/ops/test_mm.py +++ b/backends/arm/test/ops/test_mm.py @@ -53,7 +53,6 @@ def test_mm_u55_INT(test_data: Tuple): MM(), test_data(), MM.aten_op, - run_on_fvp=True, ).run() @@ -65,27 +64,30 @@ def test_mm_u85_INT(test_data: Tuple): test_data(), MM.aten_op, MM.exir_op, - run_on_fvp=True, ).run() @common.parametrize("test_data", MM.test_data_generators) @common.SkipIfNoModelConverter -def test_mm_vgf_FP(test_data: Tuple): +def test_mm_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[test_t]( - MM(), test_data(), MM.aten_op, MM.exir_op, tosa_version="TOSA-1.0+FP" + MM(), + test_data(), + MM.aten_op, + MM.exir_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", MM.test_data_generators) @common.SkipIfNoModelConverter -def test_mm_vgf_INT(test_data: Tuple): +def test_mm_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[test_t]( MM(), test_data(), MM.aten_op, MM.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index b2db55d90fd..0cff1bb1d92 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -8,7 +8,6 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, @@ -188,7 +187,6 @@ def test_mul_tensor_tosa_INT_int32(test_data: torch.Tensor): aten_op, exir_op=[], ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -200,7 +198,6 @@ def test_mul_tensor_u55_INT(test_data: torch.Tensor): test_data(), aten_op, exir_ops=[], - run_on_fvp=True, ) pipeline.run() @@ -213,7 +210,6 @@ def test_mul_tensor_u85_INT(test_data: torch.Tensor): test_data(), aten_op, exir_ops=[], - run_on_fvp=True, ) pipeline.run() @@ -226,9 +222,7 @@ def test_mul_tensor_u55_INT_int32(test_data: torch.Tensor): test_data(), aten_op, exir_ops=[], - run_on_fvp=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -240,9 +234,7 @@ def test_mul_tensor_u85_INT_int32(test_data: torch.Tensor): test_data(), aten_op, exir_ops=[], - run_on_fvp=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -256,41 +248,40 @@ def test_mul_tensor_u85_INT_int32(test_data: torch.Tensor): test_data_suite | test_data_suite_2 | test_data_int32_without_broadcasting, ) @common.SkipIfNoModelConverter -def test_mul_tensor_vgf_FP(test_data: torch.Tensor): +def test_mul_tensor_vgf_no_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( Mul(), test_data(), aten_op, exir_op=[], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite | test_data_suite_2) @common.SkipIfNoModelConverter -def test_mul_tensor_vgf_INT(test_data: torch.Tensor): +def test_mul_tensor_vgf_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( Mul(), test_data(), aten_op, exir_op=[], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", test_data_suite_int32) @common.SkipIfNoModelConverter -def test_mul_tensor_vgf_INT_int32(test_data: torch.Tensor): +def test_mul_tensor_int32_vgf_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( Mul(), test_data(), aten_op, exir_op=[], - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -314,9 +305,6 @@ def get_symmetric_a16w8_mul_quantizer(per_channel_quantization=False): @common.parametrize("test_data", test_data_suite) -@pytest.mark.xfail( - reason="missing int16 mul ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13947" -) def test_mul_tensor_16a8w_tosa_INT(test_data: input_t1): """Test mul operation with 16A8W quantization (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -342,9 +330,6 @@ def test_mul_tensor_16a8w_tosa_INT(test_data: input_t1): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947" -) def test_mul_tensor_16a8w_u55_INT16(test_data: input_t1): """Test mul operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -356,7 +341,6 @@ def test_mul_tensor_16a8w_u55_INT16(test_data: input_t1): exir_ops=[], per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.change_args( @@ -370,9 +354,6 @@ def test_mul_tensor_16a8w_u55_INT16(test_data: input_t1): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947" -) def test_mul_tensor_16a8w_u85_INT16(test_data: input_t1): """Test mul operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -384,7 +365,6 @@ def test_mul_tensor_16a8w_u85_INT16(test_data: input_t1): exir_ops=[], per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.change_args( diff --git a/backends/arm/test/ops/test_multihead_attention.py b/backends/arm/test/ops/test_multihead_attention.py index 71cf076a157..50dcaae4635 100644 --- a/backends/arm/test/ops/test_multihead_attention.py +++ b/backends/arm/test/ops/test_multihead_attention.py @@ -3,7 +3,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import pytest import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -69,7 +68,6 @@ def test_multihead_attention_tosa_INT(test_data): "test_data", test_suite, ) -@pytest.mark.xfail(reason="MLETORCH-1102: Numerical issues on FVP") @common.XfailIfNoCorstone300 def test_multihead_attention_u55_INT(test_data: input_t1): test_data, module = test_data() @@ -79,7 +77,6 @@ def test_multihead_attention_u55_INT(test_data: input_t1): [], [], use_to_edge_transform_and_lower=True, - run_on_fvp=True, # TODO: Per-channel quantization is broken (MLETORCH-1144) per_channel_quantization=False, ) @@ -91,7 +88,6 @@ def test_multihead_attention_u55_INT(test_data: input_t1): "test_data", test_suite, ) -@pytest.mark.xfail(reason="MLETORCH-1102: Numerical issues on FVP") @common.XfailIfNoCorstone320 def test_multihead_attention_u85_INT(test_data: input_t1): test_data, module = test_data() @@ -101,7 +97,6 @@ def test_multihead_attention_u85_INT(test_data: input_t1): [], [], use_to_edge_transform_and_lower=True, - run_on_fvp=True, # TODO: Per-channel quantization is broken (MLETORCH-1144) per_channel_quantization=False, ) @@ -113,14 +108,14 @@ def test_multihead_attention_u85_INT(test_data: input_t1): test_suite, ) @common.SkipIfNoModelConverter -def test_multihead_attention_vgf_FP(test_data: input_t1): +def test_multihead_attention_vgf_no_quant(test_data: input_t1): test_data_vals, module = test_data() pipeline = VgfPipeline[input_t1]( module, (*test_data_vals, *test_data_vals, *test_data_vals), [], [], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @@ -130,15 +125,14 @@ def test_multihead_attention_vgf_FP(test_data: input_t1): test_suite, ) @common.SkipIfNoModelConverter -def test_multihead_attention_vgf_INT(test_data: input_t1): +def test_multihead_attention_vgf_quant(test_data: input_t1): test_data_vals, module = test_data() pipeline = VgfPipeline[input_t1]( module, (*test_data_vals, *test_data_vals, *test_data_vals), [], [], - tosa_version="TOSA-1.0+INT", - # TODO: Per-channel quantization is broken (MLETORCH-1144) per_channel_quantization=False, + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_ne.py b/backends/arm/test/ops/test_ne.py index 60f07ad9fdd..9fa1b1d96eb 100644 --- a/backends/arm/test/ops/test_ne.py +++ b/backends/arm/test/ops/test_ne.py @@ -159,9 +159,6 @@ def test_ne_scalar_u55_INT(test_module): @common.parametrize( "test_module", test_data_tensor, - xfails={ - "ne_tensor_rank4_randn": "MLETORCH-517: Batch size > 1 not fully supported", - }, strict=False, ) @common.XfailIfNoCorstone320 @@ -171,7 +168,6 @@ def test_ne_tensor_u85_INT(test_module): test_module.get_inputs(), NotEqual.decomposed_ops, NotEqual.decomposed_exir_ops, - run_on_fvp=True, ) pipeline.run() @@ -180,7 +176,6 @@ def test_ne_tensor_u85_INT(test_module): "test_module", test_data_scalar, xfails={ - "ne_scalar_rank4_randn": "MLETORCH-517: Batch size > 1 not fully supported", "ne_scalar_rank4_randn_1batch": "MLETORCH-847: Boolean ne result unstable on U85", }, strict=False, @@ -192,58 +187,57 @@ def test_ne_scalar_u85_INT(test_module): test_module.get_inputs(), NotEqual.decomposed_ops, NotEqual.decomposed_exir_ops, - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter -def test_ne_tensor_vgf_FP(test_module): +def test_ne_tensor_vgf_no_quant(test_module): pipeline = VgfPipeline[input_t]( test_module, test_module.get_inputs(), NotEqual.aten_op_Tensor, NotEqual.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter -def test_ne_tensor_vgf_INT(test_module): +def test_ne_tensor_vgf_quant(test_module): pipeline = VgfPipeline[input_t]( test_module, test_module.get_inputs(), NotEqual.decomposed_ops, NotEqual.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_module", test_data_scalar) @common.SkipIfNoModelConverter -def test_ne_scalar_vgf_FP(test_module): +def test_ne_scalar_vgf_no_quant(test_module): pipeline = VgfPipeline[input_t]( test_module, test_module.get_inputs(), NotEqual.aten_op_Scalar, NotEqual.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_data_scalar) @common.SkipIfNoModelConverter -def test_ne_scalar_vgf_INT(test_module): +def test_ne_scalar_vgf_quant(test_module): pipeline = VgfPipeline[input_t]( test_module, test_module.get_inputs(), NotEqual.decomposed_ops, NotEqual.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_neg.py b/backends/arm/test/ops/test_neg.py index 395a4815b62..11d1153a171 100644 --- a/backends/arm/test/ops/test_neg.py +++ b/backends/arm/test/ops/test_neg.py @@ -53,7 +53,10 @@ def test_neg_tosa_INT(test_data: input_t1): @common.XfailIfNoCorstone300 def test_neg_u55_INT(test_data: input_t1): pipeline = EthosU55PipelineINT[input_t1]( - Neg(), test_data, Neg.aten_op, Neg.exir_op, run_on_fvp=True + Neg(), + test_data, + Neg.aten_op, + Neg.exir_op, ) pipeline.run() @@ -62,28 +65,35 @@ def test_neg_u55_INT(test_data: input_t1): @common.XfailIfNoCorstone320 def test_neg_u85_INT(test_data: input_t1): pipeline = EthosU85PipelineINT[input_t1]( - Neg(), test_data, Neg.aten_op, Neg.exir_op, run_on_fvp=True + Neg(), + test_data, + Neg.aten_op, + Neg.exir_op, ) pipeline.run() @common.parametrize("test_data", Neg.test_data) @common.SkipIfNoModelConverter -def test_neg_vgf_FP(test_data: input_t1): +def test_neg_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( - Neg(), test_data, Neg.aten_op, Neg.exir_op, tosa_version="TOSA-1.0+FP" + Neg(), + test_data, + Neg.aten_op, + Neg.exir_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", Neg.test_data) @common.SkipIfNoModelConverter -def test_neg_vgf_INT(test_data: input_t1): +def test_neg_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Neg(), test_data, Neg.aten_op, Neg.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_ones.py b/backends/arm/test/ops/test_ones.py index f4dafca5e10..48c75906579 100644 --- a/backends/arm/test/ops/test_ones.py +++ b/backends/arm/test/ops/test_ones.py @@ -65,7 +65,10 @@ def test_ones_tosa_INT(test_data: test_data_t): input_data(), OnesAdd.aten_op, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -79,7 +82,10 @@ def test_ones_u55_INT(test_data: test_data_t): OnesAdd.aten_op, use_to_edge_transform_and_lower=True, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -92,8 +98,11 @@ def test_ones_u85_INT(test_data: test_data_t): input_data(), OnesAdd.aten_op, use_to_edge_transform_and_lower=True, - ).dump_artifact("to_edge_transform_and_lower") - pipeline.pop_stage("check.quant_nodes") + ) + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -115,23 +124,29 @@ def test_ones_tosa_INT_not_delegated(test_data: test_data_t): @common.parametrize("test_data", OnesAdd.test_data) @common.SkipIfNoModelConverter -def test_ones_vgf_FP(test_data: test_data_t): +def test_ones_vgf_no_quant(test_data: test_data_t): input_data, init_data = test_data pipeline = VgfPipeline[input_t]( - OnesAdd(*init_data), input_data(), OnesAdd.aten_op, tosa_version="TOSA-1.0+FP" + OnesAdd(*init_data), + input_data(), + OnesAdd.aten_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", OnesAdd.test_data) @common.SkipIfNoModelConverter -def test_ones_vgf_INT(test_data: test_data_t): +def test_ones_vgf_quant(test_data: test_data_t): input_data, init_data = test_data pipeline = VgfPipeline[input_t]( OnesAdd(*init_data), input_data(), OnesAdd.aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_permute.py b/backends/arm/test/ops/test_permute.py index eb482bcee54..b507155c8f2 100644 --- a/backends/arm/test/ops/test_permute.py +++ b/backends/arm/test/ops/test_permute.py @@ -9,24 +9,29 @@ from typing import Tuple import torch - -from executorch.backends.arm.test import common +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, + OpNotSupportedPipeline, TosaPipelineFP, TosaPipelineINT, VgfPipeline, ) -from torchvision.ops import Permute +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize input_t1 = Tuple[torch.Tensor] # Input x aten_op = "torch.ops.aten.permute.default" -exir_op = "executorch_exir_dialects_edge__ops_aten_permute_default" +exir_op = "executorch_exir_dialects_edge__ops_aten_permute_copy_default" -test_data_suite = { +test_data_suite_u55 = { # (test_name,test_data,dims) "rank_2": lambda: (torch.rand(10, 10), [1, 0]), "rank_3": lambda: (torch.rand(10, 10, 10), [2, 0, 1]), @@ -34,18 +39,27 @@ "rank_4": lambda: (torch.rand(1, 5, 1, 10), [0, 2, 3, 1]), "rank_4_2": lambda: (torch.rand(1, 2, 5, 10), [1, 0, 2, 3]), "rank_4_3": lambda: (torch.rand(1, 10, 10, 5), [2, 0, 1, 3]), + "rank_4_large": lambda: (torch.rand(2, 8, 64, 65), [0, 2, 3, 1]), + "rank_3_large": lambda: (torch.rand(16, 64, 65), [1, 2, 0]), + "reshape_large_1": lambda: (torch.rand(1, 1, 65537), [0, 2, 1]), + "reshape_large_2": lambda: (torch.rand(65537, 1, 1), [1, 2, 0]), } +test_data_suite_u55_reject = { + "rank2_bool": lambda: (torch.randint(0, 2, (5, 5), dtype=torch.bool), [1, 0]), +} +test_data_suite = test_data_suite_u55.copy() | test_data_suite_u55_reject.copy() + class SimplePermute(torch.nn.Module): def __init__(self, dims: list[int]): super().__init__() - self.permute = Permute(dims=dims) + self.dims = dims def forward(self, x): - return self.permute(x) + return torch.permute(x, self.dims) @common.parametrize("test_data", test_data_suite) @@ -72,11 +86,7 @@ def test_permute_tosa_INT(test_data: torch.Tensor): pipeline.run() -@common.parametrize( - "test_data", - test_data_suite, - xfails={"rank_4_3": "MLETORCH-955 : Permutation numerical diff for u55"}, -) +@common.parametrize("test_data", test_data_suite_u55) @common.XfailIfNoCorstone300 def test_permute_u55_INT(test_data): test_data, dims = test_data() @@ -85,7 +95,22 @@ def test_permute_u55_INT(test_data): (test_data,), aten_op, exir_ops="executorch_exir_dialects_edge__ops_aten_permute_copy_default", - run_on_fvp=True, + ) + if test_data[0].dtype == torch.bool: + pipeline.pop_stage("check_count.exir") + pipeline.tester.use_portable_ops = True + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_u55_reject) +def test_permute_u55_INT_not_delegated(test_data: torch.Tensor): + test_data, dims = test_data() + pipeline = OpNotSupportedPipeline[input_t1]( + SimplePermute(dims=dims), + (test_data,), + non_delegated_ops={exir_op: 1}, + quantize=True, + u55_subset=True, ) pipeline.run() @@ -99,34 +124,133 @@ def test_permute_u85_INT(test_data: torch.Tensor): (test_data,), aten_op, exir_ops="executorch_exir_dialects_edge__ops_aten_permute_copy_default", - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_permute_vgf_FP(test_data): +def test_permute_vgf_no_quant(test_data): test_data, dims = test_data() pipeline = VgfPipeline[input_t1]( SimplePermute(dims=dims), (test_data,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_permute_vgf_INT(test_data): +def test_permute_vgf_quant(test_data): test_data, dims = test_data() pipeline = VgfPipeline[input_t1]( SimplePermute(dims=dims), (test_data,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, + ) + pipeline.run() + + +def get_symmetric_a16w8_permute_quantizer( + u55_config=False, per_channel_quantization=False +): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", test_data_suite) +def test_permute_16a8w_tosa_INT(test_data: torch.Tensor): + """Test permute operation with int16 quantization""" + test_data, dims = test_data() + pipeline = TosaPipelineINT[input_t1]( + SimplePermute(dims=dims), + (test_data,), + aten_op, + exir_op=[], + per_channel_quantization=False, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_permute_quantizer(per_channel_quantization=False), + ) + # Run the pipeline + pipeline.run() + + +test_data_suite_exact = { + x: test_data_suite[x] + for x in test_data_suite + if x not in ("rank_4_3", "rank2_bool") +} + + +@common.parametrize( + "test_data", + test_data_suite_exact, +) +@common.XfailIfNoCorstone300 +def test_permute_16a8w_u55_INT16(test_data: torch.Tensor): + """Test permute operation with int16 quantization on U55""" + test_data, dims = test_data() + pipeline = EthosU55PipelineINT[input_t1]( + SimplePermute(dims=dims), + (test_data,), + aten_op, + exir_ops=[], + per_channel_quantization=True, + use_to_edge_transform_and_lower=True, + atol=1e-02, + rtol=1e-02, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_permute_quantizer(per_channel_quantization=False), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_permute_16a8w_u85_INT16(test_data: torch.Tensor): + """Test permute operation with int16 quantization on U85""" + test_data, dims = test_data() + pipeline = EthosU85PipelineINT[input_t1]( + SimplePermute(dims=dims), + (test_data,), + aten_op, + exir_ops=[], + use_to_edge_transform_and_lower=True, + atol=1e-03, + rtol=1e-03, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_permute_quantizer(per_channel_quantization=False), ) pipeline.run() diff --git a/backends/arm/test/ops/test_pixel_shuffling.py b/backends/arm/test/ops/test_pixel_shuffling.py new file mode 100644 index 00000000000..0c3436da87e --- /dev/null +++ b/backends/arm/test/ops/test_pixel_shuffling.py @@ -0,0 +1,237 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import torch + +from executorch.backends.arm.constants import MAX_RANK + +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) +from torch import nn + +aten_op_pixel_unshuffle = "torch.ops.aten.pixel_unshuffle.default" +exir_op_pixel_unshuffle = ( + "executorch_exir_dialects_edge__ops_aten_pixel_unshuffle_default" +) + +aten_op_pixel_shuffle = "torch.ops.aten.pixel_shuffle.default" +exir_op_pixel_shuffle = "executorch_exir_dialects_edge__ops_aten_pixel_shuffle_default" + +input_t1 = Tuple[torch.Tensor] # single positional input (1-tuple) + +max_rank_input_supported = MAX_RANK - 2 + + +class PixelUnShuffle(nn.Module): + + upscale_factor = 2 + test_data_generators = { + "rand_4d": lambda: (torch.randn(1, 12, 64, 64),), + "test_4d": lambda: (torch.tensor([[[[10.0, 20.0], [30.0, 40.0]]]]),), + "test_3d": lambda: (torch.tensor([[[10.0, 20.0], [30.0, 40.0]]]),), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.space_to_depth = nn.PixelUnshuffle(self.upscale_factor) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if inputs.dim() > max_rank_input_supported: + raise RuntimeError( + f"Max rank of input for pixel_unshuffle is currently {max_rank_input_supported}, got {inputs.dim()}" + ) + return self.space_to_depth(inputs) + + +class PixelShuffle(nn.Module): + + upscale_factor = 2 + test_data_generators = { + "rand_4d": lambda: (torch.randn(1, 12, 64, 64),), + "test_4d": lambda: (torch.tensor([[[[10.0]], [[20.0]], [[30.0]], [[40.0]]]]),), + "test_3d": lambda: (torch.tensor([[[10.0]], [[20.0]], [[30.0]], [[40.0]]]),), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.depth_to_space = nn.PixelShuffle(self.upscale_factor) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if inputs.dim() > max_rank_input_supported: + raise RuntimeError( + f"Max rank of input for pixel_shuffle is currently {max_rank_input_supported}, got {inputs.dim()}" + ) + return self.depth_to_space(inputs) + + +@common.parametrize("test_data", PixelUnShuffle.test_data_generators) +def test_pixel_unshuffle_tosa_FP(test_data: input_t1): + pipeline = TosaPipelineFP[input_t1]( + PixelUnShuffle(), + test_data(), + aten_op_pixel_unshuffle, + exir_op_pixel_unshuffle, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelUnShuffle.test_data_generators) +def test_pixel_unshuffle_tosa_INT(test_data: input_t1): + pipeline = TosaPipelineINT[input_t1]( + PixelUnShuffle(), + test_data(), + aten_op_pixel_unshuffle, + exir_op_pixel_unshuffle, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelShuffle.test_data_generators) +def test_pixel_shuffle_tosa_FP(test_data: input_t1): + pipeline = TosaPipelineFP[input_t1]( + PixelShuffle(), + test_data(), + aten_op_pixel_shuffle, + exir_op_pixel_shuffle, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelShuffle.test_data_generators) +def test_pixel_shuffle_tosa_INT(test_data: input_t1): + pipeline = TosaPipelineINT[input_t1]( + PixelShuffle(), + test_data(), + aten_op_pixel_shuffle, + exir_op_pixel_shuffle, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelUnShuffle.test_data_generators) +@common.SkipIfNoModelConverter +def test_pixel_unshuffle_vgf_no_quant(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + PixelUnShuffle(), + test_data(), + aten_op_pixel_unshuffle, + exir_op_pixel_unshuffle, + run_on_vulkan_runtime=True, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelUnShuffle.test_data_generators) +@common.SkipIfNoModelConverter +def test_pixel_unshuffle_vgf_quant(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + PixelUnShuffle(), + test_data(), + aten_op_pixel_unshuffle, + exir_op_pixel_unshuffle, + run_on_vulkan_runtime=True, + quantize=True, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelShuffle.test_data_generators) +@common.SkipIfNoModelConverter +def test_pixel_shuffle_vgf_no_quant(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + PixelShuffle(), + test_data(), + aten_op_pixel_shuffle, + exir_op_pixel_shuffle, + run_on_vulkan_runtime=True, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelShuffle.test_data_generators) +@common.SkipIfNoModelConverter +def test_pixel_shuffle_vgf_quant(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + PixelShuffle(), + test_data(), + aten_op_pixel_shuffle, + exir_op_pixel_shuffle, + run_on_vulkan_runtime=True, + quantize=True, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelUnShuffle.test_data_generators) +@common.XfailIfNoCorstone300 +def test_pixel_unshuffle_u55_INT(test_data: input_t1): + pipeline = EthosU55PipelineINT[input_t1]( + PixelUnShuffle(), + test_data(), + aten_op_pixel_unshuffle, + exir_op_pixel_unshuffle, + run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize( + "test_data", + PixelUnShuffle.test_data_generators, + xfails={"rand_4d": "MLETORCH-1424: rand test fails"}, +) +@common.XfailIfNoCorstone320 +def test_pixel_unshuffle_u85_INT(test_data: input_t1): + pipeline = EthosU85PipelineINT[input_t1]( + PixelUnShuffle(), + test_data(), + aten_op_pixel_unshuffle, + exir_op_pixel_unshuffle, + run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize("test_data", PixelShuffle.test_data_generators) +@common.XfailIfNoCorstone300 +def test_pixel_shuffle_u55_INT(test_data: input_t1): + pipeline = EthosU55PipelineINT[input_t1]( + PixelShuffle(), + test_data(), + aten_op_pixel_shuffle, + exir_op_pixel_shuffle, + run_on_fvp=True, + ) + pipeline.run() + + +@common.parametrize( + "test_data", + PixelShuffle.test_data_generators, + xfails={"rand_4d": "MLETORCH-1424: rand test fails"}, +) +@common.XfailIfNoCorstone320 +def test_pixel_shuffle_u85_INT(test_data: input_t1): + pipeline = EthosU85PipelineINT[input_t1]( + PixelShuffle(), + test_data(), + aten_op_pixel_shuffle, + exir_op_pixel_shuffle, + run_on_fvp=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_pow.py b/backends/arm/test/ops/test_pow.py index 016c3e97265..1955ff43587 100644 --- a/backends/arm/test/ops/test_pow.py +++ b/backends/arm/test/ops/test_pow.py @@ -62,10 +62,10 @@ class Pow_TensorScalar(torch.nn.Module): test_data = { # Test whole number exponents - "exp_minus_three": lambda: (torch.randn((10, 5)), -3.0), - "exp_minus_one": lambda: (torch.randn((42,)), -1.0), - "exp_zero": lambda: (torch.randn((1, 2, 3, 7)), 0.0), - "exp_one": lambda: (torch.randn((1, 4, 6, 2)), 1.0), + "exp_minus_three": lambda: (torch.randn((10, 5)).relu() + 0.1, -3.0), + "exp_minus_one": lambda: (torch.randn((42,)).relu() + 0.1, -1.0), + "exp_zero": lambda: (torch.randn((1, 2, 3, 7)).relu(), 0.0), + "exp_one": lambda: (torch.randn((1, 4, 6, 2)).relu(), 1.0), "exp_two": lambda: (torch.randn((1, 2, 3, 6)), 2.0), # Test decimal exponent (base must be non-negative) "non_neg_base_exp_pos_decimal": lambda: ( @@ -105,28 +105,25 @@ def test_pow_tensor_tensor_tosa_FP(test_data: Pow_TensorTensor.input_t): @common.parametrize("test_data", Pow_TensorTensor.test_data, x_fail, strict=False) @common.SkipIfNoModelConverter -def test_pow_tensor_tensor_vgf_FP(test_data: Pow_TensorTensor.input_t): +def test_pow_tensor_tensor_vgf_no_quant(test_data: Pow_TensorTensor.input_t): pipeline = VgfPipeline[Pow_TensorTensor.input_t]( Pow_TensorTensor(), test_data(), Pow_TensorTensor.aten_op, Pow_TensorTensor.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() x_fail = { - "exp_minus_three": "TOSA constraints: If x == 0 and y ⇐ 0, the result is undefined.", - "exp_minus_one": "TOSA constraints: If x == 0 and y ⇐ 0, the result is undefined.", - "exp_zero": "TOSA constraints: If x == 0 and y ⇐ 0, the result is undefined.", - "exp_one": "TOSA constraints: If x == 0 and y ⇐ 0, the result is undefined.", - "exp_two": "TOSA constraints: If x == 0 and y ⇐ 0, the result is undefined.", - "non_neg_base_exp_pos_decimal": "TOSA constraints: If x == 0 and y ⇐ 0, the result is undefined.", + "exp_two": "TOSA constraints: If x <0 .", } -@common.parametrize("test_data", Pow_TensorScalar.test_data, x_fail, strict=False) +@common.parametrize( + "test_data", Pow_TensorScalar.test_data, xfails=x_fail, strict=False +) def test_pow_tensor_scalar_tosa_FP(test_data: Pow_TensorScalar.input_t): base, exp = test_data() pipeline = TosaPipelineFP[Pow_TensorScalar.input_t]( @@ -138,7 +135,7 @@ def test_pow_tensor_scalar_tosa_FP(test_data: Pow_TensorScalar.input_t): pipeline.run() -@common.parametrize("test_data", Pow_TensorScalar.test_data, x_fail, strict=False) +@common.parametrize("test_data", Pow_TensorScalar.test_data, strict=False) def test_pow_tensor_scalar_tosa_INT(test_data: Pow_TensorScalar.input_t): base, exp = test_data() pipeline = TosaPipelineINT[Pow_TensorScalar.input_t]( @@ -159,7 +156,6 @@ def test_pow_tensor_scalar_u55_INT(test_data: Pow_TensorScalar.input_t): (base,), Pow_TensorScalar.aten_op, Pow_TensorScalar.exir_op, - run_on_fvp=True, ) pipeline.run() @@ -173,34 +169,36 @@ def test_pow_tensor_scalar_u85_INT(test_data: Pow_TensorScalar.input_t): (base,), Pow_TensorScalar.aten_op, Pow_TensorScalar.exir_op, - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", Pow_TensorScalar.test_data, x_fail, strict=False) @common.SkipIfNoModelConverter -def test_pow_tensor_scalar_vgf_FP(test_data: Pow_TensorScalar.input_t): +def test_pow_tensor_scalar_vgf_no_quant(test_data: Pow_TensorScalar.input_t): base, exp = test_data() pipeline = VgfPipeline[Pow_TensorScalar.input_t]( Pow_TensorScalar(exp), (base,), Pow_TensorScalar.aten_op, Pow_TensorScalar.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() -@common.parametrize("test_data", Pow_TensorScalar.test_data, x_fail, strict=False) +@common.parametrize( + "test_data", + Pow_TensorScalar.test_data, +) @common.SkipIfNoModelConverter -def test_pow_tensor_scalar_vgf_INT(test_data: Pow_TensorScalar.input_t): +def test_pow_tensor_scalar_vgf_quant(test_data: Pow_TensorScalar.input_t): base, exp = test_data() pipeline = VgfPipeline[Pow_TensorScalar.input_t]( Pow_TensorScalar(exp), (base,), Pow_TensorScalar.aten_op, Pow_TensorScalar.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_reciprocal.py b/backends/arm/test/ops/test_reciprocal.py index 78edbb980e8..5d09dfd9268 100644 --- a/backends/arm/test/ops/test_reciprocal.py +++ b/backends/arm/test/ops/test_reciprocal.py @@ -71,7 +71,6 @@ def test_reciprocal_u55_INT(test_data: torch.Tensor): (test_data(),), aten_op, exir_ops=[], - run_on_fvp=False, ) pipeline.run() @@ -84,7 +83,6 @@ def test_reciprocal_u85_INT(test_data: torch.Tensor): (test_data(),), aten_op, exir_ops=[], - run_on_fvp=False, symmetric_io_quantization=True, ) pipeline.run() @@ -92,23 +90,23 @@ def test_reciprocal_u85_INT(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_reciprocal_vgf_FP(test_data: torch.Tensor): +def test_reciprocal_vgf_no_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( Reciprocal(), (test_data(),), aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_reciprocal_vgf_INT(test_data: torch.Tensor): +def test_reciprocal_vgf_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( Reciprocal(), (test_data(),), aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_relu.py b/backends/arm/test/ops/test_relu.py index 0b29bc24e75..f659a3c86cb 100644 --- a/backends/arm/test/ops/test_relu.py +++ b/backends/arm/test/ops/test_relu.py @@ -43,6 +43,28 @@ def forward(self, x): return self.relu(x) +test_data_conv_relu = { + # (test_name, test_data) + "4d_randn_inplace=True": (lambda: (torch.randn(1, 64, 96, 96) * 1000, True)), + "4d_randn_inplace=False": (lambda: (torch.randn(1, 64, 96, 96) * 1000, False)), +} + + +class Conv2d_Relu_Add(torch.nn.Module): + def __init__(self, inplace: bool = True): + super().__init__() + self.conv1 = torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=7, padding="same" + ) + self.relu = torch.nn.ReLU(inplace=inplace) + + def forward(self, x: torch.Tensor): + y = self.conv1(x) + z = self.relu(y) + out = x + z + return out + + @common.parametrize("test_data", test_data_suite) def test_relu_tosa_FP(test_data: torch.Tensor): pipeline = TosaPipelineFP[input_t1]( @@ -54,6 +76,35 @@ def test_relu_tosa_FP(test_data: torch.Tensor): pipeline.run() +# Test the folding of Conv2D with ReLU +@common.parametrize("test_data", test_data_conv_relu) +def test_conv_relu_folding_tosa_INT(test_data: torch.Tensor): + input_data, inplace = test_data() + pipeline = TosaPipelineINT[input_t1]( + Conv2d_Relu_Add(inplace=inplace), + (input_data,), + [], + [], + ) + # We should have : + # 3 quantize_per_tensor nodes: input activation , output of the conv-relu sequence, out of the add + # 4 dequantize_per_tensor nodes: into the conv2d input, into the add, output of the conv-relu sequence, before returning + # 2 dequantize_per_channel nodes: one for the weights and another one for the bias + # In case of incorrect annotation of the ReLU, we get separate Q/DR around both the conv2d and the ReLU and + # therefore more quantize_per_tensor and dequantize_per_tensor nodes + pipeline.add_stage_after( + "quantize", + pipeline.tester.check_count, + { + "quantized_decomposed.quantize_per_tensor.default": 3, + "torch.ops.quantized_decomposed.dequantize_per_tensor.default": 4, + "quantized_decomposed.dequantize_per_channel.default": 2, + }, + suffix="quant_nodes", + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) def test_relu_tosa_INT(test_data: torch.Tensor): pipeline = TosaPipelineINT[input_t1]( @@ -66,50 +117,50 @@ def test_relu_tosa_INT(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 def test_relu_u55_INT(test_data: torch.Tensor): pipeline = EthosU55PipelineINT[input_t1]( Relu(), (test_data(),), aten_op, exir_op, - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 def test_relu_u85_INT(test_data: torch.Tensor): pipeline = EthosU85PipelineINT[input_t1]( Relu(), (test_data(),), aten_op, exir_op, - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_relu_vgf_FP(test_data: torch.Tensor): +def test_relu_vgf_no_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( Relu(), (test_data(),), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_relu_vgf_INT(test_data: torch.Tensor): +def test_relu_vgf_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( Relu(), (test_data(),), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_remainder.py b/backends/arm/test/ops/test_remainder.py new file mode 100644 index 00000000000..d1874d15fdb --- /dev/null +++ b/backends/arm/test/ops/test_remainder.py @@ -0,0 +1,199 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + + +def _nonzero_float_tensor(*shape: int) -> torch.Tensor: + return torch.rand(*shape, dtype=torch.float32) * 5 + 0.1 + + +class Remainder(torch.nn.Module): + input_t = Tuple[torch.Tensor | float, torch.Tensor | float] + + aten_op_tensor = "torch.ops.aten.remainder.Tensor" + exir_op_tensor = "executorch_exir_dialects_edge__ops_aten_remainder_Tensor" + aten_op_scalar = "torch.ops.aten.remainder.Scalar" + exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_remainder_Scalar" + + test_cases_tensor = { + "rank2_tensors": lambda: ( + torch.randn(2, 3) * 7, + _nonzero_float_tensor(2, 3), + ), + "rank4_tensors": lambda: ( + torch.randn(1, 4, 2, 3) * 7, + _nonzero_float_tensor(1, 4, 2, 3), + ), + "broadcast": lambda: ( + torch.randn(4, 5, 1), + _nonzero_float_tensor(1, 5, 6), + ), + } + + test_cases_scalar = { + "scalar_pos": lambda: ( + torch.randn(1, 2, 3, 4), + 0.25, + ), + "scalar_neg": lambda: ( + torch.randn(3, 4), + -0.25, + ), + } + + def forward(self, x: torch.Tensor | float, y: torch.Tensor | float) -> torch.Tensor: + return torch.remainder(x, y) + + +@common.parametrize("test_data", Remainder.test_cases_tensor) +def test_remainder_tensor_tosa_FP(test_data): + data = test_data() + pipeline = TosaPipelineFP[Remainder.input_t]( + Remainder(), + data, + Remainder.aten_op_tensor, + Remainder.exir_op_tensor, + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_scalar) +def test_remainder_scalar_tosa_FP(test_data): + data = test_data() + pipeline = TosaPipelineFP[Remainder.input_t]( + Remainder(), + data, + Remainder.aten_op_scalar, + Remainder.exir_op_scalar, + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_tensor) +def test_remainder_tensor_tosa_INT(test_data): + pipeline = TosaPipelineINT[Remainder.input_t]( + Remainder(), + test_data(), + [], + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_scalar) +def test_remainder_scalar_tosa_INT(test_data): + pipeline = TosaPipelineINT[Remainder.input_t]( + Remainder(), + test_data(), + [], + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_tensor) +@common.XfailIfNoCorstone300 +def test_remainder_tensor_u55_INT(test_data): + pipeline = EthosU55PipelineINT[Remainder.input_t]( + Remainder(), + test_data(), + [], + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_scalar) +@common.XfailIfNoCorstone300 +def test_remainder_scalar_u55_INT(test_data): + pipeline = EthosU55PipelineINT[Remainder.input_t]( + Remainder(), + test_data(), + [], + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_tensor) +@common.XfailIfNoCorstone320 +def test_remainder_tensor_u85_INT(test_data): + pipeline = EthosU85PipelineINT[Remainder.input_t]( + Remainder(), + test_data(), + [], + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_scalar) +@common.XfailIfNoCorstone320 +def test_remainder_scalar_u85_INT(test_data): + pipeline = EthosU85PipelineINT[Remainder.input_t]( + Remainder(), + test_data(), + [], + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_tensor) +@common.SkipIfNoModelConverter +def test_remainder_tensor_vgf_no_quant(test_data): + data = test_data() + pipeline = VgfPipeline[Remainder.input_t]( + Remainder(), + data, + Remainder.aten_op_tensor, + Remainder.exir_op_tensor, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_scalar) +@common.SkipIfNoModelConverter +def test_remainder_scalar_vgf_no_quant(test_data): + data = test_data() + pipeline = VgfPipeline[Remainder.input_t]( + Remainder(), + data, + Remainder.aten_op_scalar, + Remainder.exir_op_scalar, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_tensor) +@common.SkipIfNoModelConverter +def test_remainder_tensor_vgf_quant(test_data): + pipeline = VgfPipeline[Remainder.input_t]( + Remainder(), + test_data(), + [], + quantize=True, + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_scalar) +@common.SkipIfNoModelConverter +def test_remainder_scalar_vgf_quant(test_data): + pipeline = VgfPipeline[Remainder.input_t]( + Remainder(), + test_data(), + [], + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_repeat.py b/backends/arm/test/ops/test_repeat.py index 3236515b661..0b3de3b72df 100644 --- a/backends/arm/test/ops/test_repeat.py +++ b/backends/arm/test/ops/test_repeat.py @@ -16,6 +16,7 @@ from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, + OpNotSupportedPipeline, TosaPipelineFP, TosaPipelineINT, VgfPipeline, @@ -29,6 +30,7 @@ class Repeat(torch.nn.Module): aten_op = "torch.ops.aten.repeat.default" + exir_op = "executorch_exir_dialects_edge__ops_aten_repeat_default" def __init__(self, multiples: Sequence[int]): super().__init__() @@ -50,7 +52,7 @@ def forward(self, x: torch.Tensor): return x.repeat_interleave(self.repeats, self.dim) -test_data_suite = { +test_data_suite_u55 = { # test_name : lambda: (module, test_data) "1_x_1": lambda: (Repeat((2,)), (torch.randn(3),)), "2_x_2": lambda: (Repeat((2, 1)), (torch.randn(3, 4),)), @@ -61,6 +63,13 @@ def forward(self, x: torch.Tensor): "1_x_4": lambda: (Repeat((2, 1, 2, 4)), (torch.randn((3, 3, 3)),)), "interleave_int_3_x_1": lambda: (RepeatInterleaveInt(3, 1), (torch.randn(3, 4),)), } +test_data_suite_u55_reject = { + "1_x_1_bool": lambda: ( + Repeat((2,)), + (torch.randint(0, 2, (3,), dtype=torch.bool),), + ), +} +test_data_suite = test_data_suite_u55 | test_data_suite_u55_reject @common.parametrize("test_data", test_data_suite) @@ -87,7 +96,8 @@ def test_repeat_tosa_INT(test_data: Tuple): pipeline.run() -@common.parametrize("test_data", test_data_suite) +@common.parametrize("test_data", test_data_suite_u55) +@common.XfailIfNoCorstone300 def test_repeat_u55_INT(test_data: Tuple): module, test_data = test_data() pipeline = EthosU55PipelineINT[input_t1]( @@ -95,12 +105,26 @@ def test_repeat_u55_INT(test_data: Tuple): test_data, module.aten_op, exir_ops=[], - run_on_fvp=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_u55_reject) +@common.XfailIfNoCorstone300 +def test_repeat_u55_INT_not_delegated(test_data: Tuple): + module, test_data = test_data() + pipeline = OpNotSupportedPipeline[input_t1]( + module, + test_data, + non_delegated_ops={module.exir_op: 1}, + u55_subset=True, + quantize=True, ) pipeline.run() @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 def test_repeat_u85_INT(test_data: Tuple): module, test_data = test_data() pipeline = EthosU85PipelineINT[input_t1]( @@ -108,32 +132,31 @@ def test_repeat_u85_INT(test_data: Tuple): test_data, module.aten_op, exir_ops=[], - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_repeat_vgf_FP(test_data: Tuple): +def test_repeat_vgf_no_quant(test_data: Tuple): module, args = test_data() pipeline = VgfPipeline[input_t1]( module, args, module.aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_repeat_vgf_INT(test_data: Tuple): +def test_repeat_vgf_quant(test_data: Tuple): module, args = test_data() pipeline = VgfPipeline[input_t1]( module, args, module.aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_round.py b/backends/arm/test/ops/test_round.py index a4fea455e4f..572163c250a 100644 --- a/backends/arm/test/ops/test_round.py +++ b/backends/arm/test/ops/test_round.py @@ -87,25 +87,25 @@ def test_round_u85_INT(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_round_vgf_FP(test_data: torch.Tensor): +def test_round_vgf_no_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( Round(), (test_data(),), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_round_vgf_INT(test_data: torch.Tensor): +def test_round_vgf_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( Round(), (test_data(),), [], exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_rshift.py b/backends/arm/test/ops/test_rshift.py index e97bfb840ae..ea7cd4092fa 100644 --- a/backends/arm/test/ops/test_rshift.py +++ b/backends/arm/test/ops/test_rshift.py @@ -91,21 +91,18 @@ def test_bitwise_right_shift_tensor_tosa_INT_scalar(test_data): RshiftScalar.torch_op_INT, RshiftScalar.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", RshiftScalar.test_data) -@XfailIfNoCorstone300 +@common.XfailIfNoCorstone300 def test_bitwise_right_shift_tensor_u55_INT_scalar(test_data): pipeline = EthosU55PipelineINT[scalar_input_t]( RshiftScalar(), test_data(), RshiftScalar.torch_op_INT, RshiftScalar.exir_op, - run_on_fvp=True, ) - pipeline.pop_stage("check.quant_nodes") # Forced rounding in U55 HW causes off-by-one errors. pipeline.change_args("run_method_and_compare_outputs", inputs=test_data(), atol=1) @@ -113,43 +110,40 @@ def test_bitwise_right_shift_tensor_u55_INT_scalar(test_data): @common.parametrize("test_data", RshiftScalar.test_data) -@XfailIfNoCorstone320 +@common.XfailIfNoCorstone320 def test_bitwise_right_shift_tensor_u85_INT_scalar(test_data): pipeline = EthosU85PipelineINT[scalar_input_t]( RshiftScalar(), test_data(), RshiftScalar.torch_op_INT, RshiftScalar.exir_op, - run_on_fvp=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", RshiftScalar.test_data) @common.SkipIfNoModelConverter -def test_bitwise_right_shift_scalar_vgf_FP_scalar(test_data): +def test_bitwise_right_shift_tensor_vgf_no_quant_scalar(test_data): pipeline = VgfPipeline[scalar_input_t]( RshiftScalar(), test_data(), RshiftScalar.torch_op_FP, RshiftScalar.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", RshiftScalar.test_data) @common.SkipIfNoModelConverter -def test_bitwise_right_shift_tensor_vgf_INT_scalar(test_data): +def test_bitwise_right_shift_tensor_vgf_quant_scalar(test_data): pipeline = VgfPipeline[scalar_input_t]( RshiftScalar(), test_data(), RshiftScalar.torch_op_INT, RshiftScalar.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -176,7 +170,6 @@ def test_bitwise_right_shift_tensor_tosa_INT(test_data): RshiftTensor.torch_op, RshiftTensor.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -188,9 +181,7 @@ def test_bitwise_right_shift_tensor_u55_INT(test_data): test_data(), RshiftTensor.torch_op, RshiftTensor.exir_op, - run_on_fvp=True, ) - pipeline.pop_stage("check.quant_nodes") # Forced rounding in U55 HW causes off-by-one errors. pipeline.change_args("run_method_and_compare_outputs", inputs=test_data(), atol=1) @@ -205,34 +196,31 @@ def test_bitwise_right_shift_tensor_u85_INT(test_data): test_data(), RshiftTensor.torch_op, RshiftTensor.exir_op, - run_on_fvp=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @common.parametrize("test_data", RshiftTensor.test_data) @common.SkipIfNoModelConverter -def test_bitwise_right_shift_tensor_vgf_FP(test_data): +def test_bitwise_right_shift_tensor_vgf_no_quant(test_data): pipeline = VgfPipeline[tensor_input_t]( RshiftTensor(), test_data(), RshiftTensor.torch_op, RshiftTensor.exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", RshiftTensor.test_data) @common.SkipIfNoModelConverter -def test_bitwise_right_shift_tensor_vgf_INT(test_data): +def test_bitwise_right_shift_tensor_vgf_quant(test_data): pipeline = VgfPipeline[tensor_input_t]( RshiftTensor(), test_data(), RshiftTensor.torch_op, RshiftTensor.exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_rsqrt.py b/backends/arm/test/ops/test_rsqrt.py index d146a83287e..8c3ee914758 100644 --- a/backends/arm/test/ops/test_rsqrt.py +++ b/backends/arm/test/ops/test_rsqrt.py @@ -8,9 +8,11 @@ from typing import Tuple +import pytest import torch from executorch.backends.arm.test import common + from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -19,7 +21,6 @@ VgfPipeline, ) - aten_op = "torch.ops.aten.rsqrt.default" input_t1 = Tuple[torch.Tensor] # Input x @@ -66,7 +67,6 @@ def test_rsqrt_u55_INT(test_tensor: torch.Tensor): test_tensor(), aten_op, exir_ops=[], - run_on_fvp=True, ) pipeline.run() @@ -79,30 +79,79 @@ def test_rsqrt_u85_INT(test_tensor: torch.Tensor): test_tensor(), aten_op, exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_tensor", Rsqrt.test_parameters) @common.SkipIfNoModelConverter -def test_rsqrt_vgf_FP(test_tensor: torch.Tensor): +def test_rsqrt_vgf_no_quant(test_tensor: torch.Tensor): pipeline = VgfPipeline[input_t1]( Rsqrt(), test_tensor(), aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_tensor", Rsqrt.test_parameters) @common.SkipIfNoModelConverter -def test_rsqrt_vgf_INT(test_tensor: torch.Tensor): +def test_rsqrt_vgf_quant(test_tensor: torch.Tensor): pipeline = VgfPipeline[input_t1]( Rsqrt(), test_tensor(), aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, + ) + pipeline.run() + + +@common.parametrize("test_tensor", Rsqrt.test_parameters) +def test_rsqrt_tosa_INT_a16w8(test_tensor: torch.Tensor): + """Test rsqrt operation with int16 I/O quantization for TOSA INT.""" + # Use wider tolerances for int16 I/O quantization + pipeline = TosaPipelineINT[input_t1]( + Rsqrt(), + test_tensor(), + aten_op, + exir_op=[], + tosa_extensions=["int16"], + epsilon=2**16, + ) + pipeline.run() + + +@common.parametrize("test_tensor", Rsqrt.test_parameters) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="MLETORCH-707: AssertionError: Output 0 does not match reference output." +) +def test_rsqrt_16a8w_u55_INT16(test_tensor: torch.Tensor): + """Test rsqrt operation with int16 I/O quantization for U55""" + # Use wider tolerances for int16 I/O quantization on U55 + pipeline = EthosU55PipelineINT[input_t1]( + Rsqrt(), + test_tensor(), + aten_op, + exir_ops=[], + a16w8_quantization=True, + epsilon=2**16, + ) + pipeline.run() + + +@common.parametrize("test_tensor", Rsqrt.test_parameters) +@common.XfailIfNoCorstone320 +def test_rsqrt_16a8w_u85_INT16(test_tensor: torch.Tensor): + """Test rsqrt operation with int16 I/O quantization for U85""" + # Use wider tolerances for int16 I/O quantization on U85 + pipeline = EthosU85PipelineINT[input_t1]( + Rsqrt(), + test_tensor(), + aten_op, + exir_ops=[], + a16w8_quantization=True, + epsilon=2**16, ) pipeline.run() diff --git a/backends/arm/test/ops/test_rsub.py b/backends/arm/test/ops/test_rsub.py new file mode 100644 index 00000000000..1872521c1d7 --- /dev/null +++ b/backends/arm/test/ops/test_rsub.py @@ -0,0 +1,126 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +rsub_test_data = { + "rand_2D_4x4": lambda: (torch.rand(4, 4), 2), + "rand_3D_4x4x4": lambda: (torch.rand(4, 2, 2), 1.5), + "rand_4D_2x2x4x4": lambda: (torch.rand(2, 2, 4, 4), -1.1), + "rand_4D_big_small": lambda: ( + (10e30) * torch.randn(1, 20, 30, 40), + -0.25, + ), + "zero": lambda: (torch.rand(4, 4), 0), + # "swapped": lambda: (2, torch.rand(4, 4)), # torch.rsub(Scalar, Tensor) is not supported as it is not supported in eager mode. +} + + +class Rsub(torch.nn.Module): + aten_op = "torch.ops.aten.rsub.Scalar" + exir_op = "executorch_exir_dialects_edge__ops_aten_sub_Tensor" + + def forward(self, x: torch.Tensor, y: int): + return torch.rsub(x, y) + + +input_t1 = Tuple[torch.Tensor, torch.Tensor] + + +@common.parametrize("test_data", rsub_test_data) +def test_rsub_scalar_tosa_FP(test_data): + pipeline = TosaPipelineFP[input_t1]( + Rsub(), + test_data(), + aten_op=Rsub.aten_op, + exir_op=Rsub.exir_op, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_data", rsub_test_data) +def test_rsub_scalar_tosa_INT(test_data): + """Test Subtraction (TOSA INT)""" + pipeline = TosaPipelineINT[input_t1]( + Rsub(), + test_data(), + aten_op="torch.ops.aten.sub.Tensor", + exir_op=Rsub.exir_op, + use_to_edge_transform_and_lower=False, + qtol=0, + ) + pipeline.run() + + +@common.parametrize("test_data", rsub_test_data) +@common.XfailIfNoCorstone300 +def test_rsub_scalar_u55_INT(test_data): + """Test Subtraction on Ethos-U55 (FVP Mode)""" + pipeline = EthosU55PipelineINT[input_t1]( + Rsub(), + test_data(), + aten_ops="torch.ops.aten.sub.Tensor", + exir_ops=Rsub.exir_op, + run_on_fvp=True, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_data", rsub_test_data) +@common.XfailIfNoCorstone320 +def test_rsub_scalar_u85_INT(test_data): + """Test Subtraction on Ethos-U85 (FVP Mode)""" + pipeline = EthosU85PipelineINT[input_t1]( + Rsub(), + test_data(), + aten_ops="torch.ops.aten.sub.Tensor", + exir_ops=Rsub.exir_op, + run_on_fvp=True, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_data", rsub_test_data) +@common.SkipIfNoModelConverter +def test_rsub_scalar_vgf_no_quant(test_data: Tuple[torch.Tensor]): + """Test Subtraction (VGF FP)""" + pipeline = VgfPipeline[input_t1]( + Rsub(), + test_data(), + Rsub.aten_op, + Rsub.exir_op, + use_to_edge_transform_and_lower=False, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", rsub_test_data) +@common.SkipIfNoModelConverter +def test_rsub_scalar_vgf_quant(test_data: Tuple[torch.Tensor]): + """Test Subtraction (VGF INT)""" + pipeline = VgfPipeline[input_t1]( + Rsub(), + test_data(), + aten_op="torch.ops.aten.sub.Tensor", + exir_op=Rsub.exir_op, + use_to_edge_transform_and_lower=False, + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_scalar_tensor.py b/backends/arm/test/ops/test_scalar_tensor.py index 22c1cc0373d..bc265077f58 100644 --- a/backends/arm/test/ops/test_scalar_tensor.py +++ b/backends/arm/test/ops/test_scalar_tensor.py @@ -2,7 +2,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import torch from executorch.backends.arm.test import common @@ -74,7 +73,10 @@ def test_scalar_tensor_tosa_INT(test_data): tuple(data), ScalarTensor.aten_op, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -86,7 +88,6 @@ def test_scalar_tensor_u55_INT(test_data): ScalarTensor(scalar, dtype), tuple(data), ScalarTensor.aten_op, - run_on_fvp=True, ).run() @@ -98,32 +99,37 @@ def test_scalar_tensor_u85_INT(test_data): ScalarTensor(scalar, dtype), tuple(data), ScalarTensor.aten_op, - run_on_fvp=True, ).run() @common.parametrize("test_data", float_test_data_suite) @common.SkipIfNoModelConverter -def test_scalar_tensor_vgf_FP(test_data): +def test_scalar_tensor_vgf_no_quant(test_data): scalar, dtype, data = test_data() pipeline = VgfPipeline( ScalarTensor(scalar, dtype), tuple(data), ScalarTensor.aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() -@common.parametrize("test_data", int_test_data_suite) +@common.parametrize( + "test_data", + int_test_data_suite, +) @common.SkipIfNoModelConverter -def test_scalar_tensor_vgf_INT(test_data): +def test_scalar_tensor_vgf_quant(test_data): scalar, dtype, data = test_data() pipeline = VgfPipeline( ScalarTensor(scalar, dtype), tuple(data), ScalarTensor.aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index c4f371a1a14..b3704c87fb6 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -435,5 +435,4 @@ def test_bitwise_right_shift_tensor_tosa_INT_inplace(): (torch.IntTensor(5),), aten_op="torch.ops.aten.bitwise_right_shift.Tensor", ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_sdpa.py b/backends/arm/test/ops/test_sdpa.py index 009e4b2ad70..201d80acaf1 100644 --- a/backends/arm/test/ops/test_sdpa.py +++ b/backends/arm/test/ops/test_sdpa.py @@ -48,22 +48,26 @@ def test_sdpa_tosa_INT(): @common.SkipIfNoModelConverter -def test_sdpa_vgf_FP(): +def test_sdpa_vgf_no_quant(): test_input = tuple(torch.randn(1, 3, 197, 64) for _ in range(3)) pipeline = VgfPipeline[input_t]( - SDPA(), test_input, [], [], tosa_version="TOSA-1.0+FP" + SDPA(), + test_input, + [], + [], + quantize=False, ) pipeline.run() @common.SkipIfNoModelConverter -def test_sdpa_vgf_INT(): +def test_sdpa_vgf_quant(): test_input = tuple(torch.randn(1, 3, 197, 64) for _ in range(3)) pipeline = VgfPipeline[input_t]( SDPA(), test_input, [], [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_select.py b/backends/arm/test/ops/test_select.py index 4c3887f1e18..6a82300f252 100644 --- a/backends/arm/test/ops/test_select.py +++ b/backends/arm/test/ops/test_select.py @@ -110,7 +110,6 @@ def test_select_int_u55_INT_copy(test_data: Tuple): test_data(), aten_op_copy, exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() @@ -124,7 +123,6 @@ def test_select_int_u55_INT(test_data: Tuple): test_data(), aten_op_int, exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() @@ -151,7 +149,6 @@ def test_select_int_u85_INT_copy(test_data: Tuple): test_data(), aten_op_copy, exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() @@ -165,7 +162,6 @@ def test_select_int_u85_INT(test_data: Tuple): test_data(), aten_op_int, exir_ops=[], - run_on_fvp=True, use_to_edge_transform_and_lower=True, ) pipeline.run() @@ -173,43 +169,51 @@ def test_select_int_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_select_int_vgf_FP_copy(test_data: Tuple): +def test_select_int_copy_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( - SelectCopy(), test_data(), aten_op_copy, [], tosa_version="TOSA-1.0+FP" + SelectCopy(), + test_data(), + aten_op_copy, + [], + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_select_int_vgf_FP(test_data: Tuple): +def test_select_int_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( - SelectInt(), test_data(), aten_op_int, [], tosa_version="TOSA-1.0+FP" + SelectInt(), + test_data(), + aten_op_int, + [], + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_select_int_vgf_INT_copy(test_data: Tuple): +def test_select_int_copy_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( SelectCopy(), test_data(), aten_op_copy, [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_select_int_vgf_INT(test_data: Tuple): +def test_select_int_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( SelectInt(), test_data(), aten_op_int, [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_select_scatter.py b/backends/arm/test/ops/test_select_scatter.py new file mode 100644 index 00000000000..b4df8d4ab9d --- /dev/null +++ b/backends/arm/test/ops/test_select_scatter.py @@ -0,0 +1,173 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU85PipelineINT, + OpNotSupportedPipeline, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +test_data_suite = { + "rank2_rand": lambda: ( + torch.randint(-30, 30, (5, 9), dtype=torch.float32), + torch.randint(0, 9, (9,), dtype=torch.float32), + 0, + 2, + ), + "rank2_zeros": lambda: ( + torch.rand((3, 2), dtype=torch.float32), + torch.randint(0, 4, (2,), dtype=torch.float32), + 0, + 0, + ), + "rank3_rand": lambda: ( + torch.rand((2, 4, 5), dtype=torch.float32), + torch.randint(-5, 5, (2, 5), dtype=torch.float32), + 1, + 0, + ), + "rank3_ones": lambda: ( + torch.ones((2, 3, 3), dtype=torch.float32), + torch.rand((2, 3), dtype=torch.float32), + 2, + 2, + ), + "rank4_rand": lambda: ( + torch.rand((1, 2, 4, 5), dtype=torch.float32), + torch.rand((2, 4, 5), dtype=torch.float32), + 0, + 0, + ), + "rank4_ones": lambda: ( + torch.ones((2, 3, 3, 2), dtype=torch.float32), + torch.randint(-5, 5, (2, 3, 2), dtype=torch.float32), + 2, + -1, + ), + "rank5_ones": lambda: ( + torch.ones((3, 4, 20, 9, 5), dtype=torch.float32), + torch.randn((3, 4, 20, 9), dtype=torch.float32), + 4, + 1, + ), + "rank6_rand": lambda: ( + torch.rand((1, 2, 3, 4, 2, 1), dtype=torch.float32), + torch.randn((2, 3, 4, 2, 1), dtype=torch.float32), + 0, + 0, + ), +} + + +class SelectScatter(torch.nn.Module): + fp_aten_op = "torch.ops.aten.select_scatter.default" + int_aten_ops = [ + "torch.ops.aten.arange.start_step", + "torch.ops.aten.view_copy.default", + "torch.ops.aten.unsqueeze_copy.default", + "torch.ops.aten.expand_copy.default", + "torch.ops.aten.where.self", + "torch.ops.aten.eq.Tensor", + ] + fp_exir_op = ["executorch_exir_dialects_edge__ops_aten_select_scatter_default"] + int_exir_ops = [ + "executorch_exir_dialects_edge__ops_aten_eq_Tensor", + "executorch_exir_dialects_edge__ops_aten_where_self", + "executorch_exir_dialects_edge__ops_aten_arange_start_step", + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default", + "executorch_exir_dialects_edge__ops_aten_expand_copy_default", + ] + u55_not_supported = { + "executorch_exir_dialects_edge__ops_aten_eq_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_where_self": 1, + } + + def forward(self, x: torch.Tensor, y: torch.Tensor, dim: int, index: int): + return x.select_scatter(y, dim, index) + + +input_t = Tuple[torch.Tensor, torch.Tensor, int, int] + + +@common.parametrize("test_module", test_data_suite) +def test_select_scatter_tosa_FP(test_module: input_t): + pipeline = TosaPipelineFP[input_t]( + SelectScatter(), + test_module(), + aten_op=SelectScatter.fp_aten_op, + exir_op=SelectScatter.fp_exir_op, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_suite) +def test_select_scatter_tosa_INT(test_module: input_t): + pipeline = TosaPipelineINT[input_t]( + SelectScatter(), + test_module(), + aten_op=SelectScatter.int_aten_ops, + exir_op=SelectScatter.int_exir_ops, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_suite) +def test_select_scatter_u55_INT(test_module: input_t): + # select_scatter is not supported on U55 + pipeline = OpNotSupportedPipeline[input_t]( + SelectScatter(), + test_module(), + SelectScatter.u55_not_supported, + quantize=True, + u55_subset=True, + n_expected_delegates=1, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_module", test_data_suite) +def test_select_scatter_u85_INT(test_module: input_t): + pipeline = EthosU85PipelineINT[input_t]( + SelectScatter(), + test_module(), + aten_ops=SelectScatter.int_aten_ops, + exir_ops=SelectScatter.int_exir_ops, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +@common.parametrize("test_module", test_data_suite) +def test_select_scatter_vgf_no_quant(test_module: input_t): + pipeline = VgfPipeline[input_t]( + SelectScatter(), + test_module(), + aten_op=SelectScatter.fp_aten_op, + exir_op=SelectScatter.fp_exir_op, + quantize=False, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +@common.parametrize("test_module", test_data_suite) +def test_select_scatter_vgf_quant(test_module: input_t): + pipeline = VgfPipeline[input_t]( + SelectScatter(), + test_module(), + aten_op=SelectScatter.int_aten_ops, + exir_op=SelectScatter.int_exir_ops, + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_sigmoid.py b/backends/arm/test/ops/test_sigmoid.py index aac2ee1c9b1..bac6e376cee 100644 --- a/backends/arm/test/ops/test_sigmoid.py +++ b/backends/arm/test/ops/test_sigmoid.py @@ -34,6 +34,7 @@ "zeros": lambda: torch.zeros(10, 10, 10, 10), "ones": lambda: torch.ones(10, 10, 10), "rand": lambda: torch.rand(10, 10) - 0.5, + "rand_4d": lambda: torch.rand(1, 1, 5, 10), "randn_pos": lambda: torch.randn(10) + 10, "randn_neg": lambda: torch.randn(10) - 10, "ramp": lambda: torch.arange(-16, 16, 0.2), @@ -141,123 +142,123 @@ def test_sigmoid_tosa_INT_3(): @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 def test_sigmoid_u55_INT(test_data: Tuple): pipeline = EthosU55PipelineINT[input_t1]( Sigmoid(), (test_data(),), aten_op, exir_op, - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 def test_sigmoid_u85_INT(test_data: Tuple): pipeline = EthosU85PipelineINT[input_t1]( Sigmoid(), (test_data(),), aten_op, exir_op, - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_sigmoid_vgf_FP(test_data: Tuple): +def test_sigmoid_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Sigmoid(), (test_data(),), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_sigmoid_vgf_INT(test_data: Tuple): +def test_sigmoid_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Sigmoid(), (test_data(),), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.SkipIfNoModelConverter -def test_sigmoid_vgf_FP_add(): +def test_sigmoid_add_vgf_no_quant(): pipeline = VgfPipeline[input_t1]( AddSigmoid(), (test_data_suite["zeros"](),), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.SkipIfNoModelConverter -def test_sigmoid_vgf_INT_add(): +def test_sigmoid_add_vgf_quant(): pipeline = VgfPipeline[input_t1]( AddSigmoid(), (test_data_suite["ramp"](),), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.SkipIfNoModelConverter -def test_sigmoid_vgf_FP_add_2(): +def test_sigmoid_add_2_vgf_no_quant(): pipeline = VgfPipeline[input_t1]( SigmoidAdd(), (test_data_suite["zeros"](),), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.SkipIfNoModelConverter -def test_sigmoid_vgf_INT_add_2(): +def test_sigmoid_add_2_vgf_quant(): pipeline = VgfPipeline[input_t1]( SigmoidAdd(), (test_data_suite["zeros"](),), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.SkipIfNoModelConverter -def test_sigmoid_vgf_FP_add_3(): +def test_sigmoid_add_3_vgf_no_quant(): pipeline = VgfPipeline[input_t1]( SigmoidAddSigmoid(), (test_data_suite["randn_neg"](), test_data_suite["randn_pos"]()), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.SkipIfNoModelConverter -def test_sigmoid_vgf_INT_add_3(): +def test_sigmoid_add_3_vgf_quant(): pipeline = VgfPipeline[input_t1]( SigmoidAddSigmoid(), (test_data_suite["randn_neg"](), test_data_suite["randn_pos"]()), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -269,22 +270,23 @@ def get_symmetric_a16w8_sigmoid_quantizer(per_channel_quantization=False): } quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + + # Use a smaller episilon value to not greatly inflate [qmin, qmax] quantizer.set_global( - get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization, epsilon=2**-16 + ) ) return Quantize( quantizer, get_symmetric_a16w8_quantization_config( - is_per_channel=per_channel_quantization + is_per_channel=per_channel_quantization, epsilon=2**-16 ), ) @common.parametrize("test_data", test_data_suite) -@pytest.mark.xfail( - reason="missing int16 sigmoid ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13974" -) def test_sigmoid_16a8w_tosa_INT(test_data: torch.Tensor): """Test sigmoid operation with 16A8W quantization (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -311,7 +313,7 @@ def test_sigmoid_16a8w_tosa_INT(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 @pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations" + reason="MLETORCH-707: AssertionError: Output 0 does not match reference output." ) def test_sigmoid_16a8w_u55_INT16(test_data: torch.Tensor): """Test sigmoid operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" @@ -324,7 +326,6 @@ def test_sigmoid_16a8w_u55_INT16(test_data: torch.Tensor): exir_op, per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.change_args( @@ -338,9 +339,6 @@ def test_sigmoid_16a8w_u55_INT16(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations" -) def test_sigmoid_16a8w_u85_INT16(test_data: torch.Tensor): """Test sigmoid operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -352,7 +350,6 @@ def test_sigmoid_16a8w_u85_INT16(test_data: torch.Tensor): exir_op, per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.change_args( diff --git a/backends/arm/test/ops/test_sigmoid_16bit.py b/backends/arm/test/ops/test_sigmoid_16bit.py deleted file mode 100644 index ad8c49b234c..00000000000 --- a/backends/arm/test/ops/test_sigmoid_16bit.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import pytest -import torch -from executorch.backends.arm.quantizer import ( - get_symmetric_quantization_config, - TOSAQuantizer, -) -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test.tester.test_pipeline import ( - EthosU85PipelineINT, - OpNotSupportedPipeline, - TosaPipelineINT, -) -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.xnnpack.test.tester import Quantize -from torchao.quantization.pt2e import HistogramObserver -from torchao.quantization.pt2e.quantizer import QuantizationSpec - - -def _get_16_bit_quant_config(): - int16_spec = QuantizationSpec( - dtype=torch.int16, - observer_or_fake_quant_ctr=HistogramObserver, - qscheme=torch.per_tensor_symmetric, - ) - qconfig = QuantizationConfig( - input_activation=int16_spec, - output_activation=int16_spec, - weight=None, - bias=None, - ) - return qconfig - - -def get_16bit_sigmoid_quantizer(u55_config=False): - tosa_version = conftest.get_option("tosa_version") - tosa_profiles = { - "1.0": TosaSpecification.create_from_string( - "TOSA-1.0+INT+int16" + ("+u55" if u55_config else "") - ), - } - - quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) - quantizer.set_global(get_symmetric_quantization_config()) - quantizer.set_module_type( - torch.nn.modules.activation.Sigmoid, _get_16_bit_quant_config() - ) - - return Quantize(quantizer, get_symmetric_quantization_config()) - - -input_t = tuple[torch.Tensor] -test_data_suite = { - "ones": lambda: torch.ones(10, 10, 10), - "rand": lambda: torch.rand(10, 10) - 0.5, - "rand_4d": lambda: torch.rand(1, 1, 5, 10), - "randn_pos": lambda: torch.randn(10) + 10, - "randn_neg": lambda: torch.randn(10) - 10, - "ramp": lambda: torch.arange(-16, 16, 0.02), -} - - -class Sigmoid(torch.nn.Module): - aten_op = "torch.ops.aten.sigmoid.default" - exir_op = "executorch_exir_dialects_edge__ops_aten_sigmoid_default" - - def __init__(self): - super().__init__() - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, x): - return self.sigmoid(x) - - -class SigmoidAddSigmoid(torch.nn.Module): - def __init__(self): - super().__init__() - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, x): - return self.sigmoid((self.sigmoid(x) + self.sigmoid(x))) - - -@common.parametrize("test_data", test_data_suite) -def test_sigmoid_tosa_INT(test_data): - pipeline = TosaPipelineINT( - Sigmoid(), - (test_data(),), - Sigmoid.aten_op, - Sigmoid.exir_op, - qtol=1, - tosa_extensions=["int16"], - ) - pipeline.change_args("quantize", get_16bit_sigmoid_quantizer()) - pipeline.run() - - -@common.parametrize( - "test_data", - test_data_suite, - xfails={ - "ramp": "AssertionError: Output 0 does not match reference output. MLETORCH-787" - }, - strict=False, -) -def test_sigmoid_tosa_INT_add_sigmoid(test_data): - pipeline = TosaPipelineINT( - SigmoidAddSigmoid(), - (test_data(),), - Sigmoid.aten_op, - Sigmoid.exir_op, - qtol=1, - tosa_extensions=["int16"], - ) - pipeline.change_args("quantize", get_16bit_sigmoid_quantizer()) - pipeline.run() - - -@common.parametrize( - "test_data", - test_data_suite, -) -def test_sigmoid_u55_INT(test_data): - pipeline = OpNotSupportedPipeline( - Sigmoid(), - (test_data(),), - {Sigmoid.exir_op: 1}, - quantize=True, - u55_subset=True, - ) - pipeline.change_args("quantize", get_16bit_sigmoid_quantizer(True)) - pipeline.run() - - -@common.parametrize( - "test_data", - test_data_suite, -) -def test_sigmoid_u55_INT_add_sigmoid(test_data): - pipeline = OpNotSupportedPipeline( - SigmoidAddSigmoid(), - (test_data(),), - {Sigmoid.exir_op: 3}, - n_expected_delegates=1, - quantize=True, - u55_subset=True, - tosa_extensions=["int16"], - ) - pipeline.change_args("quantize", get_16bit_sigmoid_quantizer(True)) - pipeline.run() - - -@common.parametrize("test_data", test_data_suite) -@common.XfailIfNoCorstone320 -def test_sigmoid_u85_INT(test_data): - pipeline = EthosU85PipelineINT( - Sigmoid(), - (test_data(),), - Sigmoid.aten_op, - Sigmoid.exir_op, - run_on_fvp=True, - ) - pipeline.change_args("quantize", get_16bit_sigmoid_quantizer()) - pipeline.run() - - -@common.parametrize( - "test_data", - test_data_suite, - xfails={ - "ramp": "AssertionError: Output 0 does not match reference output. MLETORCH-787" - }, -) -@pytest.mark.xfail # MLETORCH-787: Investigate int16-int8 rescaling precision -@common.XfailIfNoCorstone320 -def test_sigmoid_u85_INT_add_sigmoid(test_data): - pipeline = EthosU85PipelineINT( - SigmoidAddSigmoid(), - (test_data(),), - Sigmoid.aten_op, - Sigmoid.exir_op, - run_on_fvp=True, - ) - pipeline.change_args("quantize", get_16bit_sigmoid_quantizer()) - pipeline.run() diff --git a/backends/arm/test/ops/test_sigmoid_32bit.py b/backends/arm/test/ops/test_sigmoid_32bit.py index 70863cd4757..29fc90b67fc 100644 --- a/backends/arm/test/ops/test_sigmoid_32bit.py +++ b/backends/arm/test/ops/test_sigmoid_32bit.py @@ -4,16 +4,13 @@ # LICENSE file in the root directory of this source tree. import torch -from executorch.backends.arm.quantizer import TOSAQuantizer from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU85PipelineINT, OpNotSupportedPipeline, TosaPipelineINT, ) -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.xnnpack.test.tester import Quantize from torchao.quantization.pt2e import HistogramObserver from torchao.quantization.pt2e.quantizer import QuantizationSpec @@ -53,22 +50,12 @@ def _get_32_bit_quant_config(): return qconfig -def get_32bit_sigmoid_quantizer(u55_config=False): - tosa_version = conftest.get_option("tosa_version") - tosa_profiles = { - "1.0": TosaSpecification.create_from_string( - "TOSA-1.0+INT+int16" + ("+u55" if u55_config else "") - ), - } - - quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) - quantizer.set_global(_get_32_bit_quant_config()) - quantizer.set_module_type( +def configure_32bit_sigmoid_quantizer(pipeline): + pipeline.quantizer.set_global(_get_32_bit_quant_config()) + pipeline.quantizer.set_module_type( torch.nn.modules.activation.Sigmoid, _get_16_bit_quant_config() ) - return Quantize(quantizer, _get_32_bit_quant_config()) - input_t = tuple[torch.Tensor] test_data_suite = { @@ -112,7 +99,7 @@ def test_sigmoid_tosa_INT(test_data): qtol=1, tosa_extensions=["int16"], ) - pipeline.change_args("quantize", get_32bit_sigmoid_quantizer()) + configure_32bit_sigmoid_quantizer(pipeline) pipeline.run() @@ -126,11 +113,12 @@ def test_sigmoid_tosa_INT_add_sigmoid(test_data): qtol=1, tosa_extensions=["int16"], ) - pipeline.change_args("quantize", get_32bit_sigmoid_quantizer()) + configure_32bit_sigmoid_quantizer(pipeline) pipeline.run() @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 def test_sigmoid_u55_INT(test_data): pipeline = OpNotSupportedPipeline( Sigmoid(), @@ -140,11 +128,12 @@ def test_sigmoid_u55_INT(test_data): u55_subset=True, tosa_extensions=["int16"], ) - pipeline.change_args("quantize", get_32bit_sigmoid_quantizer(True)) + configure_32bit_sigmoid_quantizer(pipeline) pipeline.run() @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 def test_sigmoid_u55_INT_add_sigmoid(test_data): pipeline = OpNotSupportedPipeline( SigmoidAddSigmoid(), @@ -155,7 +144,7 @@ def test_sigmoid_u55_INT_add_sigmoid(test_data): u55_subset=True, tosa_extensions=["int16"], ) - pipeline.change_args("quantize", get_32bit_sigmoid_quantizer(True)) + configure_32bit_sigmoid_quantizer(pipeline) pipeline.run() @@ -167,9 +156,8 @@ def test_sigmoid_u85_INT(test_data): (test_data(),), Sigmoid.aten_op, Sigmoid.exir_op, - run_on_fvp=True, ) - pipeline.change_args("quantize", get_32bit_sigmoid_quantizer()) + configure_32bit_sigmoid_quantizer(pipeline) pipeline.run() @@ -184,7 +172,6 @@ def test_sigmoid_u85_INT_add_sigmoid(test_data): (test_data(),), Sigmoid.aten_op, Sigmoid.exir_op, - run_on_fvp=True, ) - pipeline.change_args("quantize", get_32bit_sigmoid_quantizer()) + configure_32bit_sigmoid_quantizer(pipeline) pipeline.run() diff --git a/backends/arm/test/ops/test_sign.py b/backends/arm/test/ops/test_sign.py index 35ea9fc3e45..dd4f28981de 100644 --- a/backends/arm/test/ops/test_sign.py +++ b/backends/arm/test/ops/test_sign.py @@ -89,25 +89,25 @@ def test_sign_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_sign_vgf_FP(test_data: Tuple): +def test_sign_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Sign(), (test_data,), aten_op=aten_op, exir_op=exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_sign_vgf_INT(test_data: Tuple): +def test_sign_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Sign(), (test_data,), aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_silu.py b/backends/arm/test/ops/test_silu.py index edc7d769be1..62e960b750e 100644 --- a/backends/arm/test/ops/test_silu.py +++ b/backends/arm/test/ops/test_silu.py @@ -79,7 +79,9 @@ def test_silu_tosa_INT_inplace(test_data: input_t): def test_silu_u55_INT(test_data: input_t): silu_data = (test_data(), False) pipeline = EthosU55PipelineINT[input_t]( - Silu(), silu_data, Silu.aten_op_INT, run_on_fvp=True + Silu(), + silu_data, + Silu.aten_op_INT, ) pipeline.run() @@ -89,7 +91,9 @@ def test_silu_u55_INT(test_data: input_t): def test_silu_u55_INT_inplace(test_data: input_t): silu_data = (test_data(), True) pipeline = EthosU55PipelineINT[input_t]( - Silu(), silu_data, Silu.aten_op_INT, run_on_fvp=True + Silu(), + silu_data, + Silu.aten_op_INT, ) pipeline.run() @@ -99,7 +103,9 @@ def test_silu_u55_INT_inplace(test_data: input_t): def test_silu_u85_INT(test_data: input_t): silu_data = (test_data(), False) pipeline = EthosU85PipelineINT[input_t]( - Silu(), silu_data, Silu.aten_op_INT, run_on_fvp=True + Silu(), + silu_data, + Silu.aten_op_INT, ) pipeline.run() @@ -109,52 +115,60 @@ def test_silu_u85_INT(test_data: input_t): def test_silu_u85_INT_inplace(test_data: input_t): silu_data = (test_data(), True) pipeline = EthosU85PipelineINT[input_t]( - Silu(), silu_data, Silu.aten_op_INT, run_on_fvp=True + Silu(), + silu_data, + Silu.aten_op_INT, ) pipeline.run() @common.parametrize("test_data", Silu.test_data) @common.SkipIfNoModelConverter -def test_silu_vgf_FP(test_data: input_t): +def test_silu_vgf_no_quant(test_data: input_t): silu_data = (test_data(), False) pipeline = VgfPipeline[input_t]( - Silu(), silu_data, Silu.aten_op_FP, tosa_version="TOSA-1.0+FP" + Silu(), + silu_data, + Silu.aten_op_FP, + quantize=False, ) pipeline.run() @common.parametrize("test_data", Silu.test_data) @common.SkipIfNoModelConverter -def test_silu_vgf_FP_inplace(test_data: input_t): +def test_silu_inplace_vgf_no_quant(test_data: input_t): silu_data = (test_data(), True) pipeline = VgfPipeline[input_t]( - Silu(), silu_data, Silu.aten_op_inplace_FP, tosa_version="TOSA-1.0+FP" + Silu(), + silu_data, + Silu.aten_op_inplace_FP, + quantize=False, ) pipeline.run() @common.parametrize("test_data", Silu.test_data) @common.SkipIfNoModelConverter -def test_silu_vgf_INT(test_data: input_t): +def test_silu_vgf_quant(test_data: input_t): silu_data = (test_data(), False) pipeline = VgfPipeline[input_t]( Silu(), silu_data, Silu.aten_op_INT, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", Silu.test_data) @common.SkipIfNoModelConverter -def test_silu_vgf_INT_inplace(test_data: input_t): +def test_silu_inplace_vgf_quant(test_data: input_t): silu_data = (test_data(), True) pipeline = VgfPipeline[input_t]( Silu(), silu_data, Silu.aten_op_INT, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_sin.py b/backends/arm/test/ops/test_sin.py index 3ca593ad608..05cc8f5534d 100644 --- a/backends/arm/test/ops/test_sin.py +++ b/backends/arm/test/ops/test_sin.py @@ -61,45 +61,48 @@ def test_sin_tosa_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 def test_sin_u55_INT(test_data: Tuple): pipeline = EthosU55PipelineINT[input_t1]( Sin(), (test_data,), aten_op, exir_ops=[], - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 def test_sin_u85_INT(test_data: Tuple): pipeline = EthosU85PipelineINT[input_t1]( Sin(), (test_data,), aten_op, exir_ops=[], - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_sin_vgf_FP(test_data: Tuple): +def test_sin_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( - Sin(), (test_data,), aten_op, tosa_version="TOSA-1.0+FP" + Sin(), + (test_data,), + aten_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_sin_vgf_INT(test_data: Tuple): +def test_sin_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Sin(), (test_data,), aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_sinh.py b/backends/arm/test/ops/test_sinh.py index a059ce0ad26..703d3e52011 100644 --- a/backends/arm/test/ops/test_sinh.py +++ b/backends/arm/test/ops/test_sinh.py @@ -81,20 +81,23 @@ def test_sinh_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_sinh_vgf_FP(test_data: Tuple): +def test_sinh_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( - Sinh(), (test_data,), aten_op, tosa_version="TOSA-1.0+FP" + Sinh(), + (test_data,), + aten_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_sinh_vgf_INT(test_data: Tuple): +def test_sinh_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Sinh(), (test_data,), aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index eafeb04320e..ab5bafdef32 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -7,7 +7,6 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, @@ -34,17 +33,16 @@ test_data_suite = { "ones_neg_3": lambda: (torch.ones(10), [(3, -3)]), "ones_neg_8": lambda: (torch.ones(10), [(-8, 3)]), - "ones_slice_2": lambda: (torch.ones(10, 10), [(1, 3), (3, None)]), - "ones_slice_3": lambda: (torch.ones(10, 10, 10), [(0, 7), (0, None), (0, 8)]), + "ones_slice_2": lambda: (torch.ones(10, 10), [(1, 3), (3, 10)]), + "ones_slice_3": lambda: (torch.ones(10, 10, 10), [(0, 7), (0, 10), (0, 8)]), "ones_slice_4": lambda: ( torch.ones((1, 12, 10, 10)), - [(None, None), (None, 5), (3, 5), (4, 10)], + [(0, 1), (0, 5), (3, 5), (4, 10)], ), } class Slice(torch.nn.Module): - def forward(self, x: torch.Tensor, s: list[tuple[int, int]]): slices = [slice(*i) for i in s] return x[slices] @@ -79,51 +77,51 @@ def test_slice_tensor_tosa_INT_nhwc(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 def test_slice_tensor_u55_INT(test_data: torch.Tensor): pipeline = EthosU55PipelineINT[input_t1]( Slice(), test_data(), aten_ops=[], exir_ops=[], - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 def test_slice_tensor_u85_INT(test_data: torch.Tensor): pipeline = EthosU85PipelineINT[input_t1]( Slice(), test_data(), aten_ops=[], exir_ops=[], - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_slice_tensor_vgf_FP(test_data: torch.Tensor): +def test_slice_tensor_vgf_no_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( Slice(), test_data(), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_slice_tensor_vgf_INT(test_data: torch.Tensor): +def test_slice_tensor_vgf_quant(test_data: torch.Tensor): pipeline = VgfPipeline[input_t1]( Slice(), test_data(), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -148,9 +146,6 @@ def get_symmetric_a16w8_slice_quantizer(per_channel_quantization=False): @common.parametrize("test_data", test_data_suite) -@pytest.mark.xfail( - reason="missing int16 slice ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13976" -) def test_slice_tensor_16a8w_tosa_INT(test_data: torch.Tensor): """Test slice operation with 16A8W quantization (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -176,9 +171,6 @@ def test_slice_tensor_16a8w_tosa_INT(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 slice operations" -) def test_slice_tensor_16a8w_u55_INT16(test_data: torch.Tensor): """Test slice operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -190,7 +182,6 @@ def test_slice_tensor_16a8w_u55_INT16(test_data: torch.Tensor): exir_ops=[], per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.change_args( @@ -204,9 +195,6 @@ def test_slice_tensor_16a8w_u55_INT16(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 slice operations" -) def test_slice_tensor_16a8w_u85_INT16(test_data: torch.Tensor): """Test slice operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -218,7 +206,6 @@ def test_slice_tensor_16a8w_u85_INT16(test_data: torch.Tensor): exir_ops=[], per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.change_args( diff --git a/backends/arm/test/ops/test_softmax.py b/backends/arm/test/ops/test_softmax.py index dc258f20ec4..0b2af23d10c 100644 --- a/backends/arm/test/ops/test_softmax.py +++ b/backends/arm/test/ops/test_softmax.py @@ -61,17 +61,15 @@ def test_softmax_tosa_INT(test_data): pipeline.run() -@common.parametrize( - "test_data", - Softmax.test_data, - { - "randn_neg_dim": "MLBEDSW-11032: ILLEGAL_OFM_BASE error: Base addresses must be aligned to brick depth on u55." - }, -) +@common.parametrize("test_data", Softmax.test_data) @common.XfailIfNoCorstone300 def test_softmax_u55_INT(test_data): data, dim = test_data() - pipeline = EthosU55PipelineINT[input_t1](Softmax(dim), data, [], run_on_fvp=True) + pipeline = EthosU55PipelineINT[input_t1]( + Softmax(dim), + data, + [], + ) pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -81,7 +79,11 @@ def test_softmax_u55_INT(test_data): @common.XfailIfNoCorstone320 def test_softmax_u85_INT(test_data): data, dim = test_data() - pipeline = EthosU85PipelineINT[input_t1](Softmax(dim), data, [], run_on_fvp=True) + pipeline = EthosU85PipelineINT[input_t1]( + Softmax(dim), + data, + [], + ) pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -89,13 +91,13 @@ def test_softmax_u85_INT(test_data): @common.parametrize("test_data", Softmax.test_data) @common.SkipIfNoModelConverter -def test_softmax_vgf_FP(test_data): +def test_softmax_vgf_no_quant(test_data): data, dim = test_data() pipeline = VgfPipeline[input_t1]( Softmax(dim), data, [], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.add_stage_after( "to_edge_transform_and_lower", pipeline.tester.check_not, [exir_op] @@ -105,13 +107,13 @@ def test_softmax_vgf_FP(test_data): @common.parametrize("test_data", Softmax.test_data) @common.SkipIfNoModelConverter -def test_softmax_vgf_INT(test_data): +def test_softmax_vgf_quant(test_data): data, dim = test_data() pipeline = VgfPipeline[input_t1]( Softmax(dim), data, [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op]) # TODO: MLETORCH-1136 Change args of run_method_and_compare_outputs of the vgf tests diff --git a/backends/arm/test/ops/test_split.py b/backends/arm/test/ops/test_split.py index 388e85762af..a1028fd07ef 100644 --- a/backends/arm/test/ops/test_split.py +++ b/backends/arm/test/ops/test_split.py @@ -22,7 +22,6 @@ class Split(torch.nn.Module): - test_data = { "split_1d_2_size_0_dim": lambda: (torch.rand(10), 2, 0), "split_2d_3_size_1_dim": lambda: (torch.rand(10, 10), 3, 1), @@ -60,12 +59,24 @@ def forward( return x.split(split_size=split_size_or_sections, dim=dim)[1:3] +class SplitCopy(torch.nn.Module): + aten_op = "torch.ops.aten.split_copy.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_split_copy_Tensor" + + def forward( + self, + x: torch.Tensor, + split_size: int, + dim: int, + ): + return torch.split_copy(x, split_size=split_size, dim=dim) + + @common.parametrize( "test_data", (Split.test_data | Split.test_data_list), ) def test_split_with_sizes_tosa_FP(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1]( Split(), test_data(), @@ -77,7 +88,6 @@ def test_split_with_sizes_tosa_FP(test_data: input_t1): @common.parametrize("test_data", Split.test_data_list) def test_split_with_sizes_tosa_FP_2(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1]( SplitWithSizes(), test_data(), @@ -92,7 +102,6 @@ def test_split_with_sizes_tosa_FP_2(test_data: input_t1): (Split.test_data | Split.test_data_list), ) def test_split_with_sizes_tosa_FP_one_out(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1]( SplitSingleOut(), test_data(), @@ -107,7 +116,6 @@ def test_split_with_sizes_tosa_FP_one_out(test_data: input_t1): (Split.test_data | Split.test_data_list), ) def test_split_with_sizes_tosa_FP_two_out(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1]( SplitTwoOut(), test_data(), @@ -122,7 +130,6 @@ def test_split_with_sizes_tosa_FP_two_out(test_data: input_t1): (Split.test_data | Split.test_data_list), ) def test_split_with_sizes_tosa_INT(test_data: input_t1): - pipeline = TosaPipelineINT[input_t1]( Split(), test_data(), @@ -132,33 +139,26 @@ def test_split_with_sizes_tosa_INT(test_data: input_t1): pipeline.run() -@common.parametrize( - "test_data", - (Split.test_data | Split.test_data_list), -) +@common.parametrize("test_data", (Split.test_data | Split.test_data_list)) +@common.XfailIfNoCorstone300 def test_split_with_sizes_u55_INT(test_data: input_t1): pipeline = EthosU55PipelineINT[input_t1]( Split(), test_data(), aten_ops=[], exir_ops=exir_op, - run_on_fvp=False, ) pipeline.run() -@common.parametrize( - "test_data", - (Split.test_data | Split.test_data_list), -) +@common.parametrize("test_data", (Split.test_data | Split.test_data_list)) +@common.XfailIfNoCorstone320 def test_split_with_sizes_u85_INT(test_data: input_t1): - pipeline = EthosU85PipelineINT[input_t1]( Split(), test_data(), aten_ops=[], exir_ops=exir_op, - run_on_fvp=False, ) pipeline.run() @@ -168,27 +168,26 @@ def test_split_with_sizes_u85_INT(test_data: input_t1): (Split.test_data | Split.test_data_list), ) @common.SkipIfNoModelConverter -def test_split_with_sizes_vgf_FP(test_data: input_t1): +def test_split_with_sizes_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Split(), test_data(), aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Split.test_data_list) @common.SkipIfNoModelConverter -def test_split_with_sizes_vgf_FP_2(test_data: input_t1): - +def test_split_with_sizes_2_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( SplitWithSizes(), test_data(), aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @@ -198,14 +197,13 @@ def test_split_with_sizes_vgf_FP_2(test_data: input_t1): (Split.test_data | Split.test_data_list), ) @common.SkipIfNoModelConverter -def test_split_with_sizes_vgf_FP_one_out(test_data: input_t1): - +def test_split_with_sizes_one_out_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( SplitSingleOut(), test_data(), aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @@ -215,14 +213,13 @@ def test_split_with_sizes_vgf_FP_one_out(test_data: input_t1): (Split.test_data | Split.test_data_list), ) @common.SkipIfNoModelConverter -def test_split_with_sizes_vgf_FP_two_out(test_data: input_t1): - +def test_split_with_sizes_two_out_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( SplitTwoOut(), test_data(), aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @@ -232,13 +229,84 @@ def test_split_with_sizes_vgf_FP_two_out(test_data: input_t1): (Split.test_data | Split.test_data_list), ) @common.SkipIfNoModelConverter -def test_split_with_sizes_vgf_INT(test_data: input_t1): - +def test_split_with_sizes_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Split(), test_data(), aten_op=[], exir_op=exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, + ) + pipeline.run() + + +@common.parametrize("test_data", Split.test_data) +def test_split_tensor_tosa_FP(test_data: Tuple): + pipeline = TosaPipelineFP[input_t1]( + SplitCopy(), + test_data(), + aten_op=SplitCopy.aten_op, + exir_op=SplitCopy.exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", Split.test_data) +def test_split_tensor_tosa_INT(test_data: Tuple): + pipeline = TosaPipelineINT[input_t1]( + SplitCopy(), + test_data(), + aten_op=SplitCopy.aten_op, + exir_op=SplitCopy.exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", Split.test_data) +def test_split_tensor_u55_INT(test_data: Tuple): + pipeline = EthosU55PipelineINT[input_t1]( + SplitCopy(), + test_data(), + aten_ops=SplitCopy.aten_op, + exir_ops=SplitCopy.exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", Split.test_data) +def test_split_tensor_u85_INT(test_data: Tuple): + pipeline = EthosU85PipelineINT[input_t1]( + SplitCopy(), + test_data(), + aten_ops=SplitCopy.aten_op, + exir_ops=SplitCopy.exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", Split.test_data) +@common.SkipIfNoModelConverter +def test_split_tensor_vgf_no_quant(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + SplitCopy(), + test_data(), + aten_op=SplitCopy.aten_op, + exir_op=SplitCopy.exir_op, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", Split.test_data) +@common.SkipIfNoModelConverter +def test_split_tensor_vgf_quant(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + SplitCopy(), + test_data(), + aten_op=SplitCopy.aten_op, + exir_op=SplitCopy.exir_op, + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_sqrt.py b/backends/arm/test/ops/test_sqrt.py index 15e2dd45322..c3d1aae0883 100644 --- a/backends/arm/test/ops/test_sqrt.py +++ b/backends/arm/test/ops/test_sqrt.py @@ -70,7 +70,6 @@ def test_sqrt_u55_INT(test_data: Sqrt.input_t): test_data(), Sqrt.aten_op_INT, Sqrt.exir_op_INT, - run_on_fvp=True, ) pipeline.run() @@ -83,32 +82,31 @@ def test_sqrt_u85_INT(test_data: Sqrt.input_t): test_data(), Sqrt.aten_op_INT, Sqrt.exir_op_INT, - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", Sqrt.test_data) @common.SkipIfNoModelConverter -def test_sqrt_vgf_FP(test_data: Sqrt.input_t): +def test_sqrt_vgf_no_quant(test_data: Sqrt.input_t): pipeline = VgfPipeline[Sqrt.input_t]( Sqrt(), test_data(), Sqrt.aten_op_FP, Sqrt.exir_op_FP, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Sqrt.test_data) @common.SkipIfNoModelConverter -def test_sqrt_vgf_INT(test_data: Sqrt.input_t): +def test_sqrt_vgf_quant(test_data: Sqrt.input_t): pipeline = VgfPipeline[Sqrt.input_t]( Sqrt(), test_data(), Sqrt.aten_op_INT, Sqrt.exir_op_INT, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_squeeze.py b/backends/arm/test/ops/test_squeeze.py index 5c9f031deec..696c677b057 100644 --- a/backends/arm/test/ops/test_squeeze.py +++ b/backends/arm/test/ops/test_squeeze.py @@ -29,6 +29,7 @@ class SqueezeDim(torch.nn.Module): "squeeze3d_dim_neg_2": lambda: (torch.randn(1, 1, 5), -2), "squeeze4d_dim_pos_3": lambda: (torch.randn(1, 2, 3, 1), 3), "squeeze4d_dim_neg_2": lambda: (torch.randn(1, 5, 1, 5), -2), + "squeeze5d_dim_neg_2": lambda: (torch.randn(1, 1, 5, 1, 5), -2), } def forward(self, x: torch.Tensor, dim: int): @@ -40,6 +41,7 @@ class SqueezeDims(torch.nn.Module): "squeeze3d_dims_0_1": lambda: (torch.randn(1, 1, 5), (0, 1)), "squeeze4d_dims_0_neg_1": lambda: (torch.randn(1, 5, 5, 1), (0, -1)), "squeeze4d_dims_0_neg_2": lambda: (torch.randn(1, 5, 1, 5), (0, -2)), + "squeeze5d_dims_0_neg_2": lambda: (torch.randn(1, 1, 5, 1, 5), (0, -2)), } def forward(self, x: torch.Tensor, dims: tuple[int]): @@ -51,6 +53,7 @@ class Squeeze(torch.nn.Module): "squeeze3d": lambda: (torch.randn(1, 1, 5),), "squeeze4d_dims": lambda: (torch.randn(1, 5, 5, 1),), "squeeze3d_dims_mix": lambda: (torch.randn(1, 5, 1, 5),), + "squeeze4d_dims_mix": lambda: (torch.randn(1, 1, 5, 1, 5),), } def forward(self, x: torch.Tensor): @@ -92,7 +95,6 @@ def test_squeeze_dim_u55_INT(test_data: Tuple): test_data(), aten_ops="torch.ops.aten.squeeze.default", exir_ops=[], - run_on_fvp=True, ) pipeline.run() @@ -105,33 +107,32 @@ def test_squeeze_dim_u85_INT(test_data: Tuple): test_data(), aten_ops="torch.ops.aten.squeeze.default", exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", Squeeze.test_parameters) @common.SkipIfNoModelConverter -def test_squeeze_dim_vgf_FP(test_data: Tuple): +def test_squeeze_dim_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Squeeze(), test_data(), "torch.ops.aten.squeeze.default", [], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Squeeze.test_parameters) @common.SkipIfNoModelConverter -def test_squeeze_dim_vgf_INT(test_data: Tuple): +def test_squeeze_dim_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Squeeze(), test_data(), "torch.ops.aten.squeeze.default", [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -171,7 +172,6 @@ def test_squeeze_dim_u55_INT_2(test_data: Tuple): test_data(), aten_ops="torch.ops.aten.squeeze.dim", exir_ops=[], - run_on_fvp=True, ) pipeline.run() @@ -184,33 +184,32 @@ def test_squeeze_dim_u85_INT_2(test_data: Tuple): test_data(), aten_ops="torch.ops.aten.squeeze.dim", exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", SqueezeDim.test_parameters) @common.SkipIfNoModelConverter -def test_squeeze_dim_vgf_FP_2(test_data: Tuple): +def test_squeeze_dim_2_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( SqueezeDim(), test_data(), "torch.ops.aten.squeeze.dim", [], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", SqueezeDim.test_parameters) @common.SkipIfNoModelConverter -def test_squeeze_dim_vgf_INT_2(test_data: Tuple): +def test_squeeze_dim_2_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( SqueezeDim(), test_data(), "torch.ops.aten.squeeze.dim", [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -250,7 +249,6 @@ def test_squeeze_dims_u55_INT(test_data: Tuple): test_data(), aten_ops="torch.ops.aten.squeeze.dims", exir_ops=[], - run_on_fvp=True, ) pipeline.run() @@ -263,32 +261,31 @@ def test_squeeze_dims_u85_INT(test_data: Tuple): test_data(), aten_ops="torch.ops.aten.squeeze.dims", exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", SqueezeDims.test_parameters) @common.SkipIfNoModelConverter -def test_squeeze_dims_vgf_FP(test_data: Tuple): +def test_squeeze_dims_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( SqueezeDims(), test_data(), "torch.ops.aten.squeeze.dims", [], - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", SqueezeDims.test_parameters) @common.SkipIfNoModelConverter -def test_squeeze_dims_vgf_INT(test_data: Tuple): +def test_squeeze_dims_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( SqueezeDims(), test_data(), "torch.ops.aten.squeeze.dims", [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_stack.py b/backends/arm/test/ops/test_stack.py new file mode 100644 index 00000000000..a3911a62b01 --- /dev/null +++ b/backends/arm/test/ops/test_stack.py @@ -0,0 +1,150 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +import torch.nn as nn + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +test_data_suite = { + # (test_name, test_data) + "ones_two_tensors": lambda: ((torch.ones(1), torch.ones(1)), 0), + "ones_and_rand_three_tensors": lambda: ( + (torch.ones(1, 2), torch.randn(1, 2), torch.randn(1, 2)), + 1, + ), + "ones_and_rand_four_tensors": lambda: ( + ( + torch.ones(1, 2, 5), + torch.randn(1, 2, 5), + torch.randn(1, 2, 5), + torch.randn(1, 2, 5), + ), + -1, + ), + "rand_two_tensors": lambda: ( + (torch.randn(2, 2, 4), torch.randn(2, 2, 4)), + 2, + ), + "rand_two_tensors_dim_0": lambda: ( + (torch.randn(1, 2, 4, 4), torch.randn(1, 2, 4, 4)), + ), + "rand_two_tensors_dim_2": lambda: ( + (torch.randn(2, 2, 3, 5), torch.randn(2, 2, 3, 5)), + 2, + ), + "rand_large": lambda: ( + ( + 10000 * torch.randn(2, 3, 1, 4), + torch.randn(2, 3, 1, 4), + torch.randn(2, 3, 1, 4), + ), + -3, + ), +} + + +class Stack(nn.Module): + aten_op = "torch.ops.aten.stack.default" + exir_op = "executorch_exir_dialects_edge__ops_aten_cat_default" + + def forward(self, n: tuple[torch.Tensor, ...], dim: int = 0): + return torch.stack(n, dim) + + +input_t1 = Tuple[torch.Tensor] + + +@common.parametrize("test_module", test_data_suite) +def test_stack_tosa_FP(test_module: input_t1): + test_data = test_module() + pipeline = TosaPipelineFP[input_t1]( + Stack(), + test_data, + aten_op=Stack.aten_op, + exir_op=Stack.exir_op, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_suite) +def test_stack_tosa_INT(test_module: input_t1): + test_data = test_module() + pipeline = TosaPipelineINT[input_t1]( + Stack(), + test_data, + aten_op=Stack.aten_op, + exir_op=Stack.exir_op, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_module", test_data_suite) +def test_stack_u55_INT(test_module: input_t1): + test_data = test_module() + pipeline = EthosU55PipelineINT[input_t1]( + Stack(), + test_data, + aten_ops=Stack.aten_op, + exir_ops=Stack.exir_op, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_module", test_data_suite) +def test_stack_u85_INT(test_module: input_t1): + test_data = test_module() + pipeline = EthosU85PipelineINT[input_t1]( + Stack(), + test_data, + aten_ops=Stack.aten_op, + exir_ops=Stack.exir_op, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +@common.parametrize("test_module", test_data_suite) +def test_stack_vgf_no_quant(test_module: input_t1): + test_data = test_module() + pipeline = VgfPipeline[input_t1]( + Stack(), + test_data, + aten_op=Stack.aten_op, + exir_op=Stack.exir_op, + use_to_edge_transform_and_lower=False, + quantize=False, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +@common.parametrize("test_module", test_data_suite) +def test_stack_vgf_quant(test_module: input_t1): + test_data = test_module() + pipeline = VgfPipeline[input_t1]( + Stack(), + test_data, + aten_op=Stack.aten_op, + exir_op=Stack.exir_op, + use_to_edge_transform_and_lower=False, + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py index c691506beb2..f18f6525d27 100644 --- a/backends/arm/test/ops/test_sub.py +++ b/backends/arm/test/ops/test_sub.py @@ -10,8 +10,12 @@ from typing import Tuple import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -19,6 +23,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.sub.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_sub_Tensor" @@ -73,6 +79,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): return x - y +class SubAlpha(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return torch.sub(x, y, alpha=5) + + class SubTan(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor): @@ -109,6 +120,18 @@ def test_sub_tensor_tosa_FP_2(test_data: Tuple[torch.Tensor, torch.Tensor]): pipeline.run() +@common.parametrize("test_data", sub_tan_test_data) +def test_sub_tensor_tosa_FP_alpha(test_data: Tuple[torch.Tensor, torch.Tensor]): + """Test Two-Operand Subtraction with alpha (TOSA FP)""" + pipeline = TosaPipelineFP[input_t2]( + SubAlpha(), + test_data(), + aten_op, + exir_op, + ) + pipeline.run() + + @common.parametrize("test_data", sub_test_data) def test_sub_tensor_tosa_INT(test_data): """Test Subtraction (TOSA INT)""" @@ -126,12 +149,22 @@ def test_sub_tensor_tosa_INT_2(test_data: Tuple[torch.Tensor, torch.Tensor]): @common.parametrize("test_data", sub_tan_test_data) def test_sub_tensor_tosa_INT_3(test_data: Tuple[torch.Tensor, torch.Tensor]): """Test Two-Operand Subtraction (TOSA INT)""" + # This test has only been added to the tosa INT profile in order to catch quantization-induced errors. pipeline = TosaPipelineINT[input_t2]( SubTan(), test_data(), aten_op, exir_op, qtol=0 ) pipeline.run() +@common.parametrize("test_data", sub_tan_test_data) +def test_sub_tensor_tosa_INT_alpha(test_data: Tuple[torch.Tensor, torch.Tensor]): + """Test Two-Operand Subtraction with alpha (TOSA INT)""" + pipeline = TosaPipelineINT[input_t2]( + SubAlpha(), test_data(), aten_op, exir_op, qtol=0 + ) + pipeline.run() + + @common.parametrize("test_data", sub_test_data) @common.XfailIfNoCorstone300 def test_sub_tensor_u55_INT(test_data): @@ -141,7 +174,6 @@ def test_sub_tensor_u55_INT(test_data): test_data(), aten_op, exir_op, - run_on_fvp=True, ) pipeline.run() @@ -155,7 +187,6 @@ def test_sub_tensor_u55_INT_2(test_data: Tuple[torch.Tensor, torch.Tensor]): test_data(), aten_op, exir_op, - run_on_fvp=True, ) pipeline.run() @@ -169,7 +200,6 @@ def test_sub_tensor_u85_INT_2(test_data): test_data(), aten_op, exir_op, - run_on_fvp=True, ) pipeline.run() @@ -183,62 +213,154 @@ def test_sub_tensor_u85_INT(test_data: Tuple[torch.Tensor, torch.Tensor]): test_data(), aten_op, exir_op, - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", sub_test_data) @common.SkipIfNoModelConverter -def test_sub_tensor_vgf_FP(test_data: Tuple[torch.Tensor]): +def test_sub_tensor_vgf_no_quant(test_data: Tuple[torch.Tensor]): """Test Subtraction (VGF FP)""" pipeline = VgfPipeline[input_t1]( Sub(), test_data(), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", sub2_test_data) @common.SkipIfNoModelConverter -def test_sub_tensor_vgf_FP_2(test_data: Tuple[torch.Tensor, torch.Tensor]): +def test_sub_tensor_2_vgf_no_quant(test_data: Tuple[torch.Tensor, torch.Tensor]): """Test Two-Operand Subtraction (VGF FP)""" pipeline = VgfPipeline[input_t2]( Sub2(), test_data(), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", sub_test_data) @common.SkipIfNoModelConverter -def test_sub_tensor_vgf_INT(test_data: Tuple[torch.Tensor]): +def test_sub_tensor_vgf_quant(test_data: Tuple[torch.Tensor]): """Test Subtraction (VGF INT)""" pipeline = VgfPipeline[input_t1]( Sub(), test_data(), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @common.parametrize("test_data", sub2_test_data) @common.SkipIfNoModelConverter -def test_sub_tensor_vgf_INT_2(test_data: Tuple[torch.Tensor, torch.Tensor]): +def test_sub_tensor_2_vgf_quant(test_data: Tuple[torch.Tensor, torch.Tensor]): """Test Two-Operand Subtraction (VGF INT)""" pipeline = VgfPipeline[input_t2]( Sub2(), test_data(), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, + ) + pipeline.run() + + +def get_symmetric_a16w8_sub_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", sub_test_data) +def test_sub_tensor_16a8w_tosa_INT(test_data: input_t1): + """Test sub operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Sub(), + test_data(), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_sub_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", sub_test_data) +@common.XfailIfNoCorstone300 +def test_sub_tensor_16a8w_u55_INT16(test_data: input_t1): + """Test sub operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t1]( + Sub(), + test_data(), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_sub_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", sub_test_data) +@common.XfailIfNoCorstone320 +def test_sub_tensor_16a8w_u85_INT16(test_data: input_t1): + """Test sub operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t1]( + Sub(), + test_data(), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_sub_quantizer( + per_channel_quantization=per_channel_quantization + ), ) pipeline.run() diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py index 9308315f76d..14a6eeea14c 100644 --- a/backends/arm/test/ops/test_sum.py +++ b/backends/arm/test/ops/test_sum.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Callable, Tuple import torch from executorch.backends.arm.test import common @@ -35,6 +35,7 @@ class Sum(torch.nn.Module): "4d_dim_3_keep": lambda: (torch.rand(1, 2, 3, 4), 3, True), "4d_dims_keep": lambda: (torch.rand(1, 2, 8, 8), [2, 3, 0], True), "dim_None": lambda: (torch.rand(10), None, True), + "dim_None_4d_tensor": lambda: (torch.rand(10, 3, 2, 1), None, True), } def forward(self, x: torch.Tensor, dim: int, keepdim: bool): @@ -60,7 +61,6 @@ def test_sum_dim_intlist_tosa_INT(test_data: input_t1): aten_op, exir_op=[], ) - pipeline.dump_artifact("export") pipeline.run() @@ -72,7 +72,6 @@ def test_view_u55_INT_1_0(test_data: Tuple): test_data(), aten_op, exir_ops=[], - run_on_fvp=True, ) pipeline.run() @@ -85,28 +84,32 @@ def test_view_u85_INT_1_0(test_data: Tuple): test_data(), aten_op, exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", Sum.test_parameters) @common.SkipIfNoModelConverter -def test_sum_dim_intlist_vgf_FP(test_data: input_t1): +def test_sum_dim_intlist_vgf_no_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( - Sum(), test_data(), aten_op, tosa_version="TOSA-1.0+FP" + Sum(), + test_data(), + aten_op, + run_on_vulkan_runtime=True, + quantize=False, ) pipeline.run() @common.parametrize("test_data", Sum.test_parameters) @common.SkipIfNoModelConverter -def test_sum_dim_intlist_vgf_INT(test_data: input_t1): +def test_sum_dim_intlist_vgf_quant(test_data: input_t1): pipeline = VgfPipeline[input_t1]( Sum(), test_data(), aten_op, - tosa_version="TOSA-1.0+INT", + run_on_vulkan_runtime=True, + quantize=True, ) pipeline.run() @@ -119,7 +122,7 @@ def test_sum_dim_intlist_vgf_INT(test_data: input_t1): @common.parametrize("test_data", reject_inputs) -def test_view_u55_INT_not_delegated(test_data: Tuple): +def test_view_u55_INT_failure_set(test_data: Tuple): pipeline = EthosU55PipelineINT[input_t1]( Sum(), test_data(), @@ -129,3 +132,30 @@ def test_view_u55_INT_not_delegated(test_data: Tuple): ) pipeline.pop_stage("check_count.exir") pipeline.run() + + +input_t2 = tuple[torch.Tensor] + + +class SumDefault(torch.nn.Module): + test_parameters = { + "rank1": lambda: (torch.rand(10),), + "rank2": lambda: (torch.rand(10, 1, 10),), + "rank4": lambda: (torch.rand(1, 1, 5, 8),), + } + aten_op = "torch.ops.aten.sum.default" + + def forward(self, x: torch.Tensor): + return x.sum() + + +@common.parametrize("test_data", SumDefault.test_parameters) +def test_sum_tosa_FP(test_data: Callable[[], input_t2]): + pipeline = TosaPipelineFP[input_t2](SumDefault(), test_data(), SumDefault.aten_op) + pipeline.run() + + +@common.parametrize("test_data", SumDefault.test_parameters) +def test_sum_tosa_INT(test_data: Callable[[], input_t2]): + pipeline = TosaPipelineINT[input_t1](SumDefault(), test_data(), SumDefault.aten_op) + pipeline.run() diff --git a/backends/arm/test/ops/test_t_copy.py b/backends/arm/test/ops/test_t_copy.py new file mode 100644 index 00000000000..705e812cd6d --- /dev/null +++ b/backends/arm/test/ops/test_t_copy.py @@ -0,0 +1,115 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +test_data_suite = { + # test_name: (test_data, min, max) + "rand": lambda: (torch.rand(2, 3),), + "rand_multiplied": lambda: (torch.rand(3, 4) * 10,), + "ones": lambda: (torch.ones(5, 10),), + "randn": lambda: (torch.randn(1, 10) * 2,), +} + + +class TCopy(torch.nn.Module): + aten_op = "torch.ops.aten.t_copy.default" + exir_op = "executorch_exir_dialects_edge__ops_aten_permute_copy_default" + + def forward(self, x: torch.Tensor): + return torch.t_copy(x) + + +input_t1 = Tuple[torch.Tensor] + + +@common.parametrize("test_data", test_data_suite) +def test_t_tosa_FP(test_data: Tuple): + pipeline = TosaPipelineFP[input_t1]( + TCopy(), + test_data(), + aten_op=TCopy.aten_op, + exir_op=TCopy.exir_op, + use_to_edge_transform_and_lower=False, + ) + + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_t_tosa_INT(test_data: Tuple): + pipeline = TosaPipelineINT[input_t1]( + TCopy(), + test_data(), + aten_op=TCopy.aten_op, + exir_op=TCopy.exir_op, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_suite) +def test_t_u55_INT(test_data: Tuple): + pipeline = EthosU55PipelineINT[input_t1]( + TCopy(), + test_data(), + aten_ops=TCopy.aten_op, + exir_ops=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_suite) +def test_t_u85_INT(test_data: Tuple): + pipeline = EthosU85PipelineINT[input_t1]( + TCopy(), + test_data(), + aten_ops=TCopy.aten_op, + exir_ops=TCopy.exir_op, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_t_vgf_no_quant(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + TCopy(), + test_data(), + aten_op=TCopy.aten_op, + exir_op=TCopy.exir_op, + use_to_edge_transform_and_lower=False, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_t_vgf_quant(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + TCopy(), + test_data(), + aten_op=TCopy.aten_op, + exir_op=TCopy.exir_op, + use_to_edge_transform_and_lower=False, + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_tanh.py b/backends/arm/test/ops/test_tanh.py index 0e74618fd2f..d03fe03452b 100644 --- a/backends/arm/test/ops/test_tanh.py +++ b/backends/arm/test/ops/test_tanh.py @@ -8,12 +8,8 @@ import pytest import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - get_symmetric_a16w8_quantization_config, - TOSAQuantizer, -) -from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -21,8 +17,6 @@ TosaPipelineINT, VgfPipeline, ) -from executorch.backends.arm.tosa.specification import TosaSpecification -from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.tanh.default" input_t1 = Tuple[torch.Tensor] # Input x @@ -70,73 +64,54 @@ def test_tanh_tosa_INT(test_data: Tuple): @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 def test_tanh_u55_INT(test_data: Tuple): pipeline = EthosU55PipelineINT[input_t1]( Tanh(), (test_data(),), aten_op, exir_ops=[], - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 def test_tanh_u85_INT(test_data: Tuple): pipeline = EthosU85PipelineINT[input_t1]( Tanh(), (test_data(),), aten_op, exir_ops=[], - run_on_fvp=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_tanh_vgf_FP(test_data: Tuple): +def test_tanh_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( - Tanh(), (test_data(),), aten_op, tosa_version="TOSA-1.0+FP" + Tanh(), + (test_data(),), + aten_op, + quantize=False, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_tanh_vgf_INT(test_data: Tuple): +def test_tanh_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Tanh(), (test_data(),), aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() -def get_symmetric_a16w8_tanh_quantizer(per_channel_quantization=False): - tosa_version = conftest.get_option("tosa_version") - tosa_profiles = { - "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), - } - - quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) - quantizer.set_global( - get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) - ) - - return Quantize( - quantizer, - get_symmetric_a16w8_quantization_config( - is_per_channel=per_channel_quantization - ), - ) - - @common.parametrize("test_data", test_data_suite) -@pytest.mark.xfail( - reason="missing int16 tanh ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13975" -) def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor): """Test tanh operation with 16A8W quantization (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -149,13 +124,8 @@ def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor): per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, tosa_extensions=["int16"], - ) - - pipeline.change_args( - "quantize", - get_symmetric_a16w8_tanh_quantizer( - per_channel_quantization=per_channel_quantization - ), + epsilon=2**-16, + rtol=2e-03, ) pipeline.run() @@ -163,7 +133,7 @@ def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 @pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations" + reason="MLETORCH-707: AssertionError: Output 0 does not match reference output." ) def test_tanh_16a8w_u55_INT16(test_data: torch.Tensor): """Test tanh operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" @@ -176,23 +146,15 @@ def test_tanh_16a8w_u55_INT16(test_data: torch.Tensor): exir_ops=[], per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, - ) - - pipeline.change_args( - "quantize", - get_symmetric_a16w8_tanh_quantizer( - per_channel_quantization=per_channel_quantization - ), + a16w8_quantization=True, + epsilon=2**-16, + rtol=2e-03, ) pipeline.run() @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations" -) def test_tanh_16a8w_u85_INT16(test_data: torch.Tensor): """Test tanh operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -204,13 +166,8 @@ def test_tanh_16a8w_u85_INT16(test_data: torch.Tensor): exir_ops=[], per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, - ) - - pipeline.change_args( - "quantize", - get_symmetric_a16w8_tanh_quantizer( - per_channel_quantization=per_channel_quantization - ), + a16w8_quantization=True, + epsilon=2**-16, + rtol=2e-03, ) pipeline.run() diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py index 5c01788c805..17db2c3f226 100644 --- a/backends/arm/test/ops/test_to_copy.py +++ b/backends/arm/test/ops/test_to_copy.py @@ -95,24 +95,15 @@ def test_to_tosa_FP(test_data: Tuple): @common.parametrize("test_data", _TO_COPY_TEST_DATA_FP) @common.SkipIfNoModelConverter -def test_to_vgf_FP(test_data: Tuple): +def test_to_vgf_no_quant(test_data: Tuple): test_tensor, new_dtype = test_data() pipeline = VgfPipeline[input_t1]( Cast(new_dtype), (test_tensor,), aten_op=[], exir_op=[], - tosa_version="TOSA-1.0+FP", + quantize=False, ) - # int to int cast is not supported in TOSA+FP profile - if not new_dtype.is_floating_point and not torch.is_floating_point(test_tensor): - pipeline.change_args( - "check_count.exir", - { - "torch.ops.higher_order.executorch_call_delegate": 0, - "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, - }, - ) pipeline.run() @@ -164,7 +155,7 @@ def test_to_tosa_INT_not_delegated(test_data: Tuple): @common.parametrize("test_data", _TO_COPY_TEST_DATA_INT) @common.SkipIfNoModelConverter -def test_to_vgf_INT(test_data: Tuple): +def test_to_vgf_quant(test_data: Tuple): # Op not supported pass @@ -192,20 +183,15 @@ def test_to_vgf_INT(test_data: Tuple): ), } -redundant_xfails_FP = { +redundant_xfails = { "rand_fp16_fp16": "FP16 is not supported", "rand_int8_int8": "Tracing graph with quantized input is not supported.", "rand_int16_int16": "Tracing graph with quantized input is not supported.", } -redundant_xfails_INT = { - "rand_fp16_fp16": "FP16 is not supported", - "rand_int8_int8": "Tracing graph with quantized input is not supported.", -} - @common.parametrize( - "test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails_FP + "test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails ) def test_to_tosa_FP_REDUNDANT_CAST(test_data: Tuple): test_tensor, new_dtype = test_data() @@ -220,7 +206,7 @@ def test_to_tosa_FP_REDUNDANT_CAST(test_data: Tuple): @common.parametrize( - "test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails_INT + "test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails ) def test_to_tosa_INT_REDUNDANT_CAST(test_data: Tuple): test_tensor, new_dtype = test_data() @@ -231,7 +217,6 @@ def test_to_tosa_INT_REDUNDANT_CAST(test_data: Tuple): exir_op=[], ) pipeline.pop_stage("run_method_and_compare_outputs") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -244,3 +229,32 @@ def test_to_tosa_INT_not_delegated_REDUNDANT_CAST(test_data: Tuple): non_delegated_ops={}, # These are removed outside of the Arm backend so the graph is empty ) pipeline.run() + + +_TO_COPY_DATA_INT_U55_REJECT = { + "rand_bool_int8": lambda: ( + torch.randint(0, 2, (1, 2, 3, 4), dtype=torch.bool), + torch.int8, + ), + "rand_int16_bool": lambda: ( + torch.randint(-1000, 1000, (1, 2, 3, 4), dtype=torch.int16), + torch.bool, + ), + "rand_int32_int8": lambda: ( + torch.randint(-1000, 1000, (1, 2, 3, 4), dtype=torch.int32), + torch.int8, + ), +} + + +@common.parametrize("test_data", _TO_COPY_DATA_INT_U55_REJECT) +def test_to_u55_INT(test_data: Tuple): + test_tensor, new_dtype = test_data() + pipeline = OpNotSupportedPipeline[input_t1]( + Cast(new_dtype), + (test_tensor,), + u55_subset=True, + quantize=True, + non_delegated_ops={}, # These are removed outside of the Arm backend so the graph is empty + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_transpose_copy.py b/backends/arm/test/ops/test_transpose_copy.py new file mode 100644 index 00000000000..fb521eda1db --- /dev/null +++ b/backends/arm/test/ops/test_transpose_copy.py @@ -0,0 +1,114 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +test_data_suite = { + # test_name: (test_data, min, max) + "rank_2": lambda: (torch.rand(2, 3), 0, 1), + "rank_2_swapped": lambda: (torch.rand(3, 4), 1, 0), + "rank_3": lambda: (torch.ones(5, 10, 10), 1, 2), + "rank_4": lambda: (torch.rand(1, 10, 4, 2) * 2, 2, 0), +} + + +class TransposeCopy(torch.nn.Module): + aten_op = "torch.ops.aten.transpose_copy.int" + exir_op = "executorch_exir_dialects_edge__ops_aten_permute_copy_default" + + def forward(self, x: torch.Tensor, dim0: int, dim1: int): + return torch.transpose_copy(x, dim0=dim0, dim1=dim1) + + +input_t1 = Tuple[torch.Tensor] + + +@common.parametrize("test_data", test_data_suite) +def test_transpose_int_tosa_FP(test_data: Tuple): + pipeline = TosaPipelineFP[input_t1]( + TransposeCopy(), + test_data(), + aten_op=TransposeCopy.aten_op, + exir_op=TransposeCopy.exir_op, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_transpose_int_tosa_INT(test_data: Tuple): + pipeline = TosaPipelineINT[input_t1]( + TransposeCopy(), + test_data(), + aten_op=TransposeCopy.aten_op, + exir_op=TransposeCopy.exir_op, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_suite) +def test_transpose_int_u55_INT(test_data: Tuple): + pipeline = EthosU55PipelineINT[input_t1]( + TransposeCopy(), + test_data(), + aten_ops=TransposeCopy.aten_op, + exir_ops=TransposeCopy.exir_op, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_suite) +def test_transpose_int_u85_INT(test_data: Tuple): + pipeline = EthosU85PipelineINT[input_t1]( + TransposeCopy(), + test_data(), + aten_ops=TransposeCopy.aten_op, + exir_ops=TransposeCopy.exir_op, + use_to_edge_transform_and_lower=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_transpose_int_vgf_no_quant(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + TransposeCopy(), + test_data(), + aten_op=TransposeCopy.aten_op, + exir_op=TransposeCopy.exir_op, + use_to_edge_transform_and_lower=False, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_transpose_int_vgf_quant(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + TransposeCopy(), + test_data(), + aten_op=TransposeCopy.aten_op, + exir_op=TransposeCopy.exir_op, + use_to_edge_transform_and_lower=False, + quantize=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_unary_combos.py b/backends/arm/test/ops/test_unary_combos.py index db442d2d8d0..312350ea8d3 100644 --- a/backends/arm/test/ops/test_unary_combos.py +++ b/backends/arm/test/ops/test_unary_combos.py @@ -109,7 +109,10 @@ def test_unary_combos_tosa_INT(model_cls): def test_unary_combos_u55_INT(model_cls): m, inputs, exir = _build(model_cls) p = EthosU55PipelineINT[Tensor1]( - m, inputs, aten_ops=[], exir_ops=exir, run_on_fvp=True + m, + inputs, + aten_ops=[], + exir_ops=exir, ) p.run() @@ -119,16 +122,23 @@ def test_unary_combos_u55_INT(model_cls): def test_unary_combos_u85_INT(model_cls): m, inputs, exir = _build(model_cls) p = EthosU85PipelineINT[Tensor1]( - m, inputs, aten_ops=[], exir_ops=exir, run_on_fvp=True + m, + inputs, + aten_ops=[], + exir_ops=exir, ) p.run() @common.SkipIfNoModelConverter @pytest.mark.parametrize("model_cls", MODELS, ids=lambda c: c.__name__) -def test_unary_combos_vgf_INT(model_cls): +def test_unary_combos_vgf_quant(model_cls): m, inputs, exir = _build(model_cls) p = VgfPipeline[Tensor1]( - m, inputs, aten_op=[], exir_op=exir, tosa_version="TOSA-1.0+INT" + m, + inputs, + aten_op=[], + exir_op=exir, + quantize=True, ) p.run() diff --git a/backends/arm/test/ops/test_unbind.py b/backends/arm/test/ops/test_unbind.py index cd33f8217df..ce3f769cd06 100644 --- a/backends/arm/test/ops/test_unbind.py +++ b/backends/arm/test/ops/test_unbind.py @@ -58,25 +58,25 @@ def test_unbind_int_tosa_INT(test_data: test_data_t): @common.parametrize("test_data", Unbind.test_data) @common.SkipIfNoModelConverter -def test_unbind_int_vgf_FP(test_data: test_data_t): +def test_unbind_int_vgf_no_quant(test_data: test_data_t): input_data, init_data = test_data pipeline = VgfPipeline[input_t]( Unbind(*init_data), input_data(), Unbind.aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Unbind.test_data) @common.SkipIfNoModelConverter -def test_unbind_int_vgf_INT(test_data: test_data_t): +def test_unbind_int_vgf_quant(test_data: test_data_t): input_data, init_data = test_data pipeline = VgfPipeline[input_t]( Unbind(*init_data), input_data(), Unbind.aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_unflatten.py b/backends/arm/test/ops/test_unflatten.py index 95c68b2940d..d4730ac6dc2 100644 --- a/backends/arm/test/ops/test_unflatten.py +++ b/backends/arm/test/ops/test_unflatten.py @@ -9,6 +9,8 @@ import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, TosaPipelineFP, TosaPipelineINT, VgfPipeline, @@ -30,8 +32,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.unflatten(x, self.dim, self.sizes) test_data: dict[str, test_data_t] = { - "randn_4d": (lambda: (Unflatten(1, (2, 2)), (torch.randn(3, 4, 5, 1),))), - "rand_3d": (lambda: (Unflatten(1, (-1, 2)), (torch.rand(3, 4, 4),))), + "rand_3d_batch3": (lambda: (Unflatten(1, (-1, 2)), (torch.rand(3, 4, 4),))), + "rand_3d_batch1": (lambda: (Unflatten(1, (-1, 2)), (torch.rand(1, 4, 4),))), + "randn_4d_dim1": (lambda: (Unflatten(1, (2, 2)), (torch.randn(3, 4, 5, 1),))), + "randn_4d_dim3": (lambda: (Unflatten(3, (2, 2)), (torch.randn(1, 1, 5, 4),))), } @@ -49,7 +53,27 @@ def test_unflatten_int_tosa_FP(test_data: test_data_t): @common.parametrize("test_data", Unflatten.test_data) def test_unflatten_int_tosa_INT(test_data: test_data_t): module, inputs = test_data() - pipeline = TosaPipelineINT[input_t]( + pipeline = TosaPipelineINT[input_t](module, inputs, Unflatten.aten_op) + pipeline.run() + + +@common.parametrize("test_data", Unflatten.test_data, strict=False) +@common.XfailIfNoCorstone300 +def test_unflatten_int_u55_INT(test_data: test_data_t): + module, inputs = test_data() + pipeline = EthosU55PipelineINT[input_t]( + module, + inputs, + Unflatten.aten_op, + ) + pipeline.run() + + +@common.parametrize("test_data", Unflatten.test_data, strict=False) +@common.XfailIfNoCorstone320 +def test_unflatten_int_u85_INT(test_data: test_data_t): + module, inputs = test_data() + pipeline = EthosU85PipelineINT[input_t]( module, inputs, Unflatten.aten_op, @@ -59,25 +83,25 @@ def test_unflatten_int_tosa_INT(test_data: test_data_t): @common.parametrize("test_data", Unflatten.test_data) @common.SkipIfNoModelConverter -def test_unflatten_int_vgf_FP(test_data: test_data_t): +def test_unflatten_int_vgf_no_quant(test_data: test_data_t): module, inputs = test_data() pipeline = VgfPipeline[input_t]( module, inputs, Unflatten.aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", Unflatten.test_data) @common.SkipIfNoModelConverter -def test_unflatten_int_vgf_INT(test_data: test_data_t): +def test_unflatten_int_vgf_quant(test_data: test_data_t): module, inputs = test_data() pipeline = VgfPipeline[input_t]( module, inputs, Unflatten.aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_unsqueeze.py b/backends/arm/test/ops/test_unsqueeze.py index 54e1b0dd0ce..0c29d3b588c 100644 --- a/backends/arm/test/ops/test_unsqueeze.py +++ b/backends/arm/test/ops/test_unsqueeze.py @@ -25,7 +25,7 @@ class Unsqueeze(torch.nn.Module): - shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 4), (5, 4, 3)] + shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 4), (5, 4, 3), (1, 5, 4, 3)] test_parameters = {} for n in shapes: test_parameters[f"rand_{n}"] = (torch.randn(n),) @@ -65,7 +65,6 @@ def test_unsqueeze_u55_INT(test_tensor: torch.Tensor): (*test_tensor, 0), aten_op, exir_ops=[], - run_on_fvp=False, ) pipeline.run() @@ -78,29 +77,31 @@ def test_unsqueeze_u85_INT(test_tensor: torch.Tensor): (*test_tensor, 0), aten_op, exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_tensor", Unsqueeze.test_parameters) @common.SkipIfNoModelConverter -def test_unsqueeze_vgf_FP(test_tensor: torch.Tensor): +def test_unsqueeze_vgf_no_quant(test_tensor: torch.Tensor): for i in range(-test_tensor[0].dim() - 1, test_tensor[0].dim() + 1): pipeline = VgfPipeline[input_t1]( - Unsqueeze(), (*test_tensor, i), aten_op, tosa_version="TOSA-1.0+FP" + Unsqueeze(), + (*test_tensor, i), + aten_op, + quantize=False, ) pipeline.run() @common.parametrize("test_tensor", Unsqueeze.test_parameters) @common.SkipIfNoModelConverter -def test_unsqueeze_vgf_INT(test_tensor: torch.Tensor): +def test_unsqueeze_vgf_quant(test_tensor: torch.Tensor): for i in range(-test_tensor[0].dim() - 1, test_tensor[0].dim() + 1): pipeline = VgfPipeline[input_t1]( Unsqueeze(), (*test_tensor, i), aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_upsample_bilinear2d.py b/backends/arm/test/ops/test_upsample_bilinear2d.py index 95e69bc5204..edac736981a 100644 --- a/backends/arm/test/ops/test_upsample_bilinear2d.py +++ b/backends/arm/test/ops/test_upsample_bilinear2d.py @@ -7,7 +7,6 @@ import torch from executorch.backends.arm.test import common - from executorch.backends.arm.test.tester.test_pipeline import ( EthosU85PipelineINT, OpNotSupportedPipeline, @@ -196,6 +195,24 @@ def test_upsample_bilinear2d_vec_tosa_INT_Upsample( pipeline.run() +@common.parametrize("test_data", test_data_suite_tosa) +def test_upsample_bilinear2d_vec_tosa_INT_a16w8( + test_data: torch.Tensor, +): + """Test upsample_bilinear2d vector op with int16 I/O quantization for TOSA INT.""" + test_data, size, scale_factor, compare_outputs = test_data + pipeline = TosaPipelineINT[input_t1]( + Upsample(size, scale_factor), + (test_data,), + aten_op, + exir_op=[], + tosa_extensions=["int16"], + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + @common.parametrize("test_data", test_data_u55) @common.XfailIfNoCorstone300 def test_upsample_bilinear2d_vec_U55_INT_Upsample_not_delegated( @@ -259,7 +276,6 @@ def test_upsample_bilinear2d_vec_U85_INT_Upsample(test_data: input_t1): Upsample(size, scale_factor), (test_data,), aten_op, - run_on_fvp=True, qtol=1, use_to_edge_transform_and_lower=True, ) @@ -279,7 +295,6 @@ def test_upsample_bilinear2d_vec_U85_INT_Interpolate( Interpolate(size, scale_factor), (test_data,), aten_op, - run_on_fvp=True, qtol=1, use_to_edge_transform_and_lower=True, ) @@ -299,7 +314,6 @@ def test_upsample_bilinear2d_vec_U85_INT_UpsamplingBilinear2d( UpsamplingBilinear2d(size, scale_factor), (test_data,), aten_op, - run_on_fvp=True, qtol=1, use_to_edge_transform_and_lower=True, ) @@ -308,16 +322,39 @@ def test_upsample_bilinear2d_vec_U85_INT_UpsamplingBilinear2d( pipeline.run() +@common.parametrize("test_data", test_data_suite_Uxx) +@common.XfailIfNoCorstone320 +def test_upsample_bilinear2d_vec_U85_INT_a16w8( + test_data: input_t1, +): + """Test upsample_bilinear2d vec op with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + data, size, scale_factor, compare_outputs = test_data + + pipeline = EthosU85PipelineINT[input_t1]( + UpsamplingBilinear2d(size, scale_factor), + (data,), + aten_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + @common.parametrize("test_data", test_data_suite_tosa) @common.SkipIfNoModelConverter -def test_upsample_bilinear2d_vgf_FP_UpsamplingBilinear2d(test_data: torch.Tensor): +def test_upsample_bilinear2d_UpsamplingBilinear2d_vgf_no_quant( + test_data: torch.Tensor, +): data, size, scale_factor, compare = test_data pipeline = VgfPipeline[input_t1]( UpsamplingBilinear2d(size, scale_factor), (data,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) if not compare: pipeline.pop_stage(-1) @@ -326,14 +363,14 @@ def test_upsample_bilinear2d_vgf_FP_UpsamplingBilinear2d(test_data: torch.Tensor @common.parametrize("test_data", test_data_suite_tosa) @common.SkipIfNoModelConverter -def test_upsample_bilinear2d_vgf_FP_Upsample(test_data: torch.Tensor): +def test_upsample_bilinear2d_Upsample_vgf_no_quant(test_data: torch.Tensor): data, size, scale_factor, compare = test_data pipeline = VgfPipeline[input_t1]( Upsample(size, scale_factor), (data,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) if not compare: pipeline.pop_stage(-1) @@ -342,14 +379,14 @@ def test_upsample_bilinear2d_vgf_FP_Upsample(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite_tosa) @common.SkipIfNoModelConverter -def test_upsample_bilinear2d_vgf_FP_Interpolate(test_data: torch.Tensor): +def test_upsample_bilinear2d_Interpolate_vgf_no_quant(test_data: torch.Tensor): data, size, scale_factor, compare = test_data pipeline = VgfPipeline[input_t1]( Interpolate(size, scale_factor), (data,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) if not compare: pipeline.pop_stage(-1) @@ -358,14 +395,16 @@ def test_upsample_bilinear2d_vgf_FP_Interpolate(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite_tosa) @common.SkipIfNoModelConverter -def test_upsample_bilinear2d_vgf_INT_UpsamplingBilinear2d(test_data: torch.Tensor): +def test_upsample_bilinear2d_UpsamplingBilinear2d_vgf_quant( + test_data: torch.Tensor, +): data, size, scale_factor, compare = test_data pipeline = VgfPipeline[input_t1]( UpsamplingBilinear2d(size, scale_factor), (data,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) if not compare: pipeline.pop_stage(-1) @@ -374,14 +413,14 @@ def test_upsample_bilinear2d_vgf_INT_UpsamplingBilinear2d(test_data: torch.Tenso @common.parametrize("test_data", test_data_suite_tosa) @common.SkipIfNoModelConverter -def test_upsample_bilinear2d_vgf_INT_Upsample(test_data: torch.Tensor): +def test_upsample_bilinear2d_Upsample_vgf_quant(test_data: torch.Tensor): data, size, scale_factor, compare = test_data pipeline = VgfPipeline[input_t1]( Upsample(size, scale_factor), (data,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) if not compare: pipeline.pop_stage(-1) @@ -390,14 +429,14 @@ def test_upsample_bilinear2d_vgf_INT_Upsample(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite_tosa) @common.SkipIfNoModelConverter -def test_upsample_bilinear2d_vgf_INT_Interpolate(test_data: torch.Tensor): +def test_upsample_bilinear2d_Interpolate_vgf_quant(test_data: torch.Tensor): data, size, scale_factor, compare = test_data pipeline = VgfPipeline[input_t1]( Interpolate(size, scale_factor), (data,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) if not compare: pipeline.pop_stage(-1) diff --git a/backends/arm/test/ops/test_upsample_nearest2d.py b/backends/arm/test/ops/test_upsample_nearest2d.py index a39adefc168..5da590398f4 100644 --- a/backends/arm/test/ops/test_upsample_nearest2d.py +++ b/backends/arm/test/ops/test_upsample_nearest2d.py @@ -195,16 +195,32 @@ def test_upsample_nearest2d_vec_tosa_INT_interpolate(test_data: torch.Tensor): pipeline.run() +@common.parametrize("test_data", test_data_suite) +def test_upsample_nearest2d_vec_tosa_INT_a16w8(test_data: torch.Tensor): + """Test upsample_nearest2d vector op with int16 I/O quantization for TOSA INT.""" + test_data, size, scale_factor, compare_outputs = test_data() + pipeline = TosaPipelineINT[input_t1]( + Upsample(size, scale_factor), + (test_data,), + aten_op, + exir_op=[], + tosa_extensions=["int16"], + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_upsample_nearest2d_vgf_FP(test_data: torch.Tensor): +def test_upsample_nearest2d_vgf_no_quant(test_data: torch.Tensor): data, size, scale_factor, compare = test_data() pipeline = VgfPipeline[input_t1]( UpsamplingNearest2d(size, scale_factor), (data,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) if not compare: pipeline.pop_stage(-1) @@ -213,14 +229,14 @@ def test_upsample_nearest2d_vgf_FP(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_upsample_nearest2d_vgf_FP_nearest(test_data: torch.Tensor): +def test_upsample_nearest2d_nearest_vgf_no_quant(test_data: torch.Tensor): data, size, scale_factor, compare = test_data() pipeline = VgfPipeline[input_t1]( Upsample(size, scale_factor), (data,), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) if not compare: pipeline.pop_stage(-1) @@ -229,13 +245,15 @@ def test_upsample_nearest2d_vgf_FP_nearest(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_upsample_nearest2d_vgf_FP_interpolate(test_data: torch.Tensor): +def test_upsample_nearest2d_interpolate_vgf_FP(test_data: torch.Tensor): data, size, scale_factor, compare = test_data() pipeline = VgfPipeline[input_t1]( Interpolate(size, scale_factor), (data,), aten_op, exir_op, + quantize=False, + # Override tosa version to test FP-only path tosa_version="TOSA-1.0+FP", ) if not compare: @@ -245,14 +263,14 @@ def test_upsample_nearest2d_vgf_FP_interpolate(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_upsample_nearest2d_vgf_INT(test_data: torch.Tensor): +def test_upsample_nearest2d_vgf_quant(test_data: torch.Tensor): data, size, scale_factor, compare = test_data() pipeline = VgfPipeline[input_t1]( UpsamplingNearest2d(size, scale_factor), (data,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) if not compare: pipeline.pop_stage(-1) @@ -261,14 +279,14 @@ def test_upsample_nearest2d_vgf_INT(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_upsample_nearest2d_vgf_INT_nearest(test_data: torch.Tensor): +def test_upsample_nearest2d_nearest_vgf_quant(test_data: torch.Tensor): data, size, scale_factor, compare = test_data() pipeline = VgfPipeline[input_t1]( Upsample(size, scale_factor), (data,), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) if not compare: pipeline.pop_stage(-1) @@ -277,13 +295,15 @@ def test_upsample_nearest2d_vgf_INT_nearest(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_upsample_nearest2d_vgf_INT_interpolate(test_data: torch.Tensor): +def test_upsample_nearest2d_interpolate_vgf_INT(test_data: torch.Tensor): data, size, scale_factor, compare = test_data() pipeline = VgfPipeline[input_t1]( Interpolate(size, scale_factor), (data,), aten_op, exir_op, + quantize=True, + # Override tosa version to test INT-only path tosa_version="TOSA-1.0+INT", ) if not compare: diff --git a/backends/arm/test/ops/test_var.py b/backends/arm/test/ops/test_var.py index 9567f90c480..73bf2165b23 100644 --- a/backends/arm/test/ops/test_var.py +++ b/backends/arm/test/ops/test_var.py @@ -194,7 +194,6 @@ def test_var_dim_u55_INT_no_dim(test_data: Tuple): (test_data,), aten_ops=[], exir_ops=[], - run_on_fvp=True, ) pipeline.run() @@ -208,31 +207,34 @@ def test_var_dim_u85_INT_no_dim(test_data: Tuple): (test_data,), aten_ops=[], exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", Var.test_parameters) @common.SkipIfNoModelConverter -def test_var_dim_vgf_FP_no_dim(test_data: Tuple): +def test_var_dim_no_dim_vgf_no_quant(test_data: Tuple): data, keepdim, correction = test_data() pipeline = VgfPipeline[input_t1]( - Var(keepdim, correction), (data,), [], [], tosa_version="TOSA-1.0+FP" + Var(keepdim, correction), + (data,), + [], + [], + quantize=False, ) pipeline.run() @common.parametrize("test_data", Var.test_parameters) @common.SkipIfNoModelConverter -def test_var_dim_vgf_INT_no_dim(test_data: Tuple): +def test_var_dim_no_dim_vgf_quant(test_data: Tuple): data, keepdim, correction = test_data() pipeline = VgfPipeline[input_t1]( Var(keepdim, correction), (data,), [], [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -276,7 +278,6 @@ def test_var_dim_u55_INT(test_data: Tuple): (test_data,), aten_ops=[], exir_ops=[], - run_on_fvp=True, ) pipeline.run() @@ -290,31 +291,34 @@ def test_var_dim_u85_INT(test_data: Tuple): (test_data,), aten_ops=[], exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", VarDim.test_parameters) @common.SkipIfNoModelConverter -def test_var_dim_vgf_FP(test_data: Tuple): +def test_var_dim_vgf_no_quant(test_data: Tuple): data, dim, keepdim, unbiased = test_data() pipeline = VgfPipeline[input_t1]( - VarDim(dim, keepdim, unbiased), (data,), [], [], tosa_version="TOSA-1.0+FP" + VarDim(dim, keepdim, unbiased), + (data,), + [], + [], + quantize=False, ) pipeline.run() @common.parametrize("test_data", VarDim.test_parameters) @common.SkipIfNoModelConverter -def test_var_dim_vgf_INT(test_data: Tuple): +def test_var_dim_vgf_quant(test_data: Tuple): data, dim, keepdim, unbiased = test_data() pipeline = VgfPipeline[input_t1]( VarDim(dim, keepdim, unbiased), (data,), [], [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -348,7 +352,17 @@ def test_var_dim_tosa_INT_correction(test_data: Tuple): pipeline.run() -@common.parametrize("test_data", VarCorrection.test_parameters) +# TODO: Xfail "var_3d_dims_keep_dim_0_correction" until the Ethos-U Vela compiler ships commit +# 642f7517d3a6bd053032e1942822f6e38ccd546f. That patch fixes the bug that causes the test to fail. +@common.parametrize( + "test_data", + VarCorrection.test_parameters, + xfails={ + "var_3d_dims_keep_dim_0_correction": ( + "Blocked by Vela commit 642f7517d3a6bd053032e1942822f6e38ccd546f" + ), + }, +) @common.XfailIfNoCorstone300 def test_var_dim_u55_INT_correction(test_data: Tuple): test_data, dim, keepdim, correction = test_data() @@ -357,7 +371,6 @@ def test_var_dim_u55_INT_correction(test_data: Tuple): (test_data,), aten_ops=[], exir_ops=[], - run_on_fvp=True, ) pipeline.run() @@ -371,30 +384,33 @@ def test_var_dim_u85_INT_correction(test_data: Tuple): (test_data,), aten_ops=[], exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", VarCorrection.test_parameters) @common.SkipIfNoModelConverter -def test_var_dim_vgf_FP_correction(test_data: Tuple): +def test_var_dim_correction_vgf_no_quant(test_data: Tuple): data, dim, keepdim, corr = test_data() pipeline = VgfPipeline[input_t1]( - VarCorrection(dim, keepdim, corr), (data,), [], [], tosa_version="TOSA-1.0+FP" + VarCorrection(dim, keepdim, corr), + (data,), + [], + [], + quantize=False, ) pipeline.run() @common.parametrize("test_data", VarCorrection.test_parameters) @common.SkipIfNoModelConverter -def test_var_dim_vgf_INT_correction(test_data: Tuple): +def test_var_dim_correction_vgf_quant(test_data: Tuple): data, dim, keepdim, corr = test_data() pipeline = VgfPipeline[input_t1]( VarCorrection(dim, keepdim, corr), (data,), [], [], - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index fb0ba54436e..99df4f2f2f7 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -9,7 +9,6 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, @@ -51,6 +50,10 @@ class View(torch.nn.Module): "rand_4d_4_3": lambda: (torch.rand(5, 10, 1, 1), (1, 25, 2)), "rand_4d_4_2": lambda: (torch.rand(2, 50, 1, 1), (1, 100)), "rand_4d_2_4_same": lambda: (torch.rand(2, 3, 2, 3), (2, 3, 3, 2)), + "rand_4d_5d": lambda: (torch.rand(1, 3, 4, 5), (1, 1, 4, 5, -1)), + "rand_5d_5d": lambda: (torch.rand(1, 1, 4, 5, 6), (1, 1, 4, -1, 6)), + "rand_5d_3d": lambda: (torch.rand(1, 1, 4, 5, 6), (2, 3, -1)), + "rand_3d_5d": lambda: (torch.rand(4, 5, 6), (1, 1, 2, -1, 3)), } rank_product_too_large = { @@ -104,26 +107,26 @@ def test_view_u55_INT(test_data: Tuple): @common.parametrize("test_data", View.needs_transpose_tests) @common.SkipIfNoModelConverter -def test_view_vgf_FP(test_data: Tuple): +def test_view_vgf_no_quant(test_data: Tuple): test_tensor, new_shape = test_data() pipeline = VgfPipeline[input_t1]( View(new_shape), (test_tensor,), aten_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_data", View.needs_transpose_tests) @common.SkipIfNoModelConverter -def test_view_vgf_INT(test_data: Tuple): +def test_view_vgf_quant(test_data: Tuple): test_tensor, new_shape = test_data() pipeline = VgfPipeline[input_t1]( View(new_shape), (test_tensor,), aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) pipeline.run() @@ -176,9 +179,6 @@ def get_symmetric_a16w8_view_quantizer(per_channel_quantization=False): @common.parametrize("test_data", View.needs_transpose_tests) -@pytest.mark.xfail( - reason="missing int16 view ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13977" -) def test_view_16a8w_tosa_INT(test_data: Tuple): """Test view operation with 16A8W quantization (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -205,9 +205,6 @@ def test_view_16a8w_tosa_INT(test_data: Tuple): @common.parametrize("test_data", View.needs_transpose_tests) @common.XfailIfNoCorstone300 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 view operations" -) def test_view_16a8w_u55_INT16(test_data: Tuple): """Test view operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -220,7 +217,6 @@ def test_view_16a8w_u55_INT16(test_data: Tuple): exir_ops=[], per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.change_args( @@ -234,9 +230,6 @@ def test_view_16a8w_u55_INT16(test_data: Tuple): @common.parametrize("test_data", View.needs_transpose_tests) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 view operations" -) def test_view_16a8w_u85_INT16(test_data: Tuple): """Test view operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -249,7 +242,6 @@ def test_view_16a8w_u85_INT16(test_data: Tuple): exir_ops=[], per_channel_quantization=per_channel_quantization, use_to_edge_transform_and_lower=True, - run_on_fvp=True, ) pipeline.change_args( diff --git a/backends/arm/test/ops/test_where.py b/backends/arm/test/ops/test_where.py index ea036d26361..50a7aef657d 100644 --- a/backends/arm/test/ops/test_where.py +++ b/backends/arm/test/ops/test_where.py @@ -6,7 +6,6 @@ from typing import List, Tuple import torch - from executorch.backends.arm.quantizer import ( EthosUQuantizer, get_symmetric_quantization_config, @@ -65,6 +64,30 @@ def forward( return torch.where(self.condition(input_), input_, other_) +class ConstWhere(torch.nn.Module): + + def __init__(self, buffer: torch.Tensor, dtype: torch.dtype): + super().__init__() + self.buffer = buffer + self.dtype = dtype + self.min = torch.nn.Buffer(torch.tensor(0.0, dtype=self.dtype)) + self.input_1 = torch.nn.Buffer(torch.tensor(-1.0, dtype=self.dtype)) + self.input_2 = torch.nn.Buffer(torch.tensor(1.0, dtype=self.dtype)) + + def get_inputs(self): + return (torch.rand(self.buffer.size(), dtype=self.dtype),) + + def forward(self, input: torch.Tensor): + return ( + torch.where( + self.buffer > self.min, + self.input_1, + self.input_2, + ) + + input + ) + + def tensor_condition(input: torch.Tensor): return input > torch.zeros_like(input) @@ -128,6 +151,11 @@ def scalar_condition(input: torch.Tensor): scalar_condition, ) +const_float32 = ConstWhere( + buffer=torch.tensor([[1.0, -1.0], [-1.0, 1.0]]), + dtype=torch.float32, +) + test_modules_common = { "two_dim_tensor_cond": lambda: two_dim_tensor_cond, "three_dim_tensor_cond": lambda: three_dim_tensor_cond, @@ -135,12 +163,16 @@ def scalar_condition(input: torch.Tensor): "two_dim_scalar_cond": lambda: two_dim_scalar_cond, "three_dim_scalar_cond": lambda: three_dim_scalar_cond, "float32_scalar_cond": lambda: float32_scalar_cond, + "const_float32": lambda: const_float32, } test_modules_FP = { **test_modules_common, - "float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype, "float32_tensor_cond_tuple_dtype_bool": lambda: float32_tensor_cond_tuple_dtype_bool, +} + +test_modules_FP_unsupported_dtype = { + "float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype, "int32_scalar_cond": lambda: int32_scalar_cond, } @@ -162,6 +194,17 @@ def test_where_self_tosa_FP(test_module): pipeline.run() +@common.parametrize("test_module", test_modules_FP_unsupported_dtype) +def test_where_self_tosa_FP_unsupported_dtype(test_module): + pipeline = OpNotSupportedPipeline[input_t]( + test_module(), + test_module().get_inputs(), + {exir_op: 1}, + n_expected_delegates=1, # condition can be delegated + ) + pipeline.run() + + @common.parametrize("test_module", test_modules_INT) def test_where_self_tosa_INT(test_module): pipeline = TosaPipelineINT[input_t]( @@ -169,7 +212,6 @@ def test_where_self_tosa_INT(test_module): test_module().get_inputs(), aten_op, exir_op, - symmetric_io_quantization=True, ) pipeline.run() @@ -212,7 +254,6 @@ def test_where_self_u85_INT(test_module): test_module().get_inputs(), aten_op, exir_op, - run_on_fvp=True, symmetric_io_quantization=True, ) pipeline.run() @@ -220,26 +261,25 @@ def test_where_self_u85_INT(test_module): @common.parametrize("test_module", test_modules_FP) @common.SkipIfNoModelConverter -def test_where_self_vgf_FP(test_module): +def test_where_self_vgf_no_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), aten_op, exir_op, - tosa_version="TOSA-1.0+FP", + quantize=False, ) pipeline.run() @common.parametrize("test_module", test_modules_INT) @common.SkipIfNoModelConverter -def test_where_self_vgf_INT(test_module): +def test_where_self_vgf_quant(test_module): pipeline = VgfPipeline[input_t]( test_module(), test_module().get_inputs(), aten_op, exir_op, - tosa_version="TOSA-1.0+INT", - symmetric_io_quantization=True, + quantize=True, ) pipeline.run() diff --git a/backends/arm/test/ops/test_while.py b/backends/arm/test/ops/test_while.py new file mode 100644 index 00000000000..50b701fcbfc --- /dev/null +++ b/backends/arm/test/ops/test_while.py @@ -0,0 +1,201 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Tuple + +import torch +import torch.fx + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, +) + +input_single = Tuple[torch.Tensor] +input_double = Tuple[torch.Tensor, torch.Tensor] + + +class WhileTwoInputsTwoOutputs(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, lhs: torch.Tensor, rhs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + def cond_fn(lhs_val: torch.Tensor, rhs_val: torch.Tensor) -> torch.Tensor: + total = torch.sum(rhs_val) + zero = torch.zeros_like(total) + return torch.gt(total, zero).squeeze() + + def body_fn( + lhs_val: torch.Tensor, rhs_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + next_lhs = torch.add(lhs_val, rhs_val) + next_rhs = torch.sub(rhs_val, torch.full((1,), 1.0)) + return (next_lhs, next_rhs) + + result = torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + (lhs, rhs), + (), + ) + return result # type: ignore + + +class WhileOneInputOneBufferTwoOutputs(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("threshold", torch.tensor((30.0,))) + + def forward(self, value: torch.Tensor) -> torch.Tensor: + def cond_fn(value: torch.Tensor, limit: torch.Tensor) -> torch.Tensor: + total = value.sum() + return torch.lt(total, limit).squeeze() + + def body_fn( + value: torch.Tensor, limit: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + return (torch.add(value, value), limit.clone()) + + result = torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + (value, self.threshold), + (), + ) + return result # type: ignore + + +class DecreasingOutput(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, value: torch.Tensor) -> torch.Tensor: + def cond_fn(value: torch.Tensor) -> torch.Tensor: + total = value.sum() + return torch.gt(total, torch.full((1,), 60.0)).squeeze() + + def body_fn(value: torch.Tensor) -> Tuple[torch.Tensor]: + return (torch.div(value, torch.full((1,), 2.0)),) + + result = torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + (value,), + (), + ) + return result[0] # type: ignore + + +class WhileAdditionalArg(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("threshold", torch.tensor((300.0,))) + + def forward(self, value: torch.Tensor) -> torch.Tensor: + def cond_fn(value: torch.Tensor, limit: torch.Tensor) -> torch.Tensor: + total = value.sum() + return torch.lt(total, limit).squeeze() + + def body_fn(value: torch.Tensor, limit: torch.Tensor) -> tuple[torch.Tensor]: + return (torch.add(value, value),) + + result = torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + (value,), + (self.threshold,), + ) + return result # type: ignore + + +class WhileSingleCapturedOutput(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("threshold", torch.tensor((200.0,))) + + def forward(self, value: torch.Tensor) -> torch.Tensor: + def cond_fn(value: torch.Tensor, limit: torch.Tensor) -> torch.Tensor: + total = value.sum() + return torch.lt(total, limit).squeeze() + + def body_fn( + value: torch.Tensor, limit: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + return (torch.add(value, value), limit.clone()) + + result = torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + (value, self.threshold), + (), + ) + return result[0] # type: ignore + + +def _single_input_case( + module_factory: Callable[[], torch.nn.Module], +) -> Callable[[], Tuple[torch.nn.Module, input_single]]: + def _create() -> Tuple[torch.nn.Module, input_single]: + return module_factory(), (torch.ones(2, 3, 4, 6),) + + return _create + + +def _dual_input_case( + module_factory: Callable[[], torch.nn.Module], +) -> Callable[[], Tuple[torch.nn.Module, input_double]]: + def _create() -> Tuple[torch.nn.Module, input_double]: + return module_factory(), (torch.zeros(2, 3), torch.full((2, 3), -2.0)) + + return _create + + +test_cases: dict[str, Callable[[], Tuple[torch.nn.Module, Tuple]]] = { + "two_in_two_out": _dual_input_case(WhileTwoInputsTwoOutputs), + "one_in_one_buffer_two_out": _single_input_case(WhileOneInputOneBufferTwoOutputs), + "decreasing_output": _single_input_case(DecreasingOutput), + "additional_arg": _single_input_case(WhileAdditionalArg), + "two_in_one_captured_out": _single_input_case(WhileSingleCapturedOutput), +} + + +@common.parametrize( + "case", + test_cases, +) +def test_while_loop_tosa_FP(case: Callable[[], Tuple[torch.nn.Module, Tuple]]): + module, example_inputs = case() + pipeline = TosaPipelineFP[tuple]( + module, + example_inputs, + "torch.ops.higher_order.while_loop", + tosa_extensions=["cf"], + ) + pipeline.run() + + +@common.parametrize( + "case", + test_cases, +) +def test_while_loop_tosa_INT(case: Callable[[], Tuple[torch.nn.Module, Tuple]]): + module, example_inputs = case() + pipeline = TosaPipelineINT[tuple]( + module, + example_inputs, + "torch.ops.higher_order.while_loop", + tosa_extensions=["cf"], + ) + pipeline.add_stage_after( + "to_edge_transform_and_lower", + ArmTester.check_not, + pipeline.tester, + ["torch.ops.higher_order.while_loop"], + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_zeros.py b/backends/arm/test/ops/test_zeros.py index caee678282a..7e1609e8976 100644 --- a/backends/arm/test/ops/test_zeros.py +++ b/backends/arm/test/ops/test_zeros.py @@ -65,7 +65,10 @@ def test_zeros_tosa_INT(test_data: test_data_t): input_data(), ZerosAdd.aten_op, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -79,7 +82,10 @@ def test_zeros_u55_INT(test_data: test_data_t): ZerosAdd.aten_op, use_to_edge_transform_and_lower=True, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -92,8 +98,11 @@ def test_zeros_u85_INT(test_data: test_data_t): input_data(), ZerosAdd.aten_op, use_to_edge_transform_and_lower=True, - ).dump_artifact("to_edge_transform_and_lower") - pipeline.pop_stage("check.quant_nodes") + ) + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -118,10 +127,13 @@ def test_zeros_tosa_INT_not_delegated(test_data: test_data_t): ZerosAdd.test_data, ) @common.SkipIfNoModelConverter -def test_zeros_vgf_FP(test_data: test_data_t): +def test_zeros_vgf_no_quant(test_data: test_data_t): input_data, init_data = test_data pipeline = VgfPipeline[input_t]( - ZerosAdd(*init_data), input_data(), ZerosAdd.aten_op, tosa_version="TOSA-1.0+FP" + ZerosAdd(*init_data), + input_data(), + ZerosAdd.aten_op, + quantize=False, ) pipeline.run() @@ -131,13 +143,16 @@ def test_zeros_vgf_FP(test_data: test_data_t): ZerosAdd.test_data, ) @common.SkipIfNoModelConverter -def test_zeros_vgf_INT(test_data: test_data_t): +def test_zeros_vgf_quant(test_data: test_data_t): input_data, init_data = test_data pipeline = VgfPipeline[input_t]( ZerosAdd(*init_data), input_data(), ZerosAdd.aten_op, - tosa_version="TOSA-1.0+INT", + quantize=True, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/passes/test_broadcast_args_pass.py b/backends/arm/test/passes/test_broadcast_args_pass.py index 719a0ddd622..e47b0d3a72b 100644 --- a/backends/arm/test/passes/test_broadcast_args_pass.py +++ b/backends/arm/test/passes/test_broadcast_args_pass.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import operator -from typing import Tuple +from typing import Callable, Tuple import torch from executorch.backends.arm._passes import BroadcastArgsPass @@ -12,17 +12,19 @@ from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline -input_t = Tuple[torch.Tensor] # Input x +input_t = Tuple[torch.Tensor, torch.Tensor] class NeedsMultipleBroadcastsModel(torch.nn.Module): test_data = (torch.rand(1, 10), torch.rand(10, 1)) - def __init__(self, op: operator): + def __init__( + self, op: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + ) -> None: self.op = op super().__init__() - def forward(self, x: torch.Tensor, y: torch.Tensor): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return self.op(x, y) @@ -50,5 +52,6 @@ def test_multiple_broacasts_model(module: NeedsMultipleBroadcastsModel): ops_not_before_pass=ops_not_before_pass, ops_after_pass=ops_after_pass, pass_list=[BroadcastArgsPass], + tosa_extensions=["u55"], ) pipeline.run() diff --git a/backends/arm/test/passes/test_cast_int64_pass.py b/backends/arm/test/passes/test_cast_int64_pass.py index 7832fd87ed9..afcc0d1db36 100644 --- a/backends/arm/test/passes/test_cast_int64_pass.py +++ b/backends/arm/test/passes/test_cast_int64_pass.py @@ -21,7 +21,7 @@ class Int64Model(torch.nn.Module): "rand": (torch.rand(4),), } - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x + 3 diff --git a/backends/arm/test/passes/test_convert_expand_copy_to_repeat.py b/backends/arm/test/passes/test_convert_expand_copy_to_repeat.py index aa877c355bd..899472b2e8a 100644 --- a/backends/arm/test/passes/test_convert_expand_copy_to_repeat.py +++ b/backends/arm/test/passes/test_convert_expand_copy_to_repeat.py @@ -20,17 +20,17 @@ class Expand(torch.nn.Module): Basic expand model using torch.Tensor.expand function """ - def __init__(self): - super(Expand, self).__init__() + def __init__(self) -> None: + super().__init__() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.expand(3, 4) def get_inputs(self) -> input_t: return (torch.rand(3, 1),) -def test_expand_to_repeat_tosa_INT(): +def test_expand_to_repeat_tosa_INT() -> None: module = Expand() pipeline = PassPipeline[input_t]( module, diff --git a/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py b/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py index ddb31625849..5366e5453c1 100644 --- a/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py +++ b/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple, Union +from typing import Callable, ClassVar, Dict, Tuple, Union import pytest @@ -22,6 +22,10 @@ input_t1 = Tuple[torch.Tensor] # Input x input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y +Scalar = Union[bool, float, int] +ArangeNoneParam = Tuple[Callable[[], input_t1], Tuple[Scalar, Scalar, Scalar]] +FullNoneParam = Tuple[Callable[[], input_t1], Tuple[Tuple[int, ...], Scalar]] + ##################################################### ## Test arange(dtype=int64) -> arange(dtype=int32) ## @@ -29,11 +33,10 @@ class ArangeDefaultIncrementViewLessThan(torch.nn.Module): - - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return (torch.arange(10, dtype=torch.int64) + 1).view(-1, 1) < x - test_data = { + test_data: ClassVar[Dict[str, input_t1]] = { "randint": ( torch.randint( 0, @@ -46,7 +49,9 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", ArangeDefaultIncrementViewLessThan.test_data) -def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): +def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP( + test_data: input_t1, +) -> None: module = ArangeDefaultIncrementViewLessThan() aten_ops_checks = [ "torch.ops.aten.lt.Tensor", @@ -67,7 +72,9 @@ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP(test_data: inp @common.parametrize("test_data", ArangeDefaultIncrementViewLessThan.test_data) -def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1): +def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT( + test_data: input_t1, +) -> None: module = ArangeDefaultIncrementViewLessThan() aten_ops_checks = [ "torch.ops.aten.lt.Tensor", @@ -83,16 +90,14 @@ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT(test_data: in aten_ops_checks, exir_ops_checks, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() class ArangeStartIncrementViewLessThan(torch.nn.Module): - - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return (torch.arange(0, 10, dtype=torch.int64) + 1).view(-1, 1) < x - test_data = { + test_data: ClassVar[Dict[str, input_t1]] = { "randint": ( torch.randint( 0, @@ -105,7 +110,9 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", ArangeStartIncrementViewLessThan.test_data) -def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): +def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP( + test_data: input_t1, +) -> None: module = ArangeStartIncrementViewLessThan() aten_ops_checks = [ "torch.ops.aten.lt.Tensor", @@ -126,7 +133,9 @@ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP(test_data: input @common.parametrize("test_data", ArangeStartIncrementViewLessThan.test_data) -def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1): +def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT( + test_data: input_t1, +) -> None: module = ArangeStartIncrementViewLessThan() aten_ops_checks = [ "torch.ops.aten.lt.Tensor", @@ -142,16 +151,14 @@ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT(test_data: inpu aten_ops_checks, exir_ops_checks, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() class ArangeStartStepIncrementViewLessThan(torch.nn.Module): - - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return (torch.arange(0, 10, 2, dtype=torch.int64) + 1).view(-1, 1) < x - test_data = { + test_data: ClassVar[Dict[str, input_t1]] = { "randint": ( torch.randint( 0, @@ -166,7 +173,7 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", ArangeStartStepIncrementViewLessThan.test_data) def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_FP( test_data: input_t1, -): +) -> None: module = ArangeStartStepIncrementViewLessThan() aten_ops_checks = [ "torch.ops.aten.lt.Tensor", @@ -189,7 +196,7 @@ def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_FP( @common.parametrize("test_data", ArangeStartStepIncrementViewLessThan.test_data) def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_INT( test_data: input_t1, -): +) -> None: module = ArangeStartStepIncrementViewLessThan() aten_ops_checks = [ "torch.ops.aten.lt.Tensor", @@ -205,7 +212,6 @@ def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_INT( aten_ops_checks, exir_ops_checks, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -225,7 +231,7 @@ def __init__(self, start: float, stop: float, step: float): def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.arange(*self.args) + x - test_data = { + test_data: ClassVar[Dict[str, ArangeNoneParam]] = { "int64": (lambda: (torch.randn(10, 1),), (0, 10, 1)), "float32_start": (lambda: (torch.randn(10, 1),), (0.0, 10, 1)), "float32_stop": (lambda: (torch.randn(10, 1),), (0, 10.0, 1)), @@ -238,11 +244,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @common.parametrize("test_data", ArangeAddDtypeNone.test_data) -def test_arange_dtype_none_tosa_FP(test_data): - input_data, init_data = test_data +def test_arange_dtype_none_tosa_FP(test_data: ArangeNoneParam) -> None: + input_factory, init_data = test_data pipeline = TosaPipelineFP[input_t1]( ArangeAddDtypeNone(*init_data), - input_data(), + input_factory(), ArangeAddDtypeNone.aten_op, ArangeAddDtypeNone.exir_op, ) @@ -250,11 +256,11 @@ def test_arange_dtype_none_tosa_FP(test_data): @common.parametrize("test_data", ArangeAddDtypeNone.test_data) -def test_arange_dtype_none_tosa_INT(test_data): - input_data, init_data = test_data +def test_arange_dtype_none_tosa_INT(test_data: ArangeNoneParam) -> None: + input_factory, init_data = test_data pipeline = TosaPipelineINT[input_t1]( ArangeAddDtypeNone(*init_data), - input_data(), + input_factory(), ArangeAddDtypeNone.aten_op, ArangeAddDtypeNone.exir_op, ) @@ -268,8 +274,7 @@ def test_arange_dtype_none_tosa_INT(test_data): class FullIncrementViewMulXLessThanY(torch.nn.Module): - - def forward(self, x: torch.Tensor, y: torch.Tensor): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return ( ( torch.full( @@ -286,7 +291,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): * x ) < y - test_data = { + test_data: ClassVar[Dict[str, input_t2]] = { "randint": ( torch.randint( 0, @@ -305,7 +310,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): @common.parametrize("test_data", FullIncrementViewMulXLessThanY.test_data) -def test_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): +def test_convert_full_int64_dtype_to_int32_pass_tosa_FP( + test_data: input_t2, +) -> None: """ There are four int64 placeholders in the original graph: 1. _lifted_tensor_constant0: 1 @@ -347,7 +354,9 @@ def test_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): @common.parametrize("test_data", FullIncrementViewMulXLessThanY.test_data) -def test_convert_full_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1): +def test_convert_full_int64_dtype_to_int32_pass_tosa_INT( + test_data: input_t2, +) -> None: """ For INT profile, _lifted_tensor_constant0 is still int64 after applying ConvertInt64ConstOpsToInt32Pass(). And an int64->int32 cast is inserted at the beginning of the graph. @@ -375,13 +384,11 @@ def test_convert_full_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1): aten_ops_checks, exir_ops_checks, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() class RejectFullIncrementViewMulXLessThanY(torch.nn.Module): - - def forward(self, x: torch.Tensor, y: torch.Tensor): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return ( ( torch.full( @@ -398,7 +405,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): * x ) < y - test_data = { + test_data: ClassVar[Dict[str, input_t2]] = { "randint": ( torch.randint( 0, @@ -420,7 +427,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): @pytest.mark.xfail( reason="MLETORCH-1254: Add operator support check for aten.arange and aten.full" ) -def test_reject_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): +def test_reject_convert_full_int64_dtype_to_int32_pass_tosa_FP( + test_data: input_t2, +) -> None: module = RejectFullIncrementViewMulXLessThanY() aten_ops_checks = [ "torch.ops.aten.full.default", @@ -469,11 +478,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @common.parametrize("test_data", AddConstFullDtypeNone.test_data) -def test_full_dtype_none_tosa_FP(test_data): - input_data, init_data = test_data +def test_full_dtype_none_tosa_FP(test_data: FullNoneParam) -> None: + input_factory, init_data = test_data pipeline = TosaPipelineFP[input_t1]( AddConstFullDtypeNone(*init_data), - input_data(), + input_factory(), aten_op=[], exir_op=AddConstFullDtypeNone.exir_op, ) @@ -481,11 +490,11 @@ def test_full_dtype_none_tosa_FP(test_data): @common.parametrize("test_data", AddConstFullDtypeNone.test_data_bool) -def test_full_dtype_none_tosa_FP_bool(test_data): - input_data, init_data = test_data +def test_full_dtype_none_tosa_FP_bool(test_data: FullNoneParam) -> None: + input_factory, init_data = test_data pipeline = TosaPipelineFP[input_t1]( AddConstFullDtypeNone(*init_data), - input_data(), + input_factory(), aten_op=[], exir_op=AddConstFullDtypeNone.exir_op, ) @@ -501,9 +510,10 @@ def test_full_dtype_none_tosa_FP_bool(test_data): ) def test_full_dtype_none_tosa_INT(test_data): input_data, init_data = test_data + input_factory, init_data = test_data pipeline = TosaPipelineINT[input_t1]( AddConstFullDtypeNone(*init_data), - input_data(), + input_factory(), aten_op=[], exir_op=AddConstFullDtypeNone.exir_op, ) diff --git a/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py b/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py index ea7e03f8e21..bc7f8218183 100644 --- a/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py +++ b/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Callable, Dict, Tuple import torch from executorch.backends.arm._passes import ConvertInt64OutputOpsToInt32Pass @@ -21,20 +21,20 @@ class CastingToInt64Model(torch.nn.Module): - def __init__(self, target_dtype): + def __init__(self, target_dtype: torch.dtype) -> None: super().__init__() self.target_dtype = target_dtype - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.to(dtype=self.target_dtype) -test_data_suite_convert = { +test_data_suite_convert: Dict[str, Callable[[], Tuple[torch.Tensor, torch.dtype]]] = { "fp32_input": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float32), torch.int64), "fp16_input": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float16), torch.int64), } -test_data_suite_remove = { +test_data_suite_remove: Dict[str, Callable[[], Tuple[torch.Tensor, torch.dtype]]] = { "int32_input": lambda: ( torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32), torch.int64, @@ -42,8 +42,13 @@ def forward(self, x: torch.Tensor): } +TestDataFactory = Callable[[], Tuple[torch.Tensor, torch.dtype]] + + @common.parametrize("test_data", test_data_suite_convert) -def test_convert_or_remove_casting_to_int64_convert_tosa_FP(test_data: Tuple): +def test_convert_or_remove_casting_to_int64_convert_tosa_FP( + test_data: TestDataFactory, +) -> None: test_tensor, target_dtype = test_data() module = CastingToInt64Model(target_dtype) @@ -61,7 +66,9 @@ def test_convert_or_remove_casting_to_int64_convert_tosa_FP(test_data: Tuple): @common.parametrize("test_data", test_data_suite_remove) -def test_convert_or_remove_casting_to_int64_remove_tosa_FP(test_data: Tuple): +def test_convert_or_remove_casting_to_int64_remove_tosa_FP( + test_data: TestDataFactory, +) -> None: test_tensor, target_dtype = test_data() module = CastingToInt64Model(target_dtype) @@ -86,7 +93,7 @@ def test_convert_or_remove_casting_to_int64_remove_tosa_FP(test_data: Tuple): class Int64OutputModel(torch.nn.Module): - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: # return torch.argmax(x) # RuntimeError: Int did not match Long; But this is expected as we expect _argmax_i32 to generate int32 output # return (10 * torch.argmax(x) + 10).to(dtype=torch.int32) # [1]. This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (function _resize_output_check) return (10 * torch.argmax(x, dim=-1) + 10) + 1.5 diff --git a/backends/arm/test/passes/test_convert_int_pow_to_muls.py b/backends/arm/test/passes/test_convert_int_pow_to_muls.py deleted file mode 100644 index 4eeff845749..00000000000 --- a/backends/arm/test/passes/test_convert_int_pow_to_muls.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Tuple - -import torch -from executorch.backends.arm._passes import ConvertIntPowToMuls - -from executorch.backends.arm.test import common - -from executorch.backends.arm.test.tester.test_pipeline import PassPipeline - -input_t = Tuple[torch.nn.Module, int] # Input x - - -class Square(torch.nn.Module): - """ - Basic squaring - """ - - def forward(self, x): - return x.square() - - def get_inputs(self) -> input_t: - return (torch.rand(4, 4),) - - -class Pow(torch.nn.Module): - """ - Basic squaring - """ - - def __init__(self, exponent): - super().__init__() - self.exponent = exponent - - def forward(self, x): - return x.pow(self.exponent) - - def get_inputs(self) -> input_t: - return (torch.rand(4, 4),) - - -test_data = { - "square": (Square(), 1), - "pow_2": (Pow(2), 1), - "pow_3": (Pow(3), 2), - "pow_0": (Pow(0), 0), - "pow_neg_2": (Pow(-2), 1), -} - - -@common.parametrize("data", test_data) -def test_convert_pow_to_muls(data): - module = data[0] - nbr_muls = data[1] - pipeline = PassPipeline[input_t]( - module, - module.get_inputs(), - quantize=False, - ops_before_pass={ - "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 1, - }, - ops_not_before_pass=[], - ops_after_pass={ - "executorch_exir_dialects_edge__ops_aten_mul_Tensor": nbr_muls, - }, - ops_not_after_pass=["executorch_exir_dialects_edge__ops_pow_Tensor_Scalar"], - pass_list=[ConvertIntPowToMuls], - ) - pipeline.run() diff --git a/backends/arm/test/passes/test_convert_permute_singleton_to_view_pass.py b/backends/arm/test/passes/test_convert_permute_singleton_to_view_pass.py new file mode 100644 index 00000000000..eb395403e3f --- /dev/null +++ b/backends/arm/test/passes/test_convert_permute_singleton_to_view_pass.py @@ -0,0 +1,100 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm._passes import ConvertPermuteSingletonToViewPass +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor] + + +class PermuteSingletonAxesModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(0, 2, 3, 1) + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 1, 3, 4),) + + +def test_convert_permute_singleton_to_view_applies(): + module = PermuteSingletonAxesModule() + pipeline = PassPipeline[input_t]( + module, + PermuteSingletonAxesModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default", + ], + pass_list=[ConvertPermuteSingletonToViewPass], + ) + pipeline.run() + + +class PermuteNonSingletonModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(0, 2, 1) + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 3, 4),) + + +def test_convert_permute_singleton_to_view_skips_non_singleton(): + module = PermuteNonSingletonModule() + pipeline = PassPipeline[input_t]( + module, + PermuteNonSingletonModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + ], + pass_list=[ConvertPermuteSingletonToViewPass], + ) + pipeline.run() + + +class PermuteSameSizedNonSingletonModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(2, 1, 0) + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 1, 2),) + + +def test_convert_permute_singleton_to_view_skips_same_sized_non_singleton(): + module = PermuteSameSizedNonSingletonModule() + pipeline = PassPipeline[input_t]( + module, + PermuteSameSizedNonSingletonModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + ], + pass_list=[ConvertPermuteSingletonToViewPass], + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_convert_split_to_slice.py b/backends/arm/test/passes/test_convert_split_to_slice.py index fba52308ff0..3321693babd 100644 --- a/backends/arm/test/passes/test_convert_split_to_slice.py +++ b/backends/arm/test/passes/test_convert_split_to_slice.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch from executorch.backends.arm._passes.convert_split_to_slice import ( @@ -17,6 +17,10 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + class Split(torch.nn.Module): """ Basic split model using torch.split function @@ -25,7 +29,7 @@ class Split(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(10),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]: return torch.split(x, 2) @@ -37,17 +41,21 @@ class SplitTensor(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(10),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]: return x.split(2) -modules = {"split_basic": Split(), "split_tensor": SplitTensor()} +modules: Dict[str, ModuleWithInputs] = { + "split_basic": Split(), + "split_tensor": SplitTensor(), +} @common.parametrize("module", modules) -def test_split_to_slice_tosa_INT(module): +def test_split_to_slice_tosa_INT(module: ModuleWithInputs) -> None: + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), quantize=True, ops_before_pass={ diff --git a/backends/arm/test/passes/test_convert_to_clamp.py b/backends/arm/test/passes/test_convert_to_clamp.py index cc854eeacd7..b54c177e52f 100644 --- a/backends/arm/test/passes/test_convert_to_clamp.py +++ b/backends/arm/test/passes/test_convert_to_clamp.py @@ -4,10 +4,10 @@ # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import ClassVar, Dict, Tuple import torch -from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass +from executorch.backends.arm._passes.convert_to_clamp_pass import ConvertToClampPass from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline @@ -16,26 +16,26 @@ class HardTanh(torch.nn.Module): - test_data = {"rand": (torch.rand(1, 64, 64, 3),)} + test_data: ClassVar[Dict[str, input_t]] = {"rand": (torch.rand(1, 64, 64, 3),)} def __init__(self): super().__init__() self.hardtanh = torch.nn.Hardtanh() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.hardtanh(x) class ReLU(torch.nn.Module): - test_data = {"rand": (torch.rand(1, 64, 64, 3),)} + test_data: ClassVar[Dict[str, input_t]] = {"rand": (torch.rand(1, 64, 64, 3),)} def __init__(self): super().__init__() self.relu = torch.nn.ReLU() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.relu(x) @@ -45,7 +45,7 @@ def forward(self, x): @common.parametrize("test_data", HardTanh.test_data) -def test_tosa_FP_hardtahn(test_data: input_t): +def test_tosa_FP_hardtahn(test_data: input_t) -> None: module = HardTanh() op_checks_before_pass = { "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, @@ -69,7 +69,7 @@ def test_tosa_FP_hardtahn(test_data: input_t): @common.parametrize("test_data", ReLU.test_data) -def test_tosa_FP_relu(test_data: input_t): +def test_tosa_FP_relu(test_data: input_t) -> None: module = ReLU() op_checks_before_pass = { "executorch_exir_dialects_edge__ops_aten_relu_default": 1, diff --git a/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py b/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py index 4d686039456..c4aebae2292 100644 --- a/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py +++ b/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py @@ -3,16 +3,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch -from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d +from executorch.backends.arm._passes.decompose_avg_pool2d_pass import ( + DecomposeAvgPool2dPass, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline input_t = Tuple[torch.Tensor] # Input x +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + class AvgPool2dWithStride(torch.nn.Module): """ avg_pool2d model with explicit stride parameter @@ -21,7 +27,7 @@ class AvgPool2dWithStride(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(1, 3, 8, 8),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) @@ -33,7 +39,7 @@ class AvgPool2dWithoutStride(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(1, 3, 8, 8),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=3) @@ -45,11 +51,11 @@ class AvgPool2dListKernel(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(1, 3, 8, 8),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=[2, 3]) -modules = { +modules: Dict[str, ModuleWithInputs] = { "avg_pool2d_with_stride": AvgPool2dWithStride(), "avg_pool2d_without_stride": AvgPool2dWithoutStride(), "avg_pool2d_list_kernel": AvgPool2dListKernel(), @@ -57,10 +63,11 @@ def forward(self, x): @common.parametrize("module", modules) -def test_decompose_avg_pool2d_tosa_MI(module): +def test_decompose_avg_pool2d_tosa_MI(module: ModuleWithInputs) -> None: """Test that DecomposeAvgPool2d pass works correctly with and without stride parameters.""" + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), quantize=False, ops_before_pass={ @@ -70,6 +77,6 @@ def test_decompose_avg_pool2d_tosa_MI(module): # After decomposition, we should still see avg_pool2d (transformed) "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1, }, - pass_list=[DecomposeAvgPool2d], + pass_list=[DecomposeAvgPool2dPass], ) pipeline.run() diff --git a/backends/arm/test/passes/test_decompose_cosine_similarity_pass.py b/backends/arm/test/passes/test_decompose_cosine_similarity_pass.py index 80a328f39c6..8dec8408584 100644 --- a/backends/arm/test/passes/test_decompose_cosine_similarity_pass.py +++ b/backends/arm/test/passes/test_decompose_cosine_similarity_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch @@ -16,6 +16,10 @@ input_t = Tuple[torch.Tensor, torch.Tensor] +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + class CosineSimilarityModel(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(2, 3, 4), torch.rand(2, 3, 4)) @@ -24,11 +28,11 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return torch.cosine_similarity(x1, x2, dim=1, eps=1e-6) -modules = {"cosine_basic": CosineSimilarityModel()} +modules: Dict[str, ModuleWithInputs] = {"cosine_basic": CosineSimilarityModel()} @common.parametrize("module", modules) -def test_decompose_cosine_similarity_tosa_INT(module): +def test_decompose_cosine_similarity_tosa_INT(module: ModuleWithInputs) -> None: ops_after_pass = { "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 5, @@ -40,8 +44,9 @@ def test_decompose_cosine_similarity_tosa_INT(module): "executorch_exir_dialects_edge__ops_aten_reciprocal_default": 1, } + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), ops_before_pass=None, ops_not_before_pass=None, diff --git a/backends/arm/test/passes/test_decompose_div_pass.py b/backends/arm/test/passes/test_decompose_div_pass.py index b52e264bf11..3d6293b2194 100644 --- a/backends/arm/test/passes/test_decompose_div_pass.py +++ b/backends/arm/test/passes/test_decompose_div_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass @@ -15,6 +15,10 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + class Div(torch.nn.Module): """ Basic div model using torch.div @@ -23,7 +27,7 @@ class Div(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(10),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.div(x, 2) @@ -35,17 +39,18 @@ class DivTensor(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(10),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.div(2) -modules = {"div_basic": Div(), "div_tensor": DivTensor()} +modules: Dict[str, ModuleWithInputs] = {"div_basic": Div(), "div_tensor": DivTensor()} @common.parametrize("module", modules) -def test_decompose_div_tosa_FP(module): +def test_decompose_div_tosa_FP(module: ModuleWithInputs) -> None: + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), quantize=False, ops_before_pass={ diff --git a/backends/arm/test/passes/test_decompose_int_pow_pass.py b/backends/arm/test/passes/test_decompose_int_pow_pass.py new file mode 100644 index 00000000000..a9a74c633e1 --- /dev/null +++ b/backends/arm/test/passes/test_decompose_int_pow_pass.py @@ -0,0 +1,80 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Dict, Protocol, Tuple + +import torch +from executorch.backends.arm._passes import DecomposeIntPowPass + +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor] # Inputs to the module + + +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + +TestParam = Tuple[ModuleWithInputs, int] + + +class Square(torch.nn.Module): + """ + Basic squaring + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.square() + + def get_inputs(self) -> input_t: + return (torch.rand(4, 4),) + + +class Pow(torch.nn.Module): + """ + Basic squaring + """ + + def __init__(self, exponent: int) -> None: + super().__init__() + self.exponent = exponent + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.pow(self.exponent) + + def get_inputs(self) -> input_t: + return (torch.rand(4, 4),) + + +test_data: Dict[str, TestParam] = { + "square": (Square(), 1), + "pow_2": (Pow(2), 1), + "pow_3": (Pow(3), 2), + "pow_0": (Pow(0), 0), + "pow_neg_2": (Pow(-2), 1), +} + + +@common.parametrize("data", test_data) +def test_decompose_int_pow(data: TestParam) -> None: + module_with_inputs, nbr_muls = data + module = cast(torch.nn.Module, module_with_inputs) + pipeline = PassPipeline[input_t]( + module, + module_with_inputs.get_inputs(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 1, + }, + ops_not_before_pass=[], + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": nbr_muls, + }, + ops_not_after_pass=["executorch_exir_dialects_edge__ops_pow_Tensor_Scalar"], + pass_list=[DecomposeIntPowPass], + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_decompose_layernorm_pass.py b/backends/arm/test/passes/test_decompose_layernorm_pass.py index d3c2cd6efd7..02fed874765 100644 --- a/backends/arm/test/passes/test_decompose_layernorm_pass.py +++ b/backends/arm/test/passes/test_decompose_layernorm_pass.py @@ -24,7 +24,7 @@ def __init__(self): super(LayerNorm, self).__init__() self.layer_norm = torch.nn.LayerNorm(10) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.layer_norm(x) return x diff --git a/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py b/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py index 5b4c84edbfd..b926e15b92a 100644 --- a/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py +++ b/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py @@ -3,12 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch from executorch.backends.arm._passes.decompose_linalg_vector_norm_pass import ( - DecomposeLinearVectorNormPass, + DecomposeLinalgVectorNormPass, ) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline @@ -16,6 +16,12 @@ input_t = Tuple[torch.Tensor] +class ModuleWithInputs(Protocol): + ord: float | None + + def get_inputs(self) -> input_t: ... + + class VectorNormModel(torch.nn.Module): """ A test module with torch.linalg.vector_norm. @@ -24,7 +30,9 @@ class VectorNormModel(torch.nn.Module): We support only order 1 or 2. """ - def __init__(self, ord: float = None, dim=None, keepdim: bool = False): + def __init__( + self, ord: float | None = None, dim=None, keepdim: bool = False + ) -> None: super().__init__() self.ord = ord self.dim = dim @@ -55,9 +63,9 @@ def get_inputs(self) -> input_t: @common.parametrize("module", modules) -def test_decompose_vector_norm_tosa_INT(module): +def test_decompose_vector_norm_tosa_INT(module: ModuleWithInputs) -> None: """ - This test creates a PassPipeline that applies the DecomposeLinearVectorNormPass. + This test creates a PassPipeline that applies the DecomposeLinalgVectorNormPass. The expected primitive ops vary depending on the norm order: - p == 1: should decompose to ABS and SUM. - p == 2 (default): should decompose to MUL, SUM, and SQRT. @@ -65,6 +73,7 @@ def test_decompose_vector_norm_tosa_INT(module): """ ord_val = module.ord if module.ord is not None else 2.0 + ops_after_pass: Dict[str, int] if ord_val == 1: ops_after_pass = { "executorch_exir_dialects_edge__ops_aten_abs_default": 1, @@ -75,9 +84,16 @@ def test_decompose_vector_norm_tosa_INT(module): "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2, "executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 1, } + else: + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_abs_default": 1, + "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2, + "executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 1, + } + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), # The op is decomposed in legalization aten -> edge, so we are not able to check ops before ops_before_pass=None, @@ -86,6 +102,6 @@ def test_decompose_vector_norm_tosa_INT(module): ops_not_after_pass=[ "executorch_exir_dialects_edge__ops_aten_linarg_vector_norm_default", ], - pass_list=[DecomposeLinearVectorNormPass], + pass_list=[DecomposeLinalgVectorNormPass], ) pipeline.run() diff --git a/backends/arm/test/passes/test_decompose_meandim_pass.py b/backends/arm/test/passes/test_decompose_meandim_pass.py index 22dda5d9244..ac7f3f883c4 100644 --- a/backends/arm/test/passes/test_decompose_meandim_pass.py +++ b/backends/arm/test/passes/test_decompose_meandim_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch @@ -17,6 +17,15 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleWithMeanAttrs(Protocol): + ops_after_pass: Dict[str, int] + ops_not_after_pass: list[str] + u55_ops_after_pass: Dict[str, int] + u55_ops_not_after_pass: list[str] + + def get_inputs(self) -> input_t: ... + + class MeanDim(torch.nn.Module): """ Basic mean model using torch.mean with keepdim = True @@ -28,7 +37,7 @@ class MeanDim(torch.nn.Module): } ops_not_after_pass = u55_ops_not_after_pass = [ - "torch.ops.aten.view_copy.default", + "torch.ops.aten.reshape.default", "torch.ops.aten.avg_pool2d.default", "torch.ops.aten.mean.dim", ] @@ -36,7 +45,7 @@ class MeanDim(torch.nn.Module): def __init__(self): super(MeanDim, self).__init__() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.mean(x, (0, 1), True) def get_inputs(self) -> input_t: @@ -52,7 +61,7 @@ class MeanDimTensor(torch.nn.Module): "torch.ops.aten.sum.dim_IntList": 2, "torch.ops.aten.mul.Tensor": 1, "torch.ops.aten.avg_pool2d.default": 1, - "torch.ops.aten.view_copy.default": 1, + "torch.ops.aten.reshape.default": 1, } ops_not_after_pass = [ @@ -62,7 +71,7 @@ class MeanDimTensor(torch.nn.Module): u55_ops_after_pass = { "torch.ops.aten.sum.dim_IntList": 2, "torch.ops.aten.mul.Tensor": 1, - "torch.ops.aten.view_copy.default": 1, + "torch.ops.aten.reshape.default": 1, } u55_ops_not_after_pass = [ @@ -73,25 +82,25 @@ class MeanDimTensor(torch.nn.Module): def __init__(self): super(MeanDimTensor, self).__init__() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mean((0, 2), False) def get_inputs(self) -> input_t: return (torch.rand(4, 4, 4, 4),) -modules = {"meandim_basic": MeanDim(), "meandim_tensor": MeanDimTensor()} +modules: Dict[str, ModuleWithMeanAttrs] = { + "meandim_basic": MeanDim(), + "meandim_tensor": MeanDimTensor(), +} @common.parametrize("module", modules) -def test_decompose_meandim_tosa_INT(module): +def test_decompose_meandim_tosa_INT(module: ModuleWithMeanAttrs) -> None: # Decompose meandim_pass requires initiating the pas with args, which is not supported # by RunPasses in the arm_tester -> PassPipeline cannot be used. - pipeline = TosaPipelineINT[input_t]( - module, - module.get_inputs(), - [], - ) + nn_module = cast(torch.nn.Module, module) + pipeline = TosaPipelineINT[input_t](nn_module, module.get_inputs(), []) pipeline.pop_stage("check_not.exir") pipeline.pop_stage("check_count.exir") pipeline.pop_stage("to_executorch") @@ -106,11 +115,12 @@ def test_decompose_meandim_tosa_INT(module): @common.parametrize("module", modules) -def test_decompose_meandim_u55_INT(module): +def test_decompose_meandim_u55_INT(module: ModuleWithMeanAttrs) -> None: # Decompose meandim_pass requires initiating the pas with args, which is not supported # by RunPasses in the arm_tester -> PassPipeline cannot be used. + nn_module = cast(torch.nn.Module, module) pipeline = EthosU55PipelineINT[input_t]( - module, module.get_inputs(), [], run_on_fvp=False + nn_module, module.get_inputs(), [], run_on_fvp=False ) pipeline.pop_stage("check_not.exir") pipeline.pop_stage("check_count.exir") diff --git a/backends/arm/test/passes/test_decompose_quant_nodes.py b/backends/arm/test/passes/test_decompose_quant_nodes.py new file mode 100644 index 00000000000..fe216164f86 --- /dev/null +++ b/backends/arm/test/passes/test_decompose_quant_nodes.py @@ -0,0 +1,44 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes import DecomposeQuantNodesPass +from executorch.backends.arm.test.common import parametrize +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + + +class Mul(torch.nn.Module): + test_data = { + "randn": (torch.randn(1, 3, 16, 16), torch.randn(1, 3, 16, 16)), + "large_randn": (10e10 * torch.randn(1, 3, 16, 16), torch.randn(1, 3, 16, 16)), + } + + def forward(self, x, y): + return x * y + + +@parametrize("test_data", Mul.test_data) +def test_decompose_quant_nodes_pass(test_data: Tuple[torch.Tensor]): + module = Mul() + q_dq_ops = { + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + # Verify that DecomposeQuantNodesPass removes quantize/dequantize nodes + # and that the output is correct. + pipeline = PassPipeline( + module, + test_data, + quantize=True, + pass_list=[ + DecomposeQuantNodesPass, + ], + ops_before_pass=q_dq_ops, + ops_not_after_pass=list(q_dq_ops.keys()), + tosa_extensions=["FP"], + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_decompose_softmax_pass.py b/backends/arm/test/passes/test_decompose_softmax_pass.py index 3af1976e3f3..28d7bbb7fdf 100644 --- a/backends/arm/test/passes/test_decompose_softmax_pass.py +++ b/backends/arm/test/passes/test_decompose_softmax_pass.py @@ -22,7 +22,7 @@ def __init__(self): super(Softmax, self).__init__() self.softmax = torch.nn.Softmax(dim=1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.softmax(x) return x @@ -39,7 +39,7 @@ def __init__(self): super(SoftmaxLog, self).__init__() self.softmax = torch.nn.LogSoftmax(dim=1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.softmax(x) return x diff --git a/backends/arm/test/passes/test_decompose_tosa_unsupported_clamp_pass.py b/backends/arm/test/passes/test_decompose_tosa_unsupported_clamp_pass.py new file mode 100644 index 00000000000..9ceeb1b93be --- /dev/null +++ b/backends/arm/test/passes/test_decompose_tosa_unsupported_clamp_pass.py @@ -0,0 +1,73 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes.decompose_tosa_unsupported_clamp_pass import ( + DecomposeTOSAUnsupportedClampPass, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor] + + +class ClampInt32(torch.nn.Module): + test_data = {"rand": (torch.randint(-50, 50, (2, 3), dtype=torch.int32),)} + + def forward(self, x: torch.Tensor): + return torch.clamp(x, -10, 5) + + +@common.parametrize("test_data", ClampInt32.test_data) +def test_decompose_int32_clamp_pass(test_data: input_t): + module = ClampInt32() + pipeline = PassPipeline[input_t]( + module, + test_data, + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_aten_maximum_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_clamp_default", + ], + pass_list=[DecomposeTOSAUnsupportedClampPass], + ) + pipeline.run() + + +class ClampTensorInt32(torch.nn.Module): + test_data = {"rand": (torch.randint(-50, 50, (2, 3), dtype=torch.int32),)} + + def forward(self, x: torch.Tensor): + return torch.clamp(x, torch.tensor(-10), torch.tensor(5)) + + +@common.parametrize("test_data", ClampTensorInt32.test_data) +def test_decompose_int32_clamp_tensor_pass(test_data: input_t): + module = ClampTensorInt32() + pipeline = PassPipeline[input_t]( + module, + test_data, + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_clamp_Tensor": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_aten_maximum_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_clamp_Tensor", + ], + pass_list=[DecomposeTOSAUnsupportedClampPass], + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_decompose_var_pass.py b/backends/arm/test/passes/test_decompose_var_pass.py index c347a2f667c..2e31c9de817 100644 --- a/backends/arm/test/passes/test_decompose_var_pass.py +++ b/backends/arm/test/passes/test_decompose_var_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass @@ -15,6 +15,10 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + class VarDim(torch.nn.Module): """ Basic variance model using torch.Tensor.var function. @@ -24,7 +28,7 @@ def __init__(self, keepdim): super(VarDim, self).__init__() self.keepdim = keepdim - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.var(dim=-1, keepdim=self.keepdim) def get_inputs(self) -> input_t: @@ -40,14 +44,14 @@ def __init__(self, keepdim): super(VarCorrection, self).__init__() self.keepdim = keepdim - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.var(x, -1, keepdim=self.keepdim) def get_inputs(self) -> input_t: return (torch.rand(4, 4),) -modules = { +modules: Dict[str, ModuleWithInputs] = { "vardim_keepdim": VarDim(True), "vardim_no_keepdim": VarDim(False), "varcorrection_keepdim": VarCorrection(True), @@ -56,9 +60,10 @@ def get_inputs(self) -> input_t: @common.parametrize("module", modules) -def test_decompose_var_tosa_FP(module): +def test_decompose_var_tosa_FP(module: ModuleWithInputs) -> None: + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), quantize=False, ops_before_pass={ diff --git a/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py b/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py index 84573878aef..588428aa31b 100644 --- a/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py +++ b/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Callable, Dict, Tuple import torch from executorch.backends.arm.test import common, conftest @@ -17,15 +17,15 @@ class FP32ToINT32Casting(torch.nn.Module): - def __init__(self, target_dtype): + def __init__(self, target_dtype: torch.dtype) -> None: super().__init__() self.target_dtype = target_dtype - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.to(self.target_dtype) -test_data_fp32_input = { +test_data_fp32_input: Dict[str, Callable[[], Tuple[torch.Tensor, torch.dtype]]] = { "fp32_input_rank1": lambda: ( torch.rand((4), dtype=torch.float32), torch.int32, @@ -46,7 +46,9 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", test_data_fp32_input) -def test_decorate_fp32_to_int32_casting_tosa_FP(test_data: Tuple): +def test_decorate_fp32_to_int32_casting_tosa_FP( + test_data: Callable[[], Tuple[torch.Tensor, torch.dtype]] +) -> None: test_tensor, target_dtype = test_data() module = FP32ToINT32Casting(target_dtype) @@ -61,7 +63,9 @@ def test_decorate_fp32_to_int32_casting_tosa_FP(test_data: Tuple): @common.parametrize("test_data", test_data_fp32_input) -def test_decorate_fp32_to_int32_casting_tosa_INT(test_data: Tuple): +def test_decorate_fp32_to_int32_casting_tosa_INT( + test_data: Callable[[], Tuple[torch.Tensor, torch.dtype]] +) -> None: """ Casting operation involving floating-point dtypes will be rejected in INT/INT profile. Therefore, the DecorateFp32toInt32CastingPass is not required in this profile. diff --git a/backends/arm/test/passes/test_fold_qdq_pass.py b/backends/arm/test/passes/test_fold_qdq_pass.py index 994676ff442..2015ab61834 100644 --- a/backends/arm/test/passes/test_fold_qdq_pass.py +++ b/backends/arm/test/passes/test_fold_qdq_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import ClassVar, Dict, Tuple import torch from executorch.backends.arm._passes import FoldAndAnnotateQParamsPass @@ -15,16 +15,16 @@ class SimpleQuantizeModel(torch.nn.Module): - test_data = { + test_data: ClassVar[Dict[str, input_t]] = { "rand": (torch.rand(1, 1280, 7, 7), torch.rand(1, 1280, 7, 7)), } - def forward(self, x, y): - return x + torch.max((x + x), (y + y)) + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + torch.maximum((x + x), (y + y)) @common.parametrize("test_data", SimpleQuantizeModel.test_data) -def test_fold_qdq_pass_tosa_INT(test_data: input_t): +def test_fold_qdq_pass_tosa_INT(test_data: input_t) -> None: """ Tests the FoldAndAnnotateQParamsPass which folds dq/q nodes into the node and stores the quantization parameters in meta. diff --git a/backends/arm/test/passes/test_fuse_batchnorm_pass.py b/backends/arm/test/passes/test_fuse_batchnorm_pass.py index 59fae7cafbd..eb073265a63 100644 --- a/backends/arm/test/passes/test_fuse_batchnorm_pass.py +++ b/backends/arm/test/passes/test_fuse_batchnorm_pass.py @@ -3,22 +3,29 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, ClassVar, Dict, Protocol, Tuple import torch -from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass +from executorch.backends.arm._passes.fuse_batch_norm2d_pass import FuseBatchNorm2dPass from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline input_t = Tuple[torch.Tensor] # Input x +class ModuleWithBatchNormAttrs(Protocol): + ops_before_pass: Dict[str, int] + ops_after_pass: Dict[str, int] + + def get_inputs(self) -> input_t: ... + + class MergeOneOfTwoBN(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, } - ops_after_pass = { + ops_after_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 0, "executorch_exir_dialects_edge__ops_aten_convolution_default": 2, } @@ -39,7 +46,7 @@ def __init__(self, affine: bool): def get_inputs(self) -> input_t: return (torch.randn(1, 3, 256, 256),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv2d(x) x = self.batch_norm2d(x) x = self.relu6(x) @@ -48,11 +55,11 @@ def forward(self, x): class MergeTwosOfTwoBN(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, "executorch_exir_dialects_edge__ops_aten_convolution_default": 2, } - ops_after_pass = { + ops_after_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 0, "executorch_exir_dialects_edge__ops_aten_convolution_default": 2, } @@ -76,7 +83,7 @@ def __init__(self, affine: bool): def get_inputs(self) -> input_t: return (torch.randn(1, 3, 256, 256),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv2d(x) x = self.batch_norm2d(x) x = self.relu6(x) @@ -86,11 +93,11 @@ def forward(self, x): class MergeMultipleUsersBN(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, "executorch_exir_dialects_edge__ops_aten_convolution_default": 3, } - ops_after_pass = { + ops_after_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 0, "executorch_exir_dialects_edge__ops_aten_convolution_default": 4, } @@ -114,7 +121,7 @@ def __init__(self, affine: bool): def get_inputs(self) -> input_t: return (torch.randn(1, 3, 256, 256),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x1 = self.conv2d(x) x = self.batch_norm2d( x1 @@ -129,24 +136,29 @@ def forward(self, x): return z, a -modules = { - "merge_one_of_two_bn_affine": MergeOneOfTwoBN(True), - "merge_one_of_two_bn": MergeOneOfTwoBN(False), - "merge_two_of_two_bn_affine": MergeTwosOfTwoBN(True), - "merge_multiple_users_bn_affine": MergeMultipleUsersBN(True), +modules: Dict[str, ModuleWithBatchNormAttrs] = { + "merge_one_of_two_bn_affine": cast(ModuleWithBatchNormAttrs, MergeOneOfTwoBN(True)), + "merge_one_of_two_bn": cast(ModuleWithBatchNormAttrs, MergeOneOfTwoBN(False)), + "merge_two_of_two_bn_affine": cast( + ModuleWithBatchNormAttrs, MergeTwosOfTwoBN(True) + ), + "merge_multiple_users_bn_affine": cast( + ModuleWithBatchNormAttrs, MergeMultipleUsersBN(True) + ), } @common.parametrize("module", modules) -def test_fuse_batchnorm_tosa_FP(module: torch.nn.Module): +def test_fuse_batchnorm_tosa_FP(module: ModuleWithBatchNormAttrs) -> None: """Test various cases where the batchnorm should either be fused with a previous conv, or converted to a new conv.""" + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), quantize=False, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, - passes_with_exported_program=[FuseBatchnorm2DPass], + passes_with_exported_program=[FuseBatchNorm2dPass], ) pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index 417ad7bff2a..deb017bf662 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -4,11 +4,11 @@ # LICENSE file in the root directory of this source tree. import operator -from typing import Tuple +from typing import cast, ClassVar, Dict, Protocol, Tuple import torch from executorch.backends.arm._passes.fuse_constant_ops_pass import ( - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, FuseConstantArgsPass, ) from executorch.backends.arm.test import common @@ -22,16 +22,26 @@ input_t2 = Tuple[torch.Tensor, torch.Tensor] +class ModuleWithFuseAttrs(Protocol): + ops_before_pass: Dict[str, int] + ops_after_pass: Dict[str, int] + ops_not_after_pass: list[str] + + def get_inputs(self) -> input_t: ... + + class FuseParameter(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten_full_default": 1, "executorch_exir_dialects_edge__ops_aten_view_copy_default": 2, "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, "executorch_exir_dialects_edge__ops_aten_addmm_default": 1, "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, } - ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1} - ops_not_after_pass = [ + ops_after_pass: ClassVar[Dict[str, int]] = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1 + } + ops_not_after_pass: ClassVar[list[str]] = [ "executorch_exir_dialects_edge__ops_aten_full_default", "executorch_exir_dialects_edge__ops_aten_view_copy_default", "executorch_exir_dialects_edge__ops_aten_permute_copy_default", @@ -51,34 +61,38 @@ def __init__( bias=bias, ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc(torch.ones(1)) + x class FuseBuffer(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, } - ops_after_pass = { + ops_after_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, } - ops_not_after_pass = [ + ops_not_after_pass: ClassVar[list[str]] = [ "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" ] - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return (x + 1) * 2 class FuseLiftedTensor(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten_select_copy_int": 1, "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, } - ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1} - ops_not_after_pass = ["executorch_exir_dialects_edge__ops_aten_select_copy_int"] + ops_after_pass: ClassVar[Dict[str, int]] = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1 + } + ops_not_after_pass: ClassVar[list[str]] = [ + "executorch_exir_dialects_edge__ops_aten_select_copy_int" + ] def __init__( self, @@ -92,18 +106,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class CatConst(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten_cat_default": 1, } - ops_after_pass = { + ops_after_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten_cat_default": 1, } - ops_not_after_pass = [] + ops_not_after_pass: ClassVar[list[str]] = [] def __init__(self): super().__init__() - def forward(self, a, b): + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return torch.cat((a, b), dim=0) @@ -115,61 +129,70 @@ def __init__(self, in_out_features: int = 3, bias: bool = True): self.linear = torch.nn.Linear(in_out_features, in_out_features, bias=bias) self.example_input = torch.rand(in_out_features, in_out_features) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: y = torch.full_like(x, 1.0) return self.linear(y) + x - def get_example_input(self): + def get_example_input(self) -> torch.Tensor: return self.example_input -modules = { - "fuse_parameter": FuseParameter(), - "fuse_buffer": FuseBuffer(), - "fuse_const_tensor": FuseLiftedTensor(), +modules: Dict[str, ModuleWithFuseAttrs] = { + "fuse_parameter": cast(ModuleWithFuseAttrs, FuseParameter()), + "fuse_buffer": cast(ModuleWithFuseAttrs, FuseBuffer()), + "fuse_const_tensor": cast(ModuleWithFuseAttrs, FuseLiftedTensor()), } -cat_module = { - "fuse_cat": CatConst(), +cat_module: Dict[str, ModuleWithFuseAttrs] = { + "fuse_cat": cast(ModuleWithFuseAttrs, CatConst()), } @common.parametrize("module", modules) -def test_fuse_const_ops_tosa_FP(module: torch.nn.Module): +def test_fuse_const_ops_tosa_FP(module: ModuleWithFuseAttrs) -> None: pipeline = PassPipeline[input_t]( - module=module, + module=cast(torch.nn.Module, module), test_data=(torch.rand(1),), quantize=False, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, ops_not_after_pass=module.ops_not_after_pass, - passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass], + passes_with_exported_program=[ + ComputeConstantOpsAOTPass, + FuseConstantArgsPass, + ], ) pipeline.run() @common.parametrize("module", modules) -def test_fuse_const_ops_tosa_INT(module: torch.nn.Module): +def test_fuse_const_ops_tosa_INT(module: ModuleWithFuseAttrs) -> None: pipeline = PassPipeline[input_t]( - module, + cast(torch.nn.Module, module), (torch.rand(10, 10),), quantize=True, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, - passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass], + passes_with_exported_program=[ + ComputeConstantOpsAOTPass, + FuseConstantArgsPass, + ], ) pipeline.run() @common.parametrize("module", cat_module) -def test_fuse_const_ops_tosa_BI_cat(module: torch.nn.Module): +def test_fuse_const_ops_tosa_BI_cat(module: ModuleWithFuseAttrs) -> None: pipeline = PassPipeline[input_t2]( - module, + cast(torch.nn.Module, module), (torch.rand(3), torch.rand(2)), quantize=True, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, - passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass], + passes_with_exported_program=[ + ComputeConstantOpsAOTPass, + FuseConstantArgsPass, + ], ) pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_duplicate_users_pass.py b/backends/arm/test/passes/test_fuse_duplicate_users_pass.py new file mode 100644 index 00000000000..ffe56e72691 --- /dev/null +++ b/backends/arm/test/passes/test_fuse_duplicate_users_pass.py @@ -0,0 +1,70 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Tuple + +import torch +from executorch.backends.arm._passes import FuseDuplicateUsersPass +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor] # Input x + + +class ModuleWithOps(torch.nn.Module): + ops_before_pass: Dict[str, int] + ops_after_pass: Dict[str, int] + + +class FuseaAvgPool(ModuleWithOps): + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 3, + } + ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1} + + def __init__(self): + super().__init__() + self.avg = torch.nn.AvgPool2d(1) + + def forward(self, x): + return self.avg(x) + self.avg(x) + self.avg(x) + + +class FuseAvgPoolChain(ModuleWithOps): + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 6, + } + ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 2} + + def __init__(self): + super().__init__() + self.avg = torch.nn.AvgPool2d(1) + + def forward(self, x): + first = self.avg(self.avg(x)) + second = self.avg(self.avg(x)) + third = self.avg(self.avg(x)) + return first + second + third + + +modules: Dict[str, ModuleWithOps] = { + "fuse_avg_pool": FuseaAvgPool(), + "fuse_avg_pool_chain": FuseAvgPoolChain(), +} + + +@common.parametrize("module", modules) +def test_fuse_duplicate_ops_FP(module: ModuleWithOps): + pipeline = PassPipeline[input_t]( + module=module, + test_data=(torch.ones(1, 1, 1, 1),), + quantize=False, + ops_before_pass=module.ops_before_pass, + ops_after_pass=module.ops_after_pass, + pass_list=[ + FuseDuplicateUsersPass, + ], + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py b/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py index f6e437ba034..22c4630d628 100644 --- a/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py @@ -4,12 +4,14 @@ # LICENSE file in the root directory of this source tree. from copy import deepcopy -from typing import Tuple +from typing import Callable, cast, ClassVar, Dict, Protocol, Tuple, TypeVar import torch from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( FuseEqualPlaceholdersPass, ) + +from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( PassPipeline, TosaPipelineFP, @@ -18,10 +20,26 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleWithEqualPlaceholderAttrs(Protocol): + ops_before_pass: Dict[str, int] + ops_after_pass: Dict[str, int] + ops_not_after_pass: list[str] + + def get_inputs(self) -> input_t: ... + + +T = TypeVar("T") +TestDecorator = Callable[[Callable[[T], None]], Callable[[T], None]] + + +def _typed_parametrize(test_data: Dict[str, T]) -> TestDecorator: + return cast(TestDecorator, common.parametrize("module", test_data)) + + class FuseWeightsConstants(torch.nn.Module): - ops_before_pass = {} - ops_after_pass = {} - ops_not_after_pass = [] + ops_before_pass: ClassVar[Dict[str, int]] = {} + ops_after_pass: ClassVar[Dict[str, int]] = {} + ops_not_after_pass: ClassVar[list[str]] = [] def __init__( self, @@ -33,18 +51,21 @@ def __init__( self.bias2 = deepcopy(self.bias1) self.bias3 = deepcopy(self.bias1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return ( torch.conv1d(x, self.weights1, self.bias1) + torch.conv1d(x, self.weights2, self.bias2) + self.bias3 ) + def get_inputs(self) -> input_t: + return (torch.rand(1, 2, 8),) + class FuseWeightsStateDict(torch.nn.Module): - ops_before_pass = {} - ops_after_pass = {} - ops_not_after_pass = [] + ops_before_pass: ClassVar[Dict[str, int]] = {} + ops_after_pass: ClassVar[Dict[str, int]] = {} + ops_not_after_pass: ClassVar[list[str]] = [] def __init__( self, @@ -53,15 +74,18 @@ def __init__( self.fc1 = torch.nn.Linear(in_features=8, out_features=2, bias=True) self.fc2 = deepcopy(self.fc1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc1(x) + self.fc2(x) + def get_inputs(self) -> input_t: + return (torch.rand(1, 2, 8),) + class NotFuseTensorWithDifferentType(torch.nn.Module): - ops_before_pass = {} - ops_after_pass = {} - ops_not_after_pass = [] + ops_before_pass: ClassVar[Dict[str, int]] = {} + ops_after_pass: ClassVar[Dict[str, int]] = {} + ops_not_after_pass: ClassVar[list[str]] = [] def forward(self, x: torch.Tensor, y: torch.Tensor): """ @@ -76,12 +100,20 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): return m, n -def test_fuse_equal_placeholders_constants_tosa_FP(): - module = FuseWeightsConstants() - data = (torch.rand(1, 2, 8),) +constants_modules: Dict[str, ModuleWithEqualPlaceholderAttrs] = { + "fuse_constants": cast(ModuleWithEqualPlaceholderAttrs, FuseWeightsConstants()), +} + +parametrize_constants = _typed_parametrize(constants_modules) + + +@parametrize_constants +def test_fuse_equal_placeholders_constants_tosa_FP( + module: ModuleWithEqualPlaceholderAttrs, +) -> None: pipeline = PassPipeline[input_t]( - module, - data, + cast(torch.nn.Module, module), + module.get_inputs(), quantize=False, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, @@ -97,12 +129,11 @@ def test_fuse_equal_placeholders_constants_tosa_FP(): assert "_common" in constant_keys[1], "FuseEqualPlaceholders constants failed" -def test_fuse_equal_placeholders_state_dict_tosa_FP(): +def test_fuse_equal_placeholders_state_dict_tosa_FP() -> None: module = FuseWeightsStateDict() - data = (torch.rand(1, 2, 8),) pipeline = PassPipeline[input_t]( module, - data, + module.get_inputs(), quantize=False, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, diff --git a/backends/arm/test/passes/test_fuse_view_copy.py b/backends/arm/test/passes/test_fuse_view_copy.py new file mode 100644 index 00000000000..7bf931349b6 --- /dev/null +++ b/backends/arm/test/passes/test_fuse_view_copy.py @@ -0,0 +1,82 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform + + +class FuseSequentialViews(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x.view((1, 2, 3, 4)).view((2, 3, 4, 1)).view((2, 3, 4)) + + data = (torch.randn(2, 3, 1, 4),) + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 3, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 1, + } + + +class FuseSequentialWithNoopsViews(torch.nn.Module): + def forward(self, x: torch.Tensor): + return ( + x.view((1, 2, 3, 4)) + .clone() + .view((2, 3, 4, 1)) + .to(dtype=torch.int32) + .view((2, 3, 4)) + .abs() + .reciprocal() + .sqrt() + .view((12, 2)) + ) + + data = (torch.randn(2, 3, 1, 4),) + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 4, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 1, + } + + +class DontFuseBranchingViews(torch.nn.Module): + def forward(self, x: torch.Tensor): + x = x.view((1, 2, 3, 4)) + x1 = x.abs().view((2, 3, 4, 1)) + x2 = x.ceil().view((2, 3, 4, 1)) + return x1 + x2 + + data = (torch.randn(2, 3, 1, 4),) + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 3, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 3, + } + + +tests = { + "fuse_sequential_views": FuseSequentialViews(), + "fuse_sequential_with_noops_views": FuseSequentialWithNoopsViews(), + "dont_fuse_branching_views": DontFuseBranchingViews(), +} + + +@common.parametrize("model", tests) +def test_fuse_view_copy(model): + pipeline = PassPipeline( + model, + model.data, + quantize=False, + ops_before_pass=model.ops_before_pass, + ops_after_pass=model.ops_after_pass, + pass_list=[FuseViewCopyTransform], + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py b/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py index efc1bebb610..2461a0e833a 100644 --- a/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py +++ b/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py @@ -8,9 +8,13 @@ import torch from executorch.backends.arm._passes import InsertInt32CastsAfterInt64PlaceholdersPass -from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.test.tester.test_pipeline import ( + PassPipeline, + TosaPipelineINT, +) -input_t = Tuple[torch.Tensor] # Input x +input_t = Tuple[torch.Tensor, torch.Tensor] # weights, indices +input_t3 = Tuple[torch.Tensor, torch.LongTensor, torch.Tensor] class Int64InputModel(torch.nn.Module): @@ -44,3 +48,67 @@ def test_int64_model_tosa_FP(): ) pipeline.pop_stage(-1) # Do not compare output pipeline.run() + + +class UpcastToInt64ForIndexCopyInplaceModel(torch.nn.Module): + aten_op = "torch.ops.aten.index_copy_.default" + + def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor): + return x.index_copy_(0, index, y) + + def get_inputs(self) -> input_t3: + return ( + torch.zeros(5, 3), + torch.LongTensor([0, 4, 2]), + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float), + ) + + +def test_upcast_to_int64_for_index_copy_inplace_tosa_INT(): + module = UpcastToInt64ForIndexCopyInplaceModel() + pipeline = TosaPipelineINT[input_t3]( + module, + module.get_inputs(), + aten_op=module.aten_op, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 0, + }, + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +class UpcastToInt64ForIndexCopyModel(torch.nn.Module): + aten_op = "torch.ops.aten.index_copy.default" + + def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor): + return x.index_copy(0, index, y) + + def get_inputs(self) -> input_t3: + return ( + torch.zeros(5, 3), + torch.LongTensor([0, 4, 2]), + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float), + ) + + +def test_upcast_to_int64_for_index_copy_tosa_INT(): + module = UpcastToInt64ForIndexCopyModel() + pipeline = TosaPipelineINT[input_t3]( + module, + module.get_inputs(), + aten_op=module.aten_op, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 0, + }, + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() diff --git a/backends/arm/test/passes/test_insert_rescale_i32_pass.py b/backends/arm/test/passes/test_insert_rescale_i32_pass.py new file mode 100644 index 00000000000..4b5c16ab31a --- /dev/null +++ b/backends/arm/test/passes/test_insert_rescale_i32_pass.py @@ -0,0 +1,107 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes import ( + FoldAndAnnotateQParamsPass, + InsertRescaleInt32Pass, +) +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + + +class MultipleOpsModel(torch.nn.Module): + """A module containing ops that require INT32 inputs/outputs.""" + + input_t = Tuple[torch.Tensor, torch.Tensor] + + def forward(self, x, y): + a = x - y + b = x * a + c = torch.maximum(a, b) + d = torch.abs(b) + e = c + d + f = e > a + return f + + def get_inputs(self, dtype) -> input_t: + if dtype == torch.float32: + return (torch.rand(1, 3, 5, 6), torch.rand(1, 3, 5, 6)) + elif dtype == torch.int32: + return ( + torch.randint(3, 5, (3,), dtype=torch.int32), + torch.randint(3, 5, (3,), dtype=torch.int32), + ) + else: + raise ValueError("Not a valid input dtype for model") + + def get_num_expected_rescales(self): + # "number of op nodes with i8 output" + "number of i8 node inputs" + return 5 + 11 + + +class SumModel(torch.nn.Module): + input_t = Tuple[torch.Tensor] + + def forward(self, x): + a = torch.sum(x, 2, keepdim=True) # (1, 2, 1, 4) + b = torch.sum(a, [1, 3], keepdim=True) # (1, 1, 1, 1) + c = torch.sum(b, [0, 2], keepdim=False) # (1, 1) + return c + + def get_inputs(self, dtype) -> input_t: + if dtype == torch.float32: + return (torch.rand(1, 2, 3, 4),) + elif dtype == torch.int32: + return (torch.randint(0, 10, (1, 2, 3, 4), dtype=torch.int32),) + else: + raise ValueError("Not a valid input dtype for model") + + def get_num_expected_rescales(self): + # Two RESCALE nodes per SUM node + return 6 + + +def _test_model_with_f32_data(model): + ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} + ops_after = { + "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": model.get_num_expected_rescales(), + } + pipeline = PassPipeline[model.input_t]( + model, + model.get_inputs(torch.float32), + quantize=True, + ops_not_before_pass=ops_not_before, + ops_after_pass=ops_after, + pass_list=[FoldAndAnnotateQParamsPass, InsertRescaleInt32Pass], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +def test_insert_rescales_sum_model(): + _test_model_with_f32_data(SumModel()) + + +def test_insert_rescales_multiple_ops_model(): + _test_model_with_f32_data(MultipleOpsModel()) + + +def test_dont_insert_rescales(): + module = MultipleOpsModel() + input_t = Tuple[torch.Tensor, torch.Tensor] + ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} + # All inputs are already i32. Rescales should not be added. + ops_not_after = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(torch.int32), + ops_not_before_pass=ops_not_before, + ops_not_after_pass=ops_not_after, + pass_list=[FoldAndAnnotateQParamsPass, InsertRescaleInt32Pass], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() diff --git a/backends/arm/test/passes/test_insert_table_ops_pass.py b/backends/arm/test/passes/test_insert_table_ops_pass.py index 5e695c237a0..00ff0c96de1 100644 --- a/backends/arm/test/passes/test_insert_table_ops_pass.py +++ b/backends/arm/test/passes/test_insert_table_ops_pass.py @@ -3,8 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -from typing import Tuple +from typing import ClassVar, Dict, Tuple import torch from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( @@ -18,16 +17,16 @@ class Sigmoid(torch.nn.Module): - test_data = { + test_data: ClassVar[Dict[str, input_t]] = { "rand": (torch.rand(4),), } - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.sigmoid() @common.parametrize("test_data", Sigmoid.test_data) -def test_insert_table_tosa_INT(test_data: input_t): +def test_insert_table_tosa_INT(test_data: input_t) -> None: module = Sigmoid() pipeline = PassPipeline[input_t]( module, diff --git a/backends/arm/test/passes/test_int32_cast_embedding_pass.py b/backends/arm/test/passes/test_int32_cast_embedding_pass.py index 7adca527d75..30e84fadde3 100644 --- a/backends/arm/test/passes/test_int32_cast_embedding_pass.py +++ b/backends/arm/test/passes/test_int32_cast_embedding_pass.py @@ -10,12 +10,12 @@ from executorch.backends.arm.test.tester.test_pipeline import PassPipeline -input_t = Tuple[torch.Tensor] # Input x +input_t = Tuple[torch.Tensor, torch.Tensor] class Int32Embedding(torch.nn.Module): - def forward(self, weights: torch.Tensor, indices: torch.Tensor): + def forward(self, weights: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: return torch.embedding(weights, indices) def get_inputs(self) -> input_t: diff --git a/backends/arm/test/passes/test_ioquantization_pass.py b/backends/arm/test/passes/test_ioquantization_pass.py index da3b81aa096..fc57e8fa5b0 100644 --- a/backends/arm/test/passes/test_ioquantization_pass.py +++ b/backends/arm/test/passes/test_ioquantization_pass.py @@ -14,7 +14,7 @@ from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs -input_t = Tuple[torch.Tensor] +input_t = Tuple[torch.Tensor, torch.Tensor] class SimpleModel(torch.nn.Module): diff --git a/backends/arm/test/passes/test_promote_bool_operands_pass.py b/backends/arm/test/passes/test_promote_bool_operands_pass.py new file mode 100644 index 00000000000..48c9778a75c --- /dev/null +++ b/backends/arm/test/passes/test_promote_bool_operands_pass.py @@ -0,0 +1,103 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import ClassVar, Dict, Tuple + +import torch +from executorch.backends.arm._passes import PromoteBoolOperandsPass + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.test.harness.stages import StageType +from executorch.exir.dialects._ops import ops as exir_ops + +tensor_pair_t = Tuple[torch.Tensor, torch.Tensor] + + +def _collect_cast_dtypes(pipeline: PassPipeline[tensor_pair_t]) -> list[torch.dtype]: + exported_program = pipeline.tester.get_artifact( + StageType.RUN_PASSES + ).exported_program() + graph_module = exported_program.graph_module + cast_dtypes: list[torch.dtype] = [] + for node in graph_module.graph.nodes: + if ( + node.op == "call_function" + and node.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default + and "dtype" in node.kwargs + ): + cast_dtypes.append(node.kwargs["dtype"]) + return cast_dtypes + + +class BoolBitwiseAndModule(torch.nn.Module): + test_data: ClassVar[Dict[str, tensor_pair_t]] = { + "bool_tensors": ( + torch.tensor([[True, False], [False, True]], dtype=torch.bool), + torch.tensor([[False, True], [True, False]], dtype=torch.bool), + ) + } + + def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor: + return torch.bitwise_and(lhs, rhs) + + +class MixedMulModule(torch.nn.Module): + test_data: ClassVar[Dict[str, tensor_pair_t]] = { + "mixed_tensors": ( + torch.tensor([True, False, True, False], dtype=torch.bool), + torch.tensor([1, 2, 3, 4], dtype=torch.int32), + ) + } + + def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor: + return torch.mul(lhs, rhs) + + +@common.parametrize("test_data", BoolBitwiseAndModule.test_data) +def test_promote_bool_operands_all_bool(test_data: tensor_pair_t) -> None: + module = BoolBitwiseAndModule() + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor": 1, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3, + } + pipeline = PassPipeline[tensor_pair_t]( + module, + test_data, + quantize=False, + ops_before_pass=ops_before_pass, + ops_after_pass=ops_after_pass, + pass_list=[PromoteBoolOperandsPass], + ) + pipeline.run() + cast_dtypes = _collect_cast_dtypes(pipeline) + assert cast_dtypes.count(torch.int8) == 2 + assert cast_dtypes.count(torch.bool) == 1 + + +@common.parametrize("test_data", MixedMulModule.test_data) +def test_promote_bool_operands_mixed_types(test_data: tensor_pair_t) -> None: + module = MixedMulModule() + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, + } + pipeline = PassPipeline[tensor_pair_t]( + module, + test_data, + quantize=False, + ops_before_pass=ops_before_pass, + ops_after_pass=ops_after_pass, + pass_list=[PromoteBoolOperandsPass], + ) + pipeline.run() + cast_dtypes = _collect_cast_dtypes(pipeline) + assert cast_dtypes.count(torch.int32) == 1 diff --git a/backends/arm/test/passes/test_rescale_pass.py b/backends/arm/test/passes/test_rescale_pass.py index 0959a0eaa25..1ab4f5b6a03 100644 --- a/backends/arm/test/passes/test_rescale_pass.py +++ b/backends/arm/test/passes/test_rescale_pass.py @@ -31,21 +31,21 @@ def test_rescale_op(): ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), torch.int32, - 0.2, + [0.2], 2, 0, ), ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int32), torch.int8, - 0.2, + [0.2], 0, -128, ), ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), torch.int8, - 0.8, + [0.8], 10, 127, ), @@ -71,14 +71,14 @@ def test_nonzero_zp_for_int32(): ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), torch.int32, - 0.2, + [0.2], 2, # Should be 0, expect error 1, ), ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int32), torch.int8, - 0.2, + [0.2], 1, 1, # Should be 0, expect error ), @@ -107,14 +107,14 @@ def test_zp_outside_range(): ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), torch.int32, - 0.2, + [0.2], 128, # Should be <128, expect error 0, ), ( torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int32), torch.int8, - 0.2, + [0.2], 0, -129, # Should be >-129m expect error ), @@ -172,16 +172,9 @@ def test_quantized_rescale_tosa_bi(test_data: tuple[torch.Tensor, torch.Tensor]) pipeline.run() -u55_xfails = { - "ones": "MLBEDSW-11032: ILLEGAL_OFM_BASE error: Base addresses must be aligned to brick depth on u55.", - "randn_ones": "MLBEDSW-11032: ILLEGAL_OFM_BASE error: Base addresses must be aligned to brick depth on u55.", - "randn_large": "MLBEDSW-11032: ILLEGAL_OFM_BASE error: Base addresses must be aligned to brick depth on u55.", -} - - -@common.parametrize("test_data", RescaleNetwork.test_data, xfails=u55_xfails) +@common.parametrize("test_data", RescaleNetwork.test_data) @common.XfailIfNoCorstone300 -def test_quantized_rescale_u55(test_data: tuple[torch.Tensor, torch.Tensor]): +def test_quantized_rescale_u55(test_data: input_t): """Tests a model with many ops that requires rescales. As more ops are quantized to int32 and need the InsertRescalesPass, make sure that they play nicely together.""" module = RescaleNetwork() @@ -190,14 +183,13 @@ def test_quantized_rescale_u55(test_data: tuple[torch.Tensor, torch.Tensor]): test_data=test_data, aten_ops=[], exir_ops=[], - run_on_fvp=True, ) pipeline.run() @common.parametrize("test_data", RescaleNetwork.test_data) @common.XfailIfNoCorstone320 -def test_quantized_rescale_u85(test_data: tuple[torch.Tensor, torch.Tensor]): +def test_quantized_rescale_u85(test_data: input_t): """Tests a model with many ops that requires rescales. As more ops are quantized to int32 and need the InsertRescalesPass, make sure that they play nicely together.""" module = RescaleNetwork() @@ -206,6 +198,5 @@ def test_quantized_rescale_u85(test_data: tuple[torch.Tensor, torch.Tensor]): test_data=test_data, aten_ops=[], exir_ops=[], - run_on_fvp=True, ) pipeline.run() diff --git a/backends/arm/test/passes/test_to_tosa_memory_format.py b/backends/arm/test/passes/test_to_tosa_memory_format.py index 1e9b8ffc63d..486a906a0ff 100644 --- a/backends/arm/test/passes/test_to_tosa_memory_format.py +++ b/backends/arm/test/passes/test_to_tosa_memory_format.py @@ -3,10 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, List, Protocol, Tuple import torch -from executorch.backends.arm._passes import ToTosaMemoryFormatPass +from executorch.backends.arm._passes import ( + AnnotateOutputDimOrderPass, + ToTosaMemoryFormatPass, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -18,19 +21,30 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleMetadata(Protocol): + ops_before_pass: Dict[str, int] + ops_after_pass: Dict[str, int] + ops_not_after_pass: List[str] + + def get_inputs(self) -> input_t: ... + + class NoNHWC(torch.nn.Module): """ Test-module with no ops requiring NHWC mermory format. """ - ops_after_pass = {"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2} - ops_not_after_pass = [] + ops_before_pass: Dict[str, int] = {} + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2 + } + ops_not_after_pass: List[str] = [] - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + x return x - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(1, 2, 2, 2),) @@ -39,8 +53,11 @@ class ParallelClusters(torch.nn.Module): Test-module with multiple parallel clusters of nodes requiring different memory formats. """ - ops_after_pass = {"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2} - ops_not_after_pass = [] + ops_before_pass: Dict[str, int] = {} + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2 + } + ops_not_after_pass: List[str] = [] def __init__(self): super().__init__() @@ -53,14 +70,14 @@ def __init__(self): self.maxpool = torch.nn.MaxPool2d(1, 1) self.avgpool = torch.nn.AvgPool2d(1, 1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x1 = self.conv(x) x2 = self.maxpool(x) x3 = self.avgpool(x) x4 = x * x return x1 + x2 + x3 + x4 - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(1, 2, 2, 2),) @@ -69,9 +86,11 @@ class SerialClusters(torch.nn.Module): Test-module with multiple serial clusters of nodes requring different memory formats. """ - ops_before_pass = {} - ops_after_pass = {"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4} - ops_not_after_pass = [] + ops_before_pass: Dict[str, int] = {} + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4 + } + ops_not_after_pass: List[str] = [] def __init__(self): super().__init__() @@ -87,7 +106,7 @@ def __init__(self): bias=True, ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = x * x x = self.conv(x) @@ -97,7 +116,7 @@ def forward(self, x): x = self.conv(x) return x - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(2, 2, 2, 2),) @@ -106,17 +125,17 @@ class Reshapes(torch.nn.Module): Test-module with different configurations of views requiring different memory formats. """ - ops_before_pass = {} - ops_after_pass = { + ops_before_pass: Dict[str, int] = {} + ops_after_pass: Dict[str, int] = { "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 16 } - ops_not_after_pass = [] + ops_not_after_pass: List[str] = [] def __init__(self): super().__init__() self.maxpool = torch.nn.MaxPool2d(1, 1) # Use maxpool to force NHWC format - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.maxpool(x) x = x.view((2, 2, 4, 16, 1)) # N-C-HW-invariant intact, no transposes needed @@ -156,11 +175,11 @@ def forward(self, x): return x - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(4, 4, 4, 4),) -modules = { +modules: Dict[str, ModuleMetadata] = { "no_nhwc": NoNHWC(), "parallel_clusters": ParallelClusters(), "serial_clusters": SerialClusters(), @@ -169,14 +188,15 @@ def get_inputs(self): @common.parametrize("module", modules) -def test_to_tosa_memory_format_tosa_INT(module): +def test_to_tosa_memory_format_tosa_INT(module: ModuleMetadata) -> None: # We cannot check op counts after a specific pass with the full pipeline + module_nn = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + module_nn, module.get_inputs(), ops_after_pass=module.ops_after_pass, ops_not_after_pass=module.ops_not_after_pass, - pass_list=[RemoveGetItemPass], + pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass], passes_with_exported_program=[ToTosaMemoryFormatPass], ) pipeline.pop_stage( @@ -186,7 +206,8 @@ def test_to_tosa_memory_format_tosa_INT(module): @common.parametrize("module", modules) -def test_to_tosa_memory_format_tosa_INT_functional(module): +def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> None: # Also run the actual pass pipeline to ensure functional correctness. - pipeline = TosaPipelineINT[input_t](module, module.get_inputs(), []) + module_nn = cast(torch.nn.Module, module) + pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), []) pipeline.run() diff --git a/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py b/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py index fc405e21f2a..f6ff8b8c0bb 100644 --- a/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py +++ b/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py @@ -3,16 +3,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, Tuple +from typing import Dict, List, Tuple import torch from executorch.backends.arm._passes import UnsqueezeBeforeRepeatPass from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline -input_t = Tuple[ - torch.Tensor, Dict[str, int], list[str] -] # Input x, ops_after_pass, ops_not_after_pass +pipeline_input_t = Tuple[torch.Tensor, ...] +test_case_t = Tuple[ + pipeline_input_t, + Dict[str, int], + List[str], +] class Repeat(torch.nn.Module): @@ -20,10 +23,10 @@ class Repeat(torch.nn.Module): Basic repeat model. """ - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.repeat(2, 2, 2, 2) - test_data: Dict[str, input_t] = { + test_data: Dict[str, test_case_t] = { "insert_view": ( (torch.rand((2, 3, 4)),), {"aten_repeat_default": 3, "aten_view_copy_default": 4}, @@ -38,14 +41,14 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", Repeat.test_data) -def test_unsqueeze_before_repeat_tosa_FP(test_data: input_t): +def test_unsqueeze_before_repeat_tosa_FP(test_data: test_case_t): """ When rank(input) != number of repeated dimensions (=4 in Repeat module), insert view. """ module = Repeat() data, ops_after_pass, ops_not_after_pass = test_data - pipeline = PassPipeline( + pipeline = PassPipeline[pipeline_input_t]( module, data, quantize=False, diff --git a/backends/arm/test/quantizer/test_conv_relu_fusing.py b/backends/arm/test/quantizer/test_conv_relu_fusing.py new file mode 100644 index 00000000000..ccc6c114efd --- /dev/null +++ b/backends/arm/test/quantizer/test_conv_relu_fusing.py @@ -0,0 +1,118 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.quantizer import ( + get_symmetric_a16w8_quantization_config, + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.quantizer.quantization_config import ( + QuantizationConfig, + QuantizationSpec, +) +from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline +from executorch.backends.arm.tosa import TosaSpecification + + +def get_symmetric_a8w8_quantization_config(): + affine_quant_config = get_symmetric_quantization_config() + output_activation = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=affine_quant_config.get_output_act_qspec().observer_or_fake_quant_ctr, + quant_min=-127, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + ch_axis=None, + is_dynamic=False, + ) + input_activation = output_activation + symmetric_quant_config = QuantizationConfig( + input_activation=input_activation, + output_activation=output_activation, + weight=affine_quant_config.get_weight_qspec(), + bias=None, + ) + return symmetric_quant_config + + +class ConvBNRelu(torch.nn.Module): + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=4, + kernel_size=2, + ) + self.bn = torch.nn.BatchNorm2d(num_features=4) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + bn = self.bn(conv) + relu = self.relu(bn) + return relu + + def get_example_inputs(self): + return (torch.randn(1, 3, 8, 8),) + + +def test_conv_relu_fusing_8a8w_affine(): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + quantizer = TOSAQuantizer(tosa_spec) + quant_config = get_symmetric_quantization_config() + quantizer.set_global(quant_config) + expected_annotations = { + "aten.conv2d.default": {None: 1}, + "aten.relu.default": {quant_config.get_output_act_qspec(): 1}, + } + pipeline = QuantizationPipeline[Tuple[torch.Tensor]]( + ConvBNRelu(), + ConvBNRelu().get_example_inputs(), + quantizer=quantizer, + qspecs=expected_annotations, + ) + pipeline.run() + + +def test_conv_relu_fusing_8a8w_symmetric(): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + quantizer = TOSAQuantizer(tosa_spec) + symmetric_quant_config = get_symmetric_a8w8_quantization_config() + + quantizer.set_global(symmetric_quant_config) + expected_annotations = { + "aten.conv2d.default": {symmetric_quant_config.get_output_act_qspec(): 1}, + "aten.relu.default": {symmetric_quant_config.get_output_act_qspec(): 1}, + } + pipeline = QuantizationPipeline[Tuple[torch.Tensor]]( + ConvBNRelu(), + ConvBNRelu().get_example_inputs(), + quantizer=quantizer, + qspecs=expected_annotations, + ) + pipeline.run() + + +def test_conv_relu_fusing_16a8w_symmetric(): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT+int16") + quantizer = TOSAQuantizer(tosa_spec) + quant_config = get_symmetric_a16w8_quantization_config() + + quantizer.set_global(quant_config) + expected_annotations = { + "aten.conv2d.default": {quant_config.get_output_act_qspec(): 1}, + "aten.relu.default": {quant_config.get_output_act_qspec(): 1}, + } + pipeline = QuantizationPipeline[Tuple[torch.Tensor]]( + ConvBNRelu(), + ConvBNRelu().get_example_inputs(), + quantizer=quantizer, + qspecs=expected_annotations, + ) + pipeline.run() diff --git a/backends/arm/test/quantizer/test_generic_annotater.py b/backends/arm/test/quantizer/test_generic_annotater.py index 4eaf1c205cc..4b43b6c9e50 100644 --- a/backends/arm/test/quantizer/test_generic_annotater.py +++ b/backends/arm/test/quantizer/test_generic_annotater.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import itertools -from typing import Tuple +from typing import Any, Callable, Tuple import torch from executorch.backends.arm.quantizer import is_annotated @@ -18,20 +18,25 @@ class SingleOpModel(torch.nn.Module): - def __init__(self, op, example_input, **op_kwargs) -> None: + def __init__( + self, + op: Callable[..., torch.Tensor], + example_input: Tuple[Any, ...], + **op_kwargs: Any, + ) -> None: super().__init__() - self.op = op - self._example_input = example_input - self.op_kwargs = op_kwargs + self.op: Callable[..., torch.Tensor] = op + self._example_input: Tuple[Any, ...] = example_input + self.op_kwargs: dict[str, Any] = dict(op_kwargs) - def forward(self, x): + def forward(self, x: Any) -> torch.Tensor: return self.op(x, **self.op_kwargs) - def example_inputs(self): + def example_inputs(self) -> Tuple[Any, ...]: return self._example_input -def check_annotation(model): +def check_annotation(model: SingleOpModel) -> None: pipeline = TosaPipelineINT[input_t1](model, model.example_inputs(), [], []) pipeline.pop_stage("check_count.exir") pipeline.pop_stage("run_method_and_compare_outputs") diff --git a/backends/arm/test/quantizer/test_selective_quantization.py b/backends/arm/test/quantizer/test_selective_quantization.py new file mode 100644 index 00000000000..5cdc9d18812 --- /dev/null +++ b/backends/arm/test/quantizer/test_selective_quantization.py @@ -0,0 +1,212 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Dict + +import torch +from executorch.backends.arm.quantizer import ( + get_symmetric_a16w8_quantization_config, + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline +from executorch.backends.arm.tosa import TosaSpecification +from torchvision import models, transforms # type: ignore[import-untyped] +from torchvision.ops.misc import Conv2dNormActivation # type: ignore[import-untyped] + + +def get_quantizer(): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + quantizer = TOSAQuantizer(tosa_spec) + quantizer.set_global(get_symmetric_quantization_config()) + return quantizer + + +def get_selective_quantizer_by_module( + module_types: Dict[torch.nn.Module, QuantizationConfig] +): + quantizer = get_quantizer() + quantizer.set_global(get_symmetric_quantization_config()) + for module_type, config in module_types.items(): + quantizer.set_module_type(module_type, config) + + return quantizer + + +def get_selective_quantizer_by_module_name(module_names: Dict[str, QuantizationConfig]): + quantizer = get_quantizer() + quantizer.set_global(get_symmetric_quantization_config()) + for module_name, config in module_names.items(): + quantizer.set_module_name(module_name, config) + + return quantizer + + +class Add(torch.nn.Module): + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + +class AddSoftmaxAdd(torch.nn.Module): + module_names = {"add_0": None, "add_1": None} + module_types = { + Add: None, + } + quantized_aten_targets = {"aten.relu.default": 1} + non_quantized_aten_targets = {"aten.add.Tensor": 2} + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.softmax = torch.nn.Softmax(dim=-1) + self.relu = torch.nn.ReLU() + self.add_0 = Add() + self.add_1 = Add() + + def get_inputs(self): + return (torch.randn(1, 10), torch.randn(1, 10)) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + z = self.add_0(x, y) + z = self.relu(z) + z = self.softmax(z) + return self.add_1(z, y) + + +test_models = { + "add_softmax_add": AddSoftmaxAdd, +} + + +@common.parametrize("model", test_models) +def test_selective_quant_module_name_tosa_INT(model): + model = model() + inputs = model.get_inputs() + quantzed_aten_targets = model.quantized_aten_targets + non_quantized_aten_targets = model.non_quantized_aten_targets + quantization_annotations = {} + for target, count in quantzed_aten_targets.items(): + quantization_annotations[target] = { + get_symmetric_quantization_config().output_activation: count + } + for target, count in non_quantized_aten_targets.items(): + quantization_annotations[target] = {None: count} + + pipeline = QuantizationPipeline[tuple[torch.Tensor, torch.Tensor]]( + model, + inputs, + quantizer=get_selective_quantizer_by_module_name(model.module_names), + qspecs=quantization_annotations, + ) + + pipeline.run() + + +@common.parametrize("model", test_models) +def test_selective_quant_module_type_tosa_INT(model): + model = model() + inputs = model.get_inputs() + quantzed_aten_targets = model.quantized_aten_targets + non_quantized_aten_targets = model.non_quantized_aten_targets + quantization_annotations = {} + for target, count in quantzed_aten_targets.items(): + quantization_annotations[target] = { + get_symmetric_quantization_config().output_activation: count + } + for target, count in non_quantized_aten_targets.items(): + quantization_annotations[target] = {None: count} + + pipeline = QuantizationPipeline[tuple[torch.Tensor, torch.Tensor]]( + model, + inputs, + quantizer=get_selective_quantizer_by_module(model.module_types), + qspecs=quantization_annotations, + ) + + pipeline.run() + + +mv3 = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights) +mv3.eval() +normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + +def test_mv3_selective_quant_int16(): + model = mv3 + inputs = (normalize(torch.randn(1, 3, 224, 224)),) + + a16w8_config = get_symmetric_a16w8_quantization_config() + quantization_annotations = { + "aten.conv2d.default": { + a16w8_config.output_activation: 34, + }, + "aten.hardswish_.default": { + a16w8_config.output_activation: 18, + }, + "aten.relu_.default": { + a16w8_config.output_activation: 5, + }, + } + + pipeline = QuantizationPipeline[tuple[torch.Tensor]]( + model, + inputs, + quantizer=get_selective_quantizer_by_module( + { + Conv2dNormActivation: a16w8_config, + } + ), + qspecs=quantization_annotations, + ) + + pipeline.run() + + +def test_mv3_selective_quant_float32(): + model = mv3 + inputs = (normalize(torch.randn(1, 3, 224, 224)),) + + quantization_annotations = { + "aten.adaptive_avg_pool2d.default": { + None: 1, + }, + } + + pipeline = QuantizationPipeline[tuple[torch.Tensor]]( + model, + inputs, + quantizer=get_selective_quantizer_by_module_name( + { + "features.11.block.2.avgpool": None, + } + ), + qspecs=quantization_annotations, + ) + + pipeline.run() + + +def test_mv3_io_quant(): + model = mv3 + inputs = (normalize(torch.randn(1, 3, 224, 224)),) + + quantizer = get_quantizer() + # Workaround to disable quantization for all modules + quantizer.set_module_type(torch.nn.Module, None) + # Only quantize IO + quantizer.set_io(get_symmetric_quantization_config()) + + pipeline = QuantizationPipeline[tuple[torch.Tensor]]( + model, + inputs, + quantizer=quantizer, + input_qspecs={get_symmetric_quantization_config().input_activation: 1}, + output_qspecs={get_symmetric_quantization_config().output_activation: 1}, + ) + + pipeline.run() diff --git a/backends/arm/test/quantizer/test_set_module_name.py b/backends/arm/test/quantizer/test_set_module_name.py new file mode 100644 index 00000000000..56131a83e86 --- /dev/null +++ b/backends/arm/test/quantizer/test_set_module_name.py @@ -0,0 +1,158 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.quantizer import ( + get_symmetric_a16w8_quantization_config, + get_symmetric_quantization_config, + is_annotated, + QuantizationConfig, + TOSAQuantizer, +) +from executorch.backends.arm.quantizer.quantization_config import QuantizationSpec +from executorch.backends.arm.tosa import TosaSpecification +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + +DQ_PER_CHANNEL = torch.ops.quantized_decomposed.dequantize_per_channel.default +DQ_PER_TENSOR = torch.ops.quantized_decomposed.dequantize_per_tensor.default +Q_PER_TENSOR = torch.ops.quantized_decomposed.quantize_per_tensor.default + + +class ConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv0 = torch.nn.Conv2d( + 3, + 16, + kernel_size=4, + ) + self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=3, bias=False) + self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3) + + def forward(self, x): + x = self.conv0(x) + x = torch.sigmoid(x) + x = self.conv1(x) + x = torch.tanh(x) + x = self.conv2(x) + return x + + +test_inputs = (torch.randn(1, 3, 64, 64),) + + +def validate_per_tensor_quant(node: torch.fx.Node, qspec: QuantizationSpec): + _, _, zero_point, qmin, qmax, dtype = node.args + if qspec.qscheme == torch.per_tensor_symmetric: + assert ( + zero_point == 0 + ), f"Zero point {zero_point} is not zero for symmetric quantization" + assert ( + qmin == qspec.quant_min + ), f"Quant min {qmin} does not match expected {qspec.quant_min}" + assert ( + qmax == qspec.quant_max + ), f"Quant max {qmax} does not match expected {qspec.quant_max}" + assert dtype == qspec.dtype, f"Dtype {dtype} does not match expected {qspec.dtype}" + + +def validate_per_channel_quant(node: torch.fx.Node, qspec: QuantizationSpec): + _, _, _, channel_axis, qmin, qmax, dtype = node.args + assert ( + channel_axis == qspec.ch_axis + ), f"Channel axis {channel_axis} does not match expected {qspec.ch_axis}" + assert ( + qmin == qspec.quant_min + ), f"Quant min {qmin} does not match expected {qspec.quant_min}" + assert ( + qmax == qspec.quant_max + ), f"Quant max {qmax} does not match expected {qspec.quant_max}" + assert dtype == qspec.dtype, f"Dtype {dtype} does not match expected {qspec.dtype}" + + +def validate_input(input_node: torch.fx.Node, qspec: QuantizationSpec | None): + if qspec is None: + return + + per_channel = qspec.qscheme == torch.per_channel_symmetric + expected_dequant_op = DQ_PER_CHANNEL if per_channel else DQ_PER_TENSOR + assert ( + input_node.target == expected_dequant_op + ), f"Input node {input_node} is not quantized as expected" + if per_channel: + validate_per_channel_quant(input_node, qspec) + else: + validate_per_tensor_quant(input_node, qspec) + + +def validate_output(node: torch.fx.Node, qspec: QuantizationSpec | None): + if qspec is None: + return + users = list(node.users) + assert len(users) == 1, f"Node {node} should have exactly one user" + assert ( + users[0].target == Q_PER_TENSOR + ), f"Output node {users[0]} is not quantized as expected" + validate_per_tensor_quant(users[0], qspec) + + +def validate_node( + node: torch.fx.Node, quantization_config: QuantizationConfig | None +) -> None: + if quantization_config is None: + assert not is_annotated(node), f"Node {node} is unexpectedly annotated" + return + + assert is_annotated(node), f"Node {node} is not annotated" + input_qspec = quantization_config.get_input_act_qspec() + output_qspec = quantization_config.get_output_act_qspec() + weight_qspec = quantization_config.get_weight_qspec() + + if len(node.all_input_nodes) == 3: + input_node, weight_node, bias_node = node.all_input_nodes + bias_qspec = quantization_config.get_bias_qspec(node) + validate_input(bias_node, bias_qspec) + else: + input_node, weight_node = node.all_input_nodes + + validate_input(input_node, input_qspec) + validate_input(weight_node, weight_qspec) + validate_output(node, output_qspec) + + +def test_set_module_name() -> None: + model = ConvModel() + model.eval() + + # Set up quantizer with different configs for different modules + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + quantizer = TOSAQuantizer(tosa_spec) + int8_config = get_symmetric_quantization_config(is_per_channel=False) + a16w8_config = get_symmetric_a16w8_quantization_config() + # Set module-specific configurations but don't set global config to test that + # only specified modules are quantized + quantizer.set_module_name("conv0", int8_config) + quantizer.set_module_name("conv1", a16w8_config) + + # Export model + exported_model = torch.export.export(model, test_inputs) + + # Prepare, calibrate and convert model + prepared_model = prepare_pt2e(exported_model.module(), quantizer) + prepared_model(*test_inputs) + converted_model = convert_pt2e(prepared_model) + + validate_node( + [node for node in converted_model.graph.nodes if node.name == "conv2d"][0], + int8_config, + ) + validate_node( + [node for node in converted_model.graph.nodes if node.name == "conv2d_1"][0], + a16w8_config, + ) + validate_node( + [node for node in converted_model.graph.nodes if node.name == "conv2d_2"][0], + None, + ) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index aeb0e3a56bd..45355da353e 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -8,30 +8,36 @@ import os import re import shutil -import subprocess +import subprocess # nosec B404 - invoked only for trusted toolchain binaries import tempfile from pathlib import Path -from typing import Any, cast, Dict, List, Literal, Optional, Tuple +from types import NoneType +from typing import Any, cast, Dict, List, Optional, Tuple import numpy as np import torch - -from executorch.backends.arm.arm_backend import is_tosa, is_vgf -from executorch.backends.arm.test.conftest import is_option_enabled -from executorch.backends.arm.tosa.specification import ( - get_tosa_spec, - Tosa_1_00, - TosaSpecification, +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.constants import ( + NHWC_INVERSE_ORDER, + NHWC_ORDER, + NNHWC_INVERSE_ORDER, + NNHWC_ORDER, ) + +from executorch.backends.arm.ethosu import EthosUCompileSpec +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification +from executorch.backends.arm.vgf import VgfCompileSpec +from executorch.backends.arm.vgf.model_converter import find_model_converter_binary from executorch.exir import ExecutorchProgramManager, ExportedProgram -from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.lowered_backend_module import LoweredBackendModule from torch.fx.node import Node from torch.overrides import TorchFunctionMode -from tosa.TosaGraph import TosaGraph +from tosa.TosaGraph import TosaGraph # type: ignore[import-not-found, import-untyped] logger = logging.getLogger(__name__) @@ -143,22 +149,55 @@ def get_output_quantization_params( Raises: RuntimeError if no output quantization parameters are found. """ - quant_params = {} - for node in output_node.args[0]: - if node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default: - quant_params[node] = QuantizationParams( - node_name=node.args[0].name, - scale=node.args[1], - zp=node.args[2], - qmin=node.args[3], - qmax=node.args[4], - dtype=node.args[5], + quant_params: dict[Node, QuantizationParams | None] = {} + for node in output_node.args[0]: # type: ignore[union-attr] + if ( + node.target # type: ignore[union-attr] + == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): + quant_params[node] = QuantizationParams( # type: ignore[index] + node_name=node.args[0].name, # type: ignore[arg-type, union-attr] + scale=node.args[1], # type: ignore[arg-type, union-attr] + zp=node.args[2], # type: ignore[arg-type, union-attr] + qmin=node.args[3], # type: ignore[arg-type, union-attr] + qmax=node.args[4], # type: ignore[arg-type, union-attr] + dtype=node.args[5], # type: ignore[arg-type, union-attr] ) else: - quant_params[node] = None + quant_params[node] = None # type: ignore[index] return quant_params +def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: + dtype = _torch_to_numpy_dtype_dict[tensor.dtype] + array = tensor.detach().numpy().astype(dtype) # type: ignore[var-annotated] + dim_order = tensor.dim_order() + if dim_order == NHWC_ORDER: + a = array.transpose(NHWC_ORDER) + return a + elif dim_order == NNHWC_ORDER: + return array.transpose(NNHWC_ORDER) + else: + return array + + +def numpy_to_torch_tensor(array: np.ndarray, output_node: Node) -> torch.Tensor: + output_tensor = get_first_fake_tensor(output_node) + shape = output_tensor.shape + dim_order = output_tensor.dim_order() + if dim_order == NHWC_ORDER: + shape_with_dim_order = [shape[i] for i in NHWC_ORDER] + tensor = torch.from_numpy(array).reshape(shape_with_dim_order) + return tensor.permute(NHWC_INVERSE_ORDER).to(memory_format=torch.channels_last) + elif dim_order == NNHWC_ORDER: + shape_with_dim_order = [shape[i] for i in NNHWC_ORDER] + tensor = torch.from_numpy(array).reshape(shape_with_dim_order) + return tensor.permute(NNHWC_INVERSE_ORDER).to(memory_format=torch.channels_last) + else: + tensor = torch.from_numpy(array).reshape(shape) + return tensor + + class TosaReferenceModelDispatch(TorchFunctionMode): """A context manager for executing call_delegate nodes using the reference model""" @@ -168,14 +207,10 @@ def __init__(self): def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs): tosa_buffer = lowered_backend_module.processed_bytes - compile_specs = lowered_backend_module.compile_specs - if not is_tosa(compile_specs): - raise RuntimeError( - "Model needs to be compiled to tosa to run reference model." - ) - tosa_spec = get_tosa_spec(compile_specs) + compile_spec = TosaCompileSpec.from_list(lowered_backend_module.compile_specs) - return run_tosa_graph(tosa_buffer, tosa_spec, inputs) + output_node = lowered_backend_module.original_module.graph.output_node() + return run_tosa_graph(tosa_buffer, compile_spec.tosa_spec, inputs, output_node) def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) @@ -197,6 +232,22 @@ def __torch_function__(self, func, types, args=..., kwargs=None): ) kwargs = kwargs or {} + + # This is a hack since Q/DQ ops does not handle channels last input correctly: the simplest and most robust + # workaround is to simply run them in channels first format and then convert back to channels last. + if func in ( + torch.ops.quantized_decomposed.quantize_per_tensor.out, + torch.ops.quantized_decomposed.dequantize_per_tensor.out, + torch.ops.quantized_decomposed.quantize_per_channel.out, + torch.ops.quantized_decomposed.dequantize_per_channel.out, + ): + + input_dim_order = args[0].dim_order() + if input_dim_order in (NHWC_ORDER, NNHWC_ORDER): + args = [args[0].to(memory_format=torch.contiguous_format), *args[1:]] + res = func(*args, **kwargs) + return res.to(memory_format=torch.channels_last) + return func(*args, **kwargs) @@ -204,29 +255,28 @@ def run_target( executorch_program_manager: ExecutorchProgramManager, inputs: Tuple[torch.Tensor], intermediate_path: str | Path, - target_board: Literal["corestone-300", "corestone-320", "vkml_emulation_layer"], + target_board: str, elf_path: str | Path, timeout: int = 120, # s ): if target_board not in VALID_TARGET: raise ValueError(f"Unsupported target: {target_board}") - if target_board in ("corstone-300", "corstone-320"): - return run_corstone( - executorch_program_manager, - inputs, - intermediate_path, - target_board, - elf_path, - timeout, - ) - elif target_board == "vkml_emulation_layer": + if target_board == "vkml_emulation_layer": return run_vkml_emulation_layer( executorch_program_manager, inputs, intermediate_path, elf_path, ) + return run_corstone( + executorch_program_manager, + inputs, + intermediate_path, + target_board, + elf_path, + timeout, + ) def save_inputs_to_file( @@ -234,10 +284,10 @@ def save_inputs_to_file( inputs: Tuple[torch.Tensor], intermediate_path: str | Path, ): - input_file_paths = [] + input_file_paths: list[str] = [] input_names = get_input_names(exported_program) for input_name, input_ in zip(input_names, inputs): - input_path = save_bytes(intermediate_path, input_, input_name) + input_path = save_bytes(intermediate_path, input_, input_name) # type: ignore[arg-type] input_file_paths.append(input_path) return input_file_paths @@ -250,15 +300,14 @@ def get_output_from_file( ): output_np = [] output_node = exported_program.graph_module.graph.output_node() - for i, node in enumerate(output_node.args[0]): - output_shape = node.meta["val"].shape + for i, node in enumerate(output_node.args[0]): # type: ignore[union-attr] output_dtype = node.meta["val"].dtype - tosa_ref_output = np.fromfile( + tosa_ref_output = np.fromfile( # type: ignore[var-annotated] os.path.join(intermediate_path, f"{output_base_name}-{i}.bin"), _torch_to_numpy_dtype_dict[output_dtype], ) - output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape)) + output_np.append(numpy_to_torch_tensor(tosa_ref_output, node)) return tuple(output_np) @@ -315,7 +364,7 @@ def run_corstone( executorch_program_manager: ExecutorchProgramManager, inputs: Tuple[torch.Tensor], intermediate_path: str | Path, - target_board: Literal["corestone-300", "corestone-320"], + target_board: str, elf_path: str | Path, timeout: int = 120, # s ) -> list[torch.Tensor]: @@ -337,7 +386,6 @@ def run_corstone( to figure out the shape and dtype of the buffer that was output from the FVP. """ - exported_program = executorch_program_manager.exported_program() intermediate_path = Path(intermediate_path) intermediate_path.mkdir(exist_ok=True) @@ -353,15 +401,18 @@ def run_corstone( input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path) output_base_name = "out" - out_path = os.path.join(intermediate_path, output_base_name) - cmd_line = f"executor_runner -m {pte_path} -o {out_path}" + cmd_line = "executor_runner -m program.pte -o out" for input_path in input_paths: - cmd_line += f" -i {input_path}" + relative_path = os.path.relpath( + Path(input_path).resolve(), start=intermediate_path + ) + cmd_line += f" -i {relative_path}" - ethos_u_extra_args = "" - if is_option_enabled("fast_fvp"): - ethos_u_extra_args = ethos_u_extra_args + "--fast" + if len(cmd_line) > 256: + raise ValueError( + "The argument passed to the FVP should be less than 256 characters long, otherwise it gets truncated" + ) match target_board: case "corstone-300": @@ -380,10 +431,12 @@ def run_corstone( "-C", "cpu0.semihosting-stack_base=0", "-C", - f"ethosu.extra_args='{ethos_u_extra_args}'", - "-C", "cpu0.semihosting-heap_limit=0", "-C", + f"cpu0.semihosting-cwd={intermediate_path}", + "-C", + "ethosu.extra_args='--fast'", + "-C", f"cpu0.semihosting-cmd_line='{cmd_line}'", "-a", str(elf_path), @@ -414,7 +467,9 @@ def run_corstone( "-C", "mps4_board.subsystem.cpu0.semihosting-heap_limit=0", "-C", - f"mps4_board.subsystem.ethosu.extra_args='{ethos_u_extra_args}'", + f"mps4_board.subsystem.cpu0.semihosting-cwd={intermediate_path}", + "-C", + "mps4_board.subsystem.ethosu.extra_args='--fast'", "-C", f"mps4_board.subsystem.cpu0.semihosting-cmd_line='{cmd_line}'", "-a", @@ -430,10 +485,25 @@ def run_corstone( # Regex to check for error or fault messages in stdout from FVP result_stdout = result.stdout.decode() error_regex = r"(^[EF][: ].*$)|(^.*Hard fault.*$)|(^.*Assertion.*$)" - if re.compile(error_regex, re.MULTILINE).search(result_stdout): - raise RuntimeError( + pattern = re.compile(error_regex, re.MULTILINE) + regex_matches = [m.group(0) for m in pattern.finditer(result_stdout)] + + if regex_matches: + logger.error( f"Corstone simulation failed:\ncmd: {' '.join(command_args)}\nlog: \n {result_stdout}\n{result.stderr.decode()}" ) + # Pretty-print regex matches + pretty_matches = "\n".join(f"{m.strip()}" for i, m in enumerate(regex_matches)) + logger.error( + f"Corstone simulation failed. Problems: {len(regex_matches)} found:\n{pretty_matches}" + ) + raise RuntimeError( + f"Corstone simulation failed. Problems: {len(regex_matches)} found:\n{pretty_matches}" + ) + else: + logger.info( + f"Corstone simulation:\ncmd: {' '.join(command_args)}\nlog: \n {result_stdout}\n{result.stderr.decode()}" + ) return get_output_from_file(exported_program, intermediate_path, output_base_name) @@ -444,11 +514,14 @@ def prep_data_for_save( quant_param: Optional[QuantizationParams] = None, ): if isinstance(data, torch.Tensor): - data_np = np.array(data.detach(), order="C").astype( - _torch_to_numpy_dtype_dict[data.dtype] - ) + data_np = torch_tensor_to_numpy(data) + elif isinstance(data, (int, float, bool, NoneType)): + return np.array(data) else: - data_np = np.array(data) + raise RuntimeError( + f"Input dtype {type(data)} could not be converted to numpy array." + ) + if quant_param is not None: assert quant_param.node_name in input_name, ( f"The quantization params name '{quant_param.node_name}' does not " @@ -462,30 +535,8 @@ def prep_data_for_save( f"{quant_param.dtype}".replace("torch.", "") ) # Use string format of dtype to convert to numpy dtype ) - return data_np - -def save_npy( - path: str, - data, - input_name: str, - quant_param: Optional[QuantizationParams] = None, -) -> str: - """Serializes and saves 'data' as a .npy file, possibly quantizing it before. - - Parameters: - path: the directory where to save the data. - data: the data to save. - input_name: the name of the file, without file-ending. - quant_param: the parameters to use for quantization. - Returns: - the full file path of the output. - """ - data_np = prep_data_for_save(data, input_name, quant_param) - file_path = os.path.join(path, input_name + ".npy") - np.save(file_path, data_np, allow_pickle=False) - - return file_path + return data_np def save_bytes( @@ -521,7 +572,9 @@ def _run_cmd(cmd: List[str], check=True) -> subprocess.CompletedProcess[bytes]: cmd (List[str]): The command to run as a list. """ try: - result = subprocess.run(cmd, check=check, capture_output=True) + result = subprocess.run( # nosec B603 - cmd constructed from trusted inputs + cmd, check=check, capture_output=True + ) return result except subprocess.CalledProcessError as e: arg_string = " ".join(cmd) @@ -586,8 +639,7 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: data = np.frombuffer(data, dtype=np.float32) data = data.reshape(tensor["shape"]) tensor["data"] = data - except Exception: - # This is just nice-to-have if it works, don't care if it fails. + except Exception: # nosec B110 - best-effort casting for debug output only pass return json_out @@ -628,11 +680,15 @@ def corstone320_installed() -> bool: def model_converter_installed() -> bool: - cmd = ["model-converter", "--version"] + model_converter = find_model_converter_binary() + if model_converter is None: + return False + try: - _run_cmd(cmd, check=True) - except: + _run_cmd([model_converter, "--version"], check=True) + except Exception: return False + return True @@ -664,30 +720,36 @@ def assert_elf_path_exists(elf_path): ) -def get_elf_path(target_board): +def get_elf_path(target_board: str, use_portable_ops: bool = False) -> str: + elf_path = "" + if target_board not in VALID_TARGET: raise ValueError(f"Unsupported target: {target_board}") + if use_portable_ops: + portable_ops_str = "portable-ops_" + else: + portable_ops_str = "" + if target_board in ("corstone-300", "corstone-320"): elf_path = os.path.join( "arm_test", - f"arm_semihosting_executor_runner_{target_board}", + f"arm_semihosting_executor_runner_{portable_ops_str}{target_board}", "arm_executor_runner", ) - assert_elf_path_exists(elf_path) elif target_board == "vkml_emulation_layer": elf_path = os.path.join( - "arm_test/arm_executor_runner_vkml", + f"arm_test/arm_executor_runner_{portable_ops_str}vkml", "executor_runner", ) - assert_elf_path_exists(elf_path) + assert_elf_path_exists(elf_path) return elf_path -def arm_executor_runner_exists(target_board): +def arm_executor_runner_exists(target_board: str, use_portable_ops: bool = False): try: - get_elf_path(target_board) + get_elf_path(target_board, use_portable_ops=use_portable_ops) except: return False else: @@ -698,18 +760,21 @@ def run_tosa_graph( graph: Any, tosa_version: TosaSpecification, inputs: list[torch.Tensor], + output_node: Node, ) -> list[torch.Tensor]: """Runs the TOSA reference model with inputs and returns the result.""" - inputs_np = [input.numpy() for input in inputs] + + # Convert tensors to numpy arrays with correct dim_order + inputs_np = [torch_tensor_to_numpy(input_tensor) for input_tensor in inputs] if isinstance(tosa_version, Tosa_1_00): - import tosa_reference_model as reference_model + import tosa_reference_model as reference_model # type: ignore[import-not-found, import-untyped] - debug_mode = "ALL" if logger.level <= logging.DEBUG else None + debug_mode = "ALL" if logger.getEffectiveLevel() <= logging.DEBUG else None outputs_np, status = reference_model.run( graph, inputs_np, - verbosity=_tosa_refmodel_loglevel(logger.level), + verbosity=_tosa_refmodel_loglevel(logger.getEffectiveLevel()), initialize_variable_tensor_from_numpy=True, debug_mode=debug_mode, ) @@ -722,17 +787,21 @@ def run_tosa_graph( status == reference_model.GraphStatus.TOSA_VALID ), "Non-valid TOSA given to reference model." - return [torch.from_numpy(output) for output in outputs_np] + # Convert output numpy arrays to tensors with same dim_order as the output nodes + result = [ + numpy_to_torch_tensor(output_array, node) + for output_array, node in zip(outputs_np, output_node.args[0]) # type: ignore[arg-type] + ] + + return result -def get_target_board(compile_spec: list[CompileSpec]) -> str | None: - if is_vgf(compile_spec): +def get_target_board(compile_spec: ArmCompileSpec) -> str | None: + if isinstance(compile_spec, VgfCompileSpec): return "vkml_emulation_layer" - for spec in compile_spec: - if spec.key == "compile_flags": - flags = spec.value.decode() - if "u55" in flags: - return "corstone-300" - elif "u85" in flags: - return "corstone-320" + if isinstance(compile_spec, EthosUCompileSpec): + if "u55" in compile_spec.target: + return "corstone-300" + if "u85" in compile_spec.target: + return "corstone-320" return None diff --git a/backends/arm/test/setup_testing.sh b/backends/arm/test/setup_testing.sh index d1e4725d93b..bb68361c238 100755 --- a/backends/arm/test/setup_testing.sh +++ b/backends/arm/test/setup_testing.sh @@ -10,6 +10,23 @@ script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")") et_root_dir=$(realpath "${script_dir}/../../..") build_executor_runner=${et_root_dir}/backends/arm/scripts/build_executor_runner.sh build_root_test_dir=${et_root_dir}/arm_test/arm_semihosting_executor_runner +extraflags="-DET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE=83886080" -${build_executor_runner} --pte=semihosting --target=ethos-u55-128 --output="${build_root_test_dir}_corstone-300" -${build_executor_runner} --pte=semihosting --target=ethos-u85-128 --output="${build_root_test_dir}_corstone-320" +# By default tests with an elf without any portable_ops +# If you supply use_portable_ops=True when creating the ArmTester() +# you will instead test with some portable ops compiled in, see list below. + +#--target --system_config --memory_mode should match the ArmTester used setup see backends/arm/test/common.py + +${build_executor_runner} --pte=semihosting --target=ethos-u55-128 --system_config=Ethos_U55_High_End_Embedded --memory_mode=Shared_Sram --output="${build_root_test_dir}_corstone-300" --extra_build_flags=${extraflags} +${build_executor_runner} --pte=semihosting --target=ethos-u85-128 --system_config=Ethos_U85_SYS_DRAM_Mid --memory_mode=Dedicated_Sram_384KB --output="${build_root_test_dir}_corstone-320" --extra_build_flags=${extraflags} + +# List of portable ops used by testing, this is mainly used to test models in the flow +# test setup to make sure models that are not fully delegated can still be tested and run OK +# To use this you can set use_portable_ops=True when creating ArmTester() + +portable_ops_list_u55="aten::permute_copy.out,aten::convolution.out,aten::relu.out,aten::_native_batch_norm_legit_no_training.out,aten::as_strided_copy.out,aten::mean.out,aten::squeeze_copy.dims,dim_order_ops::_clone_dim_order.out" +portable_ops_list_u85="aten::permute_copy.out,aten::convolution.out,aten::relu.out,aten::_native_batch_norm_legit_no_training.out,aten::as_strided_copy.out,aten::mean.out,aten::full_like.out,aten::bmm.out,aten::scalar_tensor.out,aten::index.Tensor_out,aten::where.self_out" + +${build_executor_runner} --pte=semihosting --target=ethos-u55-128 --system_config=Ethos_U55_High_End_Embedded --memory_mode=Shared_Sram --select_ops_list="${portable_ops_list_u55}" --output="${build_root_test_dir}_portable-ops_corstone-300" --extra_build_flags=${extraflags} +${build_executor_runner} --pte=semihosting --target=ethos-u85-128 --system_config=Ethos_U85_SYS_DRAM_Mid --memory_mode=Dedicated_Sram_384KB --select_ops_list="${portable_ops_list_u85}" --output="${build_root_test_dir}_portable-ops_corstone-320" --extra_build_flags=${extraflags} diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index a6181cf34ce..ffb18043536 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -1,9 +1,10 @@ # load("//caffe2/test/fb:defs.bzl", "define_tests") +load("@fbsource//tools/build_defs:fbsource_utils.bzl", "is_fbcode") load("@fbcode_macros//build_defs:python_pytest.bzl", "python_pytest") load("@bazel_skylib//lib:paths.bzl", "paths") def define_arm_tests(): - # TODO Add more tests + # TODO [fbonly] Add more tests test_files = [] # Passes @@ -17,13 +18,18 @@ def define_arm_tests(): "ops/test_addmm.py", "ops/test_avg_pool2d.py", "ops/test_cat.py", + "ops/test_conv2d.py", "ops/test_linear.py", "ops/test_mul.py", + "ops/test_permute.py", + "ops/test_rsqrt.py", "ops/test_slice.py", "ops/test_sigmoid.py", + "ops/test_sub.py", "ops/test_tanh.py", "ops/test_view.py", "ops/test_cos.py", + "ops/test_to_copy.py", ] # Quantization @@ -31,6 +37,17 @@ def define_arm_tests(): "quantizer/test_generic_annotater.py", ] + # Misc tests + test_files += [ + "misc/test_compile_spec.py", + "misc/test_tosa_spec.py", + "misc/test_bn_relu_folding_qat.py", + "misc/test_custom_partition.py", + "misc/test_debug_hook.py", + # "misc/test_dim_order.py", (TODO - T238390249) + "misc/test_outputs_order.py", + ] + TESTS = {} for test_file in test_files: @@ -48,8 +65,12 @@ def define_arm_tests(): "//executorch/kernels/quantized:custom_ops_generated_lib", ], deps = [ - "//executorch/backends/arm/test:arm_tester", + "//executorch/backends/arm/test/tester/fb:arm_tester_fb" if is_fbcode else "//executorch/backends/arm/test:arm_tester", "//executorch/backends/arm/test:conftest", + "//executorch/backends/arm:ethosu", + "//executorch/backends/arm/tosa:compile_spec", + "//executorch/backends/arm/tosa:partitioner", + "//executorch/backends/arm:vgf", "//executorch/exir:lib", "fbsource//third-party/pypi/pytest:pytest", "fbsource//third-party/pypi/parameterized:parameterized", diff --git a/backends/arm/test/test_arm_baremetal.sh b/backends/arm/test/test_arm_baremetal.sh index 53c707cad28..5a168637214 100755 --- a/backends/arm/test/test_arm_baremetal.sh +++ b/backends/arm/test/test_arm_baremetal.sh @@ -14,7 +14,7 @@ script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) et_root_dir=$(cd ${script_dir}/../../.. && pwd) cd "${et_root_dir}" pwd -setup_path_script=${et_root_dir}/examples/arm/ethos-u-scratch/setup_path.sh +setup_path_script=${et_root_dir}/examples/arm/arm-scratch/setup_path.sh _setup_msg="please refer to ${et_root_dir}/examples/arm/setup.sh to properly install necessary tools." @@ -155,17 +155,17 @@ test_pytest_ethosu_fvp() { # Same as test_pytest but also sometime verify using test_pytest_ops_vkml() { # Same as test_pytest but also sometime verify using VKML runtime - echo "${TEST_SUITE_NAME}: Run pytest with VKML" + echo "${TEST_SUITE_NAME}: Run pytest operator tests with VKML runtime" - backends/arm/scripts/build_executorch.sh backends/arm/test/setup_testing_vkml.sh - pytest --verbose --color=yes --numprocesses=auto --durations=10 backends/arm/test/ --ignore=backends/arm/test/models + pytest --verbose --color=yes --numprocesses=auto --durations=10 backends/arm/test/ \ + --ignore=backends/arm/test/models -k _vgf_ echo "${TEST_SUITE_NAME}: PASS" } test_pytest_models_vkml() { # Same as test_pytest but also sometime verify VKML runtime - echo "${TEST_SUITE_NAME}: Run pytest with VKML" + echo "${TEST_SUITE_NAME}: Run pytest model tests with VKML runtime" backends/arm/scripts/build_executorch.sh backends/arm/test/setup_testing_vkml.sh @@ -173,7 +173,7 @@ test_pytest_models_vkml() { # Same as test_pytest but also sometime verify VKML # Install model dependencies for pytest source backends/arm/scripts/install_models_for_test.sh - pytest --verbose --color=yes --numprocesses=auto --durations=0 backends/arm/test/models + pytest --verbose --color=yes --numprocesses=auto --durations=0 backends/arm/test/models -k _vgf_ echo "${TEST_SUITE_NAME}: PASS" } @@ -189,11 +189,11 @@ test_run_vkml() { # End to End model tests using run.sh echo "${TEST_SUITE_NAME}: Test VKML" out_folder="arm_test/test_run" - examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=add --output=${out_folder}/runner - examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=mul --output=${out_folder}/runner + examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=add --output=${out_folder}/runner --bundleio + examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=mul --output=${out_folder}/runner --bundleio - examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=qadd --output=${out_folder}/runner - examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=qops --output=${out_folder}/runner + examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=qadd --output=${out_folder}/runner --bundleio + examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=qops --output=${out_folder}/runner --bundleio echo "${TEST_SUITE_NAME}: PASS" } @@ -253,8 +253,8 @@ test_models_vkml() { # End to End model tests using model_test.py # VKML echo "${TEST_SUITE_NAME}: Test target VKML" - python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=vgf --model=mv2 - python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=vgf --no_quantize --model=mv2 + python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=vgf --model=resnet18 --extra_runtime_flags="--bundleio_atol=0.2 --bundleio_rtol=0.2" + python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=vgf --model=resnet50 --extra_runtime_flags="--bundleio_atol=0.2 --bundleio_rtol=0.2" echo "${TEST_SUITE_NAME}: PASS" } @@ -339,6 +339,19 @@ test_full_vkml() { # All End to End model tests echo "${TEST_SUITE_NAME}: PASS" } +test_model_smollm2-135M() { + echo "${TEST_SUITE_NAME}: Test SmolLM2-135M on Ethos-U85" + + # Build common libs once + python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --build_libs + + python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-128 --model=smollm2 --extra_flags="-DEXECUTORCH_SELECT_OPS_LIST=dim_order_ops::_to_dim_order_copy.out" + + echo "${TEST_SUITE_NAME}: PASS" + + +} + test_smaller_stories_llama() { echo "${TEST_SUITE_NAME}: Test smaller_stories_llama" @@ -365,5 +378,41 @@ test_smaller_stories_llama() { echo "${TEST_SUITE_NAME}: PASS" } +test_memory_allocation() { + echo "${TEST_SUITE_NAME}: Test ethos-u memory allocation with run.sh" + + mkdir -p arm_test/test_run + # Ethos-U85 + echo "${TEST_SUITE_NAME}: Test target Ethos-U85" + examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=examples/arm/example_modules/add.py &> arm_test/test_run/full.log + python3 backends/arm/test/test_memory_allocator_log.py --log arm_test/test_run/full.log \ + --require "model_pte_program_size" "<= 3000 B" \ + --require "method_allocator_planned" "<= 64 B" \ + --require "method_allocator_loaded" "<= 1024 B" \ + --require "method_allocator_input" "<= 16 B" \ + --require "Total DRAM used" "<= 0.06 KiB" + echo "${TEST_SUITE_NAME}: PASS" +} + +test_undefinedbehavior_sanitizer() { + echo "${TEST_SUITE_NAME}: Test ethos-u executor_runner with UBSAN" + + mkdir -p arm_test/test_run + # Ethos-U85 + echo "${TEST_SUITE_NAME}: Test target Ethos-U85" + examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=examples/arm/example_modules/add.py --build_type=UndefinedSanitizer + echo "${TEST_SUITE_NAME}: PASS" +} + +test_address_sanitizer() { + echo "${TEST_SUITE_NAME}: Test ethos-u executor_runner with ASAN" + + mkdir -p arm_test/test_run + # Ethos-U85 + echo "${TEST_SUITE_NAME}: Test target Ethos-U85" + examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=examples/arm/example_modules/add.py --build_type=AddressSanitizer + echo "${TEST_SUITE_NAME}: PASS" +} + ${TEST_SUITE} diff --git a/backends/arm/test/test_memory_allocator_log.py b/backends/arm/test/test_memory_allocator_log.py new file mode 100644 index 00000000000..3853b60b7f6 --- /dev/null +++ b/backends/arm/test/test_memory_allocator_log.py @@ -0,0 +1,170 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Check log files for memory metrics and compare them against thresholds. + +Usage example: + python3 test_memory_allocator_log.py \ + --log path/to/log.txt \ + --require "Total SRAM used" "<= 310 KiB" \ + --require "method_allocator_input" "<= 4 B" +""" + +import argparse +import re +import sys +from typing import List, Optional, Tuple + + +def unit_factor(u: str) -> float: + if not u: + return 1.0 + ul = u.strip().lower() + table = { + "b": 1, + "byte": 1, + "bytes": 1, + "kb": 1000, + "mb": 1000**2, + "gb": 1000**3, + "kib": 1024, + "mib": 1024**2, + "gib": 1024**3, + } + if ul in table: + return float(table[ul]) + return 1.0 + + +def parse_value(text_num: str, text_unit: Optional[str]) -> float: + return float(text_num) * unit_factor(text_unit or "") + + +def parse_cond(cond: str) -> Tuple[str, float, str]: + # Regexp explained. Example of things it will parse: + # "< 310 KiB", ">=10MB", "== 42", "!=3 bytes", "<=0.5 MiB" + + # The regexp explained in detail: + # ^: anchor the match to the start and end of the string (no extra chars allowed). + # \s*: optional whitespace (spaces, tabs, etc.). + # (<=|>=|==|!=|<|>): capturing group 1. One of the comparison operators: <=, >=, ==, !=, <, >. + # \s*: optional whitespace. + # ([0-9]+(?:\.[0-9]+)?): capturing group 2. A number: + # [0-9]+: one or more digits (the integer part). + # (?:\.[0-9]+)?: optional non-capturing group for a fractional part like .25. + # \s*: optional whitespace between number and unit + # ([A-Za-z]+)?: capturing group 3, optional. A unit made of letters only (e.g., B, KB, KiB, MB, MiB). Case# insensitive by class choice. + # \s*: optional trailing whitespace. + m = re.match( + r"^\s*(<=|>=|==|!=|<|>)\s*([0-9]+(?:\.[0-9]+)?)\s*([A-Za-z]+)?\s*$", cond + ) + if not m: + raise ValueError(f"Invalid condition: {cond}") + op, num, unit = m.groups() + return op, float(num), (unit or "") + + +def compare(a: float, b: float, op: str) -> bool: + return { + "<": a < b, + "<=": a <= b, + ">": a > b, + ">=": a >= b, + "==": abs(a - b) < 1e-9, + "!=": abs(a - b) >= 1e-9, + }[op] + + +def find_metric_value(line: str, label: str) -> Tuple[Optional[str], Optional[str]]: + # Same regexp as parse_cond() but without the first group of matching comparison operators + # First go, search for the pattern but escape and ignore cases + # The regexp: + # ([0-9]+(?:\.[0-9]+)?) — capturing group 1: a decimal number + # [0-9]+ — one or more digits (integer part) + # (?:\.[0-9]+)? — optional fractional part like .25 (non-capturing) + # \s* — optional whitespace between number and unit + # ([A-Za-z]+)? — capturing group 2 (optional): a unit made only of letters (e.g., B, KB, KiB, MB) + m = re.search( + re.escape(label) + r".*?([0-9]+(?:\.[0-9]+)?)\s*([A-Za-z]+)?", + line, + flags=re.IGNORECASE, + ) + if m: + return m.group(1), m.group(2) + # Second go, same regexp as above but not caring about label. If + # no number was tied to a label be happy just salvaging it from + # the line + m = re.search(r"([0-9]+(?:\.[0-9]+)?)\s*([A-Za-z]+)?", line) + if m: + return m.group(1), m.group(2) + return None, None + + +def first_line_with_label(lines: List[str], label: str) -> Optional[str]: + label_lc = label.lower() + return next((ln for ln in lines if label_lc in ln.lower()), None) + + +def check_requirement(label: str, cond: str, lines: List[str]) -> Optional[str]: + op, thr_num, thr_unit = parse_cond(cond) + matched = first_line_with_label(lines, label) + if matched is None: + return f"{label}: not found in log" + + num_str, unit_str = find_metric_value(matched, label) + if num_str is None: + return f"{label}: value not found on line: {matched.strip()}" + + left_bytes = parse_value(num_str, unit_str) + right_bytes = parse_value(str(thr_num), thr_unit or (unit_str or "")) + ok = compare(left_bytes, right_bytes, op) + + human_left = f"{num_str} {unit_str or 'B'}" + human_right = f"{thr_num:g} {thr_unit or (unit_str or 'B')}" + print( + f"[check] {label}: {human_left} {op} {human_right} -> {'OK' if ok else 'FAIL'}" + ) + + if ok: + return None + return f"{label}: {human_left} not {op} {human_right}" + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--log", required=True, help="Path to log file") + parser.add_argument( + "--require", + action="append", + nargs=2, + metavar=("LABEL", "COND"), + default=[], + help="""Required label and condition consisting + of a number and unit. Example: \"Total DRAM + used\" \"<= 0.06 KiB\"""", + ) + args = parser.parse_args() + + with open(args.log, "r", encoding="utf-8", errors="ignore") as f: + lines = f.readlines() + + failures: List[str] = [] + for label, cond in args.require: + msg = check_requirement(label, cond, lines) + if msg: + failures.append(msg) + + if failures: + print("Failures:") + for msg in failures: + print(" - " + msg) + return 1 + + print("All checks passed.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/backends/arm/test/test_model.py b/backends/arm/test/test_model.py index 8833b7050e7..87a92c25ba2 100755 --- a/backends/arm/test/test_model.py +++ b/backends/arm/test/test_model.py @@ -5,9 +5,10 @@ import argparse import os -import subprocess +import subprocess # nosec B404 - launches trusted build/test scripts import sys import time +from typing import Sequence def get_args(): @@ -66,9 +67,15 @@ def get_args(): parser.add_argument( "--extra_flags", required=False, - default=None, + default="", help="Extra cmake flags to pass the when building the executor_runner", ) + parser.add_argument( + "--extra_runtime_flags", + required=False, + default="", + help="Extra runtime flags to pass the final runner/executable", + ) parser.add_argument( "--timeout", required=False, @@ -96,10 +103,12 @@ def get_args(): return args -def run_external_cmd(cmd: []): +def run_external_cmd(cmd: Sequence[str]) -> None: print("CALL:", *cmd, sep=" ") try: - subprocess.check_call(cmd) + subprocess.check_call( + cmd + ) # nosec B603 - cmd assembled from vetted scripts/flags except subprocess.CalledProcessError as err: print("ERROR called: ", *cmd, sep=" ") print(f"Failed with: {err.returncode}") @@ -129,20 +138,18 @@ def build_pte( no_intermediate: bool, no_quantize: bool, ): - pte_file_ending = "pte" command_list = [ "python3", "-m", "examples.arm.aot_arm_compiler", "--delegate", + "--bundleio", f"--model_name={model_name}", f"--target={target}", f"--output={build_output}", ] if "vgf" != target: - pte_file_ending = "bpte" - command_list.append("--bundleio") command_list.append(f"--system_config={system_config}") command_list.append(f"--memory_mode={memory_mode}") @@ -154,6 +161,7 @@ def build_pte( run_external_cmd(command_list) + pte_file_ending = "bpte" pte_file = os.path.join( output, f"{model_name}_arm_delegate_{args.target}.{pte_file_ending}" ) @@ -184,7 +192,7 @@ def build_ethosu_runtime( "--build_type=Release", f"--system_config={system_config}", f"--memory_mode={memory_mode}", - f"--extra_build_flags=-DET_DUMP_OUTPUT=OFF {extra_flags}", + f"--extra_build_flags=-DET_LOG_DUMP_OUTPUT=OFF {extra_flags}", f"--output={elf_build_path}", ] ) @@ -217,6 +225,7 @@ def build_vkml_runtime( os.path.join(script_path, "build_executor_runner_vkml.sh"), f"--et_build_root={et_build_root}", "--etdump", + "--bundleio", "--build_type=Release", f"--extra_build_flags=-DET_DUMP_OUTPUT=OFF {extra_flags}", f"--output={build_path}", @@ -227,13 +236,14 @@ def build_vkml_runtime( return runner -def run_vkml(script_path: str, pte_file: str, runner_build_path: str): +def run_vkml(script_path: str, pte_file: str, runner_build_path: str, extra_flags: str): run_external_cmd( [ "bash", os.path.join(script_path, "run_vkml.sh"), f"--model={pte_file}", f"--build_path={runner_build_path}", + f"--optional_flags={extra_flags}", ] ) @@ -296,7 +306,7 @@ def run_vkml(script_path: str, pte_file: str, runner_build_path: str): ) start_time = time.perf_counter() - run_vkml(script_path, pte_file, build_path) + run_vkml(script_path, pte_file, build_path, args.extra_runtime_flags) end_time = time.perf_counter() print( f"[Test model: {end_time - start_time:.2f} s] Tested VKML runner: {vkml_runner}" diff --git a/backends/arm/test/tester/analyze_output_utils.py b/backends/arm/test/tester/analyze_output_utils.py index 82d4f5d9837..527413e9d8f 100644 --- a/backends/arm/test/tester/analyze_output_utils.py +++ b/backends/arm/test/tester/analyze_output_utils.py @@ -5,9 +5,9 @@ import logging import tempfile +from typing import Any, cast, Sequence import torch -from executorch.backends.arm.arm_backend import get_intermediate_path from executorch.backends.arm.test.runner_utils import ( get_input_quantization_params, get_output_quantization_params, @@ -18,9 +18,29 @@ logger = logging.getLogger(__name__) -def _print_channels(result, reference, channels_close, C, H, W, rtol, atol): +TensorLike = torch.Tensor | tuple[torch.Tensor, ...] + +def _ensure_tensor(value: TensorLike) -> torch.Tensor: + if isinstance(value, torch.Tensor): + return value + if value and isinstance(value[0], torch.Tensor): + return value[0] + raise TypeError("Expected a Tensor or a non-empty tuple of Tensors") + + +def _print_channels( + result: torch.Tensor, + reference: torch.Tensor, + channels_close: Sequence[bool], + C: int, + H: int, + W: int, + rtol: float, + atol: float, +) -> str: output_str = "" + exp = "000" booldata = False if reference.dtype == torch.bool or result.dtype == torch.bool: booldata = True @@ -63,7 +83,15 @@ def _print_channels(result, reference, channels_close, C, H, W, rtol, atol): return output_str -def _print_elements(result, reference, C, H, W, rtol, atol): +def _print_elements( + result: torch.Tensor, + reference: torch.Tensor, + C: int, + H: int, + W: int, + rtol: float, + atol: float, +) -> str: output_str = "" for y in range(H): res = "[" @@ -92,15 +120,17 @@ def _print_elements(result, reference, C, H, W, rtol, atol): return output_str -def print_error_diffs( - tester, - result: torch.Tensor | tuple, - reference: torch.Tensor | tuple, - quantization_scale=None, - atol=1e-03, - rtol=1e-03, - qtol=0, -): +def print_error_diffs( # noqa: C901 + tester_or_result: Any, + result_or_reference: TensorLike, + reference: TensorLike | None = None, + # Force remaining args to be keyword-only to keep the two positional call patterns unambiguous. + *, + quantization_scale: float | None = None, + atol: float = 1e-03, + rtol: float = 1e-03, + qtol: float = 0, +) -> None: """ Prints the error difference between a result tensor and a reference tensor in NCHW format. Certain formatting rules are applied to clarify errors: @@ -131,60 +161,81 @@ def print_error_diffs( """ - - if isinstance(reference, tuple): - reference = reference[0] - if isinstance(result, tuple): - result = result[0] - - if not result.shape == reference.shape: + if reference is None: + result = _ensure_tensor(cast(TensorLike, tester_or_result)) + reference_tensor = _ensure_tensor(result_or_reference) + else: + result = _ensure_tensor(result_or_reference) + reference_tensor = _ensure_tensor(reference) + + if result.shape != reference_tensor.shape: raise ValueError( - f"Output needs to be of same shape: {result.shape} != {reference.shape}" + f"Output needs to be of same shape: {result.shape} != {reference_tensor.shape}" ) shape = result.shape - - match len(shape): - case 4: - N, C, H, W = (shape[0], shape[1], shape[2], shape[3]) - case 3: - N, C, H, W = (1, shape[0], shape[1], shape[2]) - case 2: - N, C, H, W = (1, 1, shape[0], shape[1]) - case 1: - N, C, H, W = (1, 1, 1, shape[0]) - case 0: - N, C, H, W = (1, 1, 1, 1) - case _: - raise ValueError("Invalid tensor rank") + rank = len(shape) + + if rank == 5: + N, C, D, H, W = shape + elif rank == 4: + N, C, H, W = shape + D = 1 + elif rank == 3: + C, H, W = shape + N, D = 1, 1 + elif rank == 2: + H, W = shape + N, C, D = 1, 1, 1 + elif rank == 1: + W = shape[0] + N, C, D, H = 1, 1, 1, 1 + elif rank == 0: + N = C = D = H = W = 1 + else: + raise ValueError("Invalid tensor rank") + + if rank < 3: + C = 1 + if rank < 2: + H = 1 + if rank < 1: + W = 1 if quantization_scale is not None: atol += quantization_scale * qtol - # Reshape tensors to 4D NCHW format - result = torch.reshape(result, (N, C, H, W)) - reference = torch.reshape(reference, (N, C, H, W)) + # Reshape tensors to 4D NCHW format, optionally folding depth into batch. + total_batches = N * D + result = torch.reshape(result, (total_batches, C, H, W)) + reference_tensor = torch.reshape(reference_tensor, (total_batches, C, H, W)) output_str = "" - for n in range(N): - output_str += f"BATCH {n}\n" - result_batch = result[n, :, :, :] - reference_batch = reference[n, :, :, :] + for idx in range(total_batches): + batch_idx = idx // D if D > 0 else idx + depth_idx = idx % D if D > 0 else 0 + if D > 1: + output_str += f"BATCH {batch_idx} DEPTH {depth_idx}\n" + else: + output_str += f"BATCH {batch_idx}\n" + + result_batch = result[idx, :, :, :] + reference_batch = reference_tensor[idx, :, :, :] is_close = torch.allclose(result_batch, reference_batch, rtol, atol) if is_close: output_str += ".\n" else: - channels_close = [None] * C + channels_close: list[bool] = [False] * C for c in range(C): - result_hw = result[n, c, :, :] - reference_hw = reference[n, c, :, :] + result_hw = result[idx, c, :, :] + reference_hw = reference_tensor[idx, c, :, :] channels_close[c] = torch.allclose(result_hw, reference_hw, rtol, atol) if any(channels_close) or len(channels_close) == 1: output_str += _print_channels( - result[n, :, :, :], - reference[n, :, :, :], + result[idx, :, :, :], + reference_tensor[idx, :, :, :], channels_close, C, H, @@ -194,7 +245,13 @@ def print_error_diffs( ) else: output_str += _print_elements( - result[n, :, :, :], reference[n, :, :, :], C, H, W, rtol, atol + result[idx, :, :, :], + reference_tensor[idx, :, :, :], + C, + H, + W, + rtol, + atol, ) if reference_batch.dtype == torch.bool or result_batch.dtype == torch.bool: mismatches = (reference_batch != result_batch).sum().item() @@ -202,9 +259,9 @@ def print_error_diffs( output_str += f"(BOOLEAN tensor) {mismatches} / {total} elements differ ({mismatches / total:.2%})\n" # Only compute numeric error metrics if tensor is not boolean - if reference.dtype != torch.bool and result.dtype != torch.bool: - reference_range = torch.max(reference) - torch.min(reference) - diff = torch.abs(reference - result).flatten() + if reference_tensor.dtype != torch.bool and result.dtype != torch.bool: + reference_range = torch.max(reference_tensor) - torch.min(reference_tensor) + diff = torch.abs(reference_tensor - result).flatten() diff = diff[diff.nonzero()] if not len(diff) == 0: diff_percent = diff / reference_range @@ -231,21 +288,21 @@ def print_error_diffs( def dump_error_output( - tester, - reference_output, - stage_output, - quantization_scale=None, - atol=1e-03, - rtol=1e-03, - qtol=0, -): + tester: Any, + reference_output: TensorLike, + stage_output: TensorLike, + quantization_scale: float | None = None, + atol: float = 1e-03, + rtol: float = 1e-03, + qtol: float = 0, +) -> None: """ Prints Quantization info and error tolerances, and saves the differing tensors to disc. """ # Capture assertion error and print more info banner = "=" * 40 + "TOSA debug info" + "=" * 40 logger.error(banner) - path_to_tosa_files = get_intermediate_path(tester.compile_spec) + path_to_tosa_files = tester.compile_spec.get_intermediate_path() if path_to_tosa_files is None: path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_") @@ -274,11 +331,7 @@ def dump_error_output( if __name__ == "__main__": - import sys - - logging.basicConfig(stream=sys.stdout, level=logging.INFO) - - """ This is expected to produce the example output of print_diff""" + """This is expected to produce the example output of print_diff""" torch.manual_seed(0) a = torch.rand(3, 3, 2, 2) * 0.01 b = a.clone().detach() diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index fe17bd3f448..033f2331ae9 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -6,16 +6,19 @@ import copy import logging +import shutil +import tempfile -import os -from collections import Counter +from collections import Counter, defaultdict from pprint import pformat from typing import ( Any, Callable, + cast, Dict, Iterable, List, + no_type_check, Optional, Sequence, Tuple, @@ -25,32 +28,19 @@ import executorch.backends.xnnpack.test.tester.tester as tester -import serializer.tosa_serializer as ts # type: ignore[import-untyped] - import torch.fx import torch.utils._pytree as pytree +import tosa_serializer as ts + from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager -from executorch.backends.arm.arm_backend import ( - get_intermediate_path, - is_ethosu, - is_tosa, - is_vgf, -) -from executorch.backends.arm.ethosu import EthosUPartitioner -from executorch.backends.arm.quantizer import ( - EthosUQuantizer, - get_symmetric_quantization_config, - TOSAQuantizer, - VgfQuantizer, -) +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.ethosu import EthosUCompileSpec +from executorch.backends.arm.quantizer import get_symmetric_quantization_config from executorch.backends.arm.test.runner_utils import ( dbg_tosa_fb_to_json, - get_elf_path, get_output_quantization_params, - get_target_board, - run_target, TosaReferenceModelDispatch, ) @@ -58,15 +48,30 @@ dump_error_output, print_error_diffs, ) +from executorch.backends.arm.test.tester.quantize import ArmQuantize as Quantize +from executorch.backends.arm.test.tester.serialize import Serialize + from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.mapping import extract_tensor_meta -from executorch.backends.arm.tosa.partitioner import TOSAPartitioner -from executorch.backends.arm.tosa.specification import get_tosa_spec -from executorch.backends.arm.vgf import VgfPartitioner +from executorch.backends.arm.util._factory import ( + create_partitioner, + create_quantizer, + parse_compile_spec, +) +from executorch.backends.arm.vgf import VgfCompileSpec +from executorch.backends.test.harness.error_statistics import ErrorStatistics from executorch.backends.test.harness.stages import Stage, StageType -from executorch.backends.xnnpack.test.tester import Tester +from executorch.backends.xnnpack.test.tester import ( + Partition as XnnpackPartitionStage, + Quantize as XnnpackQuantize, + Tester, + ToEdge as XnnpackToEdge, + ToEdgeTransformAndLower as XnnpackToEdgeTransformAndLower, + ToExecutorch as XnnpackToExecutorch, +) from executorch.devtools.backend_debug import get_delegation_info from executorch.exir import ( @@ -77,12 +82,7 @@ to_edge_transform_and_lower, ) from executorch.exir.backend.backend_api import validation_disabled -from executorch.exir.backend.compile_spec_schema import CompileSpec -from executorch.exir.backend.operator_support import ( - DontPartition, - DontPartitionModule, - DontPartitionName, -) +from executorch.exir.backend.operator_support import OperatorSupportBase from executorch.exir.backend.partitioner import Partitioner from executorch.exir.lowered_backend_module import LoweredBackendModule from executorch.exir.pass_base import ExportPass @@ -91,33 +91,31 @@ _copy_module, _update_exported_program_graph_module, ) - -from tabulate import tabulate +from tabulate import tabulate # type: ignore[import-untyped] from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec from torch.fx import Graph -from torch.utils._pytree import tree_flatten +from torchao.quantization.pt2e.quantizer import QuantizationSpec, SharedQuantizationSpec +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY logger = logging.getLogger(__name__) def _dump_lowered_modules_artifact( path_to_dump: Optional[str], - artifact: ExecutorchProgramManager, - graph_module: torch.fx.GraphModule, -): + artifact: Union[EdgeProgramManager, ExecutorchProgramManager], + graph_module: torch.fx.GraphModule | None, +) -> None: + if graph_module is None: + logger.warning("No graph module available to dump lowered modules.") + return + output = "Formated Graph Signature:\n" output += _format_export_graph_signature( artifact.exported_program().graph_signature ) - def get_output_format(lowered_module) -> str | None: - for spec in lowered_module.compile_specs: - if spec.key == "output_format": - return spec.value.decode() - return None - for node in graph_module.graph.nodes: if node.op == "get_attr" and node.name.startswith("lowered_module_"): lowered_module = getattr(graph_module, node.name) @@ -125,15 +123,15 @@ def get_output_format(lowered_module) -> str | None: lowered_module, LoweredBackendModule ), f"Attribute {node.name} must be of type LoweredBackendModule." - output_format = get_output_format(lowered_module) - if output_format == "tosa": + compile_spec = parse_compile_spec(lowered_module.compile_specs) + if isinstance(compile_spec, TosaCompileSpec): tosa_fb = lowered_module.processed_bytes to_print = dbg_tosa_fb_to_json(tosa_fb) to_print = pformat(to_print, compact=True, indent=1) output += f"\nTOSA deserialized {node.name}: \n{to_print}\n" - elif output_format == "vela": + elif isinstance(compile_spec, EthosUCompileSpec): vela_cmd_stream = lowered_module.processed_bytes - output += f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n" + output += f"\nVela command stream {node.name}: \n{vela_cmd_stream!r}\n" else: logger.warning( f"No TOSA nor Vela compile spec found in compile specs of {node.name}." @@ -150,7 +148,14 @@ def get_output_format(lowered_module) -> str | None: class Partition(tester.Partition): def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) - _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) + artifact = cast(Optional[EdgeProgramManager], self.artifact) + graph_module = cast(Optional[torch.fx.GraphModule], self.graph_module) + if artifact is None: + logger.warning( + "Partition stage artifact missing; skipping lowered module dump." + ) + return + _dump_lowered_modules_artifact(path_to_dump, artifact, graph_module) class ToEdgeTransformAndLower(tester.ToEdgeTransformAndLower): @@ -169,7 +174,14 @@ def __init__( def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) - _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) + artifact = cast(Optional[EdgeProgramManager], self.artifact) + graph_module = cast(Optional[torch.fx.GraphModule], self.graph_module) + if artifact is None: + logger.warning( + "ToEdgeTransformAndLower stage artifact missing; skipping lowered module dump." + ) + return + _dump_lowered_modules_artifact(path_to_dump, artifact, graph_module) def run( self, artifact: ExportedProgram, inputs=None, generate_etrecord: bool = False @@ -185,43 +197,6 @@ def run( ) -class Serialize(tester.Serialize): - def __init__(self, compile_spec: list[CompileSpec], timeout): - super().__init__() - self.timeout = timeout - self.executorch_program_manager: ExecutorchProgramManager | None - self.compile_spec = compile_spec - - def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None: - super().run(artifact, inputs) - # Keep the entire ExecutorchProgramManager for execution. - self.executorch_program_manager = artifact - - def run_artifact(self, inputs): - if self.executorch_program_manager is None: - raise RuntimeError( - "Tried running artifact from Serialize stage without running the stage." - ) - inputs_flattened, _ = tree_flatten(inputs) - intermediate_path = get_intermediate_path(self.compile_spec) - target_board = get_target_board(self.compile_spec) - elf_path = get_elf_path(target_board) - - if not os.path.exists(elf_path): - raise FileNotFoundError( - f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?" - ) - - return run_target( - self.executorch_program_manager, - inputs_flattened, - intermediate_path, - target_board, - elf_path, - self.timeout, - ) - - class ToExecutorch(tester.ToExecutorch): def run_artifact(self, inputs): with TosaReferenceModelDispatch(): @@ -229,16 +204,18 @@ def run_artifact(self, inputs): class RunPasses(tester.RunPasses): - + @no_type_check def __init__( self, - pass_list: Optional[List[Type[ExportPass]]] = None, + pass_list: Optional[List[Type[PassType]]] = None, pass_functions: Optional[List[Callable]] = None, passes_with_exported_program: Optional[List[Type[ExportPass]]] = None, ): """Passes are run in the order they are passed: first pass_list, second pass_functions, and lastly passes_with_exported_program.""" - self.pass_with_exported_program = passes_with_exported_program + self.pass_with_exported_program: Optional[List[Type[ExportPass]]] = ( + passes_with_exported_program + ) super().__init__(pass_list, pass_functions) @@ -246,14 +223,15 @@ def run( self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None ) -> None: if self.pass_with_exported_program is not None: - self.pass_functions = self.pass_functions or [] # type: ignore + pass_functions = list(self.pass_functions or []) # type: ignore[has-type] # pass_function list from superclass expects functions that take in # and return ExportedPrograms. # Create a wrapper to fit pass_with_exported_program into this. def wrap_ep_pass(ep_pass: Type[ExportPass]): def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram: - pass_result = ep_pass(ep).call(ep.graph_module) + pass_instance = ep_pass(ep) # type: ignore[call-arg] + pass_result = pass_instance.call(ep.graph_module) with validation_disabled(): return _update_exported_program_graph_module( ep, pass_result.graph_module @@ -261,9 +239,10 @@ def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram: return wrapped_ep_pass - self.pass_functions.extend( + pass_functions.extend( [wrap_ep_pass(ep_pass) for ep_pass in self.pass_with_exported_program] ) + self.pass_functions = pass_functions super().run(artifact, inputs) @@ -296,20 +275,22 @@ class ArmTester(Tester): def __init__( self, model: torch.nn.Module, - example_inputs: Tuple, - compile_spec: List[CompileSpec], + example_inputs: Tuple[Any, ...], + compile_spec: ArmCompileSpec, tosa_ref_model_path: str | None = None, dynamic_shapes: Optional[Tuple[Any]] = None, constant_methods: Optional[Dict[str, Any]] = None, transform_passes: Optional[ Union[Sequence[PassType], Dict[str, Sequence[PassType]]] ] = None, + use_portable_ops: bool = False, + timeout: int = 600, ): """ Args: model (torch.nn.Module): The model to test example_inputs (Tuple[torch.Tensor]): Example inputs to the model - compile_spec (List[CompileSpec]): The compile spec to use + compile_spec (ArmCompileSpec): The compile spec to use """ self.transform_passes = transform_passes @@ -320,35 +301,37 @@ def __init__( StageType.QUANTIZE, StageType.EXPORT, ] + self.original_module.requires_grad_(False) # Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry. - self.stages[StageType.INITIAL_MODEL] = None + self.stages[StageType.INITIAL_MODEL] = cast(Stage, None) self._run_stage(InitialModel(self.original_module)) + self.use_portable_ops = use_portable_ops + self.timeout = timeout + @no_type_check def quantize( self, - quantize_stage: Optional[tester.Quantize] = None, + quantize_stage: Optional[XnnpackQuantize] = None, ): + # Same stage type as parent but exposed via module alias if quantize_stage is None: - quantizer = None - if is_tosa(self.compile_spec): - tosa_spec = get_tosa_spec(self.compile_spec) - quantizer = TOSAQuantizer(tosa_spec) - elif is_ethosu(self.compile_spec): - quantizer = EthosUQuantizer(self.compile_spec) - elif is_vgf(self.compile_spec): - quantizer = VgfQuantizer(self.compile_spec) - quantize_stage = tester.Quantize( + quantizer = create_quantizer(self.compile_spec) + quantize_stage = Quantize( quantizer, get_symmetric_quantization_config(), ) return super().quantize(quantize_stage) + @no_type_check def to_edge( self, - to_edge_stage: Optional[tester.ToEdge] = None, + to_edge_stage: Optional[XnnpackToEdge] = None, + # Keep config keyword-only to avoid positional clashes with legacy calls. + *, config: Optional[EdgeCompileConfig] = None, ): + # Allow optional config override beyond base signature if to_edge_stage is None: to_edge_stage = tester.ToEdge(config) else: @@ -357,49 +340,40 @@ def to_edge( return super().to_edge(to_edge_stage) - def partition(self, partition_stage: Optional[Partition] = None): + @no_type_check + def partition(self, partition_stage: Optional[XnnpackPartitionStage] = None): + # Accept Arm-specific partition stage subclass if partition_stage is None: - if is_tosa(self.compile_spec): - arm_partitioner = TOSAPartitioner(compile_spec=self.compile_spec) - elif is_ethosu(self.compile_spec): - arm_partitioner = EthosUPartitioner(compile_spec=self.compile_spec) - else: - raise ValueError("compile spec doesn't target any Arm Partitioner") + arm_partitioner = create_partitioner(self.compile_spec) partition_stage = Partition(arm_partitioner) return super().partition(partition_stage) + @no_type_check def to_edge_transform_and_lower( self, - to_edge_and_lower_stage: Optional[ToEdgeTransformAndLower] = None, + to_edge_and_lower_stage: Optional[XnnpackToEdgeTransformAndLower] = None, + generate_etrecord: bool = False, + # Force the optional tuning knobs to be keyword-only for readability/back-compat. + *, partitioners: Optional[List[Partitioner]] = None, edge_compile_config: Optional[EdgeCompileConfig] = None, - additional_checks: Optional[ - List[Union[DontPartition | DontPartitionModule | DontPartitionName]] - ] = None, + additional_checks: Optional[Sequence[OperatorSupportBase]] = None, transform_passes: Optional[ Union[Sequence[PassType], Dict[str, Sequence[PassType]]] ] = None, ): + # Arm flow exposes extra stage wiring knobs + if transform_passes is not None: + raise RuntimeError( + "transform passes are given to ArmTester at construction." + ) + if to_edge_and_lower_stage is None: if partitioners is None: - arm_partitioner = None - if is_tosa(self.compile_spec): - arm_partitioner = TOSAPartitioner( - compile_spec=self.compile_spec, - additional_checks=additional_checks, - ) - elif is_ethosu(self.compile_spec): - arm_partitioner = EthosUPartitioner( - compile_spec=self.compile_spec, - additional_checks=additional_checks, - ) - elif is_vgf(self.compile_spec): - arm_partitioner = VgfPartitioner( - compile_spec=self.compile_spec, - additional_checks=additional_checks, - ) - else: - raise ValueError("compile spec doesn't target any Arm Partitioner") + operator_checks = ( + list(additional_checks) if additional_checks is not None else None + ) + arm_partitioner = create_partitioner(self.compile_spec, operator_checks) partitioners = [arm_partitioner] to_edge_and_lower_stage = ToEdgeTransformAndLower( partitioners, @@ -412,20 +386,34 @@ def to_edge_transform_and_lower( to_edge_and_lower_stage.partitioners = partitioners if edge_compile_config is not None: to_edge_and_lower_stage.edge_compile_conf = edge_compile_config - return super().to_edge_transform_and_lower(to_edge_and_lower_stage) + return super().to_edge_transform_and_lower( + to_edge_and_lower_stage, generate_etrecord=generate_etrecord + ) - def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] | None = None): + @no_type_check + def to_executorch(self, to_executorch_stage: Optional[XnnpackToExecutorch] = None): + # Allow custom ExecuTorch stage subclass if to_executorch_stage is None: to_executorch_stage = ToExecutorch() return super().to_executorch(to_executorch_stage) + @no_type_check def serialize( - self, serialize_stage: Optional[Serialize] = None, timeout: int = 480 + self, + serialize_stage: Optional[Serialize] = None, + # Keep timeout keyword-only so positional usage matches the base class. + *, + timeout: int = 480, ): if serialize_stage is None: - serialize_stage = Serialize(self.compile_spec, timeout) + serialize_stage = Serialize( + compile_spec=self.compile_spec, + module=self.original_module, + use_portable_ops=self.use_portable_ops, + timeout=self.timeout, + ) assert ( - get_intermediate_path(self.compile_spec) is not None + self.compile_spec.get_intermediate_path() is not None ), "Can't dump serialized file when compile specs do not contain an artifact path." return super().serialize(serialize_stage) @@ -435,14 +423,17 @@ def is_quantized(self) -> bool: def run_method_and_compare_outputs( self, - inputs: Optional[Tuple[torch.Tensor]] = None, - stage: Optional[str] = None, - num_runs=1, - atol=1e-03, - rtol=1e-03, - qtol=0, - error_callbacks=None, - run_eager_mode=False, + stage: Optional[StageType] = None, + inputs: Optional[Tuple[torch.Tensor, ...]] = None, + num_runs: int = 1, + atol: float = 1e-03, + rtol: float = 1e-03, + qtol: int = 0, + statistics_callback: Callable[[ErrorStatistics], None] | None = None, + # Preserve positional compatibility while keeping new flags keyword-only. + *, + error_callbacks: Optional[Sequence[Callable[..., None]]] = None, + run_eager_mode: bool = False, ): """ Compares the run_artifact output of 'stage' with the output of a reference stage. @@ -459,6 +450,12 @@ def run_method_and_compare_outputs( The default is random data. """ + # backward-compatible ordering (accept inputs as the first positional argument) + if inputs is None and isinstance(stage, tuple): + if all(isinstance(arg, torch.Tensor) for arg in stage): + inputs = cast(Tuple[torch.Tensor, ...], stage) + stage = None + if not run_eager_mode: edge_stage = self.stages[StageType.TO_EDGE] if edge_stage is None: @@ -475,6 +472,8 @@ def run_method_and_compare_outputs( ), "To compare outputs in eager mode, the model must be at Export stage" stage = stage or self.cur + if stage is None: + raise RuntimeError("No stage has been executed yet.") test_stage = self.stages[stage] is_quantized = self.is_quantized() @@ -483,7 +482,8 @@ def run_method_and_compare_outputs( else: reference_stage = self.stages[StageType.INITIAL_MODEL] - exported_program = self.stages[StageType.EXPORT].artifact + exported_stage = self.stages[StageType.EXPORT] + exported_program = cast(ExportedProgram, exported_stage.artifact) output_node = exported_program.graph_module.graph.output_node() output_qparams = get_output_quantization_params(output_node) @@ -496,9 +496,15 @@ def run_method_and_compare_outputs( ) # Loop inputs and compare reference stage with the compared stage. - for run_iteration in range(num_runs): + number_of_runs = 1 if inputs is not None else num_runs + + for run_iteration in range(number_of_runs): reference_input = inputs if inputs else next(self.generate_random_inputs()) + # Avoid issues with inplace operators + test_input = copy.deepcopy(reference_input) + original_input = copy.deepcopy(reference_input) + input_shapes = [ generated_input.shape if hasattr(generated_input, "shape") else (1,) for generated_input in reference_input @@ -511,18 +517,17 @@ def run_method_and_compare_outputs( ) if run_eager_mode: # Run exported module directly - test_outputs, _ = pytree.tree_flatten( - self._calculate_reference_output( - exported_program.module(), reference_input - ) + eager_output, _ = self._calculate_reference_output( + exported_program, test_input ) + test_outputs, _ = pytree.tree_flatten(eager_output) else: # Run lowered model with target test_outputs, _ = pytree.tree_flatten( - test_stage.run_artifact(reference_input) + test_stage.run_artifact(test_input) ) - logger.info(f"\n Input: {reference_input}") + logger.info(f"\n Input: {original_input}") logger.info(f"\n Ref output: {reference_outputs}") logger.info(f"\nTest output: {test_outputs}") @@ -536,14 +541,203 @@ def run_method_and_compare_outputs( atol, rtol, qtol, - error_callbacks, + statistics_callback=statistics_callback, + error_callbacks=error_callbacks, ) return self - def get_graph(self, stage: str | None = None) -> Graph: + def _get_output_qspec_from_node( + self, node: torch.fx.Node + ) -> QuantizationSpec | None: + if Q_ANNOTATION_KEY not in node.meta: + return None + annotation = node.meta[Q_ANNOTATION_KEY] + # If annotation.output_qspec is a SharedQuantizationSpec, we need to find + # the actual QuantizationSpec from one of the inputs. + if isinstance(annotation.output_qspec, SharedQuantizationSpec): + # First try to find a non-shared qspec from the inputs. + annotation_qspec = [ + qspec + for qspec in annotation.input_qspec_map.values() + if not isinstance(qspec, SharedQuantizationSpec) + ] + # If none of the inputs have a non-shared qspec, we need to + # find the source node of the shared qspec. + if len(annotation_qspec) == 0: + edge_or_node = annotation.output_qspec.edge_or_node + if isinstance(edge_or_node, tuple): + source_node = edge_or_node[0] + else: + source_node = edge_or_node + annotation_qspec = [source_node.meta[Q_ANNOTATION_KEY].output_qspec] + annotation_qspec = annotation_qspec[0] + else: + annotation_qspec = annotation.output_qspec + + return annotation_qspec + + def _get_input_qspecs_from_node( + self, node: torch.fx.Node + ) -> List[QuantizationSpec | None]: + if Q_ANNOTATION_KEY not in node.meta: + return [None] + annotation = node.meta[Q_ANNOTATION_KEY] + input_qspec_map = annotation.input_qspec_map + found_qspecs = [] + if len(input_qspec_map) == 0: + return [None] + for spec in input_qspec_map.values(): + # If spec is a SharedQuantizationSpec, we need to find + # the actual QuantizationSpec. + if isinstance(spec, SharedQuantizationSpec): + # First try to find a non-shared qspec from the inputs. + annotation_qspec = [ + qspec + for qspec in input_qspec_map.values() + if not isinstance(qspec, SharedQuantizationSpec) + ] + # If none of the inputs have a non-shared qspec, we need to + # find the source node of the shared qspec. + if len(annotation_qspec) == 0: + edge_or_node = annotation.output_qspec.edge_or_node + if isinstance(edge_or_node, tuple): + source_node = edge_or_node[0] + else: + source_node = edge_or_node + annotation_qspec = [source_node.meta[Q_ANNOTATION_KEY].output_qspec] + found_qspecs.append(annotation_qspec[0]) + else: + found_qspecs.append(spec) + + return found_qspecs + + def _check_input_qspecs(self, graph: Graph, input_qspecs): + if input_qspecs is None: + return + found_qspecs = [] + for node in graph.nodes: + if node.op != "placeholder": + continue + annotation_qspec = self._get_output_qspec_from_node(node) + found_qspecs.append(annotation_qspec) + found_qspecs_counter = Counter(found_qspecs) + for qspec in input_qspecs: + # check that each expected qspec is found + if qspec not in found_qspecs_counter: + raise AssertionError( + f"Expected to find input quantization annotation {qspec}, but it was not found. " + f"Found annotations: {found_qspecs_counter}" + ) + # check that number of occurrences of each qspec matches expected + if found_qspecs_counter[qspec] != input_qspecs[qspec]: + raise AssertionError( + f"Expected to find {input_qspecs[qspec]} instances of input quantization annotation {qspec}, but " + f"found {found_qspecs_counter[qspec]} instances." + ) + + def _check_output_qspecs(self, graph: Graph, output_qspecs): + if output_qspecs is None: + return + found_qspecs = [] + output_node = graph.output_node() + annotation_qspec = self._get_input_qspecs_from_node(output_node) + found_qspecs.extend(annotation_qspec) + found_qspecs_counter = Counter(found_qspecs) + for qspec in output_qspecs: + # check that each expected qspec is found + if qspec not in found_qspecs_counter: + raise AssertionError( + f"Expected to find output quantization annotation {qspec}, but it was not found. " + f"Found annotations: {found_qspecs_counter}" + ) + # check that number of occurrences of each qspec matches expected + if found_qspecs_counter[qspec] != output_qspecs[qspec]: + raise AssertionError( + f"Expected to find {output_qspecs[qspec]} instances of output quantization annotation {qspec}, but " + f"found {found_qspecs_counter[qspec]} instances." + ) + + def _check_qspecs(self, graph: Graph, quantization_annotations): + if quantization_annotations is None: + return self + + quantization_annotations_found: List[Tuple[str, QuantizationSpec | None]] = [] + for node in graph.nodes: + if node.op != "call_function": + continue + quantization_annotations_found.append( + (str(node.target), self._get_output_qspec_from_node(node)) + ) + + # Counter: (target, qspec) -> count + quantization_annotations_found_counter = Counter(quantization_annotations_found) + # Convert counter to Dict[target, Dict[qspec, count]] + quantization_annotations_found_dict: Dict[ + str, Dict[QuantizationSpec | None, int] + ] = defaultdict(dict) + for (target, qspec), count in quantization_annotations_found_counter.items(): + quantization_annotations_found_dict[target][qspec] = count + + for target, qspecs in quantization_annotations.items(): + # check if target is in found annotations + if target not in quantization_annotations_found_dict: + raise AssertionError( + f"Expected to find quantization annotation for operator {target}, but it was not found." + ) + for qspec in qspecs: + # check if qspec is in found annotations for target + if qspec not in quantization_annotations_found_dict[target]: + raise AssertionError( + f"Expected to find quantization annotation {qspec} for operator {target}, but it was not found. " + f"Found annotations: {quantization_annotations_found_dict[target]}" + ) + # check that number of occurrences of each qspec matches expected + if quantization_annotations_found_dict[target][qspec] != qspecs[qspec]: + raise AssertionError( + f"Expected to find {qspecs[qspec]} instances of quantization annotation {qspec} for operator " + f"{target}, but found {quantization_annotations_found_dict[target][qspec]} instances." + ) + + def check_quantization_annotation( + self, + quantization_annotations: Optional[ + Dict[str, Dict[QuantizationSpec | None, int]] + ] = None, + input_qspecs: Optional[Dict[QuantizationSpec | None, int]] = None, + output_qspecs: Optional[Dict[QuantizationSpec | None, int]] = None, + ): + """ + Check the quantization annotations in the graph of a quantized model. + + Args: + quantization_annotations: A dictionary mapping operator names to a dictionary of + QuantizationSpecs and their expected counts. + If None, the check is skipped. + input_qspecs: A dictionary of expected input QuantizationSpecs and their counts. + If None, the check is skipped. + output_qspecs: A dictionary of expected output QuantizationSpecs and their counts. + If None, the check is skipped. + + Returns self for daisy-chaining. + """ + if not self.is_quantized(): + raise RuntimeError( + f"{self.check_quantization_annotation.__name__} should be called after quantization stage." + ) + + graph = self.get_graph(StageType.QUANTIZE) + + self._check_input_qspecs(graph, input_qspecs) + self._check_output_qspecs(graph, output_qspecs) + self._check_qspecs(graph, quantization_annotations) + return self + + def get_graph(self, stage: StageType | None = None) -> Graph: if stage is None: stage = self.cur + if stage is None: + raise RuntimeError("No stage has been executed yet.") artifact = self.get_artifact(stage) if ( self.cur == StageType.TO_EDGE @@ -561,7 +755,10 @@ def get_graph(self, stage: str | None = None) -> Graph: return graph def dump_operator_distribution( - self, path_to_dump: Optional[str] = None, print_table: bool = True + self, + path_to_dump: Optional[str] = None, + print_table: bool = True, + include_dtypes: bool = True, ): """Dump the distribution of operators in the current stage. In the partition stage, additional information is included such as the number of @@ -583,25 +780,41 @@ def dump_operator_distribution( and print_table ): graph_module = self.get_artifact().exported_program().graph_module + delegation_info = get_delegation_info(graph_module) if print_table: - delegation_info = get_delegation_info(graph_module) op_dist = delegation_info.get_operator_delegation_dataframe() + op_dist = _get_tosa_operator_distribution(graph_module, include_dtypes) + if include_dtypes: + op_dist = { + "Operator": [op_type[0] for op_type, _ in op_dist], + "Dtype": [op_type[1] for op_type, _ in op_dist], + "Count": [count for _, count in op_dist], + } else: - op_dist = dict(_get_operator_distribution(graph_module.graph)) - to_print += _format_dict(op_dist, print_table) - to_print += "\n" + _get_tosa_operator_distribution( - graph_module, print_table - ) - to_print += "\n" - to_print += delegation_info.get_summary() + op_dist = { + "Operator": [op for op, _ in op_dist], + "Count": [count for _, count in op_dist], + } + to_print += "TOSA operators:\n" + _format_dict(dict(op_dist), print_table) + to_print += "\n" + delegation_info.get_summary() else: graph = self.get_graph(self.cur) - op_dist = dict(_get_operator_distribution(graph)) + if include_dtypes: + op_dist = _get_operator_dtype_distribution(graph) + else: + op_dist = _get_operator_distribution(graph) if print_table: - op_dist = { - "Operator": list(op_dist), - "Count": [op_dist[key] for key in op_dist], - } + if include_dtypes: + op_dist = { + "Operator": [op_dtype[0] for op_dtype, _ in op_dist], + "Dtype": [op_dtype[1] for op_dtype, _ in op_dist], + "Count": [count for _, count in op_dist], + } + else: + op_dist = { + "Operator": [op for op, _ in op_dist], + "Count": [count for _, count in op_dist], + } to_print += _format_dict(op_dist, print_table) + "\n" _dump_str(to_print, path_to_dump) @@ -621,13 +834,14 @@ def dump_dtype_distribution( to_print = f"{line} {self.cur} Placeholder Dtype Distribution {line}\n" graph = self.get_graph(self.cur) - tosa_spec = get_tosa_spec(self.compile_spec) + tosa_spec = self.compile_spec.tosa_spec dtype_dist_placeholders, dtype_dirst_tensors = _get_dtype_distribution( graph, tosa_spec ) all_dtypes = set(dtype_dist_placeholders.keys()) | set( dtype_dirst_tensors.keys() ) + dtype_dist: dict[str, Any] if print_table: dtype_dist = { "Dtype": all_dtypes, @@ -645,13 +859,14 @@ def dump_dtype_distribution( ], } else: - dtype_dist = dict(dtype_dist_placeholders + dtype_dirst_tensors) + combined_counts = dtype_dist_placeholders + dtype_dirst_tensors + dtype_dist = {key: combined_counts[key] for key in combined_counts} to_print += _format_dict(dtype_dist, print_table) + "\n" _dump_str(to_print, path_to_dump) return self def run_transform_for_annotation_pipeline( - self, stage: str | None = None + self, stage: StageType | None = None ) -> torch.fx.GraphModule: """Run transform_for_annotation_pipeline on exported program to ensure passes do not break the initial model before quantization. @@ -665,12 +880,14 @@ def run_transform_for_annotation_pipeline( if stage is None: stage = self.cur + if stage is None: + raise RuntimeError("No stage has been executed yet.") # We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run. artifact = self.get_artifact(stage) if self.cur == StageType.EXPORT: - new_gm = ArmPassManager(get_tosa_spec(self.compile_spec)).transform_for_annotation_pipeline( # type: ignore[arg-type] - graph_module=artifact.graph_module - ) + new_gm = ArmPassManager( + self.compile_spec + ).transform_for_annotation_pipeline(graph_module=artifact.graph_module) else: raise RuntimeError("Can only run passes on Export stage.") _copy_module(artifact.graph_module, new_gm) @@ -678,8 +895,8 @@ def run_transform_for_annotation_pipeline( @staticmethod def _calculate_reference_output( - module: Union[torch.fx.GraphModule, torch.nn.Module], inputs - ) -> torch.Tensor: + program: ExportedProgram, inputs: Tuple[Any, ...] + ) -> Tuple[torch.Tensor, Optional[float]]: """ Note: I'd prefer to use the base class method here, but since it use the exported program, I can't. The partitioner stage clears the state_dict @@ -687,8 +904,10 @@ def _calculate_reference_output( module. """ - return module.forward(*inputs) + module = program.module() + return module.forward(*inputs), None + @no_type_check def _compare_outputs( self, reference_output, @@ -697,36 +916,89 @@ def _compare_outputs( atol=1e-03, rtol=1e-03, qtol=0, - error_callbacks=None, + statistics_callback: Callable[[ErrorStatistics], None] | None = None, + # Extra debugging hooks are keyword-only to keep the signature stable. + *, + error_callbacks: Optional[Sequence[Callable[..., None]]] = None, ): + # Accept extra error callback hook for debugging try: super()._compare_outputs( - reference_output, stage_output, quantization_scale, atol, rtol, qtol + reference_output, + stage_output, + quantization_scale, + atol, + rtol, + qtol, + statistics_callback=statistics_callback, ) except AssertionError as e: - if error_callbacks is None: - error_callbacks = [print_error_diffs, dump_error_output] - for callback in error_callbacks: + callbacks = ( + list(error_callbacks) + if error_callbacks is not None + else [print_error_diffs, dump_error_output] + ) + for callback in callbacks: callback( self, stage_output, reference_output, - quantization_scale=None, + quantization_scale=quantization_scale, atol=1e-03, rtol=1e-03, qtol=0, ) raise e + def __del__(self): + intermediate_path = self.compile_spec.get_intermediate_path() + if not intermediate_path: + return + if len(tempdir := tempfile.gettempdir()) > 0: + if intermediate_path.startswith(tempdir): + shutil.rmtree(intermediate_path, ignore_errors=True) + + def check_dtype_count(self, dtype_dict: Dict[str, Dict[str, int]]): + if self.cur in ( + StageType.PARTITION, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + ): + graph_module = self.get_artifact().exported_program().graph_module + op_dist = _get_tosa_operator_distribution(graph_module, include_dtypes=True) + op_dist_dict: Dict[str, Dict[str, int]] = defaultdict(dict) + for op_dtype, count in op_dist: + if isinstance(op_dtype, str): + raise ValueError( + f"Expected {_get_tosa_operator_distribution.__name__} to return " + "Tuple[Tuple[str, str], int]." + ) + else: + op, dtype = op_dtype + + op_dist_dict[op].update({dtype: count}) + for op in dtype_dict.keys(): + if op not in op_dist_dict: + raise RuntimeError(f"Could not find op {op}.") + for dtype, count in dtype_dict[op].items(): + dtype_count = op_dist_dict[op].setdefault(dtype, 0) + if dtype_count != count: + raise RuntimeError( + f"Expected {count} occurencies of {op=}, {dtype=} but found {dtype_count}." + ) + + else: + + raise NotImplementedError(f"Cannot check dtypes for stage {self.cur}") + def _get_dtype_distribution( graph: Graph, tosa_spec: TosaSpecification -) -> tuple[dict, dict]: +) -> tuple[Counter[str], Counter[str]]: """Counts the occurences of placeholder and call_function dtypes in a graph. The result is a tuple of Counters (placeholder_distribution, call_function_distribution) """ - placeholder_dtypes = [] - call_function_dtypes = [] + placeholder_dtypes: list[str] = [] + call_function_dtypes: list[str] = [] for node in graph.nodes: if node.op == "placeholder": placeholder_dtypes.append(str(node.meta["val"].dtype)) @@ -737,17 +1009,39 @@ def _get_dtype_distribution( return Counter(placeholder_dtypes), Counter(call_function_dtypes) -def _get_operator_distribution(graph: Graph) -> dict[str, int]: +def _get_operator_distribution(graph: Graph) -> List[Tuple[str, int]]: """Counts the occurences of operator names in a graph. - The result is a dict {'operator name':'number of nodes'} + The result is a sorted list [('operator name':'number of nodes')] """ - return Counter( - [str(node.target) for node in list(graph.nodes) if node.op == "call_function"] + return sorted( + Counter( + [ + str(node.target) + for node in list(graph.nodes) + if node.op == "call_function" + ] + ).items() ) +def _get_operator_dtype_distribution(graph: Graph) -> List[Tuple[Tuple[str, str], int]]: + """Counts the occurences of operator names and dtype pairs in a graph. + The result is a sorted list[(('operator name','dtype'),'number of nodes')] + """ + target_dtype_pairs = [] + for node in graph.nodes: + if node.op != "call_function": + continue + if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor): + dtype = str(node.meta["val"].dtype) + else: + dtype = "UNKNOWN" + target_dtype_pairs.append((str(node.target), dtype)) + return sorted(Counter(target_dtype_pairs).items()) + + def _format_export_graph_signature(signature: ExportGraphSignature) -> str: - def specs_dict(specs: list[InputSpec | OutputSpec], title: str): + def specs_dict(specs: Sequence[InputSpec | OutputSpec], title: str): _dict: dict[str, list] = {title: [], "arg": [], "kind": [], "target": []} for i, spec in enumerate(specs): _dict[title].append(i) @@ -763,40 +1057,58 @@ def specs_dict(specs: list[InputSpec | OutputSpec], title: str): def _get_tosa_operator_distribution( - graph_module: torch.fx.GraphModule, print_table=False -) -> str: + graph_module: torch.fx.GraphModule, include_dtypes=False +) -> list[Tuple[str, int]] | list[Tuple[Tuple[str, str], int]]: """Counts the occurences of operator names of all lowered modules containing a TOSA flatbuffer. The result is a string with the operator distribution or an error message. """ - op_list = [] id = 0 + unknown_dtype_str = "UNKNOWN" + op_list = [] while lowered_module := getattr(graph_module, f"lowered_module_{id}", None): - for spec in lowered_module.compile_specs: - if spec.key != "output_format": - continue - if spec.value == b"tosa": - tosa_fb = lowered_module.processed_bytes - tosa_json = dbg_tosa_fb_to_json(tosa_fb) - for region in tosa_json["regions"]: - for block in region["blocks"]: - op_list.extend( - [operator["op"] for operator in block["operators"]] - ) - break - elif spec.value == b"vela": - return "Can not get operator distribution for Vela command stream." - else: - return f"Unknown output format '{spec.value}'." + compile_spec = parse_compile_spec(lowered_module.compile_specs) + if isinstance(compile_spec, TosaCompileSpec): + tosa_fb = lowered_module.processed_bytes + tosa_json = dbg_tosa_fb_to_json(tosa_fb) + for region in tosa_json["regions"]: + for block in region["blocks"]: + for operator in block["operators"]: + op = operator["op"] + if include_dtypes: + outputs = operator.get("outputs", []) + if outputs == []: + op_list.append((op, unknown_dtype_str)) + continue + tensor_block = block.get("tensors", {}) + tensors_with_matching_name = [ + t for t in tensor_block if t["name"] == outputs[0] + ] + dtype = ( + tensors_with_matching_name[0]["type"] + if len(tensors_with_matching_name) > 0 + else unknown_dtype_str + ) + op_list.append((op, dtype)) + else: + op_list.append(op) + + elif isinstance(compile_spec, EthosUCompileSpec): + raise NotImplementedError( + "Can not get operator distribution for Vela command stream." + ) + elif isinstance(compile_spec, VgfCompileSpec): + raise NotImplementedError("Can not get operator distribution for VGF.") + else: + raise NotImplementedError( + f"Unknown output format '{compile_spec.get_output_format()}'." + ) id += 1 if id == 0: - return "No delegate with name 'lowered_module_0 found in graph module." - op_dist = dict(Counter(op_list)) - op_dist = { - "Operator": list(op_dist.keys()), - "Count": [item[1] for item in op_dist.items()], - } - return "TOSA operators:\n" + _format_dict(dict(op_dist), print_table) + raise ValueError( + "No delegate with name 'lowered_module_0 found in graph module." + ) + return sorted(Counter(op_list).items()) def _dump_str(to_print: str, path_to_dump: Optional[str] = None): @@ -804,7 +1116,7 @@ def _dump_str(to_print: str, path_to_dump: Optional[str] = None): with open(path_to_dump, "a") as fp: fp.write(to_print) else: - logger.info(to_print) + print(to_print) def _format_dict(to_print: dict, print_table: bool = True) -> str: diff --git a/backends/arm/test/tester/quantize.py b/backends/arm/test/tester/quantize.py new file mode 100644 index 00000000000..18ecd401efe --- /dev/null +++ b/backends/arm/test/tester/quantize.py @@ -0,0 +1,43 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +from executorch.backends.arm.quantizer import TOSAQuantizer +from executorch.backends.test.harness.stages.quantize import Quantize + +from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( + DuplicateDynamicQuantChainPass, +) + +from torch.export import export + + +class ArmQuantize(Quantize): + + def run( + self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] + ) -> None: + assert inputs is not None + if self.is_qat: + artifact.train() + captured_graph = export(artifact, inputs, strict=True).module() + + if not isinstance(self.quantizer, TOSAQuantizer): + raise ValueError("ArmQuantizer can only run with TOSAQuantizer.") + + if self.calibration_samples is not None: + converted = self.quantizer.quantize_with_submodules( + captured_graph, self.calibration_samples, bool(self.is_qat) # type: ignore + ) + else: + converted = self.quantizer.quantize_with_submodules( + captured_graph, [inputs], bool(self.is_qat) + ) + + DuplicateDynamicQuantChainPass()(converted) + + self.converted_graph = converted diff --git a/backends/arm/test/tester/serialize.py b/backends/arm/test/tester/serialize.py new file mode 100644 index 00000000000..33e57cc721d --- /dev/null +++ b/backends/arm/test/tester/serialize.py @@ -0,0 +1,78 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from typing import Optional + +import executorch.backends.xnnpack.test.tester.tester as tester + +import torch.fx + +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec + +from executorch.backends.arm.test.runner_utils import ( + get_elf_path, + get_target_board, + run_target, +) + +from executorch.exir import ExecutorchProgramManager +from torch.utils._pytree import tree_flatten + + +logger = logging.getLogger(__name__) + + +class Serialize(tester.Serialize): + def __init__( + self, + compile_spec: ArmCompileSpec, + module: Optional[torch.nn.Module], + use_portable_ops: bool = False, + timeout: int = 120, + ): + """ + Args: + compile_spec: CompileSpecs to be used for serialization. + module: Original Module to be used for serialization. Optional - can be used for reference output generation. + portable_ops: If True tests with compiled in portable ops, default is to test without this to get error if not fully delegated + timeout: Timeout for fvp. Default is 120 seconds. + """ + super().__init__() + self.module = module + self.timeout = timeout + self.executorch_program_manager: ExecutorchProgramManager | None + self.compile_spec = compile_spec + self.use_portable_ops = use_portable_ops + + def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None: + super().run(artifact, inputs) + # Keep the entire ExecutorchProgramManager for execution. + self.executorch_program_manager = artifact + + def run_artifact(self, inputs): + if self.executorch_program_manager is None: + raise RuntimeError( + "Tried running artifact from Serialize stage without running the stage." + ) + inputs_flattened, _ = tree_flatten(inputs) + intermediate_path = self.compile_spec.get_intermediate_path() + target_board = get_target_board(self.compile_spec) + elf_path = get_elf_path(target_board, self.use_portable_ops) + + if not os.path.exists(elf_path): + raise FileNotFoundError( + f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?" + ) + + return run_target( + self.executorch_program_manager, + inputs_flattened, + intermediate_path, + target_board, + elf_path, + self.timeout, + ) diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 102ccd209e9..86cd130d2b4 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -21,31 +21,73 @@ ) import torch +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.quantizer import ( EthosUQuantizer, + get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, TOSAQuantizer, VgfQuantizer, ) from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses + +from executorch.backends.arm.test.tester.quantize import ArmQuantize as Quantize from executorch.backends.arm.tosa.specification import ( TosaLoweringContext, TosaSpecification, ) -from executorch.backends.xnnpack.test.tester.tester import Quantize -from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.backends.arm.util._factory import create_quantizer from executorch.exir.pass_base import ExportPass from torch._export.pass_base import PassType +from torchao.quantization.pt2e.quantizer import QuantizationSpec logger = logging.getLogger(__name__) -T = TypeVar("T") +T = TypeVar("T", bound=Tuple[Any, ...]) """ Generic type used for test data in the pipeline. Depends on which type the operator expects.""" +def _require_tosa_version() -> str: + version = conftest.get_option("tosa_version") + if not isinstance(version, str): + raise TypeError(f"TOSA version option must be a string, got {type(version)}.") + return version + + +def _has_quantizable_inputs(test_data: T) -> bool: + for data in test_data: + if isinstance(data, torch.Tensor) and data.is_floating_point(): + return True + return False + + +class PipelineStage: + """Container for a pipeline stage (callable plus arguments).""" + + def __init__(self, func: Callable, id: str, *args, **kwargs): + self.id: str = id + self.func: Callable = func + self.args = args + self.kwargs = kwargs + self.is_called = False + + def __call__(self): + if not self.is_called: + self.func(*self.args, **self.kwargs) + else: + raise RuntimeError(f"{self.id} called twice.") + self.is_called = True + + def update(self, *args, **kwargs): + if not self.is_called: + self.args = args + self.kwargs = kwargs + else: + raise RuntimeError(f"{self.id} args updated after being called.") + + class BasePipelineMaker(Generic[T]): """ The BasePiplineMaker defines a list of stages to be applied to a torch.nn.module for lowering it @@ -66,53 +108,27 @@ class BasePipelineMaker(Generic[T]): tester.to_edge().check(exir_ops).partition() """ - class PipelineStage: - """ - Helper class to store a pipeline stage as a function call + args for calling later on. - - Attributes: - id: name of the function to be called, used for refering to stages in the pipeline. - func: handle to the function to be called. - args: args used when called. - kwargs: kwargs used when called. - is_called: keeps track of if the function has been called. - """ - - def __init__(self, func: Callable, id: str, *args, **kwargs): - self.id: str = id - self.func: Callable = func - self.args = args - self.kwargs = kwargs - self.is_called = False - - def __call__(self): - if not self.is_called: - self.func(*self.args, **self.kwargs) - else: - raise RuntimeError(f"{self.id} called twice.") - self.is_called = True - - def update(self, *args, **kwargs): - if not self.is_called: - self.args = args - self.kwargs = kwargs - else: - raise RuntimeError(f"{self.id} args updated after being called.") + @staticmethod + def _normalize_ops(ops: str | Sequence[str] | None) -> list[str]: + if ops is None: + return [] + if isinstance(ops, str): + return [ops] + return list(ops) def __init__( self, module: torch.nn.Module, test_data: T, - aten_ops: str | List[str], - compile_spec: List[CompileSpec], - exir_ops: Optional[str | List[str]] = None, + aten_ops: str | Sequence[str] | None, + compile_spec: ArmCompileSpec, + exir_ops: str | Sequence[str] | None = None, use_to_edge_transform_and_lower: bool = True, dynamic_shapes: Optional[Tuple[Any]] = None, transform_passes: Optional[ Union[Sequence[PassType], Dict[str, Sequence[PassType]]] ] = None, ): - self.tester = ArmTester( module, example_inputs=test_data, @@ -121,15 +137,10 @@ def __init__( transform_passes=transform_passes, ) - self.aten_ops = aten_ops if isinstance(aten_ops, list) else [aten_ops] - if exir_ops is None: - self.exir_ops = [] - elif isinstance(exir_ops, list): - self.exir_ops = exir_ops - else: - self.exir_ops = [exir_ops] + self.aten_ops = self._normalize_ops(aten_ops) + self.exir_ops = self._normalize_ops(exir_ops) self.test_data = test_data - self._stages = [] + self._stages: list[PipelineStage] = [] self.add_stage(self.tester.export) self.add_stage(self.tester.check, self.aten_ops, suffix="aten") @@ -204,13 +215,30 @@ def add_stage(self, func: Callable, *args, **kwargs): if stage_id in id_list: raise ValueError("Suffix must be unique in pipeline") - pipeline_stage = self.PipelineStage(func, stage_id, *args, **kwargs) + pipeline_stage = PipelineStage(func, stage_id, *args, **kwargs) self._stages.insert(pos, pipeline_stage) logger.debug(f"Added stage {stage_id} to {type(self).__name__}") return self + @property + def quantizer(self) -> TOSAQuantizer: + quantize_pipeline_stage = self._stages[self.find_pos("quantize")] + quantize_stage = quantize_pipeline_stage.args[0] + if isinstance(quantize_stage, Quantize): + quantizer = quantize_stage.quantizer + if isinstance(quantizer, TOSAQuantizer): + return quantizer + else: + raise RuntimeError( + f"Quantizer in pipeline was {type(quantizer).__name__}, not TOSAQuantizer as expected." + ) + else: + raise RuntimeError( + f"First argument of quantize stage was {type(quantize_stage).__name__}, not Quantize as expected." + ) + def pop_stage(self, identifier: int | str): """Removes and returns the stage at postion pos""" if isinstance(identifier, int): @@ -218,6 +246,8 @@ def pop_stage(self, identifier: int | str): elif isinstance(identifier, str): pos = self.find_pos(identifier) stage = self._stages.pop(pos) + else: + raise TypeError("identifier must be an int or str") logger.debug(f"Removed stage {stage.id} from {type(self).__name__}") @@ -245,19 +275,24 @@ def add_stage_after(self, stage_id: str, func: Callable, *args, **kwargs): self.add_stage(func, *args, **kwargs) return self - def dump_artifact(self, stage_id: str, suffix: str = None): + def dump_artifact(self, stage_id: str, suffix: str | None = None): """Adds a dump_artifact stage after the given stage id.""" self.add_stage_after(stage_id, self.tester.dump_artifact, suffix=suffix) return self - def dump_operator_distribution(self, stage_id: str, suffix: str = None): + def dump_operator_distribution( + self, stage_id: str, suffix: str | None = None, include_dtypes: bool = False + ): """Adds a dump_operator_distribution stage after the given stage id.""" self.add_stage_after( - stage_id, self.tester.dump_operator_distribution, suffix=suffix + stage_id, + self.tester.dump_operator_distribution, + suffix=suffix, + include_dtypes=include_dtypes, ) return self - def visualize(self, stage_id: str, suffix: str = None): + def visualize(self, stage_id: str, suffix: str | None = None): """Adds a dump_operator_distribution stage after the given stage id.""" self.add_stage_after(stage_id, self.tester.visualize, suffix=suffix) return self @@ -283,14 +318,13 @@ def run(self): class TOSAPipelineMaker(BasePipelineMaker, Generic[T]): - @staticmethod def is_tosa_ref_model_available(): """Checks if the TOSA reference model is available.""" # Not all deployments of ET have the TOSA reference model available. # Make sure we don't try to use it if it's not available. try: - import tosa_reference_model + import tosa_reference_model # type: ignore[import-not-found, import-untyped] # Check if the module has content return bool(dir(tosa_reference_model)) @@ -339,22 +373,23 @@ def __init__( symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, - custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, + custom_path: str | None = None, + tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, dynamic_shapes: Optional[Tuple[Any]] = None, tosa_extensions: Optional[List[str]] = None, + epsilon: float = 2**-12, ): if tosa_extensions is None: tosa_extensions = [] - tosa_profiles = { + tosa_profiles: dict[str, TosaSpecification] = { "1.0": TosaSpecification.create_from_string( "TOSA-1.0+INT" + "".join([f"+{ext}" for ext in tosa_extensions]) ), } - tosa_version = conftest.get_option("tosa_version") + tosa_version = _require_tosa_version() compile_spec = common.get_tosa_compile_spec( tosa_profiles[tosa_version], @@ -363,9 +398,15 @@ def __init__( ) quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) - quantization_config = get_symmetric_quantization_config( - is_per_channel=per_channel_quantization - ) + # choose 16A8W quantization config when int16 extension is requested + if "int16" in tosa_extensions: + quantization_config = get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization, epsilon=epsilon + ) + else: + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization + ) if symmetric_io_quantization: quantizer.set_io(quantization_config) quant_stage = Quantize(quantizer, quantization_config) @@ -381,30 +422,32 @@ def __init__( ) self.add_stage(self.tester.quantize, quant_stage, pos=0) - self.add_stage_after( - "quantize", - self.tester.check, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) - remove_quant_nodes_stage = ( "to_edge_transform_and_lower" if use_to_edge_transform_and_lower else "partition" ) - self.add_stage_after( - remove_quant_nodes_stage, - self.tester.check_not, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) + + if _has_quantizable_inputs(test_data): + # only add stages if we have quantizable input + self.add_stage_after( + "quantize", + self.tester.check, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) + self.add_stage_after( + remove_quant_nodes_stage, + self.tester.check_not, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) if run_on_tosa_ref_model: self.add_stage( @@ -444,8 +487,8 @@ def __init__( exir_op: Optional[str | List[str]] = None, run_on_tosa_ref_model: bool = True, use_to_edge_transform_and_lower: bool = True, - custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, + custom_path: str | None = None, + tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 0, @@ -457,12 +500,12 @@ def __init__( ): if tosa_extensions is None: tosa_extensions = [] - tosa_profiles = { + tosa_profiles: dict[str, TosaSpecification] = { "1.0": TosaSpecification.create_from_string( "TOSA-1.0+FP" + "".join([f"+{ext}" for ext in tosa_extensions]) ), } - tosa_version = conftest.get_option("tosa_version") + tosa_version = _require_tosa_version() compile_spec = common.get_tosa_compile_spec( tosa_profiles[tosa_version], @@ -524,21 +567,29 @@ def __init__( run_on_fvp: bool = True, symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, + a16w8_quantization: bool = False, use_to_edge_transform_and_lower: bool = True, - custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, + custom_path: str | None = None, + tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, + epsilon: float = 2**-12, ): compile_spec = common.get_u55_compile_spec( custom_path=custom_path, tosa_debug_mode=tosa_debug_mode, ) quantizer = EthosUQuantizer(compile_spec) - quantization_config = get_symmetric_quantization_config( - is_per_channel=per_channel_quantization - ) + # choose int8 or int16 activation quantization + if a16w8_quantization: + quantization_config = get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization, epsilon=epsilon + ) + else: + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization + ) if symmetric_io_quantization: quantizer.set_io(quantization_config) quant_stage = Quantize(quantizer, quantization_config) @@ -554,30 +605,32 @@ def __init__( self.add_stage(self.tester.quantize, quant_stage, pos=0) - self.add_stage_after( - "quantize", - self.tester.check, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) - remove_quant_nodes_stage = ( "to_edge_transform_and_lower" if use_to_edge_transform_and_lower else "partition" ) - self.add_stage_after( - remove_quant_nodes_stage, - self.tester.check_not, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) + + if _has_quantizable_inputs(test_data): + # only add stages if we have quantizable input + self.add_stage_after( + "quantize", + self.tester.check, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) + self.add_stage_after( + remove_quant_nodes_stage, + self.tester.check_not, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) if run_on_fvp: self.add_stage(self.tester.serialize) @@ -611,25 +664,33 @@ def __init__( module: torch.nn.Module, test_data: T, aten_ops: str | List[str], - exir_ops: str | List[str] = None, + exir_ops: str | List[str] | None = None, run_on_fvp: bool = True, symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, + a16w8_quantization: bool = False, use_to_edge_transform_and_lower: bool = True, - custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, + custom_path: str | None = None, + tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, + epsilon: float = 2**-12, ): compile_spec = common.get_u85_compile_spec( custom_path=custom_path, tosa_debug_mode=tosa_debug_mode, ) quantizer = EthosUQuantizer(compile_spec) - quantization_config = get_symmetric_quantization_config( - is_per_channel=per_channel_quantization - ) + # choose int8 or int16 activation quantization + if a16w8_quantization: + quantization_config = get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization, epsilon=epsilon + ) + else: + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization + ) if symmetric_io_quantization: quantizer.set_io(quantization_config) quant_stage = Quantize(quantizer, quantization_config) @@ -645,30 +706,32 @@ def __init__( self.add_stage(self.tester.quantize, quant_stage, pos=0) - self.add_stage_after( - "quantize", - self.tester.check, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) - remove_quant_nodes_stage = ( "to_edge_transform_and_lower" if use_to_edge_transform_and_lower else "partition" ) - self.add_stage_after( - remove_quant_nodes_stage, - self.tester.check_not, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) + + if _has_quantizable_inputs(test_data): + # only add stages if we have quantizable input + self.add_stage_after( + "quantize", + self.tester.check, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) + self.add_stage_after( + remove_quant_nodes_stage, + self.tester.check_not, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) if run_on_fvp: self.add_stage(self.tester.serialize) @@ -716,20 +779,20 @@ def __init__( pass_list: Optional[List[Type[PassType]]] = None, pass_functions: Optional[List[Callable]] = None, passes_with_exported_program: Optional[List[Type[ExportPass]]] = None, - custom_path: str = None, + custom_path: str | None = None, tosa_extensions: Optional[List[str]] = None, ): if tosa_extensions is None: tosa_extensions = [] - tosa_profiles = { + tosa_profiles: dict[str, TosaSpecification] = { "1.0": TosaSpecification.create_from_string( "TOSA-1.0+" + ("INT" if quantize else "FP") + "".join([f"+{ext}" for ext in tosa_extensions]), ), } - tosa_version = conftest.get_option("tosa_version") - self.tosa_spec = tosa_profiles[tosa_version] + tosa_version = _require_tosa_version() + self.tosa_spec: TosaSpecification = tosa_profiles[tosa_version] compile_spec = common.get_tosa_compile_spec( self.tosa_spec, custom_path=custom_path @@ -759,9 +822,9 @@ def __init__( self.add_stage(self.tester.check_count, ops_before_pass, suffix="before") if ops_not_before_pass: self.add_stage(self.tester.check_not, ops_not_before_pass, suffix="before") - test_pass_stage = RunPasses( - pass_list, pass_functions, passes_with_exported_program - ) + test_pass_stage = RunPasses( # type: ignore[arg-type] + pass_list, pass_functions, passes_with_exported_program # type: ignore[arg-type] + ) # Legacy pass APIs expose callable classes rather than ExportPass subclasses self.add_stage(self.tester.run_passes, test_pass_stage) @@ -769,7 +832,10 @@ def __init__( self.add_stage(self.tester.check_count, ops_after_pass, suffix="after") if ops_not_after_pass: self.add_stage(self.tester.check_not, ops_not_after_pass, suffix="after") - self.add_stage(self.tester.run_method_and_compare_outputs) + self.add_stage( + self.tester.run_method_and_compare_outputs, + inputs=self.test_data, + ) def run(self): with TosaLoweringContext(self.tosa_spec): @@ -792,17 +858,17 @@ def __init__( self, module: torch.nn.Module, test_data: T, - custom_path: str = None, + custom_path: str | None = None, tosa_extensions: Optional[List[str]] = None, ): if tosa_extensions is None: tosa_extensions = [] - tosa_profiles = { + tosa_profiles: dict[str, TosaSpecification] = { "1.0": TosaSpecification.create_from_string( "TOSA-1.0+INT" + "".join([f"+{ext}" for ext in tosa_extensions]), ), } - tosa_version = conftest.get_option("tosa_version") + tosa_version = _require_tosa_version() compile_spec = common.get_tosa_compile_spec( tosa_profiles[tosa_version], custom_path=custom_path @@ -832,6 +898,63 @@ def __init__( ) +class QuantizationPipeline(TOSAPipelineMaker, Generic[T]): + """ + Runs quantization and checks that appropriate nodes are annotated with an expected + quantization-spec. + + Attributes: + module: The module which the pipeline is applied to. + test_data: Data used for testing the module. + quantizer: The quantizer to use for quantization. + qspecs: Annotations to check for after quantization. A dict mapping + operator names to a dict mapping QuantizationSpec (or None) to the number of times + that spec should appear in the graph. A None QuantizationSpec indicates that the + operator should not be quantized. + input_qspecs: Annotations to check for after quantization on inputs. + output_qspecs: Annotations to check for after quantization on outputs. + custom_path : Path to dump intermediate artifacts to. + + """ + + def __init__( + self, + module: torch.nn.Module, + test_data: T, + quantizer: TOSAQuantizer, + qspecs: Optional[Dict[str, Dict[QuantizationSpec | None, int]]] = None, + input_qspecs: Optional[Dict[QuantizationSpec | None, int]] = None, + output_qspecs: Optional[Dict[QuantizationSpec | None, int]] = None, + custom_path: Optional[str] = None, + ): + tosa_spec = quantizer.tosa_spec + compile_spec = common.get_tosa_compile_spec(tosa_spec, custom_path=custom_path) + super().__init__( + module, + test_data, + None, + compile_spec, + None, + use_to_edge_transform_and_lower=True, + ) + # TODO sort out typing + quant_stage = Quantize(quantizer, quantization_config=quantizer.global_config) # type: ignore[arg-type] + self.add_stage(self.tester.quantize, quant_stage, pos=0) + + # Delete most of the pipeline + self.pop_stage("check_count.exir") + self.pop_stage("to_executorch") + self.pop_stage("to_edge_transform_and_lower") + self.pop_stage("check.aten") + self.add_stage_after( + "export", + self.tester.check_quantization_annotation, + qspecs, + input_qspecs, + output_qspecs, + ) + + class OpNotSupportedPipeline(TOSAPipelineMaker, Generic[T]): """ Runs the partitioner on a module and checks that ops are not delegated to test @@ -853,14 +976,14 @@ def __init__( test_data: T, non_delegated_ops: Dict[str, int], n_expected_delegates: int = 0, - custom_path: str = None, + custom_path: str | None = None, quantize: Optional[bool] = False, u55_subset: Optional[bool] = False, tosa_extensions: Optional[List[str]] = None, ): if tosa_extensions is None: tosa_extensions = [] - tosa_profiles = { + tosa_profiles: dict[str, TosaSpecification] = { "1.0": TosaSpecification.create_from_string( "TOSA-1.0+" + ("INT" if quantize else "FP") @@ -868,11 +991,14 @@ def __init__( + "".join([f"+{ext}" for ext in tosa_extensions]), ), } - tosa_version = conftest.get_option("tosa_version") + tosa_version = _require_tosa_version() tosa_spec = tosa_profiles[tosa_version] - compile_spec = common.get_tosa_compile_spec(tosa_spec, custom_path=custom_path) + compile_spec: ArmCompileSpec = common.get_tosa_compile_spec( + tosa_spec, + custom_path=custom_path, + ) super().__init__( module, test_data, @@ -882,7 +1008,10 @@ def __init__( ) if tosa_spec.support_integer(): - self.add_stage(self.tester.quantize, pos=0) + quantizer = create_quantizer(compile_spec) + quantizer.set_global(get_symmetric_quantization_config()) + quant_stage = Quantize(quantizer) + self.add_stage(self.tester.quantize, quant_stage, pos=0) self.change_args("check_not.exir", []) self.change_args( @@ -907,7 +1036,7 @@ class VgfPipeline(BasePipelineMaker, Generic[T]): exir_ops: Exir dialect ops expected to be found in the graph after to_edge. if not using use_edge_to_transform_and_lower. - run_on_vulkan_runtime: Set to true to test VGF output on VKML runtime. + run_on_vulkan_runtime: Whether to test VGF output on VKML runtime. vgf_compiler_flags: Optional compiler flags. @@ -923,14 +1052,15 @@ def __init__( test_data: T, aten_op: str | List[str], exir_op: Optional[str | List[str]] = None, - run_on_vulkan_runtime: bool = False, + run_on_vulkan_runtime: bool = True, vgf_compiler_flags: Optional[str] = "", - tosa_version: str = "TOSA-1.0+FP", + tosa_version: str = "TOSA-1.0+INT+FP", + quantize: bool = True, symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, - custom_path: str = None, - tosa_debug_mode: Optional[ArmCompileSpecBuilder.DebugMode] = None, + custom_path: str | None = None, + tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, @@ -940,7 +1070,6 @@ def __init__( ] = None, tosa_extensions: Optional[List[str]] = None, ): - if tosa_extensions is None: tosa_extensions = [] tosa_spec = TosaSpecification.create_from_string( @@ -964,7 +1093,7 @@ def __init__( transform_passes=transform_passes, ) - if tosa_spec.support_integer(): + if quantize: quantizer = VgfQuantizer(compile_spec) quantization_config = get_symmetric_quantization_config( is_per_channel=per_channel_quantization @@ -975,30 +1104,32 @@ def __init__( self.add_stage(self.tester.quantize, quant_stage, pos=0) - self.add_stage_after( - "quantize", - self.tester.check, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) - remove_quant_nodes_stage = ( "to_edge_transform_and_lower" if use_to_edge_transform_and_lower else "partition" ) - self.add_stage_after( - remove_quant_nodes_stage, - self.tester.check_not, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) + + if _has_quantizable_inputs(test_data): + # only add stages if we have quantizable input + self.add_stage_after( + "quantize", + self.tester.check, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) + self.add_stage_after( + remove_quant_nodes_stage, + self.tester.check_not, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) else: self.add_stage_after( "export", @@ -1019,3 +1150,16 @@ def __init__( qtol=qtol, inputs=self.test_data, ) + self.run_on_vulkan_runtime = run_on_vulkan_runtime + + # TODO: Remove once CI fully working + def run(self): + import pytest + + if self.run_on_vulkan_runtime: + try: + super().run() + except FileNotFoundError as e: + pytest.skip(f"VKML executor_runner not found - not built - skip {e}") + else: + super().run() diff --git a/backends/arm/tosa/TARGETS b/backends/arm/tosa/TARGETS index b1df4f37c53..d0f7a743f53 100644 --- a/backends/arm/tosa/TARGETS +++ b/backends/arm/tosa/TARGETS @@ -6,28 +6,11 @@ runtime.python_library( "mapping.py", ], deps = [ - "fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/serializer:serializer", - "fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/serializer:serializer", + "fbsource//third-party/tosa_tools:serializer", "//caffe2:torch", ":specification", ], ) -runtime.python_library( - name = "quant_utils", - srcs = [ - "quant_utils.py", - ], - deps = [ - "fbsource//third-party/pypi/numpy:numpy", - "fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/serializer:serializer", - "fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/serializer:serializer", - "fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/tosa:tosa", - "fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/tosa:tosa", - "//executorch/backends/arm:constants", - ":mapping", - "//executorch/exir/dialects:lib", - ], -) runtime.python_library( name = "specification", srcs = [ @@ -44,8 +27,6 @@ runtime.python_library( "utils.py", ], deps = [ - "fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/serializer:serializer", - ":quant_utils", "//executorch/backends/arm/operators:node_visitor", ], ) @@ -59,3 +40,32 @@ runtime.python_library( ":specification", ], ) + +runtime.python_library( + name = "compile_spec", + srcs = [ + "compile_spec.py", + ], + deps = [ + ":tosa", + ":specification", + "//executorch/backends/arm:arm_compile_spec", + ], +) + +runtime.python_library( + name = "partitioner", + srcs = [ + "backend.py", + "partitioner.py", + ], + deps = [ + ":compile_spec", + "//executorch/backends/arm:constants", + "//executorch/backends/arm:process_node", + "//executorch/backends/arm/debug:schema", + "//executorch/backends/arm/operator_support:operator_support", + "//executorch/backends/arm/_passes:passes", + "//executorch/exir:lib", + ], +) diff --git a/backends/arm/tosa/__init__.py b/backends/arm/tosa/__init__.py index 132d3563a43..30860642ac5 100644 --- a/backends/arm/tosa/__init__.py +++ b/backends/arm/tosa/__init__.py @@ -3,7 +3,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # -# pyre-unsafe from .specification import TosaSpecification diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index ce2b7a27487..913b5207767 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -2,21 +2,28 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide TOSA backend entry points for the Arm ExecuTorch integration. -# pyre-unsafe +Implement the Ahead-of-Time (AoT) preprocessing path that lowers an +``ExportedProgram`` to a TOSA flatbuffer using Arm's lowering pipeline. Use +this module either as a standalone backend that produces a TOSA artifact or as +part of a composed pipeline for hardware backends that consume TOSA as an +intermediate form. + +Use ``TOSABackend.preprocess`` to return the serialized TOSA flatbuffer that +subsequent stages (for example, JIT or hardware-specific compilers) consume. + +""" -# -# Main implementation of AoT flow to partition and preprocess for Arm target -# backends. Converts via TOSA as an intermediate form supported by AoT and -# JIT compiler flows. -# import logging -from collections import deque +import tempfile from itertools import count -from typing import cast, Dict, final, List, Set +from typing import cast, Dict, final, List + +import torch -import serializer.tosa_serializer as ts # type: ignore -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder +import tosa_serializer as ts +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump from executorch.backends.arm.debug.schema import DebugHook from executorch.backends.arm.process_node import ( @@ -24,44 +31,110 @@ process_output, process_placeholder, ) -from executorch.backends.arm.tosa.specification import get_tosa_spec +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.dim_order_utils import get_memory_format +from executorch.exir.graph_module import get_cond_while_submodules from torch.export.exported_program import ExportedProgram -from torch.fx import Graph, Node +from torch.fx import Graph, GraphModule, Node # TOSA backend debug functionality logger = logging.getLogger(__name__) def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]: - """ - Returns dictionary: node name -> external ids + """Assign deterministic output IDs to leaf outputs. - Assign id to an output node of the model so we can trace it. + Flattens the output structure and assigns the external ID + based on the leaf position in the exported output tuple/list. + + Args: + ep_graph (Graph): FX graph produced by export preprocessing. + + Returns: + dict[str, int]: Mapping from *leaf output node name* to external output index. """ node2external_id = {} - def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]): - q = deque(start_nodes) - while q: - n = q.popleft() - if n in seen: - continue - seen.add(n) - node2external_id[n.name] = idx - # Walk backwards so we touch every producer - q.extend(n.all_input_nodes) - - out = next(n for n in ep_graph.nodes if n.op == "output") - seen: Set[Node] = set() - for idx, val in enumerate(out.args[0]): - bfs_mark([val], idx, seen) + def _collect_leaves(arg, nodes): + # Collect only FX Nodes that are actual outputs + # (ignore ints/None/etc inside structured outputs). + if isinstance(arg, Node): + nodes.append(arg) + elif isinstance(arg, (list, tuple)): + for a in arg: + _collect_leaves(a, nodes) + + out = ep_graph.output_node() + out_leaves: list[Node] = [] + # First argument of output is the structured container (tuple/list) of outputs + _collect_leaves(out.args[0], out_leaves) + + # Map each output leaf's name to its position + node2external_id = {leaf.name: idx for idx, leaf in enumerate(out_leaves)} + return node2external_id +def _sort_outputs(graph_module: GraphModule, node_to_id_map: dict[str, int]): + """Reorder graph outputs to match ascending external IDs. + + Args: + graph_module (GraphModule): Graph to reorder in place. + node_to_id_map (dict[str, int]): Mapping from node name to output index. + + Returns: + GraphModule: Updated graph module with deterministic output ordering. + + """ + + def _external_id(n: Node, node_2_id, fallback: int) -> int: + """Return the external ID for ``n`` or ``fallback`` when absent.""" + return node_2_id.get(n.name, fallback) + + out_node = graph_module.graph.output_node() + out_list = cast(tuple, out_node.args[0]) + _counter = count() + + # sort nodes by the key that is id + def _sort_key(t: Node) -> int: + """Key function that orders outputs by external ID or position.""" + return _external_id(t, node_to_id_map, next(_counter)) + + orig_ord = tuple(sorted(out_list, key=_sort_key)) + + current_order = tuple(out_list) + if orig_ord != current_order: + replacement = list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord + out_node.args = (replacement,) + graph_module.graph.lint() + graph_module.recompile() + + return graph_module + + +def _get_matching_fake_tensor(node: Node): + """Return a fake tensor with the same properties as node, + but with .dim_order() == node.meta["tosa_dim_order"] + """ + fake_tensor = node.meta["val"] + desired_dim_order = node.meta["tosa_dim_order"] + return fake_tensor.to(memory_format=get_memory_format(list(desired_dim_order))) + + def arm_get_first_delegation_tag(graph_module) -> str: - """Get the first delegation tag from the graph_module or return empty string.""" + """Return the first delegation tag discovered in the FX graph. + + Args: + graph_module (GraphModule): Module produced by Arm partitioning. + + Returns: + str: First non-empty delegation tag or an empty string when no tag is + recorded. + + """ for node in graph_module.graph.nodes: tag = node.meta.get("delegation_tag") if tag: @@ -73,108 +146,252 @@ def arm_get_first_delegation_tag(graph_module) -> str: @final class TOSABackend(BackendDetails): + """Provide a backend for lowering programs to TOSA. + + Use this class standalone to produce a TOSA representation, or as part of a + composed pipeline for hardware backends that consume TOSA. + """ - BackendDetails subclass for lowering to TOSA. - Is used either by itself to get to a TOSA representation, or with composition - to be used as a separate step to target TOSA compliant hardware. - """ @staticmethod - def preprocess( # noqa: C901 + def preprocess(edge_program: ExportedProgram, compile_specs: List[CompileSpec]): + """Convert an exported program using the provided compile specs. + + Args: + edge_program (ExportedProgram): Program generated by Torch export. + compile_specs (List[CompileSpec]): Raw compile specifications from + ``executorch.apply_backend``. + + Returns: + PreprocessResult: Result containing serialized TOSA bytes. + + """ + return TOSABackend._preprocess( + edge_program, TosaCompileSpec.from_list(compile_specs) + ) + + @staticmethod + def _preprocess( # noqa: C901 edge_program: ExportedProgram, - compile_spec: List[CompileSpec], + compile_spec: TosaCompileSpec, ) -> PreprocessResult: + """Lower an exported program to a TOSA flatbuffer. + + Apply Arm transformation passes to ``edge_program``, then walk the + transformed FX graph to emit a TOSA graph via the serializer. When + requested in ``compile_spec``, write additional debug artifacts. + + Args: + edge_program (ExportedProgram): Program to lower to TOSA. + compile_spec (TosaCompileSpec): Backend options. Recognized keys: + - output_format: Must be "tosa". + - tosa_spec: Target TOSA version/capabilities. + - debug_artifact_path: Directory for debug outputs. + - compile_flags: Optional backend flags. + - dump_debug_info: Enable extra debug JSON dump. + + Returns: + PreprocessResult: Result containing processed_bytes with the + serialized TOSA flatbuffer. + + Raises: + ValueError: If output_format is not "tosa" or the TOSA + specification is missing from compile_spec. + RuntimeError: If an unsupported FX node type is encountered. + + """ # if a debug/test build capture output files from TOSA stage - artifact_path = None - output_format = "" - compile_flags = [] - dump_debug_info = None - for spec in compile_spec: - if spec.key == "debug_artifact_path": - artifact_path = spec.value.decode() - if spec.key == "output_format": - output_format = spec.value.decode() - if spec.key == "compile_flags": - compile_flags.append(spec.value.decode()) - if spec.key == "dump_debug_info": - dump_debug_info = spec.value.decode() - - # Check that the output format is set correctly in the compile spec - if output_format != "tosa": - raise ValueError(f'Invalid output format {output_format}, must be "tosa"') - - # Assign to every node external id - node_2_id = _annotate_external_ids(edge_program.graph) - - tosa_spec = get_tosa_spec(compile_spec) - if tosa_spec is None: - raise ValueError( - "TOSA backend needs a TOSA version specified in the CompileSpec" - ) + artifact_path = compile_spec.get_intermediate_path() + tosa_spec = compile_spec.tosa_spec + dump_debug_info = compile_spec.tosa_debug_mode + debug_hook = None + if dump_debug_info is not None: + debug_hook = DebugHook(dump_debug_info) logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}") # Converted output for this subgraph, serializer needs path early as it emits # const data directly. Path created and data written only in debug builds. - tosa_graph = ts.TosaSerializer(artifact_path) + if not artifact_path: + artifact_path = "" + + version = tosa_spec.version + tosa_graph = ts.TosaSerializer( + artifact_path, + targetMajor=version.major, + targetMinor=version.minor, + targetPatch=version.micro, + targetDraft=False, + ) - assert ( + if not ( tosa_spec.version.major == ts.TOSA_VERSION_MAJOR and tosa_spec.version.minor == ts.TOSA_VERSION_MINOR - ), f"TOSA serializer version ({ts.TOSA_VERSION_MAJOR}.{ts.TOSA_VERSION_MINOR}) doesn't match specification {tosa_spec}" + ): + raise RuntimeError( + f"TOSA serializer version " + f"({ts.TOSA_VERSION_MAJOR}.{ts.TOSA_VERSION_MINOR}) " + f"doesn't match specification {tosa_spec}" + ) + + TOSABackend._preprocess_module( + edge_program.graph_module, + edge_program, + compile_spec, + tosa_graph, + debug_hook, + ) + # Serialize and return the TOSA flatbuffer. + binary = tosa_graph.serialize() + + if artifact_path: + tag = arm_get_first_delegation_tag(edge_program.graph_module) + + # Only dump TOSA if we are not saving to temporary folder. + if len( + tempdir := tempfile.gettempdir() + ) > 0 and not artifact_path.startswith(tempdir): + debug_tosa_dump( + binary, + artifact_path, + suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"), + ) + + if debug_hook is not None: + if debug_hook.mode == ArmCompileSpec.DebugMode.JSON: + json_output = debug_hook.serialize() + with open(f"{artifact_path}/debug.json", "w") as f: + f.write(json_output) + + return PreprocessResult(processed_bytes=binary) + + @staticmethod + def _regularize_submodule(submodule: GraphModule, submodule_node: Node): + """To make a submodule fit into the normal flow of a graph_module, we need to do some regularizations. + + - Buffers created before passes are treated as input to the submodule. Buffers created during passes + are treated as "normal" buffers, i.e. gathered from the state_dict. + To make it easy to tell them apart, mark all placeholders with "is_input = True" before running passes. + - Make sure output node args[0] is always iterable. + - Match the dim_order() of the input tensors with the dim orders of the submodule_node inputs. + - Match the dim_order() of the out tensors with the dim orders of the submodule_node outputs. + """ + submodule_inputs: list[Node] = [] + for node in submodule.graph.nodes: + if node.op == "placeholder": + node.meta["is_input"] = True + submodule_inputs.append(node) + match submodule_node.target: + case torch.ops.higher_order.cond: + args = cast(list[Node], submodule_node.args[-1]) + case torch.ops.higher_order.while_loop: + args = cast(list[Node], submodule_node.args[-2]) + cast( + list, submodule_node.args[-1] + ) + case _: + raise RuntimeError( + f"Unexpected control flow target: {submodule_node.target}" + ) + + for submodule_input, submodule_arg in zip(submodule_inputs, args, strict=True): + submodule_input.meta["val"] = _get_matching_fake_tensor(submodule_arg) + + output_node = submodule.graph.output_node() + if isinstance(output_node.args[0], Node): + output_node.update_arg(0, [output_node.args[0]]) + output_args = cast(list[Node], output_node.args[0]) + + # Not all outputs might be used, causing len(users) < len(outputs) + # Therefore, strict != True in the zip + for submodule_output, submodule_user in zip(output_args, submodule_node.users): + submodule_output.meta["val"] = _get_matching_fake_tensor(submodule_user) + + @staticmethod + def _preprocess_module( # noqa: C901 + graph_module: GraphModule, + edge_program: ExportedProgram, + compile_spec: TosaCompileSpec, + tosa_graph: ts.TosaSerializer, + debug_hook: DebugHook | None, + submodule_name: str | None = None, + containing_graph_module: GraphModule | None = None, + ): + """Convert an FX ``graph_module`` to TOSA serializer calls. + + Args: + graph_module (GraphModule): Module to lower recursively. + edge_program (ExportedProgram): Original exported program. + compile_spec (TosaCompileSpec): Backend options with TOSA settings. + tosa_graph (ts.TosaSerializer): Serializer receiving operators. + debug_hook (DebugHook | None): Optional debug instrumentation. + submodule_name (str | None): Name used when visiting nested blocks. + + Raises: + RuntimeError: If an FX node with an unsupported op kind is found. + + """ + tosa_spec = compile_spec.tosa_spec + node_to_id_map = _annotate_external_ids(graph_module.graph) + artifact_path = compile_spec.get_intermediate_path() + output_order_workaround = compile_spec.get_output_order_workaround() # TODO: Fix the need to lazily import this. from executorch.backends.arm._passes import ArmPassManager - graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore - exported_program=edge_program + graph_module = ArmPassManager(compile_spec).transform_to_backend_pipeline( # type: ignore + exported_program=edge_program, graph_module=graph_module ) - debug_hook = None - if dump_debug_info is not None: - debug_hook = DebugHook(ArmCompileSpecBuilder.DebugMode[dump_debug_info]) - # TODO: Fix the need to lazily import this. from executorch.backends.arm.operators.node_visitor import get_node_visitors node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook) - # Re-shuffle output nodes to preserve author's order - def _external_id(n: Node, node_2_id, fallback: int) -> int: - return node_2_id.get(n.name, fallback) - - out_node = next(n for n in graph_module.graph.nodes if n.op == "output") - _counter = count() + if output_order_workaround: + logger.debug("Re-sorting outputs during TOSA lowering.") + graph_module = _sort_outputs(graph_module, node_to_id_map) + else: + logger.debug("No re-sorting outputs (workaround) during TOSA lowering.") - # sort nodes by the key that is id - def _sort_key(t: Node) -> int: - return _external_id(t, node_2_id, next(_counter)) - - orig_ord = tuple(sorted(out_node.args[0], key=_sort_key)) - - current_order = tuple(out_node.args[0]) - if orig_ord != current_order: - replacement = ( - list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord - ) - out_node.args = (replacement,) - graph_module.graph.lint() - graph_module.recompile() + if submodule_name is not None: + tosa_graph.startRegion(submodule_name) + tosa_graph.currRegion.addBasicBlock(submodule_name) + suffix = f"_{submodule_name}" + for loop_node in graph_module.graph.nodes: + loop_node.meta[TOSA_TENSOR_NAME_META] = suffix - input_count = 0 for node in graph_module.graph.nodes: node = cast(Node, node) try: if node.op == "call_function": process_call_function(node, tosa_graph, node_visitors, tosa_spec) elif node.op == "placeholder": - if len(node.users) == 0: + if len(node.users) == 0 and submodule_name is None: + # In top level module, we don't need to handle unused placeholders. + # In submodules, we do need to handle them to preserve call signature. continue - process_placeholder(node, tosa_graph, edge_program, tosa_spec) - if node.name in edge_program.graph_signature.user_inputs: - input_count += 1 + process_placeholder( + node, + tosa_graph, + edge_program, + containing_graph_module, + tosa_spec, + ) elif node.op == "output": - process_output(node, tosa_graph) + process_output(node, tosa_graph, tosa_spec) + elif node.op == "get_attr": + attr = getattr(graph_module, str(node.target), None) + if attr is None: + raise RuntimeError( + "get_attr node is not targeting anything in graph module." + ) + if not isinstance(attr, GraphModule): + raise RuntimeError( + "get_attr node is not targeting a GraphModule." + ) + + # If the above conditions are ok, we don't need to handle this node here. + # Only the string value of node.target is important. else: # This will only happen if an unpartitioned graph is passed without # any checking of compatibility. @@ -183,48 +400,44 @@ def _sort_key(t: Node) -> int: debug_fail(node, graph_module, tosa_graph, artifact_path) raise - if artifact_path: - tag = arm_get_first_delegation_tag(graph_module) - debug_tosa_dump( + # Recursively preprocess controlflow submodules. + for name, submodule, control_flow_node in get_cond_while_submodules( + graph_module + ): + TOSABackend._regularize_submodule(submodule, control_flow_node) + TOSABackend._preprocess_module( + submodule, + edge_program, + compile_spec, tosa_graph, - artifact_path, - suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"), + debug_hook, + submodule_name=name, + containing_graph_module=graph_module, ) - if debug_hook is not None: - if debug_hook.mode == ArmCompileSpecBuilder.DebugMode.JSON: - json_output = debug_hook.serialize() - with open(f"{artifact_path}/debug.json", "w") as f: - f.write(json_output) - - # Serialize and return the TOSA flatbuffer. - binary = bytes(tosa_graph.serialize()) - - return PreprocessResult(processed_bytes=binary) - @staticmethod def filter_tosa_compile_specs( - compile_spec: List[CompileSpec], - ) -> List[CompileSpec]: - """ - Filter out the CompileSpec elements relevant for the TOSA backend. - This is needed to compose a backend targetting hardware IP with the - TOSABackend, since we first want to use the TOSABackend to generate - the TOSA flatbuffer representation as an intermediate step. The TOSA - flatbuffer can then be consumed by the backend targetting specific - hardware. - """ - tosa_compile_spec = [] - tosa_compile_spec.append(CompileSpec("output_format", "tosa".encode())) + compile_spec: ArmCompileSpec, + ) -> TosaCompileSpec: + """Extract the TOSA-specific settings from a composite compile spec. + + Args: + compile_spec (ArmCompileSpec): Compile specification that may + include both TOSA and hardware-specific options. - # Copy everything that's TOSA generic - tosa_backend_compile_spec_keys = [ - "tosa_spec", - "debug_artifact_path", - ] + Returns: + TosaCompileSpec: TOSA-only specification ready for + ``TOSABackend.preprocess``. - for spec in compile_spec: - if spec.key in tosa_backend_compile_spec_keys: - tosa_compile_spec.append(CompileSpec(spec.key, spec.value)) + """ - return tosa_compile_spec + pipeline_config = compile_spec.get_pass_pipeline_config() + tosa_compile_spec = TosaCompileSpec(compile_spec.tosa_spec) + tosa_compile_spec.set_pass_pipeline_config(pipeline_config) + return ( + tosa_compile_spec.dump_intermediate_artifacts_to( + compile_spec.get_intermediate_path() + ) + .dump_debug_info(compile_spec.tosa_debug_mode) + .set_output_order_workaround(compile_spec.output_order_workaround) + ) diff --git a/backends/arm/tosa/compile_spec.py b/backends/arm/tosa/compile_spec.py new file mode 100644 index 00000000000..5cd72ce04b3 --- /dev/null +++ b/backends/arm/tosa/compile_spec.py @@ -0,0 +1,48 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.common.pipeline_config import ( # noqa: unused + ArmPassPipelineConfig, +) +from executorch.backends.arm.tosa import TosaSpecification + + +class TosaCompileSpec(ArmCompileSpec): + """Arm-specific compile spec capturing TOSA serializer requirements.""" + + def __init__(self, tosa_spec: TosaSpecification | str): + """Normalize and store the provided TOSA specification. + + Args: + tosa_spec (TosaSpecification | str): Target spec object or version + string supported by :meth:`TosaSpecification.create_from_string`. + + """ + if isinstance(tosa_spec, str): + tosa_spec = TosaSpecification.create_from_string(tosa_spec) + self._set_compile_specs(tosa_spec, []) + self.validate() + + def validate(self): + """Ensure that no unsupported compiler flags were supplied.""" + if len(self.compiler_flags) != 0: + raise ValueError( + f"TosaCompileSpec can't have compiler flags, got {self.compiler_flags}" + ) + pass + + @classmethod + def get_output_format(cls) -> str: + """Return the artifact format emitted by this compile spec.""" + return "tosa" + + @classmethod + def from_list_hook(cls, compile_spec, specs: dict[str, str]): + super().from_list_hook(compile_spec, specs) + + def _create_default_pipeline_config(self): + config = super()._create_default_pipeline_config() + return config diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index 136f59beb62..152f99d4431 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -4,7 +4,12 @@ # LICENSE file in the root directory of this source tree. from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401 + conv2d, + conv3d, + depthwise_conv2d, + matmul, rescale, + resize, table, transpose, ) diff --git a/backends/arm/tosa/dialect/lib.py b/backends/arm/tosa/dialect/lib.py index 4a807d682dc..ed26a21a297 100644 --- a/backends/arm/tosa/dialect/lib.py +++ b/backends/arm/tosa/dialect/lib.py @@ -15,6 +15,17 @@ def register_tosa_dialect_op(op_schema, func) -> Callable: + """Register a TOSA dialect operator with the backend op library. + + Args: + op_schema (str): Operator schema without namespace or overload name. + func (Callable): Fake implementation used for registration. + + Returns: + Callable: Backend dialect operator handle exposed via ``exir_ops`` and + marked ``not_callable`` for runtime use. + + """ if tosa_lib.ns not in _BACKEND_OP_LIB: _BACKEND_OP_LIB.append(tosa_lib.ns) @@ -43,6 +54,7 @@ def register_tosa_dialect_op(op_schema, func) -> Callable: # the op doesn't need to be callable. This can be changed in the future if needed to support # execution of TOSA ops directly. def not_callable(): + """Raise when the dialect op handle is invoked at runtime.""" raise RuntimeError("TOSA dialect op is not callable") op.__equvalent_callable__ = not_callable @@ -51,11 +63,22 @@ def not_callable(): class TosaValueError(ValueError): + """Error type that annotates failures with the originating TOSA op.""" + def __init__(self, message="A TOSA value error occurred", *args, op=None): + """Initialise the error with optional operator metadata. + + Args: + message (str): Human-readable error message. + *args: Additional arguments forwarded to ``ValueError``. + op: Optional operator identifier included in the string output. + + """ super().__init__(message, *args) self.op = op def __str__(self): + """Return the base message, appending the operator when provided.""" base_message = super().__str__() if self.op is not None: return f"{base_message} (TOSA op: {self.op})" diff --git a/backends/arm/tosa/dialect/ops/conv2d.py b/backends/arm/tosa/dialect/ops/conv2d.py new file mode 100644 index 00000000000..45afae51708 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/conv2d.py @@ -0,0 +1,117 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op + +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +def validate_conv2d_args_dtypes( + tosa_spec: TosaSpecification, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + op: str = "CONV2D", +) -> torch.dtype: + output_dtype = None + supported_int_types = (torch.int8, torch.int16) + supported_float_types = ( + torch.float16, + torch.float32, + ) + if x.dtype in supported_int_types: + if not tosa_spec.support_integer(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support {x.dtype} but found input type {x.dtype}", + op=op, + ) + if weight.dtype not in (torch.int8,): + raise TosaValueError( + f"TOSA spec {tosa_spec} only supports {torch.int8} weights for {x.dtype} input but found {weight.dtype}", + op=op, + ) + if bias is not None and bias.dtype not in (torch.int32,): + raise TosaValueError( + f"TOSA spec {tosa_spec} only supports {torch.int32} bias for {x.dtype} input but found {bias.dtype}", + op=op, + ) + output_dtype = torch.int32 + + elif x.dtype in supported_float_types: + if not tosa_spec.support_float(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support {x.dtype} but found input type {x.dtype}", + op=op, + ) + if weight.dtype != x.dtype: + raise TosaValueError( + f"TOSA spec {tosa_spec} requires weights {weight.dtype} to be of the same type as input {x.dtype}", + op=op, + ) + if bias is not None and bias.dtype != x.dtype: + raise TosaValueError( + f"TOSA spec {tosa_spec} requires bias {bias.dtype} to be of the same type as input {x.dtype}", + op=op, + ) + output_dtype = x.dtype + else: + raise TosaValueError( + f"Unsupported input dtype {x.dtype}, supported types are {supported_int_types + supported_float_types} ", + op=op, + ) + return output_dtype + + +@register_fake_tosa_op( + "CONV2D(Tensor input, " + "Tensor weight, " + "Tensor bias, " + "int[2] stride, " + "int[4] pad, " + "int[2] dialation, " + "bool transposed, " + "int[2] output_padding, " + "int groups) -> Tensor", # schema + ( + TosaSpecification.create_from_string("TOSA-1.0+FP"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ), # target TOSA specifications +) +def CONV2D( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: list[int], + pad: list[int], + dialation: list[int], + transposed: bool, + output_padding: list[int], + groups: int, +) -> torch.Tensor: + tosa_spec = get_context_spec() + + output_dtype = validate_conv2d_args_dtypes(tosa_spec, x, weight, bias, op="CONV2D") + + torch_pad = [pad[0], pad[2]] + aten_fake_tensor = exir_ops.edge.aten.convolution.default( + x, + weight, + bias, + stride, + torch_pad, + dialation, + transposed, + output_padding, + groups, + ) + return aten_fake_tensor.to(dtype=output_dtype) diff --git a/backends/arm/tosa/dialect/ops/conv3d.py b/backends/arm/tosa/dialect/ops/conv3d.py new file mode 100644 index 00000000000..6428e091367 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/conv3d.py @@ -0,0 +1,75 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops.conv2d import validate_conv2d_args_dtypes +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +def validate_conv3d_args_dtypes( + tosa_spec: TosaSpecification, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.dtype: + if len(x.shape) != 5 or len(weight.shape) != 5: + raise TosaValueError( + f"Expected 5D input/weight tensors for CONV3D, got {x.shape} and {weight.shape}", + op="CONV3D", + ) + return validate_conv2d_args_dtypes(tosa_spec, x, weight, bias, op="CONV3D") + + +@register_fake_tosa_op( + "CONV3D(Tensor input, " + "Tensor weight, " + "Tensor bias, " + "int[3] stride, " + "int[6] pad, " + "int[3] dialation, " + "bool transposed, " + "int[3] output_padding, " + "int groups) -> Tensor", + ( + TosaSpecification.create_from_string("TOSA-1.0+FP"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ), +) +def CONV3D( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: list[int], + pad: list[int], + dialation: list[int], + transposed: bool, + output_padding: list[int], + groups: int, +) -> torch.Tensor: + tosa_spec = get_context_spec() + + output_dtype = validate_conv3d_args_dtypes(tosa_spec, x, weight, bias) + + torch_pad = [pad[0], pad[2], pad[4]] + aten_fake_tensor = exir_ops.edge.aten.convolution.default( + x, + weight, + bias, + stride, + torch_pad, + dialation, + transposed, + output_padding, + groups, + ) + return aten_fake_tensor.to(dtype=output_dtype) diff --git a/backends/arm/tosa/dialect/ops/depthwise_conv2d.py b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py new file mode 100644 index 00000000000..c234a2e84a8 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py @@ -0,0 +1,65 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.tosa.dialect.ops.conv2d import validate_conv2d_args_dtypes +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op + +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_fake_tosa_op( + "DEPTHWISE_CONV2D(Tensor input, " + "Tensor weight, " + "Tensor bias, " + "int[2] stride, " + "int[4] pad, " + "int[2] dialation, " + "bool transposed, " + "int[2] output_padding, " + "int groups) -> Tensor", # schema + ( + TosaSpecification.create_from_string("TOSA-1.0+FP"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ), # target TOSA specifications +) +def DEPTHWISE_CONV2D( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: list[int], + pad: list[int], + dialation: list[int], + transposed: bool, + output_padding: list[int], + groups: int, +) -> torch.Tensor: + tosa_spec = get_context_spec() + + output_dtype = validate_conv2d_args_dtypes( + tosa_spec, x, weight, bias, op="DEPTHWISE_CONV2D" + ) + + torch_pad = [pad[0], pad[2]] + H, W = weight.shape[0], weight.shape[2] + in_channels_group = x.shape[1] // groups + out_channels = weight.shape[1] * x.shape[1] + torch_weight = weight.reshape(out_channels, in_channels_group, H, W) + aten_fake_tensor = exir_ops.edge.aten.convolution.default( + x, + torch_weight, + bias, + stride, + torch_pad, + dialation, + transposed, + output_padding, + groups, + ) + return aten_fake_tensor.to(dtype=output_dtype) diff --git a/backends/arm/tosa/dialect/ops/matmul.py b/backends/arm/tosa/dialect/ops/matmul.py new file mode 100644 index 00000000000..1ba3821f674 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/matmul.py @@ -0,0 +1,56 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op + +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_fake_tosa_op( + "MATMUL(Tensor input1, Tensor input2) -> Tensor", # schema + ( + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ), # target TOSA specifications +) +def MATMUL(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + tosa_spec = get_context_spec() + """Performs matrix multiplication on two input tensors. + Additionally validates TOSA constraints of a MATMUL op. + """ + if x1.dtype != x2.dtype: + raise TosaValueError( + f"Input tensors must have the same dtype, got {x1.dtype} and {x2.dtype}", + op="MATMUL", + ) + if x1.dtype in (torch.int8, torch.int16): + if not tosa_spec.support_integer(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support integers", op="MATMUL" + ) + else: + dtype = torch.int32 + elif x1.dtype in (torch.float16, torch.float32): + if not tosa_spec.support_float(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support float", op="MATMUL" + ) + else: + # float16 supports float16 accumulation as well + dtype = torch.float32 + else: + raise TosaValueError( + f"Input tensors must be of type int8, float16 or float32, got {x1.dtype}", + op="MATMUL", + ) + + aten_fake_tensor = exir_ops.edge.aten.bmm.default(x1, x2) + + return torch.empty_like(aten_fake_tensor, dtype=dtype) diff --git a/backends/arm/tosa/dialect/ops/rescale.py b/backends/arm/tosa/dialect/ops/rescale.py index 5f0cf9d15dc..f622bbf115d 100644 --- a/backends/arm/tosa/dialect/ops/rescale.py +++ b/backends/arm/tosa/dialect/ops/rescale.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import List + import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op @@ -14,13 +16,13 @@ @register_fake_tosa_op( - "RESCALE(Tensor input1, ScalarType dtype, float scale, int in_zp, int out_zp) -> Tensor", # schema + "RESCALE(Tensor input1, ScalarType dtype, float[] scale, int in_zp, int out_zp) -> Tensor", # schema ( TosaSpecification.create_from_string("TOSA-1.0+INT"), ), # target TOSA specifications ) def RESCALE( - x: torch.Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int + x: torch.Tensor, dtype: torch.dtype, scales: List[float], in_zp: int, out_zp: int ) -> torch.Tensor: tosa_spec = get_context_spec() """Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op. diff --git a/backends/arm/tosa/dialect/ops/resize.py b/backends/arm/tosa/dialect/ops/resize.py new file mode 100644 index 00000000000..b40b1f74a75 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/resize.py @@ -0,0 +1,66 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Literal, Optional + +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op + +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +# Add kwarg instead? +@register_fake_tosa_op( + "RESIZE(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, str resize_mode) -> Tensor", # schema + ( + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ), # target TOSA specifications +) +def RESIZE( + x: torch.Tensor, + output_size: list[int] | None = None, + align_corners: Optional[bool] = False, + scale_factors: list[float] | None = None, + *, + resize_mode: Literal["nearest", "bilinear"], +) -> torch.Tensor: + tosa_spec = get_context_spec() + + if resize_mode not in ("nearest", "bilinear"): + raise TosaValueError(f"Unsupported resize mode {resize_mode} for TOSA RESIZE") + if x.dtype == torch.int8: + if not tosa_spec.support_integer(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support integers", op="RESIZE" + ) + bilinear = resize_mode == "bilinear" + output_dtype = torch.int32 if bilinear else torch.int8 + elif x.dtype == torch.int16: + if not tosa_spec.support_integer(): + raise TosaValueError( + f"Context TOSA spec {tosa_spec} doesn't support int16", op="RESIZE" + ) + output_dtype = x.dtype + elif x.dtype in (torch.float16, torch.float32): + if not tosa_spec.support_float(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support float", op="RESIZE" + ) + output_dtype = x.dtype + else: + raise TosaValueError(f"Unsupported input dtype {x.dtype} for TOSA RESIZE") + + # Does it matter which one to use for fake tracing? + fake_aten_tensor = exir_ops.edge.aten.upsample_nearest2d.vec( + x, output_size, scale_factors + ) + + return fake_aten_tensor.to(output_dtype) diff --git a/backends/arm/tosa/dialect/ops/transpose.py b/backends/arm/tosa/dialect/ops/transpose.py index 9c5aba05394..8d5bf8bac70 100644 --- a/backends/arm/tosa/dialect/ops/transpose.py +++ b/backends/arm/tosa/dialect/ops/transpose.py @@ -26,9 +26,9 @@ def TRANSPOSE(a, perms): # By utilizing an edge IR passthrough operator we can keep the edge program in # channels-first/contiguous and get the desired behavior in the TOSA lowering. - if len(perms) not in (4, 5): + if len(perms) not in (4, 5, 6): raise TosaValueError( - f"Only 4D and 5D tensors are supported, got {len(perms)}: {perms}", + f"Only 4D, 5D and 6D tensors are supported, got {len(perms)}: {perms}", op="TRANSPOSE", ) diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 60ef98a37c0..ca83c6c09ea 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -2,22 +2,23 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide PyTorch-to-TOSA mapping helpers. -# pyre-unsafe +Use these utilities to translate PyTorch dtypes and FX node metadata into the +TOSA serializer types and shapes used during initial compilation. -# -# PyTorch to Tosa mapping - simple mapping functions and multi-type extraction -# of key information. These are used by the initial compile stage which captures -# the standardised TOSA representation. -# +""" +import operator +from enum import Enum from typing import Any, Optional, Sequence -import serializer.tosa_serializer as ts # type: ignore - import torch +import tosa_serializer as ts from executorch.backends.arm.tosa.specification import TosaSpecification +TOSA_TENSOR_NAME_META = "tosa_tensor_name" + UNSUPPORTED_DTYPES = ( torch.float64, torch.double, @@ -31,7 +32,45 @@ ) +class TosaSpecialDtype(Enum): + """Special TOSA dtypes not natively expressed in PyTorch.""" + + INT48 = ts.DType.INT48 + + def get_tosa_dtype(self) -> ts.DType: + """Return the underlying ``ts.DType`` enumerant. + + Returns: + ts.DType: Serializer dtype associated with the enum entry. + + """ + return self.value + + @staticmethod + def meta_key() -> str: + """Return the FX ``meta`` key that stores special dtypes. + + Returns: + str: Metadata key used to encode :class:`TosaSpecialDtype`. + + """ + return "tosa_special_dtype" + + def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any: + """Map a ``torch.dtype`` to a ``ts.DType``. + + Args: + data_type (torch.dtype): PyTorch dtype to convert. + tosa_spec (TosaSpecification): Active spec (reserved for future checks). + + Returns: + ts.DType: Matching serializer dtype. + + Raises: + ValueError: If the dtype is unsupported or unknown. + + """ if data_type in UNSUPPORTED_DTYPES: raise ValueError(f"Unsupported type: {data_type}") @@ -57,7 +96,22 @@ def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any: # TODO: other types, can be # SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None def extract_tensor_meta(meta, tosa_spec: TosaSpecification): - assert meta.get("val") is not None + """Extract dtype, shape, and dimension order from FX metadata. + + Args: + meta (dict): FX node ``meta`` containing a ``val`` FakeTensor (or tuple). + tosa_spec (TosaSpecification): Active TOSA spec for dtype mapping. + + Returns: + tuple[ts.DType, tuple[int, ...], tuple[int, ...]]: Tuple containing + tensor dtype, shape, and dimension order. + + Raises: + ValueError: If ``meta['val']`` is not a ``FakeTensor``. + + """ + if meta.get("val") is None: + raise ValueError("Expected node.meta['val'] to be set to a FakeTensor") val = meta["val"] if type(val) is tuple: # TODO: should use first concrete representation @@ -77,23 +131,89 @@ def extract_tensor_meta(meta, tosa_spec: TosaSpecification): return (dtype, shape, dim_order) -# Class to capture arguments and turn into tensor references for TOSA OPs class TosaArg: + """Capture and normalize TOSA operator arguments. + + Use this to convert FX nodes, sequences, and numeric literals into a + consistent structure suitable for TOSA serialization. + + Attributes: + name (str): Node name when argument is a ``torch.fx.Node``; empty + otherwise. + dtype (ts.DType | None): Inferred dtype when available. + shape (tuple[int, ...] | None): Inferred shape when available. + dim_order (tuple[int, ...] | None): Dimension order, defaulting to + ``range(len(shape))``. + special (list | None): Captured list when the argument is a sequence. + number (float | int | None): Captured numeric value when provided. + tosa_spec (TosaSpecification): Active specification used for mapping. + multiple_output_name (list[str]): Output node names when node has multiple outputs; empty otherwise. + """ + def __process_node(self, argument: torch.fx.Node): - self.name: str = argument.name - self.dtype, self.shape, self.dim_order = extract_tensor_meta( - argument.meta, self.tosa_spec - ) + """Parse a ``torch.fx.Node`` and populate tensor attributes. + + Args: + argument (torch.fx.Node): FX node to inspect. + + """ + suffix = argument.meta.get(TOSA_TENSOR_NAME_META, "") + self.name = argument.name + suffix + + if "val" in argument.meta: + output_dtype, self.shape, self.dim_order = extract_tensor_meta( + argument.meta, self.tosa_spec + ) + # Handle special case of types not representable in torch (i.e. i48_t) + if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None): + output_dtype = special_type.get_tosa_dtype() + + self.dtype = output_dtype + + # If all users of the node are getitems, node visitors should connect the output of this node directly to the getitem tensors. + # Add a new attribute 'multiple_output_names' instead of making 'name' a list to avoid ambiguity regarding the type of 'name'. + # Make name of the output is the first getitem since we in most cases only handle that output. + users = list(argument.users) + if len(users) > 0 and all(user.target == operator.getitem for user in users): + self.multiple_output_names: list = [user.name + suffix for user in users] + self.name = self.multiple_output_names[0] + else: + self.multiple_output_names = [] def __process_list(self, argument): + """Capture a sequence argument as ``special``. + + Args: + argument (Sequence[Any]): Sequence to store. + + """ self.special: list = list(argument) def __process_number(self, argument: float | int): + """Capture a numeric argument as ``number``. + + Args: + argument (float | int): Numeric value. + + """ self.number: float | int = argument def __init__( self, argument: Any, tosa_spec: Optional[TosaSpecification] = None ) -> None: + """Initialize the argument wrapper and populate fields. + + Args: + argument (Any): One of ``torch.fx.Node``, ``Sequence``, ``int``, + ``float``, ``torch.dtype``, or ``None``. + tosa_spec (Optional[TosaSpecification]): Active specification; + required for metadata extraction. + + Raises: + ValueError: If ``tosa_spec`` is missing or has the wrong type. + RuntimeError: If ``argument`` is of an unsupported type. + + """ if tosa_spec is None: raise ValueError("tosa_spec is None") elif not isinstance(tosa_spec, TosaSpecification): @@ -127,6 +247,12 @@ def __init__( ) def __repr__(self): + """Return a compact representation of populated attributes. + + Returns: + str: Readable list of set attributes. + + """ attrs = [] if hasattr(self, "name"): if self.name is not None: @@ -143,4 +269,6 @@ def __repr__(self): attrs.append(f"number={self.number!r}") if hasattr(self, "tosa_spec") and self.tosa_spec is not None: attrs.append(f"tosa_spec={self.tosa_spec!r}") + if hasattr(self, "multiple_output_names"): + attrs.append(f"names={self.multiple_output_names!r}") return f"{self.__class__.__name__}({', '.join(attrs)})" diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index c0f546fe50a..3fd88b330c2 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -2,10 +2,19 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide a partitioner for delegating subgraphs to the TOSA backend. -# pyre-unsafe +Implement logic to identify and tag regions of an ``ExportedProgram`` that can +be delegated to the TOSA backend. Use this module to: + +- Partition graphs based on operator support and additional checks. +- Prune trivial no-op partitions that would lower to empty TOSA graphs. +- Tag constant data and report reasons for rejected nodes. + +""" import logging +from itertools import count from typing import Callable, List, Optional, Sequence, Tuple import torch @@ -13,13 +22,14 @@ from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( calculate_multiples, ) + +from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.operator_support.tosa_supported_operators import ( tosa_support_factory, ) from executorch.backends.arm.tosa.backend import TOSABackend -from executorch.backends.arm.tosa.specification import get_tosa_spec -from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, Partitioner, @@ -27,118 +37,227 @@ ) from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.graph_module import get_cond_while_submodules from torch.export.exported_program import ExportedProgram -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx import GraphModule +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.passes.operator_support import OperatorSupportBase logger = logging.getLogger(__name__) def is_noop_clone(node: torch.fx.node.Node) -> bool: + """Return True if the node is a no-op ``dim_order_ops._clone_dim_order``. + + Args: + node (torch.fx.Node): FX node to inspect. + + Returns: + bool: True if the node targets ``dim_order_ops._clone_dim_order.default`` + in the Edge dialect; otherwise, False. + + """ return node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default -def is_noop_alias_copy(node: torch.fx.node.Node) -> bool: +def is_noop_alias_copy(node: torch.fx.Node) -> bool: + """Return True if the node is a no-op ``aten.alias_copy``. + + Args: + node (torch.fx.Node): FX node to inspect. + + Returns: + bool: True if the node targets ``aten.alias_copy.default``; otherwise, + False. + + """ return node.target == exir_ops.edge.aten.alias_copy.default def is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool: + """Return True if node is a no-op ``dim_order_ops._to_dim_order_copy``. + + Consider the op a no-op when the output dtype equals the input's dtype. + + Args: + node (torch.fx.Node): FX node to inspect. + + Returns: + bool: True if it targets ``_to_dim_order_copy.default`` and preserves + dtype; otherwise, False. + + """ if node.target != exir_ops.edge.dim_order_ops._to_dim_order_copy.default: return False else: - return node.meta.get("dtype") == get_first_fake_tensor(node.args[0]).dtype # type: ignore[arg-type] + input_node = ensure_type(torch.fx.Node, node.args[0]) + return node.meta.get("dtype") == get_first_fake_tensor(input_node).dtype def is_noop_expand(node: torch.fx.node.Node) -> bool: + """Return True if the node is an ``expand_copy`` with all-ones multiples. + + This corresponds to a semantic no-op, since expanding by 1 along every + dimension leaves the tensor unchanged. + + Args: + node (torch.fx.Node): FX node to inspect. + + Returns: + bool: True if the node targets ``aten.expand_copy.default`` and all + computed multiples are 1; otherwise, False. + + """ if node.target != exir_ops.edge.aten.expand_copy.default: return False else: - multiples = calculate_multiples(node.args) - return all(m == 1 for m in multiples) + multiples, changes_rank = calculate_multiples(node.args) + return all(m == 1 for m in multiples) and not changes_rank + + +def is_partitioned( + node: torch.fx.Node, + tag: str, +) -> bool: + """Return True if the node currently belongs to the partition ``tag``. + + Args: + node (torch.fx.Node): FX node to check. + tag (str): Delegation tag identifying the partition. + + Returns: + bool: True if the node carries the matching delegation tag. + + """ + return "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag + + +def reject_partition( + reason: str, partition: Partition, reporter: WhyNoPartitionReporter +) -> None: + """Remove a proposed partition and record the rejection reason. + + Args: + reason (str): Human-readable explanation for rejection. + partition (object): Proposed partition object from the + capability partitioner. + reporter (WhyNoPartitionReporter): used to report why nodes were rejected. + + """ + for node in partition.nodes: + if "delegation_tag" in node.meta: + del node.meta["delegation_tag"] + reporter.report_reject( + node, + reason, + ) class TOSAPartitioner(Partitioner): + """Partition an exported program into TOSA-delegable subgraphs. + + Construct this partitioner for compile specs targeting TOSA. The partition + algorithm uses capability checks and optional additional operator-support + rules to tag nodes with a delegation tag per subgraph. + + """ + def __init__( self, - compile_spec: List[CompileSpec], + compile_spec: TosaCompileSpec, additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> None: - from executorch.backends.arm.arm_backend import is_tosa + """Initialize the TOSAPartitioner. - if not is_tosa(compile_spec): - raise RuntimeError("compile spec is not targeting TOSA") - self.delegation_spec = DelegationSpec(TOSABackend.__name__, compile_spec) - self.additional_checks = additional_checks - - def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa - # Run the CapabilityBasedPartitioner to return the largest possible - # subgraphs containing the nodes with the tags + Args: + compile_spec (TosaCompileSpec): Parsed compile specifications for + TOSA containing the TOSA spec and original list. + additional_checks (Optional[Sequence[OperatorSupportBase]]): Extra + operator-support checks to apply when partitioning. - logger.info("TOSAPartitioner::partition") - partition_tags: dict[str, DelegationSpec] = {} + Raises: + RuntimeError: If the provided compile spec does not target TOSA. - tosa_spec = get_tosa_spec(self.delegation_spec.compile_specs) - - logger.info(f"Partitioning for {self.delegation_spec.backend_id}: {tosa_spec}") + """ + self.delegation_spec = DelegationSpec( + TOSABackend.__name__, compile_spec.to_list() + ) + self.tosa_spec = compile_spec.tosa_spec + self.additional_checks = additional_checks + self.tosa_spec = compile_spec.tosa_spec - reporter = WhyNoPartitionReporter() + def _tag_module( # noqa + self, + module: GraphModule, + containing_program: ExportedProgram, + reporter: WhyNoPartitionReporter, + tag_iterator: count | None = None, + ) -> set[str]: + """Tag nodes in a module or submodule from the containing program. + + Args: + module: A GraphModule from `containing_program` to tag nodes in. + containing_program: The ExportedProgram that contains the module. + reporter: A reporter to report why nodes were rejected. + + Returns: + A set of strings with the partition tags. + + """ + tags: set[str] = set() + if tag_iterator is None: + tag_iterator = count(0) + for _, submodule, _ in get_cond_while_submodules(module): + submodule_tags = self._tag_module( + submodule, containing_program, reporter, tag_iterator + ) + if len(tags & submodule_tags) != 0: + raise RuntimeError( + "Got overlapping tags in two different modules, this shouldn't happen." + ) + tags = tags | submodule_tags operator_support = tosa_support_factory( - tosa_spec, exported_program, reporter, self.additional_checks + self.tosa_spec, containing_program, reporter, self.additional_checks ) capability_partitioner = CapabilityBasedPartitioner( - exported_program.graph_module, + module, operator_support, allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() - def reject_partition(reason: str, partition, tag) -> None: - for node in partition.nodes: - if "delegation_tag" in node.meta: - del node.meta["delegation_tag"] - reporter.report_reject( - node, - reason, - ) - partition_tags.pop(tag, None) - for partition in partition_list: - tag = f"tag{partition.id}" - - def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: - return ( - "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag - ) + tag = f"tag{next(tag_iterator)}" + tags.add(tag) for node in partition.nodes: node.meta["delegation_tag"] = tag - partition_tags[tag] = self.delegation_spec - # De-tag outmost q-nodes upwards and dq-nodes downwards. - # De-tag if at least one input/ output is not part of partition. - for node in exported_program.graph_module.graph.nodes: - if not is_partitioned(node): + # De-tag outermost q-nodes upwards and dq-nodes downwards. + # De-tag if at least one input/output is not part of the partition. + for node in module.graph.nodes: + if not is_partitioned(node, tag): continue if node.target in Q_OPS: for input in node.all_input_nodes: - if not is_partitioned(input): + if not is_partitioned(input, tag): del node.meta["delegation_tag"] break continue if node.target in DQ_OPS: for user in node.users: - if not is_partitioned(user): + if not is_partitioned(user, tag): del node.meta["delegation_tag"] break continue - if tosa_spec.support_float(): + if self.tosa_spec.support_float(): continue - if is_partitioned(node): + if is_partitioned(node, tag): for input in node.all_input_nodes: - if is_partitioned(input): + if is_partitioned(input, tag): continue if get_first_fake_tensor(input).dtype.is_floating_point: reporter.report_reject( @@ -161,62 +280,177 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: reject_partition( "Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.", partition, - tag, + reporter, ) + tags.remove(tag) + return tags + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """Partition the program and tag TOSA-compatible subgraphs. + + Run the FX capability-based partitioner to propose subgraphs, then + refine tags by removing boundary-only quantize/dequantize nodes and by + rejecting partitions that would lower to no-ops. Emit a detailed report + of rejected nodes and their reasons. + + Args: + exported_program (ExportedProgram): Program to analyze and + partition. + + Returns: + PartitionResult: The input program with nodes tagged for delegation + and a mapping of partition tags to delegation specs. + + """ + logger.info("TOSAPartitioner::partition") + logger.info( + f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}" + ) + + reporter = WhyNoPartitionReporter() + tags = self._tag_module( + exported_program.graph_module, exported_program, reporter + ) + partition_tags = {tag: self.delegation_spec for tag in tags} tag_constant_data(exported_program) - logger.info(f"The following nodes were rejected for {tosa_spec}:") + logger.info(f"The following nodes were rejected for {self.tosa_spec}:") logger.info("\n" + reporter.get_table_report()) logger.info("(Placeholders and outputs are not included in this list)") return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags ) - def ops_to_not_decompose( + def ops_to_not_decompose( # noqa: C901 self, ep: ExportedProgram, ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: - ops_to_not_decompose_if_quant_op = [ + """Return operators and a filter that should not be decomposed. + + Provide a base set of ops to preserve as-is and a predicate that keeps + certain activations whole when surrounded by quantize/dequantize ops in + a quantized graph. This helps downstream TOSA lowering and delegation. + + Args: + ep (ExportedProgram): Program used to infer target-specific policy. + + Returns: + Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + A list of op overloads to keep intact, and an optional filter + function that returns True when an op should not be decomposed. + + """ + ops_to_not_decompose_if_quant_op = { + torch.ops.aten.eye.default, torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, - ] + torch.ops.aten.linear.default, + torch.ops.aten.linspace.default, + } + ops_to_not_decompose_if_fp = { + torch.ops.aten.eye.default, + torch.ops.aten.logit.default, + torch.ops.aten.linear.default, + torch.ops.aten.linspace.default, + } + ops_to_not_decompose_always = { + torch.ops.aten.logit.default, + } + ops_to_not_decompose_if_integer = { + torch.ops.aten.eye.default, + torch.ops.aten.linspace.default, + } def filter_fn(node: torch.fx.Node) -> bool: - # This function filters for operators to not decompose where: - # - It's target is in ops_to_not_decompose_if_quant_op list. - # - All it's inputs/outputs are quantize operators. - dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default - q = torch.ops.quantized_decomposed.quantize_per_tensor.default + """Filter function applied to ops in 'ops_to_not_decompose'. + Returns True if the op should not be decomposed. + If this function returns True, the partitioner *must* accept the node, or the lowering fails. + + Args: + node (torch.fx.Node): FX node to evaluate. + + Returns: + bool: True to keep the op intact; otherwise, False. + + """ + if ( + self.tosa_spec.support_float() + and node.target in ops_to_not_decompose_if_fp + ): + return True + + dq = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ) + q = ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ) if node.target in ops_to_not_decompose_if_quant_op: # Assume we should not decompose the operator (it is quantized) - should_not_decompose = True + correct_output_quant = True + correct_input_quant = True input_nodes = node.all_input_nodes - ouput_nodes = node.users + output_nodes = node.users for inp in input_nodes: - if inp.target != dq: - should_not_decompose = False - - for out in ouput_nodes: - if out.target != q: - should_not_decompose = False - - return should_not_decompose - - # Be default, do not decompose the operator - return True - - ops_to_not_decompose = [ - torch.ops.aten.linear.default, - torch.ops.aten.eye.default, - torch.ops.aten.linspace.default, - torch.ops.aten.logit.default, - ] + ops_to_not_decompose_if_quant_op + if inp.target not in dq: + correct_input_quant = False + + for out in output_nodes: + if out.target not in q: + correct_output_quant = False + # In some cases, a linear is quantized together with its activation. + if ( + node.target == torch.ops.aten.linear.default + and len(output_nodes) == 1 + and list(output_nodes)[0].target + in (torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default) + ): + correct_output_quant = True + + if correct_input_quant and correct_output_quant: + return True + + if node.target in ops_to_not_decompose_if_integer: + # We only want to tag nodes as do_not_decompose if we are sure that + # we can partition them. We partition them if one or more of the + # following is true: + # 1. The node outputs an integer type. + # 2. All users cast the output to an integer type. + + dtype = get_first_fake_tensor(node).dtype + if not dtype.is_floating_point and not dtype.is_complex: + return True + + output_nodes = node.users + for user in output_nodes: + if user.target != torch.ops.aten.to.dtype: + return False + else: + cast_dtype = get_first_fake_tensor(user).dtype + if cast_dtype.is_complex or cast_dtype.is_floating_point: + return False + return True + + if node.target in ops_to_not_decompose_if_fp: + if self.tosa_spec.support_float(): + return True + if node.target in ops_to_not_decompose_always: + return True + return False + + ops_to_not_decompose = list( + ops_to_not_decompose_always + | ops_to_not_decompose_if_quant_op + | ops_to_not_decompose_if_fp + | ops_to_not_decompose_if_integer + ) - tosa_spec = get_tosa_spec(self.delegation_spec.compile_specs) - if not tosa_spec.is_U55_subset: + if not self.tosa_spec.is_U55_subset: # Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d # and upsample_nearest2d decompose into that it will not be possible to # delegate those operators on U55. If we have said here to not decompose diff --git a/backends/arm/tosa/quant_utils.py b/backends/arm/tosa/quant_utils.py deleted file mode 100644 index 86e8e5bad8b..00000000000 --- a/backends/arm/tosa/quant_utils.py +++ /dev/null @@ -1,445 +0,0 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -# Utility functions for TOSA quantized lowerings - -import math - -from typing import Any, Tuple - -import serializer.tosa_serializer as ts # type: ignore - -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, - get_output_qparams, -) - -from executorch.backends.arm.tosa.mapping import TosaArg -from torch.fx import Node -from tosa.RoundingMode import RoundingMode # type: ignore - - -def insert_rescale_ops_to_int32_maxscale( - tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None -) -> tuple[list[Any], float]: - """For ADD and SUB, we rescale to int32 using a different common scale(2*max(left scale,right scale)) - compared to all the other cases. We also multiply the left and right scales by 1<<20 giving us extra precision - for the computation without overflowing. - - Returns a list of the rescaled nodes and the scale factor used, - needed by insert_rescale_op_to_int8. - """ - - if len(inputs) > 2: - raise ValueError("More than two inputs not supported") - - tensors = inputs.copy() - # Reshape tensor according to TOSA dim order - for tensor in tensors: - dim_order = tensor.dim_order - tensor.shape = [tensor.shape[i] for i in dim_order] - - input_qparams = get_input_qparams(node) - lhs_qparams, rhs_qparams = input_qparams.values() - lhs_scale = lhs_qparams.get_scale_per_tensor() - rhs_scale = rhs_qparams.get_scale_per_tensor() - # Common scale for the two numbers - max_scale_2x = 2 * max(lhs_scale, rhs_scale) - SHIFT_INT8 = 20 - # We are adding two int8 numbers. If the zero point is non-null, the result will be in the range [-255;255], therefore we need 9 bits for the result. - # We have a 32-bit accumulator, so we can shift to the left by 20 bits and not overflow. In reality, because we divide by the 2*max(lhs_scale,rhs_scale) - # we are shifting to the left by 19. - lhs_factor = (1 << SHIFT_INT8) * lhs_scale / max_scale_2x - rhs_factor = (1 << SHIFT_INT8) * rhs_scale / max_scale_2x - rescaled_lhs = build_rescale_to_int32( - tosa_graph, - tensors[0], - lhs_qparams.get_zp_per_tensor(), - lhs_factor, - tosa_spec=tosa_spec, - ) - rescaled_rhs = build_rescale_to_int32( - tosa_graph, - tensors[1], - rhs_qparams.get_zp_per_tensor(), - rhs_factor, - tosa_spec=tosa_spec, - ) - out_qparam = get_output_qparams(node)[0] - out_scale = out_qparam.get_scale_per_tensor() - back_scale = max_scale_2x / (out_scale * (1 << SHIFT_INT8)) - - return [rescaled_lhs, rescaled_rhs], back_scale - - -def insert_rescale_ops_to_int32( - tosa_graph: Any, - inputs: list[TosaArg], - node: Node, - tosa_spec=None, -) -> tuple[list[Any], float]: - """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. - The scales are adjusted using the smallest scale of all 'nodes'. - - Returns a list of the rescaled nodes and the scale factor used, - needed by insert_rescale_op_to_int8. - - This functions is used in serialization to TOSA for target ops that are - handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict. - """ - - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, - ) - - tensors = inputs.copy() - - # Reshape tensor according to TOSA dim order - for tensor in tensors: - dim_order = tensor.dim_order - tensor.shape = [tensor.shape[i] for i in dim_order] - - input_qparams = get_input_qparams(node) - qargs = input_qparams.values() - - # Scale the int8 quantized input to a common scale in the integer - # domain - min_scale = min([qarg.get_scale_per_tensor() for qarg in qargs]) - scales = [qarg.get_scale_per_tensor() / min_scale for qarg in qargs] - - rescaled_nodes: list[Any] = [] - for tensor, qarg, scale in zip(tensors, qargs, scales): - rescaled_nodes.append( - build_rescale_to_int32( - tosa_graph, tensor, qarg.get_zp_per_tensor(), scale, tosa_spec=tosa_spec - ) - ) - return rescaled_nodes, min_scale - - -def insert_rescale_op_to_int8( - tosa_graph: Any, - last_tensor: TosaArg, - scale: float, - node: Node, - compute_rescale=True, - tosa_spec=None, -) -> None: - """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. - Parameters: - node: The original node that is being handled by the rescales. - last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32' - compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. - tosa_graph: the tosa_graph to manipulate. - - This functions is used in serialization to TOSA for target ops that are - handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict. - """ - _insert_rescale_op_to_dtype( - tosa_graph, last_tensor, scale, node, ts.DType.INT8, compute_rescale, tosa_spec - ) - - -def insert_rescale_op_to_int16( - tosa_graph: Any, - last_tensor: TosaArg, - scale: float, - node: Node, - compute_rescale=True, - tosa_spec=None, -) -> None: - """Rescales the node back to int16, adding a suitable RESCALE op to 'tosa_graph'. - Parameters: - node: The original node that is being handled by the rescales. - last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32' - compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. - tosa_graph: the tosa_graph to manipulate. - - This functions is used in serialization to TOSA for target ops that are - handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict. - """ - _insert_rescale_op_to_dtype( - tosa_graph, last_tensor, scale, node, ts.DType.INT16, compute_rescale, tosa_spec - ) - - -def _insert_rescale_op_to_dtype( - tosa_graph: Any, - last_tensor: TosaArg, - scale: float, - node: Node, - output_dtype: Any, - compute_rescale=True, - tosa_spec=None, -) -> None: - """Common implementation for rescaling nodes back to a specific dtype. - Parameters: - node: The original node that is being handled by the rescales. - last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32' - output_dtype: The target dtype (ts.DType.INT8 or ts.DType.INT16) - compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. - tosa_graph: the tosa_graph to manipulate. - - This functions is used in serialization to TOSA for target ops that are - handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict. - """ - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_output_qparams, - ) - - output_qparams = get_output_qparams(node) - if len(output_qparams) != 1: - raise ValueError("More than one output not supported") - - qargs_out = output_qparams[0] - if compute_rescale: - output_rescale_scale = scale / qargs_out.get_scale_per_tensor() - else: - output_rescale_scale = scale - - # Rescale Back to the specified dtype - build_rescale_from_int32_to_dtype( - tosa_graph, - last_tensor, - node.name, - qargs_out.get_zp_per_tensor(), - output_rescale_scale, - output_dtype, - tosa_spec=tosa_spec, - ) - - -# TOSA uses the RESCALE operation to scale between values with differing precision. -# The RESCALE operator is defined using an integer multiply, add, and shift. -# This utility function is for calculating the multiplier and shift given a scale. -# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling -def compute_multiplier_and_shift( - scales: list[float], scaleWidth: int = 32 -) -> Tuple[list[int], list[int]]: - if scaleWidth == 16: - offset = 15 - elif scaleWidth == 32: - offset = 31 - else: - raise ValueError( - f"Unsupported scale width: {scaleWidth}, only 16 and 32 are valid values." - ) - - multipliers = [] - shifts = [] - for scale in scales: - mantissa, exponent = math.frexp(scale) - shift = exponent - - const_2_power_15_or_31 = 1 << offset - shifted_mantissa = round(mantissa * const_2_power_15_or_31) - - assert shifted_mantissa <= const_2_power_15_or_31 - - if shifted_mantissa == const_2_power_15_or_31: - shifted_mantissa = shifted_mantissa // 2 - shift += 1 - - # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits. - shift = offset - shift - - # INT32_MAX, 2^31 - 1 - assert shifted_mantissa <= (const_2_power_15_or_31 - 1) - - multiplier = shifted_mantissa - - if shift > 62: - multiplier = multiplier >> min(31, shift - 62) - shift = 62 - multipliers.append(multiplier) - shifts.append(shift) - return multipliers, shifts - - -# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be -# const inputs. Create constant operators from the data already initialized. -def create_const_ops_for_rescale( - tosa_fb, - scale_32, - input_dtype, - node_name, - multipliers, - shifts, - input_zp, - output_zp, - output_dtype, - ts, -): - - multipliers = tosa_fb.addConst( - (len(multipliers),), - ts.DType.INT32 if scale_32 else ts.DType.INT16, - multipliers, - name=node_name + "_multipliers", - ) - shifts = tosa_fb.addConst( - (len(shifts),), ts.DType.INT8, shifts, name=node_name + "_shifts" - ) - input_zp = tosa_fb.addConst( - [1], input_dtype, input_zp, name=node_name + "_input_zp" - ) - output_zp = tosa_fb.addConst( - [1], output_dtype, output_zp, name=node_name + "_output_zp" - ) - - return [multipliers.name, shifts.name, input_zp.name, output_zp.name] - - -def build_rescale( - tosa_fb: Any, - scale: list[float], - input_node: Any, - output_name: str, - output_type: Any, - input_zp: list[int], - output_zp: list[int], - rounding_mode: RoundingMode, - per_channel=False, -): - import serializer.tosa_serializer as ts # type: ignore - import tosa.Op as TosaOp # type: ignore - - scaleWidth = 32 - is_scale32 = True - multipliers, shifts = compute_multiplier_and_shift(scale, scaleWidth) - rescale_inputs = create_const_ops_for_rescale( - tosa_fb, - is_scale32, - input_node.dtype, - output_name, - multipliers, - shifts, - input_zp, - output_zp, - output_type, - ts, - ) - attr_rescale = ts.TosaSerializerAttribute() - attr_rescale.RescaleAttribute( - scale32=is_scale32, - rounding_mode=rounding_mode, - per_channel=per_channel, - input_unsigned=False, - output_unsigned=False, - ) - - tosa_fb.addOperator( - TosaOp.Op().RESCALE, - [input_node.name, *rescale_inputs], - [output_name], - attr_rescale, - ) - - return - - -def build_rescale_to_int32( - tosa_fb: Any, - input_arg: TosaArg, - input_zp: int, - rescale_scale: float, - is_scale32: bool = True, - is_double_round: bool = False, - per_channel: bool = False, - tosa_spec=None, -) -> Any: - input_A_rescaled_to_int32 = None - - input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input_arg.shape, ts.DType.INT32) - - build_rescale( - tosa_fb, - [rescale_scale], - input_arg, - input_A_rescaled_to_int32.name, - ts.DType.INT32, - [input_zp], - [0], - rounding_mode=RoundingMode.SINGLE_ROUND, - ) # type: ignore[call-arg] - - return input_A_rescaled_to_int32 - - -def build_rescale_from_int32( - tosa_fb: Any, - input_node: TosaArg, - output_name: str, - output_zp: int, - rescale_scale: float, - is_scale32: bool = True, - is_double_round: bool = False, - per_channel: bool = False, - tosa_spec=None, -) -> None: - # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs - # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale - build_rescale_from_int32_to_dtype( - tosa_fb, - input_node, - output_name, - output_zp, - rescale_scale, - ts.DType.INT8, - is_scale32, - is_double_round, - per_channel, - tosa_spec, - ) - - return - - -def build_rescale_from_int32_to_dtype( - tosa_fb: Any, - input_node: TosaArg, - output_name: str, - output_zp: int, - rescale_scale: float, - output_dtype: Any, - is_scale32: bool = True, - is_double_round: bool = False, - per_channel: bool = False, - tosa_spec=None, -) -> None: - """Common implementation for rescaling from INT32 to a specific dtype (INT8 or INT16). - - Parameters: - tosa_fb: The TOSA serializer - input_node: Input tensor (should be INT32) - output_name: Name for the output tensor - output_zp: Output zero point - rescale_scale: Rescaling factor - output_dtype: Target dtype (ts.DType.INT8 or ts.DType.INT16) - Other parameters: Standard rescale parameters - """ - # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs - # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale - build_rescale( - tosa_fb, - [rescale_scale], - input_node, - output_name=output_name, - output_type=output_dtype, - input_zp=[0], - output_zp=[output_zp], - rounding_mode=RoundingMode.SINGLE_ROUND, - ) # type: ignore[call-arg] - - return diff --git a/backends/arm/tosa/schemas/tosa_1.0.fbs b/backends/arm/tosa/schemas/tosa_1.0.fbs index acd376daa9f..e58682da898 100644 --- a/backends/arm/tosa/schemas/tosa_1.0.fbs +++ b/backends/arm/tosa/schemas/tosa_1.0.fbs @@ -510,13 +510,31 @@ table TosaTensor { variable: bool; // is this a variable tensor is_unranked: bool; // whether this is an unranked tensor variable_name:string; // name for variable attribute + + // In a model that is larger than 2GB, then tensors instead uses the following + // attributes to find stored data, which is outside of flatbuffers + // the offset is calculated relative to the beginning of the file and is only + // valid if > 1. + offset: ulong; + size: ulong; +} + +table TosaShape { + name: string; // name of the shape + rank: uint32; // rank of the shape + data: [ubyte] (force_align: 8); // raw data array if it's a constant shape +} + +table OpLocation { + text: string; // Opaque string, interpretted by user } table TosaOperator { op:Op; // operator enum attribute:Attribute; // union structure. operator attribute - inputs:[string]; // list of input tensor names - outputs:[string]; // list of output tensor names + inputs:[string]; // list of input tensor or shape names + outputs:[string]; // list of output tensor or shape names + location: OpLocation; // location of this Op in mlir } table TosaBasicBlock { @@ -525,6 +543,7 @@ table TosaBasicBlock { tensors:[TosaTensor]; // tensors array inputs:[string]; // name of graph inputs outputs:[string]; // name of graph outputs + shapes:[TosaShape]; // shapes array } table TosaRegion { @@ -537,4 +556,4 @@ table TosaGraph { regions:[TosaRegion]; // regions array } -root_type TosaGraph; +root_type TosaGraph; \ No newline at end of file diff --git a/backends/arm/tosa/specification.py b/backends/arm/tosa/specification.py index 92b68955cdd..b5b9613c208 100644 --- a/backends/arm/tosa/specification.py +++ b/backends/arm/tosa/specification.py @@ -3,57 +3,130 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +"""Provide TOSA specification parsing and context utilities. -# -# Main implementation of AoT flow to partition and preprocess for Arm target -# backends. Converts via TOSA as an intermediate form supported by AoT and -# JIT compiler flows. -# +Use these helpers to parse and validate TOSA profile/extension strings and to +manage a lowering-time context for the active specification. + +""" import contextvars import re -from typing import List - -from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] - CompileSpec, -) +from typing import Dict, Generic, List, Set, TypeVar from packaging.version import Version +T = TypeVar("T") + + +class TosaSpecMapping(Generic[T]): + def __init__(self): + self._mapping: Dict[TosaSpecification, List[T]] = {} + + def add(self, spec: "TosaSpecification", value: T) -> None: + """ + Adds a value to the mapping for the given TOSA specification. + The specification is normalized to its canonical form, which means that + only the version and profiles are considered, without extensions. + This allows for grouping of values under the same TOSA specification + regardless of the extensions they may have. + """ + + if spec.is_U55_subset or spec.extensions: + raise ValueError( + f"TosaSpecMapping does not support extensions, got: {spec}" + ) + + if isinstance(spec, Tosa_1_00) and len(spec.profiles) > 1: + raise ValueError( + f"TosaSpecMapping does not support multiple profiles, got: {spec}" + ) + + norm_spec = spec._canonical_key() + if norm_spec not in self._mapping: + self._mapping[norm_spec] = [] + self._mapping[norm_spec].append(value) + + @staticmethod + def _get_base_specs(spec: "TosaSpecification") -> List["TosaSpecification"]: + # Handles combined TOSA-1.0+FP+INT, etc. + if isinstance(spec, Tosa_1_00): + profiles: Set[str] = set(spec.profiles) + if profiles == {"FP", "INT"}: + version = spec.version + return [ + TosaSpecification.create_from_string(f"TOSA-{version}+FP"), + TosaSpecification.create_from_string(f"TOSA-{version}+INT"), + ] + return [spec] + + def get(self, spec: "TosaSpecification") -> List[T]: + """ + Returns a list of values associated with the given TOSA specification. + The specification is normalized to its canonical form, which means that + only the version and profiles are considered, without extensions. + """ + + base_specs = self._get_base_specs(spec) + result: List[T] = [] + for base in base_specs: + norm_base = base._canonical_key() + result.extend(self._mapping.get(norm_base, [])) + if len(result) == 0: + raise KeyError(f"No values found for TOSA specification: {spec}") + + return result # Do not deduplicate with set(), as values may be unhashable + class TosaSpecification: - """ - This class implements a representation of TOSA specification - (https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile - (with extension) and a level (8k). - For 1.00 releases the profile is INT or FP, and the extensions are for - INT: int16, int4, var, cf - FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf + """Represent a TOSA specification. + + A specification includes a semantic version, one or more profiles, and + optional extensions and levels (for example ``8k``). + The encoded form follows ``TOSA-..+[+][+...]``. + Profiles use uppercase (for example ``INT``, ``FP``); levels and extensions + use lowercase. - The TOSA specification is encoded in the string represenatation - TOSA-major.minor.patch+profile[+level][+extensions] + Attributes: + version (Version): Parsed TOSA semantic version. + is_U55_subset (bool): True if the ``u55`` subset is requested. - Profiles are uppercase letters and extensions and level is lowercase. """ version: Version is_U55_subset: bool + extensions: List[str] def support_integer(self) -> bool: - """ - Returns true if any integer operations are supported for the specification. - """ + """Return True if integer operations are supported.""" raise NotImplementedError def support_float(self) -> bool: - """ - Returns true if any float operations are supported for the specification. + """Return True if floating-point operations are supported.""" + raise NotImplementedError + + def support_extension(self, extension: str) -> bool: + """Return True if an extension is supported and enabled. + + Args: + extension (str): Extension name (for example ``int4``, ``bf16``). + + Returns: + bool: True if the extension is valid for the active profiles and selected. + """ raise NotImplementedError def __init__(self, version: Version, extras: List[str]): + """Initialize the base specification. + + Args: + version (Version): Parsed TOSA semantic version. + extras (List[str]): Remaining tokens such as profiles, levels, and extensions. + + """ self.version = version + self.extensions = [] self.is_U55_subset = "u55" in extras if self.is_U55_subset: @@ -61,11 +134,20 @@ def __init__(self, version: Version, extras: List[str]): @staticmethod def create_from_string(repr: str) -> "TosaSpecification": - """ - Creates a TOSA specification class from a string representation: - TOSA-1.00.0+INT+FP+int4+cf - """ + """Create a specification from a standard string format. + + Example: ``TOSA-1.00.0+INT+FP+int4+cf``. + + Args: + repr (str): Standard representation string. + + Returns: + TosaSpecification: Parsed specification instance. + Raises: + ValueError: If the representation is malformed or version is unsupported. + + """ pattern = r"^(TOSA)-([\d.]+)\+(.+)$" match = re.match(pattern, repr) if match: @@ -82,8 +164,26 @@ def create_from_string(repr: str) -> "TosaSpecification": raise ValueError(f"Failed to parse TOSA specification representation: {repr}") + def _canonical_key(self) -> "TosaSpecification": + """ + Returns a new TosaSpecification instance with only version and profiles (no extensions). + """ + raise NotImplementedError + class Tosa_1_00(TosaSpecification): + """Provide TOSA 1.00 profile and extension semantics. + + This variant validates profiles (``INT``, ``FP``), the optional ``8k`` level, + and allowed extensions based on the selected profiles. + + Attributes: + profiles (List[str]): Selected profiles, e.g., ``["INT"]`` or ``["INT", "FP"]``. + level_8k (bool): True if the ``8k`` level is enabled. + extensions (List[str]): Enabled extensions valid for the chosen profiles. + + """ + profiles: List[str] level_8k: bool extensions: List[str] @@ -95,6 +195,16 @@ class Tosa_1_00(TosaSpecification): } def __init__(self, version: Version, extras: List[str]): + """Initialize the 1.00 specification and validate extras. + + Args: + version (Version): Semantic version (major=1, minor=0). + extras (List[str]): Tokens including profiles, level, and extensions. + + Raises: + ValueError: If no/too many profiles are provided or extensions are invalid. + + """ super().__init__(version, extras) # Check that we have at least one profile in the extensions list @@ -133,12 +243,20 @@ def __init__(self, version: Version, extras: List[str]): self.extensions = extras def _get_profiles_string(self) -> str: + """Return the ``+``-joined profile segment (e.g., ``+INT+FP``).""" return "".join(["+" + p for p in self.profiles]) def _get_extensions_string(self) -> str: + """Return the ``+``-joined extensions segment (e.g., ``+int4+cf``).""" return "".join(["+" + e for e in self.extensions]) def __repr__(self): + """Return the standard specification string format. + + Returns: + str: Standard form like ``TOSA-1.00.0+INT+8k+int4``. + + """ extensions = self._get_extensions_string() if self.level_8k: extensions += "+8k" @@ -147,9 +265,24 @@ def __repr__(self): return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}" def __hash__(self) -> int: + """Return a stable hash for use in sets and dict keys. + + Returns: + int: Hash value derived from version and profiles. + + """ return hash(str(self.version) + self._get_profiles_string()) def __eq__(self, other: object) -> bool: + """Return True if another instance represents the same spec. + + Args: + other (object): Object to compare. + + Returns: + bool: True if versions and profiles match. + + """ if isinstance(other, Tosa_1_00): return (self.version == other.version) and ( self._get_profiles_string() == other._get_profiles_string() @@ -157,52 +290,117 @@ def __eq__(self, other: object) -> bool: return False def support_integer(self): + """Return True if the ``INT`` profile is present.""" return "INT" in self.profiles def support_float(self): + """Return True if the ``FP`` profile is present.""" return "FP" in self.profiles def support_extension(self, extension: str) -> bool: + """Return True if an extension is supported and enabled. + + Args: + extension (str): Extension name (for example ``int4``, ``bf16``). + + Returns: + bool: True if the extension is valid for the active profiles and selected. + + """ for p in self.profiles: if extension in self.valid_extensions[p] and extension in self.extensions: return True return False + def _canonical_key(self) -> "Tosa_1_00": + """ + Returns a new Tosa_1_00 instance with only major.minor version and profiles (no extensions). + Patch version is set to zero for normalization. + """ + from packaging.version import Version + + norm_version = Version(f"{self.version.major}.{self.version.minor}.0") + return Tosa_1_00(norm_version, self.profiles.copy()) + class TosaLoweringContext: - """ - A context manager to handle the TOSA specific aspects of the lowering process. - For now it only handles the TOSA specification context, but it can be extended - to include other policies or configurations. + """Manage the TOSA specification context for lowering. + + For now, only the active ``TosaSpecification`` is tracked, but this can be + extended to carry additional lowering policies or configuration. + + Attributes: + tosa_spec_var (contextvars.ContextVar): Context variable storing the active spec. + spec (TosaSpecification): Specification passed to the context manager. + """ # Define a context variable for the spec tosa_spec_var: contextvars.ContextVar = contextvars.ContextVar("tosa_spec") def __init__(self, spec: TosaSpecification): + """Initialize the lowering context with a specification. + + Args: + spec (TosaSpecification): Active specification to put into context. + + """ self.spec = spec def __enter__(self): + """Set the context variable and return self. + + Returns: + TosaLoweringContext: This context manager instance. + + """ # Set the spec in the context variable and store the token for later reset self.token = TosaLoweringContext.tosa_spec_var.set(self.spec) return self def __exit__(self, exc_type, exc_value, traceback): + """Reset the context variable to its previous state. + + Args: + exc_type (type | None): Exception type, if any. + exc_value (BaseException | None): Exception instance, if any. + traceback (TracebackType | None): Traceback, if any. + + """ # Reset the context variable to its previous state TosaLoweringContext.tosa_spec_var.reset(self.token) -# A helper function to retrieve the current spec anywhere in your code def get_context_spec() -> TosaSpecification: + """Get the current ``TosaSpecification`` from the lowering context. + + Returns: + TosaSpecification: Active specification retrieved from the context var. + + Raises: + RuntimeError: If called outside a ``TosaLoweringContext``. + + """ try: return TosaLoweringContext.tosa_spec_var.get() except LookupError: raise RuntimeError("Function must be executed within a TosaLoweringContext") -def get_tosa_spec(compile_spec: List[CompileSpec]) -> TosaSpecification: - for spec in compile_spec: - if spec.key == "tosa_spec": - return TosaSpecification.create_from_string(spec.value.decode()) - raise ValueError("Could not find TOSA version in CompileSpec") +def tosa_spec_in_set(spec: TosaSpecification, specs: Set[TosaSpecification]) -> bool: + """Check if a specification matches any in a set, considering base specs. + + Args: + spec (TosaSpecification): Specification to check. + specs (Set[TosaSpecification]): Set of specifications to match against. + + Returns: + bool: True if a match is found, False otherwise. + + """ + base_specs = TosaSpecMapping._get_base_specs(spec) + for base in base_specs: + if base in specs: + return True + return False diff --git a/backends/arm/tosa/utils.py b/backends/arm/tosa/utils.py index 15c8612d33f..df77153e29f 100644 --- a/backends/arm/tosa/utils.py +++ b/backends/arm/tosa/utils.py @@ -2,18 +2,17 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -# pyre-unsafe +"""Utility helpers for building TOSA graphs in the Arm backend.""" import logging from typing import Any import numpy as np -import serializer.tosa_serializer as ts # type: ignore import sympy # type: ignore import torch +import tosa_serializer as ts from executorch.backends.arm.tosa.mapping import extract_tensor_meta from executorch.backends.arm.tosa.specification import TosaSpecification @@ -27,19 +26,21 @@ def are_fake_tensors_broadcastable( fake_tensors: list[FakeTensor], ) -> tuple[bool, list[int]]: - """ - Determines whether a list of FakeTensors can be broadcast together. + """Determine whether the fake tensors share a broadcastable shape. + Args: - fake_tensors (list[FakeTensor]): List of 2 or more FakeTensors - who's shapes to evaluate + fake_tensors (list[FakeTensor]): Fake tensors whose shapes should + be validated for broadcasting. Returns: - tuple[bool, list[int]]: First element is whether the shapes are - broadcastable. Second element is the common shape if compatible. - If not, empty list. + tuple[bool, list[int]]: Tuple where the first element indicates + whether broadcasting is possible and the second element contains + the broadcast shape. The shape list is empty when broadcasting + fails. Raises: - RuntimeError: If less than 2 tensors are passed in. + RuntimeError: Raised when fewer than two tensors are supplied. + """ if len(fake_tensors) < 1: raise RuntimeError(f"Expected 2 or more tensors got {len(fake_tensors)}") @@ -66,26 +67,27 @@ def are_fake_tensors_broadcastable( def broadcast_tensors( tosa_fb, nodes: list[Node], tosa_spec: TosaSpecification ) -> list[Any]: - """ - Given a list of nodes it determines the common shape they broadcast to - and adds the necessary reshape and tile operations to perform the broadcast. + """Broadcast the FX nodes to a shared shape inside the TOSA graph. + + This mirrors ``reshape_for_broadcast`` but also emits the tile operators + needed to materialize the broadcast and supports any number of inputs. Args: - tosa_fb: Tosa graph to add nodes to - nodes (list[Node]): List of nodes to broadcast together - tosa_spec (TosaSpecification): Tosa spec + tosa_fb (Any): TOSA graph builder that receives the broadcast + operators. + nodes (list[Node]): FX nodes whose tensor metadata should be + broadcast. + tosa_spec (TosaSpecification): Active TOSA specification used to + decode tensor metadata. Returns: - list[Any]: List containing the fx.Nodes or TosaSerializerTensors - of the right common shape. Order of output matches order of input. + list[Any]: Broadcast versions of the inputs. Each element is either + the original FX node or a TOSA serializer tensor, ordered to match + ``nodes``. Raises: RuntimeError: If the supplied nodes are not broadcastable. - Note: - This function and `reshape_for_broadcast` both reshape the tensors - for broadcast. However this function also performs the broadcast and - does not have a limit on only two input tensors. """ index_fake_tensors = [node.meta["val"] for node in nodes] broadcastable, common_shape = are_fake_tensors_broadcastable(index_fake_tensors) @@ -108,7 +110,7 @@ def broadcast_tensors( tens_dtype, ) - build_reshape_tosa_1_0(tosa_fb, node.name, new_shape, reshaped.name) + build_reshape_tosa(tosa_fb, node.name, new_shape, reshaped.name) tiled = tosa_fb.addIntermediate(common_shape, tens_dtype) multipliers = [ @@ -121,11 +123,13 @@ def broadcast_tensors( name=f"{node.name}_multiples", ) + attr = ts.TosaSerializerAttribute() + attr.TileAttribute() tosa_fb.addOperator( - ts.TosaOp.Op().TILE, + ts.Op.TILE, [reshaped.name, multiple_shapes.name], [tiled.name], - None, + attr, ) broadcast_tensors.append(tiled) @@ -133,9 +137,20 @@ def broadcast_tensors( return broadcast_tensors -def build_reshape_tosa_1_0( +def build_reshape_tosa( tosa_graph, input_name, new_shape, output_name, shape_name_override="" ): + """Insert a TOSA reshape operator using the v1.0 semantics. + + Args: + tosa_graph (Any): Graph builder used to emit TOSA operators. + input_name (str): Name of the tensor that should be reshaped. + new_shape (list[int]): Target tensor shape. + output_name (str): Name assigned to the reshaped tensor. + shape_name_override (str): Optional override for the shape constant + name. + + """ shape = tosa_graph.addConst( np.array(new_shape).shape, ts.DType.SHAPE, @@ -146,7 +161,7 @@ def build_reshape_tosa_1_0( attr = ts.TosaSerializerAttribute() attr.ReshapeAttribute() tosa_graph.addOperator( - ts.TosaOp.Op().RESHAPE, + ts.Op.RESHAPE, [input_name, shape.name], [output_name], attr, @@ -154,13 +169,26 @@ def build_reshape_tosa_1_0( def tosa_shape(shape, dim_order): + """Reorder a shape tuple into TOSA layout while resolving symints. + + Args: + shape (Sequence[int | torch.SymInt]): Original tensor shape, + possibly containing ``torch.SymInt``. + dim_order (Sequence[int]): Desired dimension order for the output + shape. + + Returns: + list[int]: List containing the reordered dimensions where symbolic + values become ``-1``. + + """ reordered = tuple([shape[dim] for dim in dim_order]) # Dynamic shapes in executorch are represented with torch.SymInt objects in the shapes, # in TOSA we do not have this concept and instead use -1. removed_symints = tuple( [-1 if isinstance(d, torch.SymInt) else d for d in reordered] ) - return removed_symints + return list(removed_symints) def get_resize_parameters_1d( @@ -169,6 +197,26 @@ def get_resize_parameters_1d( resize_mode: int, align_corners: bool, ): + """Compute resize coefficients for a single spatial dimension. + + Args: + input_size (int | torch.SymInt): Input size for the axis, possibly + symbolic. + output_size (int | torch.SymInt): Output size for the axis, possibly + symbolic. + resize_mode (int): Target resize mode defined by TOSA. + align_corners (bool): Whether the resize should align the corner + pixels. + + Returns: + tuple[int, int, int, int]: Numerator, denominator, offset, and border + terms encoded as integers. + + Raises: + RuntimeError: If symbolic shapes are used with ``align_corners`` or if + the computed ratio or border is not constant. + + """ # We don't support align_corners for symbolic shapes, because handling the edge case where size == 1 is tricky. if align_corners: if (not isinstance(input_size, int)) or (not isinstance(output_size, int)): @@ -228,19 +276,23 @@ def get_resize_parameters( resize_mode: int, align_corners: bool, ) -> tuple[torch.IntTensor, ...]: - """Get the tosa.resize parameters based on the input and output size. + """Calculate 2D resize parameters for TOSA emission. Args: - input_size_xy (tuple[int | torch.SymInt]): Size of the input - output_size_xy (tuple[int | torch.SymInt]): Size of the output - resize_mode (tosa.ResizeMode): The TOSA resize mode - align_corners (bool): Align the corners pixels of the input and output + input_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height + and width of the input tensor. + output_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height + and width of the output tensor. + resize_mode (int): TOSA resize mode used for coefficient generation. + align_corners (bool): Whether to align corner pixels between input and + output. Returns: - scale_n (torch.IntTensor), scale_d (torch.IntTensor), - offset (torch.IntTensor), border (torch.IntTensor) - """ + tuple[torch.IntTensor, ...]: Four-element tuple of tensors describing + the scale numerator, scale denominator, offset, and border for Y + and X dimensions. + """ # Get the parameters for each dimension independently y_params = get_resize_parameters_1d( input_size_xy[0], output_size_xy[0], resize_mode, align_corners diff --git a/backends/arm/util/_factory.py b/backends/arm/util/_factory.py new file mode 100644 index 00000000000..23d8215fc9b --- /dev/null +++ b/backends/arm/util/_factory.py @@ -0,0 +1,59 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner +from executorch.backends.arm.quantizer import ( + EthosUQuantizer, + TOSAQuantizer, + VgfQuantizer, +) +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.partitioner import TOSAPartitioner +from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch.fx.passes.operator_support import OperatorSupportBase + + +def parse_compile_spec(compile_specs: list[CompileSpec]) -> ArmCompileSpec: + output_format = None + for spec in compile_specs: + if spec.key == "output_format": + output_format = spec.value.decode() + break + else: + raise ValueError("Compile spec without output format.") + if output_format == TosaCompileSpec.get_output_format(): + return TosaCompileSpec.from_list(compile_specs) + if output_format == EthosUCompileSpec.get_output_format(): + return EthosUCompileSpec.from_list(compile_specs) + if output_format == VgfCompileSpec.get_output_format(): + return VgfCompileSpec.from_list(compile_specs) + raise ValueError(f"Unknown output format {output_format}") + + +def create_partitioner( + compile_spec: ArmCompileSpec, + additional_checks: list[OperatorSupportBase] | None = None, +): + if isinstance(compile_spec, TosaCompileSpec): + return TOSAPartitioner(compile_spec, additional_checks) + elif isinstance(compile_spec, EthosUCompileSpec): + return EthosUPartitioner(compile_spec, additional_checks) + elif isinstance(compile_spec, VgfCompileSpec): + return VgfPartitioner(compile_spec, additional_checks) + else: + raise ValueError("compile spec doesn't target any Arm Partitioner") + + +def create_quantizer(compile_spec: ArmCompileSpec): + if isinstance(compile_spec, TosaCompileSpec): + return TOSAQuantizer(compile_spec) + elif isinstance(compile_spec, EthosUCompileSpec): + return EthosUQuantizer(compile_spec) + elif isinstance(compile_spec, VgfCompileSpec): + return VgfQuantizer(compile_spec) + else: + raise ValueError("compile spec doesn't target any Arm Quantizer") diff --git a/backends/arm/util/arm_model_evaluator.py b/backends/arm/util/arm_model_evaluator.py index a3dcbdc5c6f..d9cbdc2a923 100644 --- a/backends/arm/util/arm_model_evaluator.py +++ b/backends/arm/util/arm_model_evaluator.py @@ -1,11 +1,10 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +import json import logging import os import random @@ -14,7 +13,7 @@ from collections import defaultdict from pathlib import Path -from typing import Any, Optional, Tuple +from typing import Any, cast, Optional, Tuple import torch from torch.nn.modules import Module @@ -29,21 +28,155 @@ logger.setLevel(logging.INFO) +# ImageNet 224x224 transforms (Resize->CenterCrop->ToTensor->Normalize) +# If future models require different preprocessing, extend this helper accordingly. +def _get_imagenet_224_transforms(): + """Return standard ImageNet 224x224 preprocessing transforms.""" + return transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.484, 0.454, 0.403], std=[0.225, 0.220, 0.220]), + ] + ) + + +def _build_calibration_loader( + dataset: datasets.ImageFolder, max_items: int +) -> DataLoader: + """Return a DataLoader over a deterministic, shuffled subset of size <= max_items. + + Shuffles with seed: ARM_EVAL_CALIB_SEED (int) or default 1337; then selects first k and + sorts indices to keep enumeration order stable while content depends on seed. + """ + k = min(max_items, len(dataset)) + seed_env = os.getenv("ARM_EVAL_CALIB_SEED") + default_seed = 1337 + if seed_env is not None: + try: + seed = int(seed_env) + except ValueError: + logger.warning( + "ARM_EVAL_CALIB_SEED is not an int (%s); using default seed %d", + seed_env, + default_seed, + ) + seed = default_seed + else: + seed = default_seed + rng = random.Random( + seed + ) # nosec B311 - deterministic shuffling for evaluation only + indices = list(range(len(dataset))) + rng.shuffle(indices) + selected = sorted(indices[:k]) + return torch.utils.data.DataLoader( + torch.utils.data.Subset(dataset, selected), batch_size=1, shuffle=False + ) + + +def _load_imagenet_folder(directory: str) -> datasets.ImageFolder: + """Shared helper to load an ImageNet-layout folder. + + Raises FileNotFoundError for a missing directory early to aid debugging. + """ + directory_path = Path(directory) + if not directory_path.exists(): + raise FileNotFoundError(f"Directory: {directory} does not exist.") + transform = _get_imagenet_224_transforms() + return datasets.ImageFolder(directory_path, transform=transform) + + class GenericModelEvaluator: + """Base evaluator computing quantization error metrics and optional compression ratio. + + Subclasses can extend: provide calibration (get_calibrator) and override evaluate() + to add domain specific metrics (e.g. top-1 / top-5 accuracy). + """ + + @staticmethod + def evaluate_topk( + model: Module, + dataset: datasets.ImageFolder, + batch_size: int, + topk: int = 5, + log_every: int = 50, + ) -> Tuple[float, float]: + """Evaluate model top-1 / top-k accuracy. + + Args: + model: Torch module (should be in eval() mode prior to call). + dataset: ImageFolder style dataset. + batch_size: Batch size for evaluation. + topk: Maximum k for accuracy (default 5). + log_every: Log running accuracy every N batches. + Returns: + (top1_accuracy, topk_accuracy) + """ + # Some exported / quantized models (torchao PT2E) disallow direct eval()/train(). + # Try to switch to eval mode, but degrade gracefully if unsupported. + try: + model.eval() + except NotImplementedError: + # Attempt to enable train/eval overrides if torchao helper is present. + try: + from torchao.quantization.pt2e.utils import ( # type: ignore + allow_exported_model_train_eval, + ) + + allow_exported_model_train_eval(model) + try: + model.eval() + except Exception: + logger.debug( + "Model eval still not supported after allow_exported_model_train_eval; proceeding without explicit eval()." + ) + except Exception: + logger.debug( + "Model eval() unsupported and torchao allow_exported_model_train_eval not available; proceeding." + ) + loaded_dataset = DataLoader(dataset, batch_size=batch_size, shuffle=False) + top1_correct = 0 + topk_correct = 0 + total = 0 + with torch.inference_mode(): # disable autograd + some backend optimizations + for i, (image, target) in enumerate(loaded_dataset): + prediction = model(image) + topk_indices = torch.topk(prediction, k=topk, dim=1).indices + # target reshaped for broadcasting + target_view = target.view(-1, 1) + top1_correct += (topk_indices[:, :1] == target_view).sum().item() + topk_correct += (topk_indices == target_view).sum().item() + batch_sz = image.size(0) + total += batch_sz + if (i + 1) % log_every == 0 or total == len(dataset): + logger.info( + "Eval progress: %d / %d top1=%.4f top%d=%.4f", + total, + len(dataset), + top1_correct / total, + topk, + topk_correct / total, + ) + top1_accuracy = top1_correct / len(dataset) + topk_accuracy = topk_correct / len(dataset) + return top1_accuracy, topk_accuracy + REQUIRES_CONFIG = False def __init__( self, model_name: str, fp32_model: torch.nn.Module, - int8_model: torch.nn.Module, + quant_model: torch.nn.Module, example_input: Tuple[torch.Tensor], tosa_output_path: Optional[str], ) -> None: self.model_name = model_name self.fp32_model = fp32_model - self.int8_model = int8_model + self.quant_model = quant_model self.example_input = example_input if tosa_output_path: @@ -52,21 +185,27 @@ def __init__( self.tosa_output_path = "" def get_model_error(self) -> defaultdict: - """ - Returns a dict containing the following metrics between the outputs of the FP32 and INT8 model: - - Maximum error - - Maximum absolute error - - Maximum percentage error - - Mean absolute error + """Return per-output quantization error statistics. + + Metrics (lists per output tensor): + max_error + max_absolute_error + max_percentage_error (safe-divided; zero fp32 elements -> 0%) + mean_absolute_error """ fp32_outputs, _ = tree_flatten(self.fp32_model(*self.example_input)) - int8_outputs, _ = tree_flatten(self.int8_model(*self.example_input)) + quant_outputs, _ = tree_flatten(self.quant_model(*self.example_input)) model_error_dict = defaultdict(list) - for fp32_output, int8_output in zip(fp32_outputs, int8_outputs): - difference = fp32_output - int8_output - percentage_error = torch.div(difference, fp32_output) * 100 + for fp32_output, quant_output in zip(fp32_outputs, quant_outputs): + difference = fp32_output - quant_output + # Avoid divide by zero: elements where fp32 == 0 produce 0% contribution + percentage_error = torch.where( + fp32_output != 0, + difference / fp32_output * 100, + torch.zeros_like(difference), + ) model_error_dict["max_error"].append(torch.max(difference).item()) model_error_dict["max_absolute_error"].append( torch.max(torch.abs(difference)).item() @@ -101,7 +240,6 @@ def evaluate(self) -> dict[str, Any]: if self.tosa_output_path: # We know output_metrics["metrics"] is list since we just defined it, safe to ignore. - # pyre-ignore[16] output_metrics["metrics"][ # type: ignore[index] "compression_ratio" ] = self.get_compression_ratio() @@ -116,14 +254,14 @@ def __init__( self, model_name: str, fp32_model: Module, - int8_model: Module, + quant_model: Module, example_input: Tuple[torch.Tensor], tosa_output_path: str | None, batch_size: int, validation_dataset_path: str, ) -> None: super().__init__( - model_name, fp32_model, int8_model, example_input, tosa_output_path + model_name, fp32_model, quant_model, example_input, tosa_output_path ) self.__batch_size = batch_size @@ -131,69 +269,241 @@ def __init__( @staticmethod def __load_dataset(directory: str) -> datasets.ImageFolder: - directory_path = Path(directory) - if not directory_path.exists(): - raise FileNotFoundError(f"Directory: {directory} does not exist.") - - transform = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.484, 0.454, 0.403], std=[0.225, 0.220, 0.220] - ), - ] - ) - return datasets.ImageFolder(directory_path, transform=transform) + return _load_imagenet_folder(directory) @staticmethod def get_calibrator(training_dataset_path: str) -> DataLoader: dataset = MobileNetV2Evaluator.__load_dataset(training_dataset_path) - rand_indices = random.sample(range(len(dataset)), k=1000) + return _build_calibration_loader(dataset, 1000) - # Return a subset of the dataset to be used for calibration - return torch.utils.data.DataLoader( - torch.utils.data.Subset(dataset, rand_indices), - batch_size=1, - shuffle=False, + @classmethod + def from_config( + cls, + model_name: str, + fp32_model: Module, + quant_model: Module, + example_input: Tuple[torch.Tensor], + tosa_output_path: str | None, + config: dict[str, Any], + ) -> "MobileNetV2Evaluator": + """Factory constructing evaluator from a config dict. + + Expected keys: batch_size, validation_dataset_path + """ + return cls( + model_name, + fp32_model, + quant_model, + example_input, + tosa_output_path, + batch_size=config["batch_size"], + validation_dataset_path=config["validation_dataset_path"], ) - def __evaluate_mobilenet(self) -> Tuple[float, float]: + def evaluate(self) -> dict[str, Any]: + # Load dataset and compute top-1 / top-5 dataset = MobileNetV2Evaluator.__load_dataset(self.__validation_set_path) - loaded_dataset = DataLoader( - dataset, - batch_size=self.__batch_size, - shuffle=False, + top1_correct, top5_correct = GenericModelEvaluator.evaluate_topk( + self.quant_model, dataset, self.__batch_size, topk=5 ) + output = super().evaluate() + output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct} + return output - top1_correct = 0 - top5_correct = 0 - for i, (image, target) in enumerate(loaded_dataset): - prediction = self.int8_model(image) - top1_prediction = torch.topk(prediction, k=1, dim=1).indices - top5_prediction = torch.topk(prediction, k=5, dim=1).indices +class DeiTTinyEvaluator(GenericModelEvaluator): + REQUIRES_CONFIG = True - top1_correct += (top1_prediction == target.view(-1, 1)).sum().item() - top5_correct += (top5_prediction == target.view(-1, 1)).sum().item() + def __init__( + self, + model_name: str, + fp32_model: Module, + quant_model: Module, + example_input: Tuple[torch.Tensor], + tosa_output_path: str | None, + batch_size: int, + validation_dataset_path: str, + ) -> None: + super().__init__( + model_name, fp32_model, quant_model, example_input, tosa_output_path + ) + self.__batch_size = batch_size + self.__validation_set_path = validation_dataset_path - logger.info("Iteration: {}".format((i + 1) * self.__batch_size)) - logger.info( - "Top 1: {}".format(top1_correct / ((i + 1) * self.__batch_size)) - ) - logger.info( - "Top 5: {}".format(top5_correct / ((i + 1) * self.__batch_size)) - ) + @staticmethod + def __load_dataset(directory: str) -> datasets.ImageFolder: + return _load_imagenet_folder(directory) - top1_accuracy = top1_correct / len(dataset) - top5_accuracy = top5_correct / len(dataset) + @staticmethod + def get_calibrator(training_dataset_path: str) -> DataLoader: + dataset = DeiTTinyEvaluator.__load_dataset(training_dataset_path) + return _build_calibration_loader(dataset, 1000) - return top1_accuracy, top5_accuracy + @classmethod + def from_config( + cls, + model_name: str, + fp32_model: Module, + quant_model: Module, + example_input: Tuple[torch.Tensor], + tosa_output_path: str | None, + config: dict[str, Any], + ) -> "DeiTTinyEvaluator": + """Factory constructing evaluator from a config dict. + + Expected keys: batch_size, validation_dataset_path + """ + return cls( + model_name, + fp32_model, + quant_model, + example_input, + tosa_output_path, + batch_size=config["batch_size"], + validation_dataset_path=config["validation_dataset_path"], + ) def evaluate(self) -> dict[str, Any]: - top1_correct, top5_correct = self.__evaluate_mobilenet() + # Load dataset and compute top-1 / top-5 + dataset = DeiTTinyEvaluator.__load_dataset(self.__validation_set_path) + top1, top5 = GenericModelEvaluator.evaluate_topk( + self.quant_model, dataset, self.__batch_size, topk=5 + ) output = super().evaluate() + output["metrics"]["accuracy"] = {"top-1": top1, "top-5": top5} + return output - output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct} + +class ResNet18Evaluator(GenericModelEvaluator): + REQUIRES_CONFIG = True + + def __init__( + self, + model_name: str, + fp32_model: Module, + quant_model: Module, + example_input: Tuple[torch.Tensor], + tosa_output_path: str | None, + batch_size: int, + validation_dataset_path: str, + ) -> None: + super().__init__( + model_name, fp32_model, quant_model, example_input, tosa_output_path + ) + self.__batch_size = batch_size + self.__validation_set_path = validation_dataset_path + + @staticmethod + def __load_dataset(directory: str) -> datasets.ImageFolder: + return _load_imagenet_folder(directory) + + @staticmethod + def get_calibrator(training_dataset_path: str) -> DataLoader: + dataset = ResNet18Evaluator.__load_dataset(training_dataset_path) + return _build_calibration_loader(dataset, 1000) + + @classmethod + def from_config( + cls, + model_name: str, + fp32_model: Module, + quant_model: Module, + example_input: Tuple[torch.Tensor], + tosa_output_path: str | None, + config: dict[str, Any], + ) -> "ResNet18Evaluator": + return cls( + model_name, + fp32_model, + quant_model, + example_input, + tosa_output_path, + batch_size=config["batch_size"], + validation_dataset_path=config["validation_dataset_path"], + ) + + def evaluate(self) -> dict[str, Any]: + dataset = ResNet18Evaluator.__load_dataset(self.__validation_set_path) + top1, top5 = GenericModelEvaluator.evaluate_topk( + self.quant_model, dataset, self.__batch_size, topk=5 + ) + output = super().evaluate() + output["metrics"]["accuracy"] = {"top-1": top1, "top-5": top5} return output + + +evaluators: dict[str, type[GenericModelEvaluator]] = { + "generic": GenericModelEvaluator, + "mv2": MobileNetV2Evaluator, + "deit_tiny": DeiTTinyEvaluator, + "resnet18": ResNet18Evaluator, +} + + +def evaluator_calibration_data( + evaluator_name: str, + evaluator_config: str | None, +): + evaluator = evaluators[evaluator_name] + + if hasattr(evaluator, "get_calibrator"): + assert evaluator_config is not None + + config_path = Path(evaluator_config) + with config_path.open() as f: + config = json.load(f) + + # All current evaluators exposing calibration implement a uniform + # static method signature: get_calibrator(training_dataset_path: str) + # so we can call it generically without enumerating classes. + return evaluator.get_calibrator( + training_dataset_path=config["training_dataset_path"] + ) + + +def evaluate_model( + model_name: str, + intermediates: str, + target: str, + model_fp32: torch.nn.Module, + model_quant: torch.nn.Module, + example_inputs: Tuple[torch.Tensor], + evaluator_name: str, + evaluator_config: str | None, +) -> None: + evaluator = evaluators[evaluator_name] + + intermediates_path = Path(intermediates) + tosa_paths = list(intermediates_path.glob("*.tosa")) + + if evaluator.REQUIRES_CONFIG: + assert evaluator_config is not None + config_path = Path(evaluator_config) + with config_path.open() as f: + config = json.load(f) + + # Prefer a subclass provided from_config if available. + if hasattr(evaluator, "from_config"): + factory = cast(Any, evaluator.from_config) # type: ignore[attr-defined] + init_evaluator = factory( + model_name, + model_fp32, + model_quant, + example_inputs, + str(tosa_paths[0]), + config, + ) + else: + raise RuntimeError( + f"Evaluator {evaluator_name} requires config but does not implement from_config()" + ) + else: + init_evaluator = evaluator( + model_name, model_fp32, model_quant, example_inputs, str(tosa_paths[0]) + ) + + quant_metrics = init_evaluator.evaluate() + output_json_path = intermediates_path / f"{target}-quant_metrics.json" + + with output_json_path.open("w") as json_file: + json.dump(quant_metrics, json_file) diff --git a/backends/arm/vgf/__init__.py b/backends/arm/vgf/__init__.py index 4ab8144cbd6..88be90e084e 100644 --- a/backends/arm/vgf/__init__.py +++ b/backends/arm/vgf/__init__.py @@ -3,12 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # -# pyre-unsafe from .backend import VgfBackend # noqa: F401 +from .compile_spec import VgfCompileSpec # noqa: F401 from .partitioner import VgfPartitioner # noqa: F401 -__all__ = [ - "VgfBackend", - "VgfPartitioner", -] +__all__ = ["VgfBackend", "VgfPartitioner", "VgfCompileSpec"] diff --git a/backends/arm/vgf/backend.py b/backends/arm/vgf/backend.py index 7c408748529..0e931afa10e 100644 --- a/backends/arm/vgf/backend.py +++ b/backends/arm/vgf/backend.py @@ -3,7 +3,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe # # Main implementation of AoT flow to partition and preprocess for VGF target @@ -11,19 +10,32 @@ # this form is used where the final JIT compile is performed on target (in the # runtime delegate executorch::runtime::BackendInterface::init # +"""Ahead-of-time Arm VGF backend built on the shared TOSA pipeline.""" import logging -import os -import subprocess +import os # nosec B404 - used alongside subprocess for tool invocation +import subprocess # nosec B404 - required to drive external converter CLI import tempfile from typing import final, List -from executorch.backends.arm.tosa.backend import ( +from executorch.backends.arm.tosa.backend import ( # type: ignore[import-not-found] arm_get_first_delegation_tag, TOSABackend, ) -from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult -from executorch.exir.backend.compile_spec_schema import CompileSpec + +from executorch.backends.arm.vgf.compile_spec import ( # type: ignore[import-not-found] + VgfCompileSpec, +) +from executorch.backends.arm.vgf.model_converter import ( # type: ignore[import-not-found] + require_model_converter_binary, +) +from executorch.exir.backend.backend_details import ( # type: ignore[import-not-found] + BackendDetails, + PreprocessResult, +) +from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] + CompileSpec, +) from torch.export.exported_program import ExportedProgram # debug functionality @@ -32,29 +44,34 @@ @final class VgfBackend(BackendDetails): - """ - BackendDetails subclass for delegation to VGF compatible devices. This enables - encapsulated TOSA on target device and JIT compilation on suitable platforms. + """BackendDetails subclass for delegation to VGF compatible devices. + + This enables encapsulated TOSA on target device and JIT compilation on + suitable platforms. + """ @staticmethod def _compile_tosa_flatbuffer( tosa_flatbuffer: bytes, - compile_spec: List[CompileSpec], + compile_spec: VgfCompileSpec, tag_name: str = "", ) -> bytes: - """ - Static helper method to do the compilation of the TOSA flatbuffer - representation to a target specific binary stream. - """ - compile_flags = [] - artifact_path = None - for spec in compile_spec: - if spec.key == "compile_flags": - compile_flags.append(spec.value.decode()) - if spec.key == "debug_artifact_path": - artifact_path = spec.value.decode() + """Compile a TOSA flatbuffer into a target-specific binary stream. + + Args: + tosa_flatbuffer (bytes): Serialized TOSA graph produced by + ``TOSABackend``. + compile_spec (VgfCompileSpec): Compile specification providing + converter flags and artifact paths. + tag_name (str): Optional suffix used when producing debug outputs. + Returns: + bytes: Target-specific VGF binary stream. + + """ + compile_flags = compile_spec.compiler_flags + artifact_path = compile_spec.get_intermediate_path() # Pass on the TOSA flatbuffer to the vgf compiler. binary = vgf_compile(tosa_flatbuffer, compile_flags, artifact_path, tag_name) return binary @@ -62,10 +79,22 @@ def _compile_tosa_flatbuffer( @staticmethod def preprocess( edge_program: ExportedProgram, - compile_spec: List[CompileSpec], + compile_specs: List[CompileSpec], ) -> PreprocessResult: + """Lower the exported program and compile it for a VGF target. + + Args: + edge_program (ExportedProgram): Program to lower to VGF. + compile_specs (List[CompileSpec]): Serialized VGF compile specs + supplied by the frontend. + + Returns: + PreprocessResult: Result containing the compiled VGF binary. + + """ logger.info(f"{VgfBackend.__name__} preprocess") + compile_spec = VgfCompileSpec.from_list(compile_specs) # deduce TOSA compile_spec from VGF compile spec. We get a new # compile spec list, containing only elements relevant for the # TOSABackend. @@ -75,7 +104,7 @@ def preprocess( # ('All backend implementation are final...'), so use composition instead. # preprocess returns the serialized TOSA flatbuffer in .processed_bytes, # which can be passed on to next compilation step. - tosa_preprocess = TOSABackend.preprocess(edge_program, tosa_compile_spec) + tosa_preprocess = TOSABackend._preprocess(edge_program, tosa_compile_spec) tag_name = arm_get_first_delegation_tag(edge_program.graph_module) @@ -92,6 +121,20 @@ def vgf_compile( artifact_path: str | None = None, tag_name: str = "", ): + """Invoke the VGF compiler to convert a TOSA flatbuffer. + + Args: + tosa_flatbuffer (bytes): Serialized TOSA graph produced by + ``TOSABackend``. + compile_flags (List[str]): Command-line flags forwarded to + ``model-converter``. + artifact_path (str | None): Directory where debug artifacts are saved. + tag_name (str): Optional suffix used when producing debug outputs. + + Returns: + bytes: Compiled VGF binary emitted by ``model-converter``. + + """ with tempfile.TemporaryDirectory() as tmpdir: # We currently write out a flatbuffer as input to the converter @@ -101,12 +144,13 @@ def vgf_compile( f.write(tosa_flatbuffer) additional_flags = " ".join(compile_flags) + converter_binary = require_model_converter_binary() vgf_path = tosa_path + ".vgf" conversion_command = ( - f"model-converter {additional_flags} -i {tosa_path} -o {vgf_path}" + f"{converter_binary} {additional_flags} -i {tosa_path} -o {vgf_path}" ) try: - subprocess.run( + subprocess.run( # nosec B602 - shell invocation constrained to trusted converter binary [conversion_command], shell=True, check=True, capture_output=True ) except subprocess.CalledProcessError as process_error: @@ -116,11 +160,13 @@ def vgf_compile( Stdout:\n{process_error.stdout.decode()}" ) - if artifact_path is not None: + if artifact_path: logger.info(f"Emitting debug output to: {vgf_path=}") os.makedirs(artifact_path, exist_ok=True) cp = f"cp {vgf_path} {artifact_path}" - subprocess.run(cp, shell=True, check=True, capture_output=False) + subprocess.run( # nosec B602 - shell copy of trusted artifact for debugging + cp, shell=True, check=True, capture_output=False + ) vgf_bytes = open(vgf_path, "rb").read() return vgf_bytes diff --git a/backends/arm/vgf/compile_spec.py b/backends/arm/vgf/compile_spec.py new file mode 100644 index 00000000000..b5b13f59939 --- /dev/null +++ b/backends/arm/vgf/compile_spec.py @@ -0,0 +1,73 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.common.pipeline_config import ( # noqa: unused + ArmPassPipelineConfig, +) +from executorch.backends.arm.tosa import ( # type: ignore[import-not-found] + TosaSpecification, +) + +# debug functionality +logger = logging.getLogger(__name__) + + +class VgfCompileSpec(ArmCompileSpec): + """Compile specification for VGF-compatible targets.""" + + def __init__( + self, + tosa_spec: TosaSpecification | str | None = None, + compiler_flags: list[str] | None = None, + ): + """Normalise inputs and populate the underlying Arm compile spec. + + Args: + tosa_spec (TosaSpecification | str | None): TOSA specification to + target. Strings are parsed via + :meth:`TosaSpecification.create_from_string`. Defaults to + ``"TOSA-1.0+FP+INT"``. + compiler_flags (list[str] | None): Optional converter-backend flags. + """ + if tosa_spec is None: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP+INT") + elif isinstance(tosa_spec, str): + tosa_spec = TosaSpecification.create_from_string(tosa_spec) + + if compiler_flags is None: + compiler_flags = [] + self._set_compile_specs(tosa_spec, compiler_flags) + self.validate() + + def validate(self): + """Validate the configuration against VGF-supported TOSA profiles.""" + tosa_version = self.tosa_spec.version # type: ignore[attr-defined] + tosa_profiles = self.tosa_spec.profiles # type: ignore[attr-defined] + + if tosa_version.major != 1: + raise ValueError( + "Arm backend only supports converter-backend for TOSA version 1. " + f"Invalid TOSA version: {tosa_version}" + ) + + if "FP" not in tosa_profiles and "INT" not in tosa_profiles: + raise ValueError( + "Arm backend only supports converter-backend for FP and/or INT. " + f"Invalid TOSA profile: {tosa_profiles}" + ) + + @classmethod + def get_output_format(cls) -> str: + """Return the artifact format emitted by this compile spec.""" + return "vgf" + + def _create_default_pipeline_config(self) -> ArmPassPipelineConfig: + config = super()._create_default_pipeline_config() + # GRPHCOMP-3140 / MLETORCH-1529 + config.disable_fuse_duplicate_users() + return config diff --git a/backends/arm/vgf/model_converter.py b/backends/arm/vgf/model_converter.py new file mode 100644 index 00000000000..dffbf76f26a --- /dev/null +++ b/backends/arm/vgf/model_converter.py @@ -0,0 +1,34 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from shutil import which +from typing import Optional + +MODEL_CONVERTER_BINARY = "model-converter" +_MODEL_CONVERTER_FALLBACK_BINARY = "model_converter" + + +def find_model_converter_binary() -> Optional[str]: + """Return the name of the first model converter executable on PATH.""" + + for candidate in (MODEL_CONVERTER_BINARY, _MODEL_CONVERTER_FALLBACK_BINARY): + if which(candidate): + return candidate + return None + + +def require_model_converter_binary() -> str: + """Return a usable model converter executable or raise a helpful error.""" + + binary = find_model_converter_binary() + if binary is None: + tried = ", ".join((MODEL_CONVERTER_BINARY, _MODEL_CONVERTER_FALLBACK_BINARY)) + raise RuntimeError( + "Unable to locate a model converter executable. " + f"Tried: {tried}. Ensure the Model Converter is installed and on PATH." + ) + return binary diff --git a/backends/arm/vgf/partitioner.py b/backends/arm/vgf/partitioner.py index f6dab597487..96c4408b922 100644 --- a/backends/arm/vgf/partitioner.py +++ b/backends/arm/vgf/partitioner.py @@ -3,30 +3,33 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe -from typing import final, List, Optional, Sequence +from typing import final, Optional, Sequence -from executorch.backends.arm.arm_backend import ( - is_vgf, -) # usort: skip from executorch.backends.arm.tosa.partitioner import TOSAPartitioner -from executorch.backends.arm.vgf import VgfBackend -from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.backends.arm.vgf import VgfBackend, VgfCompileSpec from executorch.exir.backend.partitioner import DelegationSpec from torch.fx.passes.operator_support import OperatorSupportBase @final class VgfPartitioner(TOSAPartitioner): + """ + Partitions subgraphs supported by the Arm Vgf backend. + + Args: + compile_spec: The Vgf compilation specification. + additional_checks: Optional sequence of additional operator support checks. + """ + def __init__( self, - compile_spec: List[CompileSpec], + compile_spec: VgfCompileSpec, additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> None: - if not is_vgf(compile_spec): - raise RuntimeError("compile spec is not targeting Vgf") - # Override the delegation spec for Vgf - self.delegation_spec = DelegationSpec(VgfBackend.__name__, compile_spec) + self.delegation_spec = DelegationSpec( + VgfBackend.__name__, compile_spec.to_list() + ) self.additional_checks = additional_checks + self.tosa_spec = compile_spec.tosa_spec diff --git a/backends/backends.bzl b/backends/backends.bzl index 5ca30a83b54..42aed059f22 100644 --- a/backends/backends.bzl +++ b/backends/backends.bzl @@ -6,7 +6,6 @@ def get_all_cpu_backend_targets(): """ return [ "//executorch/backends/xnnpack:xnnpack_backend", - "//executorch/backends/fb/qnnpack:qnnpack_backend", ] def get_all_cpu_aot_and_backend_targets(): @@ -18,6 +17,4 @@ def get_all_cpu_aot_and_backend_targets(): return [ "//executorch/backends/xnnpack:xnnpack_preprocess", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", - "//executorch/backends/fb/qnnpack:qnnpack_preprocess", - "//executorch/backends/fb/qnnpack/partition:qnnpack_partitioner", ] + get_all_cpu_backend_targets() diff --git a/backends/cadence/CMakeLists.txt b/backends/cadence/CMakeLists.txt index 47183bed21d..271b4806614 100644 --- a/backends/cadence/CMakeLists.txt +++ b/backends/cadence/CMakeLists.txt @@ -88,8 +88,11 @@ elseif(EXECUTORCH_FUSION_G3_OPT) ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/third-party/nnlib ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 ) +elseif(EXECUTORCH_VISION_OPT) + set(TARGET_DIR vision) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/kernels) else() - set(TARGET_DIR reference) + set(TARGET_DIR generic) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/kernels) endif() diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 0c39fd3d38e..e99bd9ab1dc 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -12,7 +12,6 @@ load( "CXX", ) load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib") -load("@fbcode_macros//build_defs:cpp_python_extension.bzl", "cpp_python_extension") oncall("odai_jarvis") @@ -118,6 +117,7 @@ runtime.python_library( ], deps = [ "fbcode//caffe2:torch", + "fbcode//executorch/backends/cadence/aot:ref_implementations", "fbcode//executorch/backends/cadence/aot:utils", ], ) @@ -131,6 +131,7 @@ runtime.python_library( deps = [ "fbcode//caffe2:torch", "fbcode//executorch/exir:scalar_type", + "fbcode//executorch/kernels/quantized:custom_ops_generated_lib", ], ) @@ -143,8 +144,19 @@ executorch_generated_lib( platforms = CXX, visibility = ["PUBLIC"], deps = [ - "//executorch/backends/cadence/reference/kernels:cadence_kernels", - "//executorch/backends/cadence/reference/operators:cadence_cpu_ops", + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/backends/cadence/generic/operators:op_requantize", + "//executorch/backends/cadence/generic/operators:op_im2row", + "//executorch/backends/cadence/generic/operators:op_dequantize_per_tensor", + "//executorch/backends/cadence/generic/operators:op_quantize_per_tensor", + "//executorch/backends/cadence/generic/operators:op_quantized_add", + "//executorch/backends/cadence/generic/operators:op_quantized_conv2d", + "//executorch/backends/cadence/generic/operators:op_quantized_conv1d", + "//executorch/backends/cadence/generic/operators:op_quantized_fully_connected", + "//executorch/backends/cadence/generic/operators:op_quantized_layer_norm", + "//executorch/backends/cadence/generic/operators:op_quantized_linear", + "//executorch/backends/cadence/generic/operators:op_quantized_matmul", + "//executorch/backends/cadence/generic/operators:op_quantized_relu", "//executorch/kernels/portable:executorch_all_ops", "//executorch/kernels/portable:operators", ], @@ -256,6 +268,7 @@ runtime.python_library( ], typing = True, deps = [ + ":ops_registrations", "//caffe2:torch", "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:simplify_ops", @@ -345,6 +358,7 @@ python_unittest( typing = True, deps = [ ":ops_registrations", + ":typing_stubs", ":type_dispatch", "//caffe2:torch", "//executorch/backends/cadence/aot:graph_builder", @@ -615,7 +629,22 @@ python_unittest( deps = [ ":typing_stubs", "//executorch/backends/cadence/aot:ops_registrations", - "//executorch/backends/cadence/aot:ref_implementations", "//caffe2:torch", ] ) + +python_unittest( + name = "test_quantizer_ops", + srcs = [ + "tests/test_quantizer_ops.py", + ], + typing = True, + deps = [ + "fbsource//third-party/pypi/parameterized:parameterized", + "//caffe2:torch", + "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/cadence/aot/quantizer:quantizer", + "//executorch/exir:pass_base", + "//pytorch/ao:torchao", + ], +) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 6c497d5bec4..5770b05ad1e 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -14,6 +14,7 @@ import torch from executorch.backends.cadence.aot.compiler_funcs import ( prepare as prepare_fn, + QuantizedInputWrapper, trace as trace_fn, ) from executorch.backends.cadence.aot.memory_planning import ( @@ -37,42 +38,28 @@ ExecutorchProgramManager, ) from executorch.exir.passes import ToOutVarPass -from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass -from executorch.exir.program._program import to_edge - +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass +from executorch.exir.program._program import _transform, to_edge from torch.export.exported_program import ExportedProgram from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e from .passes import apply_exir_ops_passes, apply_torch_ops_passes - from .utils import print_ops_info default_quantizer = CadenceDefaultQuantizer() -# Note: this is not meant as a primary API since it can create inconsistencies -# if the quantizer here is different from the quantizer used to convert. It is -# however useful for unit tests to separate the converted model from the fused -# model, to be able to get reference numerics. -# If this does not apply, please use quantize_pt2 instead. def trace( model: torch.nn.Module, inputs: tuple[object, ...], dump_graphs: bool = False, + ops_to_keep: Optional[list[torch._ops.OpOverload]] = None, ) -> ExportedProgram: """ Trace the model with export and return an ExportedProgram. """ - - ops_to_keep = [ - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - torch.ops.aten.layer_norm.default, - torch.ops.aten.linear.default, - torch.ops.aten.matmul.default, - torch.ops.aten.rms_norm.default, - ] - + if ops_to_keep is None: + ops_to_keep = [] program = trace_fn( model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep ) @@ -99,7 +86,10 @@ def prepare_pt2( Returns a GraphModule with the prepared model. """ - traced_program = trace(model, inputs, dump_graphs=dump_graphs) + ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition() + traced_program = trace( + model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep + ) prepared_program = prepare_traced_pt2( traced_program, quantizer, dump_graphs=dump_graphs ) @@ -153,23 +143,23 @@ def convert_pt2( # It is however useful for unit tests to separate the converted model from the # fused model, to be able to get reference numerics. # If this does not apply, please use quantize_pt2 instead. -def fuse_pt2( - converted_graph_module: torch.fx.GraphModule, +def apply_pre_edge_transform_passes( + converted_program: ExportedProgram, quantizer: CadenceQuantizer, -) -> torch.fx.GraphModule: +) -> ExportedProgram: """ - Fuse a converted graph module using the given quantizer. + Fuse a converted exported program using the given quantizer. The quantizer must be the same as the one used to convert the model. If you do not expect that behavior, please use quantize_pt2 instead, which will instantiate a default quantizer for you if needed. - Returns a GraphModule with the fused model. + Returns an ExportedProgram with the fused model. """ # Get patterns and apply fusion of dq -> op -> q to qop # pyre-ignore[16]: no attribute patterns = [q.pattern for q in quantizer.quantizers] - QuantFusion(patterns)(converted_graph_module) + fused_program = _transform(converted_program, QuantFusion(patterns)) - return converted_graph_module + return fused_program # Note: quantizer is not optional here to force the user to supply a quantizer @@ -184,14 +174,15 @@ def get_fake_quant_model( # Make the model inference mode by calling model.eval() model.eval() - program = trace(model, inputs, dump_graphs=dump_graphs) + ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition() + program = trace(model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep) if dump_graphs: logging.info("Graph after trace:") logging.info(program.graph.print_tabular()) # Get prepared graph module - prepared_gm = prepare_pt2(model, inputs, quantizer, dump_graphs=dump_graphs) + prepared_gm = prepare_traced_pt2(program, quantizer, dump_graphs=dump_graphs) # Calibrate # If no calibration data is provided, use the inputs @@ -212,13 +203,14 @@ def quantize_pt2( quantizer: Optional[CadenceQuantizer] = None, calibration_data: Optional[list[tuple[object, ...]]] = None, dump_graphs: bool = False, + quant_input_args: Optional[list[str]] = None, ) -> ExportedProgram: """ Trace, prepare, convert and fuse the model using the given quantizer. If calibration data is provided, it will be used to calibrate the model. If not, the inputs will be used for calibration instead, which is useful for unit tests but should not be used for end-to-end use cases. - Returns a GraphModule with the quantized model. + Returns an ExportedProgram with the quantized model. Note: this function should not be called directly in general. Please use quantize_and_export_to_executorch for most needs. """ @@ -234,17 +226,18 @@ def quantize_pt2( calibration_data=calibration_data, dump_graphs=dump_graphs, ) + # Wrap the model to handle quantized inputs + wrapped_module = QuantizedInputWrapper(converted_gm, quant_input_args).module - # Get fused model - fused_gm = fuse_pt2(converted_gm, quantizer) + # Apply quant fusion to the exported program + program = torch.export.export(wrapped_module, inputs, strict=True) + fused_program = apply_pre_edge_transform_passes(program, quantizer) if dump_graphs: logging.info("Graph after quantization and fusion:") - logging.info(fused_gm.graph.print_tabular()) + logging.info(fused_program.graph_module.graph.print_tabular()) - program = torch.export.export(fused_gm, inputs, strict=True) - - return program + return fused_program TO_EDGE_OP_EXCEPTION_LIST: list[torch._ops.OpOverload] = [ @@ -452,7 +445,7 @@ def _lower_ep_to_cadence_gen_etrecord( emit_stacktrace=False, to_out_var_pass=ToOutVarPass(), extract_delegate_segments=False, - sym_shape_eval_pass=HintBasedSymShapeEvalPass(), + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ), ) diff --git a/backends/cadence/aot/compiler_funcs.py b/backends/cadence/aot/compiler_funcs.py index 6ff6057255c..9756602ad2d 100644 --- a/backends/cadence/aot/compiler_funcs.py +++ b/backends/cadence/aot/compiler_funcs.py @@ -6,14 +6,18 @@ # pyre-strict - -from typing import Optional +import logging +from typing import Any, Optional, Union import torch from torch._inductor.decomposition import remove_decompositions +from torch.fx import GraphModule from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e from torchao.quantization.pt2e.quantizer import Quantizer +logger: logging.Logger = logging.getLogger(__name__) +QuantArgs = tuple[float, int, int, int, torch.dtype] + @torch.no_grad() def trace( @@ -52,3 +56,108 @@ def prepare( prepared_model = prepare_pt2e(traced_model, quantizer) return prepared_model + + +def extract_input_quant_params_from_graph( + module: GraphModule, + input_names: list[str], +) -> dict[int, QuantArgs]: + """ + Extract quantization parameters from the FX graph for model inputs. + """ + quant_args: dict[int, QuantArgs] = {} + found_names: set[str] = set() + + if not input_names: + return quant_args + + for idx, name in enumerate(input_names): + for node in module.graph.nodes: + if node.op != "call_function": + continue + + if ( + node.args + and isinstance(node.args[0], torch.fx.Node) + and node.args[0].name == name + and not node.name.startswith("_assert_tensor_metadata") + and "quantize_per_tensor" in str(node.target) + ): + args = node.args[1:] + if len(args) >= 5: + quant_args[idx] = ( + float(args[0]), # scale + int(args[1]), # zero_point + int(args[2]), # qmin + int(args[3]), # qmax + args[4], # dtype + ) + found_names.add(name) + break + + missing_names = set(input_names) - found_names + if missing_names: + raise ValueError( + f"Could not find quantization parameters for input(s): {sorted(missing_names)}. " + f"Make sure these input names exist in the graph and quantization parameters." + ) + + return quant_args + + +class QuantizedInputWrapper(torch.nn.Module): + """ + Wrapper that allows a quantized model to accept quantized inputs. + + If no input_names or quant_args are provided, the wrapper passes inputs + through unchanged (no dequantization). + + Args: + module: The quantized GraphModule to wrap. + input_names: Optional list of input placeholder names in the graph. + If provided, extracts quant params from graph. + quant_args: Optional dict mapping input index to (scale, zero_point, qmin, qmax, dtype). + If provided, uses these directly instead of extracting from graph. + + Example: + # Extract from graph + wrapper = QuantizedInputWrapper(quantized_module, input_names=["x"]) + + # Explicit quant args + wrapper = QuantizedInputWrapper( + quantized_module, + quant_args={0: (1/255, 0, 0, 255, torch.uint8)}, + ) + """ + + def __init__( + self, + module: GraphModule, + input_args: Optional[Union[list[str], dict[int, QuantArgs]]] = None, + ) -> None: + super().__init__() + self.module: GraphModule = module + self.quant_args: dict[int, QuantArgs] = {} + + if input_args is not None: + logger.warning( + "Warning: Using pre-quantized inputs. This should only be done when calibration has been confirmed." + "Incorrect quantization parameters can lead to significant accuracy degradation." + ) + if isinstance(input_args, list): + self.quant_args = extract_input_quant_params_from_graph(module, input_args) + elif isinstance(input_args, dict): + self.quant_args = input_args + + def forward(self, *args: torch.Tensor) -> Any: + """Run inference, dequantizing configured inputs.""" + dequantized_args = [] + for index, node in enumerate(args): + if index in self.quant_args: + scale, zp, qmin, qmax, dtype = self.quant_args[index] + node = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + node, scale, zp, qmin, qmax, dtype + ) + dequantized_args.append(node) + + return self.module(*dequantized_args) diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index 6af7a88fdc2..cf4fa484997 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -18,8 +18,8 @@ from executorch.backends.cadence.aot.compiler import ( _lower_ep_to_cadence_gen_etrecord, + apply_pre_edge_transform_passes, convert_pt2, - fuse_pt2, prepare_pt2, ) @@ -63,11 +63,10 @@ def export_model( # Get reference outputs from converted model ref_outputs = converted_model(*example_inputs) - # Quantize the model (note: quantizer needs to be the same as - # the one used in prepare_and_convert_pt2) - quantized_model = fuse_pt2(converted_model, quantizer) + ep = torch.export.export(converted_model, example_inputs, strict=True) - ep = torch.export.export(quantized_model, example_inputs, strict=True) + # Fuse the quantized patterns on the exported program (note: quantizer needs to be the same as the one used in prepare_and_convert_pt2) + ep = apply_pre_edge_transform_passes(ep, quantizer) # Get edge program after Cadence specific passes exec_prog: ExecutorchProgramManager = _lower_ep_to_cadence_gen_etrecord( diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 196480931e0..3ba6f4700b1 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -182,209 +182,289 @@ variants: function kernels: - arg_meta: null - kernel_name: impl::reference::quantize_per_tensor_out + kernel_name: impl::generic::quantize_per_tensor_out + +- func: cadence::quantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::quantize_per_tensor_asym8s_out + +- func: cadence::quantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::quantize_per_tensor_asym8u_out + +- func: cadence::quantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::quantize_per_tensor_asym16s_out + +- func: cadence::quantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::quantize_per_tensor_asym16u_out + +- func: cadence::quantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::quantize_per_tensor_asym32s_out - func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: impl::reference::dequantize_per_tensor_out + kernel_name: impl::generic::dequantize_per_tensor_out + +- func: cadence::dequantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::dequantize_per_tensor_asym8s_out + +- func: cadence::dequantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::dequantize_per_tensor_asym8u_out + +- func: cadence::dequantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::dequantize_per_tensor_asym16s_out + +- func: cadence::dequantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::dequantize_per_tensor_asym16u_out + +- func: cadence::dequantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::dequantize_per_tensor_asym32s_out -- func: cadence::quantized_conv_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nchw_out + kernel_name: impl::generic::quantized_conv2d_nchw_out -- func: cadence::quantized_conv_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nhwc_out + kernel_name: impl::generic::quantized_conv2d_nhwc_out - func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_layer_norm_out + kernel_name: impl::generic::quantized_layer_norm_out - func: cadence::quantized_layer_norm.per_tensor_out(Tensor input, float in_scale, int in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_layer_norm_per_tensor_out + kernel_name: impl::generic::quantized_layer_norm_per_tensor_out - func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_linear_out + kernel_name: impl::generic::quantized_linear_out - func: cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_linear_per_tensor_out + kernel_name: impl::generic::quantized_linear_per_tensor_out - func: cadence::quantized_linear_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_linear_asym8sxasym8s_asym8s_per_tensor_out + kernel_name: impl::generic::quantized_linear_asym8sxasym8s_asym8s_per_tensor_out - func: cadence::quantized_linear_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_linear_asym8uxasym8u_asym8u_per_tensor_out + kernel_name: impl::generic::quantized_linear_asym8uxasym8u_asym8u_per_tensor_out - func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_relu_out + kernel_name: impl::generic::quantized_relu_out - func: cadence::quantized_relu.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_relu_per_tensor_out + kernel_name: impl::generic::quantized_relu_per_tensor_out - func: cadence::quantized_relu_asym8s_asym8s.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_relu_asym8s_asym8s_per_tensor_out + kernel_name: impl::generic::quantized_relu_asym8s_asym8s_per_tensor_out - func: cadence::quantized_relu_asym8u_asym8u.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_relu_asym8u_asym8u_per_tensor_out + kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out - func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_add_per_tensor_out + kernel_name: impl::generic::quantized_add_per_tensor_out - func: cadence::quantized_add_asym8sxasym8s_asym8s.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_add_asym8sxasym8s_asym8s_per_tensor_out + kernel_name: impl::generic::quantized_add_asym8sxasym8s_asym8s_per_tensor_out - func: cadence::quantized_add_asym8uxasym8u_asym8u.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_add_asym8uxasym8u_asym8u_per_tensor_out + kernel_name: impl::generic::quantized_add_asym8uxasym8u_asym8u_per_tensor_out - func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_matmul_out + kernel_name: impl::generic::quantized_matmul_out - func: cadence::quantized_matmul_asym8sxasym8s_asym8s.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_matmul_asym8sxasym8s_asym8s_out + kernel_name: impl::generic::quantized_matmul_asym8sxasym8s_asym8s_out - func: cadence::quantized_matmul_asym8uxasym8u_asym8u.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_matmul_asym8uxasym8u_asym8u_out + kernel_name: impl::generic::quantized_matmul_asym8uxasym8u_asym8u_out - func: cadence::im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::im2row_out + kernel_name: impl::generic::im2row_out - func: cadence::im2row.per_tensor_out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, int in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::im2row_per_tensor_out + kernel_name: impl::generic::im2row_per_tensor_out + +- func: cadence::quantized_conv2d_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_conv2d_nchw_per_tensor_out + +- func: cadence::quantized_conv2d_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_conv2d_nhwc_per_tensor_out + +- func: cadence::quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out -- func: cadence::quantized_conv_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nchw_per_tensor_out + kernel_name: impl::generic::quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out -- func: cadence::quantized_conv_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nhwc_per_tensor_out + kernel_name: impl::generic::quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out -- func: cadence::quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out + kernel_name: impl::generic::quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out -- func: cadence::quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out + kernel_name: impl::generic::quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out -- func: cadence::quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out + kernel_name: impl::generic::quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out -- func: cadence::quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out + kernel_name: impl::generic::quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out -- func: cadence::quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out + kernel_name: impl::generic::quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out -- func: cadence::quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out + kernel_name: impl::generic::quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out -- func: cadence::quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out + kernel_name: impl::generic::quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out -- func: cadence::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out + kernel_name: impl::generic::quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out -- func: cadence::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out + kernel_name: impl::generic::quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out -- func: cadence::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv1d_ncl_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out + kernel_name: impl::generic::quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out -- func: cadence::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv1d_nlc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out + kernel_name: impl::generic::quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out -- func: cadence::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv1d_nlc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out + kernel_name: impl::generic::quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out - func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_fully_connected_out + kernel_name: impl::generic::quantized_fully_connected_out - func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_fully_connected_per_tensor_out + kernel_name: impl::generic::quantized_fully_connected_per_tensor_out - func: cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out + kernel_name: impl::generic::quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out - func: cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out + kernel_name: impl::generic::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out - func: cadence::requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::requantize_out + kernel_name: impl::generic::requantize_out - func: cadence::requantize.per_tensor_out(Tensor input, float in_scale, int in_zero_point, float out_scale, int out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::reference::requantize_per_tensor_out + kernel_name: impl::generic::requantize_per_tensor_out diff --git a/backends/cadence/aot/functions_fusion_g3.yaml b/backends/cadence/aot/functions_fusion_g3.yaml index 269e8a08e4b..d41c19c0b01 100644 --- a/backends/cadence/aot/functions_fusion_g3.yaml +++ b/backends/cadence/aot/functions_fusion_g3.yaml @@ -20,17 +20,17 @@ - op: _softmax.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::_softmax_out + kernel_name: impl::G3::_softmax_out - op: add.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::add_out + kernel_name: impl::G3::add_out - op: add.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::add_scalar_out + kernel_name: impl::G3::add_scalar_out - op: bmm.out kernels: @@ -40,18 +40,18 @@ - op: cat.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::cat_out + kernel_name: impl::G3::cat_out - op: clamp.out cpp_no_default_args: ['min'] kernels: - arg_meta: null - kernel_name: cadence::impl::G3::clamp_out + kernel_name: impl::G3::clamp_out - op: clamp.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::clamp_Tensor_out + kernel_name: impl::G3::clamp_Tensor_out - op: clone.out kernels: @@ -61,12 +61,12 @@ - op: div.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::div_out + kernel_name: impl::G3::div_out - op: div.out_mode kernels: - arg_meta: null - kernel_name: cadence::impl::G3::div_out_mode + kernel_name: impl::G3::div_out_mode - op: embedding.out kernels: @@ -81,41 +81,41 @@ - op: lt.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::lt_Scalar_out + kernel_name: impl::G3::lt_Scalar_out - op: lt.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::lt_Tensor_out + kernel_name: impl::G3::lt_Tensor_out - op: mul.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::mul_out + kernel_name: impl::G3::mul_out - op: mul.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::mul_scalar_out + kernel_name: impl::G3::mul_scalar_out - op: permute_copy.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::permute_copy_out + kernel_name: impl::G3::permute_copy_out - op: rsqrt.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::rsqrt_out + kernel_name: impl::G3::rsqrt_out - op: sigmoid.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::sigmoid_out + kernel_name: impl::G3::sigmoid_out - op: slice_copy.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::slice_copy_Tensor_out + kernel_name: impl::G3::slice_copy_Tensor_out - op: split_with_sizes_copy.out kernels: @@ -125,27 +125,27 @@ - op: sqrt.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::sqrt_out + kernel_name: impl::G3::sqrt_out - op: sub.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::sub_out + kernel_name: impl::G3::sub_out - op: sub.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::sub_scalar_out + kernel_name: impl::G3::sub_scalar_out - op: tanh.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::tanh_out + kernel_name: impl::G3::tanh_out - op: transpose_copy.int_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::transpose_copy_int_out + kernel_name: impl::G3::transpose_copy_int_out - op: view_copy.out kernels: @@ -155,37 +155,37 @@ - op: where.self_out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::where_self_out + kernel_name: impl::G3::where_self_out - op: native_layer_norm.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::native_layer_norm_out + kernel_name: impl::G3::native_layer_norm_out - op: mean.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::mean_out + kernel_name: impl::G3::mean_out - op: exp.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::exp_out - + kernel_name: impl::G3::exp_out + - op: hardtanh.out kernels: - arg_meta: null - kernel_name: cadence::impl::G3::hardtanh_out + kernel_name: impl::G3::hardtanh_out # custom ops - func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: cadence::impl::G3::native::quantize_per_tensor_out + kernel_name: impl::G3::native::quantize_per_tensor_out - func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: cadence::impl::G3::native::dequantize_per_tensor_out + kernel_name: impl::G3::native::dequantize_per_tensor_out diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index cf4c5a8fffb..3bdbb33d59b 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -20,62 +20,62 @@ - op: _softmax.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::_softmax_out + kernel_name: impl::HiFi::_softmax_out - op: atan2.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::atan2_out + kernel_name: impl::HiFi::atan2_out - op: add.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::add_out + kernel_name: impl::HiFi::add_out - op: bitwise_and.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::bitwise_and_Scalar_out + kernel_name: impl::HiFi::bitwise_and_Scalar_out - op: bitwise_and.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::bitwise_and_Tensor_out + kernel_name: impl::HiFi::bitwise_and_Tensor_out - op: bitwise_or.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::bitwise_or_Scalar_out + kernel_name: impl::HiFi::bitwise_or_Scalar_out - op: bitwise_or.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::bitwise_or_Tensor_out + kernel_name: impl::HiFi::bitwise_or_Tensor_out - op: bitwise_xor.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::bitwise_xor_Scalar_out + kernel_name: impl::HiFi::bitwise_xor_Scalar_out - op: bitwise_xor.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::bitwise_xor_Tensor_out + kernel_name: impl::HiFi::bitwise_xor_Tensor_out - op: bmm.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::bmm_out + kernel_name: impl::HiFi::bmm_out - op: cat.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::cat_out + kernel_name: impl::HiFi::cat_out - op: clamp.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::clamp_Tensor_out + kernel_name: impl::HiFi::clamp_Tensor_out - op: clone.out kernels: @@ -85,47 +85,47 @@ - op: div.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::div_out + kernel_name: impl::HiFi::div_out - op: div.out_mode kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::div_out_mode + kernel_name: impl::HiFi::div_out_mode - op: embedding.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::embedding_out + kernel_name: impl::HiFi::embedding_out - op: eq.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::eq_Tensor_out + kernel_name: impl::HiFi::eq_Tensor_out - op: fmod.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::fmod_Tensor_out + kernel_name: impl::HiFi::fmod_Tensor_out - op: fmod.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::fmod_Scalar_out + kernel_name: impl::HiFi::fmod_Scalar_out - op: full.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::full_out + kernel_name: impl::HiFi::full_out - op: ge.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::ge_Scalar_out + kernel_name: impl::HiFi::ge_Scalar_out - op: ge.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::ge_Tensor_out + kernel_name: impl::HiFi::ge_Tensor_out - op: gelu.out kernels: @@ -135,42 +135,42 @@ - op: gt.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::gt_Scalar_out + kernel_name: impl::HiFi::gt_Scalar_out - op: gt.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::gt_Tensor_out + kernel_name: impl::HiFi::gt_Tensor_out - op: hardtanh.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::hardtanh_out + kernel_name: impl::HiFi::hardtanh_out - op: le.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::le_Scalar_out + kernel_name: impl::HiFi::le_Scalar_out - op: le.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::le_Tensor_out + kernel_name: impl::HiFi::le_Tensor_out - op: lt.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::lt_Scalar_out + kernel_name: impl::HiFi::lt_Scalar_out - op: lt.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::lt_Tensor_out + kernel_name: impl::HiFi::lt_Tensor_out - op: masked_fill.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::masked_fill_Scalar_out + kernel_name: impl::HiFi::masked_fill_Scalar_out - op: max_pool2d_with_indices.out kernels: @@ -180,291 +180,386 @@ - op: maximum.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::maximum_out + kernel_name: impl::HiFi::maximum_out - op: mean.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::mean_out + kernel_name: impl::HiFi::mean_out - op: minimum.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::minimum_out + kernel_name: impl::HiFi::minimum_out - op: mm.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::mm_out + kernel_name: impl::HiFi::mm_out - op: mul.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::mul_out + kernel_name: impl::HiFi::mul_out - op: ne.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::ne_Tensor_out + kernel_name: impl::HiFi::ne_Tensor_out - op: permute_copy.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::permute_copy_out + kernel_name: impl::HiFi::permute_copy_out - op: pow.Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::pow_Scalar_out + kernel_name: impl::HiFi::pow_Scalar_out - op: pow.Tensor_Scalar_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::pow_Tensor_Scalar_out + kernel_name: impl::HiFi::pow_Tensor_Scalar_out - op: pow.Tensor_Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::pow_Tensor_Tensor_out + kernel_name: impl::HiFi::pow_Tensor_Tensor_out - op: remainder.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::remainder_Tensor_out + kernel_name: impl::HiFi::remainder_Tensor_out - op: rsqrt.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::rsqrt_out + kernel_name: impl::HiFi::rsqrt_out - op: select_copy.int_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::select_copy_int_out + kernel_name: impl::HiFi::select_copy_int_out - op: sigmoid.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::sigmoid_out + kernel_name: impl::HiFi::sigmoid_out - op: slice_copy.Tensor_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::slice_copy_Tensor_out + kernel_name: impl::HiFi::slice_copy_Tensor_out - op: split_with_sizes_copy.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::split_with_sizes_copy_out + kernel_name: impl::HiFi::split_with_sizes_copy_out - op: sub.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::sub_out + kernel_name: impl::HiFi::sub_out - op: tanh.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::tanh_out + kernel_name: impl::HiFi::tanh_out - op: view_copy.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::view_copy_out + kernel_name: impl::HiFi::view_copy_out - op: where.self_out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::where_self_out + kernel_name: impl::HiFi::where_self_out # custom ops - func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantize_per_tensor_out + kernel_name: impl::HiFi::quantize_per_tensor_out + +- func: cadence::quantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantize_per_tensor_asym8s_out + +- func: cadence::quantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantize_per_tensor_asym8u_out + +- func: cadence::quantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantize_per_tensor_asym16s_out + +- func: cadence::quantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantize_per_tensor_asym16s_out + +- func: cadence::quantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantize_per_tensor_asym32s_out - func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::dequantize_per_tensor_out + kernel_name: impl::HiFi::dequantize_per_tensor_out -- func: cadence::quantized_conv_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::dequantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nchw_out + kernel_name: impl::HiFi::dequantize_per_tensor_asym8s_out -- func: cadence::quantized_conv_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::dequantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_out + kernel_name: impl::HiFi::dequantize_per_tensor_asym8u_out -- func: cadence::quantized_conv_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::dequantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::dequantize_per_tensor_asym16s_out + +- func: cadence::dequantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::dequantize_per_tensor_asym16u_out + +- func: cadence::dequantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nchw_per_tensor_out + kernel_name: impl::HiFi::dequantize_per_tensor_asym16s_out -- func: cadence::quantized_conv_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nchw_out -- func: cadence::quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nhwc_out -- func: cadence::quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nchw_per_tensor_out -- func: cadence::quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nhwc_per_tensor_out -- func: cadence::quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out -- func: cadence::quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out -- func: cadence::quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out -- func: cadence::quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out -- func: cadence::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out -- func: cadence::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out -- func: cadence::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out -- func: cadence::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out -- func: cadence::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out + kernel_name: impl::HiFi::quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out + +- func: cadence::quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out + +- func: cadence::quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv1d_ncl_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out + +- func: cadence::quantized_conv1d_nlc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv1d_nlc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out - func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_layer_norm_out + kernel_name: impl::HiFi::quantized_layer_norm_out - func: cadence::quantized_layer_norm.per_tensor_out(Tensor input, float in_scale, int in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_layer_norm_per_tensor_out + kernel_name: impl::HiFi::quantized_layer_norm_per_tensor_out - func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_linear_out + kernel_name: impl::HiFi::quantized_linear_out - func: cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_linear_per_tensor_out + kernel_name: impl::HiFi::quantized_linear_per_tensor_out - func: cadence::quantized_linear_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_linear_asym8sxasym8s_asym8s_per_tensor_out + kernel_name: impl::HiFi::quantized_linear_asym8sxasym8s_asym8s_per_tensor_out - func: cadence::quantized_linear_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_linear_asym8uxasym8u_asym8u_per_tensor_out + kernel_name: impl::HiFi::quantized_linear_asym8uxasym8u_asym8u_per_tensor_out - func: cadence::quantized_relu_per_tensor.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_relu_per_tensor_out + kernel_name: impl::HiFi::quantized_relu_per_tensor_out - func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_relu_out + kernel_name: impl::HiFi::quantized_relu_out - func: cadence::quantized_relu.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_relu_per_tensor_out + kernel_name: impl::HiFi::quantized_relu_per_tensor_out - func: cadence::quantized_relu_asym8s_asym8s.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_relu_asym8s_asym8s_per_tensor_out + kernel_name: impl::HiFi::quantized_relu_asym8s_asym8s_per_tensor_out - func: cadence::quantized_relu_asym8u_asym8u.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_relu_asym8u_asym8u_per_tensor_out + kernel_name: impl::HiFi::quantized_relu_asym8u_asym8u_per_tensor_out - func: cadence::quantized_add_asym8sxasym8s_asym8s.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_add_asym8sxasym8s_asym8s_per_tensor_out + kernel_name: impl::HiFi::quantized_add_asym8sxasym8s_asym8s_per_tensor_out - func: cadence::quantized_add_asym8uxasym8u_asym8u.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_add_asym8uxasym8u_asym8u_per_tensor_out + kernel_name: impl::HiFi::quantized_add_asym8uxasym8u_asym8u_per_tensor_out - func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_matmul_out + kernel_name: impl::HiFi::quantized_matmul_out - func: cadence::quantized_matmul_asym8sxasym8s_asym8s.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_matmul_asym8sxasym8s_asym8s_out + kernel_name: impl::HiFi::quantized_matmul_asym8sxasym8s_asym8s_out - func: cadence::quantized_matmul_asym8uxasym8u_asym8u.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_matmul_asym8uxasym8u_asym8u_out + kernel_name: impl::HiFi::quantized_matmul_asym8uxasym8u_asym8u_out - func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_fully_connected_out + kernel_name: impl::HiFi::quantized_fully_connected_out - func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_fully_connected_per_tensor_out + kernel_name: impl::HiFi::quantized_fully_connected_per_tensor_out - func: cadence::quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out + kernel_name: impl::HiFi::quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out - func: cadence::quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out + kernel_name: impl::HiFi::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out + +- func: cadence::quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_w8a32_linear_out + +- func: cadence::quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_w8a32_conv_out + +- func: cadence::quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_w8a32_gru_out diff --git a/backends/cadence/aot/functions_vision.yaml b/backends/cadence/aot/functions_vision.yaml new file mode 100644 index 00000000000..cae1e0dc415 --- /dev/null +++ b/backends/cadence/aot/functions_vision.yaml @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This yaml file contains operators that are also defined by the ATen library. +# For lean mode: +# - Codegen'd target `executorch_generated_lib` will be reading all the information +# from this file, including operator schema and kernel metadata. +# - Selective build target `codegen:executorch_defined_ops` now is selecting all the +# operators in this file, by dumping all the op names into `selected_operators.yaml`. +# +# See the README.md file in executorch/kernels/portable for a description of the syntax used +# by this file. + + +# aten ops +- op: _to_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::to_copy_out + +- op: _softmax.out + kernels: + - arg_meta: null + kernel_name: impl::vision::_softmax_out + +- op: add.out + kernels: + - arg_meta: null + kernel_name: impl::vision::add_out + +- op: bmm.out + kernels: + - arg_meta: null + kernel_name: torch::executor::bmm_out + +- op: cat.out + kernels: + - arg_meta: null + kernel_name: torch::executor::cat_out + +- op: clone.out + kernels: + - arg_meta: null + kernel_name: torch::executor::clone_out + +- op: div.out + kernels: + - arg_meta: null + kernel_name: torch::executor::div_out + +- op: div.out_mode + kernels: + - arg_meta: null + kernel_name: torch::executor::div_out_mode + +- op: embedding.out + kernels: + - arg_meta: null + kernel_name: impl::vision::embedding_out + +- op: empty.out + kernels: + - arg_meta: null + kernel_name: torch::executor::empty_out + +- op: expand_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::expand_copy_out + +- op: full.out + kernels: + - arg_meta: null + kernel_name: impl::vision::full_out + +- op: gelu.out + kernels: + - arg_meta: null + kernel_name: torch::executor::gelu_out + +- op: hardtanh.out + kernels: + - arg_meta: null + kernel_name: torch::executor::hardtanh_out + +- op: max_pool2d_with_indices.out + kernels: + - arg_meta: null + kernel_name: torch::executor::max_pool2d_with_indices_out + +- op: mean.out + kernels: + - arg_meta: null + kernel_name: torch::executor::mean_dim_out + +- op: mul.out + kernels: + - arg_meta: null + kernel_name: torch::executor::mul_out + +- op: mul.Scalar_out + kernels: + - arg_meta: null + kernel_name: torch::executor::mul_scalar_out + +- op: permute_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::permute_copy_out + +- op: rsqrt.out + kernels: + - arg_meta: null + kernel_name: torch::executor::rsqrt_out + +- op: sigmoid.out + kernels: + - arg_meta: null + kernel_name: torch::executor::sigmoid_out + +- op: slice_copy.Tensor_out + kernels: + - arg_meta: null + kernel_name: torch::executor::slice_copy_Tensor_out + +- op: split_with_sizes_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::split_with_sizes_copy_out + +- op: sub.out + kernels: + - arg_meta: null + kernel_name: torch::executor::sub_out + +- op: view_copy.out + kernels: + - arg_meta: null + kernel_name: impl::vision::view_copy_out + +- op: where.self_out + kernels: + - arg_meta: null + kernel_name: torch::executor::where_out + +- op: transpose_copy.int_out + kernels: + - arg_meta: null + kernel_name: torch::executor::transpose_copy_int_out + +- op: eq.Scalar_out + kernels: + - arg_meta: null + kernel_name: torch::executor::eq_scalar_out + +- op: logical_not.out + kernels: + - arg_meta: null + kernel_name: torch::executor::logical_not_out + +- op: any.out + kernels: + - arg_meta: null + kernel_name: torch::executor::any_out + +- op: native_group_norm.out + kernels: + - arg_meta: null + kernel_name: torch::executor::native_group_norm_out + +- op: sum.IntList_out + kernels: + - arg_meta: null + kernel_name: torch::executor::sum_dim_out + +- op: select_copy.int_out + kernels: + - arg_meta: null + kernel_name: torch::executor::select_copy_int_out + +# custom ops +- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::vision::quantize_per_tensor_out + +- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::vision::dequantize_per_tensor_out + +- func: cadence::quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_conv_out + +- func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_conv2d_nchw_out + +- func: cadence::quantized_conv2d_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_conv2d_nhwc_out + +- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_layer_norm_out +- func: cadence::quantized_layer_norm.per_tensor_out(Tensor input, float in_scale, int in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_layer_norm_per_tensor_out + +- func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_linear_out + +- func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_relu_out + +- func: cadence::quantized_relu.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_relu_per_tensor_out + +- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_matmul_out + +- func: cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_linear_per_tensor_out + +- func: cadence::im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::im2row_out + +- func: cadence::im2row.per_tensor_out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, int in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::im2row_per_tensor_out + +- func: cadence::quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_conv_per_tensor_out + +- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_fully_connected_out + +- func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::quantized_fully_connected_per_tensor_out + +- func: cadence::requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::requantize_out + +- func: cadence::requantize.per_tensor_out(Tensor input, float in_scale, int in_zero_point, float out_scale, int out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::vision::requantize_per_tensor_out diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index dbd19e1d3af..0d5c511c239 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -34,12 +34,12 @@ from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, register_cadence_pass, + RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue -from executorch.exir.passes import dead_code_elimination_pass from executorch.exir.passes.spec_prop_pass import SpecPropPass from torch.fx.node import Argument from torch.nn.utils.fusion import fuse_conv_bn_weights @@ -454,7 +454,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseCascadedTransposeOrPermuteOps(ExportPass): +class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface): """ Fuse a cascaded chain of transpose and permute ops """ @@ -464,89 +464,89 @@ class FuseCascadedTransposeOrPermuteOps(ExportPass): exir_ops.edge.aten.permute_copy.default, } - # Find a chain of transpose or permute ops, and fuse them into a single permute op. + @property + def targets(self) -> list[EdgeOpOverload]: + return list(self.transpose_or_permute_target) - def fuse_cascaded_transpose_or_permute_ops( - self, graph_module: torch.fx.GraphModule - ): - graph = graph_module.graph - for node in graph.nodes: - # We are only interested in permute/transpose ops - if node.target not in self.transpose_or_permute_target: - continue - # Get the cascaded chain of transpose/permute ops starting at node - cascaded_transpose_or_permute_ops = get_cascaded_ops( - [node], self.transpose_or_permute_target + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Get the cascaded chain of transpose/permute ops starting at node + cascaded_transpose_or_permute_ops = get_cascaded_ops( + [node], self.transpose_or_permute_target + ) + # The chain must have more than 1 node + if len(cascaded_transpose_or_permute_ops) == 1: + return False + + # Get shape from node metadata + val = node.meta.get("val") + if val is None: + return False + out_shape = val.shape + out_dims = len(out_shape) + + # This is the trivial dimension order + dims = list(range(out_dims)) + # Compute the effect of the chain on dims + for tp in cascaded_transpose_or_permute_ops: + dims = ( + get_transposed_dims(tp, dims) + if tp.target == exir_ops.edge.aten.transpose_copy.int + else get_permuted_dims(tp, dims) ) - # The chain must have more than 1 node - if len(cascaded_transpose_or_permute_ops) == 1: - continue - out_shape = get_shape(graph_module, node) - assert out_shape is not None - out_dims = len(out_shape) - # This is the trivial dimension order - dims = list(range(out_dims)) - # Compute the effect of the chain on dims - for tp in cascaded_transpose_or_permute_ops: - dims = ( - get_transposed_dims(tp, dims) - if tp.target == exir_ops.edge.aten.transpose_copy.int - else get_permuted_dims(tp, dims) - ) + graph = node.graph - # In case the permute chain cancelled each other, the final dims will - # be the same as the initial order. In that case, the chain was nop. - # Otherwise create a new permute op that encompasses the effect of the - # chain. - if dims == list(range(out_dims)): - cascaded_transpose_or_permute_ops[-1].replace_all_uses_with( - node.args[0] + # In case the permute chain cancelled each other, the final dims will + # be the same as the initial order. In that case, the chain was nop. + # Otherwise create a new permute op that encompasses the effect of the + # chain. + if dims == list(range(out_dims)): + cascaded_transpose_or_permute_ops[-1].replace_all_uses_with( + cast(torch.fx.Node, node.args[0]) + ) + else: + with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]): + new_permute = graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(node.args[0], dims), ) - else: - with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]): - new_permute = graph.call_function( - exir_ops.edge.aten.permute_copy.default, - args=(node.args[0], dims), - ) - cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute) + new_permute.meta = cascaded_transpose_or_permute_ops[-1].meta + cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute) - # Now erase the chain - for tp in reversed(cascaded_transpose_or_permute_ops): - graph.erase_node(tp) + # Now erase the chain (except the first node which will be handled by the interface) + for tp in reversed(cascaded_transpose_or_permute_ops[1:]): + graph.erase_node(tp) - graph_module.recompile() - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.fuse_cascaded_transpose_or_permute_ops(graph_module) - result = super().call(graph_module) - return result + # Return True to indicate the first node in the chain should be removed + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseCascadedViewOps(ExportPass): +class FuseCascadedViewOps(RemoveOrReplacePassInterface): """ Fuse a cascaded chain of view ops """ - def fuse_cascaded_view_ops(self, graph_module: torch.fx.GraphModule): - view_target = exir_ops.edge.aten.view_copy.default - for view_node in graph_module.graph.find_nodes( - op="call_function", target=view_target, sort=True - ): - input_view = view_node.args[0] - if input_view.op != "call_function" or input_view.target != view_target: - continue + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.view_copy.default] - view_node.replace_input_with(input_view, input_view.args[0]) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Check if the input to this view node is also a view node + input_view = node.args[0] + if not isinstance(input_view, torch.fx.Node): + return False - graph_module.recompile() + if ( + input_view.op != "call_function" + or input_view.target != exir_ops.edge.aten.view_copy.default + ): + return False - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.fuse_cascaded_view_ops(graph_module) - dead_code_elimination_pass(graph_module) - result = super().call(graph_module) - return result + # Replace the input of this view node with the input of the cascaded view + # This effectively "skips" the intermediate view node + node.replace_input_with(input_view, cast(torch.fx.Node, input_view.args[0])) + return True class FuseOpPairsAcrossBranchesPass(ExportPass): diff --git a/backends/cadence/aot/graph_builder.py b/backends/cadence/aot/graph_builder.py index 2cfd7900e8e..f609ba55472 100644 --- a/backends/cadence/aot/graph_builder.py +++ b/backends/cadence/aot/graph_builder.py @@ -44,12 +44,12 @@ class GraphBuilder(ExportPass): gm = builder.get_graph_module() """ - def __init__(self) -> None: + def __init__(self, fake_tensor_mode: Optional[FakeTensorMode] = None) -> None: self.exporter = ExportPass() self.tracer: ExportPass.ExportTracer = self.ExportTracer( self, torch.fx.graph.CodeGen() ) - self.fake_tensor_mode = FakeTensorMode( + self.fake_tensor_mode: FakeTensorMode = fake_tensor_mode or FakeTensorMode( allow_fallback_kernels=False, allow_non_fake_inputs=True, ) diff --git a/backends/cadence/aot/memory_constraints.py b/backends/cadence/aot/memory_constraints.py index 0eaaa8987c6..4b637da8d48 100644 --- a/backends/cadence/aot/memory_constraints.py +++ b/backends/cadence/aot/memory_constraints.py @@ -417,6 +417,10 @@ def is_slice_view(self, node: torch.fx.Node) -> bool: return not self.constraint.is_alias_of(source_info.source, node) return False + def has_relative_placement_constraint(self, node: torch.fx.Node) -> bool: + """Return if `node` already has any relative placement constraint.""" + return self.constraint.get_relative_placement_source(node) is not None + # Return true if the cat node performs concatenation along outermost dimension def is_cat_along_outermost_dim( self, graph_module: torch.fx.GraphModule, cat_node: torch.fx.Node @@ -481,6 +485,17 @@ def is_removable_cat_op( if any(self.is_slice_view(arg) for arg in cat_tensors): return False + # If any of the tensors already has a relative placement constraint, + # we cannot add a new constraint for this cat without conflicting. + # This can happen when a tensor is used in multiple cat operations. + if any(self.has_relative_placement_constraint(arg) for arg in cat_tensors): + return False + + # If the same tensor appears multiple times in the cat inputs, + # we cannot place it at multiple different offsets relative to the output. + if len(cat_tensors) != len(set(cat_tensors)): + return False + # Many ops in HiFi require the input to be aligned to 8-byte boundary. # If the cat is not the graph's output, then ensure that the relative # offset of any concatenated non-placeholder tensor is a multiple of diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 35b4cbf3902..06c5854b120 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -7,7 +7,7 @@ # pyre-strict from math import prod -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import torch from executorch.backends.cadence.aot.utils import ( @@ -21,6 +21,67 @@ lib = Library("cadence", "DEF") +# Track meta kernels that have been registered +_REGISTERED_META_KERNELS: set[str] = set() + + +# Original register_fake function to use for registrations +_register_fake_original = register_fake + +_OUTPUTS_TYPE = torch.Tensor | tuple[torch.Tensor, ...] + + +def _validate_ref_impl_exists() -> None: + """ + Validates that all registered meta kernels have corresponding reference implementations. + This is called at module initialization time after both files have been imported. + """ + + # Import here after module initialization to ensure ref_implementations has been loaded + from executorch.backends.cadence.aot.ref_implementations import ( + get_registered_ref_implementations, + ) + + # If reference implementation should not be in + # executorch.backends.cadence.aot.ref_implementations, add here + _SKIP_OPS = { + "cadence::roi_align_box_processor", + } + + ref_impls = get_registered_ref_implementations() + error_impls = [] + for op_name in _REGISTERED_META_KERNELS: + # Strip the namespace prefix if present (e.g., "cadence::" -> "") + op_name_clean = op_name.split("::")[-1] if "::" in op_name else op_name + + if op_name_clean not in ref_impls: + if op_name not in _SKIP_OPS: + error_impls.append(op_name) + + if error_impls: + error_msg = ( + f"The following {len(error_impls)} meta kernel registrations are missing reference implementations:\n" + + "\n".join(f" - {op}" for op in error_impls) + + "\n\nPlease add reference implementations in ref_implementations.py using " + + "@impl_tracked(m, '')." + ) + + raise RuntimeError(error_msg) + + +# Wrap register_fake to track all registrations +def register_fake( + op_name: str, +) -> Callable[[Callable[..., _OUTPUTS_TYPE]], Callable[..., _OUTPUTS_TYPE]]: + """ + Wrapped version of register_fake that tracks all meta kernel registrations. + This enables validation that all meta kernels have reference implementations. + """ + global _REGISTERED_META_KERNELS + _REGISTERED_META_KERNELS.add(op_name) + return _register_fake_original(op_name) + + lib.define( "quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" ) @@ -28,12 +89,78 @@ "quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantize_per_tensor_asym8s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "quantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + +lib.define( + "quantize_per_tensor_asym8u(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "quantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + +lib.define( + "quantize_per_tensor_asym16s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "quantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + +lib.define( + "quantize_per_tensor_asym16u(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "quantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + +lib.define( + "quantize_per_tensor_asym32s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "quantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + lib.define( "dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" ) lib.define( "dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "dequantize_per_tensor_asym8s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "dequantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "dequantize_per_tensor_asym8u(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "dequantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "dequantize_per_tensor_asym16s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "dequantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "dequantize_per_tensor_asym16u(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "dequantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + +lib.define( + "dequantize_per_tensor_asym32s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "dequantize_per_tensor_asym32s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)" @@ -86,28 +213,28 @@ ) lib.define( - "quantized_conv_nhwc(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)" + "quantized_conv2d_nhwc(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv2d_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nhwc.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv2d_nhwc.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv2d_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nchw(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)" + "quantized_conv2d_nchw(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nchw.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv2d_nchw.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv2d_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)" @@ -122,76 +249,100 @@ "quantized_matmul_asym8sxasym8s_asym8s.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv1d_ncl_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv1d_ncl_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv1d_nlc_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv1d_nlc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv1d_nlc_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv1d_nlc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" + "quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) lib.define( - "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_matmul_asym8uxasym8u_asym8u(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)" @@ -200,10 +351,6 @@ "quantized_matmul_asym8uxasym8u_asym8u.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)" ) -lib.define( - "convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " - "int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)" -) lib.define( "transposed_convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " "int[] dilation, SymInt[] output_padding, int groups, bool channel_last=False) -> (Tensor Y)" @@ -230,7 +377,7 @@ "float out_scale, int out_zero_point) -> (Tensor Z)" ) lib.define( - "quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, " + "quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " "Tensor indices, bool pruned_weights=False) -> (Tensor X)" ) lib.define( @@ -239,7 +386,7 @@ "Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor out)" ) lib.define( - "avg_pool2d(Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, " + "avg_pool2d(Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=[], bool ceil_mode=False, " "bool count_include_pad=True, int? divisor_override=None, Tensor? in_zero_point=None, bool channel_last=False) -> (Tensor out)" ) lib.define( @@ -250,7 +397,6 @@ "im2row.per_tensor(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, " "int in_zero_point, bool channel_last=False) -> (Tensor out)" ) -lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)") lib.define( "linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)" ) @@ -300,6 +446,19 @@ "rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)" +) +lib.define( + "quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)" +) +lib.define( + "quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)" +) +lib.define( + "quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)" +) + # Load/store with iDMA. These only exist before memory planning. # Post memory planning, we check that outputs/inputs for the load/store are in # DTCM and replace idma_load/idma_store with idma_copy. @@ -326,8 +485,28 @@ # ------------------------------------ # # Migrated from the custom_ops.yaml files containing different operator variants (e.g., .out, .tensor_out) lib.define( - "convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, " - "int groups, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" + "conv1d(Tensor input, Tensor weight, Tensor bias, int[1] stride, SymInt[1] padding, int[1] dilation, " + "int groups) -> Tensor" +) +lib.define( + "conv1d.out(Tensor input, Tensor weight, Tensor bias, int[1] stride, SymInt[1] padding, int[1] dilation, " + "int groups, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "conv2d(Tensor input, Tensor weight, Tensor bias, int[2] stride, SymInt[2] padding, int[2] dilation, " + "int groups) -> Tensor" +) +lib.define( + "conv2d.out(Tensor input, Tensor weight, Tensor bias, int[2] stride, SymInt[2] padding, int[2] dilation, " + "int groups, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "conv3d(Tensor input, Tensor weight, Tensor bias, int[3] stride, SymInt[3] padding, int[3] dilation, " + "int groups) -> Tensor" +) +lib.define( + "conv3d.out(Tensor input, Tensor weight, Tensor bias, int[3] stride, SymInt[3] padding, int[3] dilation, " + "int groups, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "transposed_convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " @@ -393,7 +572,6 @@ lib.define( "fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)" ) -lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)") lib.define( "quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" @@ -411,7 +589,7 @@ "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, " + "quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " "Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)" ) @@ -422,7 +600,7 @@ "Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "avg_pool2d.out(Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, " + "avg_pool2d.out(Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=[], " "bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, " "Tensor? in_zero_point=None, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) @@ -455,12 +633,35 @@ "int sampling_ratio, bool aligned) -> (Tensor out)" ) lib.define( - "_softmax_f32_f32(Tensor self, int dim, bool? half_to_float) -> (Tensor out)" + "_softmax_f32_f32(Tensor self, int dim, bool? half_to_float = None) -> (Tensor out)" +) +lib.define( + "_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float = None, *, Tensor(a!) out) -> Tensor(a!)" +) + +lib.define( + "quantized_w8a32_linear(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale) -> Tensor" +) +lib.define( + "quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)" +) + +lib.define( + "quantized_w8a32_conv(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale) -> Tensor" +) +lib.define( + "quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)" ) + lib.define( - "_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale) -> Tensor" ) +lib.define( + "quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)" +) + + # Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined aten_lib = Library("aten", "FRAGMENT") aten_lib.define( @@ -517,6 +718,66 @@ def quantize_per_tensor_meta( return input.new_empty(input.size(), dtype=dtype) +@register_fake("cadence::quantize_per_tensor_asym8s") +def quantize_per_tensor_asym8s_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=dtype) + + +@register_fake("cadence::quantize_per_tensor_asym8u") +def quantize_per_tensor_asym8u_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=dtype) + + +@register_fake("cadence::quantize_per_tensor_asym16s") +def quantize_per_tensor_asym16s_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=dtype) + + +@register_fake("cadence::quantize_per_tensor_asym16u") +def quantize_per_tensor_asym16u_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=dtype) + + +@register_fake("cadence::quantize_per_tensor_asym32s") +def quantize_per_tensor_asym32s_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=dtype) + + @register_fake("cadence::dequantize_per_tensor") def dequantize_per_tensor_meta( input: torch.Tensor, @@ -529,6 +790,66 @@ def dequantize_per_tensor_meta( return input.new_empty(input.size(), dtype=torch.float) +@register_fake("cadence::dequantize_per_tensor_asym8s") +def dequantize_per_tensor_asym8s_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=torch.float) + + +@register_fake("cadence::dequantize_per_tensor_asym8u") +def dequantize_per_tensor_asym8u_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=torch.float) + + +@register_fake("cadence::dequantize_per_tensor_asym16s") +def dequantize_per_tensor_asym16s_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=torch.float) + + +@register_fake("cadence::dequantize_per_tensor_asym16u") +def dequantize_per_tensor_asym16u_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=torch.float) + + +@register_fake("cadence::dequantize_per_tensor_asym32s") +def dequantize_per_tensor_asym32s_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=torch.float) + + @register_fake("cadence::quantized_add") def quantized_add_meta( X: torch.Tensor, @@ -680,8 +1001,8 @@ def quantized_linear_asym8uxasym8u_asym8u_per_tensor_meta( return src.new_empty(out_size, dtype=src.dtype) -@register_fake("cadence::quantized_conv_nhwc") -def quantized_conv_nhwc_meta( +@register_fake("cadence::quantized_conv2d_nhwc") +def quantized_conv2d_nhwc_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -709,9 +1030,9 @@ def quantized_conv_nhwc_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], True, ) @@ -724,8 +1045,8 @@ def quantized_conv_nhwc_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nchw") -def quantized_conv_nchw_meta( +@register_fake("cadence::quantized_conv2d_nchw") +def quantized_conv2d_nchw_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -753,9 +1074,9 @@ def quantized_conv_nchw_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], False, ) @@ -768,8 +1089,8 @@ def quantized_conv_nchw_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nchw.per_tensor") -def quantized_conv_nchw_per_tensor_meta( +@register_fake("cadence::quantized_conv2d_nchw.per_tensor") +def quantized_conv2d_nchw_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -797,9 +1118,9 @@ def quantized_conv_nchw_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], False, ) @@ -812,8 +1133,8 @@ def quantized_conv_nchw_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nhwc.per_tensor") -def quantized_conv_nhwc_per_tensor_meta( +@register_fake("cadence::quantized_conv2d_nhwc.per_tensor") +def quantized_conv2d_nhwc_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -841,9 +1162,9 @@ def quantized_conv_nhwc_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], True, ) @@ -856,8 +1177,8 @@ def quantized_conv_nhwc_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor") -def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_meta( +@register_fake("cadence::quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor") +def quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -890,9 +1211,9 @@ def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], False, ) @@ -905,8 +1226,8 @@ def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor") -def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_meta( +@register_fake("cadence::quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor") +def quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -939,9 +1260,9 @@ def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], False, ) @@ -954,8 +1275,8 @@ def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor") -def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_meta( +@register_fake("cadence::quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor") +def quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -988,9 +1309,9 @@ def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], True, ) @@ -1003,8 +1324,8 @@ def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor") -def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_meta( +@register_fake("cadence::quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor") +def quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -1037,9 +1358,9 @@ def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], True, ) @@ -1052,8 +1373,8 @@ def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor") -def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_meta( +@register_fake("cadence::quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor") +def quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -1086,9 +1407,9 @@ def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], False, ) @@ -1101,8 +1422,8 @@ def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor") -def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_meta( +@register_fake("cadence::quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor") +def quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -1135,9 +1456,9 @@ def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], False, ) @@ -1150,8 +1471,8 @@ def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor") -def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_meta( +@register_fake("cadence::quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor") +def quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -1184,9 +1505,9 @@ def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], True, ) @@ -1199,8 +1520,8 @@ def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor") -def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta( +@register_fake("cadence::quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor") +def quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -1233,9 +1554,9 @@ def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], True, ) @@ -1248,8 +1569,10 @@ def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor") -def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( +@register_fake( + "cadence::quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor" +) +def quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -1282,9 +1605,9 @@ def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], False, ) @@ -1297,8 +1620,10 @@ def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor") -def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( +@register_fake( + "cadence::quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor" +) +def quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -1331,9 +1656,9 @@ def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], False, ) @@ -1346,8 +1671,10 @@ def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor") -def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( +@register_fake( + "cadence::quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor" +) +def quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -1380,9 +1707,9 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], True, ) @@ -1395,8 +1722,10 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -@register_fake("cadence::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor") -def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( +@register_fake( + "cadence::quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor" +) +def quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -1429,9 +1758,9 @@ def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( get_conv1d_output_size( in_size, out_channels, - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], kernel_size[0], True, ) @@ -1646,15 +1975,6 @@ def im2row_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) -# Define the abstract implementations of the operators as required -@register_fake("cadence::linalg_vector_norm") -def linalg_vector_norm_meta( - X: torch.Tensor, -) -> torch.Tensor: - # Output of norm is a scalar, so we return a [] tensor - return X.new_empty([], dtype=X.dtype) - - @register_fake("cadence::linalg_svd") def linalg_svd_meta( A: torch.Tensor, @@ -1848,8 +2168,8 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta( return src.new_empty(out_size, dtype=src.dtype) -@register_fake("cadence::convolution") -def convolution_meta( +@register_fake("cadence::conv1d") +def conv1d_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -1857,32 +2177,109 @@ def convolution_meta( padding: Tuple[int], dilation: Tuple[int], groups: int, - channel_last: bool = False, ) -> torch.Tensor: - if channel_last: - out_channels, *kernel_size, _ = weight.shape - else: - out_channels, _, *kernel_size = weight.shape + # Validate tensor dimensions + assert len(input.shape) == 3, f"Conv1d expects 3D input, got {len(input.shape)}D" + assert len(weight.shape) == 3, f"Conv1d expects 3D weight, got {len(weight.shape)}D" + + # Extract dimensions + batch_size, in_channels, length = input.shape + out_channels, weight_in_channels, kernel_size = weight.shape + + # Validate groups parameter and channel consistency + assert groups > 0, f"groups must be positive, got {groups}" + assert ( + in_channels % groups == 0 + ), f"in_channels ({in_channels}) must be divisible by groups ({groups})" + assert ( + out_channels % groups == 0 + ), f"out_channels ({out_channels}) must be divisible by groups ({groups})" + + # Validate weight channels match input channels divided by groups + expected_weight_in_channels = in_channels // groups + assert ( + weight_in_channels == expected_weight_in_channels + ), f"Expected weight to have {expected_weight_in_channels} input channels (in_channels/groups), but got {weight_in_channels}" + + output_size = get_conv1d_output_size( + input.shape, + out_channels, + stride[0], + padding[0], + dilation[0], + kernel_size, + False, + ) + + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::conv2d") +def conv2d_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, +) -> torch.Tensor: + assert ( + len(weight.shape) == 4 + ), f"Conv2d expects a 4D weight, got {len(weight.shape)}D" + out_channels, _, *kernel_size = weight.shape in_size = input.shape - # Assert that the input tensor has at least 3 dimensions, and at most 6 - assert len(in_size) > 2 - assert len(in_size) < 6 + assert len(in_size) == 4, f"conv2d expects 4D input, got {len(in_size)}D" - # Compute the output tensor size - output_size = ( - get_conv1d_output_size( - in_size, - out_channels, - stride[0], - padding[0], - dilation[0], - kernel_size[0], - channel_last, - ) - if len(in_size) == 3 - else get_conv2d_output_size( - in_size, out_channels, stride, padding, dilation, kernel_size, channel_last - ) + output_size = get_conv2d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size, False + ) + + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::conv3d") +def conv3d_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int, int, int], + padding: Tuple[int, int, int], + dilation: Tuple[int, int, int], + groups: int, +) -> torch.Tensor: + assert ( + len(weight.shape) == 5 + ), f"Conv3d expects a 5D weight, got {len(weight.shape)}D" + out_channels, _, *kernel_size = weight.shape + in_size = input.shape + assert len(in_size) == 5, f"conv3d expects 5D input, got {len(in_size)}D" + + # Helper to compute 3D convolution output size + def get_conv3d_output_size( + in_size: torch.Size, + out_channels: int, + stride: Tuple[int, int, int], + padding: Tuple[int, int, int], + dilation: Tuple[int, int, int], + kernel_size: list[int], + ) -> torch.Size: + N, C, D, H, W = in_size + + dout = (D + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[ + 0 + ] + 1 + hout = (H + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[ + 1 + ] + 1 + wout = (W + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) // stride[ + 2 + ] + 1 + + return torch.Size((N, out_channels, dout, hout, wout)) + + output_size = get_conv3d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size ) return input.new_empty(output_size, dtype=input.dtype) @@ -2013,10 +2410,10 @@ def avg_pool2d_meta( kernel_size: Tuple[int], stride: Tuple[int], padding: Tuple[int], - ceil_mode: bool, - count_include_pad: Optional[bool] = True, + ceil_mode: bool = False, + count_include_pad: bool = True, divisor_override: Optional[int] = None, - in_zero_point: Optional[int] = None, + in_zero_point: Optional[torch.Tensor] = None, channel_last: bool = False, ) -> torch.Tensor: # Use torch native meta kernels when operator semantics are similar @@ -2042,14 +2439,21 @@ def transposed_im2row_meta( in_zero_point: torch.Tensor, channel_last: bool = False, ) -> torch.Tensor: + """ + Shape inference for transposed_im2row operation. + + Returns shape: (N, H_out * W_out, K_h * K_w * C_in) + """ if len(input.shape) == 3: height_dim = 1 if channel_last else 2 input = input.unsqueeze(height_dim) batch_size = input.shape[0] - n_input_plane = input.shape[3] if channel_last else input.shape[1] + n_input_channels = input.shape[3] if channel_last else input.shape[1] input_height = input.shape[1] if channel_last else input.shape[2] input_width = input.shape[2] if channel_last else input.shape[3] + + # Calculate output spatial dimensions output_height = ( (input_height - 1) * stride[0] - 2 * padding[0] @@ -2064,13 +2468,37 @@ def transposed_im2row_meta( + output_padding[1] + 1 ) - n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1] - output_length = output_height * output_width - output_size = torch.Size((batch_size, output_length, n_output_plane)) + + # Patch size is kernel_h * kernel_w * in_channels + patch_size = kernel_size[0] * kernel_size[1] * n_input_channels + num_patches = output_height * output_width + output_size = torch.Size((batch_size, num_patches, patch_size)) return input.new_empty(output_size, dtype=input.dtype) +@register_fake("cadence::quantized_embedding_byte") +def quantized_embedding_byte_meta( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: torch.Tensor | None, + indices: torch.Tensor, + pruned_weights: bool = False, +) -> torch.Tensor: + assert not pruned_weights + assert len(weight.shape) == 2 + assert 1 <= len(weight_scales.shape) <= 2 + if len(weight_scales.shape) == 2: + num_groups = weight_scales.shape[-1] + assert weight.shape[1] % num_groups == 0 + + if weight_zero_points is not None: + assert weight_zero_points.shape == weight_scales.shape + + assert 1 <= len(indices.shape) <= 2 + return torch.empty(*indices.shape, weight.shape[1], dtype=torch.float32) + + @register_fake("cadence::where_Scalar") def where_Scalar_meta( condition: torch.Tensor, @@ -2130,7 +2558,9 @@ def idma_load_impl( task_num: int = 0, channel: int = 0, ) -> torch.Tensor: - return copy_idma_copy_impl(src, task_num, channel) + res = copy_idma_copy_impl(src, task_num, channel) + assert isinstance(res, torch.Tensor) + return res @register_fake("cadence::idma_store") @@ -2139,7 +2569,9 @@ def idma_store_impl( task_num: int = 0, channel: int = 0, ) -> torch.Tensor: - return copy_idma_copy_impl(src, task_num, channel) + res = copy_idma_copy_impl(src, task_num, channel) + assert isinstance(res, torch.Tensor) + return res @register_fake("cadence::roi_align_box_processor") @@ -2153,11 +2585,257 @@ def roi_align_box_processor_meta( return rois.new_empty((rois.shape[0], 80), dtype=torch.uint8) +@register_fake("cadence::quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor") +def quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + assert input.dim() == 3 and weight.dim() == 3 + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) + out_channels, _, kernel_size = weight.shape + output_size = get_conv1d_output_size( + input.shape, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size, + False, + ) + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::quantized_conv1d_ncl_asym8uxsym8u_asym8u.per_tensor") +def quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + assert input.dim() == 3 and weight.dim() == 3 + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) + out_channels, _, kernel_size = weight.shape + output_size = get_conv1d_output_size( + input.shape, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size, + False, + ) + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::quantized_conv1d_nlc_asym8sxsym8s_asym8s.per_tensor") +def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + assert input.dim() == 3 and weight.dim() == 3 + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) + out_channels, kernel_size, _ = weight.shape + output_size = get_conv1d_output_size( + input.shape, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size, + True, + ) + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::quantized_conv1d_nlc_asym8uxsym8u_asym8u.per_tensor") +def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + assert input.dim() == 3 and weight.dim() == 3 + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) + out_channels, kernel_size, _ = weight.shape + output_size = get_conv1d_output_size( + input.shape, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size, + True, + ) + return input.new_empty(output_size, dtype=input.dtype) + + @register_fake("cadence::_softmax_f32_f32") def softmax_f32_f32_meta( - self: torch.Tensor, + input_tensor: torch.Tensor, dim: int, - dtype: torch.dtype, half_to_float: Optional[bool] = None, ) -> torch.Tensor: - return self.new_empty(self.size(), dtype=self.dtype) + assert input_tensor.dtype == torch.float32, "input_tensor must be float32" + assert not half_to_float, "half_to_float is not supported" + return input_tensor.new_empty(input_tensor.size(), dtype=torch.float32) + + +@register_fake("cadence::quantized_softmax") +def quantized_softmax_meta( + input: torch.Tensor, + mask: torch.Tensor, + dim: int, + in_scale: torch.Tensor, + in_zero_point: torch.Tensor, + out_scale: torch.Tensor, + out_zero_point: torch.Tensor, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=input.dtype) + + +@register_fake("cadence::quantized_softmax.per_tensor") +def quantized_softmax_per_tensor_meta( + input: torch.Tensor, + mask: torch.Tensor, + dim: int, + in_scale: float, + in_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=input.dtype) + + +@register_fake("cadence::quantized_w8a32_linear") +def quantized_w8a32_linear_meta( + src: torch.Tensor, + weight: torch.Tensor, + w_scale: float, + bias: torch.Tensor, + b_scale: float, +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [in_dim, out_dim] + # output comes in empty with shape [leading_dims, out_dim] + src_shape = list(src.shape) + weight_shape = weight.shape + assert (src_shape[-1] % 4) == 0 + if len(src_shape) >= 2: + assert src_shape[-2] == 1 + assert len(weight_shape) == 2 + assert src_shape[-1] == weight_shape[-1] + src_shape[-1] = weight_shape[0] + return src.new_empty(src_shape, dtype=src.dtype) + + +@register_fake("cadence::quantized_w8a32_conv") +def quantized_w8a32_conv_meta( + src: torch.Tensor, + weight: torch.Tensor, + w_scale: float, + bias: torch.Tensor, + b_scale: float, +) -> torch.Tensor: + # src comes in shape [batch, in_length, in_channels] + # weight comes in shape [kernel_dim, out_ch, in_ch] + # output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1] + assert len(src.shape) == 3 + + kernel_size, out_channels, in_channels = weight.shape + assert kernel_size == 3 + assert (out_channels % 4) == 0 + assert (in_channels % 4) == 0 + assert in_channels == src.shape[-1] + + # Compute the output tensor size + output_size = get_conv1d_output_size( + src.permute(0, 2, 1).shape, + out_channels, + stride=1, + padding=0, + dilation=1, + kernel_size=kernel_size, + channel_last=False, + ) + return src.new_empty(output_size, dtype=src.dtype) + + +@register_fake("cadence::quantized_w8a32_gru") +def quantized_w8a32_gru_meta( + inputs: torch.Tensor, + hidden: torch.Tensor, + weights_inputs: torch.Tensor, + w_i_scale: float, + weights_hidden: torch.Tensor, + w_h_scale: float, + bias_inputs: torch.Tensor, + b_i_scale: float, + bias_hidden: torch.Tensor, + b_h_scale: float, +) -> torch.Tensor: + return hidden.new_empty((2, hidden.shape[-1]), dtype=torch.float32) + + +# Validate that all meta kernels have reference implementations +# This is called at module import time to catch missing implementations early +_validate_ref_impl_exists() diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index 9aedef2ce2f..96c30bcdf59 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -6,6 +6,7 @@ # pyre-strict +from abc import abstractmethod from dataclasses import dataclass from typing import Callable, List, Optional, Set, Type, Union @@ -13,9 +14,10 @@ from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket -from executorch.exir.pass_base import PassBase, PassResult +from executorch.exir.pass_base import ExportPass, PassBase, PassResult from torch._ops import OpOverloadPacket +from torch.fx import Node # Is an overlap in tensor lifetime and storage allowed at the current opt level? @@ -229,3 +231,44 @@ def set_arg( def none_throws(x: Optional[PassResult]) -> PassResult: assert x is not None return x + + +class RemoveOrReplacePassInterface(ExportPass): + @property + @abstractmethod + def targets(self) -> list[EdgeOpOverload]: + """ + The list of targets to potentially remove or replace. + """ + raise NotImplementedError("`targets` must be implemented") + + @abstractmethod + def maybe_remove_or_replace(self, node: Node) -> bool: + """ + If the node should be removed/replaced, removes/replaces from the graph. Returns + True if the graph was modified, else False. + """ + raise NotImplementedError("`maybe_remove_or_replace` must be implemented") + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + """ + For each node in targets, if the node should be removed/replaced, + removes/replaces from the graph and returns the modified graph and modified + set to True. + If no node should be removed/replaced, returns a pass result with the original + graph module and False for modified. + """ + changed = False + for target in self.targets: + for module in filter( + lambda m: isinstance(m, torch.fx.GraphModule), graph_module.modules() + ): + for node in module.graph.find_nodes(op="call_function", target=target): + changed |= self.maybe_remove_or_replace(node) + + if changed: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) + + return PassResult(graph_module, False) diff --git a/backends/cadence/aot/program_builder.py b/backends/cadence/aot/program_builder.py index 862ba4e977c..46d730b68ff 100644 --- a/backends/cadence/aot/program_builder.py +++ b/backends/cadence/aot/program_builder.py @@ -12,6 +12,7 @@ from torch import Tensor from torch._export.verifier import Verifier from torch._ops import OpOverload +from torch._subclasses.fake_tensor import FakeTensorMode from torch.export import ExportedProgram from torch.export.exported_program import ModuleCallEntry, ModuleCallSignature from torch.export.graph_signature import ( @@ -37,6 +38,7 @@ def __init__( self, mode: Optional[IrMode] = None, _core_aten_ops_exception_list: Optional[list[OpOverload]] = None, + fake_tensor_mode: Optional[FakeTensorMode] = None, ) -> None: self.input_specs: list[InputSpec] = [] self.output_specs: list[OutputSpec] = [] @@ -46,7 +48,7 @@ def __init__( self._core_aten_ops_exception_list: list[OpOverload] = ( _core_aten_ops_exception_list or [] ) - super().__init__() + super().__init__(fake_tensor_mode=fake_tensor_mode) def insert_input_spec( self, target: str, input_kind: InputKind, value: Tensor diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 729056ea2c8..7093ef19c3d 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -6,23 +6,34 @@ # pyre-strict -from typing import Any, Dict, List, Tuple +from typing import Any, cast, Dict, List, Tuple import torch +from executorch.backends.cadence.aot.compiler_utils import get_shape from executorch.backends.cadence.aot.quantizer.patterns import ( AddmmPattern, AddPattern, BmmPattern, CatPattern, Conv1dPattern, + Conv1dReluPattern0, + Conv1dReluPattern1, Conv2dPattern, + Conv2dReluPattern0, + Conv2dReluPattern1, LayerNormPattern, LinearPattern, MatmulPattern, + MixedW8A32ConvPattern, + MixedW8A32GruPattern, + MixedW8A32LinearPattern, ReluPattern0, ReluPattern1, + SoftmaxPattern, ) from executorch.backends.cadence.aot.quantizer.utils import ( + check_out_zero_point_is_min_range, + copy_node_metadata, create_zero_bias_int32, find_sequential_partitions_aten, get_conv_args, @@ -41,6 +52,13 @@ # Use this part for patterns with multiple aten ops ReluPatterns = (ReluPattern0, ReluPattern1) +ConvPatterns = (Conv1dPattern, Conv2dPattern) +ConvReluPatterns = ( + Conv1dReluPattern0, + Conv1dReluPattern1, + Conv2dReluPattern0, + Conv2dReluPattern1, +) def get_args_and_kwargs_add( @@ -49,33 +67,18 @@ def get_args_and_kwargs_add( dequants_inputs: List[fx.Node], quant_node: fx.Node, ) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: - X_scale_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_inputs[0].args[1]), - {"dtype": torch.float}, - ) - X_zero_point_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_inputs[0].args[2]), - {"dtype": torch.int32}, - ) - Y_scale_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_inputs[1].args[1]), - {"dtype": torch.float}, - ) - Y_zero_point_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_inputs[1].args[2]), - {"dtype": torch.int32}, - ) + X_scale = dequants_inputs[0].args[1] + + X_zero_point = dequants_inputs[0].args[2] + Y_scale = dequants_inputs[1].args[1] + Y_zero_point = dequants_inputs[1].args[2] args = ( inputs_inputs[0], - X_scale_, - X_zero_point_, + X_scale, + X_zero_point, inputs_inputs[1], - Y_scale_, - Y_zero_point_, + Y_scale, + Y_zero_point, quant_node.args[1], quant_node.args[2], ) @@ -113,31 +116,12 @@ def get_args_and_kwargs_linear( else: bias = bias_inputs[0] - # Create single element tensors for weight_zero_point, out_multiplier, out_shift. - # Note that the function expects int32_t, when it would default to int64_t, so - # we explicitly require that type. - weight_zero_point_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_weights[0].args[2]), - {"dtype": torch.int32}, - ) - out_multiplier_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_multiplier[0].item()), - {"dtype": torch.int32}, - ) - out_shift_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_shift[0].item()), - {"dtype": torch.int32}, - ) - args = tuple(inputs_inputs + weights_inputs + [bias]) kwargs = { "src_zero_point": dequants_inputs[0].args[2], - "weight_zero_point": weight_zero_point_, - "out_multiplier": out_multiplier_, - "out_shift": out_shift_, + "weight_zero_point": dequants_weights[0].args[2], + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), "out_zero_point": quant_node.args[2], "offset": None, } @@ -162,22 +146,8 @@ def get_args_and_kwargs_layer_norm( ), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars" # Make the scale and zero_point tensors - scale_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ( - [1], - dequants_inputs[0].args[1], - ), - {"dtype": torch.float32}, - ) - zero_point_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ( - [1], - dequants_inputs[0].args[2], - ), - {"dtype": torch.int32}, - ) + scale = dequants_inputs[0].args[1] + zero_point = dequants_inputs[0].args[2] weight = other_inputs[1] if len(other_inputs) > 1 else None @@ -190,6 +160,20 @@ def get_args_and_kwargs_layer_norm( ), {"dtype": torch.float32}, ) + if len(inputs_inputs) > 0: + if "val" in inputs_inputs[0].meta: + fake_mode = inputs_inputs[0].meta["val"].fake_mode + if fake_mode is not None: + with fake_mode: + fake_weight = torch.full( + other_inputs[0], 1, dtype=torch.float32 + ) + weight.meta["val"] = fake_weight + else: + weight.meta["val"] = torch.full( + other_inputs[0], 1, dtype=torch.float32 + ) + copy_node_metadata(weight, inputs_inputs[0]) bias = other_inputs[2] if len(other_inputs) > 2 else None @@ -202,9 +186,21 @@ def get_args_and_kwargs_layer_norm( ), {"dtype": torch.float32}, ) + if len(inputs_inputs) > 0: + if "val" in inputs_inputs[0].meta: + fake_mode = inputs_inputs[0].meta["val"].fake_mode + if fake_mode is not None: + with fake_mode: + fake_bias = torch.full(other_inputs[0], 0, dtype=torch.float32) + bias.meta["val"] = fake_bias + else: + bias.meta["val"] = torch.full( + other_inputs[0], 0, dtype=torch.float32 + ) + copy_node_metadata(bias, inputs_inputs[0]) # Make the args and kwargs for the replacement op - args = tuple(inputs_inputs + [scale_tensor] + [zero_point_tensor]) + args = tuple(inputs_inputs + [scale, zero_point]) kwargs = { "normalized_shape": other_inputs[0], "weight": weight, @@ -292,31 +288,6 @@ def get_args_and_kwargs_conv( (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) - out_multiplier_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_multiplier[0].item()), - {"dtype": torch.int32}, - ) - out_shift_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_shift[0].item()), - {"dtype": torch.int32}, - ) - - # Create a single element tensor for the weight zero point - weight_zero_point_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], weight_zero_point), - {"dtype": torch.int32}, - ) - - # Create a single element tensor for the bias scale - bias_scale_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], bias_scale), - {"dtype": torch.float32}, - ) - # Make the args and kwargs for the replacement op args = tuple(inputs_inputs + weights_inputs + [bias]) kwargs = { @@ -325,12 +296,12 @@ def get_args_and_kwargs_conv( "dilation": dilation, "groups": groups, "input_zero_point": dequants_inputs[0].args[2], - "weight_zero_point": weight_zero_point_tensor, - "bias_scale": bias_scale_tensor, + "weight_zero_point": weight_zero_point, + "bias_scale": bias_scale, "out_scale": quant_node.args[1], "out_zero_point": quant_node.args[2], - "out_multiplier": out_multiplier_, - "out_shift": out_shift_, + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), } return args, kwargs @@ -351,28 +322,168 @@ def get_args_and_kwargs_relu( # Make the args and kwargs for the replacement op args = tuple(inputs_inputs) - X_zero_point = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_inputs[0].args[2]), - {"dtype": torch.int32}, + kwargs = { + "X_zero_point": dequants_inputs[0].args[2], + "out_zero_point": quant_node.args[2], + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), + } + return args, kwargs + + +def get_args_and_kwargs_mixed_w8a32_linear( + graph_module: GraphModule, + other_inputs: List[fx.Node], + weights_inputs: List[fx.Node], + dequants_weights: List[fx.Node], + bias_inputs: List[fx.Node], + dequants_biases: List[fx.Node], +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: + w_scale_ = dequants_weights[0].args[1] + b_scale_ = dequants_biases[0].args[1] + + args = ( + other_inputs[0], + weights_inputs[0], + w_scale_, + bias_inputs[0], + b_scale_, ) - out_multiplier_ = graph_module.graph.call_function( + kwargs = {} + + return args, kwargs + + +def get_args_and_kwargs_softmax( + graph_module: GraphModule, + inputs_inputs: List[fx.Node], + dequants_inputs: List[fx.Node], + quant_node: fx.Node, + op_node: fx.Node, +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: + # Make a dummy mask tensor + mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0])) + mask_shape = list(mask_shape) if mask_shape else [] + mask_shape[-1] = mask_shape[-1] // 16 + mask_tensor = graph_module.graph.call_function( torch.ops.aten.full.default, - ([1], out_multiplier[0].item()), + ( + mask_shape, + 0.0, + ), {"dtype": torch.int32}, ) - out_shift_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_shift[0].item()), - {"dtype": torch.int32}, + if len(inputs_inputs) > 0: + if "val" in inputs_inputs[0].meta: + fake_mode = inputs_inputs[0].meta["val"].fake_mode + if fake_mode is not None: + with fake_mode: + fake_mask = torch.full(mask_shape, 0.0, dtype=torch.int32) + mask_tensor.meta["val"] = fake_mask + else: + mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32) + copy_node_metadata(mask_tensor, inputs_inputs[0]) + # Make the scale and zero_point tensors + in_scale = dequants_inputs[0].args[1] + in_zero_point = dequants_inputs[0].args[2] + out_scale = quant_node.args[1] + out_zero_point = quant_node.args[2] + + # Make the args and kwargs for the replacement op + args = ( + inputs_inputs[0], + mask_tensor, + op_node.args[1], + in_scale, + in_zero_point, + out_scale, + out_zero_point, ) + kwargs = {} + + return args, kwargs + + +def get_args_and_kwargs_mixed_w8a32_conv( + graph_module: GraphModule, + other_inputs: List[fx.Node], + weights_inputs: List[fx.Node], + dequants_weights: List[fx.Node], + bias_inputs: List[fx.Node], + dequants_biases: List[fx.Node], + op_node: fx.Node, +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: + # Stride, padding, dilation, groups not supported yet + if len(op_node.args) > 3: + assert op_node.args[3] == [1] # Stride + if len(op_node.args) > 4: + assert op_node.args[4] == [0] # Padding + if len(op_node.args) > 5: + assert op_node.args[5] == [1] # Dilation + if len(op_node.args) > 6: + assert op_node.args[6] == 1 # Groups + + assert len(dequants_weights) == 1 + assert len(dequants_biases) == 1 + W_scale_ = dequants_weights[0].args[1] + B_scale_ = dequants_biases[0].args[1] + + transposed_inputs = graph_module.graph.call_function( + torch.ops.aten.permute.default, + (other_inputs[0], [0, 2, 1]), # NCL -> NLC + ) + copy_node_metadata(transposed_inputs, other_inputs[0]) + + transposed_weights = graph_module.graph.call_function( + torch.ops.aten.permute.default, + (weights_inputs[0], [2, 0, 1]), # NCL -> LNC + ) + copy_node_metadata(transposed_weights, weights_inputs[0]) + + args = ( + transposed_inputs, + transposed_weights, + W_scale_, + bias_inputs[0], + B_scale_, + ) + kwargs = {} + + return args, kwargs + + +def get_args_and_kwargs_mixed_w8a32_gru( + graph_module: GraphModule, + other_inputs: List[fx.Node], + weights_inputs: List[fx.Node], + dequants_weights: List[fx.Node], + bias_inputs: List[fx.Node], + dequants_biases: List[fx.Node], + op_node: fx.Node, +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: + # Stride, padding, dilation, groups not supported yet + + assert len(dequants_weights) == 2 + assert len(dequants_biases) == 2 + w_i_scale = dequants_weights[0].args[1] + w_h_scale = dequants_weights[1].args[1] + b_i_scale = dequants_biases[0].args[1] + b_h_scale = dequants_biases[1].args[1] + + args = ( + other_inputs[0], + other_inputs[1], + weights_inputs[0], + w_i_scale, + weights_inputs[1], + w_h_scale, + bias_inputs[0], + b_i_scale, + bias_inputs[1], + b_h_scale, + ) + kwargs = {} - kwargs = { - "X_zero_point": X_zero_point, - "out_zero_point": quant_node.args[2], - "out_multiplier": out_multiplier_, - "out_shift": out_shift_, - } return args, kwargs @@ -390,7 +501,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 pattern.partition_types(), ) for fused_partition in fused_partitions: - anchors = pattern.get_anchors(graph_module, fused_partition) + anchors, op_node = pattern.get_anchors(graph_module, fused_partition) if not anchors or anchors.empty: continue if any(self.is_fused(p.nodes) for p in fused_partition): @@ -431,10 +542,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 bias_inputs = [node.args[0] for node in dequants_biases] other_inputs = [node.args[idx] for node, idx in anchors.others] - # The node is the first index of the list and first of the tuple - op_node = anchors.output[0][0] - - assert len(op_node.users) == 1 + assert op_node is not None, "op_node is None" quant_node = list(op_node.users.keys())[0] with graph_module.graph.inserting_after(op_node): @@ -453,7 +561,27 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 args, kwargs = get_args_and_kwargs_cat( inputs_inputs, other_inputs, op_node ) - elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)): + elif isinstance(pattern, ConvReluPatterns): + # For ConvReLU, we are fusing Conv+ReLU + # This means that the op we want to get + # the replacement args and kwargs for is the + # *conv* op, which is the anchor input, NOT + # the anchor output (which is the ReLU) + check_out_zero_point_is_min_range( + quant_node.args[2], quant_node.args[5] + ) + anchor_input_node = anchors.inputs[0][0] + args, kwargs = get_args_and_kwargs_conv( + graph_module, + inputs_inputs, + dequants_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + quant_node, + anchor_input_node, + ) + elif isinstance(pattern, ConvPatterns): args, kwargs = get_args_and_kwargs_conv( graph_module, inputs_inputs, @@ -494,6 +622,26 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 torch.ops.aten.transpose.int, (weights_inputs[0], 0, 1), ) + if "val" in weights_inputs[0].meta: + original_val = weights_inputs[0].meta["val"] + fake_mode = original_val.fake_mode + if fake_mode is not None: + with fake_mode: + transposed_val = torch.ops.aten.transpose.int( + original_val, 0, 1 + ) + transposed_weights.meta["val"] = transposed_val + else: + transposed_shape = list(original_val.shape) + transposed_shape[0], transposed_shape[1] = ( + transposed_shape[1], + transposed_shape[0], + ) + transposed_weights.meta["val"] = torch.zeros( + transposed_shape, dtype=original_val.dtype + ) + copy_node_metadata(transposed_weights, weights_inputs[0]) + # Call linear with transposed weight args, kwargs = get_args_and_kwargs_linear( graph_module, @@ -511,18 +659,76 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 dequants_inputs, quant_node, ) + elif isinstance(pattern, SoftmaxPattern): + args, kwargs = get_args_and_kwargs_softmax( + graph_module, + inputs_inputs, + dequants_inputs, + quant_node, + op_node, + ) + elif isinstance(pattern, MixedW8A32LinearPattern): + args, kwargs = get_args_and_kwargs_mixed_w8a32_linear( + graph_module, + other_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + dequants_biases, + ) + elif isinstance(pattern, MixedW8A32ConvPattern): + args, kwargs = get_args_and_kwargs_mixed_w8a32_conv( + graph_module, + other_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + dequants_biases, + op_node, + ) + elif isinstance(pattern, MixedW8A32GruPattern): + args, kwargs = get_args_and_kwargs_mixed_w8a32_gru( + graph_module, + other_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + dequants_biases, + op_node, + ) + fused = graph_module.graph.call_function( pattern.replacement_op(), args, kwargs, ) - fused.meta = quant_node.meta - quant_node.replace_all_uses_with(fused) + + if len(anchors.output) > 0: + fused.meta = quant_node.meta + quant_node.replace_all_uses_with(fused) + else: + fused.meta = op_node.meta + op_node.replace_all_uses_with(fused) + if op_node.op == "output": + _ = graph_module.graph.output((fused,)) legalize_graph(graph_module) graph_module.graph.eliminate_dead_code() - # pyre-fixme[7]: Incompatible return type + nodes_list = list(graph_module.graph.nodes) + + if len(nodes_list) > 0 and nodes_list[-1].op != "output": + output_nodes = [n for n in nodes_list if n.op == "output"] + output_arg = output_nodes[0].args[0] + original_meta = output_nodes[0].meta.copy() + + for out_node in output_nodes: + graph_module.graph.erase_node(out_node) + + new_output_node = graph_module.graph.output(output_arg) + new_output_node.meta.update(original_meta) + graph_module.recompile() + return PassResult(graph_module, True) @classmethod # pyre-ignore[2]: Parameter `nodes` has no type specified diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 74987f8b38d..7a11541b601 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import torch from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams @@ -67,7 +67,7 @@ def partition_types(self) -> list[OpOverload]: @abstractmethod def get_anchors( self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> Optional[PartitionAnchors]: + ) -> Tuple[PartitionAnchors, fx.Node]: pass @abstractmethod @@ -85,7 +85,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... addmm_node = fused_partition[0].nodes[-1] @@ -101,15 +101,18 @@ def get_anchors( qscheme=torch.per_tensor_affine, ) - return PartitionAnchors( - inputs=[(addmm_node, 1)], - weights=[(addmm_node, 2)], - biases=[(addmm_node, 0, bias_qspec)], - output=[(addmm_node,)], + return ( + PartitionAnchors( + inputs=[(addmm_node, 1)], + weights=[(addmm_node, 2)], + biases=[(addmm_node, 0, bias_qspec)], + output=[(addmm_node,)], + ), + addmm_node, ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_linear.default + return torch.ops.cadence.quantized_linear.per_tensor class AddPattern(QuantizationPattern): @@ -118,7 +121,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... add_node = fused_partition[0].nodes[-1] @@ -129,19 +132,25 @@ def get_anchors( add_node.args[1], fx.Node ) if not is_tensor_add or len(add_node.kwargs) > 0: - return PartitionAnchors( - empty=True, + return ( + PartitionAnchors( + empty=True, + ), + add_node, ) - return PartitionAnchors( - inputs=[(add_node, 0), (add_node, 1)], - weights=[], - biases=[], - output=[(add_node,)], + return ( + PartitionAnchors( + inputs=[(add_node, 0), (add_node, 1)], + weights=[], + biases=[], + output=[(add_node,)], + ), + add_node, ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_add.default + return torch.ops.cadence.quantized_add.per_tensor class BmmPattern(QuantizationPattern): @@ -150,18 +159,23 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... bmm_node = fused_partition[0].nodes[-1] - return PartitionAnchors( - inputs=[(bmm_node, 0), (bmm_node, 1)], - weights=[], - biases=[], - output=[(bmm_node,)], + return ( + PartitionAnchors( + inputs=[(bmm_node, 0), (bmm_node, 1)], + weights=[], + biases=[], + output=[(bmm_node,)], + ), + bmm_node, ) def replacement_op(self) -> OpOverload: + # TODO: T240804887 This is actually a per-tensor variant, + # we just need to change the name of the op return torch.ops.cadence.quantized_matmul.default @@ -171,7 +185,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... cat_node = fused_partition[0].nodes[-1] @@ -198,13 +212,16 @@ def get_anchors( ) ) - return PartitionAnchors( - inputs=args, - weights=[], - biases=[], - output=[ - (cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node))) - ], + return ( + PartitionAnchors( + inputs=args, + weights=[], + biases=[], + output=[ + (cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node))) + ], + ), + cat_node, ) def replacement_op(self) -> OpOverload: @@ -217,7 +234,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... conv1d_node = fused_partition[0].nodes[-1] @@ -238,16 +255,19 @@ def get_anchors( if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None: bias = [(conv1d_node, 2, bias_qspec)] - return PartitionAnchors( - inputs=[(conv1d_node, 0)], - weights=[(conv1d_node, 1)], - # pyre-fixme[6]: Incompatible parameter type - biases=bias, - output=[(conv1d_node,)], + return ( + PartitionAnchors( + inputs=[(conv1d_node, 0)], + weights=[(conv1d_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(conv1d_node,)], + ), + conv1d_node, ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_conv_nchw.default + return torch.ops.cadence.quantized_conv2d_nchw.per_tensor class Conv2dPattern(QuantizationPattern): @@ -256,7 +276,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... conv2d_node = fused_partition[0].nodes[-1] @@ -277,16 +297,19 @@ def get_anchors( if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None: bias = [(conv2d_node, 2, bias_qspec)] - return PartitionAnchors( - inputs=[(conv2d_node, 0)], - weights=[(conv2d_node, 1)], - # pyre-fixme[6]: Incompatible parameter type - biases=bias, - output=[(conv2d_node,)], + return ( + PartitionAnchors( + inputs=[(conv2d_node, 0)], + weights=[(conv2d_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(conv2d_node,)], + ), + conv2d_node, ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_conv_nchw.default + return torch.ops.cadence.quantized_conv2d_nchw.per_tensor class LayerNormPattern(QuantizationPattern): @@ -295,7 +318,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... layer_norm_node = fused_partition[0].nodes[-1] @@ -311,17 +334,20 @@ def get_anchors( # Weights are used in quantized mode by our kernel, so they are # passed in as others here along with the normalized shape. - return PartitionAnchors( - inputs=[(layer_norm_node, 0)], - weights=[], - biases=[], - # Ordering: normalized_shape, weights, bias - others=others, - output=[(layer_norm_node,)], + return ( + PartitionAnchors( + inputs=[(layer_norm_node, 0)], + weights=[], + biases=[], + # Ordering: normalized_shape, weights, bias + others=others, + output=[(layer_norm_node,)], + ), + layer_norm_node, ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_layer_norm.default + return torch.ops.cadence.quantized_layer_norm.per_tensor class LinearPattern(QuantizationPattern): @@ -330,7 +356,7 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... linear_node = fused_partition[0].nodes[-1] @@ -351,16 +377,19 @@ def get_anchors( if len(linear_node.args) > 2: bias = [(linear_node, 2, bias_qspec)] - return PartitionAnchors( - inputs=[(linear_node, 0)], - weights=[(linear_node, 1)], - # pyre-fixme[6]: Incompatible parameter type - biases=bias, - output=[(linear_node,)], + return ( + PartitionAnchors( + inputs=[(linear_node, 0)], + weights=[(linear_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(linear_node,)], + ), + linear_node, ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_linear.default + return torch.ops.cadence.quantized_linear.per_tensor class MatmulPattern(QuantizationPattern): @@ -369,18 +398,22 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... matmul_node = fused_partition[0].nodes[-1] - return PartitionAnchors( - inputs=[(matmul_node, 0), (matmul_node, 1)], - weights=[], - biases=[], - output=[(matmul_node,)], + return ( + PartitionAnchors( + inputs=[(matmul_node, 0), (matmul_node, 1)], + weights=[], + biases=[], + output=[(matmul_node,)], + ), + matmul_node, ) def replacement_op(self) -> OpOverload: + # TODO: T240804887 This is actually a per-tensor variant, we just need to change the name of the op return torch.ops.cadence.quantized_matmul.default @@ -392,19 +425,22 @@ def partition_types(self) -> List[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: + ) -> Tuple[PartitionAnchors, fx.Node]: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... relu_node = fused_partition[0].nodes[-1] - return PartitionAnchors( - inputs=[(relu_node, 0)], - weights=[], - biases=[], - output=[(relu_node,)], + return ( + PartitionAnchors( + inputs=[(relu_node, 0)], + weights=[], + biases=[], + output=[(relu_node,)], + ), + relu_node, ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_relu.default + return torch.ops.cadence.quantized_relu.per_tensor # Regular relu op @@ -417,3 +453,286 @@ def partition_types(self) -> List[OpOverload]: class ReluPattern1(ReluBasePattern): def partition_types(self) -> List[OpOverload]: return [torch.ops.aten.relu_.default] + + +# This is a base class for Conv+ReLU fusion, since it can be used with two different relu aten ops +class ConvReluBasePattern(QuantizationPattern): + @abstractmethod + def partition_types(self) -> List[OpOverload]: + pass + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # The first node should be conv, the second should be relu + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + conv_node = fused_partition[0].nodes[-1] # Second to last node + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + relu_node = fused_partition[1].nodes[-1] # Last node + + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (conv_node.args[0], conv_node), + (conv_node.args[1], conv_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_affine, + ) + + # Keep bias empty if not supplied + bias = [] + if len(conv_node.args) > 2 and conv_node.args[2] is not None: + bias = [(conv_node, 2, bias_qspec)] + + return ( + PartitionAnchors( + inputs=[(conv_node, 0)], + weights=[(conv_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(relu_node,)], # Output is from the relu node + ), + relu_node, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_conv2d_nchw.per_tensor + + +# Conv1d + regular relu op fusion +class Conv1dReluPattern0(ConvReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv1d.default, torch.ops.aten.relu.default] + + +# Conv1d + alternate relu op fusion +class Conv1dReluPattern1(ConvReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv1d.default, torch.ops.aten.relu_.default] + + +# Conv2d + regular relu op fusion +class Conv2dReluPattern0(ConvReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv2d.default, torch.ops.aten.relu.default] + + +# Conv2d + alternate relu op fusion +class Conv2dReluPattern1(ConvReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default] + + +class SoftmaxPattern(QuantizationPattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten._softmax.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + softmax_node = fused_partition[0].nodes[-1] + + return ( + PartitionAnchors( + inputs=[(softmax_node, 0)], + weights=[], + biases=[], + output=[(softmax_node,)], + ), + softmax_node, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_softmax.per_tensor + + +class MixedW8A32LinearPattern(QuantizationPattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.linear.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # pyre-ignore[29] + linear_layer = fused_partition[0].nodes[-1] + + # Bail if the arguments have different shapes than expected + if len(linear_layer.args) != 3 or len(linear_layer.kwargs) > 0: + return ( + PartitionAnchors( + empty=True, + ), + linear_layer, + ) + + input_node = linear_layer.args[0] + input_shape = input_node.meta["tensor_meta"].shape + + # Bail if the weights are not multiple of 4 (SIMD) + if input_shape[-1] % 4 != 0: + return ( + PartitionAnchors( + empty=True, + ), + linear_layer, + ) + # Currenly only supporting vector-matrix multiplication + if len(input_shape) > 0 and input_shape[-2] != 1: + return ( + PartitionAnchors( + empty=True, + ), + linear_layer, + ) + + return ( + PartitionAnchors( + inputs=[], + weights=[(linear_layer, 1)], + biases=[(linear_layer, 2)], + output=[], + others=[(linear_layer, 0)], + ), + linear_layer, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_w8a32_linear.default + + +class MixedW8A32ConvPattern(QuantizationPattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv1d.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # pyre-ignore[29] + conv_layer = fused_partition[0].nodes[-1] + + # Bail if the arguments have different shapes than expected + # Stride, padding, dilation and groups are not supported + if len(conv_layer.args) != 3 or len(conv_layer.kwargs) > 0: + return ( + PartitionAnchors( + empty=True, + ), + conv_layer, + ) + + cnn_weights = conv_layer.args[1] + if hasattr(cnn_weights.meta, "tensor_meta"): + cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape + # Bail if the channels are not multiple of 4 (SIMD) + if cnn_weights_shape[0] % 4 != 0: + return ( + PartitionAnchors( + empty=True, + ), + conv_layer, + ) + if cnn_weights_shape[1] % 4 != 0: + return ( + PartitionAnchors( + empty=True, + ), + conv_layer, + ) + # Bail if the kernel size is not 3 + if cnn_weights_shape[2] != 3: + return ( + PartitionAnchors( + empty=True, + ), + conv_layer, + ) + + return ( + PartitionAnchors( + inputs=[], + weights=[(conv_layer, 1)], + biases=[(conv_layer, 2)], + output=[], + others=[(conv_layer, 0)], + ), + conv_layer, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_w8a32_conv.default + + +class MixedW8A32GruPattern(QuantizationPattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.gru.input] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + gru_layer = fused_partition[0].nodes[-1] + if len(gru_layer.kwargs) > 0: + return ( + PartitionAnchors( + empty=True, + ), + gru_layer, + ) + + # Bail if input or states are not multiple of 4 (SIMD) + if gru_layer.args[0].meta["tensor_meta"].shape[-1] % 4 != 0: + return ( + PartitionAnchors( + empty=True, + ), + gru_layer, + ) + if gru_layer.args[1].meta["tensor_meta"].shape[-1] % 4 != 0: + return ( + PartitionAnchors( + empty=True, + ), + gru_layer, + ) + + class Wrapper: # noqa: B903 + def __init__(self, args, meta): + self.args = args + self.meta = meta + + wrapper = Wrapper(tuple(gru_layer.args[2]), gru_layer.meta) + + return ( + PartitionAnchors( + inputs=[], + # pyre-fixme[6]: Expected `List[Tuple[Node, int]]` but got `List[Tuple[Wrapper, int]]`. + weights=[(wrapper, 0), (wrapper, 1)], + # pyre-fixme[6]: Expected `List[Union[Tuple[Node, int], Tuple[Node, int, DerivedQuantizationSpec]]]` but got `List[Tuple[Wrapper, int]]`. + biases=[(wrapper, 2), (wrapper, 3)], + output=[], + others=[(gru_layer, 0), (gru_layer, 1)], + ), + gru_layer, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_w8a32_gru.default + + +class RmsNormPattern(QuantizationPattern): + """Pattern that preserves rms_norm from decomposition without matching anything.""" + + def partition_types(self) -> list[torch._ops.OpOverload]: + return [torch.ops.aten.rms_norm.default] + + def get_anchors( + self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + return PartitionAnchors(empty=True), None # pyre-ignore[7] + + def replacement_op(self) -> torch._ops.OpOverload: + return torch.ops.aten.rms_norm.default diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 8c78ac87e58..bdd4cc810a0 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -7,7 +7,7 @@ # pyre-strict from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import final, List, Optional, Tuple, Union import torch from executorch.backends.cadence.aot.quantizer.patterns import ( @@ -16,22 +16,29 @@ BmmPattern, CatPattern, Conv1dPattern, + Conv1dReluPattern0, + Conv1dReluPattern1, Conv2dPattern, + Conv2dReluPattern0, + Conv2dReluPattern1, LayerNormPattern, LinearPattern, MatmulPattern, + MixedW8A32ConvPattern, + MixedW8A32GruPattern, + MixedW8A32LinearPattern, QuantizationPattern, ReluPattern0, ReluPattern1, + RmsNormPattern, + SoftmaxPattern, ) from executorch.backends.cadence.aot.quantizer.utils import ( find_sequential_partitions_aten, is_annotated, no_outside_users, ) - from torch import fx - from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver from torchao.quantization.pt2e.quantizer import ( ComposableQuantizer, @@ -54,6 +61,15 @@ observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), ) +act_qspec_asym16s = QuantizationSpec( + dtype=torch.int16, + quant_min=-32768, + quant_max=32767, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), +) + wgt_qspec_asym8s = QuantizationSpec( dtype=torch.int8, quant_min=-128, @@ -88,6 +104,20 @@ None, ) +qconfig_A16 = QuantizationConfig( + act_qspec_asym16s, + act_qspec_asym16s, + wgt_qspec_asym8s, + None, +) + +qconfig_A32W8sym = QuantizationConfig( + input_activation=None, + output_activation=None, + weight=wgt_qspec_sym8s, + bias=wgt_qspec_sym8s, +) + class CadenceAtenQuantizer(Quantizer): def __init__( @@ -112,7 +142,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: if not no_outside_users(fused_partition): continue - anchors = self.pattern.get_anchors(model, fused_partition) + anchors, _ = self.pattern.get_anchors(model, fused_partition) if not anchors or anchors.empty: continue if is_annotated( @@ -213,6 +243,23 @@ class for explicitly defined quantizers (like CadenceDefaultQuantizer). def __init__(self, quantizers: List[Quantizer]) -> None: super().__init__(quantizers) + @final + def get_ops_to_preserve_from_decomposition(self) -> List[torch._ops.OpOverload]: + """ + Get complete list of ops to preserve from decomposition. + + Delegates preservation choices to QuantizationPattern by aggregating + the pattern's partition_types(), which explicitly declares the root + ops that compose the pattern and should be preserved. + """ + ops: set[torch._ops.OpOverload] = set() + for q in self.quantizers: + if isinstance(q, CadenceAtenQuantizer): + ops.update(q.pattern.partition_types()) + elif isinstance(q, CadenceQuantizer): + ops.update(q.get_ops_to_preserve_from_decomposition()) + return list(ops) + class CadenceDefaultQuantizer(CadenceQuantizer): """ @@ -237,6 +284,15 @@ def __init__( super().__init__([]) +class CadenceRmsNormNopQuantizer(CadenceQuantizer): + """ + Nop quantizer that preserves rms_norm from decomposition. + """ + + def __init__(self) -> None: + super().__init__([CadenceAtenQuantizer(RmsNormPattern(), qconfig_A8W8)]) + + class CadenceWithLayerNormQuantizer(CadenceQuantizer): """ Quantizer including layer norm @@ -260,3 +316,94 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8)) quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8)) super().__init__(quantizers) + + +class CadenceFusedConvReluQuantizer(CadenceQuantizer): + """ + Quantizer using fused conv+relu patterns, and including add and cat + """ + + def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + if quantizers is None: + quantizers = [] + # Order matters here, perform the "fused" patterns first + quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), qconfig_A8W8sym)) + quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), qconfig_A8W8sym)) + quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), qconfig_A8W8sym)) + quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), qconfig_A8W8sym)) + quantizers = quantizers + get_cadence_default_quantizers() + quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8)) + quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8)) + super().__init__(quantizers) + + +class CadenceW8A32MixedQuantizer(CadenceQuantizer): + """ + Quantizer for mixed quantization, 8 bit weights and 32 bit activations + TODO: Experimental quantizer, not yet well supported in OSS + """ + + def __init__(self) -> None: + quantizers = [] + quantizers.append( + CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym) + ) + quantizers.append( + CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym) + ) + quantizers.append( + CadenceAtenQuantizer(MixedW8A32GruPattern(), qconfig_A32W8sym) + ) + super().__init__(quantizers) + + +class CadenceWithSoftmaxQuantizer(CadenceQuantizer): + """ + Quantizer including A16 softmax + """ + + def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + if quantizers is None: + quantizers = get_cadence_default_quantizers() + quantizers.append(CadenceAtenQuantizer(SoftmaxPattern(), qconfig_A16)) + super().__init__(quantizers) + + +class CadenceWith16BitLinearActivationsQuantizer(CadenceQuantizer): + """ + Quantizer including A16 fully_connected + """ + + def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + if quantizers is None: + quantizers = [] + # Add 16-bit quantizers for LinearPattern + quantizers.append(CadenceAtenQuantizer(LinearPattern(), qconfig_A16)) + super().__init__(quantizers) + + +class CadenceWith16BitConvActivationsQuantizer(CadenceQuantizer): + """ + Quantizer including A16 conv + """ + + def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + if quantizers is None: + quantizers = [] + # Add 16-bit quantizers for Conv patterns + quantizers.append(CadenceAtenQuantizer(Conv1dPattern(), qconfig_A16)) + quantizers.append(CadenceAtenQuantizer(Conv2dPattern(), qconfig_A16)) + super().__init__(quantizers) + + +class CadenceWith16BitMatmulActivationsQuantizer(CadenceQuantizer): + """ + Quantizer including A16 matmul + """ + + def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + if quantizers is None: + quantizers = [] + # Add 16-bit quantizers for MatmulPattern + quantizers.append(CadenceAtenQuantizer(MatmulPattern(), qconfig_A16)) + super().__init__(quantizers) diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py index beacd1b9e86..dfc31bfac8c 100644 --- a/backends/cadence/aot/quantizer/utils.py +++ b/backends/cadence/aot/quantizer/utils.py @@ -24,6 +24,12 @@ from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +def copy_node_metadata(dest_node: fx.Node, src_node: fx.Node) -> None: + for key in ["nn_module_stack", "stack_trace", "source_fn_stack"]: + if key in src_node.meta and src_node.meta[key]: + dest_node.meta[key] = src_node.meta[key] + + def quantize_tensor_multiplier( requantize_scale_tensor: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -114,15 +120,45 @@ def create_zero_bias_int32( """ Creates a zero bias tensor with the shape of weight[0] """ - attr_node = getattr(graph_module, weight_node.target) + try: + attr_node = getattr(graph_module, weight_node.target) + except AttributeError: + if "val" in weight_node.meta: + attr_node = weight_node.meta["val"] + else: + param_dict = dict(graph_module.named_parameters()) + if weight_node.target in param_dict: + attr_node = param_dict[weight_node.target] + else: + buffer_dict = dict(graph_module.named_buffers()) + if weight_node.target in buffer_dict: + attr_node = buffer_dict[weight_node.target] + else: + raise AttributeError( + f"Could not find weight tensor for node {weight_node.target}. " + f"Node metadata keys: {list(weight_node.meta.keys())}" + ) + weight_shape = list(attr_node.shape) bias_shape = weight_shape[0] - return graph_module.graph.call_function( + new_node = graph_module.graph.call_function( torch.ops.aten.full.default, ([bias_shape], 0.0), {"dtype": torch.int32}, ) + if "val" in weight_node.meta: + fake_mode = weight_node.meta["val"].fake_mode + if fake_mode is not None: + with fake_mode: + fake_bias = torch.zeros([bias_shape], dtype=torch.int32) + new_node.meta["val"] = fake_bias + else: + new_node.meta["val"] = torch.zeros([bias_shape], dtype=torch.int32) + copy_node_metadata(new_node, weight_node) + + return new_node + def get_bias_qparams( obs_or_fqs: List[ObserverOrFakeQuantize], @@ -234,3 +270,19 @@ def find_sequential_partitions_aten( if _partitions_sequential(candidate): fused_partitions.append(candidate) return fused_partitions + + +def check_out_zero_point_is_min_range( + out_zero_point: int, + out_dtype: torch.dtype, +) -> bool: + """ + Checks if the out_zero_point is the minimum range of the quant type. + """ + if out_dtype == torch.int8: + return out_zero_point == -128 + elif out_dtype == torch.int16: + return out_zero_point == -32768 + elif out_dtype == torch.uint8 or torch.uint16: + return out_zero_point == 0 + return False diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 40ae6d23085..1128ad3167c 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -6,17 +6,53 @@ # pyre-strict - -from typing import Callable +from typing import Callable, Protocol, TypeVar import torch - +import torch.nn as nn +import torch.nn.functional as F from executorch.exir.scalar_type import ScalarType from torch.library import impl, Library - m = Library("cadence", "IMPL", "CompositeExplicitAutograd") +try: + torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib") +except (OSError, RuntimeError): + # Fall back to path-based loading for CMake/OSS builds + from pathlib import Path + + custom_libs: list[Path] = list( + Path(__file__) + .parent.parent.parent.resolve() + .glob("**/kernels/quantized/**/*custom_ops_generated_lib.*") + ) + if custom_libs: + torch.ops.load_library(str(custom_libs[0])) + del Path + +# Registry to track all ops with reference implementations +_REGISTERED_REF_IMPLEMENTATIONS: set[str] = set() + +T = TypeVar("T", bound=Callable[..., torch.Tensor | tuple[torch.Tensor, ...]]) + + +class MyDecorator(Protocol): + def __call__(self, __f: T) -> T: ... + + +# Custom impl wrapper that tracks registrations +def impl_tracked(lib: Library, op_name: str) -> MyDecorator: + """Wrapper around impl that tracks registered ops.""" + _REGISTERED_REF_IMPLEMENTATIONS.add(op_name) + return impl(lib, op_name) + + +def get_registered_ref_implementations() -> set[str]: + """Get all ops that have reference implementations.""" + return _REGISTERED_REF_IMPLEMENTATIONS.copy() + + qdtype_map: dict[ScalarType, torch.dtype] = { ScalarType.QINT8: torch.qint8, ScalarType.QUINT8: torch.quint8, @@ -24,8 +60,7 @@ } -@impl(m, "quantize_per_tensor") -def quantize_per_tensor( +def quantize_per_tensor_common( input_tensor: torch.Tensor, scale: float, zero_point: int, @@ -38,7 +73,7 @@ def quantize_per_tensor( Args: - input_tensor (Tensor): input tensor - - scale (float): Inverse of quantization scale. Derived from the ratio + - scale (float): Quantization scale. Derived from the ratio between the min/max of the floating-point tensor and the min/max of the quantized range, and then inverted. - zero_point (int): The point which represents 0 in the quantized @@ -61,18 +96,81 @@ def quantize_per_tensor( ] if dtype not in supported_quant_types: raise ValueError( - f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}" + f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_quant_types}" ) - quantized = torch.round(input_tensor * scale + zero_point).to(dtype) - return torch.max( - torch.min(quantized, torch.tensor(quant_max)), - torch.tensor(quant_min), + return torch.ops.quantized_decomposed.quantize_per_tensor( + input_tensor, + scale, + zero_point, + quant_min, + quant_max, + dtype, ) -@impl(m, "dequantize_per_tensor") -def dequantize_per_tensor( +def quantize_per_tensor_variant( + dtype: torch.dtype | None = None, +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + """Create a quantize_per_tensor variant with type checking.""" + + def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + def variant( + input_tensor: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + out_dtype: torch.dtype, + ) -> torch.Tensor: + if dtype and out_dtype != dtype: + raise ValueError(f"dtype must be {dtype}. Got {out_dtype}") + + return quantize_per_tensor_common( + input_tensor, + scale, + zero_point, + quant_min, + quant_max, + out_dtype, + ) + + return variant + + return decorator + + +@impl_tracked(m, "quantize_per_tensor") +@quantize_per_tensor_variant() +def quantize_per_tensor() -> torch.Tensor: ... + + +@impl_tracked(m, "quantize_per_tensor_asym8u") +@quantize_per_tensor_variant(torch.uint8) +def quantize_per_tensor_asym8u() -> torch.Tensor: ... + + +@impl_tracked(m, "quantize_per_tensor_asym8s") +@quantize_per_tensor_variant(torch.int8) +def quantize_per_tensor_asym8s() -> torch.Tensor: ... + + +@impl_tracked(m, "quantize_per_tensor_asym16u") +@quantize_per_tensor_variant(torch.uint16) +def quantize_per_tensor_asym16u() -> torch.Tensor: ... + + +@impl_tracked(m, "quantize_per_tensor_asym16s") +@quantize_per_tensor_variant(torch.int16) +def quantize_per_tensor_asym16s() -> torch.Tensor: ... + + +@impl_tracked(m, "quantize_per_tensor_asym32s") +@quantize_per_tensor_variant(torch.int32) +def quantize_per_tensor_asym32s() -> torch.Tensor: ... + + +def dequantize_per_tensor_common( input_tensor: torch.Tensor, scale: float, zero_point: int, @@ -97,7 +195,7 @@ def dequantize_per_tensor( is already provided. - quant_max (int): The largest value in the quantized domain. Unused since scale is already provided. - - dtype (torch.dtype): The type of the output tensor. Must be a floating point type. + - dtype (torch.dtype): The type of the input tensor. """ supported_quant_types = [ torch.int8, @@ -108,33 +206,83 @@ def dequantize_per_tensor( ] if input_tensor.dtype not in supported_quant_types: raise ValueError(f"Input dtype must be one of {supported_quant_types}") - supported_dequant_types = [ - torch.float, - torch.float32, - torch.float16, - torch.bfloat16, - ] - if dtype not in supported_dequant_types: - raise ValueError( - f"Unsupported dtype to dequantize to. Supported dtypes must be one of {supported_dequant_types}" - ) + if input_tensor.dtype != dtype: + raise ValueError("Input dtype must match dtype") - # Needed to prevent underflow in cases where the zero_point is larger than - # the quantized value. - if not input_tensor.dtype.is_signed: - input_tensor = input_tensor.to(torch.int32) + return torch.ops.quantized_decomposed.dequantize_per_tensor( + input_tensor, scale, zero_point, quant_min, quant_max, dtype + ) - return (input_tensor - zero_point).to(dtype) * scale +def dequantize_per_tensor_variant( + dtype: torch.dtype | None = None, +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + """Create a dequantize_per_tensor variant with type checking.""" -@impl(m, "quantized_add") -def quantized_add( + def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + def variant( + input_tensor: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + in_dtype: torch.dtype, + ) -> torch.Tensor: + if dtype and in_dtype != dtype: + raise ValueError(f"dtype must be {dtype}. Got {in_dtype}") + + return dequantize_per_tensor_common( + input_tensor, + scale, + zero_point, + quant_min, + quant_max, + in_dtype, + ) + + return variant + + return decorator + + +@impl_tracked(m, "dequantize_per_tensor") +@dequantize_per_tensor_variant() +def dequantize_per_tensor() -> torch.Tensor: ... + + +@impl_tracked(m, "dequantize_per_tensor_asym8u") +@dequantize_per_tensor_variant(torch.uint8) +def dequantize_per_tensor_asym8u() -> torch.Tensor: ... + + +@impl_tracked(m, "dequantize_per_tensor_asym32s") +@dequantize_per_tensor_variant(torch.int32) +def dequantize_per_tensor_asym32s() -> torch.Tensor: ... + + +@impl_tracked(m, "dequantize_per_tensor_asym16u") +@dequantize_per_tensor_variant(torch.uint16) +def dequantize_per_tensor_asym16u() -> torch.Tensor: ... + + +@impl_tracked(m, "dequantize_per_tensor_asym8s") +@dequantize_per_tensor_variant(torch.int8) +def dequantize_per_tensor_asym8s() -> torch.Tensor: ... + + +@impl_tracked(m, "dequantize_per_tensor_asym16s") +@dequantize_per_tensor_variant(torch.int16) +def dequantize_per_tensor_asym16s() -> torch.Tensor: ... + + +@impl_tracked(m, "quantized_add.per_tensor") +def quantized_add_per_tensor( X: torch.Tensor, - X_scale: torch.Tensor, - X_zero_point: torch.Tensor, + X_scale: float, + X_zero_point: int, Y: torch.Tensor, - Y_scale: torch.Tensor, - Y_zero_point: torch.Tensor, + Y_scale: float, + Y_zero_point: int, out_scale: float, out_zero_point: int, ) -> torch.Tensor: @@ -149,17 +297,17 @@ def quantized_add( out = (X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)) / out_scale + out_zero_point Args: - - X (Tensor): The first operand - - X_scale (Tensor): The ratio between the sizes of X's floating point and quantized + - X: The first operand + - X_scale: The ratio between the sizes of X's floating point and quantized ranges - - X_zero_point (Tensor): The quantized mapping of zero for X - - Y (Tensor): The second operand - - Y_scale (Tensor): The ratio between the sizes of Y's floating point and quantized + - X_zero_point: The quantized mapping of zero for X + - Y: The second operand + - Y_scale: The ratio between the sizes of Y's floating point and quantized ranges - - Y_zero_point (Tensor): The quantized mapping of zero for Y - - out_scale (float): The ratio between the sizes of the output's floating point and + - Y_zero_point: The quantized mapping of zero for Y + - out_scale: The ratio between the sizes of the output's floating point and quantized ranges - - out_zero_point (int): The quantized mapping of zero for the output + - out_zero_point: The quantized mapping of zero for the output """ supported_dtypes = [torch.int8, torch.uint8] if X.dtype != Y.dtype: @@ -180,12 +328,10 @@ def quantized_add( dequant_X = X_scale * (X - X_zero_point) dequant_Y = Y_scale * (Y - Y_zero_point) - out_scale_inv = 1 / out_scale - # q_min/q_max are unused args return quantize_per_tensor( dequant_X + dequant_Y, - out_scale_inv, + out_scale, out_zero_point, torch.iinfo(dtype).min, torch.iinfo(dtype).max, @@ -193,13 +339,78 @@ def quantized_add( ) +@impl_tracked(m, "quantized_add") +def quantized_add( + X: torch.Tensor, + X_scale: torch.Tensor, + X_zero_point: torch.Tensor, + Y: torch.Tensor, + Y_scale: torch.Tensor, + Y_zero_point: torch.Tensor, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + return quantized_add_per_tensor( + X, + float(X_scale.item()), + int(X_zero_point.item()), + Y, + float(Y_scale.item()), + int(Y_zero_point.item()), + out_scale, + out_zero_point, + ) + + +@impl_tracked(m, "quantized_add_asym8sxasym8s_asym8s.per_tensor") +def quantized_add_asym8sxasym8s_asym8s_per_tensor( + X: torch.Tensor, + X_scale: float, + X_zero_point: int, + Y: torch.Tensor, + Y_scale: float, + Y_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + if X.dtype != torch.int8: + raise ValueError("X dtype must be torch.int8") + if Y.dtype != torch.int8: + raise ValueError("Y dtype must be torch.int8") + + return quantized_add_per_tensor( + X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point + ) + + +@impl_tracked(m, "quantized_add_asym8uxasym8u_asym8u.per_tensor") +def quantized_add_asym8uxasym8u_asym8u_per_tensor( + X: torch.Tensor, + X_scale: float, + X_zero_point: int, + Y: torch.Tensor, + Y_scale: float, + Y_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + if X.dtype != torch.uint8: + raise ValueError("X dtype must be torch.int8") + if Y.dtype != torch.uint8: + raise ValueError("Y dtype must be torch.int8") + + return quantized_add_per_tensor( + X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point + ) + + def quantized_linear_common( src: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, in_zero_point: int, weight_zero_point: torch.Tensor | int, - out_multiplier: torch.Tensor | int, + out_multiplier: int, out_shift: int, out_zero_point: int, ) -> torch.Tensor: @@ -217,8 +428,7 @@ def quantized_linear_common( - out_zero_point (int): The quantized mapping of zero for the output - offset (Tensor): Unused """ - out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift) - out_scale_inv = 1 / out_scale + out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift)) N, K = weight.shape @@ -226,10 +436,10 @@ def quantized_linear_common( src = src.view(-1, K) dtype = src.dtype - supported_dtypes = [torch.int8, torch.uint8, torch.int32] + supported_dtypes = [torch.int8, torch.uint8, torch.int16, torch.int32] if dtype not in supported_dtypes: raise ValueError( - f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_dtypes}" + f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_dtypes}" ) out = torch.nn.functional.linear( @@ -239,7 +449,7 @@ def quantized_linear_common( ) return quantize_per_tensor( out, - out_scale_inv, + out_scale, out_zero_point, torch.iinfo(dtype).min, torch.iinfo(dtype).max, @@ -287,81 +497,196 @@ def variant( assert isinstance(weight_zero_point, int) assert isinstance(out_multiplier, int) assert isinstance(out_shift, int) - return quantized_linear_common( - src, - weight, - bias, - in_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - ) + _out_shift = out_shift + _out_multiplier = out_multiplier else: assert isinstance(out_shift, torch.Tensor) + assert isinstance(out_multiplier, torch.Tensor) if out_shift.numel() != 1: raise ValueError("out_shift must be a scalar") - if out_shift.dtype != torch.int64: - raise ValueError("out_shift must be an int64") - - return quantized_linear_common( - src, - weight, - bias, - in_zero_point, - weight_zero_point, - out_multiplier, - int(out_shift.item()), - out_zero_point, - ) + if out_shift.dtype != torch.int32: + raise ValueError("out_shift must be an int32") + + _out_shift = int(out_shift.item()) + _out_multiplier = int(out_multiplier[0].item()) + + return quantized_linear_common( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + _out_multiplier, + _out_shift, + out_zero_point, + ) return variant return decorator -@impl(m, "quantized_linear") +@impl_tracked(m, "quantized_linear") @quantized_linear_variant(False, False) def quantized_linear() -> torch.Tensor: ... -@impl(m, "quantized_linear.per_tensor") +@impl_tracked(m, "quantized_linear.per_tensor") @quantized_linear_variant(True, False) def quantized_linear_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_linear_asym8sxasym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_linear_asym8sxasym8s_asym8s.per_tensor") @quantized_linear_variant(True, False, torch.int8, torch.int8) def quantized_linear_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_linear_asym8uxasym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_linear_asym8uxasym8u_asym8u.per_tensor") @quantized_linear_variant(True, False, torch.uint8, torch.uint8) def quantized_linear_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_fully_connected") +@impl_tracked(m, "quantized_fully_connected") @quantized_linear_variant(False, True) def quantized_fully_connected() -> torch.Tensor: ... -@impl(m, "quantized_fully_connected.per_tensor") +@impl_tracked(m, "quantized_fully_connected.per_tensor") @quantized_linear_variant(True, True) def quantized_fully_connected_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor") @quantized_linear_variant(True, True, torch.int8, torch.int8) def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor") @quantized_linear_variant(True, True, torch.uint8, torch.uint8) def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_layer_norm.per_tensor") +@impl_tracked(m, "fully_connected") +def fully_connected( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + if input_tensor.shape[0] != 1: + raise ValueError("Fully connected linear only supports batch size of 1") + return F.linear(input_tensor, weight, bias) + + +@impl_tracked(m, "quantized_matmul") +def quantized_matmul( + X: torch.Tensor, + X_zero_point: int, + Y: torch.Tensor, + Y_zero_point: int, + bias: torch.Tensor | None, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + transposed: bool = False, +) -> torch.Tensor: + """ + Quantized matmul operation. + + Args: + - X (Tensor): The activations tensor + - X_zero_point (int): The quantized mapping of zero for the input + - Y (Tensor): The weight tensor + - Y_zero_point (int): The quantized mapping of zero for the weight + - bias (Tensor): The bias tensor + - out_multiplier (int): The multiplier used to scale the output + - out_shift (int): The shift used to scale the output + - out_zero_point (int): The quantized mapping of zero for the output + - transposed (bool): Whether Y is transposed. + """ + if bias is not None and not torch.all(bias == 0): + raise ValueError("bias must be None or all zeros since unused in out variant") + + if transposed: + Y = Y.transpose(-1, -2) + + out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift)) + + out = torch.matmul( + (X - X_zero_point).float(), + (Y - Y_zero_point).float(), + ) + return quantize_per_tensor( + out, + out_scale, + out_zero_point, + torch.iinfo(X.dtype).min, + torch.iinfo(X.dtype).max, + X.dtype, + ) + + +@impl_tracked(m, "quantized_matmul_asym8sxasym8s_asym8s") +def quantized_matmul_asym8sxasym8s_asym8s( + X: torch.Tensor, + X_zero_point: int, + Y: torch.Tensor, + Y_zero_point: int, + bias: torch.Tensor | None, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + transposed: bool = False, +) -> torch.Tensor: + if X.dtype != torch.int8: + raise ValueError("X dtype must be torch.int8") + if Y.dtype != torch.int8: + raise ValueError("Y dtype must be torch.int8") + + return quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + ) + + +@impl_tracked(m, "quantized_matmul_asym8uxasym8u_asym8u") +def quantized_matmul_asym8uxasym8u_asym8u( + X: torch.Tensor, + X_zero_point: int, + Y: torch.Tensor, + Y_zero_point: int, + bias: torch.Tensor | None, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + transposed: bool = False, +) -> torch.Tensor: + if X.dtype != torch.uint8: + raise ValueError("X dtype must be torch.uint8") + if Y.dtype != torch.uint8: + raise ValueError("Y dtype must be torch.uint8") + + return quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + ) + + +@impl_tracked(m, "quantized_layer_norm.per_tensor") def quantized_layer_norm_per_tensor( input_tensor: torch.Tensor, X_scale: float, @@ -394,15 +719,16 @@ def quantized_layer_norm_per_tensor( ) float_input_tensor = dequantize_per_tensor( - input_tensor, X_scale, X_zero_point, -128, 127, torch.float32 + input_tensor, X_scale, X_zero_point, -128, 127, input_tensor.dtype ) + assert isinstance(float_input_tensor, torch.Tensor) out = torch.nn.functional.layer_norm( float_input_tensor, normalized_shape, weight, bias, eps=eps ) return quantize_per_tensor( out, - 1 / output_scale, + output_scale, output_zero_point, torch.iinfo(input_tensor.dtype).min, torch.iinfo(input_tensor.dtype).max, @@ -410,6 +736,31 @@ def quantized_layer_norm_per_tensor( ) +@impl_tracked(m, "quantized_layer_norm") +def quantized_layer_norm( + input_tensor: torch.Tensor, + X_scale: torch.Tensor, + X_zero_point: torch.Tensor, + normalized_shape: list[int], + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + output_scale: float, + output_zero_point: int, +) -> torch.Tensor: + return quantized_layer_norm_per_tensor( + input_tensor, + float(X_scale.item()), + int(X_zero_point.item()), + normalized_shape, + weight, + bias, + eps, + output_scale, + output_zero_point, + ) + + def quantized_conv_per_tensor( input_tensor: torch.Tensor, weight: torch.Tensor, @@ -450,9 +801,9 @@ def quantized_conv_per_tensor( (input_tensor - in_zero_point).float(), (weight - weight_zero_point).float(), (bias * bias_scale).float(), - stride[1], - padding[1], - dilation[1], + stride[-1], + padding[-1], + dilation[-1], groups, ) @@ -471,7 +822,7 @@ def quantized_conv_per_tensor( return quantize_per_tensor( float_out, - 1.0 / output_scale, + output_scale, output_zero_point, torch.iinfo(input_tensor.dtype).min, torch.iinfo(input_tensor.dtype).max, @@ -479,8 +830,8 @@ def quantized_conv_per_tensor( ) -@impl(m, "quantized_conv_nchw.per_tensor") -def quantized_conv_nchw_per_tensor( +@impl_tracked(m, "quantized_conv2d_nchw.per_tensor") +def quantized_conv2d_nchw_per_tensor( input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -535,8 +886,8 @@ def quantized_conv_nchw_per_tensor( ) -@impl(m, "quantized_conv_nhwc.per_tensor") -def quantized_conv_nhwc_per_tensor( +@impl_tracked(m, "quantized_conv2d_nchw") +def quantized_conv2d_nchw( input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -545,48 +896,14 @@ def quantized_conv_nhwc_per_tensor( dilation: tuple[int, int], groups: int, in_zero_point: int, - weight_zero_point: int, - bias_scale: float, + weight_zero_point: torch.Tensor, + bias_scale: torch.Tensor, output_scale: float, output_zero_point: int, - out_multiplier: int, - out_shift: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, ) -> torch.Tensor: - """ - Quantized convolution operation. - - Args: - - input_tensor (Tensor): The activations tensor - - weight (Tensor): The weight tensor - - bias (Tensor): The bias tensor - - stride (Tuple[int]): The stride of the convolution - - padding (Tuple[int]): The padding of the convolution - - dilation (Tuple[int]): The dilation of the convolution - - groups (int): The number of groups - - in_zero_point (int): The quantized mapping of zero for the input - - weight_zero_point (int): The quantized mapping of zero for the weight - - bias_scale (float): The quantized bias scale - - output_scale (float): The scale of the output - - output_zero_point (int): The zero point of the output - - out_multiplier (int): Unused - - out_shift (int): Unused - """ - - # Convert to NCHW format to reuse the existing implementation - conv_is_1d = False - if len(input_tensor.shape) == 3: - conv_is_1d = True - input_tensor = input_tensor.movedim(-1, 1).contiguous() - if len(weight.shape) != 3: - raise ValueError("Weight tensor must be 3D if input is 3D") - weight = weight.movedim(-1, 1).contiguous() - else: - input_tensor = input_tensor.movedim(-1, -3) - if len(weight.shape) != 4: - raise ValueError("Weight tensor must be 4D if input is nd > 3") - weight = torch.permute(weight, (0, -1, 1, 2)).contiguous() - - nchw_out = quantized_conv_per_tensor( + return quantized_conv2d_nchw_per_tensor( input_tensor, weight, bias, @@ -595,8 +912,218 @@ def quantized_conv_nhwc_per_tensor( dilation, groups, in_zero_point, - weight_zero_point, - bias_scale, + int(weight_zero_point.item()), + float(bias_scale.item()), + output_scale, + output_zero_point, + int(out_multiplier.item()), + int(out_shift.item()), + ) + + +@impl_tracked(m, "quantized_w8a32_conv") +def quantized_w8a32_conv( + src: torch.Tensor, + weight: torch.Tensor, + w_scale: float, + bias: torch.Tensor, + b_scale: float, +) -> torch.Tensor: + + if len(weight.shape) != 3: + raise ValueError("Weight tensor must be 3D") + + kernel_size, out_channels, in_channels = weight.shape + if kernel_size != 3: + raise ValueError("Kernel size must be 3") + if (out_channels % 4) != 0: + raise ValueError("Out channels must be a multiple of 4") + if (in_channels % 4) != 0: + raise ValueError("In channels must be a multiple of 4") + + assert weight.dtype == torch.int8 + assert bias.dtype == torch.int8 + + # To make compliant with torch (LCN -> NCL format) + weight = weight.permute(1, 2, 0).contiguous() + + # channels last to channels first + src = src.permute(0, 2, 1).contiguous() + + dequant_weight = weight.float() * w_scale + + # Dequantize bias using scale + dequant_bias = bias.float() * b_scale + + # Perform 1D convolution + # src: [batch, in_channel, in_length] + # weight: [out_ch, in_ch, kernel_dim] + # bias: [out_ch] + output = torch.nn.functional.conv1d( + src.float(), + dequant_weight, + dequant_bias, + ) + + return output + + +@impl_tracked(m, "quantized_w8a32_linear") +def quantized_w8a32_linear( + src: torch.Tensor, + weight: torch.Tensor, + w_scale: float, + bias: torch.Tensor, + b_scale: float, +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [in_dim, out_dim] + # output comes in empty with shape [leading_dims, out_dim] + assert weight.dtype == torch.int8 + assert bias.dtype == torch.int8 + if len(src.shape) >= 2: + assert src.shape[-2] == 1, "Only supporting vector-matrix multiplication" + + # need to transpose to make compliant with torch linear (in, out -> out, in) + weight = weight.transpose(1, 0).contiguous() + dequant_weight = weight.float() * w_scale + dequant_bias = bias.float() * b_scale + + output = torch.nn.functional.linear( + src.float(), + dequant_weight, + dequant_bias, + ) + + return output + + +@impl_tracked(m, "quantized_w8a32_gru") +def quantized_w8a32_gru( + inputs: torch.Tensor, + hidden: torch.Tensor, + weights_inputs: torch.Tensor, + w_i_scale: float, + weights_hidden: torch.Tensor, + w_h_scale: float, + bias_inputs: torch.Tensor, + b_i_scale: float, + bias_hidden: torch.Tensor, + b_h_scale: float, +) -> torch.Tensor: + assert weights_inputs.dtype == torch.int8 + assert weights_hidden.dtype == torch.int8 + assert bias_inputs.dtype == torch.int8 + assert bias_hidden.dtype == torch.int8 + assert inputs.dtype == torch.float32 + assert hidden.dtype == torch.float32 + + if len(hidden.shape) > 2: + raise ValueError("Hidden state must be 2D or 1D") + + if len(hidden.shape) == 2 and hidden.shape[0] != 1: + raise ValueError("Leading dimension of hidden state must be 1") + + original_hidden_shape = hidden.shape + hidden = hidden.view(-1) + + hidden_dim = hidden.shape[0] + if (hidden_dim % 4) != 0: + raise ValueError( + "Hidden dimension must be a multiple of 4 for HiFi SIMD operations" + ) + + dequant_weights_inputs = weights_inputs.float() * w_i_scale + dequant_weights_hidden = weights_hidden.float() * w_h_scale + + # C++ implementation averages the two bias scales + avg_bias_scale = (b_i_scale + b_h_scale) / 2 + dequant_bias_inputs = bias_inputs.float() * avg_bias_scale + dequant_bias_hidden = bias_hidden.float() * avg_bias_scale + + gi = F.linear(inputs, dequant_weights_inputs, dequant_bias_inputs) + gh = F.linear(hidden, dequant_weights_hidden, dequant_bias_hidden) + + i_r, i_z, i_n = gi.chunk(3, -1) + h_r, h_z, h_n = gh.chunk(3, -1) + + reset_gate = torch.sigmoid(i_r + h_r) + update_gate = torch.sigmoid(i_z + h_z) + new_gate = torch.tanh(i_n + reset_gate * h_n) + + new_hidden = (1 - update_gate) * new_gate + update_gate * hidden + + if new_hidden.shape[0] != 1: + raise ValueError("Leading dimension of hidden state must be 1") + + assert new_hidden.shape == original_hidden_shape + + new_hidden = new_hidden.view(-1) + return torch.stack([new_hidden, new_hidden], dim=0) + + +@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor") +def quantized_conv2d_nhwc_per_tensor( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + """ + Quantized convolution operation. + + Args: + - input_tensor (Tensor): The activations tensor + - weight (Tensor): The weight tensor + - bias (Tensor): The bias tensor + - stride (Tuple[int]): The stride of the convolution + - padding (Tuple[int]): The padding of the convolution + - dilation (Tuple[int]): The dilation of the convolution + - groups (int): The number of groups + - in_zero_point (int): The quantized mapping of zero for the input + - weight_zero_point (int): The quantized mapping of zero for the weight + - bias_scale (float): The quantized bias scale + - output_scale (float): The scale of the output + - output_zero_point (int): The zero point of the output + - out_multiplier (int): Unused + - out_shift (int): Unused + """ + + # Convert to NCHW format to reuse the existing implementation + conv_is_1d = False + if len(input_tensor.shape) == 3: + conv_is_1d = True + input_tensor = input_tensor.movedim(-1, 1).contiguous() + if len(weight.shape) != 3: + raise ValueError("Weight tensor must be 3D if input is 3D") + weight = weight.movedim(-1, 1).contiguous() + else: + input_tensor = input_tensor.movedim(-1, -3) + if len(weight.shape) != 4: + raise ValueError("Weight tensor must be 4D if input is nd > 3") + weight = torch.permute(weight, (0, -1, 1, 2)).contiguous() + + nchw_out = quantized_conv_per_tensor( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, output_scale, output_zero_point, out_multiplier, @@ -609,10 +1136,46 @@ def quantized_conv_nhwc_per_tensor( return nchw_out.movedim(-3, -1).contiguous() +@impl_tracked(m, "quantized_conv2d_nhwc") +def quantized_conv2d_nhwc( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: torch.Tensor, + bias_scale: torch.Tensor, + output_scale: float, + output_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, +) -> torch.Tensor: + return quantized_conv2d_nhwc_per_tensor( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + int(weight_zero_point.item()), + float(bias_scale.item()), + output_scale, + output_zero_point, + int(out_multiplier.item()), + int(out_shift.item()), + ) + + def quantized_conv_variant( layout: str, input_dtype: torch.dtype, weight_dtype: torch.dtype, + is_1d: bool = False, ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: """Create a quantized conv variant with type checking.""" @@ -644,10 +1207,18 @@ def variant( bias.dtype == torch.int32 ), f"Expected bias dtype int32, got {bias.dtype}" + if is_1d: + assert ( + len(input_tensor.shape) == 3 + ), f"1D convolution requires 3D input tensor, got {len(input_tensor.shape)}D" + assert ( + len(weight.shape) == 3 + ), f"1D convolution requires 3D weight tensor, got {len(weight.shape)}D" + # Call the appropriate base function match layout: case "nchw": - return quantized_conv_nchw_per_tensor( + return quantized_conv2d_nchw_per_tensor( input_tensor, weight, bias, @@ -664,7 +1235,7 @@ def variant( out_shift, ) case "nhwc": - return quantized_conv_nhwc_per_tensor( + return quantized_conv2d_nhwc_per_tensor( input_tensor, weight, bias, @@ -688,64 +1259,268 @@ def variant( return decorator -@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8) -def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... +def quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8) -def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... +def quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8) -def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... +def quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8) -def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... +def quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8) -def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... +def quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8) -def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... +def quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8) -def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... +def quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8) -def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... +def quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8) -def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... +def quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> ( + torch.Tensor +): ... -@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8) -def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... +def quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> ( + torch.Tensor +): ... -@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8) -def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... +def quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> ( + torch.Tensor +): ... -@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor") +@impl_tracked(m, "quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8) -def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... +def quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> ( + torch.Tensor +): ... + + +@impl_tracked(m, "quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor") +@quantized_conv_variant("nchw", torch.int8, torch.int8, is_1d=True) +def quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl_tracked(m, "quantized_conv1d_ncl_asym8uxsym8u_asym8u.per_tensor") +@quantized_conv_variant("nchw", torch.uint8, torch.uint8, is_1d=True) +def quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl_tracked(m, "quantized_conv1d_nlc_asym8sxsym8s_asym8s.per_tensor") +@quantized_conv_variant("nhwc", torch.int8, torch.int8, is_1d=True) +def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl_tracked(m, "quantized_conv1d_nlc_asym8uxsym8u_asym8u.per_tensor") +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8, is_1d=True) +def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl_tracked(m, "conv1d") +def conv1d( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int], + padding: tuple[int], + dilation: tuple[int], + groups: int, +) -> torch.Tensor: + conv_out = torch.nn.functional.conv1d( + input_tensor, weight, bias, stride[0], padding[0], dilation[0], groups + ) + + return conv_out + + +@impl_tracked(m, "conv2d") +def conv2d( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, +) -> torch.Tensor: + conv_out = torch.nn.functional.conv2d( + input_tensor, weight, bias, stride, padding, dilation, groups + ) + + return conv_out + + +@impl_tracked(m, "conv3d") +def conv3d( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int, int], + padding: tuple[int, int, int], + dilation: tuple[int, int, int], + groups: int, +) -> torch.Tensor: + conv_out = torch.nn.functional.conv3d( + input_tensor, weight, bias, stride, padding, dilation, groups + ) + + return conv_out + + +@impl_tracked(m, "transposed_convolution") +def transposed_convolution( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + output_padding: tuple[int, int], + groups: int, + channel_last: bool = False, +) -> torch.Tensor: + + # Cadence transposed conv receives weights that have been transformed by the pass: + # 1. Transposed (dims 0 and 1 swapped): [out_channels, in_channels, *kernel] + # 2. Flipped (spatial dimensions reversed) + # We need to reverse both transformations to call PyTorch's conv_transpose + + conv_is_1d = len(input_tensor.shape) == 3 + + # Determine flip dimensions based on weight dimensionality + weight_dim = len(weight.shape) + flip_dims = [-1] if weight_dim == 3 else [-1, -2] + + # Reverse transformation step 1: Unflip the spatial dimensions + weight = torch.flip(weight, dims=flip_dims) + + # Reverse transformation step 2: Transpose back to PyTorch format [in, out, *kernel] + weight = weight.transpose(0, 1).contiguous() + if channel_last: + if conv_is_1d: + input_tensor = input_tensor.movedim(-1, 1).contiguous() + if len(weight.shape) != 3: + raise ValueError("Weight tensor must be 3D if input is 3D") + weight = weight.movedim(-1, 1).contiguous() + else: + input_tensor = input_tensor.movedim(-1, -3) + if len(weight.shape) != 4: + raise ValueError("Weight tensor must be 4D if input is nd > 3") + weight = torch.permute(weight, (0, -1, 1, 2)).contiguous() + + _stride: tuple[int, int] | int = stride + _padding: tuple[int, int] | int = padding + _dilation: tuple[int, int] | int = dilation + _output_padding: tuple[int, int] | int = output_padding + if conv_is_1d: + conv = torch.nn.functional.conv_transpose1d + _stride = stride[0] + _padding = padding[0] + _dilation = dilation[0] + _output_padding = output_padding[0] + else: + conv = torch.nn.functional.conv_transpose2d + + conv_out = conv( + input_tensor, + weight, + bias, + _stride, + _padding, + _output_padding, + groups, + _dilation, + ) + if channel_last: + if conv_is_1d: + conv_out = conv_out.movedim(1, -1).contiguous() + else: + conv_out = conv_out.movedim(-3, -1).contiguous() + + return conv_out + + +@impl_tracked(m, "avg_pool2d") +def avg_pool2d( + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + stride: tuple[int, int], + padding: tuple[int, int], + ceil_mode: bool = False, + count_include_pad: bool = False, + divisor_override: int | None = None, + in_zero_point: torch.Tensor | None = None, + channel_last: bool = False, +) -> torch.Tensor: + if channel_last: + raise NotImplementedError("Channel last is not yet supported for avg_pool2d") + + in_dtype = input_tensor.dtype + pad_h, pad_w = padding + if in_zero_point is not None: + # Avg pool2d does not allow non-0 padding, + # so we manually pad the input + pad_value = in_zero_point.item() + if not count_include_pad: + # To simulate this, just pad with 0s + pad_value = 0 + + input_tensor = torch.nn.functional.pad( + input_tensor, + (pad_w, pad_w, pad_h, pad_h), + mode="constant", + value=pad_value, + ).float() + + padding = (0, 0) + + out = torch.nn.functional.avg_pool2d( + input_tensor, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + if in_zero_point is not None: + min_val = torch.iinfo(in_dtype).min + max_val = torch.iinfo(in_dtype).max + out = torch.clamp(torch.round(out), min_val, max_val) + + return out.to(in_dtype) def quantized_relu_common( @@ -769,9 +1544,11 @@ def quantized_relu_common( if X.dtype not in supported_dtypes: raise ValueError(f"X dtype must be one of {supported_dtypes}. Got {X.dtype}") - out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift) - dequantized_X = torch.where(X > X_zero_point, X - X_zero_point, torch.zeros_like(X)) - return quantize_per_tensor( + out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift)) + dequantized_X = torch.where( + X > X_zero_point, X - X_zero_point, torch.zeros_like(X) + ).to(torch.float32) + out = quantize_per_tensor( dequantized_X, out_scale, out_zero_point, @@ -779,10 +1556,11 @@ def quantized_relu_common( torch.iinfo(X.dtype).max, X.dtype, ) + assert isinstance(out, torch.Tensor) + return out def quantized_relu_variant( - per_tensor: bool, dtype: torch.dtype | None = None, ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: """Create a quantized relu variant with type checking.""" @@ -790,43 +1568,20 @@ def quantized_relu_variant( def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: def variant( X: torch.Tensor, - X_zero_point: torch.Tensor | int, + X_zero_point: int, out_zero_point: int, - out_multiplier: torch.Tensor | int, - out_shift: torch.Tensor | int, + out_multiplier: int, + out_shift: int, ) -> torch.Tensor: - if per_tensor: - if dtype and X.dtype != dtype: - raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}") - - assert isinstance(out_shift, int) - assert isinstance(out_multiplier, int) - _out_shift = out_shift - _out_multiplier = out_multiplier - else: - assert isinstance(out_multiplier, torch.Tensor) - if out_multiplier.numel() > 1: - raise ValueError("Only scalar out_multiplier is supported") - - assert isinstance(out_shift, torch.Tensor) - if out_shift.numel() > 1: - raise ValueError("Only scalar out_shift is supported") - - assert isinstance(X_zero_point, torch.Tensor) - if X_zero_point.shape != X.shape: - raise ValueError( - f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}" - ) - - _out_multiplier = int(out_multiplier.item()) - _out_shift = int(out_shift.item()) + if dtype and X.dtype != dtype: + raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}") return quantized_relu_common( X, X_zero_point, out_zero_point, - _out_multiplier, - _out_shift, + out_multiplier, + out_shift, ) return variant @@ -834,33 +1589,41 @@ def variant( return decorator -@impl(m, "quantized_relu") -@quantized_relu_variant(False) -def quantized_relu() -> torch.Tensor: ... - - -@impl(m, "quantized_relu.per_tensor") -@quantized_relu_variant(True) +@impl_tracked(m, "quantized_relu.per_tensor") +@quantized_relu_variant() def quantized_relu_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_relu_asym8s_asym8s.per_tensor") -@quantized_relu_variant(True, torch.int8) +@impl_tracked(m, "quantized_relu_asym8s_asym8s.per_tensor") +@quantized_relu_variant(torch.int8) def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_relu_asym8u_asym8u.per_tensor") -@quantized_relu_variant(True, torch.uint8) +@impl_tracked(m, "quantized_relu_asym8u_asym8u.per_tensor") +@quantized_relu_variant(torch.uint8) def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "requantize") -def requantize( +@impl_tracked(m, "quantized_relu") +def quantized_relu( + X: torch.Tensor, + X_zero_point: torch.Tensor, + out_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, +) -> torch.Tensor: + return quantized_relu_per_tensor( + X, X_zero_point.item(), out_zero_point, out_multiplier.item(), out_shift.item() + ) + + +@impl_tracked(m, "requantize.per_tensor") +def requantize_per_tensor( input: torch.Tensor, - in_scale: torch.Tensor, - in_zero_point: torch.Tensor, - out_scale: torch.Tensor, - out_zero_point: torch.Tensor, + in_scale: float, + in_zero_point: int, + out_scale: float, + out_zero_point: int, dtype: ScalarType, ) -> torch.Tensor: if dtype in qdtype_map: @@ -869,11 +1632,6 @@ def requantize( torch.dequantize(input), out_scale, out_zero_point, qdtype_map[dtype] ) - # For in_scale or out_scale other than scalar, it requires quant/dequant - # per channel, but the channel dimension value is missing - if in_scale.numel() > 1 or out_scale.numel() > 1: - raise NotImplementedError("Only scalar scales are supported") - quant_min = torch.iinfo(input.dtype).min quant_max = torch.iinfo(input.dtype).max # pyre-fixme[6]: This dtype is actually the right one. @@ -883,15 +1641,568 @@ def requantize( return torch.ops.quantized_decomposed.quantize_per_tensor( torch.ops.quantized_decomposed.dequantize_per_tensor( input, - in_scale.flatten()[0], - in_zero_point.flatten()[0], + in_scale, + in_zero_point, quant_min, quant_max, input.dtype, ), - out_scale.flatten()[0], - out_zero_point.flatten()[0], + out_scale, + out_zero_point, out_quant_min, out_quant_max, dtype, ) + + +@impl_tracked(m, "requantize") +def requantize( + input_tensor: torch.Tensor, + in_scale: torch.Tensor, + in_zero_point: torch.Tensor, + out_scale: torch.Tensor, + out_zero_point: torch.Tensor, + dtype: ScalarType, +) -> torch.Tensor: + return requantize_per_tensor( + input_tensor, + float(in_scale.item()), + int(in_zero_point.item()), + float(out_scale.item()), + int(out_zero_point.item()), + dtype, + ) + + +@impl_tracked(m, "rms_norm") +def rms_norm( + X: torch.Tensor, + normalized_shape: tuple[int], + W: torch.Tensor, + eps: float, +) -> torch.Tensor: + return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X) + + +@impl_tracked(m, "where_Scalar") +def where_Scalar( + condition: torch.Tensor, + if_true: float, + if_false: float, +) -> torch.Tensor: + if condition.dtype != torch.bool: + raise ValueError("condition must be a bool tensor") + + return torch.where(condition, if_true, if_false) + + +@impl_tracked(m, "rope") +def rope( + input_tensor: torch.Tensor, + sin_tensor: torch.Tensor, + cos_tensor: torch.Tensor, + pos: torch.Tensor | None, +) -> torch.Tensor: + original_shape = input_tensor.shape + + if len(original_shape) not in [4, 5]: + raise ValueError( + f"Input tensor must be 4D or 5D. Got {len(original_shape)}D tensor" + ) + if original_shape[0] != 1: + raise ValueError("Input tensor must have batch size 1") + if len(original_shape) == 5: + input_tensor = input_tensor.view( + input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2], -1 + ) + + _, seq, _, hd = input_tensor.shape + + if hd % 2: + raise ValueError("Hidden dimension must be divisible by 2") + + if ( + sin_tensor.size(-1) * 2 != hd + or cos_tensor.size(-1) * 2 != hd + or sin_tensor.size(0) < seq + or cos_tensor.size(0) < seq + ): + raise ValueError( + f"sin_tensor and cos_tensor must have shape {seq}) x {hd // 2}>. Got {sin_tensor.shape} and {cos_tensor.shape}" + ) + + if pos is not None: + if pos.shape != (seq,): + raise ValueError( + f"pos must have shape {input_tensor.shape[1]}. Got {pos.shape}" + ) + sin_tensor = sin_tensor[pos] + cos_tensor = cos_tensor[pos] + + # seq x 1 x hd + sin_tensor = sin_tensor.unsqueeze(1) + cos_tensor = cos_tensor.unsqueeze(1) + + # batch x seq x num_heads x head_dim_by_two + x0, x1 = input_tensor[..., ::2], input_tensor[..., 1::2] + o0 = x0 * cos_tensor - x1 * sin_tensor + o1 = x0 * sin_tensor + x1 * cos_tensor + rotated = torch.cat([o0.view(-1, 1), o1.view(-1, 1)], dim=-1) + return rotated.view(original_shape) + + +@impl_tracked(m, "im2row") +def im2row( + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + dilation: tuple[int, int], + padding: tuple[int, int], + stride: tuple[int, int], + in_zero_point: torch.Tensor, + channel_last: bool = False, +) -> torch.Tensor: + """ + Converts an input tensor into a 2D matrix where each row is a flattened sliding window (patch) + from the input, suitable for use in convolution as a matrix multiplication (im2row). + + Args: + - input_tensor: Input tensor of shape (N, C, H, W) or (N, H, W, C) if channel_last. + - kernel_size: Size of the convolution kernel. + - dilation: Dilation of the convolution kernel. + - padding: Padding to apply to the input. + - stride: Stride of the convolution. + - in_zero_point : Zero point for input quantization (broadcastable to input). + - channel_last: If True, input is in NHWC format, else NCHW. + + Returns: + - Tensor of shape (N, num_patches, patch_size) + """ + if len(input_tensor.shape) == 3: + height_dim = 1 if channel_last else 2 + input_tensor = input_tensor.unsqueeze(height_dim) + + if in_zero_point is not None: + if in_zero_point.numel() != 1 and in_zero_point.shape != ( + input_tensor.shape[0], + ): + raise ValueError( + f"Input zero point must be a scalar or broadcastable to input shape {input_tensor.shape}" + ) + if in_zero_point.dtype != torch.int32: + raise ValueError("Input zero point must be an int32 tensor") + + if channel_last: + input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW + + N, C, H, W = input_tensor.shape + kH, kW = kernel_size + dH, dW = dilation + pH, pW = padding + sH, sW = stride + + # Handle padding with zero point values + if in_zero_point is not None and (pH > 0 or pW > 0): + # Expand zero point to (N, 1, 1, 1) for broadcasting + in_zero_point = in_zero_point.expand(N) + + # Pad input with the per-batch zero point values + input_tensor = torch.stack( + [ + torch.nn.functional.pad( + input_tensor[i], + (pW, pW, pH, pH), + mode="constant", + value=in_zero_point[i].item(), + ) + for i in range(len(input_tensor)) + ] + ) + + padding = (0, 0) # Already padded manually + + # Use unfold to extract sliding local blocks + # Unfold: (N, C, H, W) -> (N, C, L, kH, kW), where L = number of sliding windows + # torch.nn.functional.unfold returns (N, C*kH*kW, L) + patches = torch.nn.functional.unfold( + input_tensor.float(), # unfold not implemented for int + kernel_size=(kH, kW), + dilation=(dH, dW), + padding=padding, + stride=(sH, sW), + ).to( + input_tensor.dtype + ) # (N, C*kH*kW, L) + + # Transpose to (N, L, C*kH*kW) + patches = patches.transpose(1, 2).contiguous() + + # Reshape to (N*L, C*kH*kW) + patches = patches.view(N, -1, C * kH * kW) + + # If channel_last, output should be in NHWC patch order (but im2row is always row-major) + return patches + + +@impl_tracked(m, "im2row.per_tensor") +def im2row_per_tensor( + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + dilation: tuple[int, int], + padding: tuple[int, int], + stride: tuple[int, int], + in_zero_point: int, + channel_last: bool = False, +) -> torch.Tensor: + out = im2row( + input_tensor, + kernel_size, + dilation, + padding, + stride, + torch.tensor(in_zero_point, dtype=torch.int32), + channel_last, + ) + assert isinstance(out, torch.Tensor) + return out + + +@impl_tracked(m, "transposed_im2row") +def transposed_im2row( + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + dilation: tuple[int, int], + padding: tuple[int, int], + stride: tuple[int, int], + output_padding: tuple[int, int], + in_zero_point: torch.Tensor, + channel_last: bool = False, +) -> torch.Tensor: + """ + Converts input tensor into im2row format for transposed convolutions. + For each output position, extracts the kernel-sized patch of input values that + contribute to that position in a transposed convolution. + + Args: + - input_tensor: Input spatial tensor, NCHW or NHWC format (3D or 4D). + - kernel_size: Size of the convolution kernel (kernel_h, kernel_w). + - dilation: Dilation of the convolution kernel. + - padding: Padding to apply to the input. + - stride: Stride of the convolution. + - output_padding: Additional output padding for transposed convolution. + - in_zero_point: Zero point for input quantization (broadcastable to input). + - channel_last: If True, input is in NHWC format, else NCHW. + + Returns: + - 3D tensor of shape (N, output_h * output_w, kernel_h * kernel_w * in_c) + """ + # Handle 1D convolution case by adding height dimension + if len(input_tensor.shape) == 3: + height_dim = 1 if channel_last else 2 + input_tensor = input_tensor.unsqueeze(height_dim) + + if in_zero_point is not None: + if in_zero_point.dtype != torch.int32: + raise ValueError("Input zero point must be an int32 tensor") + + # Move to NCHW for processing if needed + if channel_last: + input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW + + N, C, H_in, W_in = input_tensor.shape + K_h, K_w = kernel_size + device = input_tensor.device + + # Calculate output spatial size + H_out = ( + (H_in - 1) * stride[0] + - 2 * padding[0] + + dilation[0] * (K_h - 1) + + output_padding[0] + + 1 + ) + W_out = ( + (W_in - 1) * stride[1] + - 2 * padding[1] + + dilation[1] * (K_w - 1) + + output_padding[1] + + 1 + ) + + # Create meshgrids for all output positions and kernel positions + h_out_grid = torch.arange(H_out, device=device).view( + -1, 1, 1, 1 + ) # [H_out, 1, 1, 1] + w_out_grid = torch.arange(W_out, device=device).view( + 1, -1, 1, 1 + ) # [1, W_out, 1, 1] + kh_grid = torch.arange(K_h, device=device).view(1, 1, -1, 1) # [1, 1, K_h, 1] + kw_grid = torch.arange(K_w, device=device).view(1, 1, 1, -1) # [1, 1, 1, K_w] + + # Compute input positions for all (h_out, w_out, kh, kw) combinations + # From C++ reference: h_im = _h - ((kernel_h - 1) * dilation_h) + _kh * dilation_h + pad_h + h_im = h_out_grid - (K_h - 1) * dilation[0] + kh_grid * dilation[0] + padding[0] + w_im = w_out_grid - (K_w - 1) * dilation[1] + kw_grid * dilation[1] + padding[1] + + # Check which positions are valid (divisible by stride and within bounds) + # From C++ reference: if (h_im < 0 || h_im >= stride_h * height || h_im % stride_h != 0) + h_valid = (h_im % stride[0] == 0) & (h_im >= 0) & (h_im < stride[0] * H_in) + w_valid = (w_im % stride[1] == 0) & (w_im >= 0) & (w_im < stride[1] * W_in) + valid = h_valid & w_valid # [H_out, W_out, K_h, K_w] + + # Actual input indices (h_im / stride_h from C++ reference) + h_in = h_im // stride[0] + w_in = w_im // stride[1] + + # Clamp indices to valid range (will be masked out anyway) + h_in_safe = h_in.clamp(0, H_in - 1) + w_in_safe = w_in.clamp(0, W_in - 1) + + # Initialize output patches with zero points (vectorized across batches) + # Layout depends on channel_last: NHWC uses [K_h, K_w, C], NCHW uses [C, K_h, K_w] + if channel_last: + # NHWC: patches layout [N, H_out, W_out, K_h, K_w, C] + patches = torch.zeros( + (N, H_out, W_out, K_h, K_w, C), + dtype=input_tensor.dtype, + device=device, + ) + else: + # NCHW: patches layout [N, H_out, W_out, C, K_h, K_w] + patches = torch.zeros( + (N, H_out, W_out, C, K_h, K_w), + dtype=input_tensor.dtype, + device=device, + ) + + # Initialize patches with zero points (vectorized) + if in_zero_point is not None: + if in_zero_point.numel() == 1: + # Scalar zero point - fill all patches + patches.fill_(in_zero_point.item()) + else: + # Per-batch zero points - expand and fill + # in_zero_point: [N] -> [N, 1, 1, 1, 1, 1] or [N, 1, 1, 1, 1, 1] + zp_shape = [N] + [1] * (patches.ndim - 1) + patches = patches + in_zero_point.view(*zp_shape) + + # Flatten the spatial and kernel dimensions for efficient gathering + # h_in_safe, w_in_safe: [H_out, W_out, K_h, K_w] (broadcast shape) + h_flat = h_in_safe.expand(H_out, W_out, K_h, K_w).contiguous().view(-1) + w_flat = w_in_safe.expand(H_out, W_out, K_h, K_w).contiguous().view(-1) + + # Vectorized gathering across all batches and channels using advanced indexing + # Create index tensors with appropriate broadcasting shapes + num_positions = h_flat.shape[0] + + # batch_indices: [N, 1, 1] -> broadcasts to [N, C, num_positions] + batch_indices = torch.arange(N, device=device).view(N, 1, 1) + + # channel_indices: [1, C, 1] -> broadcasts to [N, C, num_positions] + channel_indices = torch.arange(C, device=device).view(1, C, 1) + + # h_flat, w_flat: [1, 1, num_positions] -> broadcasts to [N, C, num_positions] + h_indices = h_flat.view(1, 1, num_positions) + w_indices = w_flat.view(1, 1, num_positions) + + # Advanced indexing gathers all values at once: [N, C, num_positions] + gathered = input_tensor[batch_indices, channel_indices, h_indices, w_indices] + + # Reshape based on channel_last flag + if channel_last: + # NHWC: Reshape to [N, H_out, W_out, K_h, K_w, C] + # gathered: [N, C, H_out*W_out*K_h*K_w] -> [N, H_out*W_out*K_h*K_w, C] -> [N, H_out, W_out, K_h, K_w, C] + gathered = gathered.transpose(1, 2).contiguous() # [N, num_positions, C] + gathered = gathered.view(N, H_out, W_out, K_h, K_w, C) + else: + # NCHW: Reshape to [N, H_out, W_out, C, K_h, K_w] + # gathered: [N, C, H_out*W_out*K_h*K_w] -> [N, C, H_out, W_out, K_h, K_w] -> [N, H_out, W_out, C, K_h, K_w] + gathered = gathered.view(N, C, H_out, W_out, K_h, K_w) + gathered = gathered.permute(0, 2, 3, 1, 4, 5).contiguous() + + # Apply validity mask (vectorized across batches) + # valid: [H_out, W_out, K_h, K_w] -> expand to match gathered shape + if channel_last: + # gathered: [N, H_out, W_out, K_h, K_w, C] + valid_exp = valid.unsqueeze(0).unsqueeze(-1) # [1, H_out, W_out, K_h, K_w, 1] + else: + # gathered: [N, H_out, W_out, C, K_h, K_w] + valid_exp = valid.unsqueeze(0).unsqueeze(3) # [1, H_out, W_out, 1, K_h, K_w] + + patches = torch.where(valid_exp, gathered, patches) + + # Reshape to final output format: [N, H_out * W_out, K_h * K_w * C] + # The reshaping will preserve the correct dimension ordering + if channel_last: + # patches: [N, H_out, W_out, K_h, K_w, C] -> [N, H_out*W_out, K_h*K_w*C] + patches = patches.view(N, H_out * W_out, K_h * K_w * C) + else: + # patches: [N, H_out, W_out, C, K_h, K_w] -> [N, H_out*W_out, C*K_h*K_w] + patches = patches.view(N, H_out * W_out, C * K_h * K_w) + + return patches + + +@impl_tracked(m, "quantized_embedding_byte") +def quantized_embedding_byte( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: torch.Tensor | None, + indices: torch.Tensor, + pruned_weights: bool = False, +) -> torch.Tensor: + if pruned_weights: + raise NotImplementedError("Pruned weights not supported") + + # Cannot use torch.ops.quantized_decomposed.embedding_byte.dtype because + # it doesn't support num_groups == 1 + num_groups = 1 + if len(weight_scales.shape) == 2: + num_groups = weight_scales.shape[1] + + group_size = weight.shape[1] // num_groups + weight = torch.ops.torchao.dequantize_affine.default( + input=weight, + block_size=(1, group_size), + scale=weight_scales, + zero_point=weight_zero_points, + input_dtype=weight.dtype, + quant_min=torch.iinfo(weight.dtype).min, + quant_max=torch.iinfo(weight.dtype).max, + ) + + return weight[indices] + + +@impl_tracked(m, "idma_copy") +def idma_copy(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.Tensor: + return src.clone() + + +@impl_tracked(m, "idma_store") +def idma_store(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.Tensor: + return src.clone() + + +@impl_tracked(m, "idma_load") +def idma_load(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.Tensor: + return src.clone() + + +@impl_tracked(m, "idma_wait") +def idma_wait(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.Tensor: + return src.clone() + + +@impl_tracked(m, "linalg_svd") +def linalg_svd( + A: torch.Tensor, + full_matrices: bool = False, + compute_uv: bool = True, + driver: str | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert compute_uv + U, S, Vh = torch.linalg.svd(A, full_matrices=full_matrices, driver=driver) + return U.contiguous(), S.contiguous(), Vh.contiguous() + + +@impl_tracked(m, "_softmax_f32_f32") +def softmax_f32_f32( + input_tensor: torch.Tensor, + dim: int, + half_to_float: bool | None = None, +) -> torch.Tensor: + assert input_tensor.dtype == torch.float32, "input_tensor must be float32" + assert not half_to_float, "half_to_float is not supported" + return torch.nn.functional.softmax(input_tensor, dim=dim, dtype=torch.float32) + + +def quantized_softmax_per_tensor_common( + input_tensor: torch.Tensor, + mask: torch.Tensor | None, + dim: int, + in_scale: float, + in_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + """ + Quantized softmax operation. + + Args: + - input_tensor (Tensor): The quantized input tensor + - mask (Tensor): Mask tensor + - dim (int): The dimension along which softmax is computed + - in_scale (float): The scale of the input quantization + - in_zero_point (int): The zero point of the input quantization + - out_scale (float): The scale of the output quantization + - out_zero_point (int): The zero point of the output quantization + """ + # TODO: T228751479 - Add support for mask parameter in softmax + assert mask is None + supported_dtypes = [torch.int8, torch.uint8, torch.int16] + if input_tensor.dtype not in supported_dtypes: + raise ValueError( + f"Input dtype must be one of {supported_dtypes}. Got {input_tensor.dtype}" + ) + + float_input_tensor = dequantize_per_tensor( + input_tensor, + in_scale, + in_zero_point, + torch.iinfo(input_tensor.dtype).min, + torch.iinfo(input_tensor.dtype).max, + input_tensor.dtype, + ) + + softmax_output = torch.nn.functional.softmax(float_input_tensor, dim=dim) + + return quantize_per_tensor( + softmax_output, + out_scale, + out_zero_point, + torch.iinfo(input_tensor.dtype).min, + torch.iinfo(input_tensor.dtype).max, + input_tensor.dtype, + ) + + +@impl_tracked(m, "quantized_softmax.per_tensor") +def quantized_softmax_per_tensor( + input_tensor: torch.Tensor, + mask: torch.Tensor | None, + dim: int, + in_scale: float, + in_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + return quantized_softmax_per_tensor_common( + input_tensor, + mask, + dim, + in_scale, + in_zero_point, + out_scale, + out_zero_point, + ) + + +@impl_tracked(m, "quantized_softmax") +def quantized_softmax( + input_tensor: torch.Tensor, + mask: torch.Tensor | None, + dim: int, + in_scale: torch.Tensor, + in_zero_point: torch.Tensor, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + return quantized_softmax_per_tensor_common( + input_tensor, + mask, + dim, + float(in_scale.item()), + int(in_zero_point.item()), + out_scale, + out_zero_point, + ) diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 663c5825e52..f7419ff25dc 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -6,10 +6,11 @@ # pyre-strict - -import logging from dataclasses import dataclass, field -from typing import cast, List, Optional, Sequence, Set +from typing import cast, List, Optional, Set, Type + +# Import these for the cadence function signatures. +import executorch.backends.cadence.aot.ops_registrations # noqa: F401 import torch import torch.fx @@ -17,6 +18,7 @@ CadencePassAttribute, get_arg, register_cadence_pass, + RemoveOrReplacePassInterface, set_arg, ) @@ -36,7 +38,7 @@ class RemoveCloneOpsTransformImported(ExportPass): def call(self, graph_module: torch.fx.GraphModule) -> PassResult: finalize_passes: List[PassType] = [ - RemoveCloneOpsTransform(), + RemoveCloneOpsTransform(eliminate_quant_dequant_pairs=False), ] result = PassManager(passes=finalize_passes)(graph_module) dead_code_elimination_pass(result.graph_module) @@ -44,19 +46,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveDetachCopyPass(ExportPass): - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.detach_copy.default: - return super().call_operator(op, args, kwargs, meta) - - assert len(args) == 1 - return cast(ProxyValue, args[0]) +class RemoveDetachCopyPass(RemoveOrReplacePassInterface): + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.detach_copy.default] + + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True # The following class consolidates passes to remove ops that are redundant: @@ -69,184 +68,181 @@ class RemoveRedundantOps: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveZeroSizedCatArgsPass(ExportPass): - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.cat.default: - return super().call_operator(op, args, kwargs, meta) - - # Remove any zero-sized tensor arg to form a new args list. - cat_inputs: list[ProxyValue] = [] - for arg in cast(Sequence[ProxyValue], args[0]): - if arg.to_tensor().numel() > 0: - cat_inputs.append(arg) - - # If all the tensors were empty, we just return an empty tensor with - # the right shape. +class RemoveZeroSizedCatArgsPass(RemoveOrReplacePassInterface): + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.cat.default] + + def maybe_remove_or_replace(self, node: Node) -> bool: + # Get the cat inputs (first argument is a list of tensors) + cat_inputs_arg = node.args[0] + + # Assert that cat_inputs_arg is iterable + assert isinstance( + cat_inputs_arg, (list, tuple) + ), "cat_inputs_arg must be a sequence type" + + # Filter out zero-sized tensors + cat_inputs: list[Node] = [] + for arg in cat_inputs_arg: + if isinstance(arg, Node) and arg.meta.get("val") is not None: + if arg.meta["val"].numel() > 0: + cat_inputs.append(arg) + + # If all tensors were empty, create a full op with the right shape if not cat_inputs: - empty_shape = meta["val"].shape - dtype = meta["val"].dtype - return super().call_operator( - exir_ops.edge.aten.full.default, - (tuple(empty_shape), 0), - {"dtype": dtype}, - meta, - ) + empty_shape = node.meta["val"].shape + dtype = node.meta["val"].dtype + # Create a new full node + with node.graph.inserting_before(node): + full_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=(tuple(empty_shape), 0), + kwargs={"dtype": dtype}, + ) + full_node.meta = node.meta.copy() + node.replace_all_uses_with(full_node) + return True - # If there was only one tensor in the cat_inputs list, - # we can safely erase this cat op. + # If only one tensor remains, replace with it if len(cat_inputs) == 1: - return cat_inputs[0] + node.replace_all_uses_with(cat_inputs[0]) + return True + + # If the number of inputs changed, update the cat args + if len(cat_inputs) < len(cat_inputs_arg): + # Update the first argument with filtered inputs + new_args = list(node.args) + new_args[0] = cat_inputs + node.args = tuple(new_args) + return True - # Otherwise, we replace args[0] with cat_inputs. - new_args = list(args) - # pyre error introduced after D66937105 - new_args[0] = cat_inputs # pyre-ignore[6] - return super().call_operator(op, tuple(new_args), kwargs, meta) + # No changes needed + return False @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveNopExpandOpPass(ExportPass): +class RemoveNopExpandOpPass(RemoveOrReplacePassInterface): """ For an expand op, if the operator shape matches the expand shape, then the expand is a nop. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if get_edge_overload_packet(op) not in { - exir_ops.edge.aten.expand_copy, - exir_ops.edge.aten.expand, - }: - return super().call_operator(op, args, kwargs, meta) - - # Parse the args, and check for nop condition - arg0 = cast(ProxyValue, args[0]) - arg1 = cast(Sequence[int], args[1]) - in_tensor = arg0.to_tensor() - if list(in_tensor.shape) == list(arg1): - return arg0 + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.aten.expand_copy.default, + exir_ops.edge.aten.expand.default, + ] - return super().call_operator(op, args, kwargs, meta) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + if input_node.meta["val"].shape == node.meta["val"].shape: + node.replace_all_uses_with(input_node) + return True + return False @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveToOpsPass(ExportPass): +class RemoveToOpsPass(RemoveOrReplacePassInterface): # aten.to.* as of now are all nops - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in ( + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.to.dtype, exir_ops.edge.aten.to.dtype_layout, - ): - return super().call_operator(op, args, kwargs, meta) + ] - logging.debug(f"Erasing to.dtype node (target = {op})") - return cast(ProxyValue, args[0]) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveZeroSizedConstantPadNd(ExportPass): - def call_operator( - self, - op, # pyre-ignore - args: tuple[ProxyValue, tuple[int, ...], Argument], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.constant_pad_nd.default: - return super().call_operator(op, args, kwargs, meta) +class RemoveZeroSizedConstantPadNd(RemoveOrReplacePassInterface): + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.constant_pad_nd.default] + + def maybe_remove_or_replace(self, node: Node) -> bool: + # Get padding argument (second argument) + if len(node.args) < 2: + return False - input_tensor = args[0] - padding = args[1] + padding = node.args[1] + if not isinstance(padding, (list, tuple)): + return False + # If any padding value is non-zero, keep the node if any(x != 0 for x in padding): - return super().call_operator(op, args, kwargs, meta) + return False - logging.debug(f"Erasing 0 sized constant pad nd node with {input_tensor}") - return input_tensor + # All padding is zero, replace with input + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopSliceOrViewOpPass(ExportPass): +class RemoveNopSliceOrViewOpPass(RemoveOrReplacePassInterface): """ Remove slice ops that are more like views, and view ops that do not change the shape """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.view_copy.default, - }: - return super().call_operator(op, args, kwargs, meta) + ] - arg0 = cast(ProxyValue, args[0]) - out_shape = meta["val"].shape + def maybe_remove_or_replace(self, node: Node) -> bool: + changed = False + input_node = node.args[0] + assert isinstance(input_node, Node) + if input_node.meta["val"].shape == node.meta["val"].shape: + node.replace_all_uses_with(input_node) + changed = True - # If both arg_shape and out_shape are the same, this slice is a nop - return ( - arg0 - if arg0.to_tensor().shape == out_shape - else super().call_operator(op, args, kwargs, meta) - ) + return changed @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopLinalgVectorNormOpPass(ExportPass): +class RemoveNopLinalgVectorNormOpPass(RemoveOrReplacePassInterface): """ If the norm is applied over a dimension that is size 1, it can be eliminated. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op is not exir_ops.edge.aten.linalg_vector_norm.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.linalg_vector_norm.default] + def maybe_remove_or_replace(self, node: Node) -> bool: # If the op has three args or less, it can't be a nop - if len(args) <= 3: - return super().call_operator(op, args, kwargs, meta) + if len(node.args) <= 3: + return False # If dim is None, or keepdim is False, it is not a nop - dim = cast(Optional[tuple[int, ...]], args[2]) - keepdim = cast(bool, args[3]) + dim = cast(Optional[tuple[int, ...]], node.args[2]) + keepdim = cast(bool, node.args[3]) if dim is None or not keepdim: - return super().call_operator(op, args, kwargs, meta) + return False # If the norm has 4 args and keepdim is True, check if dim is not None # and if the dimensions in dim are size 1. If not, the norm is not a nop. - t = cast(ProxyValue, args[0]) - shape = t.to_tensor().shape - if len(args) < 4: + input_node = node.args[0] + assert isinstance(input_node, Node) + shape = input_node.meta["val"].shape + if len(node.args) < 4: for d in dim: if shape[d] != 1: - return super().call_operator(op, args, kwargs, meta) + return False - return t + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -361,23 +357,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveCloneOpPass(ExportPass): - # If the op is a clone op, return the input and eliminate the op - def call_operator( - self, - op, # pyre-ignore - args: tuple[ProxyValue], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.clone.default: - return super().call_operator(op, args, kwargs, meta) - - return args[0] - - -@register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveContiguousOpPass(ExportPass): +class RemoveContiguousOpPass(RemoveOrReplacePassInterface): """ This is based on the assumption that all tensors are contiguous in ExecuTorch and after cadence passes, and we should revisit this if that assumption is no longer true. @@ -385,43 +365,37 @@ class RemoveContiguousOpPass(ExportPass): original graph module. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.contiguous.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.contiguous.default] - assert len(args) == 1 - return cast(ProxyValue, args[0]) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveAliasCopyOpPass(ExportPass): +class RemoveAliasCopyOpPass(RemoveOrReplacePassInterface): """ alias_copy is a no-op and can be removed. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.alias_copy.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.alias_copy.default] - assert len(args) == 1 - return cast(ProxyValue, args[0]) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopRequantizeOpPass(ExportPass): +class RemoveNopRequantizeOpPass(RemoveOrReplacePassInterface): """ For a requantize op, if the following three conditions are satisfied: 1. the in_scale matches the out_scale @@ -430,100 +404,96 @@ class RemoveNopRequantizeOpPass(ExportPass): then the requantize op is redundant, and can be eliminated """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.cadence.requantize.per_tensor: - return super().call_operator(op, args, kwargs, meta) - - # Parse the args - (X, in_scale, in_zero_point, out_scale, out_zero_point, out_dtype) = cast( - tuple[ProxyValue, int, float, int, float, torch.dtype], args - ) - in_dtype = X.to_tensor().dtype + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.cadence.requantize.per_tensor] + + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + in_scale = node.args[1] + in_zero_point = node.args[2] + out_scale = node.args[3] + out_zero_point = node.args[4] + out_dtype = node.args[5] + in_dtype = input_node.meta["val"].dtype # Check the three conditions if ( in_scale == out_scale and in_zero_point == out_zero_point and in_dtype == out_dtype ): - return cast(ProxyValue, args[0]) - - return super().call_operator(op, args, kwargs, meta) + node.replace_all_uses_with(input_node) + return True + return False @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopMulOpPass(ExportPass): +class RemoveNopMulOpPass(RemoveOrReplacePassInterface): """ If a mul op is multiplying two tensors with the same shape and one of those tensors is all zeros, return the zero tensor instead. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.mul.Tensor: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mul.Tensor] - # Parse the args - (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args) + def maybe_remove_or_replace(self, node: Node) -> bool: + input1 = node.args[0] + input2 = node.args[1] + assert isinstance(input1, Node) + assert isinstance(input2, Node) # Check if both inputs have the same shape - if input1.to_tensor().shape != input2.to_tensor().shape: - return super().call_operator(op, args, kwargs, meta) + if input1.meta["val"].shape != input2.meta["val"].shape: + return False # Check if one of the inputs is a zero tensor - if input1.node.target == exir_ops.edge.aten.full.default: - if input1.node.args[1] == 0: - return input1 - elif input2.node.target == exir_ops.edge.aten.full.default: - if input2.node.args[1] == 0: - return input2 + if input1.target == exir_ops.edge.aten.full.default: + if input1.args[1] == 0: + node.replace_all_uses_with(input1) + return True + elif input2.target == exir_ops.edge.aten.full.default: + if input2.args[1] == 0: + node.replace_all_uses_with(input2) + return True - return super().call_operator(op, args, kwargs, meta) + return False @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopAddOpPass(ExportPass): +class RemoveNopAddOpPass(RemoveOrReplacePassInterface): """ If an add op is adding two tensors with the same shape and one of those tensors is all zeros, return the other tensor instead. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.add.Tensor: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.add.Tensor] - # Parse the args - (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args) + def maybe_remove_or_replace(self, node: Node) -> bool: + input1 = node.args[0] + input2 = node.args[1] + assert isinstance(input1, Node) + assert isinstance(input2, Node) # Check if both inputs have the same shape - if input1.to_tensor().shape != input2.to_tensor().shape: - return super().call_operator(op, args, kwargs, meta) + if input1.meta["val"].shape != input2.meta["val"].shape: + return False # Check if one of the inputs is a zero tensor - if input1.node.target == exir_ops.edge.aten.full.default: - if input1.node.args[1] == 0: - return input2 - elif input2.node.target == exir_ops.edge.aten.full.default: - if input2.node.args[1] == 0: - return input1 + if input1.target == exir_ops.edge.aten.full.default: + if input1.args[1] == 0: + node.replace_all_uses_with(input2) + return True + elif input2.target == exir_ops.edge.aten.full.default: + if input2.args[1] == 0: + node.replace_all_uses_with(input1) + return True - return super().call_operator(op, args, kwargs, meta) + return False @register_cadence_pass(CadencePassAttribute(opt_level=2)) @@ -582,13 +552,17 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for node in subgraph.nodes: processed_nodes.add(node) + modified = False for subgraph in subgraphs_found: self.permute_subgraph(subgraph) + modified = True - graph_module.graph.eliminate_dead_code() - graph_module.recompile() + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) - return super().call(graph_module) + return PassResult(graph_module, False) def visit( self, @@ -752,17 +726,17 @@ def get_squeeze_indices(self, view_node: Node) -> List[int]: return squeeze_indices - def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None: + def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> bool: if view_node in visited_view_nodes: - return + return False squeeze_indices = self.get_squeeze_indices(view_node) if not squeeze_indices: - return + return False # Only handle simple chains for now. if len(view_node.users) != 1: - return + return False node = next(iter(view_node.users)) # Traverse down from the node until finding another view op. @@ -770,9 +744,9 @@ def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None while node.target != exir_ops.edge.aten.view_copy.default: # Only handle simple chains for now if len(node.users) != 1: - return + return False if node.target not in self.intermediate_ops: - return + return False if node.target == exir_ops.edge.aten.slice_copy.Tensor: intermediate_slices.append(node) node = next(iter(node.users)) @@ -795,18 +769,22 @@ def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None # Skip the initial view node. input_node = cast(Node, get_arg(view_node, "input")) view_node.replace_all_uses_with(input_node) + return True def call(self, graph_module: torch.fx.GraphModule) -> PassResult: visited_view_nodes = set() + modified = False for view_node in graph_module.graph.find_nodes( op="call_function", target=exir_ops.edge.aten.view_copy.default, sort=True ): - self.handle_squeeze(view_node, visited_view_nodes) + modified |= self.handle_squeeze(view_node, visited_view_nodes) - graph_module.graph.eliminate_dead_code() - graph_module.recompile() + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) - return super().call(graph_module) + return PassResult(graph_module, False) @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -829,23 +807,27 @@ class RemoveBranchedQuantDequant(ExportPass): } def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.remove_branched( + modified = self.remove_branched( graph_module, self.quantize_op_packets, self.dequantize_op_packets ) - self.remove_branched( + modified |= self.remove_branched( graph_module, self.dequantize_op_packets, self.quantize_op_packets ) - graph_module.graph.eliminate_dead_code() - result = super().call(graph_module) - return result + if modified: + graph_module.graph.eliminate_dead_code() + result = super().call(graph_module) + return result + + return PassResult(graph_module, False) def remove_branched( self, graph_module: torch.fx.GraphModule, producer_pkts: set[EdgeOpOverloadPacket], consumer_pkts: set[EdgeOpOverloadPacket], - ) -> None: + ) -> bool: + modified = False for node in graph_module.graph.nodes: if ( node.op != "call_function" @@ -869,76 +851,84 @@ def remove_branched( continue user.replace_all_uses_with(node.args[0]) + modified = True + + return modified -class RemoveCatFromSliceCopyPass(ExportPass): +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveCatFromSliceCopyPass(RemoveOrReplacePassInterface): """ Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed to the slice_copy. """ - def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None: - for slice_copy_node in graph_module.graph.find_nodes( - op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor - ): - cat_node = cast(Node, get_arg(slice_copy_node, "input")) - slice_dim = cast(int, get_arg(slice_copy_node, "dim")) - start_idx = cast(int, get_arg(slice_copy_node, "start")) - end_idx = cast(int, get_arg(slice_copy_node, "end")) - step = cast(int, get_arg(slice_copy_node, "step")) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.slice_copy.Tensor] - if cat_node.target != exir_ops.edge.aten.cat.default or step != 1: - continue + def maybe_remove_or_replace(self, node: Node) -> bool: + cat_node = cast(Node, get_arg(node, "input")) + slice_dim = cast(int, get_arg(node, "dim")) + start_idx = cast(int, get_arg(node, "start")) + end_idx = cast(int, get_arg(node, "end")) + step = cast(int, get_arg(node, "step")) - # Make sure cat and slice happens on the same dimension. - cat_dim = cast(Node, get_arg(cat_node, "dim")) - if cat_dim != slice_dim: - continue + if cat_node.target != exir_ops.edge.aten.cat.default or step != 1: + return False - # Canonicalize slice indices. - cat_output_shape = cat_node.meta["val"].shape - if start_idx is None: - start_idx = 0 - elif start_idx < 0: - start_idx += cat_output_shape[cat_dim] - if end_idx is None or end_idx > cat_output_shape[cat_dim]: - end_idx = cat_output_shape[cat_dim] - elif end_idx < 0: - end_idx += cat_output_shape[cat_dim] - - offset = 0 - for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")): - cat_input_shape = cat_input_node.meta["val"].shape - - # Check if the slice range overlaps with the cat input range. - if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]: - slice_copy_node.replace_input_with(cat_node, cat_input_node) - set_arg(slice_copy_node, "start", start_idx - offset) - set_arg(slice_copy_node, "end", end_idx - offset) - break - - offset += cat_input_shape[cat_dim] + # Make sure cat and slice happens on the same dimension. + cat_dim = cast(int, get_arg(cat_node, "dim")) + if cat_dim != slice_dim: + return False - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self._remove_unused_cat(graph_module) - graph_module.recompile() - graph_module.graph.eliminate_dead_code() - return super().call(graph_module) + # Canonicalize slice indices. + cat_output_shape = cat_node.meta["val"].shape + if start_idx is None: + start_idx = 0 + elif start_idx < 0: + start_idx += cat_output_shape[cat_dim] + if end_idx is None or end_idx > cat_output_shape[cat_dim]: + end_idx = cat_output_shape[cat_dim] + elif end_idx < 0: + end_idx += cat_output_shape[cat_dim] + + offset = 0 + for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")): + cat_input_shape = cat_input_node.meta["val"].shape + + # Check if the slice range overlaps with the cat input range. + if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]: + node.replace_input_with(cat_node, cat_input_node) + set_arg(node, "start", start_idx - offset) + set_arg(node, "end", end_idx - offset) + return True + + offset += cat_input_shape[cat_dim] + + return False + + +class CommonRemovePasses: + passes: List[Type[ExportPass]] = [ + RemoveAliasCopyOpPass, + RemoveNopExpandOpPass, + RemoveNopSliceOrViewOpPass, + RemoveToOpsPass, + RemoveZeroSizedCatArgsPass, + RemovePermutesAroundElementwiseOps, + RemoveSqueezeViewBeforeElementwiseOps, + RemoveCatFromSliceCopyPass, + RemoveCloneOpsTransformImported, + ] class CadenceRemoveNops: - passes = [ + passes: List[Type[ExportPass]] = CommonRemovePasses.passes + [ SimplifySliceOpPass, - RemoveCloneOpsTransformImported, - RemoveToOpsPass, RemoveNopRequantizeOpPass, - RemoveZeroSizedCatArgsPass, - RemoveNopSliceOrViewOpPass, - RemoveNopExpandOpPass, RemoveZeroSizedConstantPadNd, - RemoveCloneOpPass, RemoveContiguousOpPass, - RemoveAliasCopyOpPass, RemoveNopMulOpPass, RemoveNopAddOpPass, RemoveNopLinalgVectorNormOpPass, diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index 675c8e6cecd..0026c35ed57 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -820,9 +820,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: return super().call(self.graph_module) -# The following class consolidates functions to reoder ops (i.e., either hoist -# or sink some ops in the graph). -class CadenceReorderOpsInGraph: +class CommonReorderPasses: passes = [ # Hoist/sink nodes closer to their SSA def/use HoistOpsCloserToDefPass, @@ -832,6 +830,13 @@ class CadenceReorderOpsInGraph: # nodes closer to their def/use. AdvanceQuantizeOpAboveDefChainPass, PostponeDequantizeOpBelowUseChainPass, + ] + + +# The following class consolidates functions to reoder ops (i.e., either hoist +# or sink some ops in the graph). +class CadenceReorderOpsInGraph: + passes = CommonReorderPasses.passes + [ # These passes work on branches instead of linear chains to advance # quantize op beyond their def. AdvanceQuantizeOpAboveDefInBranchPass, diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 001ab95d629..5f6f162edd0 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -15,13 +15,11 @@ import math import operator from operator import neg -from typing import cast, Dict, Iterable, Optional, Sequence, Tuple +from typing import cast, Dict, Iterable, Optional, Sequence import torch import torch.fx from executorch.backends.cadence.aot.compiler_utils import ( - get_shape, - get_tensor_from_attr, get_zero_point, is_node_with_op, quantize_tensor_multiplier, @@ -32,19 +30,16 @@ ) from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, - none_throws, register_cadence_pass, + RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass -from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket -from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue -from torch._subclasses import FakeTensor -from torch.fx.node import Argument +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult # A map to represent ops that: # (a) are functionally equivalent; and @@ -69,424 +64,500 @@ def contains_placeholder_or_param(nodes: Iterable[torch.fx.Node]) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass): +class ReplaceLogicalNotBooleanWhereWithWherePass(RemoveOrReplacePassInterface): """ A where op with a logical_not and a boolean tensor can be replaced by a where op with flipped inputs and the initial boolean tensor. """ - def replace_logical_nop_where_with_where( - self, graph_module: torch.fx.GraphModule - ) -> None: - graph = graph_module.graph - for node in graph.nodes: - # We are only interested in where nodes - if node.target != exir_ops.edge.aten.where.self: - continue + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.where.self] - # If the third arg is not a logical_not, bail. - if node.args[0].target != exir_ops.edge.aten.logical_not.default: - continue + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # If the first arg is not a logical_not, bail. + if not isinstance(node.args[0], torch.fx.Node): + return False - # Get the third arg node and its input - logical_not_node = node.args[0] - logical_not_input_tensor = ( - logical_not_node.args[0].to_tensor() - if isinstance(logical_not_node.args[0], ProxyValue) - else logical_not_node.args[0] - ) + logical_not_node = cast(torch.fx.Node, node.args[0]) + if logical_not_node.target != exir_ops.edge.aten.logical_not.default: + return False - # If the logical_not input is not a boolean tensor, bail. - if logical_not_input_tensor.meta["spec"].dtype != torch.bool: - continue + # Get the first arg node and its input + if not isinstance(logical_not_node.args[0], torch.fx.Node): + return False - # Replace the where op with another one, flipping the inputs and using the boolean - # tensor from logical_not. - with graph.inserting_before(node): - linear_node = graph.call_function( - exir_ops.edge.aten.where.self, - args=(logical_not_node.args[0], node.args[2], node.args[1]), - ) - # Replace all the uses - node.replace_all_uses_with(linear_node) + logical_not_input_node = cast(torch.fx.Node, logical_not_node.args[0]) - graph_module.recompile() - graph_module.graph.eliminate_dead_code() + # If the logical_not input is not a boolean tensor, bail. + if logical_not_input_node.meta["val"].dtype != torch.bool: + return False - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.replace_logical_nop_where_with_where(graph_module) - result = super().call(graph_module) - return result + # Replace the where op with another one, flipping the inputs and using the boolean + # tensor from logical_not. + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.where.self, + args=(logical_not_input_node, node.args[2], node.args[1]), + ) + new_node.meta = node.meta + # Replace all the uses + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep +class ReplaceSafeSoftmaxWithSoftmax(RemoveOrReplacePassInterface): # keep """ Replace _safe_softmax with _softmax """ - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != torch.ops.aten._safe_softmax.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [torch.ops.aten._safe_softmax.default] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Add False for the half_to_float argument of softmax - softmax_args = list(args) + [False] + softmax_args = tuple(list(node.args) + [False]) - return super().call_operator( - torch.ops.aten._softmax.default, - tuple(softmax_args), - kwargs, - meta, - ) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + torch.ops.aten._softmax.default, + args=softmax_args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplacePT2QuantWithCadenceQuantPass(ExportPass): +class ReplacePT2QuantWithCadenceQuantPass(RemoveOrReplacePassInterface): """ Replace the pt2 quantization ops with cadence quantization ops. We do not link kernels to the PT2 quantization ops, so we need to replace them with cadence ops at all optimization levels. """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops - if op != ns.quantized_decomposed.quantize_per_tensor.default: - return super().call_operator(op, args, kwargs, meta) - - return super().call_operator( - ns.cadence.quantize_per_tensor.default, - args, - kwargs, - meta, - ) + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + ] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + ns.cadence.quantize_per_tensor.default, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplacePT2DequantWithCadenceDequantPass(ExportPass): +class ReplacePT2DequantWithCadenceDequantPass(RemoveOrReplacePassInterface): """ Replace the pt2 dequantization ops with cadence dequantization ops. We do not link kernels to the PT2 quantization ops, so we need to replace them with cadence ops at all optimization levels. """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops - if op != ns.quantized_decomposed.dequantize_per_tensor.default: - return super().call_operator(op, args, kwargs, meta) - - return super().call_operator( - ns.cadence.dequantize_per_tensor.default, - args, - kwargs, - meta, - ) + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + ] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + ns.cadence.dequantize_per_tensor.default, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass): +class ReplaceSqueezeAndUnsqueezeWithViewPass(RemoveOrReplacePassInterface): """ When the shape is static, replace squeeze_copy and unsqueeze_copy ops with view_copy op """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - # Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket, - # which allows us to cover all overloads. - if get_edge_overload_packet(op) not in { - exir_ops.edge.aten.squeeze_copy, - exir_ops.edge.aten.unsqueeze_copy, - }: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.aten.squeeze_copy.default, + exir_ops.edge.aten.squeeze_copy.dim, + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + ] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Get the output tensor shape - out_shape = meta["val"].shape + out_shape = node.meta["val"].shape # Bail out if any dim is not an int (dynamic shape) for dim in list(out_shape): if not isinstance(dim, int): - return super().call_operator(op, args, kwargs, meta) + return False - # Return a view op with the new shape - view_args = (args[0], list(out_shape)) - return super().call_operator( - exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta - ) + # Replace with view op with the new shape + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(node.args[0], list(out_shape)), + ) + # Do not remove the metadata copy! + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceFunctionallyEquivalentOpTargets(ExportPass): +class ReplaceFunctionallyEquivalentOpTargets(RemoveOrReplacePassInterface): """ Replace an op with a functionally equivalent op by just switching the op target, but without incurring any change to the op args. """ - def call_operator(self, op, args, kwargs, meta): - if op not in functionally_equivalent_op_targets: - return super().call_operator(op, args, kwargs, meta) - return super().call_operator( - functionally_equivalent_op_targets[op], args, kwargs, meta - ) + @property + def targets(self) -> list[EdgeOpOverload]: + return list(functionally_equivalent_op_targets.keys()) + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + assert isinstance(node.target, EdgeOpOverload) + target_op = functionally_equivalent_op_targets[node.target] + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + target_op, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + + # RemoveOrReplacePassInterface calls eliminate_dead_code, but this doesn't + # remove impure nodes (nodes which have side effects). Not sure if that is + # generally safe, so instead of modifying the interface, just erasing + # these nodes for this pass. + node.graph.erase_node(node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceSelectWithViewOpPass(ExportPass): +class ReplaceSelectWithViewOpPass(RemoveOrReplacePassInterface): """ If the size along the select dim is 1, then the select op can be replaced by view op. """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.select_copy.int: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.select_copy.int] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Get the input tensor and shapes + in_tensor_node = node.args[0] + assert isinstance(in_tensor_node, torch.fx.Node) + in_shape = in_tensor_node.meta["val"].shape + out_shape = node.meta["val"].shape - # Glean the shape of input and output tensor - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] - in_shape = in_tensor.shape - out_shape = meta["val"].shape # Get the select dimension - select_dim = args[1] if args[1] >= 0 else args[1] + len(in_shape) + select_dim = node.args[1] + assert isinstance(select_dim, int) + select_dim = select_dim if select_dim >= 0 else select_dim + len(in_shape) if in_shape[select_dim] == 1: - # Return a view op with the new shape - view_args = (args[0], list(out_shape)) - return super().call_operator( - exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta - ) - return super().call_operator(op, args, kwargs, meta) + # Replace with view op with the new shape + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(node.args[0], list(out_shape)), + ) + # Important to copy metadata + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True + + return False @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceMMWithAddMMPass(ExportPass): +class ReplaceMMWithAddMMPass(RemoveOrReplacePassInterface): """ This pass replaces mm with addmm by introducing a zero bias. mm is not supported, so this is an opt_level=0 pass. """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.mm.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mm.default] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # The mm op has two args: input, mat2 - assert len(args) == 2 - X, mat2 = args + assert len(node.args) == 2 + X, mat2 = node.args + assert isinstance(X, torch.fx.Node) + assert isinstance(mat2, torch.fx.Node) # Create a zero bias tensor, and insert it as a graph buffer before the # current node - mat2_tensor = mat2.to_tensor() if isinstance(mat2, ProxyValue) else mat2 + mat2_tensor = mat2.meta["val"] bias_size = mat2_tensor.size(1) - zero_bias = super().call_operator( - exir_ops.edge.aten.full.default, - ([bias_size], 0.0), - {"dtype": torch.float32}, - meta, - ) + + with node.graph.inserting_before(node): + zero_bias = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=([bias_size], 0.0), + kwargs={"dtype": torch.float32}, + ) + zero_bias.meta = node.meta # Replace mm with addmm new_args = (zero_bias, X, mat2) - return super().call_operator( - exir_ops.edge.aten.addmm.default, new_args, kwargs, meta - ) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.addmm.default, + args=new_args, + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceAddMMWithLinearPass(ExportPass): +class ReplaceAddMMWithLinearPass(RemoveOrReplacePassInterface): """ This pass replaces addmm with linear op. + + AddMM computes: beta*bias + alpha*mm(mat1, mat2) + Linear computes: mat1 @ weight.T + bias + """ - def __init__(self): - super().__init__() - self.counter = 0 + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.addmm.default] - def replace_addmm_with_linear(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - for node in graph.nodes: - # We are only interested in admm nodes - if node.target != exir_ops.edge.aten.addmm.default: - continue + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # The addmm op has three concrete args: bias, mat1, mat2 + assert len(node.args) >= 3 + (bias, mat1, mat2) = node.args[0:3] - # The addmm op has three concrete args: input, mat1, mat2 - assert len(node.args) >= 3 - (bias, mat1, mat2) = node.args[0:3] - # The other two args are optional scale args - beta = node.kwargs.get("beta", 1.0) - alpha = node.kwargs.get("alpha", 1.0) - - # AddMM performs beta*bias + alpha*mm(mat1, mat2). We can convert - # it to linear op by multiplying beta to bias, and alpha to mat2.t(). - # However, the following two conditions must hold: - # a. If bias is not a param, then beta must be 1.0 - # b. If mat2 is not a param, then mat2 must be a transpose op. Also, - # the input to the transpose must be a param, or alpha must be 1.0. - fit_bias = is_node_with_op(bias, "get_attr") or beta == 1.0 - fit_mat2 = is_node_with_op(mat2, "get_attr") - transposed_mat2 = False - if ( - not fit_mat2 - and is_node_with_op(mat2, "call_function") - and mat2.target == exir_ops.edge.aten.transpose_copy.int - ): - mat2, transposed_mat2 = mat2.args[0], True - fit_mat2 = is_node_with_op(mat2, "get_attr") or alpha == 1.0 - - if not fit_bias or not fit_mat2: - continue + # The other two args are optional scale args + beta = float(node.kwargs.get("beta", 1.0)) + alpha = float(node.kwargs.get("alpha", 1.0)) + + bias, mat1, mat2 = cast( + tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node], + (bias, mat1, mat2), + ) - # Multiply bias by beta - if beta != 1.0: - assert is_node_with_op(bias, "get_attr") - bias_tensor = get_tensor_from_attr(graph_module, bias) - assert isinstance(bias_tensor, torch.Tensor) - bias_tensor = beta * bias_tensor - with graph.inserting_before(node): - bias_name = f"_bias_addmm_to_linear_{self.counter}" - graph_module.register_buffer(bias_name, bias_tensor) - bias = graph.get_attr(bias_name) - - # Use associativity of scalar multiplication, and multiply alpha to mat2 - if is_node_with_op(mat2, "get_attr"): - mat2_tensor = get_tensor_from_attr(graph_module, mat2) - assert isinstance(mat2_tensor, torch.Tensor) - mat2_tensor = alpha * mat2_tensor - # transpose mat2 - mat2_tensor = mat2_tensor if transposed_mat2 else mat2_tensor.t() - with graph.inserting_before(node): - mat2_name = f"_mat2_addmm_to_linear_{self.counter}" - graph_module.register_buffer(mat2_name, mat2_tensor) - mat2 = graph.get_attr(mat2_name) - - # Construct the linear node - linear_args = (mat1, mat2, bias) + graph = node.graph + + # Handle transpose: if mat2 is a transpose op, extract the original tensor + transposed_mat2 = False + if ( + mat2.op == "call_function" + and mat2.target == exir_ops.edge.aten.transpose_copy.int + ): + # mat2 is already transposed, so we use the input to the transpose + mat2 = cast(torch.fx.Node, mat2.args[0]) + transposed_mat2 = True + + # Multiply bias by beta if needed + if beta != 1.0: + # Create a scaled bias using element-wise multiplication in the graph with graph.inserting_before(node): - linear_node = graph.call_function( - exir_ops.edge.aten.linear.default, args=linear_args + beta_scalar = graph.call_function( + exir_ops.edge.aten.full.default, + args=([1], beta), + kwargs={"dtype": torch.float32}, + ) + beta_scalar.meta = node.meta + bias = graph.call_function( + exir_ops.edge.aten.mul.Tensor, + args=(bias, beta_scalar), ) - linear_node.meta = node.meta - # Replace all the uses of the addmm op with linear op - node.replace_all_uses_with(linear_node) - self.counter += 1 - graph_module.recompile() - graph_module.graph.eliminate_dead_code() + # Metadata copy important + bias.meta = node.meta - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.replace_addmm_with_linear(graph_module) - result = super().call(graph_module) - return result + # Multiply mat2 by alpha if needed + if alpha != 1.0: + with graph.inserting_before(node): + alpha_scalar = graph.call_function( + exir_ops.edge.aten.full.default, + args=([1], alpha), + kwargs={"dtype": torch.float32}, + ) + alpha_scalar.meta = node.meta + mat2 = graph.call_function( + exir_ops.edge.aten.mul.Tensor, + args=(mat2, alpha_scalar), + ) + + # Metadata copy important + mat2.meta = node.meta + + # Transpose mat2 if it wasn't already transposed + if not transposed_mat2: + with graph.inserting_before(node): + mat2 = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, + args=(mat2, -1, -2), + ) + + # Metadata copy important + mat2.meta = node.meta + + # Construct the linear node: linear(input, weight, bias) + # linear computes: input @ weight.T + bias + linear_args = (mat1, mat2, bias) + with graph.inserting_before(node): + linear_node = graph.call_function( + exir_ops.edge.aten.linear.default, + args=linear_args, + ) + + # Metadata copy important + linear_node.meta = node.meta + + # Replace all uses of the addmm op with linear op + node.replace_all_uses_with(linear_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplacePermuteWithTransposePass(ExportPass): +class ReplacePermuteWithTransposePass(RemoveOrReplacePassInterface): """ Replace permute op with transpose if the permutation is only along two dimensions. """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.permute_copy.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.permute_copy.default] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Get the old dim and new dim order - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] - old_dims = tuple(range(in_tensor.dim())) - new_dims = args[1] + in_tensor = node.args[0] + assert isinstance(in_tensor, torch.fx.Node) + in_shape = in_tensor.meta["val"].shape + old_dims = tuple(range(len(in_shape))) + new_dims = cast(Sequence[int], node.args[1]) # Compute the number of positions in which the old and new order differ diff = [od for od, nd in zip(old_dims, new_dims) if od != nd] + # If the difference is zero, replace with identity (just the input) + if len(diff) == 0: + node.replace_all_uses_with(in_tensor) + return True + # If the difference is in two dimensions, we can replace this permute op # with transpose op. if len(diff) == 2: - new_args = (args[0], diff[0], diff[1]) - return super().call_operator( - exir_ops.edge.aten.transpose_copy.int, new_args, kwargs, meta - ) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.transpose_copy.int, + args=(node.args[0], diff[0], diff[1]), + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True - return ( - args[0] if len(diff) == 0 else super().call_operator(op, args, kwargs, meta) - ) + return False @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(ExportPass): +class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(RemoveOrReplacePassInterface): """ Replace optional tensors with concrete tensors. Currently, we replace the optional bias tensor with a zero tensor. """ - def call_operator(self, op, args, kwargs, meta): - if get_edge_overload_packet(op) != exir_ops.edge.aten.convolution: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.cadence.conv1d.default, + exir_ops.edge.cadence.conv2d.default, + exir_ops.edge.cadence.conv3d.default, + exir_ops.edge.cadence.transposed_convolution.default, + ] - # Check if the bias is already concrete - assert len(args) == 9 - if args[2] is not None: - return super().call_operator(op, args, kwargs, meta) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Check if this is a transposed convolution + assert isinstance(node.target, EdgeOpOverload) + is_transposed = ( + node.target == exir_ops.edge.cadence.transposed_convolution.default + ) + num_expected_args = 9 if is_transposed else 7 + assert len(node.args) == num_expected_args + # Check if the bias is concrete + if node.args[2] is not None: + return False # The bias length is the number of out channels. - out_shape = meta["val"].shape + out_shape = node.meta["val"].shape bias_size = out_shape[1] - # Create a zero bias tensor (bias is not a constant tensor, - # so it needs to be the result of a graph operation). - zero_bias = super().call_operator( - exir_ops.edge.aten.full.default, - ([bias_size], 0.0), - {"dtype": torch.float32}, - meta, - ) - # Replace bias with zero_bias - args = list(args) - args[2] = zero_bias - args = tuple(args) + # Create a zero bias tensor + with node.graph.inserting_before(node): + zero_bias = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=([bias_size], 0.0), + kwargs={"dtype": torch.float32}, + ) + # Create proper metadata for the zero_bias node + zero_bias.meta = node.meta + new_args = list(node.args) + new_args[2] = zero_bias + new_args = tuple(new_args) + + new_node = node.graph.call_function( + # pyre-ignore[6]: Target is a call func, but type is union call func and str + node.target, + args=new_args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta - return super().call_operator(op, args, kwargs, meta) + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceRepeatWithCatPass(ExportPass): +class ReplaceRepeatWithCatPass(RemoveOrReplacePassInterface): """ Replace repeat op as successive cat ops along different dimensions. repeat is not supported, so this is an opt_level=0 pass. """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.repeat.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.repeat.default] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Extract the input tensor, and the repeats from the args - in_tensor = args[0] - repeats = args[1] + in_tensor = node.args[0] + assert isinstance(in_tensor, torch.fx.Node) + repeats = cast(Sequence[int], node.args[1]) # Glean the shapes of input tensor - in_shape = list( - in_tensor.to_tensor().shape - if isinstance(in_tensor, ProxyValue) - else in_tensor.shape - ) + in_shape = list(in_tensor.meta["val"].shape) # If the size of repeats is more than the dimensionality of the tensor, # the output of repeat will be a higher-dimensional tensor. We reshape @@ -496,17 +567,20 @@ def call_operator(self, op, args, kwargs, meta): diff >= 0 ), "Repeat arg malformed: expected a repeat along each dimension of input tensor" + graph = node.graph + result_node = in_tensor + if diff > 0: # Extend the input shape with 1's along the higher dimensions in_shape = ([1] * diff) + in_shape # Insert a view op that reshapes the input tensor to have same # dimensionality as the output tensor. - in_tensor = super().call_operator( - exir_ops.edge.aten.view_copy.default, - (in_tensor, in_shape), - kwargs, - meta, - ) + with graph.inserting_before(node): + result_node = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(in_tensor, in_shape), + ) + result_node.meta = node.meta assert len(repeats) == len(in_shape) # Repeat op is nothing but successive cat ops along each dimension. @@ -514,37 +588,45 @@ def call_operator(self, op, args, kwargs, meta): # We do not need to do anything if repeat factor is 1 if repeat == 1: continue - cat_arg = [in_tensor] * repeat - in_tensor = super().call_operator( - exir_ops.edge.aten.cat.default, (cat_arg, dim), kwargs, meta - ) + cat_arg = [result_node] * repeat + with graph.inserting_before(node): + result_node = graph.call_function( + exir_ops.edge.aten.cat.default, args=(cat_arg, dim) + ) + result_node.meta = node.meta - return in_tensor + node.replace_all_uses_with(result_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplacePadWithCatPass(ExportPass): +class ReplacePadWithCatPass(RemoveOrReplacePassInterface): """ Replace constant pad nd op that does padding on outer-most dimension with Cat(left_padding_constant_tensor, X, right_padding_constant_tensor) """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.constant_pad_nd.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.constant_pad_nd.default] - assert len(args) >= 2 - input_node, orig_padding = args[:2] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + assert len(node.args) >= 2 + input_node, orig_padding = node.args[:2] + assert isinstance(input_node, torch.fx.Node) # if there is no padding, this op will be treated in removal pass. if not orig_padding: - return super().call_operator(op, args, kwargs, meta) + return False - value = 0 if len(args) == 2 else args[2] + value = 0 if len(node.args) == 2 else node.args[2] - arg_shape = input_node.to_tensor().shape + arg_shape = input_node.meta["val"].shape - padding = orig_padding + ([0] * (len(orig_padding) % 2 != 0)) + # Convert orig_padding to a list for manipulation + # pyre-ignore[6]: Argument type + padding_list = list(orig_padding) + padding = padding_list + ([0] * (len(padding_list) % 2 != 0)) assert len(padding) >= 2 (left_padding_size, right_padding_size) = padding[-2:] # Replace only if constant_pad_nd is along the innermost padding dimension. @@ -553,41 +635,47 @@ def call_operator(self, op, args, kwargs, meta): or left_padding_size < 0 or right_padding_size < 0 ): - return super().call_operator(op, args, kwargs, meta) + return False cat_tensors = [] dim = len(arg_shape) - len(padding) // 2 + graph = node.graph + # add left_padding if left_padding_size > 0: left_padding_shape = ( arg_shape[:dim] + (left_padding_size,) + arg_shape[dim + 1 :] ) - left_padding_node = super().call_operator( - exir_ops.edge.aten.full.default, - ( - left_padding_shape, - value, - ), - {"dtype": torch.float32}, - meta, - ) + with graph.inserting_before(node): + left_padding_node = graph.call_function( + exir_ops.edge.aten.full.default, + args=( + left_padding_shape, + value, + ), + kwargs={"dtype": torch.float32}, + ) + left_padding_node.meta = node.meta cat_tensors.append(left_padding_node) + # input_node cat_tensors.append(input_node) + # right_padding if right_padding_size > 0: right_padding_shape = ( arg_shape[:dim] + (right_padding_size,) + arg_shape[dim + 1 :] ) - right_padding_node = super().call_operator( - exir_ops.edge.aten.full.default, - ( - right_padding_shape, - value, - ), - {"dtype": torch.float32}, - meta, - ) + with graph.inserting_before(node): + right_padding_node = graph.call_function( + exir_ops.edge.aten.full.default, + args=( + right_padding_shape, + value, + ), + kwargs={"dtype": torch.float32}, + ) + right_padding_node.meta = node.meta cat_tensors.append(right_padding_node) assert len(cat_tensors) == 1 + (left_padding_size > 0) + ( @@ -595,55 +683,65 @@ def call_operator(self, op, args, kwargs, meta): ) new_args = (cat_tensors, dim) - return super().call_operator( - exir_ops.edge.aten.cat.default, - new_args, - kwargs, - meta, - ) + with graph.inserting_before(node): + new_node = graph.call_function( + exir_ops.edge.aten.cat.default, + args=new_args, + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceConstantPadNdWithSlicePass(ExportPass): +class ReplaceConstantPadNdWithSlicePass(RemoveOrReplacePassInterface): """ Replace constant pad nd op that does padding on outer-most dimension with exir_ops slice(left_padding_constant_tensor, X, right_padding_constant_tensor) """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.constant_pad_nd.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.constant_pad_nd.default] - assert len(args) >= 2 - input_node, orig_padding = args[:2] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + assert len(node.args) >= 2 + input_node = node.args[0] + orig_padding = cast(Sequence[int], node.args[1]) + assert isinstance(input_node, torch.fx.Node) # if there is no padding, this op will be treated in removal pass. if not orig_padding: - return super().call_operator(op, args, kwargs, meta) + return False - padding = orig_padding + ([0] * (len(orig_padding) % 2 != 0)) + padding = list(orig_padding) + ([0] * (len(orig_padding) % 2 != 0)) assert len(padding) >= 2 + + # pyre-ignore[6] (start, diff) = map(neg, padding[-2:]) # Replace only if constant_pad_nd is along the innermost padding dimension. if any(x != 0 for x in padding[0:-2]) or start < 0 or diff < 0: - return super().call_operator(op, args, kwargs, meta) + return False - arg_shape = input_node.to_tensor().shape + arg_shape = input_node.meta["val"].shape dim = len(arg_shape) - len(padding) // 2 stop = arg_shape[dim] - diff assert start <= stop - new_args = (input_node, dim, start, stop) - return super().call_operator( - exir_ops.edge.aten.slice.Tensor, - new_args, - kwargs, - meta, - ) + + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.slice.Tensor, + args=(input_node, dim, start, stop), + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True # Make that pass runnable standalone at opt level 0. @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAtenConvolutionWithCadenceConvolutionPass(ExportPass): +class ReplaceAtenConvolutionWithCadenceConvolutionPass(RemoveOrReplacePassInterface): """ Replace aten convolution op with jarvis-specific convolution op, since the aten version is not supported by jarvis. @@ -652,11 +750,14 @@ class ReplaceAtenConvolutionWithCadenceConvolutionPass(ExportPass): for unit-stride convolutions. """ - def call_operator(self, op, args, kwargs, meta): - if get_edge_overload_packet(op) != exir_ops.edge.aten.convolution: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.convolution.default] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # There must be 9 total args. - assert len(args) == 9 + if len(node.args) != 9: + return False # Unpack the args ( @@ -669,109 +770,98 @@ def call_operator(self, op, args, kwargs, meta): transposed, output_padding, groups, - ) = args - # Currently we only handle conversion to conv1d and conv2d, therefore + ) = node.args + + # Cast to appropriate types + stride = cast(Sequence[int], stride) + padding = cast(Sequence[int], padding) + dilation = cast(Sequence[int], dilation) + output_padding = cast(Sequence[int], output_padding) + + # Currently we only handle conversion to conv1d, conv2d, and conv3d, therefore # verify that the stride, padding, dilation, and output_padding have - # len <=2. - assert ( - len(stride) == len(padding) == len(dilation) == len(output_padding) == 1 - ) or ( - len(stride) == len(padding) == len(dilation) == len(output_padding) == 2 - ), "Can only map convolution to conv1d and conv2d at present" - - target = ( - exir_ops.edge.cadence.transposed_convolution.default - if transposed - else exir_ops.edge.cadence.convolution.default - ) + # len <=3. + if not ( + (len(stride) == len(padding) == len(dilation) == len(output_padding) == 1) + or ( + len(stride) == len(padding) == len(dilation) == len(output_padding) == 2 + ) + or ( + len(stride) == len(padding) == len(dilation) == len(output_padding) == 3 + ) + ): + return False + # Determine if this is 1D, 2D, or 3D convolution based on parameter lengths if transposed: - # Flip the height and width dimensions of weight, since we apply a - # gather stencil. Also, the first two dimensions of weight must be - # transposed/interchanged. - # If weight is a ProxyValue, new_weight needs to be the output of a - # graph operation (in this case a transpose_copy op) to be an explicit - # ProxyValue as well. If not, the view op can be done directly on the - # tensor. - transposed_weight = ( - super().call_operator( + target = exir_ops.edge.cadence.transposed_convolution.default + elif len(stride) == 1: + target = exir_ops.edge.cadence.conv1d.default + elif len(stride) == 2: + target = exir_ops.edge.cadence.conv2d.default + else: # len(stride) == 3 + target = exir_ops.edge.cadence.conv3d.default + + with node.graph.inserting_before(node): + if transposed: + # Flip the height and width dimensions of weight, since we apply a + # gather stencil. Also, the first two dimensions of weight must be + # transposed/interchanged. + assert isinstance(weight, torch.fx.Node) + transposed_weight = node.graph.call_function( exir_ops.edge.aten.transpose_copy.int, - ( - weight, - 0, - 1, - ), - kwargs, - meta, + args=(weight, 0, 1), ) - if isinstance(weight, ProxyValue) - else weight.transpose(0, 1) - ) + transposed_weight.meta = weight.meta + + # Get the dimension for flip based on weight shape + weight_dim = len(weight.meta["val"].shape) + flip_dims = [-1] if weight_dim == 3 else [-1, -2] - flipped_weight = ( - super().call_operator( + flipped_weight = node.graph.call_function( exir_ops.edge.aten.flip.default, - ( - transposed_weight, - [-1] if transposed_weight.to_tensor().dim() == 3 else [-1, -2], - ), - kwargs, - meta, + args=(transposed_weight, flip_dims), ) - if isinstance(transposed_weight, ProxyValue) - else ( - transposed_weight.flip(-1) - if transposed_weight.dim() == 3 - else transposed_weight.flip(-1, -2) + flipped_weight.meta = transposed_weight.meta + + new_args = ( + in_tensor, + flipped_weight, + bias, + stride, + padding, + dilation, + output_padding, + groups, + False, ) - ) + else: + # Verify that output_padding is 0. + if not all(x == 0 for x in output_padding): + return False - # From the previous checks, if flipped_weight is a FakeTensor, it has to be - # a constant (if not, it would be a ProxyValue). Mark it as such. - if isinstance(flipped_weight, FakeTensor): - flipped_weight.constant = flipped_weight - new_args = ( - in_tensor, - flipped_weight, - bias, - stride, - padding, - dilation, - output_padding, - groups, - False, - ) - else: - # Verify that output_padding is 0. - assert all( - x == 0 for x in output_padding - ), "Cannot handle padded output in convolution" - - # If the innermost dim of output tensor is 1, then the stride - # should be 1. Note that the first dimension of output tensor is - # channel - new_stride = stride.copy() - out_shape = meta["val"].shape - assert out_shape is not None - for i, e in enumerate(out_shape[2:]): - new_stride[i] = 1 if e == 1 else stride[i] + # Keep the original stride to maintain correct output dimensions + new_stride = stride - new_args = ( - in_tensor, - weight, - bias, - new_stride, - padding, - dilation, - groups, - False, - ) + new_args = ( + in_tensor, + weight, + bias, + new_stride, + padding, + dilation, + groups, + ) - return super().call_operator(target, new_args, kwargs, meta) + new_node = node.graph.call_function(target, args=new_args) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class ReplaceTrivialConvWithLinear(ExportPass): +class ReplaceTrivialConvWithLinear(RemoveOrReplacePassInterface): """ In nn.Conv1d, the operand shapes are: input - [batch, in_channels, in_length] @@ -786,41 +876,48 @@ class ReplaceTrivialConvWithLinear(ExportPass): """ trivial_conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = { - exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default, - exir_ops.edge.cadence.quantized_conv_nchw.default: exir_ops.edge.cadence.quantized_linear.default, - exir_ops.edge.cadence.quantized_conv_nhwc.default: exir_ops.edge.cadence.quantized_linear.default, + exir_ops.edge.cadence.conv1d.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.conv2d.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.conv3d.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, } - def call_operator(self, op, args, kwargs, meta): - if op not in self.trivial_conv_op_to_linear_op: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return list(self.trivial_conv_op_to_linear_op.keys()) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Parse the necessary args of the convolution node. Both convolution # and quantized_conv have the same first 8 args. The quantized op has # extra args holding at least the zero point and scale of input, weight, bias, # and output tensor. + assert isinstance(node.target, EdgeOpOverload) quantized_op = ( - op == exir_ops.edge.cadence.quantized_conv_nchw.default - or op == exir_ops.edge.cadence.quantized_conv_nhwc.default + node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor + or node.target == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor ) - assert (len(args) == 8 and not quantized_op) or ( - len(args) >= 12 and quantized_op + assert (len(node.args) == 7 and not quantized_op) or ( + len(node.args) >= 12 and quantized_op ), "Inconsistent args for convolution" - (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7] + (in_tensor, weight, bias, stride, padding, dilation, groups) = node.args[0:7] - # Glean the shapes of input, weight, and output - in_shape = ( - in_tensor.to_tensor().shape - if isinstance(in_tensor, ProxyValue) - else in_tensor.shape - ) + assert isinstance(in_tensor, torch.fx.Node) + assert isinstance(weight, torch.fx.Node) - weight_shape = ( - weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape - ) - out_shape = meta["val"].shape + # Glean the shapes of input, weight, and output + in_shape = in_tensor.meta["val"].shape + weight_shape = weight.meta["val"].shape + out_shape = node.meta["val"].shape assert None not in {in_shape, weight_shape, out_shape} + # pyre-ignore[6]: Argument type for iteration + stride_list = list(stride) + # pyre-ignore[6]: Argument type for iteration + padding_list = list(padding) + # pyre-ignore[6]: Argument type for iteration + dilation_list = list(dilation) + # Check the condition under which conv can be replaced by linear: (1) this # should not be a depthwise convolution; (2) the padding, stride, and dilation # should be standard; (3) The [channels, height, width] of input must match the @@ -829,47 +926,40 @@ def call_operator(self, op, args, kwargs, meta): # by linear. if ( groups != 1 - or any(x != 0 for x in padding) - or any(x != 1 for x in stride) - or any(x != 1 for x in dilation) + or any(x != 0 for x in padding_list) + or any(x != 1 for x in stride_list) + or any(x != 1 for x in dilation_list) or (list(in_shape[1:]) != list(weight_shape[1:])) ): - return super().call_operator(op, args, kwargs, meta) + return False # Reshape the weight to [out_channels, in_channels * X] K = math.prod(weight_shape[1:]) - # If weight is a ProxyValue, linear_weight needs to be the output of a - # graph operation (in this case a view_copy op) to be an explicit ProxyValue - # as well. If not, the view op can be done directly on the tensor. - linear_weight = ( - super().call_operator( + graph = node.graph + + # Weight is always a Node, so we need a view_copy operation + with graph.inserting_before(node): + linear_weight = graph.call_function( exir_ops.edge.aten.view_copy.default, - ( + args=( weight, [weight_shape[0], K], ), - kwargs, - meta, ) - if isinstance(weight, ProxyValue) - else weight.contiguous().view(weight_shape[0], K) - ) - # From the previous check, if linear_weight is a FakeTensor, it has to be - # a constant (if not, it would be a ProxyValue). Mark it as such. - if isinstance(linear_weight, FakeTensor): - linear_weight.constant = linear_weight + linear_weight.meta = node.meta # Reshape the input from 3d to 2d tensor - in_view = super().call_operator( - exir_ops.edge.aten.view_copy.default, - ( - in_tensor, - [in_shape[0], K], - ), - kwargs, - meta, - ) + with graph.inserting_before(node): + in_view = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=( + in_tensor, + [in_shape[0], K], + ), + ) + in_view.meta = node.meta + # Create the linear node, which multiplies the 2d input and weight # tensors, and adds the 1d bias to produce a 2d output. if quantized_op: @@ -879,17 +969,14 @@ def call_operator(self, op, args, kwargs, meta): bias_scale, out_scale, out_zero_point, - ) = args[7:12] + ) = node.args[7:12] # If the multiplier and shift tensors are provided, use them. - if ( - len(args) >= 14 - and isinstance(args[12], ProxyValue) - and isinstance(args[13], ProxyValue) - ): - out_multiplier = args[12] - out_shift = args[13] + if len(node.args) >= 14: + out_multiplier = node.args[12] + out_shift = node.args[13] # If not, compute them. else: + # pyre-ignore[58]: Division operands requantize_scale = bias_scale / out_scale (out_multiplier, out_shift) = quantize_tensor_multiplier( requantize_scale @@ -907,21 +994,23 @@ def call_operator(self, op, args, kwargs, meta): ) else: linear_args = (in_view, linear_weight, bias) + with graph.inserting_before(node): + linear_res = graph.call_function( + self.trivial_conv_op_to_linear_op[cast(EdgeOpOverload, node.target)], + args=linear_args, + ) + linear_res.meta = node.meta - linear_res = super().call_operator( - self.trivial_conv_op_to_linear_op[op], - linear_args, - kwargs, - meta, - ) # Reshape the output of linear from 2d to 3d tensor - out_res = super().call_operator( - exir_ops.edge.aten.view_copy.default, - (linear_res, list(out_shape)), - kwargs, - meta, - ) - return out_res + with graph.inserting_before(node): + out_res = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(linear_res, list(out_shape)), + ) + out_res.meta = node.meta + + node.replace_all_uses_with(out_res) + return True def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int: @@ -932,133 +1021,202 @@ def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int: return dim -class ExportPassWithTransposeHelper(ExportPass): - def transpose_dims( - self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int - ) -> ProxyValue: - """Helper function to transpose dims of a `proxy` with given `meta`.""" - shape = proxy.data.shape +@register_cadence_pass(CadencePassAttribute(opt_level=3)) +class ReplaceConvWithChannelLastConvPass(RemoveOrReplacePassInterface): + """ + Replace NCHW convolutions with NHWC (channel-last) convolutions by adding + transpose operations before and after the convolution. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.cadence.conv1d.default, + exir_ops.edge.cadence.conv2d.default, + exir_ops.edge.cadence.conv3d.default, + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + ] + + def _transpose_dims( + self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int + ) -> torch.fx.Node: + """Helper function to transpose dims of a node.""" + shape = node.meta["val"].shape dim0, dim1 = ( canonicalize_transposed_dim(dim0, shape), canonicalize_transposed_dim(dim1, shape), ) dim0, dim1 = min(dim0, dim1), max(dim0, dim1) - return super().call_operator( - exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta + transpose_node = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} ) - - -@register_cadence_pass(CadencePassAttribute(opt_level=3)) -class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper): - def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: - shape = proxy.to_tensor().shape + transpose_node.meta = node.meta + return transpose_node + + def _change_nchw_to_nhwc( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert NCHW format to NHWC format.""" + shape = node.meta["val"].shape if len(shape) == 3: - return self.transpose_dims(proxy, meta, 1, -1) + return self._transpose_dims(graph, node, 1, -1) indices = list(range(len(shape))) permute_indices = [indices[0]] + indices[2:] + [indices[1]] - return super().call_operator( - exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} ) - - def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: - shape = proxy.to_tensor().shape + permute_node.meta = node.meta + return permute_node + + def _change_nhwc_to_nchw( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert NHWC format to NCHW format.""" + shape = node.meta["val"].shape if len(shape) == 3: - return self.transpose_dims(proxy, meta, 1, -1) + return self._transpose_dims(graph, node, 1, -1) indices = list(range(len(shape))) permute_indices = [indices[0], indices[-1]] + indices[1:-1] - return super().call_operator( - exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} ) + permute_node.meta = node.meta + return permute_node - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.cadence.convolution.default, - exir_ops.edge.cadence.quantized_conv_nchw.default, - }: - return super().call_operator(op, args, kwargs, meta) - - quantized_op = op == exir_ops.edge.cadence.quantized_conv_nchw.default - - if not quantized_op and len(args) == 8 and args[-1] is True: - # Already in NHWC layout. - return super().call_operator(op, args, kwargs, meta) - - new_op = ( - exir_ops.edge.cadence.quantized_conv_nhwc.default - if quantized_op - else exir_ops.edge.cadence.convolution.default + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + assert isinstance(node.target, EdgeOpOverload) + quantized_op = ( + node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor ) - input_proxy = cast(ProxyValue, args[0]) - weight_proxy = cast(ProxyValue, args[1]) - input_proxy = self.change_nchw_to_nhwc(input_proxy, meta) - weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta) + # Check if already in NHWC layout + if not quantized_op and len(node.args) == 8 and node.args[-1] is True: + return False - # Non-quantized ops still need to set the last optional argument to True. - channel_last_arg = [] if quantized_op else [True] + # Determine the new op target + if quantized_op: + new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor + else: + new_op = node.target - new_args = ( - # Transposed input/weights. - (input_proxy, weight_proxy) - # All other args (bias, quant params, etc) - + tuple(args[2:]) - + tuple(channel_last_arg) - ) - output_proxy = super().call_operator(new_op, new_args, kwargs, meta) - nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta) - return nchw_proxy + graph = node.graph + + # Get input and weight nodes + input_node = cast(torch.fx.Node, node.args[0]) + weight_node = cast(torch.fx.Node, node.args[1]) + + # Insert transpose operations before the node + with graph.inserting_before(node): + # Convert input from NCHW to NHWC + input_nhwc = self._change_nchw_to_nhwc(graph, input_node) + # Convert weight from NCHW to NHWC + weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node) + + # Non-quantized ops need to set the last optional argument to True + channel_last_arg = [] if quantized_op else [True] + + # Create new args with transposed input/weights + new_args = ( + (input_nhwc, weight_nhwc) + + tuple(node.args[2:]) + + tuple(channel_last_arg) + ) + + # Create the new conv operation + new_conv = graph.call_function(new_op, new_args, node.kwargs) + new_conv.meta = node.meta + + # Convert output back from NHWC to NCHW + nchw_output = self._change_nhwc_to_nchw(graph, new_conv) + + # Replace all uses with the final output + node.replace_all_uses_with(nchw_output) + return True @register_cadence_pass(CadencePassAttribute(opt_level=3)) -class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper): - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { +class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface): + """ + Make the slice/cat dimension the outermost dimension by adding transpose + operations before and after the slice/cat operation. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.cat.default, exir_ops.edge.aten.slice_copy.Tensor, - }: - return super().call_operator(op, args, kwargs, meta) - dim = cast(int, args[1]) if len(args) > 1 else 0 - output_shape = meta["val"].shape + ] + + def _transpose_dims( + self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int + ) -> torch.fx.Node: + """Helper function to transpose dims of a node.""" + shape = node.meta["val"].shape + dim0, dim1 = ( + canonicalize_transposed_dim(dim0, shape), + canonicalize_transposed_dim(dim1, shape), + ) + dim0, dim1 = min(dim0, dim1), max(dim0, dim1) + transpose_node = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} + ) + transpose_node.meta = node.meta + return transpose_node + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Get the dimension argument + dim = cast(int, node.args[1]) if len(node.args) > 1 else 0 + output_shape = node.meta["val"].shape + + # Canonicalize dim to be positive if dim < 0: - # Keep dim positive. dim += len(output_shape) + # Not needed if dim is already outermost or all dims before it are 1 if dim == 0 or math.prod(output_shape[:dim]) == 1: - # Not needed if dim is already outermost or all dims before it are 1. - return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta) - - if op == exir_ops.edge.aten.slice_copy.Tensor: - # Transpose -> slice. - slice_args = ( - self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0), - 0, - ) + args[2:] - new_op = super().call_operator(op, slice_args, kwargs, meta) - else: - # (Transpose input0, Transpose input1, ...) -> cat. - cat_in_tensors = [ - self.transpose_dims(t, meta, dim, 0) - for t in cast(list[ProxyValue], args[0]) - ] - new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta) - # slice/cat -> transpose. - return self.transpose_dims(new_op, meta, 0, dim) + return False + + graph = node.graph + + with graph.inserting_before(node): + if node.target == exir_ops.edge.aten.slice_copy.Tensor: + # Transpose input -> slice with dim=0 -> transpose back + input_node = cast(torch.fx.Node, node.args[0]) + transposed_input = self._transpose_dims(graph, input_node, dim, 0) + + # Create slice operation with dim=0 + slice_args = (transposed_input, 0) + node.args[2:] + sliced = graph.call_function( + exir_ops.edge.aten.slice_copy.Tensor, slice_args, node.kwargs + ) + sliced.meta = node.meta + + # Transpose back + result = self._transpose_dims(graph, sliced, 0, dim) + else: + # Cat operation: transpose all inputs -> cat with dim=0 -> transpose back + cat_inputs = cast(list[torch.fx.Node], node.args[0]) + transposed_inputs = [ + self._transpose_dims(graph, t, dim, 0) for t in cat_inputs + ] + + # Create cat operation with dim=0 + catted = graph.call_function( + exir_ops.edge.aten.cat.default, (transposed_inputs, 0), node.kwargs + ) + catted.meta = node.meta + + # Transpose back + result = self._transpose_dims(graph, catted, 0, dim) + + # Replace all uses with the final result + node.replace_all_uses_with(result) + return True @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class ReplaceConvWithIm2RowAndLinear(ExportPass): +class ReplaceConvWithIm2RowAndLinear(RemoveOrReplacePassInterface): """ Replace convolution where groups=1 with im2row followed by a linear op. """ @@ -1066,51 +1224,66 @@ class ReplaceConvWithIm2RowAndLinear(ExportPass): # A map from the convolution op to the linear op that it should # decompose to. conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = { - exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default, - exir_ops.edge.cadence.quantized_conv_nchw.default: exir_ops.edge.cadence.quantized_linear.default, - exir_ops.edge.cadence.quantized_conv_nhwc.default: exir_ops.edge.cadence.quantized_linear.default, + exir_ops.edge.cadence.conv1d.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.conv2d.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.conv3d.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, } - def call_operator(self, op, args, kwargs, meta): - if op not in self.conv_op_to_linear_op: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return list(self.conv_op_to_linear_op.keys()) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Get the relevant args from convolution node. + assert isinstance(node.target, EdgeOpOverload) quantized_op = ( - op == exir_ops.edge.cadence.quantized_conv_nchw.default - or op == exir_ops.edge.cadence.quantized_conv_nhwc.default + node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor + or node.target == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor ) - assert (len(args) == 8 and not quantized_op) or ( - len(args) >= 12 and quantized_op + assert (len(node.args) == 7 and not quantized_op) or ( + len(node.args) >= 12 and quantized_op ), "Inconsistent args for convolution" - (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7] + (in_tensor, weight, bias, stride, padding, dilation, groups) = node.args[0:7] + + assert isinstance(in_tensor, torch.fx.Node) + assert isinstance(weight, torch.fx.Node) # We do not replace depthwise convolution with gemm yet. if groups != 1: - return super().call_operator(op, args, kwargs, meta) + return False + + weight_shape = weight.meta["val"].shape + + # pyre-ignore[6]: Argument type for iteration + stride_list = list(stride) + # pyre-ignore[6]: Argument type for iteration + padding_list = list(padding) + # pyre-ignore[6]: Argument type for iteration + dilation_list = list(dilation) - weight_shape = ( - weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape - ) # If this is a pointwise convolution, im2col will start dominating the # runtime. So we call convolution op for this case. if ( all(x == 1 for x in weight_shape[2:]) - and all(x == 1 for x in stride) - and all(x == 0 for x in padding) - and all(x == 1 for x in dilation) + and all(x == 1 for x in stride_list) + and all(x == 0 for x in padding_list) + and all(x == 1 for x in dilation_list) ): - return super().call_operator(op, args, kwargs, meta) + return False # Get the shapes - out_shape = meta["val"].shape + out_shape = node.meta["val"].shape assert None not in {weight_shape, out_shape} # Determine if the convolution is NCHW or NHWC. The NHWC, i.e., the # channel_last layout is specified by the channel_last arg of conv # op, which is either the last argument (15th) or implicitely False # if the op is quantized, or the last argument if not. - channel_last = op == exir_ops.edge.cadence.quantized_conv_nhwc.default + channel_last = ( + node.target == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor + ) # The weight tensor is [out_channels, in_channels, X] for NCHW layout, # and [out_channels, X, in_channels] for NHWC layout. Here, X is the # kernel_width for conv1d, and X = kernel_height * kernel_width for @@ -1119,74 +1292,53 @@ def call_operator(self, op, args, kwargs, meta): # If the convolution op was quantized, we need the input tensor's # zero_point for im2row. Otherwise in_zero_point defaults to a zero # tensor. - in_zero_point = ( - ( - super().call_operator( - exir_ops.edge.aten.full.default, - ( - [1], - args[7], - ), - {"dtype": torch.int32}, - meta, - ) - if isinstance(in_tensor.to_tensor(), FakeTensor) - else get_zero_point(in_tensor.to_tensor()) - ) - if quantized_op - else torch.tensor(0, dtype=torch.int32) - ) + in_zero_point = node.args[7] if quantized_op else 0 + # im2row expects every kernel parameter to be 2d. So we extend the # parameters for conv1d by prepending their default values. - stride = ([1] + stride) if len(stride) == 1 else stride - padding = ([0] + padding) if len(padding) == 1 else padding - dilation = ([1] + dilation) if len(dilation) == 1 else dilation + stride_2d = ([1] + stride_list) if len(stride_list) == 1 else stride_list + padding_2d = ([0] + padding_list) if len(padding_list) == 1 else padding_list + dilation_2d = ( + ([1] + dilation_list) if len(dilation_list) == 1 else dilation_list + ) kernel_size = ([1] + kernel_size) if len(kernel_size) == 1 else kernel_size # Assert that kernel size does not have a 0 assert 0 not in kernel_size + graph = node.graph + # Create an im2row node with the input. This will create a 2d matrix of # shape [out_height*out_weight, X*in_channels]. X is as defined in the # comment above. im2row_args = ( in_tensor, kernel_size, - dilation, - padding, - stride, + dilation_2d, + padding_2d, + stride_2d, in_zero_point, channel_last, ) - im2row = super().call_operator( - exir_ops.edge.cadence.im2row.default, - im2row_args, - kwargs, - meta, - ) + with graph.inserting_before(node): + im2row = graph.call_function( + exir_ops.edge.cadence.im2row.per_tensor, + args=im2row_args, + ) + im2row.meta = node.meta # Get the product of the >2 dims of the weight K = math.prod(weight_shape[1:]) - # If weight is a ProxyValue, linear_weight needs to be the output of a - # graph operation (in this case a view_copy op) to be an explicit ProxyValue - # as well. If not, the view op can be done directly on the tensor. - linear_weight = ( - super().call_operator( + # Weight is always a Node, so we need a view_copy operation + with graph.inserting_before(node): + linear_weight = graph.call_function( exir_ops.edge.aten.view_copy.default, - ( + args=( weight, [weight_shape[0], K], ), - kwargs, - meta, ) - if isinstance(weight, ProxyValue) - else weight.contiguous().view(weight_shape[0], K) - ) - # From the previous check, if linear_weight is a FakeTensor, it has to be - # a constant (if not, it would be a ProxyValue). Mark it as such. - if isinstance(linear_weight, FakeTensor): - linear_weight.constant = linear_weight + linear_weight.meta = node.meta # Create the linear node, which multiplies the 3d input with 2d weight # tensors with bias addition. The outermost dimension of the input is @@ -1198,17 +1350,14 @@ def call_operator(self, op, args, kwargs, meta): bias_scale, out_scale, out_zero_point, - ) = args[7:12] + ) = node.args[7:12] # If the multiplier and shift tensors are provided, use them. - if ( - len(args) >= 14 - and isinstance(args[12], ProxyValue) - and isinstance(args[13], ProxyValue) - ): - out_multiplier = args[12] - out_shift = args[13] + if len(node.args) >= 14: + out_multiplier = node.args[12] + out_shift = node.args[13] # If not, compute them. else: + # pyre-ignore[58]: Division operands requantize_scale = bias_scale / out_scale (out_multiplier, out_shift) = quantize_tensor_multiplier( requantize_scale @@ -1226,34 +1375,40 @@ def call_operator(self, op, args, kwargs, meta): ) else: linear_args = (im2row, linear_weight, bias) - linear_res = super().call_operator( - self.conv_op_to_linear_op[op], - linear_args, - kwargs, - meta, - ) + + with graph.inserting_before(node): + linear_res = graph.call_function( + self.conv_op_to_linear_op[cast(EdgeOpOverload, node.target)], + args=linear_args, + ) + linear_res.meta = node.meta + # The output of linear is a 3D tensor. However, the output is in NHWC # layout by default, because an input vector of size X is multiplied # with the weight matrix, i.e., column values are contiguous. If the # channel_last is False, we want to transpose this output. if not channel_last: - linear_res = super().call_operator( - exir_ops.edge.aten.transpose_copy.int, - (linear_res, 1, 2), - kwargs, - meta, - ) + with graph.inserting_before(node): + linear_res = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, + args=(linear_res, 1, 2), + ) + linear_res.meta = node.meta + # And finally, we want to view the 3D output of linear op as 4D tensor - return super().call_operator( - exir_ops.edge.aten.view_copy.default, - (linear_res, list(out_shape)), - kwargs, - meta, - ) + with graph.inserting_before(node): + out_res = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(linear_res, list(out_shape)), + ) + out_res.meta = node.meta + + node.replace_all_uses_with(out_res) + return True @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class ReplaceTransposedConvWithLinearPass(ExportPass): +class ReplaceTransposedConvWithLinearPass(RemoveOrReplacePassInterface): """ Replace transposed convolution where groups=1 with transposed_im2row followed by a linear op. @@ -1266,15 +1421,20 @@ class ReplaceTransposedConvWithLinearPass(ExportPass): exir_ops.edge.cadence.quantized_transposed_conv.default: exir_ops.edge.cadence.quantized_linear.default, } - def call_operator(self, op, args, kwargs, meta): - if op not in self.transposed_conv_op_to_linear_op: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return list(self.transposed_conv_op_to_linear_op.keys()) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Get the relevant args from transposed_convolution node. - quantized_op = op == exir_ops.edge.cadence.quantized_transposed_conv.default - assert len(args) == ( - 16 if quantized_op else 9 - ), "Inconsistent args for transposed_convolution" + assert isinstance(node.target, EdgeOpOverload) + quantized_op = ( + node.target == exir_ops.edge.cadence.quantized_transposed_conv.default + ) + expected_args = 16 if quantized_op else 9 + if len(node.args) != expected_args: + return False + ( in_tensor, weight, @@ -1284,23 +1444,23 @@ def call_operator(self, op, args, kwargs, meta): dilation, output_padding, groups, - ) = args[0:8] + ) = node.args[0:8] # We do not replace depthwise transposed_convolution with gemm yet. if groups != 1: - return super().call_operator(op, args, kwargs, meta) + return False # Get the shapes - out_shape = meta["val"].shape - weight_shape = ( - weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape - ) - assert None not in {weight_shape, out_shape} + assert isinstance(weight, torch.fx.Node) + out_shape = node.meta["val"].shape + weight_shape = weight.meta["val"].shape + if None in {weight_shape, out_shape}: + return False # Determine if the transposed_convolution is NCHW or NHWC. The NHWC, # i.e., the channel_last layout is specified by the channel_last arg # of transposed_conv op, which is the last argument. - channel_last = args[-1] + channel_last = node.args[-1] # The weight tensor is [out_channels, in_channels, X] for NCHW layout, # and [out_channels, X, in_channels] for NHWC layout. Here, X is the # kernel_width for conv1d, and X = kernel_height * kernel_width for @@ -1309,22 +1469,35 @@ def call_operator(self, op, args, kwargs, meta): # If the transposed_convolution op was quantized, we need the input tensor's # zero_point for im2row. Otherwise in_zero_point defaults to a zero # tensor. + assert isinstance(in_tensor, torch.fx.Node) in_zero_point = ( - get_zero_point(in_tensor.to_tensor()) + get_zero_point(in_tensor.meta["val"]) if quantized_op else torch.tensor(0, dtype=torch.int32) ) + + # Cast to appropriate types + stride = cast(Sequence[int], stride) + padding = cast(Sequence[int], padding) + dilation = cast(Sequence[int], dilation) + output_padding = cast(Sequence[int], output_padding) + # transposed_im2row expects every kernel parameter to be 2d. So we extend the # parameters for conv1d by prepending their default values. - stride = ([1] + stride) if len(stride) == 1 else stride - padding = ([0] + padding) if len(padding) == 1 else padding - dilation = ([1] + dilation) if len(dilation) == 1 else dilation - output_padding = ( - ([0] + output_padding) if len(output_padding) == 1 else output_padding + stride_list = ([1] + list(stride)) if len(stride) == 1 else list(stride) + padding_list = ([0] + list(padding)) if len(padding) == 1 else list(padding) + dilation_list = ([1] + list(dilation)) if len(dilation) == 1 else list(dilation) + output_padding_list = ( + ([0] + list(output_padding)) + if len(output_padding) == 1 + else list(output_padding) ) kernel_size = ([1] + kernel_size) if len(kernel_size) == 1 else kernel_size - # Assert that kernel size does not have a 0 - assert 0 not in kernel_size + # Check that kernel size does not have a 0 + if 0 in kernel_size: + return False + + graph = node.graph # Create a transposed_im2row node with the input. This will create a 2d # matrix of shape [out_height*out_weight, X*in_channels]. X is as @@ -1332,42 +1505,33 @@ def call_operator(self, op, args, kwargs, meta): transposed_im2row_args = ( in_tensor, kernel_size, - dilation, - padding, - stride, - output_padding, + dilation_list, + padding_list, + stride_list, + output_padding_list, in_zero_point, channel_last, ) - transposed_im2row = super().call_operator( - exir_ops.edge.cadence.transposed_im2row.default, - transposed_im2row_args, - kwargs, - meta, - ) + with graph.inserting_before(node): + transposed_im2row = graph.call_function( + exir_ops.edge.cadence.transposed_im2row.default, + args=transposed_im2row_args, + ) + transposed_im2row.meta = node.meta + # Reshape the weight to [out_channels, in_channels * X] K = math.prod(weight_shape[1:]) - # If weight is a ProxyValue, linear_weight needs to be the output of a - # graph operation (in this case a view_copy op) to be an explicit ProxyValue - # as well. If not, the view op can be done directly on the tensor. - linear_weight = ( - super().call_operator( + # Weight is always a Node, so we need a view_copy operation + with graph.inserting_before(node): + linear_weight = graph.call_function( exir_ops.edge.aten.view_copy.default, - ( + args=( weight, [weight_shape[0], K], ), - kwargs, - meta, ) - if isinstance(weight, ProxyValue) - else weight.contiguous().view(weight_shape[0], K) - ) - # From the previous check, if linear_weight is a FakeTensor, it has to be - # a constant (if not, it would be a ProxyValue). Mark it as such. - if isinstance(linear_weight, FakeTensor): - linear_weight.constant = linear_weight + linear_weight.meta = node.meta # Create the linear node, which multiplies the 3d input with 2d weight # tensors with bias addition. The outermost dimension of the input is @@ -1379,7 +1543,8 @@ def call_operator(self, op, args, kwargs, meta): bias_scale, out_scale, out_zero_point, - ) = args[8:13] + ) = node.args[8:13] + # pyre-ignore[58]: Division operands requantize_scale = bias_scale / out_scale (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale) linear_args = ( @@ -1395,58 +1560,67 @@ def call_operator(self, op, args, kwargs, meta): ) else: linear_args = (transposed_im2row, linear_weight, bias) - linear_res = super().call_operator( - self.transposed_conv_op_to_linear_op[op], - linear_args, - kwargs, - meta, - ) + + with graph.inserting_before(node): + linear_res = graph.call_function( + self.transposed_conv_op_to_linear_op[cast(EdgeOpOverload, node.target)], + args=linear_args, + ) + linear_res.meta = node.meta + # The output of linear is a 3D tensor. However, the output is in NHWC # layout by default, because an input vector of size X is multiplied # with the weight matrix, i.e., column values are contiguous. If the # channel_last is False, we want to transpose this output. if not channel_last: - linear_res = super().call_operator( - exir_ops.edge.aten.transpose_copy.int, - (linear_res, 1, 2), - kwargs, - meta, - ) + with graph.inserting_before(node): + linear_res = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, + args=(linear_res, 1, 2), + ) + linear_res.meta = node.meta + # And finally, we want to view the 3D output of linear op as 4D tensor - return super().call_operator( - exir_ops.edge.aten.view_copy.default, - (linear_res, list(out_shape)), - kwargs, - meta, - ) + with graph.inserting_before(node): + out_res = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(linear_res, list(out_shape)), + ) + out_res.meta = node.meta + + node.replace_all_uses_with(out_res) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceNopTransposeOrPermuteWithViewPass(ExportPass): +class ReplaceNopTransposeOrPermuteWithViewPass(RemoveOrReplacePassInterface): """ If the transpose/permute op does not change the byte order (e.g., transpose/permute from Nx1xHxW to NxHx1xW), then it can be replaced by view op. """ - def call_operator(self, op, args, kwargs, meta): - # Only proceed for transpose or permute op. - if op not in { + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.transpose_copy.int, exir_ops.edge.aten.permute_copy.default, - }: - return super().call_operator(op, args, kwargs, meta) + ] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Get the input tensor and shape - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] - in_shape = in_tensor.shape + in_tensor_node = node.args[0] + assert isinstance(in_tensor_node, torch.fx.Node) + in_shape = in_tensor_node.meta["val"].shape # Get the output tensor shape - out_shape = meta["val"].shape + out_shape = node.meta["val"].shape - if op == exir_ops.edge.aten.transpose_copy.int: + if node.target == exir_ops.edge.aten.transpose_copy.int: # Get the two dims to be transposed - dim0 = args[1] if args[1] >= 0 else in_tensor.dim() + args[1] - dim1 = args[2] if args[2] >= 0 else in_tensor.dim() + args[2] + dim0 = cast(int, node.args[1]) + dim1 = cast(int, node.args[2]) + dim0 = dim0 if dim0 >= 0 else len(in_shape) + dim0 + dim1 = dim1 if dim1 >= 0 else len(in_shape) + dim1 # We can eliminate transpose if (a) the size at dim0 and dim1 is 1; # (b) the size at dim0 or dim1 is 1, and dim0 and dim1 are consecutive. both_one = in_shape[dim0] == 1 and in_shape[dim1] == 1 @@ -1454,17 +1628,22 @@ def call_operator(self, op, args, kwargs, meta): in_shape[dim0] == 1 or in_shape[dim1] == 1 ) if both_one or either_one_and_consecutive: - new_args = (args[0], list(out_shape)) - return super().call_operator( - exir_ops.edge.aten.view_copy.default, new_args, kwargs, meta - ) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(in_tensor_node, list(out_shape)), + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True - elif op == exir_ops.edge.aten.permute_copy.default: - old_dims = list(range(in_tensor.dim())) - new_dims = args[1] + elif node.target == exir_ops.edge.aten.permute_copy.default: + old_dims = list(range(len(in_shape))) + new_dims = cast(Sequence[int], node.args[1]) # If the permute does not change anything, return the input as output. - if old_dims == new_dims: - return args[0] + if old_dims == list(new_dims): + node.replace_all_uses_with(in_tensor_node) + return True # Get the old dim order, and the permuted dim order for all dims that # are not 1. old_order = [ @@ -1475,22 +1654,30 @@ def call_operator(self, op, args, kwargs, meta): ] # If the byte ordering for non-unit dims is unchanged, this is a nop. if old_order == new_order: - new_args = (args[0], list(out_shape)) - return super().call_operator( - exir_ops.edge.aten.view_copy.default, new_args, kwargs, meta - ) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(in_tensor_node, list(out_shape)), + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True - return super().call_operator(op, args, kwargs, meta) + return False def call(self, graph_module: torch.fx.GraphModule) -> PassResult: result = super().call(graph_module) - fuse_cascaded_result = none_throws(FuseCascadedViewOps()(result.graph_module)) - result = none_throws(ExportPass()(fuse_cascaded_result.graph_module)) + # If this pass made modifications, fuse any cascaded view ops that may have been created + if result.modified: + fuse_cascaded_result = FuseCascadedViewOps().call(result.graph_module) + + # True because we are in the 'if modified' block + return PassResult(fuse_cascaded_result.graph_module, True) return result @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class ReplaceLinearWithFullyConnectedOpPass(ExportPass): +class ReplaceLinearWithFullyConnectedOpPass(RemoveOrReplacePassInterface): """ If the input of linear/quantized_linear op is a vector, replace it with fully_connected op. @@ -1501,253 +1688,199 @@ class ReplaceLinearWithFullyConnectedOpPass(ExportPass): exir_ops.edge.cadence.quantized_linear.default: exir_ops.edge.cadence.quantized_fully_connected.default, } - def call_operator(self, op, args, kwargs, meta): - # Only proceed for linear or quantized_linear ops. - if op not in self.linear_to_fc_op: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return list(self.linear_to_fc_op.keys()) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Extract the input tensor - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] - leading_dims = math.prod(in_tensor.shape[:-1]) + in_tensor_arg = node.args[0] + assert isinstance(in_tensor_arg, torch.fx.Node) + in_tensor_shape = in_tensor_arg.meta["val"].shape + leading_dims = math.prod(in_tensor_shape[:-1]) # If the tensor is not a vector, do nothing. if leading_dims != 1: - return super().call_operator(op, args, kwargs, meta) + return False # Replace the linear with fully connected op - return super().call_operator( - self.linear_to_fc_op[op], - args, - kwargs, - meta, - ) + assert isinstance(node.target, EdgeOpOverload) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + self.linear_to_fc_op[cast(EdgeOpOverload, node.target)], + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass) @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceScalarTensorWithFullPass(ExportPass): +class ReplaceScalarTensorWithFullPass(RemoveOrReplacePassInterface): """ aten.scalar_tensor can be replaced by aten.full with a shape of [1]. scalar_tensor is not supported, so this is an opt_level=0 pass. """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.scalar_tensor.default, + @property + def targets(self) -> list[EdgeOpOverload]: + return [ torch.ops.aten.scalar_tensor.default, - }: - return super().call_operator(op, args, kwargs, meta) + exir_ops.edge.aten.scalar_tensor.default, + ] - return super().call_operator( - exir_ops.edge.aten.full.default, - ( - [1], - args[0], - ), - {"dtype": torch.float32}, - meta, - ) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=( + [1], + node.args[0], + ), + kwargs={"dtype": torch.float32}, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceFullLikeWithFullPass(ExportPass): +class ReplaceFullLikeWithFullPass(RemoveOrReplacePassInterface): """ aten.full_like can be replaced by aten.full with the shape of the arg tensor. full_like is not supported, so this is an opt_level=0 pass. """ - def call_operator(self, op, args, kwargs, meta): - if op not in { - exir_ops.edge.aten.full_like.default, - }: - return super().call_operator(op, args, kwargs, meta) - - # Get the shape of the "like" tensor, and pass that in to the full op. - return super().call_operator( - exir_ops.edge.aten.full.default, - ( - ( - args[0].to_tensor().shape - if isinstance(args[0], ProxyValue) - else args[0].shape - ), - args[1], - ), - {}, - meta, - ) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.full_like.default] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + input_arg = node.args[0] + assert isinstance(input_arg, torch.fx.Node) + shape = input_arg.meta["val"].shape + fill_value = node.args[1] + + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=(shape, fill_value), + kwargs={}, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceInfArgInFullWithValuePass(ExportPass): +class ReplaceInfArgInFullWithValuePass(RemoveOrReplacePassInterface): """ aten.full allows "-inf" and "inf" as inputs. The profiler cannot handle that, so replace them with the maximum value of the type. """ - def call_operator(self, op, args, kwargs, meta): - if op not in { - exir_ops.edge.aten.full.default, - }: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.full.default] - new_args = list(args) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - if args[1] == float("-inf"): + new_args = list(node.args) + fill_value = node.args[1] + if fill_value == float("-inf"): new_args[1] = torch.finfo(torch.float32).min - elif args[1] == float("inf"): + elif fill_value == float("inf"): new_args[1] = torch.finfo(torch.float32).max + else: + return False - return super().call_operator(op, tuple(new_args), kwargs, meta) - - -@register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass): - """ - Replace ops with single element arguments (size = [1]) with overloads that accept scalar ints/floats. - """ - - # Keep track of which operators and arguments are being replaced. - replaced_scalar_args: dict[ - EdgeOpOverloadPacket, tuple[EdgeOpOverload, Sequence[int]] - ] = { - exir_ops.edge.cadence.quantized_add: ( - exir_ops.edge.cadence.quantized_add.per_tensor, - [1, 2, 4, 5], - ), - exir_ops.edge.cadence.quantized_conv_nchw: ( - exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - [8, 9, 12, 13], - ), - exir_ops.edge.cadence.quantized_conv_nhwc: ( - exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, - [8, 9, 12, 13], - ), - exir_ops.edge.cadence.quantized_fully_connected: ( - exir_ops.edge.cadence.quantized_fully_connected.per_tensor, - [4, 5, 6], - ), - exir_ops.edge.cadence.quantized_layer_norm: ( - exir_ops.edge.cadence.quantized_layer_norm.per_tensor, - [1, 2], - ), - exir_ops.edge.cadence.quantized_linear: ( - exir_ops.edge.cadence.quantized_linear.per_tensor, - [4, 5, 6], - ), - exir_ops.edge.cadence.quantized_relu: ( - exir_ops.edge.cadence.quantized_relu.per_tensor, - [1, 3, 4], - ), - exir_ops.edge.cadence.im2row: ( - exir_ops.edge.cadence.im2row.per_tensor, - [5], - ), - exir_ops.edge.cadence.requantize: ( - exir_ops.edge.cadence.requantize.per_tensor, - [1, 2, 3, 4], - ), - } - - def call_operator(self, op, args, kwargs, meta): - op_edge_overload_packet = get_edge_overload_packet(op) - - if op_edge_overload_packet not in self.replaced_scalar_args: - return super().call_operator(op, args, kwargs, meta) - - # Get all the args that need to be replaced. - new_op, args_to_be_replaced = self.replaced_scalar_args[op_edge_overload_packet] - - updated_args = list(args) - for op_arg_index in args_to_be_replaced: - arg = args[op_arg_index] - if not isinstance(arg, ProxyValue): - return super().call_operator(op, args, kwargs, meta) - - if not arg.is_tensor(): - return super().call_operator(op, args, kwargs, meta) - - if not isinstance(arg.node.target, EdgeOpOverload): - return super().call_operator(op, args, kwargs, meta) - - if get_edge_overload_packet(arg.node.target) != exir_ops.edge.aten.full: - # Only replace if arg generated by a full op. - return super().call_operator(op, args, kwargs, meta) - - if tuple(arg.node.args[0]) != (1,): - # Only replace if the size of the full op is [1]. - return super().call_operator(op, args, kwargs, meta) - - updated_args[op_arg_index] = arg.node.args[1] + new_args = tuple(new_args) - return super().call_operator( - new_op, - tuple(updated_args), - kwargs, - meta, - ) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=new_args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAtenAvgPoolWithCadenceAvgPoolPass(ExportPass): +class ReplaceAtenAvgPoolWithCadenceAvgPoolPass(RemoveOrReplacePassInterface): """ Replace the aten avg_pool op with the cadence custom avg_pool2d op. """ - def call_operator(self, op, args, kwargs, meta): - # Only continue for avg_pool op - if op not in { + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.avg_pool1d.default, exir_ops.edge.aten.avg_pool2d.default, - }: - return super().call_operator(op, args, kwargs, meta) + ] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Determine if the op is avg_pool1d or avg_pool2d - avg_pool1d: bool = op == exir_ops.edge.aten.avg_pool1d.default - # Get the input tensor - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + avg_pool1d: bool = node.target == exir_ops.edge.aten.avg_pool1d.default + + # Get the input tensor node + in_tensor_node = node.args[0] + assert isinstance(in_tensor_node, torch.fx.Node) # Replace avg_pool2d with custom avg_pool2d, and if the input tensor is # quantized, pass its zero_point tensor as arg to the custom avg_pool2d. # stride, padding, ceil_mode, count_include_pad, divisor_override, are # the native avg_pool2d args. 'channel_last' denotes NCHW vs NHWC layout, # and is False by default. - kernel_size = args[1] - stride = args[2] if len(args) >= 3 else [1, 1] - padding = args[3] if len(args) >= 4 else [0, 0] - ceil_mode = args[4] if len(args) >= 5 else False - count_include_pad = args[5] if len(args) >= 6 else True - divisor_override = args[6] if len(args) >= 7 else None - zero_point = torch.tensor(0, dtype=torch.int32) + kernel_size = node.args[1] + # When stride is not provided or is empty, PyTorch defaults to kernel_size + stride = node.args[2] if len(node.args) >= 3 and node.args[2] else kernel_size + padding = node.args[3] if len(node.args) >= 4 else [0, 0] + ceil_mode = node.args[4] if len(node.args) >= 5 else False + count_include_pad = node.args[5] if len(node.args) >= 6 else True + divisor_override = node.args[6] if len(node.args) >= 7 else None + zero_point = node.args[7] if len(node.args) >= 8 else None + + graph = node.graph + out_shape = node.meta["val"].shape + + kernel_size = cast(Sequence[int], kernel_size) + stride = cast(Sequence[int], stride) + padding = cast(Sequence[int], padding) # If the op is avg_pool1d, then we need to reshape the 3d input to a 4d # tensor. if avg_pool1d: - in_shape = list(in_tensor.shape) + in_shape = list(in_tensor_node.meta["val"].shape) assert len(in_shape) == 3, "Expected 3d input for avg_pool1d" - in_shape.insert(2, 1) - out_shape = meta["val"].shape - in_view_op = super().call_operator( - exir_ops.edge.aten.view_copy.default, - (in_tensor, in_shape), - kwargs, - meta, - ) + in_shape_4d = in_shape[:2] + [1] + in_shape[2:] + + with graph.inserting_before(node): + in_view_node = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(in_tensor_node, in_shape_4d), + ) + in_view_node.meta = node.meta + # Extend the kernel_size, stride and padding to 2d - kernel_size = [1] + kernel_size if len(kernel_size) == 1 else kernel_size - stride = [1] + stride if len(stride) == 1 else stride - padding = [0] + padding if len(padding) == 1 else padding + kernel_size = ( + [1] + list(kernel_size) if len(kernel_size) == 1 else kernel_size + ) + stride = [1] + list(stride) if len(stride) == 1 else stride + padding = [0] + list(padding) if len(padding) == 1 else padding + + input_for_pool = in_view_node + else: + input_for_pool = in_tensor_node # Create a new avg_pool node with the updated args new_args = ( - in_view_op if avg_pool1d else args[0], + input_for_pool, kernel_size, stride, padding, @@ -1757,70 +1890,66 @@ def call_operator(self, op, args, kwargs, meta): zero_point, False, ) - avg_pool2d_op = super().call_operator( - exir_ops.edge.cadence.avg_pool2d.default, - new_args, - kwargs, - meta, - ) - # If the node was avg_pool1d, we again reshape the 4d output to 3d output - return ( - super().call_operator( - exir_ops.edge.aten.view_copy.default, - (avg_pool2d_op, list(out_shape)), - kwargs, - meta, + with graph.inserting_before(node): + avg_pool2d_node = graph.call_function( + exir_ops.edge.cadence.avg_pool2d.default, + args=new_args, ) - if avg_pool1d - else avg_pool2d_op - ) + avg_pool2d_node.meta = node.meta + + # If the node was avg_pool1d, we again reshape the 4d output to 3d output + if avg_pool1d: + with graph.inserting_before(node): + result_node = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(avg_pool2d_node, list(out_shape)), + ) + result_node.meta = node.meta + node.replace_all_uses_with(result_node) + else: + node.replace_all_uses_with(avg_pool2d_node) + + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceIm2RowWithViewPass(ExportPass): - def can_replace(self, op, args, kwargs, meta) -> bool: - if op != exir_ops.edge.cadence.im2row.default: - return False +class ReplaceIm2RowWithViewPass(RemoveOrReplacePassInterface): + """ + Replace im2row with view when possible (no padding, no dilation, and output spatial dimensions are 1). + """ + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.cadence.im2row.default] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Check if im2row applies padding. If yes, we cannot replace it with view. - pad = cast(tuple[int, ...], args[3]) + pad = cast(Sequence[int], node.args[3]) if any(p != 0 for p in pad): return False # Check if im2row has dilation. If yes, we cannot replace it with view. - dilation = cast(tuple[int, ...], args[2]) + dilation = cast(Sequence[int], node.args[2]) if any(d != 1 for d in dilation): return False # im2row works on 3D or 4D tensors. # Output shape[1:-1] will be unit if input spatial dimensions are the same as kernel spatial dimensions. - output_shape = meta["val"].shape - if math.prod(output_shape[1:-1]) == 1: - return True + output_shape = node.meta["val"].shape + if math.prod(output_shape[1:-1]) != 1: + return False - return False + # Replace im2row with view_copy + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(node.args[0], list(output_shape)), + ) + new_node.meta = node.meta - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.cadence.im2row.default: - return super().call_operator(op, args, kwargs, meta) - - if not self.can_replace(op, args, kwargs, meta): - return super().call_operator(op, args, kwargs, meta) - - output_shape = meta["val"].shape - return super().call_operator( - exir_ops.edge.aten.view_copy.default, - (args[0], tuple(output_shape)), - kwargs, - meta, - ) + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -1839,98 +1968,141 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - ret = super().call(graph_module) - modified = ret.graph_module.graph.eliminate_dead_code() or ret.modified - return PassResult(ret.graph_module, modified) + changed = False + for module in filter( + lambda m: isinstance(m, torch.fx.GraphModule), graph_module.modules() + ): + module = cast(torch.fx.GraphModule, module) + for node in module.graph.nodes: + if node.op != "call_function": + continue + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor) and val.numel() == 0: + with module.graph.inserting_before(node): + new_node = module.graph.call_function( + exir_ops.edge.aten.full.default, + args=(val.shape, 0), + kwargs={"dtype": val.dtype}, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + changed = True + + if changed: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) + + return PassResult(graph_module, False) @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceWhereWithFullArgsWithWhereScalar(ExportPass): +class ReplaceWhereWithFullArgsWithWhereScalar(RemoveOrReplacePassInterface): """Replaces where ops using two full ops as tensors with a scalar version. """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.where.self, - }: - return super().call_operator(op, args, kwargs, meta) - - # If the args are not full ops, bail - # pyre-ignore[16]: `ProxyValue` has no attribute `node`. - if (args[1].node.target != exir_ops.edge.aten.full.default) or ( - args[2].node.target != exir_ops.edge.aten.full.default - ): - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.where.self] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Check if args[1] and args[2] are full ops + arg1 = node.args[1] + arg2 = node.args[2] + + if not isinstance(arg1, torch.fx.Node) or not isinstance(arg2, torch.fx.Node): + return False - # If one of the full ops is a different size than than the cond tensor, we need to broadcast. Bail. if ( - # pyre-ignore[16]: `ProxyValue` has no attribute `node`. - list(args[0].to_tensor().shape) != args[1].node.args[0] - or list(args[0].to_tensor().shape) != args[2].node.args[0] + arg1.target != exir_ops.edge.aten.full.default + or arg2.target != exir_ops.edge.aten.full.default ): - return super().call_operator(op, args, kwargs, meta) + return False - # Get the scalar values from the full ops - scalar_value_1 = args[1].node.args[1] - scalar_value_2 = args[2].node.args[1] + # Get the condition tensor shape + cond_arg = node.args[0] + assert isinstance(cond_arg, torch.fx.Node) + cond_shape = list(cond_arg.meta["val"].shape) - # Replace the where op with a scalar where op - return super().call_operator( - exir_ops.edge.cadence.where_Scalar.default, - (args[0], scalar_value_1, scalar_value_2), - kwargs, - meta, - ) + # Check if the full ops have the same size as the cond tensor + full1_shape = arg1.args[0] + full2_shape = arg2.args[0] - return super().call_operator(op, args, kwargs, meta) + if cond_shape != full1_shape or cond_shape != full2_shape: + return False + # Get the scalar values from the full ops + scalar_value_1 = arg1.args[1] + scalar_value_2 = arg2.args[1] -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAtenApproxGeluWithApproxGeluPass(ExportPass): - """ - Replace the aten gelu op with an approximate arg with an approximate gelu op. - """ - - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.gelu.default, - }: - return super().call_operator(op, args, kwargs, meta) - return super().call_operator(op, args, kwargs, meta) + # Replace the where op with a scalar where op + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.cadence.where_Scalar.default, + args=(cond_arg, scalar_value_1, scalar_value_2), + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True # Adapted from fbcode/pyspeech/opt_passes/replace_ops.py @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class ReplaceSplitWithSlicePass(ExportPass): +class ReplaceSplitWithSlicePass(RemoveOrReplacePassInterface): """ split_with_sizes() delegates to slice() op, so perform this replacement here. This avoids the expense of delegation from ATen. """ - # For split_with_sizes, return the slice dim and extent for each split. - def get_split_sizes( - self, graph_module: torch.fx.GraphModule, node: torch.fx.Node - ) -> Optional[list[tuple[int, ...]]]: + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.split_with_sizes_copy.default] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # All the users of this split_with_sizes op must be getitem ops + if any(user.target != operator.getitem for user in node.users): + return False + + # Get the slice dim and extent for each split + slice_ops = self._get_split_sizes(node) + if slice_ops is None: + return False + + graph = node.graph + + # Go over each getitem user, and replace it with slice op + for user in list(node.users.keys()): + assert user.target == operator.getitem + item_idx = int(user.args[1]) + assert item_idx < len(slice_ops) + cur_slice = slice_ops[item_idx] + with graph.inserting_before(user): + cur_slice_node = graph.call_function( + exir_ops.edge.aten.slice_copy.Tensor, + (node.args[0], cur_slice[0], cur_slice[1], cur_slice[2], 1), + ) + # Metadata copy important + cur_slice_node.meta = user.meta + user.replace_all_uses_with(cur_slice_node) + + # Return True to indicate the split node should be removed + return True + + def _get_split_sizes(self, node: torch.fx.Node) -> Optional[list[tuple[int, ...]]]: + """For split_with_sizes, return the slice dim and extent for each split.""" # Parse the args of the split_with_sizes op tensor_arg, split_sizes = node.args[0:2] assert isinstance(tensor_arg, torch.fx.Node) - in_shape = get_shape(graph_module, tensor_arg) - split_dim = 0 if len(node.args) < 3 else node.args[2] - if in_shape is None: + + # Get shape from node metadata + val = tensor_arg.meta.get("val") + if val is None: return None + in_shape = val.shape + + split_dim = 0 if len(node.args) < 3 else node.args[2] # Canonicalize the split dimension assert isinstance(split_dim, int) @@ -1948,103 +2120,69 @@ def get_split_sizes( return slice_ops - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph = graph_module.graph - for node in graph.nodes: - if not isinstance(node.target, EdgeOpOverload): - continue - if ( - get_edge_overload_packet(node.target) - != exir_ops.edge.aten.split_with_sizes_copy - ): - continue - # All the users of this split_with_sizes op must be getitem ops - if any(user.target != operator.getitem for user in node.users): - continue - - # Get the slice dim and extent for each split - slice_ops = self.get_split_sizes(graph_module, node) - if slice_ops is None: - continue - # Go over each getitem user, and replace it with slice op - for user in list(node.users.keys()): - assert user.target == operator.getitem - item_idx = user.args[1] - assert item_idx < len(slice_ops) - cur_slice = slice_ops[item_idx] - with graph.inserting_before(user): - cur_slice_node = graph.call_function( - exir_ops.edge.aten.slice_copy.Tensor, - (node.args[0], cur_slice[0], cur_slice[1], cur_slice[2], 1), - ) - user.replace_all_uses_with(cur_slice_node) - graph.erase_node(user) +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplacePowWithMulPass(RemoveOrReplacePassInterface): + """ + Replace the pow op with successive mul ops when the exponent is an + integer between 2 and 4 (inclusive). + """ - graph.erase_node(node) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.pow.Tensor_Scalar] - graph_module.recompile() - result = super().call(graph_module) - return result + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Check if we have at least 2 args and the exponent is an int + if len(node.args) < 2 or not isinstance(node.args[1], int): + return False + exponent = cast(int, node.args[1]) -@register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplacePowWithMulPass(ExportPass): - """ - Replace the pow op for a mul op. - """ + # Only replace if exponent is between 2 and 4 (inclusive) + if exponent < 2 or exponent > 4: + return False - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if not ( - len(args) > 1 - and isinstance(args[1], int) - and cast(int, args[1]) > 1 - and cast(int, args[1]) < 5 - and op - in { - exir_ops.edge.aten.pow.Tensor_Scalar, - } - ): - return super().call_operator(op, args, kwargs, meta) + x = node.args[0] + assert isinstance(x, torch.fx.Node) - x = args[0] - exponent = cast(int, args[1]) + graph = node.graph + result_node = x - if exponent > 2: - for _ in range(exponent, 2, -1): - x = super().call_operator( + # Create successive mul operations + # For exponent=2: x * x (1 mul) + # For exponent=3: (x * x) * x (2 muls) + # For exponent=4: ((x * x) * x) * x (3 muls) + for _ in range(exponent - 1): + with graph.inserting_before(node): + result_node = graph.call_function( exir_ops.edge.aten.mul.Tensor, - (x, args[0]), - {}, - meta, + args=(result_node, x), ) - return super().call_operator( - exir_ops.edge.aten.mul.Tensor, - (x, args[0]), - {}, - meta, - ) + result_node.meta = node.meta + + node.replace_all_uses_with(result_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceMatmulWithTransposedMatmulPass(ExportPass): +class ReplaceMatmulWithTransposedMatmulPass(RemoveOrReplacePassInterface): """ For certain backends, we have efficient kernels for transposed matmul. We replace AxB with AxB' for such backends. """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.cadence.quantized_matmul.default or args[-1] is True: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.cadence.quantized_matmul.default] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # If already transposed, bail + if len(node.args) >= 9 and node.args[-1] is True: + return False # Get the args - if len(args) == 9: + if len(node.args) == 9: ( X_arg, X_zero_point, @@ -2055,8 +2193,8 @@ def call_operator(self, op, args, kwargs, meta): out_shift, out_zero_point, transposed, - ) = args - elif len(args) == 8: + ) = node.args + elif len(node.args) == 8: ( X_arg, X_zero_point, @@ -2066,42 +2204,43 @@ def call_operator(self, op, args, kwargs, meta): out_multiplier, out_shift, out_zero_point, - ) = args + ) = node.args transposed = False else: raise AssertionError( - f"Unexpected number of args for quantized_matmul: {len(args)}" + f"Unexpected number of args for quantized_matmul: {len(node.args)}" ) # If the matmul is already transposed, bail if transposed: - return super().call_operator(op, args, kwargs, meta) - - # Get the second tensor - Y_tensor = Y_arg.to_tensor() if isinstance(Y_arg, ProxyValue) else Y_arg - # Concretize the bias - zero_bias = super().call_operator( - exir_ops.edge.aten.full.default, - ([Y_tensor.size(-1)], 0), - {"dtype": torch.int32}, - meta, - ) + return False + + # Get the second tensor from metadata + assert isinstance(Y_arg, torch.fx.Node) + Y_tensor_val = Y_arg.meta.get("val") + if Y_tensor_val is None: + return False + + graph = node.graph + + # Create zero bias + with graph.inserting_before(node): + zero_bias = graph.call_function( + exir_ops.edge.aten.full.default, + args=([Y_tensor_val.size(-1)], 0), + kwargs={"dtype": torch.int32}, + ) + zero_bias.meta = node.meta - # If the arg was a ProxyValue, insert a transpose node. Otherwise we - # can simply transpose the tensor inplace. - if isinstance(Y_arg, ProxyValue): - transpose_args = (Y_arg, -1, -2) - transpose_node = super().call_operator( + # Transpose Y_arg + with graph.inserting_before(node): + Y_arg_t = graph.call_function( exir_ops.edge.aten.transpose_copy.int, - transpose_args, - {}, - meta, + args=(Y_arg, -1, -2), ) - Y_arg_t = transpose_node - else: - Y_arg_t = Y_tensor.transpose(-1, -2) + Y_arg_t.meta = node.meta - # Construct the new args, and return the transposed matmult op + # Construct the new args, and create the transposed matmul op new_args = ( X_arg, X_zero_point, @@ -2113,68 +2252,83 @@ def call_operator(self, op, args, kwargs, meta): out_zero_point, True, ) - return super().call_operator(op, new_args, kwargs, meta) + + with graph.inserting_before(node): + new_node = graph.call_function( + exir_ops.edge.cadence.quantized_matmul.default, + args=new_args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False result = super().call(graph_module) - # Fuse any inserted transpose node with transpose/permute nodes - # surrounding it. - result = FuseCascadedTransposeOrPermuteOps()(result.graph_module) - assert result is not None - # Replace permute with transpose. - result = ReplacePermuteWithTransposePass()(result.graph_module) - assert result is not None - return result + modified = modified or result.modified + if modified: + # Fuse any inserted transpose node with transpose/permute nodes + # surrounding it. + result = FuseCascadedTransposeOrPermuteOps().call(result.graph_module) + modified = modified or result.modified + # Replace permute with transpose. + result = ReplacePermuteWithTransposePass().call(result.graph_module) + modified = modified or result.modified + + return PassResult(result.graph_module, modified) @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceMulTensorWithMulAndFullOpsPass(ExportPass): +class ReplaceMulTensorWithMulAndFullOpsPass(RemoveOrReplacePassInterface): """ Extracts a single value argument of mul op to a separate full op. """ - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for mul_node in graph_module.graph.find_nodes( - op="call_function", target=torch.ops.aten.mul.Tensor - ): - x_arg, const_arg = mul_node.args + @property + def targets(self) -> list[EdgeOpOverload]: + return [torch.ops.aten.mul.Tensor] - # Swap arguments if the order is wrong - if isinstance(const_arg, torch.fx.Node): - x_arg, const_arg = const_arg, x_arg + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + x_arg, const_arg = node.args - # Skip if the const_arg is not a scalar - if not isinstance(const_arg, (float, int)) or not isinstance( - x_arg, torch.fx.Node - ): - continue + # Swap arguments if the order is wrong + if isinstance(const_arg, torch.fx.Node): + x_arg, const_arg = const_arg, x_arg - # Cast the const_arg to the dtype of the x_arg - full_arg = self.resolve_full_arg(x_arg, const_arg) + # Skip if the const_arg is not a scalar + if not isinstance(const_arg, (float, int)) or not isinstance( + x_arg, torch.fx.Node + ): + return False - full_output_dtype = ( - torch.int32 if isinstance(full_arg, int) else torch.float32 - ) + # Cast the const_arg to the dtype of the x_arg + full_arg = self.resolve_full_arg(x_arg, const_arg) - # Extract an argument to a separate full op. - with graph_module.graph.inserting_before(mul_node): - full_node = graph_module.graph.call_function( - torch.ops.aten.full.default, - args=([1], full_arg), - kwargs={"dtype": full_output_dtype}, - ) - full_node.meta = mul_node.meta - full_node.meta["val"] = [1] - new_mul_node = graph_module.graph.call_function( - torch.ops.aten.mul.Tensor, args=(x_arg, full_node) - ) - new_mul_node.meta = mul_node.meta - # Replace the old mul with a newly created mul. - mul_node.replace_all_uses_with(new_mul_node) - graph_module.graph.erase_node(mul_node) - return super().call(graph_module) + full_output_dtype = torch.int32 if isinstance(full_arg, int) else torch.float32 - def resolve_full_arg(self, x_arg, const_arg): + # Extract an argument to a separate full op. + with node.graph.inserting_before(node): + full_node = node.graph.call_function( + torch.ops.aten.full.default, + args=([1], full_arg), + kwargs={"dtype": full_output_dtype}, + ) + full_node.meta = node.meta + full_node.meta["val"] = [1] + new_mul_node = node.graph.call_function( + torch.ops.aten.mul.Tensor, args=(x_arg, full_node) + ) + new_mul_node.meta = node.meta + # Replace the old mul with a newly created mul. + node.replace_all_uses_with(new_mul_node) + node.graph.erase_node(node) + return True + + def resolve_full_arg( + self, x_arg: torch.fx.Node, const_arg: float | int + ) -> float | int: if x_arg.meta["val"].dtype == torch.float32 and isinstance(const_arg, int): const_arg = float(const_arg) if x_arg.meta["val"].dtype == torch.int32 and isinstance(const_arg, float): @@ -2183,50 +2337,50 @@ def resolve_full_arg(self, x_arg, const_arg): @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(ExportPass): +class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(RemoveOrReplacePassInterface): """ Replace the aten adaptive avg_pool op with the aten avg_pool2d op. """ - def call_operator(self, op, args, kwargs, meta): - # Only continue for avg_pool op - if op not in {exir_ops.edge.aten._adaptive_avg_pool2d.default}: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten._adaptive_avg_pool2d.default] - # Get the input tensor - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] - # Permute NCHW to NHWC for computation - in_tensor_permuted = in_tensor.permute(0, 2, 3, 1) - in_tensor_shape = in_tensor_permuted.shape + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Get the input tensor node + in_tensor_node = node.args[0] + assert isinstance(in_tensor_node, torch.fx.Node) - output_size = args[1] + # Get input shape (in NCHW format) + in_shape = in_tensor_node.meta["val"].shape + output_size = cast(Sequence[int], node.args[1]) num_dims = len(output_size) + # Spatial dimensions are at indices [2:] for NCHW format # TODO: If in_tensor_shape is not a multiple of output size, # this pass will not work. T224984800 dim_multiples = [ - (in_tensor_shape[i + 1] % output_size[i]) == 0 for i in range(num_dims) + (in_shape[i + 2] % output_size[i]) == 0 for i in range(num_dims) ] if not all(dim_multiples): logging.info( - f"Unable to replace adaptive average pool with average pool. Input tensor shape of {in_tensor_shape} is not a multiple of output size: {output_size}" + f"Unable to replace adaptive average pool with average pool. Input tensor shape of {in_shape} is not a multiple of output size: {output_size}" ) - return super().call_operator(op, args, kwargs, meta) + return False - # Compute stride and kernel_size, then set default values for other arguments - stride = [(in_tensor_shape[i + 1] // output_size[i]) for i in range(num_dims)] + # Compute stride and kernel_size based on spatial dimensions + stride = [(in_shape[i + 2] // output_size[i]) for i in range(num_dims)] kernel_size = [ - in_tensor_shape[i + 1] - (output_size[i] - 1) * stride[i] - for i in range(num_dims) + in_shape[i + 2] - (output_size[i] - 1) * stride[i] for i in range(num_dims) ] padding = [0] * num_dims ceil_mode = False count_include_pad = True divisor_override = None - # Create a new avg_pool node with the updated args + # Create a new avg_pool2d node with the computed args new_args = ( - args[0], + in_tensor_node, kernel_size, stride, padding, @@ -2234,38 +2388,101 @@ def call_operator(self, op, args, kwargs, meta): count_include_pad, divisor_override, ) - return super().call_operator( - exir_ops.edge.aten.avg_pool2d.default, - new_args, - kwargs, - meta, - ) + + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.avg_pool2d.default, + args=new_args, + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True + + +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding( + RemoveOrReplacePassInterface +): + """ + Replace torch.ops.quantized_decomposed.embedding_byte.dtype with + torch.ops.cadence.quantized_embedding_byte + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.quantized_decomposed.embedding_byte.default, + exir_ops.edge.quantized_decomposed.embedding_byte.dtype, + ] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Replace with cadence.quantized_embedding_byte + if len(node.args) < 6: + raise AssertionError( + f"Expected 6 arguments for embedding_byte, got {len(node.args)}" + ) + embedding = node.args[0] + scales = node.args[1] + weight_zero_points = node.args[2] + indices = node.args[5] + + if node.target == exir_ops.edge.quantized_decomposed.embedding_byte.dtype: + dtype = node.kwargs.get("dtype", None) + if dtype is not None and dtype != torch.float32: + raise AssertionError( + f"Unsupported output dtype for embedding_byte: {dtype}" + ) + + new_args = (embedding, scales, weight_zero_points, indices, False) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.cadence.quantized_embedding_byte.default, + args=new_args, + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True class CommonReplacePasses: passes = [ + ReplaceScalarWithTensorArgPass, ReplaceSqueezeAndUnsqueezeWithViewPass, ReplaceSplitWithSlicePass, ReplaceSelectWithViewOpPass, ReplaceMMWithAddMMPass, ReplaceRepeatWithCatPass, ReplaceFullLikeWithFullPass, + ReplaceAtenConvolutionWithCadenceConvolutionPass, + ReplacePT2QuantWithCadenceQuantPass, + ReplacePT2DequantWithCadenceDequantPass, + ReplacePowWithMulPass, + ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding, ] @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(ExportPass): +class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(RemoveOrReplacePassInterface): """ Replace aten linalg svd op with cadence custom op. """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten._linalg_svd.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten._linalg_svd.default] - return super().call_operator( - exir_ops.edge.cadence.linalg_svd.default, args, kwargs, meta - ) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.cadence.linalg_svd.default, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True # This class encapsulates all the functions that replace/switch one op in the @@ -2276,13 +2493,11 @@ class CadenceReplaceOpsInGraph: ReplaceEmptyTensorsWithFullPass, ReplaceFunctionallyEquivalentOpTargets, ReplacePermuteWithTransposePass, - ReplaceScalarWithTensorArgPass, ReplaceConvolutionOptionalArgsWithConcreteArgsPass, ReplaceAddMMWithLinearPass, RemoveNopSelectOpPass, ReplacePadWithCatPass, ReplaceConstantPadNdWithSlicePass, - ReplaceAtenConvolutionWithCadenceConvolutionPass, ReplaceConvWithChannelLastConvPass, ReplaceTrivialConvWithLinear, ReplaceConvWithIm2RowAndLinear, @@ -2296,13 +2511,8 @@ class CadenceReplaceOpsInGraph: ReplaceScalarTensorWithFullPass, ReplaceInfArgInFullWithValuePass, ReplaceLogicalNotBooleanWhereWithWherePass, - ReplacePT2QuantWithCadenceQuantPass, - ReplacePT2DequantWithCadenceDequantPass, - ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, ReplaceAtenAvgPoolWithCadenceAvgPoolPass, ReplaceWhereWithFullArgsWithWhereScalar, - ReplaceAtenApproxGeluWithApproxGeluPass, - ReplacePowWithMulPass, ReplaceMulTensorWithMulAndFullOpsPass, ] diff --git a/backends/cadence/aot/simplify_ops.py b/backends/cadence/aot/simplify_ops.py index bf836f09044..92c14cb0f5d 100644 --- a/backends/cadence/aot/simplify_ops.py +++ b/backends/cadence/aot/simplify_ops.py @@ -19,7 +19,7 @@ from executorch.backends.cadence.aot.utils import rebind from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ExportPass, ProxyValue +from executorch.exir.pass_base import ExportPass @register_cadence_pass(CadencePassAttribute(opt_level=0)) @@ -75,7 +75,7 @@ def call_operator(self, op, args, kwargs, meta): slice_scatter = op == exir_ops.edge.aten.slice_scatter.default # Parse the arguments # Extract the tensor to be sliced, and the slicing dimension - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + in_tensor = args[0].to_tensor() dim = args[1 + slice_scatter] if len(args) > 1 + slice_scatter else 0 # Make dim non-negative dim = dim if dim >= 0 else dim + in_tensor.dim() diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index d160a02721a..d6c27f60bd3 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -7,6 +7,7 @@ # pyre-strict +import copy import unittest from typing import cast, Final, List, Tuple @@ -29,6 +30,46 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import PassResult, ProxyValue +from torch.utils import _pytree as pytree + + +def validate_numerics( + original: torch.fx.GraphModule, + modified: torch.fx.GraphModule, + inputs: tuple[torch.Tensor, ...] | list[torch.Tensor], + pass_name: str, + rtol: float = 1e-5, + atol: float = 1e-6, +) -> None: + """Validate that two graph modules produce numerically equivalent outputs. + + Args: + original: The original graph module before the pass + modified: The modified graph module after the pass + inputs: Input tensors to run through both graphs + pass_name: Name of the pass being validated (for error messages) + rtol: Relative tolerance for allclose comparison + atol: Absolute tolerance for allclose comparison + """ + original.eval() + modified.eval() + with torch.no_grad(): + orig_out = original(*inputs) + mod_out = modified(*inputs) + + flat_orig_out, _ = pytree.tree_flatten(orig_out) + flat_mod_out, _ = pytree.tree_flatten(mod_out) + + # Check that outputs match within tolerance + for i, (orig_tensor, mod_tensor) in enumerate(zip(flat_orig_out, flat_mod_out)): + if not torch.allclose(orig_tensor, mod_tensor, rtol=rtol, atol=atol): + max_diff = torch.max(torch.abs(orig_tensor - mod_tensor)).item() + raise AssertionError( + f"Pass validation failed for pass {pass_name}. " + f"Output tensor {i} differs by max {max_diff:.6e}. " + f"Expected rtol={rtol}, atol={atol}. " + f"Original output: {orig_tensor}, Modified output: {mod_tensor}" + ) class TestFusionPassesBase(unittest.TestCase): @@ -202,7 +243,8 @@ def test_keep_mm_add_with_multiple_users(self) -> None: class TestFusionPasses(TestFusionPassesBase): def test_permute_transpose_fusion(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32)) + x_input = torch.randn(3, 1, 3, 1, 4, dtype=torch.float32) + x = builder.placeholder("x", x_input) permute = builder.call_operator( op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 4, 1, 3]) ) @@ -212,8 +254,11 @@ def test_permute_transpose_fusion(self) -> None: ) builder.output([output]) original_graph = builder.get_graph_module() + graph_copy = copy.deepcopy(original_graph) p = FuseCascadedTransposeOrPermuteOps() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = p.call(original_graph) + self.assertTrue(result.modified) + converted_graph = result.graph_module converted_graph.graph.eliminate_dead_code() # Assert that permute op was fused with transpose op self.assertEqual( @@ -222,10 +267,14 @@ def test_permute_transpose_fusion(self) -> None: self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.transpose_copy.int), 0 ) + validate_numerics( + graph_copy, converted_graph, (x_input,), "FuseCascadedTransposeOrPermuteOps" + ) def test_view_fusion(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32)) + x_input = torch.randn(8, 5, 3, dtype=torch.float32) + x = builder.placeholder("x", x_input) view1 = builder.call_operator( op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15]) ) @@ -237,9 +286,17 @@ def test_view_fusion(self) -> None: ) builder.output([output]) original_graph = builder.get_graph_module() + + gm_before = copy.deepcopy(original_graph) p = FuseCascadedViewOps() - converted_graph = cast(PassResult, p(original_graph)).graph_module - converted_graph.graph.eliminate_dead_code() + result = cast(PassResult, p(original_graph)) + self.assertTrue(result.modified) + converted_graph = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate_numerics(gm_before, converted_graph, inputs, "FuseCascadedViewOps") + # Assert that only one view op remains self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 1 @@ -247,7 +304,8 @@ def test_view_fusion(self) -> None: def test_view_fusion_branched(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32)) + x_input = torch.randn(8, 5, 3, dtype=torch.float32) + x = builder.placeholder("x", x_input) y = builder.call_operator( op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15]) ) @@ -259,9 +317,17 @@ def test_view_fusion_branched(self) -> None: ) builder.output([z, t]) original_graph = builder.get_graph_module() + + gm_before = copy.deepcopy(original_graph) p = FuseCascadedViewOps() - converted_graph = cast(PassResult, p(original_graph)).graph_module - converted_graph.graph.eliminate_dead_code() + result = cast(PassResult, p(original_graph)) + self.assertTrue(result.modified) + converted_graph = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate_numerics(gm_before, converted_graph, inputs, "FuseCascadedViewOps") + # z and t should be fused and y should be eliminated. self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 2 diff --git a/backends/cadence/aot/tests/test_memory_passes.py b/backends/cadence/aot/tests/test_memory_passes.py index 41f903ccf06..6c8da2202d4 100644 --- a/backends/cadence/aot/tests/test_memory_passes.py +++ b/backends/cadence/aot/tests/test_memory_passes.py @@ -947,6 +947,110 @@ def test_cat_then_cat(self) -> None: self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) self.verify_nop_memory_alloc(graph_module) + def test_cat_with_duplicate_input_tensor(self) -> None: + """ + Test that cat is NOT optimized when the same tensor appears multiple + times in the cat input list. This is because we cannot place the same + tensor at multiple different offsets relative to the output. + """ + builder = GraphBuilder() + x = builder.placeholder("x", torch.ones(3, 6, dtype=torch.float32)) + to_add_to_x = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([3, 6], 123.0), + kwargs={"dtype": torch.float32}, + ) + add_x = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x, to_add_to_x), + ) + pre_created_output = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([6, 6], 0.0), + kwargs={"dtype": torch.float32}, + ) + # Same tensor (add_x) appears twice in the cat inputs + cat = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([add_x, add_x],), + kwargs={"dim": 0, "out": pre_created_output}, + ) + builder.output([cat]) + original = builder.get_graph_module() + graph_module = self.run_memory_planning(original) + graph_module.graph.eliminate_dead_code() + + # Assert that cat op is NOT optimized away since the same tensor + # appears multiple times in the input list + self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) + self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0) + self.verify_nop_memory_alloc(graph_module) + + def test_cat_with_tensor_having_existing_constraint(self) -> None: + """ + Test that the second cat is NOT optimized when a tensor already has a + relative placement constraint from a previous cat operation. + """ + builder = GraphBuilder() + x = builder.placeholder("x", torch.ones(8, 8, dtype=torch.float32)) + to_add = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([8, 8], 1.0), + kwargs={"dtype": torch.float32}, + ) + x1 = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x, to_add), + ) + x2 = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x1, to_add), + ) + x3 = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x2, to_add), + ) + # First cat: cat(x1, x2) - this will give x1 and x2 relative placement constraints + pre_created_output1 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([16, 8], 0.0), + kwargs={"dtype": torch.float32}, + ) + cat1 = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([x1, x2],), + kwargs={"dim": 0, "out": pre_created_output1}, + ) + # Second cat: cat(x2, x3) - x2 already has a constraint from cat1, + # so this cat cannot be optimized + pre_created_output2 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([16, 8], 0.0), + kwargs={"dtype": torch.float32}, + ) + cat2 = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([x2, x3],), + kwargs={"dim": 0, "out": pre_created_output2}, + ) + # Use both cat results to keep them alive + graph_output = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(cat1, cat2), + ) + builder.output([graph_output]) + original = builder.get_graph_module() + graph_module = self.run_memory_planning( + original, opt_level=3, alloc_graph_input=False + ) + graph_module.graph.eliminate_dead_code() + + # The first cat should be optimized to _cat_nop, but the second cat + # cannot be optimized because x2 already has a relative placement constraint + self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) + self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) + self.verify_nop_memory_alloc(graph_module) + def test_view_for_unallocated_output(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.ones(3, 5, dtype=torch.float32)) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py new file mode 100644 index 00000000000..99953346b05 --- /dev/null +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect +import unittest +from typing import Callable + +import torch +from executorch.backends.cadence.aot.graph_builder import GraphBuilder +from executorch.backends.cadence.aot.quantizer import quantizer as quantizer_module +from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern +from executorch.backends.cadence.aot.quantizer.quantizer import ( + CadenceAtenQuantizer, + CadenceDefaultQuantizer, + CadenceFusedConvReluQuantizer, + CadenceNopQuantizer, + CadenceQuantizer, + CadenceRmsNormNopQuantizer, + CadenceW8A32MixedQuantizer, + CadenceWakeWordQuantizer, + CadenceWith16BitConvActivationsQuantizer, + CadenceWith16BitLinearActivationsQuantizer, + CadenceWith16BitMatmulActivationsQuantizer, + CadenceWithLayerNormQuantizer, + CadenceWithSoftmaxQuantizer, + qconfig_A16, + qconfig_A8W8, +) +from executorch.exir.pass_base import NodeMetadata +from parameterized import parameterized +from torch._ops import OpOverload +from torchao.quantization.pt2e.quantizer.quantizer import ( + Q_ANNOTATION_KEY, + QuantizationAnnotation, + QuantizationSpec, +) + +# Type alias for graph builder functions. +# These functions take a test instance and return a graph module and the target op node. +GraphBuilderFn = Callable[ + ["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node] +] + + +# Quantizers intentionally excluded from annotation testing. +# These should be explicitly justified when added. +EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = { + CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage + CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage + CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything + CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage + CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition + CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage + CadenceWith16BitConvActivationsQuantizer, # TODO: T247438221 Add test coverage + CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage + CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage +} + + +# Test case definitions for quantizer annotation tests. +# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs) +# Adding a new quantizer test only requires adding a tuple to this list. +QUANTIZER_ANNOTATION_TEST_CASES: list[ + tuple[ + str, + GraphBuilderFn, + CadenceQuantizer, + OpOverload, + QuantizationSpec, + list[QuantizationSpec], + ] +] = [ + ( + "matmul_A16", + lambda self: self._build_matmul_graph(), + CadenceWith16BitMatmulActivationsQuantizer(), + torch.ops.aten.matmul.default, + qconfig_A16.output_activation, + # For matmul, both inputs are activations + [qconfig_A16.input_activation, qconfig_A16.input_activation], + ), + ( + "linear_A16", + lambda self: self._build_linear_graph(), + CadenceWith16BitLinearActivationsQuantizer(), + torch.ops.aten.linear.default, + qconfig_A16.output_activation, + # For linear: [input_activation, weight] + [qconfig_A16.input_activation, qconfig_A16.weight], + ), +] + +# Derive the set of tested quantizer classes from the test cases. +# This ensures TESTED_QUANTIZER_CLASSES stays in sync with actual tests. +TESTED_QUANTIZER_CLASSES: set[type[CadenceQuantizer]] = { + type(case[2]) for case in QUANTIZER_ANNOTATION_TEST_CASES +} + + +class QuantizerAnnotationTest(unittest.TestCase): + """Unit tests for verifying quantizer annotations are correctly applied.""" + + def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a matmul operation.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 8)) + y = builder.placeholder("y", torch.randn(8, 4)) + matmul = builder.call_operator( + op=torch.ops.aten.matmul.default, + args=(x, y), + meta=NodeMetadata( + {"source_fn_stack": [("matmul", torch.ops.aten.matmul.default)]} + ), + ) + builder.output([matmul]) + gm = builder.get_graph_module() + + matmul_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.matmul.default, + ) + self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node") + return gm, matmul_nodes[0] + + def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a linear operation (no bias).""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 10)) + weight = builder.placeholder("weight", torch.randn(5, 10)) + linear = builder.call_operator( + op=torch.ops.aten.linear.default, + args=(x, weight), + meta=NodeMetadata( + {"source_fn_stack": [("linear", torch.ops.aten.linear.default)]} + ), + ) + builder.output([linear]) + gm = builder.get_graph_module() + + linear_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.linear.default, + ) + self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node") + return gm, linear_nodes[0] + + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) + def test_quantizer_annotation( + self, + name: str, + graph_builder_fn: GraphBuilderFn, + quantizer: CadenceQuantizer, + target: OpOverload, + expected_output_qspec: QuantizationSpec, + expected_input_qspecs: list[QuantizationSpec], + ) -> None: + """Parameterized test for quantizer annotations.""" + gm, op_node = graph_builder_fn(self) + + quantizer.annotate(gm) + + annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY] + self.assertTrue(annotation._annotated) + + # Verify output annotation + self.assertEqual(annotation.output_qspec, expected_output_qspec) + + # Verify input annotations + self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs)) + for i, (input_node, input_qspec) in enumerate( + annotation.input_qspec_map.items() + ): + expected_arg = op_node.args[i] + assert isinstance(expected_arg, torch.fx.Node) + self.assertEqual( + input_node, + expected_arg, + f"Input node mismatch at index {i}", + ) + self.assertEqual( + input_qspec, + expected_input_qspecs[i], + f"Input qspec mismatch at index {i}", + ) + + def test_all_quantizers_have_annotation_tests(self) -> None: + """Ensure every CadenceQuantizer subclass is either tested or explicitly excluded.""" + # Get all CadenceQuantizer subclasses defined in the quantizer module + all_quantizers: set[type[CadenceQuantizer]] = set() + for _, obj in inspect.getmembers(quantizer_module, inspect.isclass): + if ( + issubclass(obj, CadenceQuantizer) + and obj is not CadenceQuantizer + and obj.__module__ == quantizer_module.__name__ + ): + all_quantizers.add(obj) + + # Check for missing tests + untested = ( + all_quantizers - TESTED_QUANTIZER_CLASSES - EXCLUDED_FROM_ANNOTATION_TESTING + ) + if untested: + untested_names = sorted(cls.__name__ for cls in untested) + self.fail( + f"The following CadenceQuantizer subclasses are not tested in " + f"test_quantizer_annotation and not in EXCLUDED_FROM_ANNOTATION_TESTING: " + f"{untested_names}. Please add test cases or explicitly exclude them." + ) + + +class QuantizerOpsPreserveTest(unittest.TestCase): + def test_mixed_w8a32_ops_to_preserve(self) -> None: + q = CadenceW8A32MixedQuantizer() + actual = q.get_ops_to_preserve_from_decomposition() + expected = [ + torch.ops.aten.linear.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.gru.input, + ] + self.assertCountEqual(actual, expected) + + def test_default_quantizer_ops_to_preserve(self) -> None: + q = CadenceDefaultQuantizer() + actual = q.get_ops_to_preserve_from_decomposition() + expected = [ + torch.ops.aten.addmm.default, + torch.ops.aten.bmm.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, + torch.ops.aten.matmul.default, + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ] + self.assertCountEqual(actual, expected) + + def test_nested_quantizer_ops_to_preserve(self) -> None: + # Setup: Create a nested CadenceQuantizer-like structure by composing + # - CadenceW8A32MixedQuantizer (which preserves linear, conv1d, gru.input) + # - A CadenceAtenQuantizer with AddmmPattern (which preserves addmm) + nested = CadenceDefaultQuantizer( + quantizers=[ + CadenceW8A32MixedQuantizer(), + CadenceAtenQuantizer(AddmmPattern(), qconfig_A8W8), + ] + ) + + # Execute + actual = nested.get_ops_to_preserve_from_decomposition() + + # Assert: union of both sets without duplicates + expected = [ + torch.ops.aten.linear.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.gru.input, + torch.ops.aten.addmm.default, + ] + self.assertCountEqual(actual, expected) + + def test_rms_norm_nop_quantizer_ops_to_preserve(self) -> None: + q = CadenceRmsNormNopQuantizer() + actual = q.get_ops_to_preserve_from_decomposition() + expected = [ + torch.ops.aten.rms_norm.default, + ] + self.assertCountEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 04b3e8e75ba..ccee27f47a5 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -15,6 +15,8 @@ import torch from executorch.backends.cadence.aot.typing_stubs import expand +from executorch.exir.scalar_type import ScalarType + class TestRefImplementations(unittest.TestCase): @expand( @@ -36,12 +38,11 @@ def test_quantize_per_tensor( ) -> None: input_tensor = torch.tensor([input_value]) scale = (f_max - f_min) / (q_max - q_min) - inv_scale = 1.0 / scale - zero_point = round(-f_min * inv_scale) + q_min + zero_point = round(-f_min * 1 / scale) + q_min expected_output = torch.tensor([expected_value], dtype=target_dtype) output = torch.ops.cadence.quantize_per_tensor( - input_tensor, inv_scale, zero_point, q_min, q_max, target_dtype + input_tensor, scale, zero_point, q_min, q_max, target_dtype ) self.assertEqual( @@ -85,7 +86,7 @@ def test_dequantize_per_tensor( expected_output = torch.tensor([expected_value], dtype=torch.float32) output = torch.ops.cadence.dequantize_per_tensor( - input_tensor, scale, zero_point, q_min, q_max, torch.float32 + input_tensor, scale, zero_point, q_min, q_max, input_tensor.dtype ) self.assertEqual( @@ -100,11 +101,11 @@ def test_dequantize_per_tensor( [ # Only these types need to be tested as per ET_FORALL_JARVIS_QUANTIZED_TYPES in # on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/operators.h - ("int16", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.int8), + ("int8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.int8), ("uint8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.uint8), ] ) - def test_quantized_add( + def test_quantized_add_per_tensor( self, name: str, X: int, @@ -122,13 +123,18 @@ def test_quantized_add( Y_tensor = torch.tensor([Y], dtype=dtype) expected_output = torch.tensor([expected_value], dtype=dtype) - output = torch.ops.cadence.quantized_add( + quantized_add_per_tensor = ( + torch.ops.cadence.quantized_add_asym8sxasym8s_asym8s.per_tensor + if dtype == torch.int8 + else torch.ops.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor + ) + output = quantized_add_per_tensor( X_tensor, - torch.tensor(X_scale), - torch.tensor(X_zero_point, dtype=dtype), + X_scale, + X_zero_point, Y_tensor, - torch.tensor(Y_scale), - torch.tensor(Y_zero_point, dtype=dtype), + Y_scale, + Y_zero_point, out_scale, out_zero_point, ) @@ -152,15 +158,19 @@ def test_quantized_add( torch.tensor( [1073741824], dtype=torch.int32 ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 0, # out_zero_point - torch.tensor([[-2]], dtype=dtype), # expected_output + torch.tensor([[0]], dtype=dtype), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ( (False, torch.int8), (True, torch.int8), (True, torch.uint8), + (True, torch.int16), + (False, torch.int16), ) ], # Test case 2: 1x3 input, 2x3 weight (2 output features) @@ -175,14 +185,40 @@ def test_quantized_add( torch.tensor( [1073741824], dtype=torch.int32 ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 0, # out_zero_point - torch.tensor([[-10, -30]], dtype=dtype), # expected_output + torch.tensor([[-2, -8]], dtype=dtype), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ( (False, torch.int8), (True, torch.int8), + (False, torch.int16), + (True, torch.int16), + ) + ], + *[ + ( + torch.Size([1, 3]), # src_shape: 1 sample, 3 input features + torch.Size( + [2, 3] + ), # weight_shape: 2 output features, 3 input features + 0, # in_zero_point + torch.tensor([0, 0, 0], dtype=dtype), # weight_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int32), # out_shift + 0, # out_zero_point + torch.tensor([[0, 0]], dtype=dtype), # expected_output + per_tensor, + False, + False, + ) + for (per_tensor, dtype) in ( + (False, torch.uint8), (True, torch.uint8), ) ], @@ -198,17 +234,20 @@ def test_quantized_add( torch.tensor( [1073741824], dtype=torch.int32 ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 0, # out_zero_point torch.tensor( - [[[-2, -8, -14], [-6, -28, -50]]], dtype=dtype + [[[0, -2, -4], [-2, -7, -12]]], dtype=dtype ), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ( (False, torch.int8), (True, torch.int8), - (True, torch.uint8), + (False, torch.int16), + (True, torch.int16), ) ], # Test case 4: Non-zero zero points @@ -223,15 +262,19 @@ def test_quantized_add( torch.tensor( [268435456], dtype=torch.int32 ), # out_multiplier (1.0 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 1, # out_zero_point - torch.tensor([[-15, 25]], dtype=dtype), # expected_output + torch.tensor([[1, 1]], dtype=dtype), # expected_output per_tensor, + False, + False, ) for (per_tensor, dtype) in ( (False, torch.int8), (True, torch.int8), - (True, torch.uint8), + (False, torch.int16), + (True, torch.int16), + # (True, torch.uint8), ) ], # Test case 5: Non-uniform weight zero points @@ -246,12 +289,17 @@ def test_quantized_add( torch.tensor( [268435456], dtype=torch.int32 ), # out_multiplier (1.0 * 2^31) - torch.tensor([0], dtype=torch.int64), # out_shift + torch.tensor([0], dtype=torch.int32), # out_shift 1, # out_zero_point - torch.tensor([[-23, 17]], dtype=dtype), # expected_output + torch.tensor([[1, 1]], dtype=dtype), # expected_output + False, False, + False, + ) + for dtype in ( + torch.int8, + torch.int16, ) - for dtype in (torch.int8, torch.uint8) ], # Test case 6: Non-zero out_shift (shift=1) *[ @@ -266,13 +314,66 @@ def test_quantized_add( [268435456], dtype=torch.int32 ), # out_multiplier (0.125 * 2^31) torch.tensor( - [1], dtype=torch.int64 + [1], dtype=torch.int32 + ), # out_shift (shift=1, doubles the scale) + 1, # out_zero_point + torch.tensor([[1, 2]], dtype=dtype), # expected_output + per_tensor, + False, + False, + ) + for (per_tensor, dtype) in ( + (False, torch.int8), + (True, torch.int8), + (False, torch.int16), + (True, torch.int16), + ) + ], + *[ + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [2, 2] + ), # weight_shape: 2 output features, 2 input features + 2, # in_zero_point + torch.tensor([1, 1], dtype=dtype), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (0.125 * 2^31) + torch.tensor( + [1], dtype=torch.int32 + ), # out_shift (shift=1, doubles the scale) + 1, # out_zero_point + torch.tensor([[1, 2]], dtype=dtype), # expected_output + per_tensor, + matmul, + transposed_matmul, + ) + for (matmul, transposed_matmul) in ((True, False), (True, True)) + for (per_tensor, dtype) in ((True, torch.int8), (True, torch.int16)) + ], + *[ + ( + torch.Size([2, 1, 2]), # src_shape: 1 sample, 2 input features + torch.Size( + [2, 2, 2] + ), # weight_shape: 2 output features, 2 input features + 2, # in_zero_point + torch.tensor([1, 1], dtype=dtype), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (0.125 * 2^31) + torch.tensor( + [1], dtype=torch.int32 ), # out_shift (shift=1, doubles the scale) 1, # out_zero_point - torch.tensor([[-7, 13]], dtype=dtype), # expected_output + torch.tensor([[[1, 2]], [[0, -1]]], dtype=dtype), # expected_output per_tensor, + matmul, + transposed_matmul, ) - for (per_tensor, dtype) in ((False, torch.int8), (True, torch.int8)) + for (matmul, transposed_matmul) in ((True, False), (True, True)) + for (per_tensor, dtype) in ((True, torch.int8),) ], ] ) @@ -287,7 +388,12 @@ def test_quantized_linear( out_zero_point: int, expected_output: torch.Tensor, per_tensor: bool, + matmul: bool, + transposed_matmul: bool, ) -> None: + if not per_tensor and matmul: + self.skipTest("Only per_tensor supported for matmul") + src = ( torch.arange(np.prod(src_shape)) .reshape(src_shape) @@ -298,7 +404,9 @@ def test_quantized_linear( .reshape(weight_shape) .to(expected_output.dtype) ) - bias = torch.arange(weight_shape[0]).to(torch.int32) + if matmul and not transposed_matmul: + weight = weight.transpose(-1, -2) + if per_tensor: weight_zero_point = weight_zero_point[0] out_multiplier = out_multiplier[0] @@ -307,20 +415,34 @@ def test_quantized_linear( if per_tensor: match expected_output.dtype: case torch.int8: - linear_ops = ( - torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor, - torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, - ) + if matmul: + linear_ops = ( + # Doesn't have per tensor name, but it is per tensor + torch.ops.cadence.quantized_matmul_asym8sxasym8s_asym8s, + ) + else: + linear_ops = ( + torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, + ) case torch.uint8: - linear_ops = ( - torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, - torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, - ) + if matmul: + linear_ops = ( + torch.ops.cadence.quantized_matmul_asym8uxasym8u_asym8u, + ) + else: + linear_ops = ( + torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, + ) case _: - linear_ops = ( - torch.ops.cadence.quantized_linear.per_tensor, - torch.ops.cadence.quantized_fully_connected.per_tensor, - ) + if matmul: + linear_ops = (torch.ops.cadence.quantized_matmul,) + else: + linear_ops = ( + torch.ops.cadence.quantized_linear.per_tensor, + torch.ops.cadence.quantized_fully_connected.per_tensor, + ) else: linear_ops = ( torch.ops.cadence.quantized_linear, @@ -328,17 +450,40 @@ def test_quantized_linear( ) for linear_op in linear_ops: - output = linear_op( - src, - weight, - bias, - in_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - typing.cast(torch.Tensor, None), + # Get the function name for linear_op for debugging + op_name = ( + linear_op.__name__ if hasattr(linear_op, "__name__") else str(linear_op) ) + if matmul: + assert "quantized_matmul" in op_name + output = linear_op( + src, + in_zero_point, + weight, + weight_zero_point, + None, + out_multiplier, + out_shift, + out_zero_point, + transposed_matmul, + ) + else: + assert ( + "quantized_linear" in op_name + or "quantized_fully_connected" in op_name + ) + bias = torch.arange(weight_shape[0]).to(torch.int32) + output = linear_op( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + typing.cast(torch.Tensor, None), + ) self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch") @@ -806,9 +951,9 @@ def test_quantized_conv_per_tensor( convs = [ ( - torch.ops.cadence.quantized_conv_nchw.per_tensor + torch.ops.cadence.quantized_conv2d_nchw.per_tensor if memory_format == torch.contiguous_format - else torch.ops.cadence.quantized_conv_nhwc.per_tensor + else torch.ops.cadence.quantized_conv2d_nhwc.per_tensor ) ] @@ -816,30 +961,30 @@ def test_quantized_conv_per_tensor( if input_tensor.dtype == torch.int8 and weight.dtype == torch.int8: if memory_format == torch.contiguous_format: optimized_convs = [ - torch.ops.cadence.quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor, - torch.ops.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor, - torch.ops.cadence.quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor, ] else: optimized_convs = [ - torch.ops.cadence.quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor, - torch.ops.cadence.quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor, - torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor, ] elif input_tensor.dtype == torch.uint8 and weight.dtype == torch.uint8: if memory_format == torch.contiguous_format: optimized_convs = [ - torch.ops.cadence.quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor, - torch.ops.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor, - torch.ops.cadence.quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor, ] else: optimized_convs = [ - torch.ops.cadence.quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor, - torch.ops.cadence.quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor, - torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor, ] convs.extend(optimized_convs) @@ -881,6 +1026,297 @@ def test_quantized_conv_per_tensor( f"Output values don't match expected. Got {output}, expected {expected_output}", ) + @expand( + [ + ( + "basic_int8_weights", + torch.tensor( + [ + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + ] + ], + dtype=torch.float32, + ), # src: 1x4x5 + torch.tensor( + [ + [[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]], + [[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]], + [[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]], + [[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]], + ], + dtype=torch.int8, + ), # weight: 4x4x3 + 0.1, # w_scale + torch.tensor([1, 1, 1, 1], dtype=torch.int8), # bias: 4 + 0.2, # b_scale + torch.tensor( + [ + [ + [2.2, 3.0, 3.8], + [2.2, 3.0, 3.8], + [2.2, 3.0, 3.8], + [2.2, 3.0, 3.8], + ] + ], + dtype=torch.float32, + ), # expected: conv1d result + ), + ( + "batch_size_2", + torch.tensor( + [ + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + ], + [ + [2.0, 3.0, 4.0, 5.0, 6.0], + [2.0, 3.0, 4.0, 5.0, 6.0], + [2.0, 3.0, 4.0, 5.0, 6.0], + [2.0, 3.0, 4.0, 5.0, 6.0], + ], + ], + dtype=torch.float32, + ), # src: 2x4x5 + torch.tensor( + [ + [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], + [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], + [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], + [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], + ], + dtype=torch.int8, + ), # weight: 4x4x3 + 1.0, # w_scale + torch.tensor([0, 0, 0, 0], dtype=torch.int8), # bias: 4 + 1.0, # b_scale + torch.tensor( + [ + [ + [24.0, 36.0, 48.0], + [24.0, 36.0, 48.0], + [24.0, 36.0, 48.0], + [24.0, 36.0, 48.0], + ], + [ + [36.0, 48.0, 60.0], + [36.0, 48.0, 60.0], + [36.0, 48.0, 60.0], + [36.0, 48.0, 60.0], + ], + ], + dtype=torch.float32, + ), # expected + ), + ( + "zero_weights_bias", + torch.tensor( + [ + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + ] + ], + dtype=torch.float32, + ), # src: 1x4x5 + torch.tensor( + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + ], + dtype=torch.int8, + ), # weight: 4x4x3 + 0.1, # w_scale + torch.tensor([0, 0, 0, 0], dtype=torch.int8), # bias: 4 + 1.0, # b_scale + torch.tensor( + [ + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ], + dtype=torch.float32, + ), # expected + ), + ( + "negative_weights", + torch.tensor( + [ + [ + [2.0, 4.0, 6.0, 8.0, 10.0], + [2.0, 4.0, 6.0, 8.0, 10.0], + [2.0, 4.0, 6.0, 8.0, 10.0], + [2.0, 4.0, 6.0, 8.0, 10.0], + ] + ], + dtype=torch.float32, + ), # src: 1x4x5 + torch.tensor( + [ + [[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]], + [[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]], + [[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]], + [[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]], + ], + dtype=torch.int8, + ), # weight: 4x4x3 + 0.5, # w_scale + torch.tensor([2, 2, 2, 2], dtype=torch.int8), # bias: 4 + 1.0, # b_scale + torch.tensor( + [ + [ + [-14.0, -26.0, -38.0], + [-14.0, -26.0, -38.0], + [-14.0, -26.0, -38.0], + [-14.0, -26.0, -38.0], + ] + ], + dtype=torch.float32, + ), # expected + ), + ] + ) + def test_quantized_w8a32_conv( + self, + name: str, + src: torch.Tensor, + weight: torch.Tensor, + w_scale: float, + bias: torch.Tensor, + b_scale: float, + expected_output: torch.Tensor, + ) -> None: + + # This op takes in channels last src + src = src.permute(0, 2, 1) + + # This op takes in LNC format for weights + weight = weight.permute(2, 0, 1) + output = torch.ops.cadence.quantized_w8a32_conv( + src, weight, w_scale, bias, b_scale + ) + + # Verify output properties + self.assertEqual( + output.dtype, + torch.float32, + f"Output dtype should be float32 in {name}", + ) + self.assertEqual( + output.shape, + expected_output.shape, + f"Output shape should match expected shape in {name}", + ) + + # Verify output matches expected values + self.assertTrue( + torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + + @expand( + [ + ( + "multi_input_features", + torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32), # src: 1x3 + torch.tensor([[2, 1], [1, 2], [1, 1]], dtype=torch.int8), # weight: 3x2 + 0.5, # w_scale + torch.tensor([0, 1], dtype=torch.int8), # bias: 2 + 1.0, # b_scale + torch.tensor([[3.5, 5.0]], dtype=torch.float32), # expected + ), + ( + "batch_size_2", + torch.tensor( + [[[1.0, 2.0]], [[3.0, 4.0]]], dtype=torch.float32 + ), # src: 2x2 + torch.tensor([[1, 2], [1, -1]], dtype=torch.int8), # weight: 2x2 + 1.0, # w_scale + torch.tensor([0, 0], dtype=torch.int8), # bias: 2 + 1.0, # b_scale + torch.tensor( + [[[3.0, 0.0]], [[7.0, 2.0]]], dtype=torch.float32 + ), # expected + ), + ( + "shape_assertion_error", + torch.tensor( + [[[1.0, 2.0], [3.0, 4.0]]], dtype=torch.float32 + ), # src: 1x2x2 + torch.tensor([[1, 2], [1, -1]], dtype=torch.int8), # weight: 2x2 + 1.0, # w_scale + torch.tensor([0, 1], dtype=torch.int8), # bias: 2 + 1.0, # b_scale + torch.tensor( + [[[3.0, 1.0], [7.0, 3.0]]], dtype=torch.float32 + ), # expected + ), + ( + "negative_weights", + torch.tensor([[2.0, 4.0]], dtype=torch.float32), # src: 1x2 + torch.tensor([[-2, -3], [-1, -2]], dtype=torch.int8), # weight: 2x2 + 0.5, # w_scale + torch.tensor([2, 1], dtype=torch.int8), # bias: 2 + 1.0, # b_scale + torch.tensor([[-2.0, -6.0]], dtype=torch.float32), # expected + ), + ] + ) + def test_quantized_w8a32_linear( + self, + name: str, + src: torch.Tensor, + weight: torch.Tensor, + w_scale: float, + bias: torch.Tensor, + b_scale: float, + expected_output: torch.Tensor, + ) -> None: + if name == "shape_assertion_error": + with self.assertRaisesRegex( + AssertionError, "Only supporting vector-matrix multiplication" + ): + torch.ops.cadence.quantized_w8a32_linear( + src, weight, w_scale, bias, b_scale + ) + return + + output = torch.ops.cadence.quantized_w8a32_linear( + src, weight, w_scale, bias, b_scale + ) + + # Verify output properties + self.assertEqual( + output.dtype, + torch.float32, + f"Output dtype should be float32 in {name}", + ) + self.assertEqual( + output.shape, + expected_output.shape, + f"Output shape should match expected shape in {name}", + ) + + # Verify output matches expected values + self.assertTrue( + torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + @expand( [ # Test case 1: Basic int8 case with negative scale @@ -945,63 +1381,54 @@ def test_quantized_conv_per_tensor( [4, 2, 0, -2], dtype=dtype ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2) ) - for dtype in [torch.int8, torch.uint8] + for dtype in [torch.int8] ], - # Test case 4: Non-per-tensor *[ ( - "non_per_tensor", - torch.tensor([-1, -2, -3, 1, 2, 3], dtype=dtype), # input - torch.tensor([0, 0, 0, 1, 1, 1]), # X_zero_point + "positive_with_shift_unsigned", + torch.tensor([2, 4, 6, 8], dtype=dtype), # input + 1, # X_zero_point 5, # out_zero_point - torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) - torch.tensor([1]), # out_shift (multiply by 2^1 = 2) + 1073741824, # out_multiplier (0.5 * 2^31) + 1, # out_shift (multiply by 2^1 = 2) dtype, # dtype - torch.tensor([5, 5, 5, 5, 4, 3], dtype=dtype), + torch.tensor([4, 2, 0, 0], dtype=dtype), ) - for dtype in [torch.int8] + for dtype in [torch.uint8] ], ] ) - def test_quantized_relu( + def test_quantized_relu_per_tensor( self, name: str, X: torch.Tensor, - X_zero_point: torch.Tensor | int, + X_zero_point: int, out_zero_point: int, - out_multiplier: torch.Tensor | int, - out_shift: torch.Tensor | int, + out_multiplier: int, + out_shift: int, dtype: torch.dtype, expected_output: torch.Tensor, ) -> None: - if isinstance(X_zero_point, int): - assert isinstance(out_multiplier, int) - assert isinstance(out_shift, int) - - match dtype: - case torch.int8: - quantized_relu = ( - torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor - ) - case torch.uint8: - quantized_relu = ( - torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor - ) - case _: - quantized_relu = torch.ops.cadence.quantized_relu_per_tensor + match dtype: + case torch.int8: + quantized_relu_per_tensor = ( + torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor + ) + case torch.uint8: + quantized_relu_per_tensor = ( + torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor + ) + case _: + quantized_relu_per_tensor = torch.ops.cadence.quantized_relu_per_tensor - output = quantized_relu( - X, - X_zero_point, - out_zero_point, - out_multiplier, - out_shift, - ) - else: - output = torch.ops.cadence.quantized_relu( - X, X_zero_point, out_zero_point, out_multiplier, out_shift - ) + output = quantized_relu_per_tensor( + X, + X_zero_point, + out_zero_point, + out_multiplier, + out_shift, + ) # Verify output properties self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}") @@ -1012,3 +1439,1713 @@ def test_quantized_relu( torch.equal(output, expected_output), f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", ) + + def test_where_Scalar(self) -> None: + input_tensor = torch.tensor([1, 2, 3, 4], dtype=torch.int8) + out = torch.ops.cadence.where_Scalar(input_tensor > 2, 1.0, 0.0) + self.assertTrue( + torch.equal(out, torch.tensor([0.0, 0.0, 1.0, 1.0], dtype=torch.float32)) + ) + with self.assertRaises(ValueError) as context: + torch.ops.cadence.where_Scalar(input_tensor, 1.0, 0.0) + + self.assertIn("condition must be a bool tensor", str(context.exception)) + + @expand( + [ + ( + "h1xhd4", + torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float32), + torch.tensor([[0.0, 0.0]], dtype=torch.float32), + torch.tensor([[1.0, 1.0]], dtype=torch.float32), + torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float32), + ), + ( + "h2xhd4", + torch.tensor( + [[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]]], + dtype=torch.float32, + ), + torch.tensor([[0.0, 1.0]], dtype=torch.float32), + torch.tensor([[1.0, 0.0]], dtype=torch.float32), + torch.tensor( + [[[[1.0, 2.0, -4.0, 3.0], [5, 6.0, -8.0, 7.0]]]], + dtype=torch.float32, + ), + ), + ( + "s2xh2xhd4", + torch.tensor( + [ + [ + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], + [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]], + ] + ], + dtype=torch.float32, + ), + torch.tensor([[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32), + torch.tensor([[1.0, 0.0], [1.0, 0.0]], dtype=torch.float32), + torch.tensor( + [ + [ + [[1.0, 2.0, -4.0, 3.0], [5.0, 6.0, -8.0, 7.0]], + [[9.0, 10.0, -12.0, 11.0], [13.0, 14.0, -16.0, 15.0]], + ] + ], + dtype=torch.float32, + ), + ), + ( + "pos_not_none", + torch.tensor( + [ + [ + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], + [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]], + ] + ], + dtype=torch.float32, + ), + torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32), + torch.tensor([[0.0, 1.0], [1.0, 0.0]], dtype=torch.float32), + torch.tensor( + [ + [ + [[1.0, 2.0, -4.0, 3.0], [5.0, 6.0, -8.0, 7.0]], + [[-10.0, 9.0, 11.0, 12.0], [-14.0, 13.0, 15.0, 16.0]], + ] + ], + dtype=torch.float32, + ), + torch.tensor([1, 0]), + ), + ] + ) + def test_rope( + self, + name: str, + input_tensor: torch.Tensor, + sin_tensor: torch.Tensor, + cos_tensor: torch.Tensor, + expected_output: torch.Tensor, + pos: torch.Tensor | None = None, + ) -> None: + output = torch.ops.cadence.rope(input_tensor, sin_tensor, cos_tensor, pos) + + # Verify output properties + self.assertEqual( + output.dtype, + input_tensor.dtype, + f"Output dtype should match input dtype in {name}", + ) + self.assertEqual( + output.shape, + input_tensor.shape, + f"Output shape should match input shape in {name}", + ) + + # Verify output matches expected values + self.assertTrue( + torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + + @expand( + [ + # Test case 1: Basic 2D convolution (NCHW format) + ( + "basic_2d_nchw", + torch.tensor( + [[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + torch.tensor( + [[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 (identity-like filter) + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + False, # channel_last + torch.tensor( + [[[[5.0]]]], dtype=torch.float32 + ), # expected: 1*1 + 4*1 = 5 + ), + # Test case 3: 2D convolution with stride=2 + ( + "conv2d_stride2", + torch.tensor( + [ + [ + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ] + ] + ], + dtype=torch.float32, + ), # input: 1x1x4x4 + torch.tensor( + [[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 (sum filter) + torch.tensor([0.0], dtype=torch.float32), # bias + (2, 2), # stride=2 + (0, 0), # padding + (1, 1), # dilation + 1, # groups + False, # channel_last + torch.tensor([[[[14.0, 22.0], [46.0, 54.0]]]], dtype=torch.float32), + ), + # Test case 4: 2D convolution with padding=1 + ( + "conv2d_padding1", + torch.tensor( + [[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + torch.tensor( + [[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride + (1, 1), # padding=1 + (1, 1), # dilation + 1, # groups + False, # channel_last + torch.tensor( + [[[[1.0, 2.0, 0.0], [3.0, 5.0, 2.0], [0.0, 3.0, 4.0]]]], + dtype=torch.float32, + ), # expected with padding + ), + # Test case 5: 2D convolution with dilation=2 + ( + "conv2d_dilation2", + torch.tensor( + [ + [ + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ] + ] + ], + dtype=torch.float32, + ), # input: 1x1x4x4 + torch.tensor( + [[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (2, 2), # dilation=2 + 1, # groups + False, # channel_last + torch.tensor([[[[24.0, 28.0], [40.0, 44.0]]]], dtype=torch.float32), + ), + # Test case 6: 2D grouped convolution (groups=2) + ( + "conv2d_groups2", + torch.tensor( + [ + [ + [[1.0, 2.0], [3.0, 4.0]], # first input channel + [[5.0, 6.0], [7.0, 8.0]], # second input channel + ] + ], + dtype=torch.float32, + ), # input: 1x2x2x2 + torch.tensor( + [ + [[[1.0, 1.0], [1.0, 1.0]]], # first group weight + [[[0.5, 0.5], [0.5, 0.5]]], # second group weight + ], + dtype=torch.float32, + ), # weight: 2x1x2x2 + torch.tensor([0.0, 1.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 2, # groups=2 + False, # channel_last + torch.tensor([[[[10.0]], [[14.0]]]], dtype=torch.float32), + ), + # Test case 7: 1D convolution (NCL format) + ( + "conv1d_ncl", + torch.tensor( + [[[1.0, 2.0, 3.0, 4.0]]], dtype=torch.float32 + ), # input: 1x1x4 + torch.tensor([[[1.0, 1.0]]], dtype=torch.float32), # weight: 1x1x2 + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride (only stride[1] is used for 1D) + (0, 0), # padding (only padding[1] is used for 1D) + (1, 1), # dilation (only dilation[1] is used for 1D) + 1, # groups + False, # channel_last + torch.tensor( + [[[3.0, 5.0, 7.0]]], dtype=torch.float32 + ), # expected: [1+2, 2+3, 3+4] + ), + # Test case 9: Multi-channel input and output + ( + "multi_channel", + torch.tensor( + [ + [ + [[1.0, 2.0], [3.0, 4.0]], # first input channel + [[0.5, 1.0], [1.5, 2.0]], # second input channel + ] + ], + dtype=torch.float32, + ), # input: 1x2x2x2 + torch.tensor( + [ + [ # first output channel + [[1.0, 0.0], [0.0, 1.0]], # weights for first input channel + [ + [2.0, 0.0], + [0.0, 2.0], + ], # weights for second input channel + ], + [ # second output channel + [[0.5, 0.5], [0.5, 0.5]], # weights for first input channel + [ + [1.0, 1.0], + [1.0, 1.0], + ], # weights for second input channel + ], + ], + dtype=torch.float32, + ), # weight: 2x2x2x2 + torch.tensor([0.0, 1.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + False, # channel_last + torch.tensor([[[[10.0]], [[11.0]]]], dtype=torch.float32), + ), + # Test case 10: Convolution with non-zero bias + ( + "conv2d_with_bias", + torch.tensor( + [[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + torch.tensor( + [[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 + torch.tensor([10.0], dtype=torch.float32), # bias=10 + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + False, # channel_last + torch.tensor( + [[[[15.0]]]], dtype=torch.float32 + ), # expected: 5 + 10 = 15 + ), + ] + ) + def test_convolution( + self, + name: str, + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + channel_last: bool, # Keep for backward compatibility with test data, but won't use + expected_output: torch.Tensor, + ) -> None: + # Determine if 1D or 2D based on input shape + is_conv1d = len(input_tensor.shape) == 3 + if is_conv1d: + output = torch.ops.cadence.conv1d( + input_tensor, + weight, + bias, + (stride[0],), + (padding[0],), + (dilation[0],), + groups, + ) + else: + output = torch.ops.cadence.conv2d( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + ) + + # Verify output properties + self.assertEqual( + output.dtype, + input_tensor.dtype, + f"Output dtype should match input dtype in {name}", + ) + self.assertEqual( + output.shape, + expected_output.shape, + f"Output shape should match expected shape in {name}", + ) + + # Verify output matches expected values + self.assertTrue( + torch.equal(output, expected_output), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + + @expand( + [ + ( + "basic_2d_stride1", + torch.tensor( + [[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + torch.tensor( + [[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 (in PyTorch format, will be transformed to Cadence format) + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + (0, 0), # output_padding + False, # channel_last + torch.tensor( + [[[[1.0, 3.0, 2.0], [4.0, 10.0, 6.0], [3.0, 7.0, 4.0]]]], + dtype=torch.float32, + ), + ), + # 2D transposed convolution with non-zero bias + ( + "with_bias", + torch.tensor( + [[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + torch.tensor( + [[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 (in PyTorch format, will be transformed to Cadence format) + torch.tensor([5.0], dtype=torch.float32), # bias=5.0 + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + (0, 0), # output_padding + False, # channel_last + torch.tensor( + [[[[6.0, 7.0, 5.0], [8.0, 10.0, 7.0], [5.0, 8.0, 9.0]]]], + dtype=torch.float32, + ), + ), + ] + ) + def test_transposed_convolution( + self, + name: str, + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + output_padding: tuple[int, int], + channel_last: bool, + expected_output: torch.Tensor, + ) -> None: + # Apply the same transformations that ReplaceAtenConvolutionWithCadenceConvolutionPass + # applies to weights: transpose(0,1) then flip spatial dimensions. + # This converts weights from PyTorch format to Cadence format. + weight_dim = len(weight.shape) + flip_dims = [-1] if weight_dim == 3 else [-1, -2] + + # Transform: transpose dims 0 and 1, then flip spatial dimensions + cadence_weight = weight.transpose(0, 1) + cadence_weight = torch.flip(cadence_weight, dims=flip_dims) + + output = torch.ops.cadence.transposed_convolution( + input_tensor, + cadence_weight, + bias, + stride, + padding, + dilation, + output_padding, + groups, + channel_last, + ) + + # Verify output properties + self.assertEqual( + output.dtype, + input_tensor.dtype, + f"Output dtype should match input dtype in {name}", + ) + self.assertEqual( + output.shape, + expected_output.shape, + f"Output shape should match expected shape in {name}", + ) + + # Verify output matches expected values + self.assertTrue( + torch.equal(output, expected_output), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + + @expand( + [ + # Basic non-quantized average pooling + ( + "basic_non_quantized", + torch.tensor( + [ + [ + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ] + ] + ], + dtype=torch.float32, + ), # input: 1x1x4x4 + (2, 2), # kernel_size + (2, 2), # stride + (0, 0), # padding + False, # ceil_mode + False, # count_include_pad + None, # divisor_override + None, # in_zero_point (non-quantized) + False, # channel_last + torch.tensor( + [[[[3.5, 5.5], [11.5, 13.5]]]], dtype=torch.float32 + ), # expected: average of 2x2 blocks + ), + # Non-quantized with count_include_pad=True and padding + ( + "non_quantized_count_include_pad", + torch.tensor( + [[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + (3, 3), # kernel_size (larger than input) + (1, 1), # stride + (1, 1), # padding + False, # ceil_mode + True, # count_include_pad=True + None, # divisor_override + None, # in_zero_point (non-quantized) + False, # channel_last + torch.tensor( + [[[[2.5, 2.5], [2.5, 2.5]]]], + dtype=torch.float32, + ), + ), + # Non-quantized with divisor_override + ( + "non_quantized_divisor_override", + torch.tensor( + [[[[2.0, 4.0], [6.0, 8.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + (2, 2), # kernel_size + (1, 1), # stride + (0, 0), # padding + False, # ceil_mode + False, # count_include_pad + 2, # divisor_override (instead of 4) + None, # in_zero_point (non-quantized) + False, # channel_last + torch.tensor( + [[[[10.0]]]], dtype=torch.float32 + ), # expected: (2+4+6+8)/2 = 10 + ), + # Quantized with non-zero zero_point and padding + ( + "quantized_nonzero_zero_point", + torch.tensor( + [[[[130, 132], [134, 136]]]], dtype=torch.uint8 + ), # input: 1x1x2x2, values around zero_point=128 + (3, 3), # kernel_size + (1, 1), # stride + (1, 1), # padding + False, # ceil_mode + True, # count_include_pad=True + None, # divisor_override + 128, # in_zero_point=128 (padded areas will have this value) + False, # channel_last + torch.tensor( + [[[[130, 130], [130, 130]]]], dtype=torch.uint8 + ), # expected: averages including padded zero_point values + ), + # Quantized with divisor_override + ( + "quantized_divisor_override", + torch.tensor( + [[[[64, 96], [128, 160]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + (2, 2), # kernel_size + (1, 1), # stride + (0, 0), # padding + False, # ceil_mode + False, # count_include_pad + 2, # divisor_override (instead of 4) + None, # in_zero_point=None + False, # channel_last + torch.tensor( + [[[[224]]]], dtype=torch.float32 + ), # expected: (64+96+128+160)/2 = 224 + ), + # Large values that need clamping + ( + "quantized_clamping_test", + torch.tensor( + [[[[120, 125], [125, 127]]]], dtype=torch.int8 + ), # input: 1x1x2x2, large values for int8 + (2, 2), # kernel_size + (1, 1), # stride + (0, 0), # padding + False, # ceil_mode + False, # count_include_pad + None, # divisor_override + 0, # in_zero_point=0 + False, # channel_last + torch.tensor( + [[[[124]]]], dtype=torch.int8 + ), # expected: (120+125+125+127)/4 = 124.25 -> 124, within int8 range + ), + ] + ) + def test_avg_pool2d( + self, + name: str, + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + stride: tuple[int, int], + padding: tuple[int, int], + ceil_mode: bool, + count_include_pad: bool, + divisor_override: int | None, + in_zero_point: int | None, + channel_last: bool, + expected_output: torch.Tensor, + ) -> None: + output = torch.ops.cadence.avg_pool2d( + input_tensor, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + in_zero_point if in_zero_point is None else torch.tensor([in_zero_point]), + channel_last, + ) + + # Verify output properties + self.assertEqual( + output.dtype, + input_tensor.dtype, + f"Output dtype should match input dtype in {name}", + ) + self.assertEqual( + output.shape, + expected_output.shape, + f"Output shape should match expected shape in {name}", + ) + + # Verify output matches expected values + if input_tensor.dtype.is_floating_point: + self.assertTrue( + torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + else: + self.assertTrue( + torch.equal(output, expected_output), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + + @expand( + [ + # Basic 2x2 kernel, stride 1, no padding, NCHW + ( + "nchw_basic_2x2", + torch.tensor( + [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32 + ), # (N=1, C=1, H=3, W=3) + (2, 2), # kernel_size + (1, 1), # dilation + (0, 0), # padding + (1, 1), # stride + None, # in_zero_point + False, # channel_last + False, + torch.tensor( + [ + [[1, 2, 4, 5], [2, 3, 5, 6], [4, 5, 7, 8], [5, 6, 8, 9]], + ], + dtype=torch.float32, + ), + ), + # 2x2 kernel, stride 2, no padding, NCHW + ( + "nchw_stride2", + torch.tensor( + [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32 + ), + (2, 2), + (1, 1), + (0, 0), + (2, 2), + None, + False, + False, + torch.tensor( + [ + [[1, 2, 4, 5]], + ], + dtype=torch.float32, # Only every other patch in each dim + ), + ), + # 2x2 kernel, stride 1, padding 1, NCHW + ( + "nchw_padding1", + torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.float32), # (1,1,2,2) + (2, 2), + (1, 1), + (1, 1), + (1, 1), + None, + False, + False, + torch.tensor( + [ + [ + [0, 0, 0, 1], + [0, 0, 1, 2], + [0, 0, 2, 0], + [0, 1, 0, 3], + [1, 2, 3, 4], + [2, 0, 4, 0], + [0, 3, 0, 0], + [3, 4, 0, 0], + [4, 0, 0, 0], + ], + ], + dtype=torch.float32, + ), + ), + # 2x2 kernel, stride 1, no padding, NHWC + ( + "nhwc_basic_2x2", + torch.tensor( + [[[[1], [2], [3]], [[4], [5], [6]], [[7], [8], [9]]]], + dtype=torch.float32, + ), # (N=1, H=3, W=3, C=1) + (2, 2), + (1, 1), + (0, 0), + (1, 1), + None, + True, + False, + torch.tensor( + [ + [[1, 2, 4, 5], [2, 3, 5, 6], [4, 5, 7, 8], [5, 6, 8, 9]], + ], + dtype=torch.float32, + ), + ), + # 2x2 kernel, stride 1, no padding, NCHW, in_zero_point=1 + ( + "nchw_in_zero_point_no_padding", + torch.tensor([[[[2, 3, 4], [5, 6, 7], [8, 9, 10]]]], dtype=torch.int8), + (2, 2), + (1, 1), + (0, 0), + (1, 1), + torch.tensor(1, dtype=torch.int32), + False, + False, + torch.tensor( + [ + [[2, 3, 5, 6], [3, 4, 6, 7], [5, 6, 8, 9], [6, 7, 9, 10]], + ], + dtype=torch.int8, + ), + ), + ( + "nchw_in_zero_point_with_padding=1_and_stride=2", + torch.tensor([[[[2, 3, 4], [5, 6, 7], [8, 9, 10]]]], dtype=torch.int8), + (2, 2), + (1, 1), + (1, 1), + (2, 2), + torch.tensor(-1, dtype=torch.int32), + False, + False, + torch.tensor( + [ + [ + [-1, -1, -1, 2], + [-1, -1, 3, 4], + [-1, 5, -1, 8], + [6, 7, 9, 10], + ], + ], + dtype=torch.int8, + ), + ), + # 2x2 kernel, stride 1, no padding, NHWC, in_zero_point=2 + ( + "nhwc_in_zero_point", + torch.tensor( + [[[[3], [4], [5]], [[6], [7], [8]], [[9], [10], [11]]]], + dtype=torch.int8, + ), + (2, 2), + (1, 1), + (0, 0), + (1, 1), + torch.tensor(2, dtype=torch.int32), + True, + False, + torch.tensor( + [ + [[3, 4, 6, 7], [4, 5, 7, 8], [6, 7, 9, 10], [7, 8, 10, 11]], + ], + dtype=torch.int8, + ), + ), + # Multi-channel input, 2x2 kernel, stride 1, no padding, NCHW + ( + "nchw_multi_channel", + torch.tensor( + [ + [ + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], # channel 0 + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], # channel 1 + ] + ], + dtype=torch.float32, + ), # (1,2,3,3) + (2, 2), + (1, 1), + (0, 0), + (1, 1), + None, + False, + False, + torch.tensor( + [ + [ + [1, 2, 4, 5, 10, 11, 13, 14], + [2, 3, 5, 6, 11, 12, 14, 15], + [4, 5, 7, 8, 13, 14, 16, 17], + [5, 6, 8, 9, 14, 15, 17, 18], + ], + ], + dtype=torch.float32, + ), + ), + # Multi-channel input and multi-channel zero-point + ( + "nchw_multi_channel_and_zero_point_no_padding", + torch.tensor([[[1, 2, 3]], [[4, 5, 6]]], dtype=torch.int32), + (1, 2), + (1, 1), + (0, 0), + (1, 1), + torch.tensor([-1, -2], dtype=torch.int32), + False, + False, + torch.tensor([[[1, 2], [2, 3]], [[4, 5], [5, 6]]], dtype=torch.int32), + ), + ( + "nchw_multi_channel_and_zero_point_with_padding=1_and_stride=(2, 1)", + torch.tensor([[[1, 2, 3]], [[4, 5, 6]]], dtype=torch.int32), + (1, 2), + (1, 1), + (2, 1), + (2, 2), + torch.tensor([-1, -2], dtype=torch.int32), + False, + False, + torch.tensor( + [ + [ + [-1, -1], + [-1, -1], + [-1, 1], + [2, 3], + [-1, -1], + [-1, -1], + ], + [ + [-2, -2], + [-2, -2], + [-2, 4], + [5, 6], + [-2, -2], + [-2, -2], + ], + ], + dtype=torch.int32, + ), + ), + ( + "per_tensor", + torch.tensor( + [[[[3], [4], [5]], [[6], [7], [8]], [[9], [10], [11]]]], + dtype=torch.int8, + ), + (2, 2), + (1, 1), + (0, 0), + (1, 1), + 2, + True, + True, + torch.tensor( + [ + [[3, 4, 6, 7], [4, 5, 7, 8], [6, 7, 9, 10], [7, 8, 10, 11]], + ], + dtype=torch.int8, + ), + ), + ] + ) + def test_im2row( + self, + name: str, + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + dilation: tuple[int, int], + padding: tuple[int, int], + stride: tuple[int, int], + in_zero_point: torch.Tensor | None, + channel_last: bool, + per_tensor: bool, + expected_output: torch.Tensor, + ) -> None: + if per_tensor: + output = torch.ops.cadence.im2row.per_tensor( + input_tensor, + kernel_size, + dilation, + padding, + stride, + in_zero_point, + channel_last, + ) + else: + output = torch.ops.cadence.im2row( + input_tensor, + kernel_size, + dilation, + padding, + stride, + in_zero_point, + channel_last, + ) + self.assertEqual( + output.shape, + expected_output.shape, + f"im2row output shape mismatch in {name}", + ) + self.assertTrue( + torch.equal(output, expected_output), + f"im2row output mismatch in {name}: got {output}, expected {expected_output}", + ) + + @expand( + [ + ( + "basic_2x2", + torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int32), + (2, 2), + (1, 1), + (0, 0), + (1, 1), + (0, 0), + None, + False, + torch.tensor( + [ + [ + [0, 0, 0, 1], + [0, 0, 1, 2], + [0, 0, 2, 0], + [0, 1, 0, 3], + [1, 2, 3, 4], + [2, 0, 4, 0], + [0, 3, 0, 0], + [3, 4, 0, 0], + [4, 0, 0, 0], + ] + ], + dtype=torch.int32, + ), + ), + ( + "basic_2x2_with_zero_point", + torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int32), + (2, 2), + (1, 1), + (0, 0), + (1, 1), + (0, 0), + torch.tensor(100, dtype=torch.int32), + False, + torch.tensor( + [ + [ + [100, 100, 100, 1], + [100, 100, 1, 2], + [100, 100, 2, 100], + [100, 1, 100, 3], + [1, 2, 3, 4], + [2, 100, 4, 100], + [100, 3, 100, 100], + [3, 4, 100, 100], + [4, 100, 100, 100], + ] + ], + dtype=torch.int32, + ), + ), + ( + "basic_2x2_with_stride_2", + torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int32), + (2, 2), # kernel size + (1, 1), # dilation + (0, 0), # padding + (2, 2), # stride + (0, 0), # output padding + None, + False, + torch.tensor( + [ + [ + [0, 0, 0, 1], + [0, 0, 1, 0], + [0, 0, 0, 2], + [0, 0, 2, 0], + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 2, 0, 0], + [2, 0, 0, 0], + [0, 0, 0, 3], + [0, 0, 3, 0], + [0, 0, 0, 4], + [0, 0, 4, 0], + [0, 3, 0, 0], + [3, 0, 0, 0], + [0, 4, 0, 0], + [4, 0, 0, 0], + ] + ], + dtype=torch.int32, + ), + ), + ( + "batch2_with_batch2_zero_point", + torch.tensor( + [ + [[[1, 2], [3, 4]]], + [[[5, 6], [7, 8]]], + ], + dtype=torch.int32, + ), # input: (2,1,2,2) + (2, 2), # kernel_size + (1, 1), # dilation + (0, 0), # padding + (1, 1), # stride + (0, 0), # output_padding + torch.tensor([100, 200], dtype=torch.int32), # in_zero_point per batch + False, # channel_last + torch.tensor( + [ + [ + [100, 100, 100, 1], + [100, 100, 1, 2], + [100, 100, 2, 100], + [100, 1, 100, 3], + [1, 2, 3, 4], + [2, 100, 4, 100], + [100, 3, 100, 100], + [3, 4, 100, 100], + [4, 100, 100, 100], + ], + [ + [200, 200, 200, 5], + [200, 200, 5, 6], + [200, 200, 6, 200], + [200, 5, 200, 7], + [5, 6, 7, 8], + [6, 200, 8, 200], + [200, 7, 200, 200], + [7, 8, 200, 200], + [8, 200, 200, 200], + ], + ], + dtype=torch.int32, + ), + ), + ] + ) + def test_transposed_im2row( + self, + name: str, + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + dilation: tuple[int, int], + padding: tuple[int, int], + stride: tuple[int, int], + output_padding: tuple[int, int], + in_zero_point: torch.Tensor | int | None, + channel_last: bool, + expected_output: torch.Tensor, + ) -> None: + output = torch.ops.cadence.transposed_im2row( + input_tensor, + kernel_size, + dilation, + padding, + stride, + output_padding, + in_zero_point, + channel_last, + ) + + self.assertEqual( + output.shape, + expected_output.shape, + f"transposed_im2row output shape mismatch in {name}: got {output.shape}, expected {expected_output.shape}", + ) + self.assertTrue( + torch.equal(output, expected_output), + f"transposed_im2row output mismatch in {name}: got {output}, expected {expected_output}", + ) + + @expand( + [ + ( + "1_group", + torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8), + torch.tensor([1, 1, 1], dtype=torch.float32), + torch.tensor([0, 0, 0], dtype=torch.int8), + torch.tensor([0, 2, 1], dtype=torch.int64), + torch.tensor( + [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]], + dtype=torch.float32, + ), + ), + ( + "2_groups", + torch.tensor( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=torch.int8 + ), + torch.tensor([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]], dtype=torch.float32), + torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int8), + torch.tensor([0, 2, 1], dtype=torch.int64), + torch.tensor( + [ + [0.0, 0.5, 1.0, 2.0], + [10.0, 12.5, 15.0, 18.0], + [3.0, 4.5, 6.0, 8.0], + ], + dtype=torch.float32, + ), + ), + ( + "1_group_none_zero_point", + torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8), + torch.tensor([1, 1, 1], dtype=torch.float32), + None, + torch.tensor([0, 2, 1], dtype=torch.int64), + torch.tensor( + [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]], + dtype=torch.float32, + ), + ), + ( + "1_group_batch2", + torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8), + torch.tensor([1, 1, 1], dtype=torch.float32), + torch.tensor([0, 0, 0], dtype=torch.int8), + torch.tensor([[0, 2, 1], [1, 0, 2]], dtype=torch.int64), + torch.tensor( + [ + [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]], + [[3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [6.0, 7.0, 8.0]], + ], + dtype=torch.float32, + ), + ), + ( + "2_groups_batch2", + torch.tensor( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=torch.int8 + ), + torch.tensor([[0.5, 1.0], [1.5, 2.0], [2.5, 3.0]], dtype=torch.float32), + torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int8), + torch.tensor([[0, 2, 1], [2, 1, 0]], dtype=torch.int64), + torch.tensor( + [ + [ + [0.0, 0.5, 1.0, 2.0], + [10.0, 12.5, 15.0, 18.0], + [3.0, 4.5, 6.0, 8.0], + ], + [ + [10.0, 12.5, 15.0, 18.0], + [3.0, 4.5, 6.0, 8.0], + [0.0, 0.5, 1.0, 2.0], + ], + ], + dtype=torch.float32, + ), + ), + ( + "1_group_none_zero_point_batch2", + torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.int8), + torch.tensor([1, 1, 1], dtype=torch.float32), + None, + torch.tensor([[0, 2, 1], [1, 0, 2]], dtype=torch.int64), + torch.tensor( + [ + [[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]], + [[3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [6.0, 7.0, 8.0]], + ], + dtype=torch.float32, + ), + ), + ] + ) + def test_quantized_embedding_byte( + self, + name: str, + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: torch.Tensor | None, + indices: torch.Tensor, + expected_out: torch.Tensor, + ) -> None: + self.assertTrue( + torch.equal( + torch.ops.cadence.quantized_embedding_byte( + weight, weight_scales, weight_zero_points, indices + ), + expected_out, + ) + ) + + @expand( + [ + *[ + ( + dtype, + (4, 4), + full_matrices, + ) + for dtype in [torch.float32, torch.float64] + for full_matrices in [True, False] + ] + ] + ) + def test_linalg_svd_outputs_are_contiguous( + self, + dtype: torch.dtype, + shape: tuple[int, int], + full_matrices: bool, + ) -> None: + m, n = shape + a = torch.eye(m, n, dtype=dtype) + + U, S, Vh = torch.ops.cadence.linalg_svd(a, full_matrices) + + self.assertTrue(U.is_contiguous(), "U not contiguous") + self.assertTrue(S.is_contiguous(), "S not contiguous") + self.assertTrue(Vh.is_contiguous(), "Vh not contiguous") + self.assertTrue(U.dtype == dtype, "U dtype mismatch") + self.assertTrue(S.dtype == dtype, "S dtype mismatch") + self.assertTrue(Vh.dtype == dtype, "Vh dtype mismatch") + + def test_quantized_add(self) -> None: + # Test quantized_add (default variant), just to make sure it runs since wrapper around per_tensor variant + X = torch.tensor([[1, 2], [3, 4]], dtype=torch.int8) + X_scale = torch.tensor([0.1]) + X_zero_point = torch.tensor([0]) + Y = torch.tensor([[5, 6], [7, 8]], dtype=torch.int8) + Y_scale = torch.tensor([0.1]) + Y_zero_point = torch.tensor([0]) + out_scale = 0.1 + out_zero_point = 0 + torch.ops.cadence.quantized_add( + X, + X_scale, + X_zero_point, + Y, + Y_scale, + Y_zero_point, + out_scale, + out_zero_point, + ) + + def test_requantize(self) -> None: + # Test requantize (default variant), just to make sure it runs since wrapper around per_tensor variant + input_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.int8) + in_scale = torch.tensor([0.1]) + in_zero_point = torch.tensor([0]) + out_scale_tensor = torch.tensor([0.2]) + out_zero_point_tensor = torch.tensor([0]) + torch.ops.cadence.requantize( + input_tensor, + in_scale, + in_zero_point, + out_scale_tensor, + out_zero_point_tensor, + ScalarType.CHAR, + ) + + def test_quantized_conv2d_nchw(self) -> None: + # Test quantized_conv2d_nchw (default variant), just to make sure it runs since wrapper around per_tensor variant + input_conv = torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int8) + weight_conv = torch.tensor([[[[1, 0], [0, 1]]]], dtype=torch.int8) + bias_conv = torch.tensor([0], dtype=torch.int32) + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + groups = 1 + input_zero_point = 0 + weight_zero_point = torch.tensor([0]) + bias_scale = torch.tensor([1.0]) + conv_out_scale = 0.1 + conv_out_zero_point = 0 + out_multiplier = torch.tensor([1073741824], dtype=torch.int32) + out_shift = torch.tensor([0], dtype=torch.int32) + torch.ops.cadence.quantized_conv2d_nchw( + input_conv, + weight_conv, + bias_conv, + stride, + padding, + dilation, + groups, + input_zero_point, + weight_zero_point, + bias_scale, + conv_out_scale, + conv_out_zero_point, + out_multiplier, + out_shift, + ) + + def test_quantized_relu(self) -> None: + # Test quantized_relu (default variant), just to make sure it runs since wrapper around per_tensor variant + X_relu = torch.tensor([[-1, 0, 1, 3]], dtype=torch.int8) + X_zero_point_relu = torch.tensor([0]) + relu_out_zero_point = 0 + out_multiplier_relu = torch.tensor([1073741824], dtype=torch.int32) + out_shift_relu = torch.tensor([0], dtype=torch.int32) + torch.ops.cadence.quantized_relu( + X_relu, + X_zero_point_relu, + relu_out_zero_point, + out_multiplier_relu, + out_shift_relu, + ) + + def test_quantized_conv2d_nhwc(self) -> None: + # Test quantized_conv2d_nhwc (default variant), just to make sure it runs since wrapper around per_tensor variant + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + groups = 1 + input_zero_point = 0 + weight_zero_point = torch.tensor([0]) + bias_scale = torch.tensor([1.0]) + conv_out_scale = 0.1 + conv_out_zero_point = 0 + input_nhwc = torch.tensor([[[[1], [2]], [[3], [4]]]], dtype=torch.int8) + weight_nhwc = torch.tensor([[[[1], [0]], [[0], [1]]]], dtype=torch.int8) + bias_nhwc = torch.tensor([0], dtype=torch.int32) + out_multiplier = torch.tensor([1073741824], dtype=torch.int32) + out_shift = torch.tensor([0], dtype=torch.int32) + torch.ops.cadence.quantized_conv2d_nhwc( + input_nhwc, + weight_nhwc, + bias_nhwc, + stride, + padding, + dilation, + groups, + input_zero_point, + weight_zero_point, + bias_scale, + conv_out_scale, + conv_out_zero_point, + out_multiplier, + out_shift, + ) + + def test_quantized_layer_norm(self) -> None: + # Test quantized_layer_norm (default variant), just to make sure it runs since wrapper around per_tensor variant + X_ln = torch.tensor([[-1, 1]], dtype=torch.int8) + X_scale_ln = torch.tensor([0.1]) + X_zero_point_ln = torch.tensor([0]) + normalized_shape = [2] + weight_ln = torch.tensor([1.0, 1.0]) + bias_ln = torch.tensor([0.0, 0.0]) + eps = 1e-5 + output_scale = 0.1 + output_zero_point = 0 + torch.ops.cadence.quantized_layer_norm( + X_ln, + X_scale_ln, + X_zero_point_ln, + normalized_shape, + weight_ln, + bias_ln, + eps, + output_scale, + output_zero_point, + ) + + def test_softmax_f32_f32(self) -> None: + # Just a wrapper around torch.nn.functional.softmax, so just ensure that it runs + input_tensor = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32 + ) + output = torch.ops.cadence._softmax_f32_f32(input_tensor, dim=1) + self.assertEqual(output.dtype, torch.float32) + self.assertEqual(output.shape, input_tensor.shape) + + @expand( + [ + ( + "basic_hidden_dim_4", + torch.tensor([[1.0, 2.0]], dtype=torch.float32), # inputs: 1x2 + torch.tensor( + [[0.5, 0.5, 0.5, 0.5]], dtype=torch.float32 + ), # hidden: 1x4 + torch.ones( + (12, 2), dtype=torch.int8 + ), # weights_inputs: 12x2 (3*4 x input_dim=2) + 0.1, # w_i_scale + torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 (3*4 x 4) + 0.1, # w_h_scale + torch.zeros(12, dtype=torch.int8), # bias_inputs: 12 + 0.1, # b_i_scale + torch.zeros(12, dtype=torch.int8), # bias_hidden: 12 + 0.1, # b_h_scale + ), + ( + "invalid_batch_size_2", + torch.tensor( + [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], dtype=torch.float32 + ), # inputs: 2x3 + torch.tensor( + [[0.5, 0.5, 0.5, 0.5], [0.3, 0.3, 0.3, 0.3]], dtype=torch.float32 + ), # hidden: 2x4 + torch.ones((12, 3), dtype=torch.int8), # weights_inputs: 12x3 + 0.1, # w_i_scale + torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 + 0.1, # w_h_scale + torch.zeros(12, dtype=torch.int8), # bias_inputs: 12 + 0.1, # b_i_scale + torch.zeros(12, dtype=torch.int8), # bias_hidden: 12 + 0.1, # b_h_scale + ), + ( + "non_zero_biases", + torch.tensor([[1.0, 1.0]], dtype=torch.float32), # inputs: 1x2 + torch.zeros((1, 4), dtype=torch.float32), # hidden: 1x4 + torch.ones((12, 2), dtype=torch.int8), # weights_inputs: 12x2 + 0.2, # w_i_scale + torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 + 0.1, # w_h_scale + torch.tensor( + [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8 + ), # bias_inputs: 12 + 0.1, # b_i_scale + torch.tensor( + [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8 + ), # bias_hidden: 12 + 0.1, # b_h_scale + ), + ( + "negative_weights", + torch.tensor([[1.0, -1.0]], dtype=torch.float32), # inputs: 1x2 + torch.tensor( + [[0.5, -0.5, 0.5, -0.5]], dtype=torch.float32 + ), # hidden: 1x4 + torch.tensor( + [[1, -1], [-1, 1]] * 6, dtype=torch.int8 + ), # weights_inputs: 12x2 (alternating pattern) + 0.1, # w_i_scale + torch.tensor( + [[1, -1, 1, -1], [-1, 1, -1, 1]] * 6, dtype=torch.int8 + ), # weights_hidden: 12x4 (alternating pattern) + 0.1, # w_h_scale + torch.zeros(12, dtype=torch.int8), # bias_inputs: 12 + 0.1, # b_i_scale + torch.zeros(12, dtype=torch.int8), # bias_hidden: 12 + 0.1, # b_h_scale + ), + ( + "hidden_dim_8", + torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32), # inputs: 1x3 + torch.tensor( + [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]], dtype=torch.float32 + ), # hidden: 1x8 + torch.ones((24, 3), dtype=torch.int8), # weights_inputs: 24x3 (3*8 x 3) + 0.1, # w_i_scale + torch.ones((24, 8), dtype=torch.int8), # weights_hidden: 24x8 (3*8 x 8) + 0.1, # w_h_scale + torch.zeros(24, dtype=torch.int8), # bias_inputs: 24 + 0.1, # b_i_scale + torch.zeros(24, dtype=torch.int8), # bias_hidden: 24 + 0.1, # b_h_scale + ), + ] + ) + def test_quantized_w8a32_gru( + self, + name: str, + inputs: torch.Tensor, + hidden: torch.Tensor, + weights_inputs: torch.Tensor, + w_i_scale: float, + weights_hidden: torch.Tensor, + w_h_scale: float, + bias_inputs: torch.Tensor, + b_i_scale: float, + bias_hidden: torch.Tensor, + b_h_scale: float, + ) -> None: + + if name == "invalid_batch_size_2": + with self.assertRaises(ValueError) as context: + torch.ops.cadence.quantized_w8a32_gru( + inputs, + hidden, + weights_inputs, + w_i_scale, + weights_hidden, + w_h_scale, + bias_inputs, + b_i_scale, + bias_hidden, + b_h_scale, + ) + self.assertIn( + "Leading dimension of hidden state must be 1", str(context.exception) + ) + return + + output = torch.ops.cadence.quantized_w8a32_gru( + inputs, + hidden, + weights_inputs, + w_i_scale, + weights_hidden, + w_h_scale, + bias_inputs, + b_i_scale, + bias_hidden, + b_h_scale, + ) + + # Verify output properties + self.assertEqual( + output.dtype, + torch.float32, + f"Output dtype should be float32 in {name}", + ) + self.assertEqual( + output.shape, + (2, hidden.shape[-1]), + f"Output shape should match {(2, hidden.shape[-1])} in {name}", + ) + assert isinstance(output, torch.Tensor) + + # Verify output is bounded: GRU hidden state is a convex combination of + # tanh([-1,1]) and previous hidden([-1,1]), so output should be in [-1,1] + self.assertTrue( + torch.all(output >= -1.0) and torch.all(output <= 1.0), + f"Output values should be in [-1.1, 1.1] in {name}. Got min={output.min():.4f}, max={output.max():.4f}", + ) + + def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None: + # Test that non-multiple of 4 hidden dimension raises error + inputs = torch.tensor([[1.0, 2.0]], dtype=torch.float32) # 1x2 + hidden = torch.tensor( + [[0.5, 0.5, 0.5]], dtype=torch.float32 + ) # 1x3 (not divisible by 4) + weights_inputs = torch.zeros((9, 2), dtype=torch.int8) # 9x2 + weights_hidden = torch.zeros((9, 3), dtype=torch.int8) # 9x3 + bias_inputs = torch.zeros(9, dtype=torch.int8) + bias_hidden = torch.zeros(9, dtype=torch.int8) + + with self.assertRaises(ValueError) as context: + torch.ops.cadence.quantized_w8a32_gru( + inputs, + hidden, + weights_inputs, + 0.1, + weights_hidden, + 0.1, + bias_inputs, + 0.1, + bias_hidden, + 0.1, + ) + + self.assertIn( + "Hidden dimension must be a multiple of 4", str(context.exception) + ) + + @expand( + [ + ( + "basic_int8_dim_1", + torch.tensor([[10, 20, 30]], dtype=torch.int8), + None, + 1, + 0.1, + 0, + 0.004, + 0, + torch.int8, + torch.tensor([[23, 61, 127]], dtype=torch.int8), + ), + ( + "uint8_with_zero_points", + torch.tensor([[128, 130, 132]], dtype=torch.uint8), + None, + 1, + 0.1, + 128, + 0.004, + 128, + torch.uint8, + torch.tensor([[195, 210, 228]], dtype=torch.uint8), + ), + ( + "basic_int16", + torch.tensor([[100, 200, 300]], dtype=torch.int16), + None, + 1, + 0.01, + 0, + 0.004, + 0, + torch.int16, + torch.tensor([[23, 61, 166]], dtype=torch.int16), + ), + ( + "multi_row_int8", + torch.tensor([[10, 20, 30], [5, 10, 15]], dtype=torch.int8), + None, + 1, + 0.1, + 0, + 0.004, + 0, + torch.int8, + torch.tensor([[23, 61, 127], [47, 77, 127]], dtype=torch.int8), + ), + ( + "softmax_dim_0", + torch.tensor([[10, 20], [30, 40]], dtype=torch.int8), + None, + 0, + 0.1, + 0, + 0.004, + 0, + torch.int8, + torch.tensor([[30, 30], [127, 127]], dtype=torch.int8), + ), + ] + ) + def test_quantized_softmax_per_tensor( + self, + name: str, + input_tensor: torch.Tensor, + mask: torch.Tensor | None, + dim: int, + in_scale: float, + in_zero_point: int, + out_scale: float, + out_zero_point: int, + dtype: torch.dtype, + expected_output: torch.Tensor, + ) -> None: + output = torch.ops.cadence.quantized_softmax.per_tensor( + input_tensor, + mask, + dim, + in_scale, + in_zero_point, + out_scale, + out_zero_point, + ) + + # Verify output properties + self.assertEqual( + output.dtype, dtype, f"Output dtype should be {dtype} in {name}" + ) + self.assertEqual( + output.shape, + input_tensor.shape, + f"Output shape should match input shape in {name}", + ) + + # Verify output matches expected values (allowing for small quantization errors) + # For softmax, we expect outputs to be in [0, 1] range when dequantized + self.assertTrue( + torch.allclose( + output.to(torch.float32), + expected_output.to(torch.float32), + rtol=0.05, + atol=5.0, + ), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + + def test_quantized_softmax(self) -> None: + # Test quantized_softmax (default variant with tensor scale/zero_point) + input_tensor = torch.tensor([[10, 20, 30]], dtype=torch.int8) + in_scale = torch.tensor([0.1]) + in_zero_point = torch.tensor([0]) + output = torch.ops.cadence.quantized_softmax( + input_tensor, + None, # mask + 1, # dim + in_scale, + in_zero_point, + 0.004, # out_scale + 0, # out_zero_point + ) + + # Verify output properties + self.assertEqual(output.dtype, torch.int8, "Output dtype should be int8") + self.assertEqual( + output.shape, + input_tensor.shape, + "Output shape should match input shape", + ) diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index a38416c0ff1..c957eb04b87 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -21,7 +21,7 @@ RemoveAliasCopyOpPass, RemoveBranchedQuantDequant, RemoveCatFromSliceCopyPass, - RemoveCloneOpPass, + RemoveCloneOpsTransformImported, RemoveContiguousOpPass, RemoveDetachCopyPass, RemoveNopAddOpPass, @@ -196,13 +196,13 @@ def test_remove_zero_sized_constant_pad_nd( ) builder.output([pad]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveZeroSizedConstantPadNd()(original) - ).graph_module + pass_result = cast(PassResult, RemoveZeroSizedConstantPadNd()(original)) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default), 0, ) + self.assertTrue(pass_result.modified) def test_remove_expand(self) -> None: builder = GraphBuilder() @@ -228,12 +228,12 @@ def test_remove_zero_arg_cat(self) -> None: ) builder.output([concat]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveZeroSizedCatArgsPass()(original) - ).graph_module + pass_result = cast(PassResult, RemoveZeroSizedCatArgsPass()(original)) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0 ) + self.assertTrue(pass_result.modified) def test_remove_clone(self) -> None: builder = GraphBuilder() @@ -241,7 +241,7 @@ def test_remove_clone(self) -> None: clone = builder.call_operator(op=exir_ops.edge.aten.clone.default, args=(x,)) builder.output([clone]) original = builder.get_graph_module() - p = RemoveCloneOpPass() + p = RemoveCloneOpsTransformImported() graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, torch.ops.aten.clone.default), 0 @@ -304,6 +304,22 @@ def test_remove_nop_slice(self) -> None: count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0 ) + def test_remove_nop_slice_or_view_not_modified(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) + abs_x = builder.call_operator( + op=exir_ops.edge.aten.abs.default, + args=(x,), + ) + builder.output([abs_x]) + original = builder.get_graph_module() + pass_result = cast(PassResult, RemoveNopSliceOrViewOpPass()(original)) + self.assertFalse(pass_result.modified) + graph_after_passes = pass_result.graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.abs.default), 1 + ) + def test_remove_nop_select_before_view(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) @@ -595,7 +611,9 @@ def test_remove_squeeze_view_before_elemwise_ops(self) -> None: original = deepcopy(model) p = RemoveSqueezeViewBeforeElementwiseOps() - transformed = cast(PassResult, p(model)).graph_module + pass_result = cast(PassResult, p(model)) + self.assertTrue(pass_result.modified) + transformed = pass_result.graph_module # First view should be eliminated and second view should be trivial. views = transformed.graph.find_nodes( @@ -856,9 +874,9 @@ def test_remove_dequant_on_branch(self) -> None: ) builder.output([x1_output, y1_output]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveBranchedQuantDequant()(original) - ).graph_module + pass_result = cast(PassResult, RemoveBranchedQuantDequant()(original)) + self.assertTrue(pass_result.modified) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node( graph_after_passes, @@ -888,9 +906,9 @@ def test_remove_cat_from_slice_copy(self) -> None: ) builder.output([output]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveCatFromSliceCopyPass()(original) - ).graph_module + pass_result = cast(PassResult, RemoveCatFromSliceCopyPass()(original)) + self.assertTrue(pass_result.modified) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0 ) @@ -906,9 +924,9 @@ def test_keep_cat_from_slice_copy(self) -> None: ) builder.output([output]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveCatFromSliceCopyPass()(original) - ).graph_module + pass_result = cast(PassResult, RemoveCatFromSliceCopyPass()(original)) + self.assertFalse(pass_result.modified) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 1 ) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index ca5168db2be..6c36a28b665 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -6,10 +6,13 @@ # pyre-strict +import copy import operator import unittest from typing import cast, List, Optional, Sequence, Tuple, Union +import executorch.backends.cadence.aot.ref_implementations # noqa + import torch from executorch.backends.cadence.aot.graph_builder import ( GraphBuilder, @@ -20,7 +23,7 @@ MakeSliceAndCatDimOutermostPass, ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, ReplaceAddMMWithLinearPass, - ReplaceAtenApproxGeluWithApproxGeluPass, + ReplaceAtenAvgPoolWithCadenceAvgPoolPass, ReplaceAtenConvolutionWithCadenceConvolutionPass, ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass, ReplaceConstantPadNdWithSlicePass, @@ -31,6 +34,7 @@ ReplaceFunctionallyEquivalentOpTargets, ReplaceIm2RowWithViewPass, ReplaceLinearWithFullyConnectedOpPass, + ReplaceLogicalNotBooleanWhereWithWherePass, ReplaceMatmulWithTransposedMatmulPass, ReplaceMMWithAddMMPass, ReplaceMulTensorWithMulAndFullOpsPass, @@ -42,9 +46,9 @@ ReplaceScalarTensorWithFullPass, ReplaceScalarWithTensorArgPass, ReplaceSelectWithViewOpPass, - ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, ReplaceSplitWithSlicePass, ReplaceSqueezeAndUnsqueezeWithViewPass, + ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding, ReplaceTransposedConvWithLinearPass, ReplaceTrivialConvWithLinear, ReplaceWhereWithFullArgsWithWhereScalar, @@ -52,9 +56,48 @@ from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass -from executorch.exir.passes import dead_code_elimination_pass +from executorch.exir.pass_base import ExportPass, ProxyValue from torch.fx.passes.infra.pass_base import PassResult +from torch.utils import _pytree as pytree + + +def validate( + original: torch.fx.GraphModule, + modified: torch.fx.GraphModule, + inputs: tuple[torch.Tensor, ...] | list[torch.Tensor], + pass_name: str, + rtol: float = 1e-5, + atol: float = 1e-6, +) -> None: + """Validate that two graph modules produce numerically equivalent outputs. + + Args: + original: The original graph module before the pass + modified: The modified graph module after the pass + inputs: Input tensors to run through both graphs + pass_name: Name of the pass being validated (for error messages) + rtol: Relative tolerance for allclose comparison + atol: Absolute tolerance for allclose comparison + """ + original.eval() + modified.eval() + with torch.no_grad(): + orig_out = original(*inputs) + mod_out = modified(*inputs) + + flat_orig_out, _ = pytree.tree_flatten(orig_out) + flat_mod_out, _ = pytree.tree_flatten(mod_out) + + # Check that outputs match within tolerance + for i, (orig_tensor, mod_tensor) in enumerate(zip(flat_orig_out, flat_mod_out)): + if not torch.allclose(orig_tensor, mod_tensor, rtol=rtol, atol=atol): + max_diff = torch.max(torch.abs(orig_tensor - mod_tensor)).item() + raise AssertionError( + f"Pass validation failed for pass {pass_name}. " + f"Output tensor {i} differs by max {max_diff:.6e}. " + f"Expected rtol={rtol}, atol={atol}. " + f"Original output: {orig_tensor}, Modified output: {mod_tensor}" + ) class TestReplaceOpsPasses(unittest.TestCase): @@ -103,8 +146,10 @@ def test_replace_matmul_with_transposed_matmul( y_shape: Tuple[int], ) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*x_shape, dtype=torch.float32)) - y = builder.placeholder("y", torch.randn(*y_shape, dtype=torch.float32)) + x_ = torch.randint(0, 100, x_shape, dtype=torch.int8) + x = builder.placeholder("x", x_) + y_ = torch.randint(0, 100, y_shape, dtype=torch.int8) + y = builder.placeholder("y", y_) matmul = builder.call_operator( op=exir_ops.edge.cadence.quantized_matmul.default, args=( @@ -121,8 +166,13 @@ def test_replace_matmul_with_transposed_matmul( ) builder.output([matmul]) original_gm = builder.get_graph_module() + + gm_before = copy.deepcopy(original_gm) p = ReplaceMatmulWithTransposedMatmulPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = p.call(original_gm) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 1, @@ -133,6 +183,12 @@ def test_replace_matmul_with_transposed_matmul( ), 1, ) + validate( + gm_before, + graph_after_passes, + (x_, y_), + "ReplaceMatmulWithTransposedMatmulPass", + ) @expand( [ @@ -145,15 +201,28 @@ def test_replace_constant_pad_nd_with_slice( self, _, shape: Tuple[int], padding: Tuple[int] ) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + x_input = torch.randn(*shape, dtype=torch.float32) + x = builder.placeholder("x", x_input) matmul = builder.call_operator( op=exir_ops.edge.aten.constant_pad_nd.default, args=(x, [0, 0, 0, 0]), ) builder.output([matmul]) original_gm = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original_gm) p = ReplaceConstantPadNdWithSlicePass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, graph_after_passes, inputs, "ReplaceConstantPadNdWithSlicePass" + ) + self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.slice.Tensor), 1, @@ -298,7 +367,9 @@ def test_replace_functionally_equivalent_op_targets_relu( args=(x,), ) p = ReplaceFunctionallyEquivalentOpTargets() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.relu.default), @@ -345,6 +416,203 @@ def test_replace_functionally_equivalent_op_targets_unsafe_split( count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), 0, x ) + def assertTensorMetadataIsSame( + self, a: Sequence[torch.Tensor], b: Sequence[torch.Tensor] + ) -> None: + for i, (_a, _b) in enumerate(zip(a, b)): + # TODO: actually compare the tensors. + self.assertTrue( + _a.shape == _b.shape, f"Tensor {i}: {_a.shape} != {_b.shape}" + ) + self.assertTrue( + _a.dtype == _b.dtype, f"Tensor {i}: {_a.dtype} != {_b.dtype}" + ) + + @expand( + [ + [(1, 8, 18), 8, 16, 3], + [(1, 8, 18), 8, 16, 5, 2], + # depthwise + bias + [(1, 8, 18), 8, 16, 5, 2, 0, 1, True], + # no bias + [(1, 8, 18), 8, 16, 3, 2, 4, 3, False, False], + # bias + transposed + [(1, 8, 18), 8, 16, 5, 2, 0, 1, False, True], + # Stride of 2 needed. + [(1, 8, 3), 8, 8, 48, 2, 23], + ] + ) + @torch.no_grad() + def test_replace_aten_conv_with_cadence_conv( + self, + shape: Tuple[int, ...], + in_channels: int, + out_channels: int, + kernel: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + depthwise: bool = False, + bias_enabled: bool = True, + output_padding: Optional[int] = None, + ) -> None: + groups = in_channels if depthwise else 1 + builder = GraphBuilder() + x_tensor = torch.randn(*shape, dtype=torch.float32) + x = builder.placeholder("x", x_tensor) + # For regular conv: weight shape is [out_channels, in_channels // groups, kernel] + weights_shape = [out_channels, in_channels // groups, kernel] + weights_tensor = torch.randn(weights_shape, dtype=torch.float32) + weights = builder.placeholder("weights", weights_tensor) + bias: Optional[ProxyValue] = None + bias_tensor: Optional[torch.Tensor] = None + if bias_enabled: + bias_tensor = torch.randn([out_channels], dtype=torch.float32) + bias = builder.placeholder("bias", bias_tensor) + convolution = builder.call_operator( + op=exir_ops.edge.aten.convolution.default, + args=( + x, + weights, + bias, + [stride], + [padding], + [dilation], + False, + [output_padding] if output_padding else [0], + groups, + ), + ) + builder.output([convolution]) + original_gm = builder.get_graph_module() + + gm_before = copy.deepcopy(original_gm) + p = ReplaceAtenConvolutionWithCadenceConvolutionPass() + replacement_pass_result = cast(PassResult, p(original_gm)) + self.assertIsNotNone(replacement_pass_result) + self.assertTrue(replacement_pass_result.modified) + graph_after_passes = replacement_pass_result.graph_module + + # Validate numerical accuracy + inputs = (x_tensor, weights_tensor) + if bias is not None: + inputs += (cast(torch.Tensor, bias_tensor),) + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceAtenConvolutionWithCadenceConvolutionPass", + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.convolution.default), + 0, + ) + # This is a 1D convolution (using [stride], [padding], [dilation]) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.conv1d.default), + 1, + ) + self.assertEqual( + count_node( + graph_after_passes, exir_ops.edge.cadence.transposed_convolution.default + ), + 0, + ) + + @expand( + [ + [(1, 8, 16), 8, 16, 3], + [(1, 8, 16), 8, 16, 5, 2], + # depthwise + bias + [(1, 8, 16), 8, 16, 5, 2, 0, 1, True, True], + # no bias + [(1, 8, 16), 8, 16, 3, 2, 4, 3, False, False], + # depthwise + no bias + [(1, 8, 16), 8, 16, 3, 1, 0, 1, True, False], + # bias + [(1, 8, 16), 8, 16, 5, 2, 0, 1, False, True], + ] + ) + @torch.no_grad() + def test_replace_aten_transposed_conv_with_cadence_transposed_conv( + self, + shape: Tuple[int, ...], + in_channels: int, + out_channels: int, + kernel: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + depthwise: bool = False, + bias_enabled: bool = True, + output_padding: Optional[int] = None, + ) -> None: + groups = in_channels if depthwise else 1 + builder = GraphBuilder() + x_tensor = torch.randn(*shape, dtype=torch.float32) + x = builder.placeholder("x", x_tensor) + # For transposed conv: weight shape is [in_channels, out_channels // groups, kernel] + weights_shape = [in_channels, out_channels // groups, kernel] + weights_tensor = torch.randn(weights_shape, dtype=torch.float32) + weights = builder.placeholder( + "weights", + weights_tensor, + ) + bias_tensor = ( + torch.randn([out_channels], dtype=torch.float32) if bias_enabled else None + ) + bias = ( + builder.placeholder("bias", cast(torch.Tensor, bias_tensor)) + if bias_enabled + else None + ) + convolution = builder.call_operator( + op=exir_ops.edge.aten.convolution.default, + args=( + x, + weights, + bias, + [stride], + [padding], + [dilation], + True, + [output_padding] if output_padding else [0], + groups, + ), + ) + builder.output([convolution]) + original_gm = builder.get_graph_module() + gm_before = copy.deepcopy(original_gm) + + p = ReplaceAtenConvolutionWithCadenceConvolutionPass() + replacement_pass_result = cast(PassResult, p(original_gm)) + self.assertIsNotNone(replacement_pass_result) + self.assertTrue(replacement_pass_result.modified) + graph_after_passes = replacement_pass_result.graph_module + + inputs = (x_tensor, weights_tensor) + if bias_tensor is not None: + inputs += (bias_tensor,) + + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceAtenConvolutionWithCadenceConvolutionPass", + ) + + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.convolution.default), + 0, + ) + self.assertEqual( + count_node( + graph_after_passes, exir_ops.edge.cadence.transposed_convolution.default + ), + 1, + ) + @expand( [ [(1, 8, 33), 8, 16, 3], @@ -358,7 +626,7 @@ def test_replace_functionally_equivalent_op_targets_unsafe_split( @torch.no_grad() def test_replace_transposed_conv_with_linear( self, - shape: Tuple[int], + shape: Tuple[int, ...], in_channels: int, out_channels: int, kernel: int, @@ -369,19 +637,29 @@ def test_replace_transposed_conv_with_linear( bias_enabled: bool = True, channel_last: bool = False, ) -> None: - transposed = True output_padding = [0] groups = in_channels if depthwise else 1 builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) - weights = builder.placeholder( - "weights", - torch.randn([in_channels, out_channels, kernel], dtype=torch.float32), + x_tensor = torch.randn(*shape, dtype=torch.float32) + x = builder.placeholder("x", x_tensor) + # For transposed conv: weight shape is [in_channels, out_channels // groups, kernel] + weights_tensor = torch.randn( + [in_channels, out_channels // groups, kernel], dtype=torch.float32 + ) + weights = builder.placeholder("weights", weights_tensor) + + transposed_weights = builder.call_operator( + op=exir_ops.edge.aten.transpose_copy.int, args=(weights, 0, 1) + ) + flipped_weights = builder.call_operator( + exir_ops.edge.aten.flip.default, + args=(transposed_weights, [-1]), + ) + bias_tensor = ( + torch.randn([out_channels], dtype=torch.float32) if bias_enabled else None ) bias = ( - builder.placeholder( - "bias", torch.randn([out_channels], dtype=torch.float32) - ) + builder.placeholder("bias", cast(torch.Tensor, bias_tensor)) if bias_enabled else None ) @@ -391,17 +669,17 @@ def test_replace_transposed_conv_with_linear( args=(x, [0, 2, 1]), ) convolution = builder.call_operator( - op=exir_ops.edge.aten.convolution.default, + op=exir_ops.edge.cadence.transposed_convolution.default, args=( x, - weights, + flipped_weights, bias, [stride], [padding], [dilation], - transposed, output_padding, groups, + False, ), ) if channel_last: @@ -412,11 +690,24 @@ def test_replace_transposed_conv_with_linear( builder.output([convolution]) original_gm = builder.get_graph_module() - p1 = ReplaceAtenConvolutionWithCadenceConvolutionPass() - p2 = ReplaceTransposedConvWithLinearPass() - graph_after_passes = cast( - PassResult, p2(cast(PassResult, p1(original_gm)).graph_module) - ).graph_module + gm_before = copy.deepcopy(original_gm) + + # Run ReplaceTransposedConvWithLinearPass + result = ReplaceTransposedConvWithLinearPass().call(original_gm) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = (x_tensor, weights_tensor) + if bias_tensor is not None: + inputs += (bias_tensor,) + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceTransposedConvWithLinearPass", + ) + self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 1, @@ -426,19 +717,23 @@ def test_replace_transposed_conv_with_linear( 0, ) self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), + count_node(graph_after_passes, exir_ops.edge.cadence.conv1d.default) + + count_node(graph_after_passes, exir_ops.edge.cadence.conv2d.default), + 0, + ) + self.assertEqual( + count_node( + graph_after_passes, exir_ops.edge.cadence.transposed_convolution.default + ), 0, ) @expand( [ - [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False, False], + [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False], # # depthwise - [(1, 8, 33), 8, 16, 3, 1, 0, 1, True, False, False], - [(1, 8, 33), 8, 16, 3, 2, 4, 3, True, False, False], - # channel last (uses a permute op before calling conv1d) - [(1, 33, 8), 8, 16, 3, 1, 0, 1, False, False, True], - [(1, 33, 8), 8, 16, 3, 2, 4, 3, True, False, True], + [(1, 8, 33), 8, 16, 3, 1, 0, 1, True, False], + [(1, 8, 33), 8, 16, 3, 2, 4, 3, True, False], ] ) @torch.no_grad() @@ -453,31 +748,24 @@ def test_replace_convolution_optional_args_with_concrete_args( dilation: int = 1, depthwise: bool = False, bias_enabled: bool = True, - channel_last: bool = False, ) -> None: - transposed = True - output_padding = [0] groups = in_channels if depthwise else 1 builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) - weights = builder.placeholder( - "weights", - torch.randn([in_channels, out_channels, kernel], dtype=torch.float32), - ) - bias = ( - builder.placeholder( - "bias", torch.randn([out_channels], dtype=torch.float32) - ) - if bias_enabled - else None - ) - if channel_last: - x = builder.call_operator( - op=exir_ops.edge.aten.permute_copy.default, - args=(x, [0, 2, 1]), - ) + x_input = torch.randn(*shape, dtype=torch.float32) + weights_input = torch.randn( + [out_channels, in_channels // groups, kernel], dtype=torch.float32 + ) + x = builder.placeholder("x", x_input) + weights = builder.placeholder("weights", weights_input) + bias_input = None + if bias_enabled: + bias_input = torch.randn([out_channels], dtype=torch.float32) + bias = builder.placeholder("bias", bias_input) + else: + bias = None + convolution = builder.call_operator( - op=exir_ops.edge.aten.convolution.default, + op=exir_ops.edge.cadence.conv1d.default, args=( x, weights, @@ -485,26 +773,34 @@ def test_replace_convolution_optional_args_with_concrete_args( [stride], [padding], [dilation], - transposed, - output_padding, groups, ), ) - if channel_last: - convolution = builder.call_operator( - op=exir_ops.edge.aten.permute_copy.default, - args=(convolution, [0, 2, 1]), - ) builder.output([convolution]) original_gm = builder.get_graph_module() + + gm_before = copy.deepcopy(original_gm) p = ReplaceConvolutionOptionalArgsWithConcreteArgsPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + inputs = [x_input, weights_input] + ( + [bias_input] if bias_input is not None else [] + ) + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceConvolutionOptionalArgsWithConcreteArgsPass", + ) + self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, ) self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.convolution.default), + count_node(graph_after_passes, exir_ops.edge.cadence.conv1d.default), 1, ) @@ -525,8 +821,17 @@ def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]) -> N op=exir_ops.edge.aten.constant_pad_nd.default, args=(x, padding), ) + + gm_before = copy.deepcopy(original_gm) p = ReplacePadWithCatPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x] + validate(gm_before, graph_after_passes, inputs, "ReplacePadWithCatPass") + self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 1, @@ -544,8 +849,16 @@ def test_replace_repeat_with_cat(self) -> None: op=exir_ops.edge.aten.repeat.default, args=(x, [1, 2]), ) + + gm_before = copy.deepcopy(original_gm) p = ReplaceRepeatWithCatPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + inputs = [x] + validate(gm_before, graph_after_passes, inputs, "ReplaceRepeatWithCatPass") + self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 1, @@ -597,7 +910,9 @@ def test_replace_masked_scalar_tensor_with_full( builder.output([aten_where_self]) original_gm = builder.get_graph_module() p = ReplaceScalarTensorWithFullPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, @@ -621,7 +936,9 @@ def test_replace_scalar_tensor_with_full( args=(0.123,), ) p = ReplaceScalarTensorWithFullPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, @@ -635,10 +952,10 @@ def test_replace_scalar_tensor_with_full( def test_replace_linear_with_fully_connected(self) -> None: shape, in_channels, out_channels = (1, 14), 14, 128 builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) - weights = builder.placeholder( - "weights", torch.randn([out_channels, in_channels], dtype=torch.float32) - ) + x_input = torch.randn(*shape, dtype=torch.float32) + weights_input = torch.randn([out_channels, in_channels], dtype=torch.float32) + x = builder.placeholder("x", x_input) + weights = builder.placeholder("weights", weights_input) permute_copy = builder.call_operator( op=exir_ops.edge.aten.permute_copy.default, args=(weights, [1, 0]), @@ -649,14 +966,31 @@ def test_replace_linear_with_fully_connected(self) -> None: ) builder.output([mm]) original_gm = builder.get_graph_module() + gm = cast( PassResult, ReplacePermuteWithTransposePass()(original_gm) ).graph_module gm = cast(PassResult, ReplaceMMWithAddMMPass()(gm)).graph_module - gm = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)).graph_module + + gm_before_linear = copy.deepcopy(gm) + pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)) + self.assertTrue(pass_result.modified) + gm = pass_result.graph_module + + inputs = [x_input, weights_input] + validate(gm_before_linear, gm, inputs, "ReplaceAddMMWithLinearPass") + gm_before_fc = copy.deepcopy(gm) graph_after_passes = cast( PassResult, ReplaceLinearWithFullyConnectedOpPass()(gm) ).graph_module + + validate( + gm_before_fc, + graph_after_passes, + inputs, + "ReplaceLinearWithFullyConnectedOpPass", + ) + self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), @@ -673,21 +1007,17 @@ def test_replace_linear_with_fully_connected(self) -> None: 0, ) - @expand( - [ - [(4, 16, 256), 256, 512, True], - [(7, 17, 12), 12, 34, False], - ] - ) + @expand([[1.0, 1.0], [2.0, 3.0]]) @torch.no_grad() - def test_replace_addmm_with_linear( - self, shape: Tuple[int], in_features: int, out_features: int, bias: bool - ) -> None: - M, K, N, alpha, beta = 14, 48, 24, 1.0, 1.0 + def test_replace_addmm_with_linear(self, alpha: float, beta: float) -> None: + M, K, N = 14, 12, 10 builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(N, dtype=torch.float32)) - y = builder.placeholder("y", torch.randn([M, K], dtype=torch.float32)) - z = builder.placeholder("z", torch.randn([N, K], dtype=torch.float32)) + x_input = torch.randn(N, dtype=torch.float32) + y_input = torch.randn([M, K], dtype=torch.float32) + z_input = torch.randn([N, K], dtype=torch.float32) + x = builder.placeholder("x", x_input) + y = builder.placeholder("y", y_input) + z = builder.placeholder("z", z_input) permute_copy = builder.call_operator( op=exir_ops.edge.aten.permute_copy.default, args=(z, [1, 0]), @@ -699,12 +1029,21 @@ def test_replace_addmm_with_linear( ) builder.output([addmm]) original_gm = builder.get_graph_module() + gm = cast( PassResult, ReplacePermuteWithTransposePass()(original_gm) ).graph_module - graph_after_passes = cast( - PassResult, ReplaceAddMMWithLinearPass()(gm) - ).graph_module + + gm_before_linear = copy.deepcopy(gm) + pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)) + self.assertTrue(pass_result.modified) + graph_after_passes = pass_result.graph_module + + inputs = [x_input, y_input, z_input] + validate( + gm_before_linear, graph_after_passes, inputs, "ReplaceAddMMWithLinearPass" + ) + self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.linear.default), @@ -725,8 +1064,17 @@ def test_replace_mm_with_addmm(self) -> None: op=exir_ops.edge.aten.mm.default, args=(x, y), ) + + gm_before = copy.deepcopy(original_gm) p = ReplaceMMWithAddMMPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x, y] + validate(gm_before, graph_after_passes, inputs, "ReplaceMMWithAddMMPass") + self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.addmm.default), @@ -767,7 +1115,12 @@ def test_replace_squeeze_with_view( args=(x,), ) p = ReplaceSqueezeAndUnsqueezeWithViewPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass modified the graph + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), @@ -802,7 +1155,12 @@ def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int) -> None: args=(x, dim), ) p = ReplaceSqueezeAndUnsqueezeWithViewPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass modified the graph + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), @@ -814,206 +1172,87 @@ def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int) -> None: ) @torch.no_grad() - def test_replace_single_element_tensor_arguments_from_full_op_with_scalar( - self, - in_features: int = 16, - out_features: int = 16, - ) -> None: - src_zero_point = 0 - out_zero_point = 0 - builder = GraphBuilder() - x = builder.placeholder("x", torch.randn([1, in_features])) - weights = builder.placeholder( - "weights", torch.randn([in_features, out_features], dtype=torch.float32) - ) - bias = builder.placeholder( - "bias", torch.randn([out_features], dtype=torch.float32) + def test_replace_squeeze_and_unsqueeze_with_view_no_modification(self) -> None: + """Negative test: pass doesn't modify graphs without squeeze/unsqueeze ops.""" + x = torch.randn(2, 3, 4) + original_gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.view_copy.default, + args=(x, [2, 12]), ) - quantized_input = builder.call_operator( - op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - args=(x, 0.01431146077811718, 57, -128, 127, torch.int8), + p = ReplaceSqueezeAndUnsqueezeWithViewPass() + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass did NOT modify the graph + self.assertFalse(result.modified) + graph_after_passes = result.graph_module + + # Verify the original view_copy operation is still there + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), + 1, ) - weight_zero_point = builder.call_operator( - op=exir_ops.edge.aten.full.default, - args=([1], 0), + + @torch.no_grad() + def test_replace_conv1d_with_linear(self) -> None: + x = torch.randn(1, 96, 7) + weights = torch.randn(192, 96, 7) + bias = torch.randn(192) + original_gm = single_op_builder( + placeholders=(x, weights, bias), + op=exir_ops.edge.cadence.conv1d.default, + args=(x, weights, bias, [1], [0], [1], 1), ) - out_multiplier = builder.call_operator( - op=exir_ops.edge.aten.full.default, - args=([1], 0), + + gm_before = copy.deepcopy(original_gm) + p2 = ReplaceTrivialConvWithLinear() + result = cast(PassResult, p2(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x, weights, bias] + validate(gm_before, graph_after_passes, inputs, "ReplaceTrivialConvWithLinear") + + # Assert that conv1d is trivially converted to linear + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.conv1d.default), 0 ) - out_shift = builder.call_operator( - op=exir_ops.edge.aten.full.default, - args=([1], 0), + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 0 ) - output = builder.call_operator( - op=exir_ops.edge.cadence.quantized_linear.default, - args=( - quantized_input, - weights, - bias, - src_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - None, - ), - ) - dequantized_output = builder.call_operator( - op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - args=(output, 0.010696045123040676, -31, -128, 127, torch.int8), - ) - builder.output([dequantized_output]) - original_gm = builder.get_graph_module() - p = ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module - self.assertIsNotNone(graph_after_passes) - gm = dead_code_elimination_pass(graph_after_passes).graph_module - # By default, the quantized linear op should have constant scalar attributes. - self.assertTargetCountsEqual( - gm, - [ - # No default quantized linear op. - (exir_ops.edge.cadence.quantized_linear.default, 0), - # The default quantized linear op will be replaced with quantized_linear.per_tensor. - (exir_ops.edge.cadence.quantized_linear.per_tensor, 1), - # No aten.full ops. - (exir_ops.edge.aten.full.default, 0), - ], - ) - - @torch.no_grad() - def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_args( - self, - in_features: int = 16, - out_features: int = 16, - ) -> None: - src_zero_point = 0 - out_zero_point = 0 - builder = GraphBuilder() - x = builder.placeholder("x", torch.randn([1, in_features])) - weights = builder.placeholder( - "weights", torch.randn([in_features, out_features], dtype=torch.float32) - ) - bias = builder.placeholder( - "bias", torch.randn([out_features], dtype=torch.float32) - ) - quantized_input = builder.call_operator( - op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - args=(x, 0.01431146077811718, 57, -128, 127, torch.int8), - ) - weight_zero_point = builder.call_operator( - op=exir_ops.edge.aten.full.default, - args=([1], 0), - ) - out_multiplier = builder.call_operator( - op=exir_ops.edge.aten.full.default, - args=([1], 0), - ) - out_shift = builder.call_operator( - op=exir_ops.edge.aten.full.default, - args=([1], 0), - ) - output = builder.call_operator( - op=exir_ops.edge.cadence.quantized_linear.default, - args=( - quantized_input, - weights, - bias, - src_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - None, - ), - ) - dequantized_output = builder.call_operator( - op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - args=(output, 0.010696045123040676, -31, -128, 127, torch.int8), - ) - builder.output([dequantized_output]) - original_gm = builder.get_graph_module() - - for node in original_gm.graph.nodes: - # Replace the `shape` argument for aten.full op with a tuple. - if node.target == exir_ops.edge.aten.full.default: - node.args = (tuple(node.args[0]), node.args[1]) - - # Apply replacement pass. - p = ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module - self.assertIsNotNone(graph_after_passes) - gm = dead_code_elimination_pass(graph_after_passes).graph_module - - # By default, the quantized linear op should have constant scalar attributes. - self.assertTargetCountsEqual( - gm, - [ - # No default quantized linear op. - (exir_ops.edge.cadence.quantized_linear.default, 0), - # The default quantized linear op will be replaced with quantized_linear.per_tensor. - (exir_ops.edge.cadence.quantized_linear.per_tensor, 1), - # No aten.full ops. - (exir_ops.edge.aten.full.default, 0), - ], - ) - - @torch.no_grad() - def test_replace_conv1d_with_linear(self) -> None: - x = torch.randn(1, 96, 7) - weights = torch.randn(192, 96, 7) - bias = torch.randn(192) - original_gm = single_op_builder( - placeholders=(x, weights, bias), - op=exir_ops.edge.cadence.convolution.default, - args=(x, weights, bias, [1], [0], [1], 1, False), - ) - # First, replace the aten convolution with a cadence.convolution op - p1 = ReplaceAtenConvolutionWithCadenceConvolutionPass() - temp_graph = cast(PassResult, p1(original_gm)).graph_module - # temp_graph = p1(original_gm).graph_module - self.assertIsNotNone(temp_graph) - - p2 = ReplaceTrivialConvWithLinear() - graph_after_passes = cast(PassResult, p2(temp_graph)).graph_module - - # Assert that conv1d is trivially converted to linear - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 - ) - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 0 - ) - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.linear.default) - + count_node( - graph_after_passes, exir_ops.edge.cadence.fully_connected.default + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.linear.default) + + count_node( + graph_after_passes, exir_ops.edge.cadence.fully_connected.default ), 1, ) @torch.no_grad() def test_replace_conv2d_with_linear(self) -> None: - x = torch.randn(1, 96, 7, 7) - weights = torch.randn(192, 96, 7, 7) - bias = torch.randn(192) + x = torch.randn(1, 6, 7, 7) + weights = torch.randn(12, 6, 7, 7) + bias = torch.randn(12) original_gm = single_op_builder( placeholders=(x, weights, bias), - op=exir_ops.edge.cadence.convolution.default, - args=(x, weights, bias, [1, 1], [0, 0], [1, 1], 1, False), + op=exir_ops.edge.cadence.conv2d.default, + args=(x, weights, bias, [1, 1], [0, 0], [1, 1], 1), ) - # First, replace the aten convolution with a cadence.convolution op - p1 = ReplaceAtenConvolutionWithCadenceConvolutionPass() - temp_graph = cast(PassResult, p1(original_gm)).graph_module - self.assertIsNotNone(temp_graph) + gm_before = copy.deepcopy(original_gm) p2 = ReplaceTrivialConvWithLinear() - graph_after_passes = cast(PassResult, p2(temp_graph)).graph_module + result = cast(PassResult, p2(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x, weights, bias] + validate(gm_before, graph_after_passes, inputs, "ReplaceTrivialConvWithLinear") # Assert that conv2d is trivially converted to linear self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 + count_node(graph_after_passes, exir_ops.edge.cadence.conv2d.default), 0 ) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 0 @@ -1028,23 +1267,33 @@ def test_replace_conv2d_with_linear(self) -> None: @torch.no_grad() def test_replace_conv2d_with_im2row_and_linear(self) -> None: - x = torch.randn(1, 96, 47, 37) - weights = torch.randn(192, 96, 7, 7) - bias = torch.randn(192) + x = torch.randn(1, 2, 5, 5) + weights = torch.randn(3, 2, 4, 4) + bias = torch.randn(3) original_gm = single_op_builder( placeholders=(x, weights, bias), - op=exir_ops.edge.cadence.convolution.default, - args=(x, weights, bias, [1, 1], [0, 0], [1, 1], 1, False), + op=exir_ops.edge.cadence.conv2d.default, + args=(x, weights, bias, [1, 1], [0, 0], [1, 1], 1), ) + + gm_before = copy.deepcopy(original_gm) p = ReplaceConvWithIm2RowAndLinear() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x, weights, bias] + validate( + gm_before, graph_after_passes, inputs, "ReplaceConvWithIm2RowAndLinear" + ) # Assert that the convolution is converted to im2row + linear self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 + count_node(graph_after_passes, exir_ops.edge.cadence.conv2d.default), 0 ) self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 1 + count_node(graph_after_passes, exir_ops.edge.cadence.im2row.per_tensor), 1 ) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 1 @@ -1066,8 +1315,17 @@ def test_replace_select_with_view( op=exir_ops.edge.aten.select_copy.int, args=(x, dim, index), ) + + gm_before = copy.deepcopy(original_gm) p = ReplaceSelectWithViewOpPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x] + validate(gm_before, graph_after_passes, inputs, "ReplaceSelectWithViewOpPass") + # Assert that select op was replaced with view op self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 @@ -1100,8 +1358,21 @@ def test_replace_nop_transpose_with_view( op=exir_ops.edge.aten.transpose_copy.int, args=(x, dim0, dim1), ) + + gm_before = copy.deepcopy(original_gm) p = ReplaceNopTransposeOrPermuteWithViewPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceNopTransposeOrPermuteWithViewPass", + ) # Assert that transpose op was removed, and a view op was placed instead self.assertEqual( @@ -1128,8 +1399,21 @@ def test_replace_nop_permute_with_view( op=exir_ops.edge.aten.permute_copy.default, args=(x, dims), ) + + gm_before = copy.deepcopy(original_gm) p = ReplaceNopTransposeOrPermuteWithViewPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceNopTransposeOrPermuteWithViewPass", + ) # Assert that permute op was removed, and a view op was placed instead self.assertEqual( @@ -1156,8 +1440,16 @@ def test_replace_permute_with_transpose( op=exir_ops.edge.aten.permute_copy.default, args=(x, dims), ) + + gm_before = copy.deepcopy(original_gm) p = ReplacePermuteWithTransposePass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + inputs = [x] + validate( + gm_before, graph_after_passes, inputs, "ReplacePermuteWithTransposePass" + ) # Assert that permute op was replaced by a transpose op self.assertEqual( @@ -1188,9 +1480,12 @@ def test_replace_permute_with_transpose_nop( count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0 ) + +class TestReplaceWhereWithFullArgsWithWhereScalar(unittest.TestCase): def test_replace_aten_where_with_cadence(self) -> None: builder = GraphBuilder() - cond = builder.placeholder("cond", torch.randn(4, 8)) + cond_input = torch.randn(4, 8) + cond = builder.placeholder("cond", cond_input) aten_gt_scalar = builder.call_operator( op=exir_ops.edge.aten.gt.Scalar, args=(cond, 0), @@ -1209,8 +1504,24 @@ def test_replace_aten_where_with_cadence(self) -> None: ) builder.output([aten_where_self]) original_gm = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original_gm) + p = ReplaceWhereWithFullArgsWithWhereScalar() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [cond_input] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceWhereWithFullArgsWithWhereScalar", + ) + self.assertEqual( count_node( graph_after_passes, @@ -1238,9 +1549,9 @@ def test_replace_aten_where_with_cadence_broadcast( val1: float, val2: float, ) -> None: - # cond_shape, a_shape, b_shape, val1, val2 = builder = GraphBuilder() - cond = builder.placeholder("cond", torch.randn(cond_shape)) + cond_input = torch.randn(cond_shape) + cond = builder.placeholder("cond", cond_input) aten_gt_scalar = builder.call_operator( op=exir_ops.edge.aten.gt.Scalar, args=(cond, 0), @@ -1259,41 +1570,37 @@ def test_replace_aten_where_with_cadence_broadcast( ) builder.output([aten_where_self]) original_gm = builder.get_graph_module() - p = ReplaceWhereWithFullArgsWithWhereScalar() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module - self.assertEqual( - count_node( - graph_after_passes, - exir_ops.edge.aten.where.self, - ), - 1, - ) - def test_no_replace_aten_gelu_with_approximate_gelu(self) -> None: - inputs = torch.randn(2, 1, 64) + # Deepcopy before the pass + gm_before = copy.deepcopy(original_gm) - gm = single_op_builder( - placeholders=(inputs,), - op=exir_ops.edge.aten.gelu.default, - args=(inputs,), - ) - gm = ExportPass().call(gm).graph_module + p = ReplaceWhereWithFullArgsWithWhereScalar() + result = cast(PassResult, p(original_gm)) + # Broadcast case should not be replaced + self.assertFalse(result.modified) + graph_after_passes = result.graph_module - p = ReplaceAtenApproxGeluWithApproxGeluPass() - graph_after_passes = p.call(gm).graph_module + # Validate numerical accuracy (should be same since not modified) + inputs = [cond_input] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceWhereWithFullArgsWithWhereScalar", + ) - # Assert that aten.gelu op was not decomposed, since it didn't have an approximate argument self.assertEqual( count_node( graph_after_passes, - exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.where.self, ), 1, ) def test_replace_split_with_sizes_with_slice(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(1, 16, 8, 4)) + x_input = torch.randn(1, 16, 8, 4) + x = builder.placeholder("x", x_input) split = builder.call_operator( exir_ops.edge.aten.split_with_sizes_copy.default, (x, [8, 8], 1) ) @@ -1303,8 +1610,18 @@ def test_replace_split_with_sizes_with_slice(self) -> None: builder.output([out0, out1]) graph_module = builder.get_graph_module() + gm_before = copy.deepcopy(graph_module) p = ReplaceSplitWithSlicePass() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + result = cast(PassResult, p(graph_module)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + validate( + gm_before, + graph_after_passes, + [x_input], + "ReplaceSplitWithSlicePass", + ) self.assertEqual( count_node( @@ -1319,14 +1636,22 @@ def test_replace_split_with_sizes_with_slice(self) -> None: @expand([[2], [3], [4]]) def test_replace_pow_with_mul(self, exponent: int) -> None: - x = torch.randn(2, 1, 64) + x_input = torch.randn(2, 1, 64) + x = x_input original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.pow.Tensor_Scalar, args=(x, exponent), ) + + gm_before = copy.deepcopy(original_gm) p = ReplacePowWithMulPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + validate(gm_before, graph_after_passes, [x_input], "ReplacePowWithMulPass") + self.assertEqual( count_node( graph_after_passes, @@ -1379,7 +1704,7 @@ class TestReplaceIm2rowWithViewPass(unittest.TestCase): def test_no_replacement_for_conv(self) -> None: # Create a graph with a single im2row node. x = torch.randn(1, 3, 224, 224) - pad_value = torch.randn(1) + pad_value = torch.tensor(0, dtype=torch.int32) channels_last = False gm = single_op_builder( placeholders=(x, pad_value), @@ -1391,9 +1716,19 @@ def test_no_replacement_for_conv(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + # Apply replacement pass. p = ReplaceIm2RowWithViewPass() - gm_after_replacement = p.call(gm).graph_module + result = p.call(gm) + self.assertFalse(result.modified) + gm_after_replacement = result.graph_module + + # Validate numerical accuracy + inputs = [x, pad_value] + validate(gm_before, gm_after_replacement, inputs, "ReplaceIm2RowWithViewPass") + # Check that no replacement was made. self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1 @@ -1405,7 +1740,7 @@ def test_no_replacement_for_conv(self) -> None: def test_no_replace_for_dilation(self) -> None: # Create a graph with a single im2row node. x = torch.randn(1, 3, 5, 7) - pad_value = torch.randn(1) + pad_value = torch.tensor(0, dtype=torch.int32) channels_last = False gm = single_op_builder( placeholders=(x, pad_value), @@ -1417,9 +1752,19 @@ def test_no_replace_for_dilation(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + # Apply replacement pass. p = ReplaceIm2RowWithViewPass() - gm_after_replacement = p.call(gm).graph_module + result = p.call(gm) + self.assertFalse(result.modified) + gm_after_replacement = result.graph_module + + # Validate numerical accuracy + inputs = [x, pad_value] + validate(gm_before, gm_after_replacement, inputs, "ReplaceIm2RowWithViewPass") + self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1 ) @@ -1431,7 +1776,7 @@ def test_replace_linear_like_conv(self) -> None: # Create a graph with a single im2row node. in_h, in_w = 13, 15 x = torch.randn(1, 3, in_h, in_w) - pad_value = torch.randn(1) + pad_value = torch.tensor(0, dtype=torch.int32) channels_last = False gm = single_op_builder( placeholders=(x, pad_value), @@ -1443,9 +1788,19 @@ def test_replace_linear_like_conv(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + # Apply replacement pass. p = ReplaceIm2RowWithViewPass() - gm_after_replacement = p.call(gm).graph_module + result = p.call(gm) + self.assertTrue(result.modified) + gm_after_replacement = result.graph_module + + # Validate numerical accuracy + inputs = [x, pad_value] + validate(gm_before, gm_after_replacement, inputs, "ReplaceIm2RowWithViewPass") + # In this test, the kernel width/height is the same as the input width/height. self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 0 @@ -1477,62 +1832,10 @@ def create_conv1d_graphmodule( args = args + (channels_last,) return single_op_builder( placeholders=(x, w, b), - op=exir_ops.edge.cadence.convolution.default, + op=exir_ops.edge.cadence.conv1d.default, args=args, ) - def test_conv1d_default_channel_last(self) -> None: - # Create a graph with a single convolution node. - # Check if graph module is valid by running exportpass on it. - gm = self.create_conv1d_graphmodule() - gm = ExportPass().call(gm).graph_module - self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) - self.assertEqual(count_node(gm, exir_ops.edge.aten.transpose_copy.int), 0) - - # Apply replacement pass. - p = ReplaceConvWithChannelLastConvPass() - gm_after_replacement = p.call(gm).graph_module - # Check that no replacement was made. - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), - 1, - ) - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), - # Two transposes are added, one for the input and one for the output. - 3, - ) - for node in gm_after_replacement.graph.nodes: - if node.target != exir_ops.edge.cadence.convolution.default: - continue - # Check that the channel_last argument is set to True. - self.assertEqual(len(node.args), 8, f"{node=}") - self.assertTrue(node.args[7]) - - def test_conv1d_no_transpose_if_already_channel_last(self) -> None: - gm = self.create_conv1d_graphmodule(channels_last=True) - gm = ExportPass().call(gm).graph_module - self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) - - # Apply replacement pass. - p = ReplaceConvWithChannelLastConvPass() - gm_after_replacement = p.call(gm).graph_module - # Check that no replacement was made. - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), - 1, - ) - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), - 0, - ) - for node in gm_after_replacement.graph.nodes: - if node.target != exir_ops.edge.cadence.convolution.default: - continue - # Check that the channel_last argument is set to True. - self.assertEqual(len(node.args), 8, f"{node=}") - self.assertTrue(node.args[7]) - def create_convolution_graph_module( self, channels_last: Optional[bool] = None ) -> torch.fx.GraphModule: @@ -1554,91 +1857,39 @@ def create_convolution_graph_module( args = args + (channels_last,) return single_op_builder( placeholders=(x, w, b), - op=exir_ops.edge.cadence.convolution.default, + op=exir_ops.edge.cadence.conv2d.default, args=args, ) - def test_convolution_default_channel_last(self) -> None: - # Create a graph with a single convolution node. - # Check if graph module is valid by running exportpass on it. - gm = self.create_convolution_graph_module() - gm = ExportPass().call(gm).graph_module - self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) - self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) - - # Apply replacement pass. - p = ReplaceConvWithChannelLastConvPass() - gm_after_replacement = p.call(gm).graph_module - # Check that no replacement was made. - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), - 1, - ) - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), - # Three permutes are added, two for the input/weights and one for the output. - 3, - ) - for node in gm_after_replacement.graph.nodes: - if node.target != exir_ops.edge.cadence.convolution.default: - continue - # Check that the channel_last argument is set to True. - self.assertEqual(len(node.args), 8, f"{node=}") - self.assertTrue(node.args[7]) - - def test_no_transpose_if_already_channel_last(self) -> None: - gm = self.create_convolution_graph_module(channels_last=True) - gm = ExportPass().call(gm).graph_module - self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) - - # Apply replacement pass. - p = ReplaceConvWithChannelLastConvPass() - gm_after_replacement = p.call(gm).graph_module - # Check that no replacement was made. - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), - 1, - ) - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), - 0, - ) - for node in gm_after_replacement.graph.nodes: - if node.target != exir_ops.edge.cadence.convolution.default: - continue - # Check that the channel_last argument is set to True. - self.assertEqual(len(node.args), 8, f"{node=}") - self.assertTrue(node.args[7]) - def create_quantized_convolution_graph_module( self, channels_last: Optional[bool] = None - ) -> torch.fx.GraphModule: + ) -> tuple[tuple[torch.Tensor, ...], torch.fx.GraphModule]: """Helper to create a quantized conv node. - quantized_conv( + quantized_conv_per_tensor( Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, - int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, - Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, - Tensor out_shift, bool channel_last=False) -> (Tensor Z)" + int[] dilation, int groups, int input_zero_point, int weight_zero_point, + Tensor bias_scale, float out_scale, int out_zero_point, int out_multiplier, + int out_shift, bool channel_last=False) -> (Tensor Z)" """ if channels_last: - x = torch.randn(1, 224, 56, 3) - w = torch.randn(16, 16, 16, 3) + x = torch.randint(0, 100, (1, 224, 56, 3), dtype=torch.int32) + w = torch.randint(0, 100, (16, 16, 16, 3), dtype=torch.int32) else: - x = torch.randn(1, 3, 224, 56) - w = torch.randn(16, 3, 16, 16) + x = torch.randint(0, 100, (1, 3, 224, 56), dtype=torch.int32) + w = torch.randint(0, 100, (16, 3, 16, 16), dtype=torch.int32) b = torch.randn(16) stride = (2, 2) padding = (0, 0) dilation = (1, 1) groups = 1 input_zero_point = 0 - w_zero_point = torch.randn(1) - b_scale = torch.randn(1) + w_zero_point = 100 + b_scale = 10 out_scale = 1 out_zero_point = 0 - out_multiplier = torch.randn(1) - out_shift = torch.randn(1) + out_multiplier = 5 + out_shift = 5 args = ( x, w, @@ -1656,49 +1907,38 @@ def create_quantized_convolution_graph_module( out_shift, ) if channels_last is not None: - return single_op_builder( - placeholders=( - x, - w, - b, - w_zero_point, - b_scale, - out_multiplier, - out_shift, - ), - op=exir_ops.edge.cadence.quantized_conv_nhwc.default, - args=args, - ) + op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor else: - return single_op_builder( - placeholders=( - x, - w, - b, - w_zero_point, - b_scale, - out_multiplier, - out_shift, - ), - op=exir_ops.edge.cadence.quantized_conv_nchw.default, - args=args, - ) + op = exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor + + placeholders = (x, w, b) + + return placeholders, single_op_builder( + placeholders=placeholders, + op=op, + args=args, + ) def test_quantized_convolution_default_channel_last(self) -> None: # Create a graph with a single convolution node. - gm = self.create_quantized_convolution_graph_module() + placeholders, gm = self.create_quantized_convolution_graph_module() self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.default), 1 + count_node(gm, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor), 1 ) self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) # Apply replacement pass. p = ReplaceConvWithChannelLastConvPass() - gm_after_replacement = p.call(gm).graph_module - # Check that no replacement was made. + original = copy.deepcopy(gm) + result = p.call(gm) + self.assertTrue(result.modified) + gm_after_replacement = result.graph_module + + # Check that replacement was made. self.assertEqual( count_node( - gm_after_replacement, exir_ops.edge.cadence.quantized_conv_nhwc.default + gm_after_replacement, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, ), 1, ) @@ -1708,13 +1948,24 @@ def test_quantized_convolution_default_channel_last(self) -> None: 3, ) + # Validate numerical accuracy + validate( + original, + gm_after_replacement, + placeholders, + "ReplaceConvWithChannelLastConvPass", + ) + def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None: # Create a graph with a single im2row node. - gm = self.create_quantized_convolution_graph_module(channels_last=True) + placeholders, gm = self.create_quantized_convolution_graph_module( + channels_last=True + ) # Check if graph module is valid by running exportpass on it. + original = copy.deepcopy(gm) gm = ExportPass().call(gm).graph_module self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.default), 1 + count_node(gm, exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor), 1 ) # Apply replacement pass. @@ -1723,11 +1974,18 @@ def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None: # Check that no replacement was made. self.assertEqual( count_node( - gm_after_replacement, exir_ops.edge.cadence.quantized_conv_nhwc.default + gm_after_replacement, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, ), 1, ) self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + validate( + gm_after_replacement, + original, + placeholders, + "ReplaceConvWithChannelLastConvPass", + ) class TestMakeSliceAndCatDimOutermostPass(unittest.TestCase): @@ -1747,14 +2005,23 @@ def create_slice_graph( def test_slice_no_transpose_if_already_outermost(self) -> None: # Create a graph with a single slice node. + x = torch.randn(3, 224, 224) gm = self.create_slice_graph((3, 224, 224), 0, 1, 2) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(original, gm_after_pass, [x], "MakeSliceAndCatDimOutermostPass") # Assert that no transpose ops were added. self.assertEqual( @@ -1764,14 +2031,23 @@ def test_slice_no_transpose_if_already_outermost(self) -> None: def test_slice_no_transpose_if_outermost_dimensions_are_one(self) -> None: # Create a graph with a single slice node on second outermost dimension. + x = torch.randn(1, 3, 4, 6) gm = self.create_slice_graph((1, 3, 4, 6), 1, 1, 2) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(original, gm_after_pass, [x], "MakeSliceAndCatDimOutermostPass") # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1782,14 +2058,23 @@ def test_slice_no_transpose_if_outermost_dimensions_are_one(self) -> None: def test_slice_insert_transpose(self) -> None: # Create a graph with a single slice node. + x = torch.randn(1, 3, 4, 6) gm = self.create_slice_graph((1, 3, 4, 6), 2, 1, 2) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(original, gm_after_pass, [x], "MakeSliceAndCatDimOutermostPass") # Assert that there are two transpose ops added. self.assertEqual( @@ -1811,14 +2096,26 @@ def create_cat_graph( def test_cat_no_transpose_if_already_outermost(self) -> None: # Create a graph with a single slice node on second outermost dimension. + input1 = torch.randn(1, 3, 5) + input2 = torch.randn(2, 3, 5) gm = self.create_cat_graph(input_shapes=((1, 3, 5), (2, 3, 5)), cat_dim=0) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate( + original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass" + ) # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1829,14 +2126,26 @@ def test_cat_no_transpose_if_already_outermost(self) -> None: def test_cat_no_transpose_if_outermost_dimensions_are_one(self) -> None: # Create a graph with a single slice node on second outermost dimension. + input1 = torch.randn(1, 1, 3, 5) + input2 = torch.randn(1, 2, 3, 5) gm = self.create_cat_graph(input_shapes=((1, 1, 3, 5), (1, 2, 3, 5)), cat_dim=1) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate( + original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass" + ) # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1847,6 +2156,8 @@ def test_cat_no_transpose_if_outermost_dimensions_are_one(self) -> None: def test_cat_insert_transpose(self) -> None: # Create a graph with a single slice node on second outermost dimension. + input1 = torch.randn(1, 1, 3, 5) + input2 = torch.randn(1, 1, 3, 3) gm = self.create_cat_graph( input_shapes=((1, 1, 3, 5), (1, 1, 3, 3)), cat_dim=-1 ) @@ -1854,9 +2165,19 @@ def test_cat_insert_transpose(self) -> None: gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate( + original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass" + ) # Assert that transpose ops were added to make cat on outermost dimension. self.assertEqual( @@ -1866,9 +2187,10 @@ def test_cat_insert_transpose(self) -> None: class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase): - def _get_slice_empty_gm(self) -> torch.fx.GraphModule: + def _get_slice_empty_gm(self) -> tuple[torch.fx.GraphModule, torch.Tensor]: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(4)) + x_input = torch.randn(4) + x = builder.placeholder("x", x_input) # This is empty (numel == 0). slice0 = builder.call_operator( exir_ops.edge.aten.slice_copy.Tensor, (x, 0, 0, 0) @@ -1880,10 +2202,10 @@ def _get_slice_empty_gm(self) -> torch.fx.GraphModule: ((slice0, slice1),), ) builder.output([cat]) - return builder.get_graph_module() + return builder.get_graph_module(), x_input def test_empty_slice(self) -> None: - gm = self._get_slice_empty_gm() + gm, x_input = self._get_slice_empty_gm() self.assertEqual( len( gm.graph.find_nodes( @@ -1900,8 +2222,18 @@ def test_empty_slice(self) -> None: ), 0, ) - p = ReplaceEmptyTensorsWithFullPass() - updated_gm = cast(PassResult, p(gm)).graph_module + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + + result = ReplaceEmptyTensorsWithFullPass().call(gm) + self.assertTrue(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate(gm_before, updated_gm, inputs, "ReplaceEmptyTensorsWithFullPass") + self.assertEqual( len( updated_gm.graph.find_nodes( @@ -1929,15 +2261,35 @@ def test_empty_slice(self) -> None: def test_extract_mul_argument_to_full( self, _: str, value: Union[int, float] ) -> None: - x = torch.randn(2, 1, 64) + if isinstance(value, int): + x_input = torch.randint(0, 100, (1,), dtype=torch.int32) + else: + x_input = torch.randn((1,), dtype=torch.float32) + gm = single_op_builder( - placeholders=(x,), + placeholders=(x_input,), op=torch.ops.aten.mul.Tensor, - args=(x, value), + args=(x_input, value), kwargs={}, ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + p = ReplaceMulTensorWithMulAndFullOpsPass() - graph_after_passes = p.call(gm).graph_module + result = p.call(gm) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceMulTensorWithMulAndFullOpsPass", + ) + self.assertTrue( op_counts_match( graph_after_passes, @@ -1952,17 +2304,18 @@ def test_extract_mul_argument_to_full( class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase): def _get_adaptive_avg_pool_gm( self, input_shape: Tuple[int, int, int, int], output_shape: Tuple[int, int] - ) -> torch.fx.GraphModule: + ) -> tuple[torch.Tensor, torch.fx.GraphModule]: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*input_shape)) + x_input = torch.randn(*input_shape) + x = builder.placeholder("x", x_input) adaptive_avg_pool2d = builder.call_operator( exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape) ) builder.output([adaptive_avg_pool2d]) - return builder.get_graph_module() + return x_input, builder.get_graph_module() def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None: - gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8)) + x_input, gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8)) self.assertEqual( len( gm.graph.find_nodes( @@ -1981,8 +2334,24 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None: ), 0, ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass() - updated_gm = p.call(gm).graph_module + result = p.call(gm) + self.assertTrue(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, + updated_gm, + inputs, + "ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass", + ) + self.assertEqual( len( updated_gm.graph.find_nodes( @@ -2009,7 +2378,7 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None: self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None: - gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9)) + x_input, gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9)) self.assertEqual( len( gm.graph.find_nodes( @@ -2027,9 +2396,25 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None: ), 0, ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + # Shapes are not multiples of each other, so pass will not trigger p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass() - updated_gm = p.call(gm).graph_module + result = p.call(gm) + self.assertFalse(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy (should be same since not modified) + inputs = [x_input] + validate( + gm_before, + updated_gm, + inputs, + "ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass", + ) + self.assertEqual( len( updated_gm.graph.find_nodes( @@ -2048,6 +2433,113 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None: ) +class TestReplaceAtenAvgPoolWithCadenceAvgPoolPass(unittest.TestCase): + def _get_aten_avg_pool1d_gm( + self, input_shape: Tuple[int, int, int], kernel_size: int + ) -> tuple[torch.Tensor, torch.fx.GraphModule]: + builder = GraphBuilder() + x_input = torch.randn(*input_shape) + x = builder.placeholder("x", x_input) + avg_pool1d = builder.call_operator( + exir_ops.edge.aten.avg_pool1d.default, (x, [kernel_size]) + ) + builder.output([avg_pool1d]) + return x_input, builder.get_graph_module() + + def _get_aten_avg_pool2d_gm( + self, input_shape: Tuple[int, int, int, int], kernel_size: Tuple[int, int] + ) -> tuple[torch.Tensor, torch.fx.GraphModule]: + builder = GraphBuilder() + x_input = torch.randn(*input_shape) + x = builder.placeholder("x", x_input) + avg_pool2d = builder.call_operator( + exir_ops.edge.aten.avg_pool2d.default, (x, list(kernel_size)) + ) + builder.output([avg_pool2d]) + return x_input, builder.get_graph_module() + + def test_replace_aten_avg_pool1d_with_cadence(self) -> None: + x_input, gm = self._get_aten_avg_pool1d_gm((1, 32, 64), 3) + self.assertEqual( + count_node(gm, exir_ops.edge.aten.avg_pool1d.default), + 1, + ) + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.avg_pool2d.default), + 0, + ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + + p = ReplaceAtenAvgPoolWithCadenceAvgPoolPass() + result = p.call(gm) + self.assertTrue(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, + updated_gm, + inputs, + "ReplaceAtenAvgPoolWithCadenceAvgPoolPass", + ) + + # avg_pool1d should be replaced with view operations and avg_pool2d + self.assertEqual( + count_node(updated_gm, exir_ops.edge.aten.avg_pool1d.default), + 0, + ) + self.assertEqual( + count_node(updated_gm, exir_ops.edge.cadence.avg_pool2d.default), + 1, + ) + # Should have view operations for reshaping + self.assertGreater( + count_node(updated_gm, exir_ops.edge.aten.view_copy.default), + 0, + ) + + def test_replace_aten_avg_pool2d_with_cadence(self) -> None: + x_input, gm = self._get_aten_avg_pool2d_gm((1, 32, 64, 64), (3, 3)) + self.assertEqual( + count_node(gm, exir_ops.edge.aten.avg_pool2d.default), + 1, + ) + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.avg_pool2d.default), + 0, + ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + + p = ReplaceAtenAvgPoolWithCadenceAvgPoolPass() + result = p.call(gm) + self.assertTrue(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, + updated_gm, + inputs, + "ReplaceAtenAvgPoolWithCadenceAvgPoolPass", + ) + + # avg_pool2d should be replaced with cadence avg_pool2d + self.assertEqual( + count_node(updated_gm, exir_ops.edge.aten.avg_pool2d.default), + 0, + ) + self.assertEqual( + count_node(updated_gm, exir_ops.edge.cadence.avg_pool2d.default), + 1, + ) + + class TestReplaceLinalgSvdPass(unittest.TestCase): @expand( [ @@ -2070,7 +2562,9 @@ def test_replace_aten_linalg_svd_with_cadence_linalg_svd( ) p = ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module # Assert that the aten linalg_svd op was replaced with cadence linalg_svd op self.assertEqual( @@ -2081,3 +2575,171 @@ def test_replace_aten_linalg_svd_with_cadence_linalg_svd( count_node(graph_after_passes, exir_ops.edge.cadence.linalg_svd.default), 1, ) + + @expand([("dtype",), ("default",)]) + @torch.no_grad() + def test_replace_quantized_embedding( + self, + name: str, + ) -> None: + embedding = torch.ones(5, 6, dtype=torch.int8) + indices = torch.tensor([0, 2], dtype=torch.int32) + scales = torch.ones(5, 2, dtype=torch.float32) + zero_points = None + + original_gm = single_op_builder( + placeholders=(embedding, scales, indices), + op=( + exir_ops.edge.quantized_decomposed.embedding_byte.dtype + if name == "dtype" + else exir_ops.edge.quantized_decomposed.embedding_byte.default + ), + args=(embedding, scales, zero_points, -128, 127, indices), + kwargs={"dtype": torch.float32} if name == "dtype" else {}, + ) + + gm_before = copy.deepcopy(original_gm) + p = ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding() + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [embedding, scales, indices] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding", + ) + + self.assertEqual( + count_node( + graph_after_passes, + ( + exir_ops.edge.quantized_decomposed.embedding_byte.dtype + if name == "dtype" + else exir_ops.edge.quantized_decomposed.embedding_byte.default + ), + ), + 0, + ) + + self.assertEqual( + count_node( + graph_after_passes, + exir_ops.edge.cadence.quantized_embedding_byte.default, + ), + 1, + ) + + +class TestReplaceLogicalNotBooleanWhereWithWherePass(unittest.TestCase): + """Tests for the ReplaceLogicalNotBooleanWhereWithWherePass.""" + + def test_replace_where_with_logical_not_boolean(self) -> None: + """Test that where(logical_not(bool_cond), x, y) is replaced with where(bool_cond, y, x).""" + # Setup: Create a graph with where(logical_not(bool_cond), x, y) + builder = GraphBuilder() + bool_cond_ = torch.randn(4, 8) > 0 + x_ = torch.randn(4, 8) + y_ = torch.randn(4, 8) + + bool_cond = builder.placeholder("bool_cond", bool_cond_) + x = builder.placeholder("x", x_) + y = builder.placeholder("y", y_) + + # Create logical_not node + logical_not = builder.call_operator( + op=exir_ops.edge.aten.logical_not.default, + args=(bool_cond,), + ) + + # Create where node using logical_not + where_node = builder.call_operator( + op=exir_ops.edge.aten.where.self, + args=(logical_not, x, y), + ) + builder.output([where_node]) + original_gm = builder.get_graph_module() + + # Make a copy of the original graph before applying the pass + original_gm_copy = copy.deepcopy(original_gm) + + # Execute: Apply the replacement pass + p = ReplaceLogicalNotBooleanWhereWithWherePass() + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass modified the graph + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Assert: Verify logical_not is removed (dead code elimination) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default), + 0, + ) + + # Assert: Verify where node still exists + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.where.self), + 1, + ) + + # Assert: Verify the arguments are flipped (condition uses original bool_cond, x and y are swapped) + where_nodes = list( + graph_after_passes.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.where.self + ) + ) + for node in where_nodes: + # First arg should be the original bool_cond (not the logical_not) + self.assertEqual(node.args[0].name, "bool_cond") + # Second and third args should be swapped (y, x instead of x, y) + self.assertEqual(node.args[1].name, "y") + self.assertEqual(node.args[2].name, "x") + + # Assert: Verify outputs match exactly by running both graphs + validate( + original_gm_copy, + graph_after_passes, + (bool_cond_, x_, y_), + "ReplaceLogicalNotBooleanWhereWithWherePass", + ) + + def test_no_replacement_without_logical_not(self) -> None: + """Test that the pass does NOT apply when there's no logical_not.""" + # Setup: Create a graph with where(bool_cond, x, y) without logical_not + builder = GraphBuilder() + bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0) + x = builder.placeholder("x", torch.randn(4, 8)) + y = builder.placeholder("y", torch.randn(4, 8)) + + # Create where node directly without logical_not + where_node = builder.call_operator( + op=exir_ops.edge.aten.where.self, + args=(bool_cond, x, y), + ) + builder.output([where_node]) + original_gm = builder.get_graph_module() + + # Execute: Apply the replacement pass + p = ReplaceLogicalNotBooleanWhereWithWherePass() + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass did NOT modify the graph + self.assertFalse(result.modified) + graph_after_passes = result.graph_module + + # Assert: Verify where node still exists unchanged + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.where.self), + 1, + ) + + for node in graph_after_passes.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.where.self + ): + self.assertEqual(node.args[0].name, "bool_cond") + self.assertEqual(node.args[1].name, "x") + self.assertEqual(node.args[2].name, "y") diff --git a/backends/cadence/aot/tests/test_type_dispatch_passes.py b/backends/cadence/aot/tests/test_type_dispatch_passes.py index 52904aecb41..870735aad1a 100644 --- a/backends/cadence/aot/tests/test_type_dispatch_passes.py +++ b/backends/cadence/aot/tests/test_type_dispatch_passes.py @@ -13,41 +13,36 @@ from executorch.backends.cadence.aot.graph_builder import single_op_builder from executorch.backends.cadence.aot.pass_utils import count_node from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass +from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops from torch.fx.passes.infra.pass_base import PassResult class TestTypeDispatchPasses(unittest.TestCase): - def test_int8_dispatch_quantized_fully_connected(self) -> None: - """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant""" - x = torch.randint(-128, 127, (1, 3), dtype=torch.int8) - w = torch.randint(-128, 127, (4, 3), dtype=torch.int8) - b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_fully_connected.per_tensor, - args=(x, w, b, 0, 0, 1, 0, 0, None), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_fully_connected.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + @expand( + [ + ( + "int8", + torch.int8, exir_ops.edge.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_fully_connected(self) -> None: - """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant""" - x = torch.randint(0, 255, (1, 3), dtype=torch.uint8) - w = torch.randint(0, 255, (4, 3), dtype=torch.uint8) + ( + "uint8", + torch.uint8, + exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_fully_connected( + self, + _: str, + dtype: torch.dtype, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_fully_connected dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, (1, 3), dtype=dtype) + w = torch.randint(min_val, max_val, (4, 3), dtype=dtype) b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, w, b), @@ -61,45 +56,33 @@ def test_uint8_dispatch_quantized_fully_connected(self) -> None: count_node(gm, exir_ops.edge.cadence.quantized_fully_connected.per_tensor), 0, ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, - ), - 1, - ) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) - def test_int8_dispatch_quantized_linear(self) -> None: - """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_linear""" - x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) - w = torch.randint(-128, 127, (4, 3), dtype=torch.int8) - b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_linear.per_tensor, - args=(x, w, b, 0, 0, 1, 0, 0, None), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_linear.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + @expand( + [ + ( + "int8", + torch.int8, exir_ops.edge.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_quantized_linear_dispatch(self) -> None: - """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_linear""" - x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) - w = torch.randint(0, 255, (4, 3), dtype=torch.uint8) + ( + "uint8", + torch.uint8, + exir_ops.edge.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_linear( + self, + _: str, + dtype: torch.dtype, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_linear dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, (2, 3), dtype=dtype) + w = torch.randint(min_val, max_val, (4, 3), dtype=dtype) b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, w, b), @@ -113,14 +96,8 @@ def test_uint8_quantized_linear_dispatch(self) -> None: count_node(gm, exir_ops.edge.cadence.quantized_linear.per_tensor), 0, ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, - ), - 1, - ) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) def test_mixed_types_error(self) -> None: """Test mixed int8/uint8 inputs should raise RuntimeError""" @@ -138,33 +115,29 @@ def test_mixed_types_error(self) -> None: cast(PassResult, p(gm)).graph_module self.assertIn("Unsupported input types", str(context.exception)) - def test_int8_dispatch_quantized_relu(self) -> None: - """Test int8 input should dispatch to asym8s_asym8s variant for quantized_relu""" - x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) - gm = single_op_builder( - placeholders=(x,), - op=exir_ops.edge.cadence.quantized_relu.per_tensor, - args=(x, 0, 0, 1, 0), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_relu.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + @expand( + [ + ( + "int8", + torch.int8, exir_ops.edge.cadence.quantized_relu_asym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_relu(self) -> None: - """Test uint8 input should dispatch to asym8u_asym8u variant for quantized_relu""" - x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) + ( + "uint8", + torch.uint8, + exir_ops.edge.cadence.quantized_relu_asym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_relu( + self, + _: str, + dtype: torch.dtype, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_relu dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, (2, 3), dtype=dtype) gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.cadence.quantized_relu.per_tensor, @@ -177,45 +150,33 @@ def test_uint8_dispatch_quantized_relu(self) -> None: count_node(gm, exir_ops.edge.cadence.quantized_relu.per_tensor), 0, ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_relu_asym8u_asym8u.per_tensor, - ), - 1, - ) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) - def test_int8_dispatch_quantized_matmul(self) -> None: - """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_matmul""" - x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) - y = torch.randint(-128, 127, (3, 4), dtype=torch.int8) - bias = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, y, bias), - op=exir_ops.edge.cadence.quantized_matmul.default, - args=(x, 0, y, 0, bias, 1, 0, 0, False), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_matmul.default), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + @expand( + [ + ( + "int8", + torch.int8, exir_ops.edge.cadence.quantized_matmul_asym8sxasym8s_asym8s.default, ), - 1, - ) - - def test_uint8_dispatch_quantized_matmul(self) -> None: - """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_matmul""" - x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) - y = torch.randint(0, 255, (3, 4), dtype=torch.uint8) + ( + "uint8", + torch.uint8, + exir_ops.edge.cadence.quantized_matmul_asym8uxasym8u_asym8u.default, + ), + ] + ) + def test_dispatch_quantized_matmul( + self, + _: str, + dtype: torch.dtype, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_matmul dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, (2, 3), dtype=dtype) + y = torch.randint(min_val, max_val, (3, 4), dtype=dtype) bias = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, y, bias), @@ -229,252 +190,204 @@ def test_uint8_dispatch_quantized_matmul(self) -> None: count_node(gm, exir_ops.edge.cadence.quantized_matmul.default), 0, ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_matmul_asym8uxasym8u_asym8u.default, - ), - 1, - ) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) - def test_int8_dispatch_quantized_conv_nchw(self) -> None: - """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_conv_nchw""" - x = torch.randint(-128, 127, (1, 3, 8, 8), dtype=torch.int8) - w = torch.randint(-128, 127, (16, 3, 3, 3), dtype=torch.int8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor, + @expand( + [ + ( + "int8_nchw", + torch.int8, + (1, 3, 8, 8), # x_shape + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nchw(self) -> None: - """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_conv_nchw""" - x = torch.randint(0, 255, (1, 3, 8, 8), dtype=torch.uint8) - w = torch.randint(0, 255, (16, 3, 3, 3), dtype=torch.uint8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor, + ( + "uint8_nchw", + torch.uint8, + (1, 3, 8, 8), # x_shape + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nchw_asym8uxsym8u_asym8u.per_tensor, ), - 1, - ) - - def test_int8_dispatch_quantized_conv_nhwc(self) -> None: - """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_conv_nhwc""" - x = torch.randint(-128, 127, (1, 8, 8, 3), dtype=torch.int8) - w = torch.randint(-128, 127, (16, 3, 3, 3), dtype=torch.int8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor, + ( + "int8_nhwc", + torch.int8, + (1, 8, 8, 3), # x_shape + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nhwc(self) -> None: - """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_conv_nhwc""" - x = torch.randint(0, 255, (1, 8, 8, 3), dtype=torch.uint8) - w = torch.randint(0, 255, (16, 3, 3, 3), dtype=torch.uint8) + ( + "uint8_nhwc", + torch.uint8, + (1, 8, 8, 3), # x_shape + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc_asym8uxsym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_conv_2d( + self, + _: str, + dtype: torch.dtype, + x_shape: tuple[int, ...], + original_op: torch._ops.OpOverload, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_conv_2d (nchw/nhwc) dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, x_shape, dtype=dtype) + w = torch.randint(min_val, max_val, (16, 3, 3, 3), dtype=dtype) b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, + op=original_op, args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), ) p = CompileTimeTypeDispatchPass() gm = cast(PassResult, p(gm)).graph_module # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor, - ), - 1, - ) + self.assertEqual(count_node(gm, original_op), 0) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) - def test_int8_dispatch_quantized_conv_nchw_dilated(self) -> None: - """Test int8 x int8 inputs with dilation should dispatch to dilated_asym8sxasym8s_asym8s variant for quantized_conv_nchw_dilated""" - x = torch.randint(-128, 127, (1, 3, 8, 8), dtype=torch.int8) - w = torch.randint(-128, 127, (16, 3, 3, 3), dtype=torch.int8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor, + @expand( + [ + ( + "int8_nchw_dilated", + torch.int8, + (1, 3, 8, 8), # x_shape + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nchw_dilated(self) -> None: - """Test uint8 x uint8 inputs with dilation should dispatch to dilated_asym8uxasym8u_asym8u variant for quantized_conv_nchw""" - x = torch.randint(0, 255, (1, 3, 8, 8), dtype=torch.uint8) - w = torch.randint(0, 255, (16, 3, 3, 3), dtype=torch.uint8) - b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor, + ( + "uint8_nchw_dilated", + torch.uint8, + (1, 3, 8, 8), # x_shape + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u.per_tensor, ), - 1, - ) - - def test_int8_dispatch_quantized_conv_nhwc_dilated(self) -> None: - """Test int8 x int8 inputs with dilation should dispatch to dilated_asym8sxasym8s_asym8s variant for quantized_conv_nhwc""" - x = torch.randint(-128, 127, (1, 8, 8, 3), dtype=torch.int8) - w = torch.randint(-128, 127, (16, 3, 3, 3), dtype=torch.int8) + ( + "int8_nhwc_dilated", + torch.int8, + (1, 8, 8, 3), # x_shape + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor, + ), + ( + "uint8_nhwc_dilated", + torch.uint8, + (1, 8, 8, 3), # x_shape + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_conv_2d_dilated( + self, + _: str, + dtype: torch.dtype, + x_shape: tuple[int, ...], + original_op: torch._ops.OpOverload, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_conv_2d with dilation dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, x_shape, dtype=dtype) + w = torch.randint(min_val, max_val, (16, 3, 3, 3), dtype=dtype) b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, + op=original_op, args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1), ) p = CompileTimeTypeDispatchPass() gm = cast(PassResult, p(gm)).graph_module # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor, - ), - 1, - ) + self.assertEqual(count_node(gm, original_op), 0) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) - def test_uint8_dispatch_quantized_conv_nhwc_dilated(self) -> None: - """Test uint8 x uint8 inputs with dilation should dispatch to dilated_asym8uxasym8u_asym8u variant for quantized_conv_nhwc""" - x = torch.randint(0, 255, (1, 8, 8, 3), dtype=torch.uint8) - w = torch.randint(0, 255, (16, 3, 3, 3), dtype=torch.uint8) + @expand( + [ + ( + "int8_nchw_1d", + torch.int8, + (1, 3, 8), # x_shape + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + exir_ops.edge.cadence.quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor, + ), + ( + "uint8_nchw_1d", + torch.uint8, + (1, 3, 8), # x_shape + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + exir_ops.edge.cadence.quantized_conv1d_ncl_asym8uxsym8u_asym8u.per_tensor, + ), + ( + "int8_nhwc_1d", + torch.int8, + (1, 8, 3), # x_shape + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + exir_ops.edge.cadence.quantized_conv1d_nlc_asym8sxsym8s_asym8s.per_tensor, + ), + ( + "uint8_nhwc_1d", + torch.uint8, + (1, 8, 3), # x_shape + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + exir_ops.edge.cadence.quantized_conv1d_nlc_asym8uxsym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_conv_1d( + self, + _: str, + dtype: torch.dtype, + x_shape: tuple[int, ...], + original_op: torch._ops.OpOverload, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_conv_1d (nchw/nhwc) dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, x_shape, dtype=dtype) + w = torch.randint(min_val, max_val, (16, 3, 3), dtype=dtype) b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, - args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1), + op=original_op, + args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), ) p = CompileTimeTypeDispatchPass() gm = cast(PassResult, p(gm)).graph_module # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor, - ), - 1, - ) + self.assertEqual(count_node(gm, original_op), 0) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) - def test_int8_dispatch_quantized_add(self) -> None: - """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_add""" - x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) - y = torch.randint(-128, 127, (2, 3), dtype=torch.int8) - gm = single_op_builder( - placeholders=(x, y), - op=exir_ops.edge.cadence.quantized_add.per_tensor, - args=(x, 1.0, 0, y, 1.0, 0, 1.0, 0), - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_add.per_tensor), - 0, - ) - # Should be replaced with int8 specific variant - self.assertEqual( - count_node( - gm, + @expand( + [ + ( + "int8", + torch.int8, exir_ops.edge.cadence.quantized_add_asym8sxasym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_add(self) -> None: - """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_add""" - x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) - y = torch.randint(0, 255, (2, 3), dtype=torch.uint8) + ( + "uint8", + torch.uint8, + exir_ops.edge.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_add( + self, + _: str, + dtype: torch.dtype, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_add dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, (2, 3), dtype=dtype) + y = torch.randint(min_val, max_val, (2, 3), dtype=dtype) gm = single_op_builder( placeholders=(x, y), op=exir_ops.edge.cadence.quantized_add.per_tensor, @@ -487,158 +400,62 @@ def test_uint8_dispatch_quantized_add(self) -> None: count_node(gm, exir_ops.edge.cadence.quantized_add.per_tensor), 0, ) - # Should be replaced with uint8 specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor, - ), - 1, - ) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) - def test_int8_dispatch_quantized_conv_nchw_depthwise(self) -> None: - """Test int8 x int8 inputs with depthwise should dispatch to depthwise_asym8sxsym8s_asym8s variant for quantized_conv_nchw""" - # Depthwise convolution: groups == input_channels - x = torch.randint(-128, 127, (1, 3, 8, 8), dtype=torch.int8) - w = torch.randint( - -128, 127, (3, 1, 3, 3), dtype=torch.int8 - ) # groups=3, input_channels=3 - b = torch.randint(-2147483648, 2147483647, (3,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=( - x, - w, - b, - [1, 1], - [0, 0], - [1, 1], - 3, - 0, - 0, - 1.0, - 1.0, - 0, - 1, - 1, - ), # groups=3 - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with int8 depthwise specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor, + @expand( + [ + ( + "int8_nchw_depthwise", + torch.int8, + (1, 3, 8, 8), # x_shape + (3, 1, 3, 3), # w_shape (groups=3, input_channels=3) + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nchw_depthwise(self) -> None: - """Test uint8 x uint8 inputs with depthwise should dispatch to depthwise_asym8uxasym8u_asym8u variant for quantized_conv_nchw""" - # Depthwise convolution: groups == input_channels - x = torch.randint(0, 255, (1, 3, 8, 8), dtype=torch.uint8) - w = torch.randint( - 0, 255, (3, 1, 3, 3), dtype=torch.uint8 - ) # groups=3, input_channels=3 - b = torch.randint(-2147483648, 2147483647, (3,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - args=( - x, - w, - b, - [1, 1], - [0, 0], - [1, 1], - 3, - 0, - 0, - 1.0, - 1.0, - 0, - 1, - 1, - ), # groups=3 - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), - 0, - ) - # Should be replaced with uint8 depthwise specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor, + ( + "uint8_nchw_depthwise", + torch.uint8, + (1, 3, 8, 8), # x_shape + (3, 1, 3, 3), # w_shape (groups=3, input_channels=3) + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor, ), - 1, - ) - - def test_int8_dispatch_quantized_conv_nhwc_depthwise(self) -> None: - """Test int8 x int8 inputs with depthwise should dispatch to depthwise_asym8sxsym8s_asym8s variant for quantized_conv_nhwc""" - # Depthwise convolution: groups == input_channels - x = torch.randint(-128, 127, (1, 8, 8, 3), dtype=torch.int8) - w = torch.randint( - -128, 127, (3, 3, 3, 1), dtype=torch.int8 - ) # groups=3, input_channels=3 - b = torch.randint(-2147483648, 2147483647, (3,), dtype=torch.int32) - gm = single_op_builder( - placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, - args=( - x, - w, - b, - [1, 1], - [0, 0], - [1, 1], - 3, - 0, - 0, - 1.0, - 1.0, - 0, - 1, - 1, - ), # groups=3 - ) - p = CompileTimeTypeDispatchPass() - gm = cast(PassResult, p(gm)).graph_module - # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with int8 depthwise specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor, + ( + "int8_nhwc_depthwise", + torch.int8, + (1, 8, 8, 3), # x_shape + (3, 3, 3, 1), # w_shape (groups=3, input_channels=3) + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor, ), - 1, - ) - - def test_uint8_dispatch_quantized_conv_nhwc_depthwise(self) -> None: - """Test uint8 x uint8 inputs with depthwise should dispatch to depthwise_asym8uxasym8u_asym8u variant for quantized_conv_nhwc""" - # Depthwise convolution: groups == input_channels - x = torch.randint(0, 255, (1, 8, 8, 3), dtype=torch.uint8) - w = torch.randint( - 0, 255, (3, 3, 3, 1), dtype=torch.uint8 - ) # groups=3, input_channels=3 + ( + "uint8_nhwc_depthwise", + torch.uint8, + (1, 8, 8, 3), # x_shape + (3, 3, 3, 1), # w_shape (groups=3, input_channels=3) + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor, + ), + ] + ) + def test_dispatch_quantized_conv_depthwise( + self, + _: str, + dtype: torch.dtype, + x_shape: tuple[int, ...], + w_shape: tuple[int, ...], + original_op: torch._ops.OpOverload, + expected_op: torch._ops.OpOverload, + ) -> None: + """Test quantized_conv depthwise (groups == input_channels) dispatches to correct dtype-specific variant""" + min_val, max_val = torch.iinfo(dtype).min, torch.iinfo(dtype).max + x = torch.randint(min_val, max_val, x_shape, dtype=dtype) + w = torch.randint(min_val, max_val, w_shape, dtype=dtype) b = torch.randint(-2147483648, 2147483647, (3,), dtype=torch.int32) gm = single_op_builder( placeholders=(x, w, b), - op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, + op=original_op, args=( x, w, @@ -654,20 +471,11 @@ def test_uint8_dispatch_quantized_conv_nhwc_depthwise(self) -> None: 0, 1, 1, - ), # groups=3 + ), ) p = CompileTimeTypeDispatchPass() gm = cast(PassResult, p(gm)).graph_module # Original op should be replaced - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), - 0, - ) - # Should be replaced with uint8 depthwise specific variant - self.assertEqual( - count_node( - gm, - exir_ops.edge.cadence.quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor, - ), - 1, - ) + self.assertEqual(count_node(gm, original_op), 0) + # Should be replaced with dtype-specific variant + self.assertEqual(count_node(gm, expected_op), 1) diff --git a/backends/cadence/aot/type_dispatch.py b/backends/cadence/aot/type_dispatch.py index 108c4fb1a92..37f753767e9 100644 --- a/backends/cadence/aot/type_dispatch.py +++ b/backends/cadence/aot/type_dispatch.py @@ -27,6 +27,7 @@ class OpConfig: base_name: str type_dispatch_suffixes: dict[tuple[torch.dtype, ...], str] weight_arg_idx: Optional[int] = None + is_quant_op: bool = False variant: str = "per_tensor" @@ -62,16 +63,16 @@ class CompileTimeTypeDispatchPass(ExportPass): weight_arg_idx=2, variant="default", ), - exir_ops.edge.cadence.quantized_conv_nchw.per_tensor: OpConfig( - "quantized_conv_nchw", + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor: OpConfig( + "quantized_conv2d_nchw", type_dispatch_suffixes={ (torch.int8, torch.int8): "asym8sxsym8s_asym8s", (torch.uint8, torch.uint8): "asym8uxsym8u_asym8u", }, weight_arg_idx=1, ), - exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor: OpConfig( - "quantized_conv_nhwc", + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: OpConfig( + "quantized_conv2d_nhwc", type_dispatch_suffixes={ (torch.int8, torch.int8): "asym8sxsym8s_asym8s", (torch.uint8, torch.uint8): "asym8uxsym8u_asym8u", @@ -100,6 +101,29 @@ class CompileTimeTypeDispatchPass(ExportPass): }, variant="default", ), + exir_ops.edge.cadence.quantize_per_tensor.default: OpConfig( + "quantize_per_tensor", + type_dispatch_suffixes={ + (torch.int8,): "asym8s", + (torch.uint8,): "asym8u", + (torch.int16,): "asym16s", + (torch.uint16,): "asym16s", + (torch.int32,): "asym32s", + }, + variant="default", + is_quant_op=True, + ), + exir_ops.edge.cadence.dequantize_per_tensor.default: OpConfig( + "dequantize_per_tensor", + type_dispatch_suffixes={ + (torch.int8,): "asym8s", + (torch.uint8,): "asym8u", + (torch.int16,): "asym16s", + (torch.uint16,): "asym16s", + (torch.int32,): "asym32s", + }, + variant="default", + ), } def call_operator( @@ -120,6 +144,8 @@ def call_operator( if config.weight_arg_idx is not None: weight_dtype = args[config.weight_arg_idx].to_tensor().dtype dtype_key = (input_dtype, weight_dtype) + elif config.is_quant_op: + dtype_key = (args[5],) else: dtype_key = (input_dtype,) @@ -129,28 +155,33 @@ def call_operator( type_suffix = config.type_dispatch_suffixes[dtype_key] base_name = config.base_name + typed_op_name = f"{base_name}_{type_suffix}" + if op in [ - exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, - exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, ]: groups = args[6] input_channels = ( args[0].to_tensor().shape[1] - if op == exir_ops.edge.cadence.quantized_conv_nchw.per_tensor + if op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor else args[0].to_tensor().shape[-1] ) is_depthwise = groups == input_channels - - dilation = args[5] # pyre-ignore[16]: None has no attribute '__iter__'. - is_dilated = any(d > 1 for d in dilation) - - if is_dilated: - type_suffix = f"dilated_{type_suffix}" - elif is_depthwise: - type_suffix = f"depthwise_{type_suffix}" - - typed_op_name = f"{base_name}_{type_suffix}" + is_dilated = any(d > 1 for d in args[5]) + is_1d = len(args[0].to_tensor().shape) == 3 + + if is_depthwise: + typed_op_name = f"{base_name}_depthwise_{type_suffix}" + elif is_dilated: + typed_op_name = f"{base_name}_dilated_{type_suffix}" + elif is_1d and groups == 1: + if "nchw" in base_name: + layout_suffix = "ncl" + else: + layout_suffix = "nlc" + typed_op_name = f"quantized_conv1d_{layout_suffix}_{type_suffix}" typed_op = getattr( getattr(exir_ops.edge.cadence, typed_op_name), config.variant diff --git a/backends/cadence/build_cadence_fusionG3.sh b/backends/cadence/build_cadence_fusionG3.sh index 93295bc9aa5..47a0f9ff9bb 100644 --- a/backends/cadence/build_cadence_fusionG3.sh +++ b/backends/cadence/build_cadence_fusionG3.sh @@ -9,7 +9,7 @@ set -euo pipefail unset CMAKE_PREFIX_PATH unset XTENSA_CORE -export XTENSA_CORE=FCV_FG3GP +export XTENSA_CORE=VANILLA_G3 git submodule sync git submodule update --init ./backends/cadence/install_requirements.sh @@ -33,6 +33,7 @@ if $STEPWISE_BUILD; then -DEXECUTORCH_USE_DL=OFF \ -DEXECUTORCH_BUILD_CADENCE=OFF \ -DHAVE_FNMATCH_H=OFF \ + -DFLATCC_ALLOW_WERROR=OFF \ -Bcmake-out . echo "Building any Cadence-specific binaries on top" @@ -51,6 +52,7 @@ if $STEPWISE_BUILD; then -DPYTHON_EXECUTABLE=python3 \ -DEXECUTORCH_FUSION_G3_OPT=ON \ -DHAVE_FNMATCH_H=OFF \ + -DFLATCC_ALLOW_WERROR=OFF \ -Bcmake-out/backends/cadence \ backends/cadence cmake --build cmake-out/backends/cadence -j8 @@ -76,6 +78,7 @@ else -DPYTHON_EXECUTABLE=python3 \ -DEXECUTORCH_FUSION_G3_OPT=ON \ -DHAVE_FNMATCH_H=OFF \ + -DFLATCC_ALLOW_WERROR=OFF \ -Bcmake-out cmake --build cmake-out --target install --config Release -j8 fi diff --git a/backends/cadence/build_cadence_hifi4.sh b/backends/cadence/build_cadence_hifi4.sh index 33078b7ff2f..22775af7082 100644 --- a/backends/cadence/build_cadence_hifi4.sh +++ b/backends/cadence/build_cadence_hifi4.sh @@ -9,7 +9,7 @@ set -euo pipefail unset CMAKE_PREFIX_PATH unset XTENSA_CORE -export XTENSA_CORE=nxp_rt600_RI23_11_newlib +export XTENSA_CORE=VANILLA_HIFI git submodule sync git submodule update --init ./backends/cadence/install_requirements.sh @@ -32,6 +32,7 @@ if $STEPWISE_BUILD; then -DEXECUTORCH_ENABLE_LOGGING=ON \ -DEXECUTORCH_USE_DL=OFF \ -DEXECUTORCH_BUILD_CADENCE=OFF \ + -DFLATCC_ALLOW_WERROR=OFF \ -Bcmake-out . echo "Building any Cadence-specific binaries on top" @@ -50,6 +51,7 @@ if $STEPWISE_BUILD; then -DPYTHON_EXECUTABLE=python3 \ -DEXECUTORCH_NNLIB_OPT=ON \ -DHAVE_FNMATCH_H=OFF \ + -DFLATCC_ALLOW_WERROR=OFF \ -Bcmake-out/backends/cadence \ backends/cadence cmake --build cmake-out/backends/cadence -j8 @@ -74,6 +76,7 @@ else -DPYTHON_EXECUTABLE=python3 \ -DEXECUTORCH_NNLIB_OPT=ON \ -DHAVE_FNMATCH_H=OFF \ + -DFLATCC_ALLOW_WERROR=OFF \ -Bcmake-out cmake --build cmake-out --target install --config Release -j8 fi diff --git a/backends/cadence/build_cadence_vision.sh b/backends/cadence/build_cadence_vision.sh new file mode 100755 index 00000000000..b3972db4f31 --- /dev/null +++ b/backends/cadence/build_cadence_vision.sh @@ -0,0 +1,86 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -euo pipefail + +unset CMAKE_PREFIX_PATH +unset XTENSA_CORE +export XTENSA_CORE=VANILLA_VISION +git submodule sync +git submodule update --init --recursive +./install_requirements.sh +./install_executorch.sh + +rm -rf cmake-out + +STEPWISE_BUILD=false + +if $STEPWISE_BUILD; then + echo "Building ExecuTorch" + CXXFLAGS="-fno-exceptions -fno-rtti" cmake -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_TOOLCHAIN_FILE=./backends/cadence/cadence.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_ENABLE_EVENT_TRACER=OFF \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=OFF \ + -DEXECUTORCH_BUILD_PTHREADPOOL=OFF \ + -DEXECUTORCH_BUILD_CPUINFO=OFF \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DEXECUTORCH_USE_DL=OFF \ + -DEXECUTORCH_BUILD_CADENCE=OFF \ + -DFLATCC_ALLOW_WERROR=OFF \ + -Bcmake-out . + + echo "Building any Cadence-specific binaries on top" + CXXFLAGS="-fno-exceptions -fno-rtti" cmake -DBUCK2="$BUCK" \ + -DCMAKE_TOOLCHAIN_FILE=./backends/cadence/cadence.cmake \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_BUILD_PTHREADPOOL=OFF \ + -DEXECUTORCH_BUILD_CADENCE=ON \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DEXECUTORCH_ENABLE_PROGRAM_VERIFICATION=ON \ + -DEXECUTORCH_USE_DL=OFF \ + -DEXECUTORCH_BUILD_PORTABLE_OPS=ON \ + -DEXECUTORCH_BUILD_KERNELS_LLM=OFF \ + -DPYTHON_EXECUTABLE=python3 \ + -DEXECUTORCH_VISION_OPT=ON \ + -DHAVE_FNMATCH_H=OFF \ + -DFLATCC_ALLOW_WERROR=OFF \ + -Bcmake-out/backends/cadence \ + backends/cadence + cmake --build cmake-out/backends/cadence -j8 +else + echo "Building Cadence toolchain with ExecuTorch packages" + cmake_prefix_path="${PWD}/cmake-out/lib/cmake/ExecuTorch;${PWD}/cmake-out/third-party/gflags" + CXXFLAGS="-fno-exceptions -fno-rtti" cmake -DBUCK2="$BUCK" \ + -DCMAKE_PREFIX_PATH="${cmake_prefix_path}" \ + -DCMAKE_TOOLCHAIN_FILE=./backends/cadence/cadence.cmake \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_BUILD_PTHREADPOOL=OFF \ + -DEXECUTORCH_BUILD_CPUINFO=OFF \ + -DEXECUTORCH_BUILD_CADENCE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DEXECUTORCH_ENABLE_PROGRAM_VERIFICATION=ON \ + -DEXECUTORCH_USE_DL=OFF \ + -DEXECUTORCH_BUILD_PORTABLE_OPS=ON \ + -DEXECUTORCH_BUILD_KERNELS_LLM=OFF \ + -DPYTHON_EXECUTABLE=python3 \ + -DEXECUTORCH_VISION_OPT=ON \ + -DHAVE_FNMATCH_H=OFF \ + -DFLATCC_ALLOW_WERROR=OFF \ + -Bcmake-out + cmake --build cmake-out --target install --config Release -j8 +fi + +echo "Run simple model to verify cmake build" +python3 -m examples.portable.scripts.export --model_name="add" +xt-run --turbo cmake-out/executor_runner --model_path=add.pte diff --git a/backends/cadence/cadence.cmake b/backends/cadence/cadence.cmake index a0e5ea86da1..1bd5a8db0e0 100644 --- a/backends/cadence/cadence.cmake +++ b/backends/cadence/cadence.cmake @@ -41,8 +41,12 @@ set(CMAKE_CROSSCOMPILING TRUE) set(CMAKE_C_COMPILER ${TOOLCHAIN_HOME}/bin/${CROSS_COMPILE_TARGET}-clang) set(CMAKE_CXX_COMPILER ${TOOLCHAIN_HOME}/bin/${CROSS_COMPILE_TARGET}-clang++) -set(CMAKE_C_FLAGS_INIT "-stdlib=libc++ -mtext-section-literals -mlongcalls") -set(CMAKE_CXX_FLAGS_INIT "-stdlib=libc++ -mtext-section-literals -mlongcalls") +set(CMAKE_C_FLAGS_INIT + "-stdlib=libc++ -mtext-section-literals -mlongcalls -DET_ENABLE_ENUM_STRINGS=0" +) +set(CMAKE_CXX_FLAGS_INIT + "-stdlib=libc++ -mtext-section-literals -mlongcalls -DET_ENABLE_ENUM_STRINGS=0" +) # workaround for larger compilation time set(CMAKE_CXX_FLAGS_INIT "${CMAKE_CXX_FLAGS_INIT} -fno-strict-aliasing") diff --git a/backends/cadence/common/xt_macros.h b/backends/cadence/common/xt_macros.h new file mode 100644 index 00000000000..b7a49c96e16 --- /dev/null +++ b/backends/cadence/common/xt_macros.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#define XT_KERNEL_CHECK(ctx, out, kernel, ...) \ + { \ + const auto ret = kernel(__VA_ARGS__); \ + ET_KERNEL_CHECK_MSG( \ + ctx, \ + ret == 0, \ + InvalidArgument, \ + out, \ + "Failed to run kernel: " #kernel "(" #__VA_ARGS__ \ + "). Returned code %d", \ + static_cast(ret)); \ + } diff --git a/backends/cadence/fusion_g3/operators/op_add.cpp b/backends/cadence/fusion_g3/operators/op_add.cpp index 409c4cc5104..7ef7bdde3b1 100644 --- a/backends/cadence/fusion_g3/operators/op_add.cpp +++ b/backends/cadence/fusion_g3/operators/op_add.cpp @@ -10,7 +10,7 @@ #include -#include +#include #include #include #include @@ -24,7 +24,6 @@ using ::executorch::runtime::canCast; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -163,7 +162,7 @@ Tensor& add_out( float alpha_val; torch::executor::native::utils::extract_scalar(alpha, &alpha_val); - if ((a.numel() == 1) && (alpha_val == 1.0)) { + if ((a.numel() == 1) && (alpha_val == 1.0f)) { XT_KERNEL_CHECK( ctx, out, @@ -368,4 +367,3 @@ Tensor& add_scalar_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_cat.cpp b/backends/cadence/fusion_g3/operators/op_cat.cpp index 7f8e1ee8710..0c83ebaf0ad 100644 --- a/backends/cadence/fusion_g3/operators/op_cat.cpp +++ b/backends/cadence/fusion_g3/operators/op_cat.cpp @@ -13,7 +13,7 @@ #include -#include +#include #include #include @@ -29,7 +29,6 @@ using ::executorch::runtime::KernelRuntimeContext; * operator need to be updated accordingly */ -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -170,4 +169,3 @@ Tensor& cat_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_clamp.cpp b/backends/cadence/fusion_g3/operators/op_clamp.cpp index 92fb97b1260..8eed6b681c2 100644 --- a/backends/cadence/fusion_g3/operators/op_clamp.cpp +++ b/backends/cadence/fusion_g3/operators/op_clamp.cpp @@ -15,7 +15,7 @@ #include -#include +#include #include #include #include @@ -29,7 +29,6 @@ using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; using std::optional; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -663,4 +662,3 @@ Tensor& clamp_Tensor_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_dequantize.cpp b/backends/cadence/fusion_g3/operators/op_dequantize.cpp index c3fca3bb7d4..537e3f04ae0 100644 --- a/backends/cadence/fusion_g3/operators/op_dequantize.cpp +++ b/backends/cadence/fusion_g3/operators/op_dequantize.cpp @@ -14,7 +14,7 @@ #include -#include +#include #include #include @@ -36,7 +36,6 @@ enum datatype { Bits4u = 21, Bits4 = 22 }; /** * For an input tensor, use the scale and zero_point arguments to quantize it. */ -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -784,4 +783,3 @@ Tensor& dequantize_per_token_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_div.cpp b/backends/cadence/fusion_g3/operators/op_div.cpp index a16e8ed02ba..62ebf303ebd 100644 --- a/backends/cadence/fusion_g3/operators/op_div.cpp +++ b/backends/cadence/fusion_g3/operators/op_div.cpp @@ -12,7 +12,7 @@ #include -#include +#include #include #include #include @@ -28,7 +28,6 @@ using ::executorch::runtime::KernelRuntimeContext; using std::optional; using std::string_view; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -686,4 +685,3 @@ Tensor& div_scalar_mode_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_exp.cpp b/backends/cadence/fusion_g3/operators/op_exp.cpp index 84d2ac0b94e..51d53067668 100644 --- a/backends/cadence/fusion_g3/operators/op_exp.cpp +++ b/backends/cadence/fusion_g3/operators/op_exp.cpp @@ -12,7 +12,7 @@ #include -#include +#include #include #include @@ -21,7 +21,6 @@ using ::executorch::aten::Tensor; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -67,4 +66,3 @@ Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_hardtanh.cpp b/backends/cadence/fusion_g3/operators/op_hardtanh.cpp index 09a2535c0dc..b930098fb24 100644 --- a/backends/cadence/fusion_g3/operators/op_hardtanh.cpp +++ b/backends/cadence/fusion_g3/operators/op_hardtanh.cpp @@ -11,7 +11,7 @@ #include -#include +#include #include #include #include @@ -25,7 +25,6 @@ using ::executorch::runtime::KernelRuntimeContext; using ::torch::executor::native::utils::extract_scalar; using ::torch::executor::native::utils::get_scalar_dtype; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -113,4 +112,3 @@ Tensor& hardtanh_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_lt.cpp b/backends/cadence/fusion_g3/operators/op_lt.cpp index 08783860271..850552f1d3b 100644 --- a/backends/cadence/fusion_g3/operators/op_lt.cpp +++ b/backends/cadence/fusion_g3/operators/op_lt.cpp @@ -10,7 +10,7 @@ #include -#include +#include #include using ::executorch::aten::Scalar; @@ -19,7 +19,6 @@ using ::executorch::aten::Tensor; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -198,4 +197,3 @@ Tensor& lt_Scalar_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_mean.cpp b/backends/cadence/fusion_g3/operators/op_mean.cpp index 85a8f482aac..cefd45f6ef8 100644 --- a/backends/cadence/fusion_g3/operators/op_mean.cpp +++ b/backends/cadence/fusion_g3/operators/op_mean.cpp @@ -10,7 +10,7 @@ #include -#include +#include #include #include #include @@ -23,7 +23,6 @@ using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; using std::optional; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -192,4 +191,3 @@ Tensor& mean_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_mul.cpp b/backends/cadence/fusion_g3/operators/op_mul.cpp index bee6ac9cbda..a4a230a374f 100644 --- a/backends/cadence/fusion_g3/operators/op_mul.cpp +++ b/backends/cadence/fusion_g3/operators/op_mul.cpp @@ -10,7 +10,7 @@ #include -#include +#include #include #include #include @@ -23,7 +23,6 @@ using ::executorch::runtime::canCast; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -314,4 +313,3 @@ Tensor& mul_scalar_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp index 11c3edbb6a2..aa25cec9230 100644 --- a/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp +++ b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp @@ -13,7 +13,7 @@ #include -#include +#include #include #include #include @@ -25,7 +25,6 @@ using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; using std::optional; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -292,4 +291,3 @@ std::tuple native_layer_norm_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_permute_copy.cpp b/backends/cadence/fusion_g3/operators/op_permute_copy.cpp index 204882f3da9..5b1d079f92e 100644 --- a/backends/cadence/fusion_g3/operators/op_permute_copy.cpp +++ b/backends/cadence/fusion_g3/operators/op_permute_copy.cpp @@ -11,7 +11,7 @@ #include -#include +#include #include #include @@ -29,7 +29,6 @@ using ::executorch::runtime::KernelRuntimeContext; * operator need to be updated accordingly */ -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -157,4 +156,3 @@ Tensor& permute_copy_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_quantize.cpp b/backends/cadence/fusion_g3/operators/op_quantize.cpp index 3ad399bca8b..26f90ddf5d1 100644 --- a/backends/cadence/fusion_g3/operators/op_quantize.cpp +++ b/backends/cadence/fusion_g3/operators/op_quantize.cpp @@ -14,7 +14,7 @@ #include -#include +#include #include #include @@ -33,7 +33,6 @@ enum datatype { Bits4u = 21, Bits4 = 22 }; /** * For an input tensor, use the scale and zero_point arguments to quantize it. */ -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -840,4 +839,3 @@ Tensor& quantize_per_token_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_rsqrt.cpp b/backends/cadence/fusion_g3/operators/op_rsqrt.cpp index 59f9094aa29..a9017397687 100644 --- a/backends/cadence/fusion_g3/operators/op_rsqrt.cpp +++ b/backends/cadence/fusion_g3/operators/op_rsqrt.cpp @@ -10,7 +10,7 @@ #include -#include +#include #include #include #include @@ -20,7 +20,6 @@ using ::executorch::aten::Tensor; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -69,4 +68,3 @@ Tensor& rsqrt_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_sigmoid.cpp b/backends/cadence/fusion_g3/operators/op_sigmoid.cpp index 00149ab7e85..0ded70926eb 100644 --- a/backends/cadence/fusion_g3/operators/op_sigmoid.cpp +++ b/backends/cadence/fusion_g3/operators/op_sigmoid.cpp @@ -12,7 +12,7 @@ #include -#include +#include #include #include #include @@ -22,7 +22,6 @@ using ::executorch::aten::Tensor; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -95,4 +94,3 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_slice_copy.cpp b/backends/cadence/fusion_g3/operators/op_slice_copy.cpp index b7fd37fd1ee..504a00fcaee 100644 --- a/backends/cadence/fusion_g3/operators/op_slice_copy.cpp +++ b/backends/cadence/fusion_g3/operators/op_slice_copy.cpp @@ -13,7 +13,7 @@ #include -#include +#include #include #include @@ -28,7 +28,6 @@ using ::executorch::runtime::KernelRuntimeContext; * operator need to be updated accordingly */ -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -124,7 +123,7 @@ Tensor& slice_copy_Tensor_out( InvalidArgument, out); - torch::executor::compute_slice(in, dim, start, length, step, out); + torch::executor::compute_slice(ctx, in, dim, start, length, step, out); } return out; @@ -133,4 +132,3 @@ Tensor& slice_copy_Tensor_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_softmax.cpp b/backends/cadence/fusion_g3/operators/op_softmax.cpp index 14b128e9281..1faf41c94a8 100644 --- a/backends/cadence/fusion_g3/operators/op_softmax.cpp +++ b/backends/cadence/fusion_g3/operators/op_softmax.cpp @@ -12,7 +12,7 @@ #include -#include +#include #include #include #include @@ -24,7 +24,6 @@ using ::executorch::aten::Tensor; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -127,4 +126,3 @@ Tensor& _softmax_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_sqrt.cpp b/backends/cadence/fusion_g3/operators/op_sqrt.cpp index 4b0de889a39..584d94d78a1 100644 --- a/backends/cadence/fusion_g3/operators/op_sqrt.cpp +++ b/backends/cadence/fusion_g3/operators/op_sqrt.cpp @@ -12,7 +12,7 @@ #include -#include +#include #include #include #include @@ -22,7 +22,6 @@ using ::executorch::aten::Tensor; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -63,4 +62,3 @@ Tensor& sqrt_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_sub.cpp b/backends/cadence/fusion_g3/operators/op_sub.cpp index 9f4c2c3a5c3..0b5bee9a651 100644 --- a/backends/cadence/fusion_g3/operators/op_sub.cpp +++ b/backends/cadence/fusion_g3/operators/op_sub.cpp @@ -10,7 +10,7 @@ #include -#include +#include #include #include #include @@ -23,7 +23,6 @@ using ::executorch::runtime::canCast; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -582,4 +581,3 @@ Tensor& sub_scalar_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_tanh.cpp b/backends/cadence/fusion_g3/operators/op_tanh.cpp index 14a21066632..9686dc7caa9 100644 --- a/backends/cadence/fusion_g3/operators/op_tanh.cpp +++ b/backends/cadence/fusion_g3/operators/op_tanh.cpp @@ -12,7 +12,7 @@ #include -#include +#include #include #include #include @@ -22,7 +22,6 @@ using ::executorch::aten::Tensor; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -63,4 +62,3 @@ Tensor& tanh_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/op_transpose_copy.cpp b/backends/cadence/fusion_g3/operators/op_transpose_copy.cpp index 734fdcb2cd8..4bff24cbfe5 100644 --- a/backends/cadence/fusion_g3/operators/op_transpose_copy.cpp +++ b/backends/cadence/fusion_g3/operators/op_transpose_copy.cpp @@ -11,7 +11,7 @@ #include -#include +#include #include #include @@ -21,7 +21,6 @@ using ::executorch::aten::Tensor; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -129,4 +128,3 @@ Tensor& transpose_copy_int_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_where.cpp b/backends/cadence/fusion_g3/operators/op_where.cpp index 54966c4574b..4351e8bd684 100644 --- a/backends/cadence/fusion_g3/operators/op_where.cpp +++ b/backends/cadence/fusion_g3/operators/op_where.cpp @@ -10,7 +10,7 @@ #include -#include +#include #include #include @@ -19,7 +19,6 @@ using ::executorch::aten::Tensor; using ::executorch::runtime::Error; using ::executorch::runtime::KernelRuntimeContext; -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -176,4 +175,3 @@ Tensor& where_self_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/operators.h b/backends/cadence/fusion_g3/operators/operators.h index 641bb82f035..5b39d382f3a 100644 --- a/backends/cadence/fusion_g3/operators/operators.h +++ b/backends/cadence/fusion_g3/operators/operators.h @@ -11,7 +11,6 @@ #include #include -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -245,4 +244,3 @@ ::executorch::aten::Tensor& transpose_copy_int_out( } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/targets.bzl b/backends/cadence/fusion_g3/operators/targets.bzl index bc0a01b4fe8..dd04bd1223b 100644 --- a/backends/cadence/fusion_g3/operators/targets.bzl +++ b/backends/cadence/fusion_g3/operators/targets.bzl @@ -10,8 +10,11 @@ def define_operator(name: str, deps: list[str] | None = None) -> None: "//executorch/kernels/portable/cpu/pattern:all_deps", "//executorch/runtime/kernel:kernel_includes", "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/backends/cadence/common:xt_macros", "fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib_common", "fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib", + ":operators_header", + ":xt_utils", ] if deps == None: deps = [] @@ -26,11 +29,6 @@ def define_operator(name: str, deps: list[str] | None = None) -> None: ], compatible_with = ["ovr_config//cpu:xtensa"], deps = deps + common_deps, - exported_deps = [ - ":operators_header", - ":xt_macros", - ":xt_utils", - ], ) OPERATORS = [ @@ -79,18 +77,6 @@ def define_common_targets(): ], ) - runtime.cxx_library( - name = "xt_macros", - exported_headers = ["xt_macros.h"], - visibility = [ - "//executorch/backends/cadence/...", - ], - exported_deps = [ - "//executorch/runtime/core/exec_aten:lib", - "//executorch/runtime/kernel:kernel_runtime_context", - ], - ) - runtime.cxx_library( name = "xt_utils", exported_headers = ["xt_utils.h"], diff --git a/backends/cadence/fusion_g3/operators/tests/test_op_add.cpp b/backends/cadence/fusion_g3/operators/tests/test_op_add.cpp index bba778035b6..c0768dea0fb 100644 --- a/backends/cadence/fusion_g3/operators/tests/test_op_add.cpp +++ b/backends/cadence/fusion_g3/operators/tests/test_op_add.cpp @@ -19,7 +19,6 @@ #include #include -namespace cadence { namespace impl { namespace G3 { namespace native { @@ -39,7 +38,7 @@ class FusionG3OperatorTest : public OperatorTest { protected: Tensor& add_out(const Tensor& a, const Tensor& b, const Scalar& alpha, Tensor& out) { - return cadence::impl::G3::native::add_out(context_, a, b, alpha, out); + return impl::G3::native::add_out(context_, a, b, alpha, out); } }; @@ -100,4 +99,3 @@ TEST_F(FusionG3OperatorTest, KernelCheckTest) { } // namespace native } // namespace G3 } // namespace impl -} // namespace cadence diff --git a/backends/cadence/fusion_g3/operators/xt_macros.h b/backends/cadence/fusion_g3/operators/xt_macros.h deleted file mode 100644 index 4ab99380a2d..00000000000 --- a/backends/cadence/fusion_g3/operators/xt_macros.h +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include - -#define XT_KERNEL_CHECK(ctx, out, kernel, ...) \ - const auto ret = kernel(__VA_ARGS__); \ - ET_KERNEL_CHECK_MSG( \ - ctx, \ - ret == 0, \ - InvalidArgument, \ - out, \ - "Failed to run kernel: " #kernel "(" #__VA_ARGS__ ")"); diff --git a/backends/cadence/reference/kernels/CMakeLists.txt b/backends/cadence/generic/kernels/CMakeLists.txt similarity index 100% rename from backends/cadence/reference/kernels/CMakeLists.txt rename to backends/cadence/generic/kernels/CMakeLists.txt diff --git a/backends/cadence/reference/kernels/TARGETS b/backends/cadence/generic/kernels/TARGETS similarity index 100% rename from backends/cadence/reference/kernels/TARGETS rename to backends/cadence/generic/kernels/TARGETS diff --git a/backends/cadence/generic/kernels/kernels.cpp b/backends/cadence/generic/kernels/kernels.cpp new file mode 100644 index 00000000000..28961d0faf1 --- /dev/null +++ b/backends/cadence/generic/kernels/kernels.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace impl { +namespace generic { +namespace kernels { + +// Quantize a fp32 value to an int8_t/uint8_t value +template +T quantize(const float x, float scale, int32_t zero_point) { + // constexpr float min_val = std::numeric_limits::min(); + // constexpr float max_val = std::numeric_limits::max(); + // float tmp = roundf(x * scale + zero_point); + // return std::max(std::min(tmp, max_val), min_val); + // Match Executorch CPU kernel implementation at + // https://fburl.com/code/fxizw6u6 + int64_t qvalue; + qvalue = static_cast(zero_point + std::nearbyint(scale * x)); + + qvalue = std::max(qvalue, std::numeric_limits::min()); + qvalue = std::min(qvalue, std::numeric_limits::max()); + return static_cast(qvalue); +} + +// Quantize an fp32 array to an int8_t/uint8_t array +template +void quantize( + T* __restrict__ y, + const float* __restrict__ x, + float inv_scale, + int32_t zero_point, + size_t size) { + for (size_t i = 0; i < size; ++i) { + y[i] = quantize(x[i], inv_scale, zero_point); + } +} + +// Dequantize an int8_t/uint8_t value to an fp32 value +template +float dequantize(const T x, float scale, int32_t zero_point) { + return scale * (x - zero_point); +} + +// Dequantize an int8_t/uint8_t/int16_t array to an fp32 array +template +void dequantize( + float* __restrict__ y, + const T* __restrict__ x, + float scale, + int32_t zero_point, + size_t size) { + for (size_t i = 0; i < size; ++i) { + y[i] = dequantize(x[i], scale, zero_point); + } +} + +// explicit template instantiation + +#define typed_quantize_val(dtype) \ + template dtype quantize(const float x, float inv_scale, int32_t zero_point); +typed_quantize_val(int8_t); +typed_quantize_val(uint8_t); +typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); +typed_quantize_val(int32_t); +#undef typed_quantize_val + +#define typed_quantize_vec(dtype) \ + template void quantize( \ + dtype* __restrict__ y, \ + const float* __restrict__ x, \ + float inv_scale, \ + int32_t zero_point, \ + size_t size); +typed_quantize_vec(int8_t); +typed_quantize_vec(uint8_t); +typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); +typed_quantize_vec(int32_t); +#undef typed_quantize_vec + +#define typed_dequantize_val(dtype) \ + template float dequantize(const dtype x, float scale, int32_t zero_point); +typed_dequantize_val(int8_t); +typed_dequantize_val(uint8_t); +typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); +typed_dequantize_val(int32_t); +#undef typed_dequantize_val + +#define typed_dequantize_vec(dtype) \ + template void dequantize( \ + float* __restrict__ y, \ + const dtype* __restrict__ x, \ + float scale, \ + int32_t zero_point, \ + size_t size); +typed_dequantize_vec(int8_t); +typed_dequantize_vec(uint8_t); +typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); +typed_dequantize_vec(int32_t); +#undef typed_dequantize_vec + +} // namespace kernels +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/kernels/kernels.h b/backends/cadence/generic/kernels/kernels.h new file mode 100644 index 00000000000..4b37eeb45d0 --- /dev/null +++ b/backends/cadence/generic/kernels/kernels.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace impl { +namespace generic { +namespace kernels { + +template +T quantize(const float x, float scale, int32_t zero_point); + +template +float dequantize(const T x, float scale, int32_t zero_point); + +template +void quantize( + T* __restrict__ y, + const float* __restrict__ x, + float scale, + int32_t zero_point, + size_t size); + +// Deuantize an int8_t/uint8_t/int16_t array to an fp32 array +template +void dequantize( + float* __restrict__ y, + const T* __restrict__ x, + float scale, + int32_t zero_point, + size_t size); + +template +OT requantize( + const IT in, + float in_scale, + int32_t in_zero_point, + float inv_out_scale, + int32_t out_zero_point); + +template +void requantize( + OT* __restrict__ out, + const IT* __restrict__ in, + float in_scale, + int32_t in_zero_point, + float inv_out_scale, + int32_t out_zero_point, + size_t size); + +} // namespace kernels +} // namespace generic +} // namespace impl diff --git a/backends/cadence/reference/kernels/targets.bzl b/backends/cadence/generic/kernels/targets.bzl similarity index 100% rename from backends/cadence/reference/kernels/targets.bzl rename to backends/cadence/generic/kernels/targets.bzl diff --git a/backends/cadence/generic/operators/CMakeLists.txt b/backends/cadence/generic/operators/CMakeLists.txt new file mode 100644 index 00000000000..63d8902ac89 --- /dev/null +++ b/backends/cadence/generic/operators/CMakeLists.txt @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake) + +# ATen compliant ops that are needed to run this model. +set(_aten_ops__srcs + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/slice_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_add.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_embedding.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_full.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_hardtanh.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_max_pool2d_with_indices.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mean.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_rsqrt.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sub.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_expand_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_gelu.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_empty.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_transpose_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_eq.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_logical_not.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_any.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_native_group_norm.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sum.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_select_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_view_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/select_copy_util.cpp" +) +add_library(aten_ops_cadence ${_aten_ops__srcs}) +target_link_libraries(aten_ops_cadence PUBLIC executorch) +target_link_libraries(aten_ops_cadence PRIVATE cadence_kernels) + +# Let files say "include ". +set(_common_include_directories + ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 +) + +target_include_directories( + aten_ops_cadence PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} + ${_common_include_directories} +) + +# Custom ops that are needed to run the test model. +add_library( + custom_ops + "quantized_linear_out.cpp" + "quantized_conv2d_nchw_out.cpp" + "quantized_conv2d_nhwc_out.cpp" + "quantized_relu_out.cpp" + "quantized_layer_norm.cpp" + "quantize_per_tensor.cpp" + "quantized_fully_connected_out.cpp" + "dequantize_per_tensor.cpp" + "quantized_matmul_out.cpp" + "op_requantize_out.cpp" + "im2row_out.cpp" +) +target_include_directories( + custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} + ${_common_include_directories} +) + +target_link_libraries(custom_ops PUBLIC executorch) +target_link_libraries(custom_ops PRIVATE cadence_kernels) + +# Generate C++ bindings to register kernels into both PyTorch (for AOT) and +# Executorch (for runtime). Here select all ops in functions.yaml +gen_selected_ops( + LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML + "${CMAKE_CURRENT_LIST_DIR}/../../aot/functions.yaml" "" "" +) +generate_bindings_for_kernels( + LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML FUNCTIONS_YAML + ${CMAKE_CURRENT_SOURCE_DIR}/../../aot/functions.yaml +) +message("Generated cadence x86 files ${gen_command_sources}") + +gen_operators_lib( + LIB_NAME "cadence_ops_lib" KERNEL_LIBS custom_ops DEPS aten_ops_cadence +) diff --git a/backends/cadence/generic/operators/cadence_type_util.h b/backends/cadence/generic/operators/cadence_type_util.h new file mode 100644 index 00000000000..43852277031 --- /dev/null +++ b/backends/cadence/generic/operators/cadence_type_util.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +/** + * @file cadence_type_util.h + * @brief Common type macros for Cadence quantized operators + * + * This header provides utility macros for iterating over supported quantized + * data types in Cadence operators. These macros are used with switch statements + * to dispatch to type-specific implementations. + */ + +/** + * Macro to iterate over standard Cadence quantized types (uint8_t, int8_t) + * + * Usage: + * ET_FORALL_CADENCE_QUANTIZED_TYPES(MACRO) + * + * Where MACRO is defined as: #define MACRO(ctype, name) ... + * - ctype: C++ type (uint8_t or int8_t) + * - name: ExecutorTorch ScalarType name suffix (Byte or Char) + * + * Example: + * #define HANDLE_TYPE(ctype, name) \ + * case ScalarType::name: \ + * return process(tensor); \ + * break; + * + * ScalarType dtype = tensor.scalar_type(); + * switch (dtype) { + * ET_FORALL_CADENCE_QUANTIZED_TYPES(HANDLE_TYPE) + * default: + * ET_CHECK_MSG(false, "Unsupported dtype"); + * } + */ +#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) + +/** + * Macro to iterate over extended Cadence quantized types including int16_t + * + * Usage: + * ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(MACRO) + * + * Where MACRO is defined as: #define MACRO(ctype, name) ... + * - ctype: C++ type (uint8_t, int8_t, or int16_t) + * - name: ExecutorTorch ScalarType name suffix (Byte, Char, or Short) + * + * This macro includes int16_t support for operators that can handle 16-bit + * quantized values (e.g., quantized_linear, quantized_fully_connected). + */ +#define ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) diff --git a/backends/cadence/generic/operators/op_avg_pool2d.cpp b/backends/cadence/generic/operators/op_avg_pool2d.cpp new file mode 100644 index 00000000000..936a5583d8c --- /dev/null +++ b/backends/cadence/generic/operators/op_avg_pool2d.cpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::getLeadingDims; +using ::executorch::runtime::KernelRuntimeContext; + +// Compute the avg_pool2d for in_data in NCHW layout. IT is the input datatype, +// and AT is the accumulation datatype. 'quantized' is true when the input is +// quantized tensor. +template +void avg_pool2d_nchw( + const IT* __restrict__ in_data, + const int32_t in_zero_point, + IT* __restrict__ out_data, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool count_include_pad, + int64_t divisor, + int leading_dims, + int ih, + int iw, + int oh, + int ow) { + int kh = kernel_size[0]; + int kw = kernel_size[1]; + int s0 = stride[0]; + int s1 = stride[1]; + int p0 = padding[0]; + int p1 = padding[1]; + + for (int _n = 0; _n < leading_dims; ++_n) { + for (int _ih = 0, _oh = 0; _oh < oh; ++_oh, _ih += s0) { + int input_offset = _n * ih * iw; + int output_offset = _n * oh * ow + _oh * ow; + for (int _iw = 0, _ow = 0; _ow < ow; ++_ow, _iw += s1) { + int kh_lo = std::max(0, _ih - p0); + int kh_hi = std::min(ih, _ih + kh - p0); + int kw_lo = std::max(0, _iw - p1); + int kw_hi = std::min(iw, _iw + kw - p1); + // Count the number of contributions sans padding + int count = (kh_hi - kh_lo) * (kw_hi - kw_lo); + // Set the accumulator + AT acc = count_include_pad ? in_zero_point * (kh * kw - count) : 0; + // Accumulate values + for (int _kh = kh_lo; _kh < kh_hi; ++_kh) { + for (int _kw = kw_lo; _kw < kw_hi; ++_kw) { + int input_addr = input_offset + _kh * iw + _kw; + acc += in_data[input_addr]; + } + } + // The divisor changes depending on whether the count includes + // padded cells or not. + float inv_divisor = 1. / (count_include_pad ? divisor : count); + float val = acc * inv_divisor; + if (quantized) { + int32_t min_val = + static_cast(std::numeric_limits::min()); + int32_t max_val = + static_cast(std::numeric_limits::max()); + out_data[output_offset + _ow] = std::min( + std::max(int32_t(std::nearbyint(val)), min_val), max_val); + } else { + out_data[output_offset + _ow] = val; + } + } + } + } +} + +Tensor& avg_pool2d_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + optional divisor_override, + const optional& in_zero_point_t, + bool channel_last, + Tensor& out) { + ET_DCHECK_MSG(!channel_last, "NHWC layout for avg_pool2d not yet supported"); + const int32_t in_zero_point = in_zero_point_t.has_value() + ? in_zero_point_t.value().const_data_ptr()[0] + : 0; + const int64_t divisor = divisor_override.has_value() + ? divisor_override.value() + : kernel_size[0] * kernel_size[1]; + + const int odim = out.dim(); + const int on = getLeadingDims(out, odim - 2); + const int oh = out.size(odim - 2); + const int ow = out.size(odim - 1); + const int ih = input.size(odim - 2); + const int iw = input.size(odim - 1); + + // We generate the kernel for float and uint8_t types. The operator also + // works for double, but does not support other dtypes. +#define typed_avg_pool2d(btype, ctype, quantized, dtype) \ + case ScalarType::dtype: { \ + avg_pool2d_nchw( \ + input.const_data_ptr(), \ + in_zero_point, \ + out.mutable_data_ptr(), \ + kernel_size, \ + stride, \ + padding, \ + count_include_pad, \ + divisor, \ + on, \ + ih, \ + iw, \ + oh, \ + ow); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + typed_avg_pool2d(float, float, false, Float); + typed_avg_pool2d(uint8_t, int32_t, true, Byte); + default: + ET_DCHECK_MSG( + false, + "avg_pool2d not implemented for dtype %s", + torch::executor::toString(dtype)); + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_avg_pool2d.h b/backends/cadence/generic/operators/op_avg_pool2d.h new file mode 100644 index 00000000000..05f1810bb61 --- /dev/null +++ b/backends/cadence/generic/operators/op_avg_pool2d.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& avg_pool2d_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef kernel_size, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + ::executorch::aten::optional divisor_override, + const ::executorch::aten::optional<::executorch::aten::Tensor>& + in_zero_point_t, + bool channel_last, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_conv1d.cpp b/backends/cadence/generic/operators/op_conv1d.cpp new file mode 100644 index 00000000000..9a69c9fd549 --- /dev/null +++ b/backends/cadence/generic/operators/op_conv1d.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// This implements a generic 1D float32 convolution kernel. +// The input is of shape [n x c x w] (batch x channels x width) +// The weight is of shape [oc x wc x ww] (out_channels x weight_channels x +// weight_width) The output is of shape [n x oc x ow] (batch x out_channels x +// out_width) The bias is of shape [oc] + +Tensor& conv1d_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + Tensor& out) { + // Extract dimensions + const int n = input.size(0); + const int c = input.size(1); + const int w = input.size(2); + const int oc = weight.size(0); + const int wc = weight.size(1); + const int ww = weight.size(2); + const int ow = out.size(2); + + const int16_t s = static_cast(stride[0]); + const int16_t p = static_cast(padding[0]); + const int16_t d = static_cast(dilation[0]); + const int16_t g = static_cast(groups); + + const float* p_in = input.const_data_ptr(); + const float* p_weight = weight.const_data_ptr(); + const float* p_bias = bias.const_data_ptr(); + float* p_out = out.mutable_data_ptr(); + + const bool zero_pad_unit_dilation = d == 1 && p == 0; + const int ocpg = oc / g; + const int icpg = c / g; + + for (int _n = 0; _n < n; ++_n) { + const float* in_batch = p_in + _n * c * w; + float* out_batch = p_out + _n * oc * ow; + for (int _g = 0; _g < g; ++_g) { + int sic = _g * icpg; + int soc = _g * ocpg; + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + float* out_plane = out_batch + _oc * ow; + const float* weight_batch = p_weight + _oc * wc * ww; + for (int _w = 0, _ow = 0; _ow < ow; _w += s, ++_ow) { + float acc = p_bias[_oc]; + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const float* in_plane = in_batch + _ic * w; + const float* weight_plane = weight_batch + (_ic - sic) * ww; + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = _w + _ww; + acc += in_plane[ioff] * weight_plane[_ww]; + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const float* in_plane = in_batch + _ic * w; + const float* weight_plane = weight_batch + (_ic - sic) * ww; + for (int _ww = 0; _ww < ww; ++_ww) { + int w_pos = _w + d * _ww - p; + if (w_pos >= 0 && w_pos < w) { + acc += in_plane[w_pos] * weight_plane[_ww]; + } + } + } + } + out_plane[_ow] = acc; + } + } + } + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_conv1d.h b/backends/cadence/generic/operators/op_conv1d.h new file mode 100644 index 00000000000..457c46ea358 --- /dev/null +++ b/backends/cadence/generic/operators/op_conv1d.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +Tensor& conv1d_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_conv2d.cpp b/backends/cadence/generic/operators/op_conv2d.cpp new file mode 100644 index 00000000000..b71695fe367 --- /dev/null +++ b/backends/cadence/generic/operators/op_conv2d.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// This implements a generic 2D float32 convolution kernel. +// The input is of shape [n x c x h x w] (batch x channels x height x width) +// The weight is of shape [oc x wc x wh x ww] (out_channels x weight_channels x +// weight_height x weight_width) The output is of shape [n x oc x oh x ow] +// (batch x out_channels x out_height x out_width) The bias is of shape [oc] + +Tensor& conv2d_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + Tensor& out) { + // Extract dimensions + const int n = input.size(0); + const int c = input.size(1); + const int h = input.size(2); + const int w = input.size(3); + const int oc = weight.size(0); + const int wc = weight.size(1); + const int wh = weight.size(2); + const int ww = weight.size(3); + const int oh = out.size(2); + const int ow = out.size(3); + + const int16_t s0 = static_cast(stride[0]); + const int16_t s1 = static_cast(stride[1]); + const int16_t p0 = static_cast(padding[0]); + const int16_t p1 = static_cast(padding[1]); + const int16_t d0 = static_cast(dilation[0]); + const int16_t d1 = static_cast(dilation[1]); + const int16_t g = static_cast(groups); + + const float* p_in = input.const_data_ptr(); + const float* p_weight = weight.const_data_ptr(); + const float* p_bias = bias.const_data_ptr(); + float* p_out = out.mutable_data_ptr(); + + const bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + const int ocpg = oc / g; + const int icpg = c / g; + + for (int _n = 0; _n < n; ++_n) { + const float* in_batch = p_in + _n * c * h * w; + float* out_batch = p_out + _n * oc * oh * ow; + for (int _g = 0; _g < g; ++_g) { + int sic = _g * icpg; + int soc = _g * ocpg; + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + float* out_plane = out_batch + _oc * oh * ow; + const float* weight_batch = p_weight + _oc * wc * wh * ww; + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + float acc = p_bias[_oc]; + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const float* in_plane = in_batch + _ic * h * w; + const float* weight_plane = + weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = (_h + _wh) * w + (_w + _ww); + int woff = _wh * ww + _ww; + acc += in_plane[ioff] * weight_plane[woff]; + } + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const float* in_plane = in_batch + _ic * h * w; + const float* weight_plane = + weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int h_pos = _h + d0 * _wh - p0; + int w_pos = _w + d1 * _ww - p1; + if (h_pos >= 0 && h_pos < h && w_pos >= 0 && w_pos < w) { + int ioff = h_pos * w + w_pos; + int woff = _wh * ww + _ww; + acc += in_plane[ioff] * weight_plane[woff]; + } + } + } + } + } + out_plane[_oh * ow + _ow] = acc; + } + } + } + } + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_conv2d.h b/backends/cadence/generic/operators/op_conv2d.h new file mode 100644 index 00000000000..576cb5d5cb5 --- /dev/null +++ b/backends/cadence/generic/operators/op_conv2d.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +Tensor& conv2d_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_conv3d.cpp b/backends/cadence/generic/operators/op_conv3d.cpp new file mode 100644 index 00000000000..0ea6b063311 --- /dev/null +++ b/backends/cadence/generic/operators/op_conv3d.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// This implements a generic 3D float32 convolution kernel. +// The input is of shape [n x c x d x h x w] (batch x channels x depth x height +// x width) The weight is of shape [oc x wc x wd x wh x ww] (out_channels x +// weight_channels x weight_depth x weight_height x weight_width) The output is +// of shape [n x oc x od x oh x ow] (batch x out_channels x out_depth x +// out_height x out_width) The bias is of shape [oc] + +Tensor& conv3d_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + Tensor& out) { + // Extract dimensions + const int n = input.size(0); + const int c = input.size(1); + const int d = input.size(2); + const int h = input.size(3); + const int w = input.size(4); + const int oc = weight.size(0); + const int wc = weight.size(1); + const int wd = weight.size(2); + const int wh = weight.size(3); + const int ww = weight.size(4); + const int od = out.size(2); + const int oh = out.size(3); + const int ow = out.size(4); + + const int16_t s0 = static_cast(stride[0]); + const int16_t s1 = static_cast(stride[1]); + const int16_t s2 = static_cast(stride[2]); + const int16_t p0 = static_cast(padding[0]); + const int16_t p1 = static_cast(padding[1]); + const int16_t p2 = static_cast(padding[2]); + const int16_t d0 = static_cast(dilation[0]); + const int16_t d1 = static_cast(dilation[1]); + const int16_t d2 = static_cast(dilation[2]); + const int16_t g = static_cast(groups); + + const float* p_in = input.const_data_ptr(); + const float* p_weight = weight.const_data_ptr(); + const float* p_bias = bias.const_data_ptr(); + float* p_out = out.mutable_data_ptr(); + + const bool zero_pad_unit_dilation = + d0 == 1 && d1 == 1 && d2 == 1 && p0 == 0 && p1 == 0 && p2 == 0; + const int ocpg = oc / g; + const int icpg = c / g; + + for (int _n = 0; _n < n; ++_n) { + const float* in_batch = p_in + _n * c * d * h * w; + float* out_batch = p_out + _n * oc * od * oh * ow; + for (int _g = 0; _g < g; ++_g) { + int sic = _g * icpg; + int soc = _g * ocpg; + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + float* out_volume = out_batch + _oc * od * oh * ow; + const float* weight_batch = p_weight + _oc * wc * wd * wh * ww; + for (int _d = 0, _od = 0; _od < od; _d += s0, ++_od) { + for (int _h = 0, _oh = 0; _oh < oh; _h += s1, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s2, ++_ow) { + float acc = p_bias[_oc]; + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const float* in_volume = in_batch + _ic * d * h * w; + const float* weight_volume = + weight_batch + (_ic - sic) * wd * wh * ww; + for (int _wd = 0; _wd < wd; ++_wd) { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = + (_d + _wd) * h * w + (_h + _wh) * w + (_w + _ww); + int woff = _wd * wh * ww + _wh * ww + _ww; + acc += in_volume[ioff] * weight_volume[woff]; + } + } + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const float* in_volume = in_batch + _ic * d * h * w; + const float* weight_volume = + weight_batch + (_ic - sic) * wd * wh * ww; + for (int _wd = 0; _wd < wd; ++_wd) { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int d_pos = _d + d0 * _wd - p0; + int h_pos = _h + d1 * _wh - p1; + int w_pos = _w + d2 * _ww - p2; + if (d_pos >= 0 && d_pos < d && h_pos >= 0 && + h_pos < h && w_pos >= 0 && w_pos < w) { + int ioff = d_pos * h * w + h_pos * w + w_pos; + int woff = _wd * wh * ww + _wh * ww + _ww; + acc += in_volume[ioff] * weight_volume[woff]; + } + } + } + } + } + } + out_volume[_od * oh * ow + _oh * ow + _ow] = acc; + } + } + } + } + } + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_conv3d.h b/backends/cadence/generic/operators/op_conv3d.h new file mode 100644 index 00000000000..db896f5f318 --- /dev/null +++ b/backends/cadence/generic/operators/op_conv3d.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +Tensor& conv3d_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_dequantize_per_tensor.cpp b/backends/cadence/generic/operators/op_dequantize_per_tensor.cpp new file mode 100644 index 00000000000..a17fa48778f --- /dev/null +++ b/backends/cadence/generic/operators/op_dequantize_per_tensor.cpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::dequantize; + +Tensor& dequantize_per_tensor_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + + if (input.scalar_type() == ScalarType::Byte) { + const uint8_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Char) { + const int8_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + } else if ( + input.scalar_type() == ScalarType::Bits16 || + input.scalar_type() == ScalarType::UInt16) { + const uint16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Short) { + const int16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Int) { + const int32_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); + } + return out; +} + +Tensor& dequantize_per_tensor_asym8s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const int8_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + return out; +} + +Tensor& dequantize_per_tensor_asym8u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const uint8_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + return out; +} + +Tensor& dequantize_per_tensor_asym16s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const int16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + return out; +} + +Tensor& dequantize_per_tensor_asym16u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const uint16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + return out; +} + +Tensor& dequantize_per_tensor_asym32s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const int32_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_dequantize_per_tensor.h b/backends/cadence/generic/operators/op_dequantize_per_tensor.h new file mode 100644 index 00000000000..b4965c4e7ab --- /dev/null +++ b/backends/cadence/generic/operators/op_dequantize_per_tensor.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& dequantize_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& context, + const ::executorch::aten::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ::executorch::aten::ScalarType dtype, + ::executorch::aten::Tensor& out); + +} +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_fully_connected.cpp b/backends/cadence/generic/operators/op_fully_connected.cpp new file mode 100644 index 00000000000..f1e53ad5f76 --- /dev/null +++ b/backends/cadence/generic/operators/op_fully_connected.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::Tensor; +using ::executorch::runtime::getLeadingDims; +using ::executorch::runtime::KernelRuntimeContext; + +void linear( + const Tensor& input, + const Tensor& weight, + const optional& bias, + Tensor& output) { + const float* __restrict__ input_data = input.const_data_ptr(); + const float* __restrict__ weight_data = weight.const_data_ptr(); + const float* __restrict__ bias_data = bias.value().const_data_ptr(); + float* __restrict__ output_data = output.mutable_data_ptr(); + + // input comes in shape [batch_size, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [batch_size, out_dim] + // Perform matrix multiply (M x N) x (N x P) => M x P + int64_t M = weight.size(0); // = out_dim + int64_t N = weight.size(1); // = in_dim + + // Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the + // leading dimensions is d0 * d1 * ... * d_{N-2} + int64_t leading_dims = getLeadingDims(input, input.dim() - 1); + + for (int i = 0; i < leading_dims; ++i) { + for (int j = 0; j < M; ++j) { + float sum = bias_data[j]; + for (int k = 0; k < N; ++k) { + sum += input_data[i * N + k] * weight_data[j * N + k]; + } + output_data[i * M + j] = sum; + } + } +} + +Tensor& fully_connected_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const optional& bias, + Tensor& output) { + linear(input, weight, bias, output); + return output; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_fully_connected.h b/backends/cadence/generic/operators/op_fully_connected.h new file mode 100644 index 00000000000..d23bcbeb70c --- /dev/null +++ b/backends/cadence/generic/operators/op_fully_connected.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +Tensor& fully_connected_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const optional& bias, + Tensor& output); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_idma_copy.cpp b/backends/cadence/generic/operators/op_idma_copy.cpp new file mode 100644 index 00000000000..760592a10d5 --- /dev/null +++ b/backends/cadence/generic/operators/op_idma_copy.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include // For std::memcpy + +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// CPU implementation of idma_copy_out using std::memcpy +// This function performs a direct memory copy between tensors +Tensor& idma_copy_out( + KernelRuntimeContext& ctx, + const Tensor& src, + const int64_t + task_num, // Unused in CPU implementation but kept for API compatibility + const int64_t + channel, // Unused in CPU implementation but kept for API compatibility + Tensor& out) { + ET_KERNEL_CHECK( + ctx, + src.dtype() == out.dtype() && src.numel() == out.numel(), + InvalidArgument, + out); + + // Use std::memcpy for direct memory copy + std::memcpy( + out.mutable_data_ptr(), + src.const_data_ptr(), + out.nbytes()); + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_idma_copy.h b/backends/cadence/generic/operators/op_idma_copy.h new file mode 100644 index 00000000000..06e470bb80c --- /dev/null +++ b/backends/cadence/generic/operators/op_idma_copy.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& idma_copy_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& src, + const int64_t task_num, + const int64_t channel, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_idma_wait.cpp b/backends/cadence/generic/operators/op_idma_wait.cpp new file mode 100644 index 00000000000..1f90ec1bba7 --- /dev/null +++ b/backends/cadence/generic/operators/op_idma_wait.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "executorch/backends/cadence/generic/operators/op_idma_wait.h" + +#include + +#include "executorch/runtime/core/exec_aten/exec_aten.h" +#include "executorch/runtime/core/exec_aten/util/tensor_util.h" +#include "executorch/runtime/kernel/kernel_runtime_context.h" + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// CPU implementation of idma_wait_out +// Since there's no actual DMA operation in the CPU implementation, +// this is essentially a no-op function that just ensures the output tensor +// has the same content as the input tensor +Tensor& idma_wait_out( + KernelRuntimeContext& ctx, + const Tensor& src, + const int64_t + task_num, // Unused in CPU implementation but kept for API compatibility + Tensor& out) { + ET_KERNEL_CHECK(ctx, src.numel() == out.numel(), InvalidArgument, out); + ET_KERNEL_CHECK(ctx, src.dtype() == out.dtype(), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, + src.const_data_ptr() == out.const_data_ptr(), + InvalidArgument, + out); + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_idma_wait.h b/backends/cadence/generic/operators/op_idma_wait.h new file mode 100644 index 00000000000..2426d98dc37 --- /dev/null +++ b/backends/cadence/generic/operators/op_idma_wait.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& idma_wait_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& src, + const int64_t task_num, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_im2row.cpp b/backends/cadence/generic/operators/op_im2row.cpp new file mode 100644 index 00000000000..8c939c7ad5c --- /dev/null +++ b/backends/cadence/generic/operators/op_im2row.cpp @@ -0,0 +1,309 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +#ifndef DISABLE_ALWAYS_INLINE +#define ALWAYS_INLINE __attribute__((always_inline)) +#else +#define ALWAYS_INLINE inline +#endif + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +template +ALWAYS_INLINE void im2row_( + const T* __restrict__ data_im, + const int32_t in_zero_point, + /* input parameters*/ + const int32_t channels, + const int32_t height, + const int32_t width, + /* output parameters */ + const int32_t out_height, + const int32_t out_width, + /* convolution parameters */ + const int32_t kernel_h, + const int32_t kernel_w, + const int32_t pad_h, + const int32_t pad_w, + const int32_t stride_h, + const int32_t stride_w, + const int32_t dilation_h, + const int32_t dilation_w, + T* __restrict__ data_col, + bool channels_last) { + // Consider convolving the input image of dimensions channels * height * width + // (or height * width * channels for NHWC layout) with a filter of dimensions + // channels * kernels_h * kernels_w. Assume that this convolution will produce + // an output of dimensinos out_height x out_width. For each point the output, + // im2row takes the data from the input that is used in the computation of + // that output point, and flattens it into a vector of size channels_col = + // channels * kernel_h * kernel_w. The output of im2row will therefore be a 2D + // array of size (out_height * out_width) x channels_col + const int32_t channels_col = channels * kernel_h * kernel_w; + + // If the layout is NHWC, we can copy 'channels' worth of contiguous data + // points when performing im2row. + if (channels_last) { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + // Each point in the output domain is the result of applying a filter of + // size kernel_h x kernel_w x channels on the input. But since channels + // is contiguous, we will not explicitly have a loop for it. + for (int _kh = 0; _kh < kernel_h; ++_kh) { + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; + for (int _kw = 0; _kw < kernel_w; ++_kw) { + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + + // h_im and w_im are the actual height and width coordinates of the + // input tensor from where we need to copy 'channels' points. + const T* __restrict__ slice_im = + data_im + (h_im * width + w_im) * channels; + T* __restrict__ slice_col = data_col + i_col * channels_col + + (_kh * kernel_w + _kw) * channels; + // If the coordinates were within the input domain, we copy + // 'channels' contiguous values. Otherwise we will fill the output + // with 0's. + if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + memcpy(slice_col, slice_im, channels * sizeof(T)); + } else { + std::fill_n(slice_col, channels, T(in_zero_point)); + } + } + } + } + } + } else { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + + // Each point in the output domain is the result of applying a filter + // of size chanenls * kernel_h x kernel_w on the input + for (int _c = 0; _c < channels; ++_c) { + for (int _kh = 0; _kh < kernel_h; ++_kh) { + for (int _kw = 0; _kw < kernel_w; ++_kw) { + // c_col is the linearized access in the channels_col vector. + int32_t c_col = (_c * kernel_h + _kh) * kernel_w + _kw; + // h_im and w_im are the actual height and width coordinates of + // the input tensor that we need to copy to the output. + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + // If the current data access is within the input tensor, copy the + // value + data_col[i_col * channels_col + c_col] = + (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) + ? data_im[(_c * height + h_im) * width + w_im] + : static_cast(in_zero_point); + } + } + } + } + } + } +} + +Tensor& im2row_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + const Tensor& in_zero_point, + bool channel_last, + Tensor& out) { + // Compute the input tensor's dims + bool unit_height = input.dim() == 3; + const int32_t batch_size = input.size(0); + const int32_t in_c = + channel_last ? input.size(3 - unit_height) : input.size(1); + const int32_t in_h = + unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); + const int32_t in_w = + channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); + + // Get the kernel parameters + int32_t kernel_h = kernel_size[0]; + int32_t kernel_w = kernel_size[1]; + int32_t dilation_h = dilation[0]; + int32_t dilation_w = dilation[1]; + int32_t pad_h = padding[0]; + int32_t pad_w = padding[1]; + int32_t stride_h = stride[0]; + int32_t stride_w = stride[1]; + + // If we were to apply a convolution on the input tensor, compute the output + // height and width. + int32_t out_h = + (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; + int32_t out_w = + (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; + + ET_DCHECK_MSG( + (out_h * out_w) == out.size(1), "dimension mismatch for output"); + ET_DCHECK_MSG( + (kernel_h * kernel_w * in_c) == out.size(2), + "dimension mismatch for output"); + + // Check if the input is per-tensor quantized or per-channel quantized. The + // zero point for each batch could differ for per-channel quantized input. + bool per_tensor_quantized = in_zero_point.numel() == 1; + +#define typed_im2row(dtype, ctype) \ + case ScalarType::dtype: { \ + const ctype* __restrict__ in_data = input.const_data_ptr(); \ + ctype* __restrict__ out_data = out.mutable_data_ptr(); \ + const int32_t* __restrict__ zero_point = \ + in_zero_point.const_data_ptr(); \ + int32_t in_plane = in_c * in_h * in_w; \ + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ + for (int32_t n = 0; n < batch_size; ++n) { \ + im2row_( \ + &in_data[n * in_plane], \ + per_tensor_quantized ? zero_point[0] : zero_point[n], \ + in_c, \ + in_h, \ + in_w, \ + out_h, \ + out_w, \ + kernel_h, \ + kernel_w, \ + pad_h, \ + pad_w, \ + stride_h, \ + stride_w, \ + dilation_h, \ + dilation_w, \ + &out_data[n * out_plane], \ + channel_last); \ + } \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + typed_im2row(Float, float); + typed_im2row(Byte, uint8_t); + typed_im2row(Char, int8_t); + default: + ET_DCHECK_MSG( + false, + "im2row not implemented for dtype %s", + torch::executor::toString(dtype)); + } +#undef typed_im2row + + return out; +} + +Tensor& im2row_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + int64_t in_zero_point, + bool channel_last, + Tensor& out) { + // Compute the input tensor's dims + bool unit_height = input.dim() == 3; + const int32_t batch_size = input.size(0); + const int32_t in_c = + channel_last ? input.size(3 - unit_height) : input.size(1); + const int32_t in_h = + unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); + const int32_t in_w = + channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); + + // Get the kernel parameters + int32_t kernel_h = kernel_size[0]; + int32_t kernel_w = kernel_size[1]; + int32_t dilation_h = dilation[0]; + int32_t dilation_w = dilation[1]; + int32_t pad_h = padding[0]; + int32_t pad_w = padding[1]; + int32_t stride_h = stride[0]; + int32_t stride_w = stride[1]; + + // If we were to apply a convolution on the input tensor, compute the output + // height and width. + int32_t out_h = + (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; + int32_t out_w = + (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; + + ET_DCHECK_MSG( + (out_h * out_w) == out.size(1), "dimension mismatch for output"); + ET_DCHECK_MSG( + (kernel_h * kernel_w * in_c) == out.size(2), + "dimension mismatch for output"); + +#define typed_im2row_per_tensor(dtype, ctype) \ + case ScalarType::dtype: { \ + const ctype* __restrict__ in_data = input.const_data_ptr(); \ + ctype* __restrict__ out_data = out.mutable_data_ptr(); \ + int32_t in_plane = in_c * in_h * in_w; \ + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ + for (size_t n = 0; n < batch_size; ++n) { \ + im2row_( \ + &in_data[n * in_plane], \ + in_zero_point, \ + in_c, \ + in_h, \ + in_w, \ + out_h, \ + out_w, \ + kernel_h, \ + kernel_w, \ + pad_h, \ + pad_w, \ + stride_h, \ + stride_w, \ + dilation_h, \ + dilation_w, \ + &out_data[n * out_plane], \ + channel_last); \ + } \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + typed_im2row_per_tensor(Float, float); + typed_im2row_per_tensor(Byte, uint8_t); + typed_im2row_per_tensor(Char, int8_t); + default: + ET_DCHECK_MSG( + false, + "im2row.per_tensor not implemented for dtype %s", + torch::executor::toString(dtype)); + } +#undef typed_im2row_per_tensor + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_im2row.h b/backends/cadence/generic/operators/op_im2row.h new file mode 100644 index 00000000000..0eed36c9139 --- /dev/null +++ b/backends/cadence/generic/operators/op_im2row.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& im2row_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef kernel_size, + ::executorch::aten::IntArrayRef dilation, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef stride, + const ::executorch::aten::Tensor& in_zero_point, + bool channel_last, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& im2row_per_tensor_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef kernel_size, + ::executorch::aten::IntArrayRef dilation, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef stride, + int64_t in_zero_point, + bool channel_last, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_linalg_svd.cpp b/backends/cadence/generic/operators/op_linalg_svd.cpp new file mode 100644 index 00000000000..e84a8914f35 --- /dev/null +++ b/backends/cadence/generic/operators/op_linalg_svd.cpp @@ -0,0 +1,365 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include + +const float EPSILON = 1e-10; +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +namespace impl { +namespace generic { +namespace native { +namespace { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; + +// A simple 3x3 matrix struct. +struct Matrix3x3 { + float m[3][3]; +}; + +// Returns the 3x3 identity matrix. +Matrix3x3 identityMatrix() { + Matrix3x3 I{}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + I.m[i][j] = (i == j) ? 1.0 : 0.0; + } + } + return I; +} + +// Transposes matrix A. +Matrix3x3 transpose(const Matrix3x3& A) { + Matrix3x3 At{}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + At.m[i][j] = A.m[j][i]; + } + } + return At; +} + +// Multiplies matrices A and B. +Matrix3x3 multiply(const Matrix3x3& A, const Matrix3x3& B) { + Matrix3x3 C{}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + C.m[i][j] = 0.0; + for (int k = 0; k < 3; k++) { + C.m[i][j] += A.m[i][k] * B.m[k][j]; + } + } + } + return C; +} + +// Jacobi method to compute the eigen-decomposition of a symmetric 3x3 matrix A. +// It outputs the eigenvalues (in 'diag') and the eigenvectors as columns in V. +void jacobiEigenDecomposition(const Matrix3x3& A, float diag[3], Matrix3x3& V) { + Matrix3x3 D = A; // Make a copy; D will be transformed into a diagonal matrix. + V = identityMatrix(); + + // Iterate until convergence (or max iterations) + for (int iter = 0; iter < 100; iter++) { + // Find the largest off-diagonal element in D. + int p = 0, q = 1; + float maxOff = std::fabs(D.m[0][1]); + if (std::fabs(D.m[0][2]) > maxOff) { + maxOff = std::fabs(D.m[0][2]); + p = 0; + q = 2; + } + if (std::fabs(D.m[1][2]) > maxOff) { + maxOff = std::fabs(D.m[1][2]); + p = 1; + q = 2; + } + + if (maxOff < EPSILON) { + break; + } + + // Compute the Jacobi rotation angle. + float theta = 0.0; + if (std::fabs(D.m[p][p] - D.m[q][q]) < EPSILON) { + theta = M_PI / 4.0; + } else { + theta = 0.5 * std::atan2(2 * D.m[p][q], D.m[q][q] - D.m[p][p]); + } + + float c = std::cos(theta); + float s = std::sin(theta); + + // Update the diagonal elements. + float D_pp = c * c * D.m[p][p] - 2 * s * c * D.m[p][q] + s * s * D.m[q][q]; + float D_qq = s * s * D.m[p][p] + 2 * s * c * D.m[p][q] + c * c * D.m[q][q]; + D.m[p][p] = D_pp; + D.m[q][q] = D_qq; + D.m[p][q] = D.m[q][p] = 0.0; + + // Update the remaining elements. + for (int j = 0; j < 3; j++) { + if (j != p && j != q) { + float D_pj = c * D.m[p][j] - s * D.m[q][j]; + float D_qj = s * D.m[p][j] + c * D.m[q][j]; + D.m[p][j] = D.m[j][p] = D_pj; + D.m[q][j] = D.m[j][q] = D_qj; + } + } + + // Update the eigenvector matrix V. + for (int i = 0; i < 3; i++) { + float V_ip = c * V.m[i][p] - s * V.m[i][q]; + float V_iq = s * V.m[i][p] + c * V.m[i][q]; + V.m[i][p] = V_ip; + V.m[i][q] = V_iq; + } + } + + diag[0] = D.m[0][0]; + diag[1] = D.m[1][1]; + diag[2] = D.m[2][2]; +} + +// Sorts the eigenvalues (and the corresponding eigenvectors in V) in descending +// order. +void sortEigenDecomposition(float eigenvalues[3], Matrix3x3& V) { + int indices[3] = {0, 1, 2}; + std::sort(indices, indices + 3, [&](int a, int b) { + return eigenvalues[a] > eigenvalues[b]; + }); + + float sortedEigenvalues[3]; + Matrix3x3 sortedV{}; + for (int i = 0; i < 3; i++) { + sortedEigenvalues[i] = eigenvalues[indices[i]]; + for (int j = 0; j < 3; j++) { + sortedV.m[j][i] = V.m[j][indices[i]]; + } + } + for (int i = 0; i < 3; i++) { + eigenvalues[i] = sortedEigenvalues[i]; + for (int j = 0; j < 3; j++) { + V.m[j][i] = sortedV.m[j][i]; + } + } +} + +// Computes the cross product of two 3D vectors. +void crossProduct(const float a[3], const float b[3], float result[3]) { + result[0] = a[1] * b[2] - a[2] * b[1]; + result[1] = a[2] * b[0] - a[0] * b[2]; + result[2] = a[0] * b[1] - a[1] * b[0]; +} + +// Normalizes a 3D vector. +void normalize(float v[3]) { + float norm = std::sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]); + if (norm > EPSILON) { + v[0] /= norm; + v[1] /= norm; + v[2] /= norm; + } +} + +// Computes the singular value decomposition of A such that A = U * S * Vt. +// U and Vt are orthogonal matrices and S is a diagonal matrix with singular +// values. +std::tuple svd(const Matrix3x3& A) { + // Compute A^T * A (which is symmetric). + Matrix3x3 At = transpose(A); + Matrix3x3 ATA = multiply(At, A); + + // Compute the eigen-decomposition of ATA. + float eigenvalues[3]; + Matrix3x3 V{}; + jacobiEigenDecomposition(ATA, eigenvalues, V); + sortEigenDecomposition(eigenvalues, V); + + // The singular values are the square roots of the eigenvalues. + float sigma[3]; + for (int i = 0; i < 3; i++) { + sigma[i] = (eigenvalues[i] > 0.0) ? std::sqrt(eigenvalues[i]) : 0.0; + } + + // Compute U = A * V * S^{-1}. + Matrix3x3 U{}; + for (int i = 0; i < 3; i++) { + float av[3] = {0, 0, 0}; + // Multiply A by the i-th eigenvector (column of V). + for (int r = 0; r < 3; r++) { + for (int c = 0; c < 3; c++) { + av[r] += A.m[r][c] * V.m[c][i]; + } + } + if (sigma[i] > EPSILON) { + for (int r = 0; r < 3; r++) { + U.m[r][i] = av[r] / sigma[i]; + } + } else { + // If sigma[i] is nearly zero, choose a vector orthogonal to the previous + // ones. + float vec[3] = {0, 0, 0}; + if (i == 1) { + float u0[3] = {U.m[0][0], U.m[1][0], U.m[2][0]}; + float tmp[3] = {1, 0, 0}; + float dot = u0[0] * tmp[0] + u0[1] * tmp[1] + u0[2] * tmp[2]; + if (std::fabs(dot) > 0.9) { + tmp[0] = 0; + tmp[1] = 1; + tmp[2] = 0; + } + crossProduct(u0, tmp, vec); + } else if (i == 2) { + float u0[3] = {U.m[0][0], U.m[1][0], U.m[2][0]}; + float u1[3] = {U.m[0][1], U.m[1][1], U.m[2][1]}; + crossProduct(u0, u1, vec); + } + normalize(vec); + for (int r = 0; r < 3; r++) { + U.m[r][i] = vec[r]; + } + } + } + + // Construct the diagonal S matrix. + Matrix3x3 S{}; + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + S.m[i][j] = (i == j) ? sigma[i] : 0.0; + } + } + + // Vt is the transpose of V. + Matrix3x3 Vt = transpose(V); + + return std::make_tuple(U, S, Vt); +} +} // namespace + +std::tuple linalg_svd_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& A, + bool full_matrices, + bool compute_uv, + ::executorch::aten::optional<::executorch::aten::string_view> driver, + Tensor& U, + Tensor& S, + Tensor& Vh) { + std::tuple ret_val(U, S, Vh); + + ET_KERNEL_CHECK_MSG( + ctx, + A.scalar_type() == ScalarType::Float, + InvalidArgument, + ret_val, + "input.dtype(): %s must be %s", + ::torch::executor::toString(A.scalar_type()), + ::torch::executor::toString(ScalarType::Float)); + + ET_KERNEL_CHECK_MSG( + ctx, A.numel() > 0, InvalidArgument, ret_val, "input.size() must be > 0"); + + ET_KERNEL_CHECK_MSG( + ctx, + A.numel() % 9 == 0, + InvalidArgument, + ret_val, + "SVD of only 3x3 matrix is supported! Expected the input to have (batch_size x 9) number of elements, but received %d elements instead", + int(A.numel())); + + int batch_size = A.numel() / 9; + + ET_KERNEL_CHECK_MSG( + ctx, + U.numel() / 9 == batch_size, + InvalidArgument, + ret_val, + "Output tensor U must have the same batch_size as input: %d, but got: %d instead", + batch_size, + int(U.numel() / 9)); + + ET_KERNEL_CHECK_MSG( + ctx, + S.numel() / 3 == batch_size, + InvalidArgument, + ret_val, + "Output tensor S must have the same batch_size as input: %d, but got: %d instead", + batch_size, + int(S.numel() / 3)); + + ET_KERNEL_CHECK_MSG( + ctx, + Vh.numel() / 9 == batch_size, + InvalidArgument, + ret_val, + "Output tensor Vh must have the same batch_size as input: %d, but got: %d instead", + batch_size, + int(Vh.numel() / 9)); + + const float* A_data = A.const_data_ptr(); + float* U_data = U.mutable_data_ptr(); + float* S_data = S.mutable_data_ptr(); + float* Vh_data = Vh.mutable_data_ptr(); + + for (int i = 0; i < batch_size; i++) { + int offset = i * 9; + Matrix3x3 A_mat = {{ + {A_data[offset + 0], A_data[offset + 1], A_data[offset + 2]}, + {A_data[offset + 3], A_data[offset + 4], A_data[offset + 5]}, + {A_data[offset + 6], A_data[offset + 7], A_data[offset + 8]}, + }}; + + Matrix3x3 U_mat{}, S_mat{}, Vh_mat{}; + std::tie(U_mat, S_mat, Vh_mat) = svd(A_mat); + + U_data[offset + 0] = U_mat.m[0][0]; + U_data[offset + 1] = U_mat.m[0][1]; + U_data[offset + 2] = U_mat.m[0][2]; + U_data[offset + 3] = U_mat.m[1][0]; + U_data[offset + 4] = U_mat.m[1][1]; + U_data[offset + 5] = U_mat.m[1][2]; + U_data[offset + 6] = U_mat.m[2][0]; + U_data[offset + 7] = U_mat.m[2][1]; + U_data[offset + 8] = U_mat.m[2][2]; + + S_data[offset + 0] = S_mat.m[0][0]; + S_data[offset + 1] = S_mat.m[1][1]; + S_data[offset + 2] = S_mat.m[2][2]; + + Vh_data[offset + 0] = Vh_mat.m[0][0]; + Vh_data[offset + 1] = Vh_mat.m[0][1]; + Vh_data[offset + 2] = Vh_mat.m[0][2]; + Vh_data[offset + 3] = Vh_mat.m[1][0]; + Vh_data[offset + 4] = Vh_mat.m[1][1]; + Vh_data[offset + 5] = Vh_mat.m[1][2]; + Vh_data[offset + 6] = Vh_mat.m[2][0]; + Vh_data[offset + 7] = Vh_mat.m[2][1]; + Vh_data[offset + 8] = Vh_mat.m[2][2]; + } + + return ret_val; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_linalg_svd.h b/backends/cadence/generic/operators/op_linalg_svd.h new file mode 100644 index 00000000000..7635276c4f5 --- /dev/null +++ b/backends/cadence/generic/operators/op_linalg_svd.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +std::tuple< + ::executorch::aten::Tensor&, + ::executorch::aten::Tensor&, + ::executorch::aten::Tensor&> +linalg_svd_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& A, + bool full_matrices, + bool compute_uv, + ::executorch::aten::optional<::executorch::aten::string_view> driver, + ::executorch::aten::Tensor& U, + ::executorch::aten::Tensor& S, + ::executorch::aten::Tensor& Vh); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantize_per_tensor.cpp b/backends/cadence/generic/operators/op_quantize_per_tensor.cpp new file mode 100644 index 00000000000..51a0a1d0f9d --- /dev/null +++ b/backends/cadence/generic/operators/op_quantize_per_tensor.cpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::quantize; + +// Quantize the input tensor (PT2 version). Note that quant_ are not +// used in any computation. +Tensor& quantize_per_tensor_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + + if (out.scalar_type() == ScalarType::Byte) { + uint8_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Char) { + int8_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + } else if ( + out.scalar_type() == ScalarType::Bits16 || + out.scalar_type() == ScalarType::UInt16) { + uint16_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Int) { + int32_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(out.scalar_type())); + } + return out; +} + +Tensor& quantize_per_tensor_asym8s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + int8_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + return out; +} + +Tensor& quantize_per_tensor_asym8u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + uint8_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + return out; +} + +Tensor& quantize_per_tensor_asym16s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + int16_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + return out; +} + +Tensor& quantize_per_tensor_asym16u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + uint16_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + return out; +} + +Tensor& quantize_per_tensor_asym32s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + int32_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + return out; +} + +}; // namespace native +}; // namespace generic +}; // namespace impl diff --git a/backends/cadence/generic/operators/op_quantize_per_tensor.h b/backends/cadence/generic/operators/op_quantize_per_tensor.h new file mode 100644 index 00000000000..c1e826a7cf9 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantize_per_tensor.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantize_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& context, + const ::executorch::aten::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ::executorch::aten::ScalarType dtype, + ::executorch::aten::Tensor& out); +} +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_add.cpp b/backends/cadence/generic/operators/op_quantized_add.cpp new file mode 100644 index 00000000000..393a553a253 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_add.cpp @@ -0,0 +1,216 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace impl::generic::native { + +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::dequantize; +using ::impl::generic::kernels::quantize; + +DECLARE_POINTWISE_TENSOR_QUANTIZED_BINARY_OP(quantized_add_, +); + +#define DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP(BINARY_FUNC_NAME, OP) \ + template \ + void BINARY_FUNC_NAME( \ + const Tensor& X, \ + float X_scale, \ + int32_t X_zero_point, \ + const float Y, \ + float out_scale, \ + int32_t out_zero_point, \ + Tensor& out) { \ + const T* __restrict__ X_data = X.const_data_ptr(); \ + T* __restrict__ out_data = out.mutable_data_ptr(); \ + float inv_out_scale = 1.0f / out_scale; \ + for (size_t i = 0, e = X.numel(); i < e; ++i) { \ + float x = dequantize(X_data[i], X_scale, X_zero_point); \ + float z = x OP Y; \ + out_data[i] = quantize(z, inv_out_scale, out_zero_point); \ + } \ + } + +DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP(quantized_add_Scalar_, +); + +Tensor& quantized_add_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + const Tensor& X_scale_t, + const Tensor& X_zero_point_t, + const Tensor& Y, + const Tensor& Y_scale_t, + const Tensor& Y_zero_point_t, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + float X_scale = X_scale_t.const_data_ptr()[0]; + int32_t X_zero_point = X_zero_point_t.const_data_ptr()[0]; + float Y_scale = Y_scale_t.const_data_ptr()[0]; + int32_t Y_zero_point = Y_zero_point_t.const_data_ptr()[0]; + +#define typed_quantized_add(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_add_( \ + X, \ + X_scale, \ + X_zero_point, \ + Y, \ + Y_scale, \ + Y_zero_point, \ + static_cast(out_scale), \ + static_cast(out_zero_point), \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_add); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_add + + return out; +} + +Tensor& quantized_add_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { +#define typed_quantized_add(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_add_( \ + X, \ + static_cast(X_scale), \ + static_cast(X_zero_point), \ + Y, \ + static_cast(Y_scale), \ + static_cast(Y_zero_point), \ + static_cast(out_scale), \ + static_cast(out_zero_point), \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_add); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_add + return out; +} + +Tensor& quantized_add_asym8sxasym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + quantized_add_( + X, + static_cast(X_scale), + static_cast(X_zero_point), + Y, + static_cast(Y_scale), + static_cast(Y_zero_point), + static_cast(out_scale), + static_cast(out_zero_point), + out); + return out; +} + +Tensor& quantized_add_asym8uxasym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + quantized_add_( + X, + static_cast(X_scale), + static_cast(X_zero_point), + Y, + static_cast(Y_scale), + static_cast(Y_zero_point), + static_cast(out_scale), + static_cast(out_zero_point), + out); + return out; +} + +Tensor& quantized_add_Scalar_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + const Tensor& X_scale_t, + const Tensor& X_zero_point_t, + const Scalar& Y_scalar, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + float X_scale = X_scale_t.const_data_ptr()[0]; + int32_t X_zero_point = X_zero_point_t.const_data_ptr()[0]; + float Y = static_cast( + ::torch::executor::native::utils::scalar_to(Y_scalar)); +#define typed_quantized_add_Scalar(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_add_Scalar_( \ + X, \ + X_scale, \ + X_zero_point, \ + Y, \ + static_cast(out_scale), \ + static_cast(out_zero_point), \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_add_Scalar) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_add_Scalar + return out; +} + +#undef DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP + +} // namespace impl::generic::native diff --git a/backends/cadence/generic/operators/op_quantized_add.h b/backends/cadence/generic/operators/op_quantized_add.h new file mode 100644 index 00000000000..3f87ddcf5b9 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_add.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_add_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + const ::executorch::aten::Tensor& X_scale, + const ::executorch::aten::Tensor& X_zero_point, + const ::executorch::aten::Tensor& Y, + const ::executorch::aten::Tensor& Y_scale, + const ::executorch::aten::Tensor& Y_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_add_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + double X_scale, + int64_t X_zero_point, + const ::executorch::aten::Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_add_Scalar_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + const ::executorch::aten::Tensor& X_scale, + const ::executorch::aten::Tensor& X_zero_point, + const ::executorch::aten::Scalar& Y, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_add_asym8sxasym8s_asym8s_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + double X_scale, + int64_t X_zero_point, + const ::executorch::aten::Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_add_asym8uxasym8u_asym8u_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + double X_scale, + int64_t X_zero_point, + const ::executorch::aten::Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_conv1d.cpp b/backends/cadence/generic/operators/op_quantized_conv1d.cpp new file mode 100644 index 00000000000..6ae3a6613fb --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_conv1d.cpp @@ -0,0 +1,514 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +namespace { +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::quantize; + +// This implements a generic 1d conv kernel that operates on raw pointers. +// The quantized version handles both quantized convolutions for 1D inputs. +// The input is of shape [n x c x w] +// The weight is of shape [oc x wc x ww], where wc == c +// The output is of shape [n x oc x ow] +// The bias is of shape [oc] + +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv1d_ncl_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t c, + int32_t w, + int32_t oc, + int32_t wc, + int32_t ww, + int32_t ow, + // Stride + int16_t s, + // Padding + int16_t p, + // Dilation + int16_t d, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + bool zero_pad_unit_dilation = d == 1 && p == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * c * w; + OT* out_batch = p_out + _n * oc * ow; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + OT* out_plane = out_batch + _oc * ow; + const WT* weight_batch = p_weight + _oc * wc * ww; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of size + // icpg x w, with a stencil of size icpg x ww, to compute an + // output channel of size 1 x ow. + for (int _w = 0, _ow = 0; _ow < ow; _w += s, ++_ow) { + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to the + // output channel being computed) with the corresponding weight + // channel. + // If the padding is 0, and dilation is 1, then we can remove the + // unnecessary checks, and simplify the code so that it can be + // vectorized by Tensilica compiler. + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * w; + const WT* weight_plane = weight_batch + (_ic - sic) * ww; + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = _w + _ww; + int woff = _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = + weight_plane[woff] - (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * w; + const WT* weight_plane = weight_batch + (_ic - sic) * ww; + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_w + d * _ww - p) >= 0) && ((_w + d * _ww - p) < w)) { + int ioff = _w + d * _ww - p; + int woff = _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = + weight_plane[woff] - (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_plane[_ow] = quantize(val, inv_out_scale, out_zero_point); + } else { + out_plane[_ow] = acc; + } + } + } + } + } +} + +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv1d_nlc_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t w, + int32_t c, + int32_t oc, + int32_t ww, + int32_t wc, + int32_t ow, + // Stride + int16_t s, + // Padding + int16_t p, + // Dilation + int16_t d, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + bool zero_pad_unit_dilation = d == 1 && p == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * w * c; + OT* out_batch = p_out + _n * ow * oc; + for (int _w = 0, _ow = 0; _ow < ow; _w += s, ++_ow) { + OT* out_line = out_batch + _ow * oc; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + const WT* weight_batch = p_weight + _oc * ww * wc; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of + // size w x icpg, with a stencil of size ww x icpg, to + // compute an output channel of size ow x 1. + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to + // the output channel being computed) with the corresponding + // weight channel. If the padding is 0, and dilation is 1, then + // we can remove the unnecessary checks, and simplify the code + // so that it can be vectorized by Tensilica compiler. + if (zero_pad_unit_dilation) { + for (int _ww = 0; _ww < ww; ++_ww) { + const IT* in_line = in_batch + (_w + _ww) * c; + const WT* weight_line = weight_batch + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } else { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_w + d * _ww - p) >= 0) && ((_w + d * _ww - p) < w)) { + const IT* in_line = in_batch + (_w + d * _ww - p) * c; + const WT* weight_line = weight_batch + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_line[_oc] = quantize(val, inv_out_scale, out_zero_point); + } else { + out_line[_oc] = acc; + } + } + } + } + } +} + +void quantized_conv1d_ncl( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + // input = [n, c, w] + const int n = input.size(0); + const int c = input.size(1); + const int w = input.size(2); + // weight = [oc, wc, ww] + const int oc = weight.size(0); + const int wc = weight.size(1); + const int ww = weight.size(2); + // output = [n, oc, ow] + const int ow = out.size(2); + +#define typed_quantized_conv1d_ncl(ctype, dtype) \ + case ScalarType::dtype: { \ + conv1d_ncl_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + c, \ + w, \ + oc, \ + wc, \ + ww, \ + ow, \ + stride[0], \ + padding[0], \ + dilation[0], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv1d_ncl); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv1d_ncl +} + +void quantized_conv1d_nlc( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + // input = [n, w, c] + const int n = input.size(0); + const int w = input.size(1); + const int c = input.size(2); + // weight = [oc, ww, wc] + const int oc = weight.size(0); + const int ww = weight.size(1); + const int wc = weight.size(2); + // output = [n, ow, oc] + const int ow = out.size(1); + +#define typed_quantized_conv1d_nlc(ctype, dtype) \ + case ScalarType::dtype: { \ + conv1d_nlc_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + w, \ + c, \ + oc, \ + ww, \ + wc, \ + ow, \ + stride[0], \ + padding[0], \ + dilation[0], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv1d_nlc); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv1d_nlc +} + +} // namespace + +Tensor& quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv1d_ncl( + input, + weight, + bias, + stride, + padding, + dilation, + static_cast(groups), + static_cast(in_zero_point), + static_cast(weight_zero_point), + static_cast(bias_scale), + static_cast(output_scale), + static_cast(output_zero_point), + out); + return out; +} + +Tensor& quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv1d_ncl( + input, + weight, + bias, + stride, + padding, + dilation, + static_cast(groups), + static_cast(in_zero_point), + static_cast(weight_zero_point), + static_cast(bias_scale), + static_cast(output_scale), + static_cast(output_zero_point), + out); + return out; +} + +Tensor& quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv1d_nlc( + input, + weight, + bias, + stride, + padding, + dilation, + static_cast(groups), + static_cast(in_zero_point), + static_cast(weight_zero_point), + static_cast(bias_scale), + static_cast(output_scale), + static_cast(output_zero_point), + out); + return out; +} + +Tensor& quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv1d_nlc( + input, + weight, + bias, + stride, + padding, + dilation, + static_cast(groups), + static_cast(in_zero_point), + static_cast(weight_zero_point), + static_cast(bias_scale), + static_cast(output_scale), + static_cast(output_zero_point), + out); + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_conv1d.h b/backends/cadence/generic/operators/op_quantized_conv1d.h new file mode 100644 index 00000000000..5cb79ab09fa --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_conv1d.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +executorch::aten::Tensor& +quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out( + executorch::runtime::KernelRuntimeContext& ctx, + const executorch::aten::Tensor& input, + const executorch::aten::Tensor& weight, + const executorch::aten::Tensor& bias, + executorch::aten::IntArrayRef stride, + executorch::aten::IntArrayRef padding, + executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + executorch::aten::Tensor& out); + +executorch::aten::Tensor& +quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out( + executorch::runtime::KernelRuntimeContext& ctx, + const executorch::aten::Tensor& input, + const executorch::aten::Tensor& weight, + const executorch::aten::Tensor& bias, + executorch::aten::IntArrayRef stride, + executorch::aten::IntArrayRef padding, + executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + executorch::aten::Tensor& out); + +executorch::aten::Tensor& +quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out( + executorch::runtime::KernelRuntimeContext& ctx, + const executorch::aten::Tensor& input, + const executorch::aten::Tensor& weight, + const executorch::aten::Tensor& bias, + executorch::aten::IntArrayRef stride, + executorch::aten::IntArrayRef padding, + executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + executorch::aten::Tensor& out); + +executorch::aten::Tensor& +quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out( + executorch::runtime::KernelRuntimeContext& ctx, + const executorch::aten::Tensor& input, + const executorch::aten::Tensor& weight, + const executorch::aten::Tensor& bias, + executorch::aten::IntArrayRef stride, + executorch::aten::IntArrayRef padding, + executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_conv2d.cpp b/backends/cadence/generic/operators/op_quantized_conv2d.cpp new file mode 100644 index 00000000000..ca701957866 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_conv2d.cpp @@ -0,0 +1,1051 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::quantize; + +/* This implements a generic 2d conv kernel that operates on raw pointers. + * The quantized version handles quantized convolutions for 2D inputs. + * The input is of shape [n x c x h x w] + * The weight is of shape [oc x wc x wh x ww], where wc == c + * The output is of shape [n x oc x oh x ow] + * The bias is of shape [oc] + */ +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv2d_nchw_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t c, + int32_t h, + int32_t w, + int32_t oc, + int32_t wc, + int32_t wh, + int32_t ww, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + const float inv_out_scale = 1.f / out_scale; + bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * c * h * w; + OT* out_batch = p_out + _n * oc * oh * ow; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + OT* out_plane = out_batch + _oc * oh * ow; + const WT* weight_batch = p_weight + _oc * wc * wh * ww; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of size + // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an + // output channel of size 1 x oh x ow. + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to the + // output channel being computed) with the corresponding weight + // channel. + // If the padding is 0, and dilation is 1, then we can remove the + // unnecessary checks, and simplify the code so that it can be + // vectorized by Tensilica compiler. + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = (_h + _wh) * w + (_w + _ww); + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_h + d0 * _wh - p0) >= 0) && + ((_h + d0 * _wh - p0) < h) && + ((_w + d1 * _ww - p1) >= 0) && + ((_w + d1 * _ww - p1) < w)) { + int ioff = + (_h + d0 * _wh - p0) * w + (_w + d1 * _ww - p1); + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_plane[_oh * ow + _ow] = + quantize(val, inv_out_scale, out_zero_point); + } else { + out_plane[_oh * ow + _ow] = acc; + } + } + } + } + } + } +} + +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv2d_nhwc_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t h, + int32_t w, + int32_t c, + int32_t oc, + int32_t wh, + int32_t ww, + int32_t wc, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1.f / out_scale; + bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * h * w * c; + OT* out_batch = p_out + _n * oh * ow * oc; + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + OT* out_line = out_batch + (_oh * ow + _ow) * oc; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + const WT* weight_batch = p_weight + _oc * wh * ww * wc; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of + // size h x w x icpg, with a stencil of size wh x ww x icpg, to + // compute an output channel of size oh x ow x 1. + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to + // the output channel being computed) with the corresponding + // weight channel. If the padding is 0, and dilation is 1, then + // we can remove the unnecessary checks, and simplify the code + // so that it can be vectorized by Tensilica compiler. + if (zero_pad_unit_dilation) { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + const IT* in_line = + in_batch + (_h + _wh) * w * c + (_w + _ww) * c; + const WT* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } else { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_h + d0 * _wh - p0) >= 0) && + ((_h + d0 * _wh - p0) < h) && + ((_w + d1 * _ww - p1) >= 0) && + ((_w + d1 * _ww - p1) < w)) { + const IT* in_line = in_batch + + (_h + d0 * _wh - p0) * w * c + (_w + d1 * _ww - p1) * c; + const WT* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_line[_oc] = quantize(val, inv_out_scale, out_zero_point); + } else { + out_line[_oc] = acc; + } + } + } + } + } + } +} + +void quantized_conv2d_nchw( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, c, h, w] + const int n = input.size(0); + const int c = input.size(1); + const int h = conv1d ? 1 : input.size(2); + const int w = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wc, wh, ww] + const int oc = weight.size(0); + const int wc = weight.size(1); + const int wh = conv1d ? 1 : weight.size(2); + const int ww = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oc, oh, ow] + const int oh = conv1d ? 1 : out.size(2); + const int ow = conv1d ? out.size(2) : out.size(3); + + ET_CHECK_MSG( + weight_zero_point >= -128 && weight_zero_point <= 127, + "weight_zero_point %" PRId32 + " must be in range [-128, 127] for int8 cast", + weight_zero_point); + + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ScalarType::Short && + input.scalar_type() == ScalarType::Short && + weight.scalar_type() == ScalarType::Char) { + conv2d_nchw_core_generic( + input.const_data_ptr(), + weight.const_data_ptr(), + bias.const_data_ptr(), + out.mutable_data_ptr(), + n, + c, + h, + w, + oc, + wc, + wh, + ww, + oh, + ow, + stride[0], + stride[1], + padding[0], + padding[1], + dilation[0], + dilation[1], + groups, + static_cast(in_zero_point), + static_cast(weight_zero_point), + bias_scale, + output_scale, + static_cast(output_zero_point)); + return; + } + +#define typed_quantized_conv2d_nchw(ctype, dtype) \ + case ScalarType::dtype: { \ + conv2d_nchw_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + c, \ + h, \ + w, \ + oc, \ + wc, \ + wh, \ + ww, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_conv2d_nchw); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nchw +} + +void quantized_conv2d_nhwc( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + // input = [n, h, w, c] + const int n = input.size(0); + const int h = input.size(1); + const int w = input.size(2); + const int c = input.size(3); + // weight = [oc, wh, ww, wc] + const int oc = weight.size(0); + const int wh = weight.size(1); + const int ww = weight.size(2); + const int wc = weight.size(3); + // output = [n, oh, ow, oc] + const int oh = out.size(1); + const int ow = out.size(2); + + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ScalarType::Short && + input.scalar_type() == ScalarType::Short && + weight.scalar_type() == ScalarType::Char) { + conv2d_nhwc_core_generic( + input.const_data_ptr(), + weight.const_data_ptr(), + bias.const_data_ptr(), + out.mutable_data_ptr(), + n, + h, + w, + c, + oc, + wh, + ww, + wc, + oh, + ow, + stride[0], + stride[1], + padding[0], + padding[1], + dilation[0], + dilation[1], + groups, + static_cast(in_zero_point), + static_cast(weight_zero_point), + bias_scale, + output_scale, + static_cast(output_zero_point)); + return; + } + +#define typed_quantized_conv2d_nhwc(ctype, dtype) \ + case ScalarType::dtype: { \ + conv2d_nhwc_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + h, \ + w, \ + c, \ + oc, \ + wh, \ + ww, \ + wc, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_conv2d_nhwc); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nhwc +} + +Tensor& quantized_conv2d_nchw_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED const Tensor& out_multiplier, + ET_UNUSED const Tensor& out_shift, + Tensor& out) { + const float bias_scale_float = bias_scale.const_data_ptr()[0]; + const int32_t weight_zero_point_int = + weight_zero_point.const_data_ptr()[0]; + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point_int, + bias_scale_float, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED const Tensor& out_multiplier, + ET_UNUSED const Tensor& out_shift, + Tensor& out) { + const float bias_scale_float = bias_scale.const_data_ptr()[0]; + const int32_t weight_zero_point_int = + weight_zero_point.const_data_ptr()[0]; + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point_int, + bias_scale_float, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_conv2d.h b/backends/cadence/generic/operators/op_quantized_conv2d.h new file mode 100644 index 00000000000..07678b0600c --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_conv2d.h @@ -0,0 +1,326 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// Quantized Conv2D operators - NCHW layout +::executorch::aten::Tensor& quantized_conv2d_nchw_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& out); + +::executorch::aten::Tensor& quantized_conv2d_nchw_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +// Quantized Conv2D operators - NHWC layout +::executorch::aten::Tensor& quantized_conv2d_nhwc_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& out); + +::executorch::aten::Tensor& quantized_conv2d_nhwc_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_embedding_byte.cpp b/backends/cadence/generic/operators/op_quantized_embedding_byte.cpp new file mode 100644 index 00000000000..55ca67648ca --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_embedding_byte.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) + +Tensor& quantized_embedding_byte_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& weight, + const Tensor& weight_scales, + const optional& weight_zero_points, + const Tensor& indices, + ET_UNUSED bool pruned_weights, + Tensor& out) { + size_t embedding_dim = weight.size(1); + + size_t num_groups = 1; + if (weight_scales.dim() == 2) { + num_groups = weight_scales.size(1); + } + + float* out_data = out.mutable_data_ptr(); + const int64_t* indices_ptr = indices.const_data_ptr(); + + const float* scales = weight_scales.const_data_ptr(); + + ScalarType dtype = weight.scalar_type(); + +#define typed_quantized_embedding_byte(ctype, dtype) \ + case ScalarType::dtype: { \ + ctype zp = 0; \ + if (weight_zero_points.has_value()) { \ + zp = weight_zero_points \ + ->const_data_ptr()[index * num_groups + group]; \ + } \ + const size_t output_group_start_offset = \ + embedding_dim * index + group * embedding_group_size; \ + const ctype* w_group = \ + weight.const_data_ptr() + output_group_start_offset; \ + for (size_t j = 0; j < embedding_group_size; ++j) { \ + float val = ((float)w_group[j] - zp) * scale; \ + *out_data++ = val; \ + } \ + break; \ + } + + size_t embedding_group_size = embedding_dim / num_groups; + for (size_t i = 0, e = indices.numel(); i < e; i++) { + int64_t index = indices_ptr[i]; + for (size_t group = 0; group < num_groups; group++) { + float scale = scales[index * num_groups + group]; + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_embedding_byte) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + } + } + +#undef typed_quantized_embedding_byte + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_embedding_byte.h b/backends/cadence/generic/operators/op_quantized_embedding_byte.h new file mode 100644 index 00000000000..a46bebe09df --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_embedding_byte.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_embedding_byte_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& weight_scales, + const ::executorch::aten::optional<::executorch::aten::Tensor>& + weight_zero_points, + const ::executorch::aten::Tensor& indices, + bool pruned_weights, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_fully_connected.cpp b/backends/cadence/generic/operators/op_quantized_fully_connected.cpp new file mode 100644 index 00000000000..55e29cb7f52 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_fully_connected.cpp @@ -0,0 +1,178 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +Tensor& quantized_fully_connected_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point_t, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear + return out; +} + +Tensor& quantized_fully_connected_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_per_tensor_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear + return out; +} + +Tensor& quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_per_tensor_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear + return out; +} + +Tensor& quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_per_tensor_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_fully_connected.h b/backends/cadence/generic/operators/op_quantized_fully_connected.h new file mode 100644 index 00000000000..a7510fba95f --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_fully_connected.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_fully_connected_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t in_zero_point, + const ::executorch::aten::Tensor& weight_zero_point_t, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_fully_connected_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& +quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& +quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_layer_norm.cpp b/backends/cadence/generic/operators/op_quantized_layer_norm.cpp new file mode 100644 index 00000000000..e34ed342d22 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_layer_norm.cpp @@ -0,0 +1,206 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::getLeadingDims; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::dequantize; +using ::impl::generic::kernels::quantize; + +// Compute quantized layer_norm. The current implementation assumes that the +// input is per-tensor quantized. +template +void quantized_layer_norm_per_tensor_( + const Tensor& input, + double input_scale, + int64_t input_zero_point, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + // Get the raw pointers to input, output, weight, and bias + const T* __restrict__ in_data = input.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + const float* __restrict__ weight_data = weight.const_data_ptr(); + const float* __restrict__ bias_data = bias.const_data_ptr(); + + float output_inv_scale = 1.0f / output_scale; + + size_t last_dim = input.size(input.dim() - 1); + size_t leading_dims = + ::executorch::runtime::getLeadingDims(input, input.dim() - 1); + + // Visualize the input tensor as a set of 1d vectors, and compute the + // layer_norm for each vector. + for (size_t i = 0; i < leading_dims; ++i) { + const T* x = in_data + i * last_dim; + T* y = out_data + i * last_dim; + + // compute sum and squared sum. The fp32 sum can be approximated as: + // (X_1 - in_zero_point) * in_scale + (X_2 - in_zero_point) * in_scale + ... + // (X_N - in_zero_point) * in_scale. + int32_t sum = 0; + int32_t sq_sum = last_dim * input_zero_point * input_zero_point; + for (size_t j = 0; j < last_dim; ++j) { + int32_t val = x[j]; + sum += val; + sq_sum += val * val; + } + sq_sum -= (2 * sum * input_zero_point); + sum -= (last_dim * input_zero_point); + + float mean = (input_scale * sum) / last_dim; + float variance = + (sq_sum * input_scale * input_scale) / last_dim - mean * mean; + float inv_std = 1.0f / std::sqrt(variance + eps); + + // y = (x - mean) / std * kGamma + kBeta + for (size_t j = 0; j < last_dim; ++j) { + // y[j] = (x[j] - mean) / std * kGamma + kBeta; + // Since X is quantized, we dequantize it, compute fp32 result, and + // quantize the result to an int8/uint8 value. + float val = dequantize(x[j], input_scale, input_zero_point); + + val = (val - mean) * inv_std * weight_data[j] + bias_data[j]; + y[j] = quantize(val, output_inv_scale, output_zero_point); + } + } +} + +// Compute quantized layer_norm. The current implementation assumes that the +// input is per-tensor quantized. +template +void quantized_layer_norm_( + const Tensor& input, + const Tensor& in_scale, + const Tensor& in_zero_point, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + // Extract the zero point and scale for input tensor. + float input_scale = in_scale.const_data_ptr()[0]; + int64_t input_zero_point = in_zero_point.const_data_ptr()[0]; + quantized_layer_norm_per_tensor_( + input, + input_scale, + input_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); +} + +Tensor& quantized_layer_norm_out( + ET_UNUSED ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& in_scale, + const Tensor& in_zero_point, + ET_UNUSED const IntArrayRef normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { +#define typed_quantized_layer_norm(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_layer_norm_( \ + input, \ + in_scale, \ + in_zero_point, \ + weight, \ + bias, \ + eps, \ + output_scale, \ + output_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_layer_norm) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_layer_norm + return out; +} + +Tensor& quantized_layer_norm_per_tensor_out( + ET_UNUSED ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + double in_scale, + int64_t in_zero_point, + ET_UNUSED const IntArrayRef normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { +#define typed_quantized_layer_norm(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_layer_norm_per_tensor_( \ + input, \ + in_scale, \ + in_zero_point, \ + weight, \ + bias, \ + eps, \ + output_scale, \ + output_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_layer_norm) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_layer_norm + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_layer_norm.h b/backends/cadence/generic/operators/op_quantized_layer_norm.h new file mode 100644 index 00000000000..ed642559248 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_layer_norm.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_layer_norm_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& in_scale, + const ::executorch::aten::Tensor& in_zero_point, + __ET_UNUSED const ::executorch::aten::IntArrayRef normalized_shape, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_layer_norm_per_tensor_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + double in_scale, + int64_t in_zero_point, + __ET_UNUSED const ::executorch::aten::IntArrayRef normalized_shape, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_linear.cpp b/backends/cadence/generic/operators/op_quantized_linear.cpp new file mode 100644 index 00000000000..87f990a855b --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_linear.cpp @@ -0,0 +1,220 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::toString; + +Tensor& quantized_linear_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + int64_t src_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point_t, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (dtype == ScalarType::Short && src.scalar_type() == ScalarType::Short && + weight.scalar_type() == ScalarType::Char) { + ::impl::generic::quantized::quantized_linear_( + src, + weight, + bias, + src_zero_point, + weight_zero_point_t, + out_multiplier, + out_shift, + out_zero_point, + out); + return out; + } + + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_linear); + default: + ET_DCHECK_MSG(false, "Unhandled dtype %s", toString(dtype)); + } +#undef typed_quantized_linear + return out; +} + +Tensor& quantized_linear_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear_per_tensor(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_per_tensor_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (dtype == ScalarType::Short && src.scalar_type() == ScalarType::Short && + weight.scalar_type() == ScalarType::Char) { + ::impl::generic::quantized::quantized_linear_per_tensor_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + out); + return out; + } + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16( + typed_quantized_linear_per_tensor); + default: + ET_KERNEL_CHECK_MSG( + ctx, + false, + InvalidArgument, + out, + "Unhandled dtype %s", + toString(dtype)); + } +#undef typed_quantized_linear_per_tensor + return out; +} + +Tensor& quantized_linear_asym8sxasym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + ET_UNUSED const std::optional& offset, + Tensor& out) { +#define typed_quantized_linear_per_tensor(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_per_tensor_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16( + typed_quantized_linear_per_tensor); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); + } +#undef typed_quantized_linear_per_tensor + return out; +} + +Tensor& quantized_linear_asym8uxasym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + ET_UNUSED const std::optional& offset, + Tensor& out) { +#define typed_quantized_linear_per_tensor(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_per_tensor_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16( + typed_quantized_linear_per_tensor); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); + } +#undef typed_quantized_linear_per_tensor + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_linear.h b/backends/cadence/generic/operators/op_quantized_linear.h new file mode 100644 index 00000000000..b5396cb9701 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_linear.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_linear_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + const ::executorch::aten::Tensor& weight_zero_point_t, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_linear_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& +quantized_linear_asym8sxasym8s_asym8s_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + const std::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& +quantized_linear_asym8uxasym8u_asym8u_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + const std::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_matmul.cpp b/backends/cadence/generic/operators/op_quantized_matmul.cpp new file mode 100644 index 00000000000..e3fb0f00fdc --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_matmul.cpp @@ -0,0 +1,286 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::quantize; + +// The quantized matmul. The quantized matmul accumulates in a wider register, +// whose type is TA. +template < + typename TZ, + typename TA = float, + bool transposed = false, + typename TX = TZ, + typename TY = TZ> +__attribute__((noinline)) void qmatmul( + TZ* __restrict__ Z, + int32_t Z_multiplier, + int32_t Z_shift, + int32_t Z_zero_point, + const TX* __restrict__ X, + int32_t X_zero_point, + const TY* __restrict__ y, + int32_t Y_zero_point, + size_t m, + size_t n, + size_t p) { + // Compute the Z_scale from Z_multiplier and Z_shift + const float Z_scale = -Z_multiplier * 1.0 / (1 << 31) * pow(2, Z_shift); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < p; ++j) { + TA sum = 0; + for (size_t k = 0; k < n; ++k) { + if (transposed) { + sum += (X[i * n + k] - X_zero_point) * (y[j * n + k] - Y_zero_point); + } else { + sum += (X[i * n + k] - X_zero_point) * (y[k * p + j] - Y_zero_point); + } + } + Z[i * p + j] = quantize(sum, Z_scale, Z_zero_point); + } + } +} + +Tensor& quantized_matmul_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + ET_UNUSED const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + size_t batch_size = ::executorch::runtime::getLeadingDims(X, X.dim() - 2); + size_t leading_dim = X.size(X.dim() - 2); + size_t out_dim = Y.size(Y.dim() - 1 - transposed); + size_t in_dim = X.size(X.dim() - 1); + + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ScalarType::Short && + X.scalar_type() == ScalarType::Short && + Y.scalar_type() == ScalarType::Char) { + int16_t* __restrict__ out_data = out.mutable_data_ptr(); + const int16_t* __restrict__ X_data = X.const_data_ptr(); + const int8_t* __restrict__ Y_data = Y.const_data_ptr(); + for (size_t i = 0; i < batch_size; ++i) { + const int16_t* x = X_data + i * leading_dim * in_dim; + const int8_t* y = Y_data + i * in_dim * out_dim; + int16_t* z = out_data + i * leading_dim * out_dim; + if (transposed) { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } else { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } + } + return out; + } + +#define typed_quantized_matmul(ctype, dtype) \ + case ScalarType::dtype: { \ + ctype* __restrict__ out_data = out.mutable_data_ptr(); \ + const ctype* __restrict__ X_data = X.const_data_ptr(); \ + const ctype* __restrict__ Y_data = Y.const_data_ptr(); \ + for (size_t i = 0; i < batch_size; ++i) { \ + const ctype* x = X_data + i * leading_dim * in_dim; \ + const ctype* y = Y_data + i * in_dim * out_dim; \ + ctype* z = out_data + i * leading_dim * out_dim; \ + if (transposed) { \ + qmatmul( \ + z, \ + static_cast(out_multiplier), \ + static_cast(out_shift), \ + static_cast(out_zero_point), \ + x, \ + static_cast(X_zero_point), \ + y, \ + static_cast(Y_zero_point), \ + leading_dim, \ + in_dim, \ + out_dim); \ + } else { \ + qmatmul( \ + z, \ + static_cast(out_multiplier), \ + static_cast(out_shift), \ + static_cast(out_zero_point), \ + x, \ + static_cast(X_zero_point), \ + y, \ + static_cast(Y_zero_point), \ + leading_dim, \ + in_dim, \ + out_dim); \ + } \ + } \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_matmul); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_matmul + return out; +} + +template +void _typed_quantized_matmul( + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + ET_UNUSED const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + size_t batch_size = ::executorch::runtime::getLeadingDims(X, X.dim() - 2); + size_t leading_dim = X.size(X.dim() - 2); + size_t out_dim = Y.size(Y.dim() - 1 - transposed); + size_t in_dim = X.size(X.dim() - 1); + + T* __restrict__ out_data = out.mutable_data_ptr(); + const T* __restrict__ X_data = X.const_data_ptr(); + const T* __restrict__ Y_data = Y.const_data_ptr(); + for (size_t i = 0; i < batch_size; ++i) { + const T* x = X_data + i * leading_dim * in_dim; + const T* y = Y_data + i * in_dim * out_dim; + T* z = out_data + i * leading_dim * out_dim; + if (transposed) { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } else { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } + } +} + +Tensor& quantized_matmul_asym8sxasym8s_asym8s_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + _typed_quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + out); + return out; +} + +Tensor& quantized_matmul_asym8uxasym8u_asym8u_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + _typed_quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + out); + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_matmul.h b/backends/cadence/generic/operators/op_quantized_matmul.h new file mode 100644 index 00000000000..70775380aac --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_matmul.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +Tensor& quantized_matmul_out( + KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out); + +Tensor& quantized_matmul_asym8sxasym8s_asym8s_out( + KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out); + +Tensor& quantized_matmul_asym8uxasym8u_asym8u_out( + KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_mul.cpp b/backends/cadence/generic/operators/op_quantized_mul.cpp new file mode 100644 index 00000000000..89fb2a5250d --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_mul.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::dequantize; +using ::impl::generic::kernels::quantize; + +DECLARE_POINTWISE_TENSOR_QUANTIZED_BINARY_OP(quantized_mul_, *); + +Tensor& quantized_mul_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + const Tensor& X_scale_t, + const Tensor& X_zero_point_t, + const Tensor& Y, + const Tensor& Y_scale_t, + const Tensor& Y_zero_point_t, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + float X_scale = X_scale_t.const_data_ptr()[0]; + int32_t X_zero_point = X_zero_point_t.const_data_ptr()[0]; + float Y_scale = Y_scale_t.const_data_ptr()[0]; + int32_t Y_zero_point = Y_zero_point_t.const_data_ptr()[0]; +#define typed_quantized_mul(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_mul_( \ + X, \ + X_scale, \ + X_zero_point, \ + Y, \ + Y_scale, \ + Y_zero_point, \ + static_cast(out_scale), \ + static_cast(out_zero_point), \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_mul) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_mul + return out; +} + +// Generate kernels that perform elementwise arithmetic on a quantized tensor, +// and a scalar. +#define DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP(BINARY_FUNC_NAME, OP) \ + template \ + void BINARY_FUNC_NAME( \ + const Tensor& X, \ + float X_scale, \ + int32_t X_zero_point, \ + const float Y, \ + float out_scale, \ + int32_t out_zero_point, \ + Tensor& out) { \ + const T* __restrict__ X_data = X.const_data_ptr(); \ + T* __restrict__ out_data = out.mutable_data_ptr(); \ + float inv_out_scale = 1.0f / out_scale; \ + for (size_t i = 0, e = X.numel(); i < e; ++i) { \ + float x = dequantize(X_data[i], X_scale, X_zero_point); \ + float z = x OP Y; \ + out_data[i] = quantize(z, inv_out_scale, out_zero_point); \ + } \ + } + +DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP(quantized_mul_Scalar_, *); + +Tensor& quantized_mul_Scalar_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + const Tensor& X_scale_t, + const Tensor& X_zero_point_t, + const Scalar& Y_scalar, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + float X_scale = X_scale_t.const_data_ptr()[0]; + int32_t X_zero_point = X_zero_point_t.const_data_ptr()[0]; + float Y = static_cast( + ::torch::executor::native::utils::scalar_to(Y_scalar)); + +#define typed_quantized_mul_Scalar(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_mul_Scalar_( \ + X, \ + X_scale, \ + X_zero_point, \ + Y, \ + static_cast(out_scale), \ + static_cast(out_zero_point), \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_mul_Scalar) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_mul_Scalar + return out; +} + +#undef DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_mul.h b/backends/cadence/generic/operators/op_quantized_mul.h new file mode 100644 index 00000000000..7ca8b2f1db0 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_mul.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_mul_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + const ::executorch::aten::Tensor& X_scale_t, + const ::executorch::aten::Tensor& X_zero_point_t, + const ::executorch::aten::Tensor& Y, + const ::executorch::aten::Tensor& Y_scale_t, + const ::executorch::aten::Tensor& Y_zero_point_t, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_mul_Scalar_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + const ::executorch::aten::Tensor& X_scale_t, + const ::executorch::aten::Tensor& X_zero_point_t, + const ::executorch::aten::Scalar& Y_scalar, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_relu.cpp b/backends/cadence/generic/operators/op_quantized_relu.cpp new file mode 100644 index 00000000000..9430951f65b --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_relu.cpp @@ -0,0 +1,184 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::quantize; + +template +void quantized_relu_per_tensor_out_( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + Tensor& output) { + const T* __restrict__ in = input.const_data_ptr(); + T* __restrict__ out = output.mutable_data_ptr(); + + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift); + + for (size_t i = 0, e = input.numel(); i < e; ++i) { + const float temp = in[i] > in_zero_point ? (in[i] - in_zero_point) : 0; + out[i] = quantize(temp, out_scale, out_zero_point); + } +} + +Tensor& quantized_relu_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + Tensor& output) { +#define typed_quantized_relu(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_relu_per_tensor_out_( \ + ctx, \ + input, \ + in_zero_point, \ + out_zero_point, \ + out_multiplier, \ + out_shift, \ + output); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_relu + return output; +} + +template +void quantized_relu_( + const Tensor& input, + const Tensor& in_zero_point, + const int64_t out_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { + T q_zero_point = in_zero_point.const_data_ptr()[0]; + const T* __restrict__ in = input.const_data_ptr(); + T* __restrict__ out = output.mutable_data_ptr(); + + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = + -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); + + for (size_t i = 0, e = input.numel(); i < e; ++i) { + const T temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0; + out[i] = quantize(temp, out_scale, out_zero_point); + } +} + +Tensor& quantized_relu_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& in_zero_point, + const int64_t out_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { +#define typed_quantized_relu(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_relu_( \ + input, \ + in_zero_point, \ + out_zero_point, \ + out_multiplier, \ + out_shift, \ + output); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_relu + return output; +} + +Tensor& quantized_relu_asym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + Tensor& output) { + quantized_relu_per_tensor_out_( + ctx, + input, + in_zero_point, + out_zero_point, + out_multiplier, + out_shift, + output); + return output; +} + +Tensor& quantized_relu_asym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + Tensor& output) { + quantized_relu_per_tensor_out_( + ctx, + input, + in_zero_point, + out_zero_point, + out_multiplier, + out_shift, + output); + return output; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_relu.h b/backends/cadence/generic/operators/op_quantized_relu.h new file mode 100644 index 00000000000..6241b2ddfcf --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_relu.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_relu_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& in_zero_point, + const int64_t out_zero_point, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + ::executorch::aten::Tensor& output); + +::executorch::aten::Tensor& quantized_relu_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + ::executorch::aten::Tensor& output); + +::executorch::aten::Tensor& quantized_relu_asym8s_asym8s_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + ::executorch::aten::Tensor& output); + +::executorch::aten::Tensor& quantized_relu_asym8u_asym8u_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + ::executorch::aten::Tensor& output); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_softmax.cpp b/backends/cadence/generic/operators/op_quantized_softmax.cpp new file mode 100644 index 00000000000..61037f22167 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_softmax.cpp @@ -0,0 +1,221 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { +namespace { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::dequantize; +using ::impl::generic::kernels::quantize; + +template +void quantized_softmax_per_tensor_( + const Tensor& input, + ET_UNUSED const Tensor& mask, + int64_t dim, + const float in_scale, + const int64_t in_zero_point, + const float out_scale, + const int64_t out_zero_point, + Tensor& out) { + const T* __restrict__ in_data = input.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + + float out_inv_scale = 1.0f / out_scale; + if (dim < 0) { + dim += input.dim(); + } + const int64_t input_size = input.numel(); + float* x = new float[input_size]; + + torch::executor::apply_over_dim( + [in_data, + out_data, + x, + in_scale, + in_zero_point, + out_inv_scale, + out_zero_point]( + const size_t size, const size_t stride, const size_t base) { + // Dequantize the input tensor + torch::executor::apply_unary_map_fn( + [in_scale, in_zero_point](const float val_in) { + return dequantize( + val_in, in_scale, static_cast(in_zero_point)); + }, + in_data + base, + x + base, + size, + stride); + + // Subtract max(X) from input tensor + float max_in = torch::executor::apply_unary_reduce_fn( + [](const float val_in, float val_accum) { + return std::max(val_in, val_accum); + }, + x + base, + size, + stride); + + // Compute exp(X - max(X)) + torch::executor::apply_unary_map_fn( + [max_in](const float val_in) { return std::exp(val_in - max_in); }, + x + base, + x + base, + size, + stride); + + // Compute sum(exp(X - max(X)) + float temp_sum = torch::executor::apply_unary_reduce_fn( + [](const float val_in, float val_accum) { + return val_accum + val_in; + }, + x + base, + size, + stride); + + // Compute exp(X - max(X)) / sum(exp(X - max(X)) and quantize the + float recip = 1.0 / temp_sum; + torch::executor::apply_unary_map_fn( + [recip, out_inv_scale, out_zero_point](const float val_in) { + float res = val_in * recip; + return quantize( + res, out_inv_scale, static_cast(out_zero_point)); + }, + x + base, + out_data + base, + size, + stride); + }, + input, + dim); + + delete[] x; +} + +// Compute quantized softmax. The current implementation assumes that the +// input is per-tensor quantized. +template +void quantized_softmax_( + const Tensor& input, + const Tensor& mask, + const int64_t dim, + const Tensor& in_scale, + const Tensor& in_zero_point, + const Tensor& out_scale, + const Tensor& out_zero_point, + Tensor& out) { + // Extract the zero point and scale for input tensor. + float input_scale = in_scale.const_data_ptr()[0]; + int64_t input_zero_point = in_zero_point.const_data_ptr()[0]; + float output_scale = out_scale.const_data_ptr()[0]; + int64_t output_zero_point = out_zero_point.const_data_ptr()[0]; + quantized_softmax_per_tensor_( + input, + mask, + dim, + input_scale, + input_zero_point, + output_scale, + output_zero_point, + out); +} + +} // namespace + +Tensor& quantized_softmax_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& mask, + int64_t dim, + const Tensor& in_scale, + const Tensor& in_zero_point, + const Tensor& out_scale, + const Tensor& out_zero_point, + Tensor& out) { +#define typed_quantized_softmax(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_softmax_( \ + input, \ + mask, \ + dim, \ + in_scale, \ + in_zero_point, \ + out_scale, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_softmax) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_softmax + return out; +} + +Tensor& quantized_softmax_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& mask, + int64_t dim, + double in_scale, + int64_t in_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { +#define typed_quantized_softmax(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_softmax_per_tensor_( \ + input, \ + mask, \ + dim, \ + in_scale, \ + in_zero_point, \ + out_scale, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_softmax) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_softmax + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_softmax.h b/backends/cadence/generic/operators/op_quantized_softmax.h new file mode 100644 index 00000000000..485113851a3 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_softmax.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_softmax_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& mask, + int64_t dim, + const ::executorch::aten::Tensor& in_scale, + const ::executorch::aten::Tensor& in_zero_point, + const ::executorch::aten::Tensor& out_scale, + const ::executorch::aten::Tensor& out_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_softmax_per_tensor_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& mask, + int64_t dim, + double in_scale, + int64_t in_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_requantize.cpp b/backends/cadence/generic/operators/op_requantize.cpp new file mode 100644 index 00000000000..f846a1964a3 --- /dev/null +++ b/backends/cadence/generic/operators/op_requantize.cpp @@ -0,0 +1,268 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::dequantize; +using ::impl::generic::kernels::quantize; + +// Requantize the int8_t/uint8_t input tensor to a uint8_t/int8_t out tensor. +// The scale and zero_point for requantization are in the args. +Tensor& requantize_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& in_scale_t, + const Tensor& in_zero_point_t, + const Tensor& out_scale_t, + const Tensor& out_zero_point_t, + const ScalarType out_dtype, + Tensor& out) { + ET_KERNEL_CHECK_MSG( + ctx, + in_scale_t.scalar_type() == ScalarType::Float, + InvalidArgument, + out, + "In scale is not a float: %s", + torch::executor::toString(in_scale_t.scalar_type())); + float in_scale = in_scale_t.const_data_ptr()[0]; + + ET_KERNEL_CHECK_MSG( + ctx, + in_zero_point_t.scalar_type() == ScalarType::Int, + InvalidArgument, + out, + "In zero point is not an int: %s", + torch::executor::toString(in_zero_point_t.scalar_type())); + int32_t in_zero_point = in_zero_point_t.const_data_ptr()[0]; + + ET_KERNEL_CHECK_MSG( + ctx, + out_scale_t.scalar_type() == ScalarType::Float, + InvalidArgument, + out, + "Out scale is not a float: %s", + torch::executor::toString(out_scale_t.scalar_type())); + float out_scale = out_scale_t.const_data_ptr()[0]; + + ET_KERNEL_CHECK_MSG( + ctx, + out_zero_point_t.scalar_type() == ScalarType::Int, + InvalidArgument, + out, + "Out zero point is not an int: %s", + torch::executor::toString(out_zero_point_t.scalar_type())); + int32_t out_zero_point = out_zero_point_t.const_data_ptr()[0]; + + ET_KERNEL_CHECK_MSG( + ctx, + out.scalar_type() == out_dtype, + InvalidArgument, + out, + "Out tensor dtype (%s) does not match the passed in out dtype (%s)", + torch::executor::toString(out.scalar_type()), + torch::executor::toString(out_dtype)); + + const size_t numel = out.numel(); + ScalarType in_dtype = input.scalar_type(); + + // Assert that the output tensor's dtype is same as out_dtype. + ET_KERNEL_CHECK_MSG( + ctx, + out_dtype == out.scalar_type(), + InvalidArgument, + out, + "Out dtype %s does not match requant dtype %s", + torch::executor::toString(out.scalar_type()), + torch::executor::toString(out_dtype)); + +#define typed_requantize(ctype, dtype) \ + const ctype* input_data = input.const_data_ptr(); \ + dtype* out_data = out.mutable_data_ptr(); \ + for (size_t i = 0; i < numel; ++i) { \ + float dequant = dequantize(input_data[i], in_scale, in_zero_point); \ + out_data[i] = quantize(dequant, 1 / out_scale, out_zero_point); \ + }; + +#define typed_requantize_in(ctype) \ + switch (out_dtype) { \ + case ScalarType::Byte: { \ + typed_requantize(ctype, uint8_t); \ + break; \ + } \ + case ScalarType::Char: { \ + typed_requantize(ctype, int8_t); \ + break; \ + } \ + case ScalarType::UInt16: { \ + typed_requantize(ctype, uint16_t); \ + break; \ + } \ + case ScalarType::Short: { \ + typed_requantize(ctype, int16_t); \ + break; \ + } \ + default: \ + ET_KERNEL_CHECK_MSG( \ + ctx, \ + false, \ + InvalidArgument, \ + out, \ + "Unhandled output dtype %s", \ + torch::executor::toString(out_dtype)); \ + } + + switch (in_dtype) { + case ScalarType::Byte: { + typed_requantize_in(uint8_t); + break; + } + case ScalarType::Char: { + typed_requantize_in(int8_t); + break; + } + case ScalarType::UInt16: { + typed_requantize_in(uint16_t); + break; + } + case ScalarType::Short: { + typed_requantize_in(int16_t); + break; + } + default: + ET_KERNEL_CHECK_MSG( + ctx, + false, + InvalidArgument, + out, + "Unhandled input dtype %s", + torch::executor::toString(in_dtype)); + } +#undef typed_requantize_in +#undef typed_requantize + return out; +} + +// Requantize the int8_t/uint8_t input tensor to a uint8_t/int8_t out tensor. +// The scale and zero_point for requantization are in the args. +Tensor& requantize_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + double in_scale, + int64_t in_zero_point, + double out_scale, + int64_t out_zero_point, + const ScalarType out_dtype, + Tensor& out) { + ET_KERNEL_CHECK_MSG( + ctx, + out.scalar_type() == out_dtype, + InvalidArgument, + out, + "Out tensor dtype (%s) does not match the passed in out dtype (%s)", + torch::executor::toString(out.scalar_type()), + torch::executor::toString(out_dtype)); + + const size_t numel = out.numel(); + ScalarType in_dtype = input.scalar_type(); + + // Assert that the output tensor's dtype is same as out_dtype. + ET_KERNEL_CHECK_MSG( + ctx, + out_dtype == out.scalar_type(), + InvalidArgument, + out, + "Out dtype %s does not match requant dtype %s", + torch::executor::toString(out.scalar_type()), + torch::executor::toString(out_dtype)); + +#define typed_requantize(ctype, dtype) \ + const ctype* input_data = input.const_data_ptr(); \ + dtype* out_data = out.mutable_data_ptr(); \ + for (size_t i = 0; i < numel; ++i) { \ + float dequant = dequantize(input_data[i], in_scale, in_zero_point); \ + out_data[i] = quantize(dequant, 1 / out_scale, out_zero_point); \ + }; + +#define typed_requantize_in(ctype) \ + switch (out_dtype) { \ + case ScalarType::Byte: { \ + typed_requantize(ctype, uint8_t); \ + break; \ + } \ + case ScalarType::Char: { \ + typed_requantize(ctype, int8_t); \ + break; \ + } \ + case ScalarType::UInt16: { \ + typed_requantize(ctype, uint16_t); \ + break; \ + } \ + case ScalarType::Short: { \ + typed_requantize(ctype, int16_t); \ + break; \ + } \ + default: \ + ET_KERNEL_CHECK_MSG( \ + ctx, \ + false, \ + InvalidArgument, \ + out, \ + "Unhandled output dtype %s", \ + torch::executor::toString(out_dtype)); \ + } + + switch (in_dtype) { + case ScalarType::Byte: { + typed_requantize_in(uint8_t); + break; + } + case ScalarType::Char: { + typed_requantize_in(int8_t); + break; + } + case ScalarType::UInt16: { + typed_requantize_in(uint16_t); + break; + } + case ScalarType::Short: { + typed_requantize_in(int16_t); + break; + } + default: + ET_KERNEL_CHECK_MSG( + ctx, + false, + InvalidArgument, + out, + "Unhandled input dtype %s", + torch::executor::toString(in_dtype)); + } +#undef typed_requantize_in +#undef typed_requantize + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_requantize.h b/backends/cadence/generic/operators/op_requantize.h new file mode 100644 index 00000000000..8ce4bc39bb6 --- /dev/null +++ b/backends/cadence/generic/operators/op_requantize.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& requantize_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& in_scale_t, + const ::executorch::aten::Tensor& in_zero_point_t, + const ::executorch::aten::Tensor& out_scale_t, + const ::executorch::aten::Tensor& out_zero_point_t, + const ::executorch::aten::ScalarType out_dtype, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& requantize_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + double in_scale, + int64_t in_zero_point, + double out_scale, + int64_t out_zero_point, + const ::executorch::aten::ScalarType out_dtype, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_roi_align_box_processor.cpp b/backends/cadence/generic/operators/op_roi_align_box_processor.cpp new file mode 100644 index 00000000000..a2f7fdfde7c --- /dev/null +++ b/backends/cadence/generic/operators/op_roi_align_box_processor.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace impl { +namespace generic { +namespace native { +namespace { + +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +using UnpackedVec = std::array; +using PackedVec = std::array; +using IterVec = std::array; + +IterVec computeAddrIncr(const IterVec& shape, const IterVec& strides) { + auto rank = shape.size(); + auto inc = strides; + for (int n = 1; n < static_cast(rank); ++n) { + inc[n] = strides[n] - strides[n - 1] * shape[n - 1] + inc[n - 1]; + } + return inc; +} + +template +PackedVec packTuringVals(const UnpackedVec& vals, bool is_signed) { + PackedVec result{}; + int bitPos = 0; // bit position in output vector + for (int v : vals) { + assert(is_signed || v >= 0); + if (is_signed) { + assert( + v >= -(1 << (perItemBitWidth - 1)) && + v < (1 << (perItemBitWidth - 1))); + } else { + assert(v < (1 << perItemBitWidth)); + } + + if (v < 0) { + v = (1 << perItemBitWidth) + v; + } + + // Extract bit by bit and store in the output array + for (int bit = 0; bit < perItemBitWidth; ++bit) { + auto outBitIndex = bitPos + bit; + auto byteIndex = outBitIndex / 8; + auto bitInByte = outBitIndex % 8; + // Extract bit from val + uint8_t bitVal = (v >> bit) & 1; + // Set bit in output byte + result[byteIndex] |= (bitVal << bitInByte); + } + bitPos += perItemBitWidth; + } + assert(bitPos == vals.size() * perItemBitWidth); + return result; +} + +template +constexpr int get_fp_scale() { + return 1 << frac_bits; +} + +template +int convert_to_S13(float fp) { + return int(std::round(fp * get_fp_scale())); +} + +PackedVec convertBoxPosToTuringConfig( + float topLeftX, + float topLeftY, + float bottomRightX, + float bottomRightY, + int roiAlignNumBoxes, + int output_size_h, + int output_size_w, + int sampling_ratio, + bool aligned) { + constexpr int precisionMode = 0; + auto dstImgH = output_size_h * sampling_ratio; + auto dstImgW = output_size_w * sampling_ratio; + auto dstTileH = dstImgH; + auto dstTileW = dstImgW; + + float stepX = (bottomRightX - topLeftX) / dstImgW; + float stepY = (bottomRightY - topLeftY) / dstImgH; + + if (aligned) { + topLeftX -= 0.5; + topLeftY -= 0.5; + } + + auto anchorX = convert_to_S13(topLeftX + stepX / 2); + auto anchorY = convert_to_S13(topLeftY + stepY / 2); + + UnpackedVec vals{}; + vals[0] = anchorX; + vals[1] = anchorY; + + IterVec shape = {dstTileW, dstTileH, 1, 1, 1, roiAlignNumBoxes}; + auto addrIncrementX = computeAddrIncr( + shape, + {convert_to_S13(stepX), + 0, + convert_to_S13(stepX * dstTileW), + 0, + 0, + 0}); + auto addrIncrementY = computeAddrIncr( + shape, + {0, + convert_to_S13(stepY), + 0, + convert_to_S13(stepY * dstTileH), + 0, + 0}); + + for (int i = 0; i < 10; ++i) { + vals[i + 2] = i < addrIncrementX.size() + ? addrIncrementX[i] + : addrIncrementX[addrIncrementX.size() - 1]; + vals[i + 12] = i < addrIncrementY.size() + ? addrIncrementY[i] + : addrIncrementY[addrIncrementY.size() - 1]; + } + + return packTuringVals(vals, true); +} + +} // namespace + +Tensor& roi_align_box_processor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& rois, + int64_t output_size_h, + int64_t output_size_w, + int64_t sampling_ratio, + bool aligned, + Tensor& out) { + int K = static_cast(rois.size(0)); + auto roi = rois.const_data_ptr(); + for (int i = 0; i < K; ++i) { + assert( + static_cast(roi[i * 5]) == 0 && "Only support 1 image for now."); + auto x1 = roi[i * 5 + 1]; + auto y1 = roi[i * 5 + 2]; + auto x2 = roi[i * 5 + 3]; + auto y2 = roi[i * 5 + 4]; + auto turing_roi = convertBoxPosToTuringConfig( + x1, + y1, + x2, + y2, + static_cast(K), + static_cast(output_size_h), + static_cast(output_size_w), + static_cast(sampling_ratio), + aligned); + static_assert(turing_roi.size() == 80); + + auto out_ptr = out.mutable_data_ptr() + i * turing_roi.size(); + for (auto val : turing_roi) { + *out_ptr++ = val; + } + } + return out; +} +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_roi_align_box_processor.h b/backends/cadence/generic/operators/op_roi_align_box_processor.h new file mode 100644 index 00000000000..34d07d7b700 --- /dev/null +++ b/backends/cadence/generic/operators/op_roi_align_box_processor.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& roi_align_box_processor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& rois, + int64_t output_size_h, + int64_t output_size_w, + int64_t sampling_ratio, + bool aligned, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_rope.cpp b/backends/cadence/generic/operators/op_rope.cpp new file mode 100644 index 00000000000..4a392bed1ee --- /dev/null +++ b/backends/cadence/generic/operators/op_rope.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "executorch/backends/cadence/generic/operators/op_rope.h" + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::Tensor; + +Tensor& rope_out( + ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& sin_tensor, + const Tensor& cos_tensor, + const optional& pos, + Tensor& out) { + // Input shape is [1, seq, h, hd / 2, 2] or [1, seq, h, hd] + const ssize_t seq_length = input.size(1); + const ssize_t num_heads = input.size(2); + const ssize_t head_dimension = input.numel() / (seq_length * num_heads); + const ssize_t head_dimension_by_two = head_dimension / 2; + for (int32_t s = 0; s < seq_length; ++s) { + for (int32_t h = 0; h < num_heads; ++h) { + for (int32_t hd_o = 0; hd_o < head_dimension_by_two; ++hd_o) { + // Process 2 elements in head dimension at a time. + const float x_0 = input.const_data_ptr() + [s * num_heads * head_dimension + + h * head_dimension + hd_o * 2]; + const float x_1 = input.const_data_ptr() + [s * num_heads * head_dimension + + h * head_dimension + hd_o * 2 + 1]; + int64_t token_id = s; + if (pos.has_value()) { + if (pos->scalar_type() == ::executorch::aten::ScalarType::Int) { + token_id = pos.has_value() ? pos->const_data_ptr()[s] : s; + } else { + token_id = pos.has_value() ? pos->const_data_ptr()[s] : s; + } + } + + const float sin = sin_tensor.const_data_ptr< + float>()[token_id * head_dimension_by_two + hd_o]; + const float cos = cos_tensor.const_data_ptr< + float>()[token_id * head_dimension_by_two + hd_o]; + + const float out_0 = x_0 * cos - x_1 * sin; + out.mutable_data_ptr() + [s * num_heads * head_dimension + h * head_dimension + hd_o * 2] = + out_0; + + const float out_1 = x_0 * sin + x_1 * cos; + out.mutable_data_ptr() + [s * num_heads * head_dimension + h * head_dimension + hd_o * 2 + + 1] = out_1; + } + } + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_rope.h b/backends/cadence/generic/operators/op_rope.h new file mode 100644 index 00000000000..cdd4db1be0f --- /dev/null +++ b/backends/cadence/generic/operators/op_rope.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& rope_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& sin_tensor, + const ::executorch::aten::Tensor& cos_tensor, + const ::executorch::aten::optional<::executorch::aten::Tensor>& pos, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_softmax.cpp b/backends/cadence/generic/operators/op_softmax.cpp new file mode 100644 index 00000000000..97c64a22511 --- /dev/null +++ b/backends/cadence/generic/operators/op_softmax.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +namespace { + +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +void vec_softmax_f32_f32( + float* __restrict__ y, + const float* __restrict__ x, + int n) { + // compute softmax(x, x+n) and returns in y + // y = e ^ (x - max(x)) / sum(e^(x - max(x)) + float max_x = *std::max_element(x, x + n); + float sum = 0; + + for (int i = 0; i < n; ++i) { + y[i] = expf(x[i] - max_x); + sum += y[i]; + } + + for (int i = 0; i < n; ++i) { + y[i] /= sum; + } +} + +// This function is borrowed from the portable kernel implementation, with only +// float type supported. +void _softmax_portable(const Tensor& in, int64_t dim, Tensor& out) { + const float* const in_data = in.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + + torch::executor::apply_over_dim( + [in_data, out_data]( + const size_t size, const size_t stride, const size_t base) { + // calculate max in log_softmax dim. During log_softmax + // computation each value is subtracted by the maximum in + // value before calling exp to preserve numerical stability. + const float max_in = torch::executor::apply_unary_reduce_fn( + [](const float val_in, float val_accum) { + return std::max(val_in, val_accum); + }, + in_data + base, + size, + stride); + + float temp_sum = + torch::executor::apply_unary_map_reduce_fn( + [max_in](const float val_in) { + return std::exp(val_in - max_in); + }, + [](const float mapped_in, float val_accum) { + return val_accum + mapped_in; + }, + in_data + base, + size, + stride); + + torch::executor::apply_unary_map_fn( + [max_in, temp_sum](const float val_in) { + return std::exp(val_in - max_in) / temp_sum; + }, + in_data + base, + out_data + base, + size, + stride); + }, + in, + dim); +} + +} // namespace + +Tensor& _softmax_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + int64_t dim, + ET_UNUSED bool half_to_float, + Tensor& Y) { + if (dim < 0) { + dim += X.dim(); + } + + // If dim is not the last dimension, we cannot use the kernel below. + // Falling back on a more generic kernel. + if (dim < X.dim() - 1) { + _softmax_portable(X, dim, Y); + return Y; + } + + const float* __restrict__ x_data = X.const_data_ptr(); + float* __restrict__ y_data = Y.mutable_data_ptr(); + + size_t K = X.size(X.dim() - 1); + size_t leading_dim = ::executorch::runtime::getLeadingDims(X, X.dim() - 1); + + for (size_t i = 0; i < leading_dim; ++i) { + const float* x = x_data + i * K; + float* y = y_data + i * K; + vec_softmax_f32_f32(y, x, K); + } + + return Y; +} + +Tensor& _softmax_f32_f32_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + int64_t dim, + __ET_UNUSED ::executorch::aten::optional half_to_float, + Tensor& Y) { + _softmax_out(ctx, X, dim, false, Y); + + return Y; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_softmax.h b/backends/cadence/generic/operators/op_softmax.h new file mode 100644 index 00000000000..ec51b1d00c0 --- /dev/null +++ b/backends/cadence/generic/operators/op_softmax.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& _softmax_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + int64_t dim, + __ET_UNUSED bool half_to_float, + ::executorch::aten::Tensor& Y); + +::executorch::aten::Tensor& _softmax_f32_f32_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + int64_t dim, + __ET_UNUSED ::executorch::aten::optional half_to_float, + ::executorch::aten::Tensor& Y); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_transposed_convolution.cpp b/backends/cadence/generic/operators/op_transposed_convolution.cpp new file mode 100644 index 00000000000..121b479e65f --- /dev/null +++ b/backends/cadence/generic/operators/op_transposed_convolution.cpp @@ -0,0 +1,639 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "executorch/backends/cadence/generic/operators/op_transposed_convolution.h" + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::quantize; + +// This implements a generic 2d transposed_conv kernel that operates on raw +// pointers. The version handles both quantized and fp32 convolutions. +// The input is of shape [n x c x h x w] +// The weight is of shape [oc/groups x wc x wh x ww], where wc == c +// The output is of shape [n x oc x oh x ow] +// The bias is of shape [oc] +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void transposed_conv2d_nchw_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t c, + int32_t h, + int32_t w, + int32_t oc, + int32_t wc, + int32_t wh, + int32_t ww, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + const int32_t* __restrict__ weight_zero_point = nullptr, + const float* __restrict__ bias_scale = nullptr, + float out_scale = 1, + OT out_zero_point = 0, + bool per_tensor_quantized = true) { + float inv_out_scale = 1. / out_scale; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * c * h * w; + OT* out_batch = p_out + _n * oc * oh * ow; + // Compute separable transposed_convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + OT* out_plane = out_batch + _oc * oh * ow; + const WT* weight_batch = p_weight + (_oc - soc) * wc * wh * ww; + // We compute one output channel at a time. + for (int _oh = 0; _oh < oh; ++_oh) { + for (int _ow = 0; _ow < ow; ++_ow) { + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to the + // output channel being computed) with the corresponding weight + // channel. + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + _ic * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + int _ih = _oh - ((wh - 1) * d0) + _wh * d0 + p0; + if (_ih < 0 || _ih >= s0 * h || _ih % s0 != 0) { + continue; + } + _ih = _ih / s0; + for (int _ww = 0; _ww < ww; ++_ww) { + int _iw = _ow - ((ww - 1) * d1) + _ww * d1 + p1; + if (_iw < 0 || _iw >= s1 * w || _iw % s1 != 0) { + continue; + } + _iw = _iw / s1; + int ioff = _ih * w + _iw; + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point[0] : 0); + acc += lhs * rhs; + } + } + } + if (quantized) { + float val = + (per_tensor_quantized ? bias_scale[0] : bias_scale[_oc]) * + acc; + out_plane[_oh * ow + _ow] = + quantize(val, inv_out_scale, out_zero_point); + } else { + out_plane[_oh * ow + _ow] = acc; + } + } + } + } + } + } +} + +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void transposed_conv2d_nhwc_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t h, + int32_t w, + int32_t c, + int32_t oc, + int32_t wh, + int32_t ww, + int32_t wc, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + const int32_t* __restrict__ weight_zero_point = nullptr, + const float* __restrict__ bias_scale = nullptr, + float out_scale = 1, + OT out_zero_point = 0, + bool per_tensor_quantized = true) { + float inv_out_scale = 1. / out_scale; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * h * w * c; + OT* out_batch = p_out + _n * oh * ow * oc; + for (int _oh = 0; _oh < oh; ++_oh) { + for (int _ow = 0; _ow < ow; ++_ow) { + OT* out_line = out_batch + (_oh * ow + _ow) * oc; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + const WT* weight_batch = p_weight + (_oc - soc) * wh * ww * wc; + // We compute one output channel at a time. + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to + // the output channel being computed) with the corresponding + // weight channel. + for (int _wh = 0; _wh < wh; ++_wh) { + int _ih = _oh - ((wh - 1) * d0) + _wh * d0 + p0; + if (_ih < 0 || _ih >= s0 * h || _ih % s0 != 0) { + continue; + } + _ih = _ih / s0; + for (int _ww = 0; _ww < ww; ++_ww) { + int _iw = _ow - ((ww - 1) * d1) + _ww * d1 + p1; + if (_iw < 0 || _iw >= s1 * w || _iw % s1 != 0) { + continue; + } + _iw = _iw / s1; + const IT* in_line = in_batch + _ih * w * c + _iw * c; + const WT* weight_line = weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = + weight_line[_ic] - (quantized ? weight_zero_point[0] : 0); + acc += lhs * rhs; + } + } + } + if (quantized) { + float val = + (per_tensor_quantized ? bias_scale[0] : bias_scale[_oc]) * + acc; + out_line[_oc] = quantize(val, inv_out_scale, out_zero_point); + } else { + out_line[_oc] = acc; + } + } + } + } + } + } +} + +void transposed_convolution_nchw( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + Tensor& output) { + bool conv1d = input.dim() == 3; + // input = [n, c, h, w] + const int n = input.size(0); + const int c = input.size(1); + const int h = conv1d ? 1 : input.size(2); + const int w = conv1d ? input.size(2) : input.size(3); + // weight = [oc/groups, wc, wh, ww] + const int wc = weight.size(1); + const int wh = conv1d ? 1 : weight.size(2); + const int ww = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oc, oh, ow] + const int oc = output.size(1); + const int oh = conv1d ? 1 : output.size(2); + const int ow = conv1d ? output.size(2) : output.size(3); + + float* __restrict__ p_out = output.mutable_data_ptr(); + const float* __restrict__ p_in = input.const_data_ptr(); + const float* __restrict__ p_weight = weight.const_data_ptr(); + const float* __restrict__ p_bias = bias.const_data_ptr(); + + transposed_conv2d_nchw_core_generic<>( + p_in, + p_weight, + p_bias, + p_out, + n, + c, + h, + w, + oc, + wc, + wh, + ww, + oh, + ow, + conv1d ? 1 : stride[0], + conv1d ? stride[0] : stride[1], + conv1d ? 0 : padding[0], + conv1d ? padding[0] : padding[1], + conv1d ? 1 : dilation[0], + conv1d ? dilation[0] : dilation[1], + groups); +} + +void transposed_convolution_nhwc( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + Tensor& output) { + bool conv1d = input.dim() == 3; + // input = [n, h, w, c] + const int n = input.size(0); + const int h = conv1d ? 1 : input.size(1); + const int w = conv1d ? input.size(1) : input.size(2); + const int c = conv1d ? input.size(2) : input.size(3); + + // weight = [oc/groups, wh, ww, wc] + const int wh = conv1d ? 1 : weight.size(1); + const int ww = conv1d ? weight.size(1) : weight.size(2); + const int wc = conv1d ? weight.size(2) : weight.size(3); + + // output = [n, oh, ow, oc] + const int oc = conv1d ? output.size(2) : output.size(3); + const int oh = conv1d ? 1 : output.size(1); + const int ow = conv1d ? output.size(1) : output.size(2); + + float* __restrict__ p_out = output.mutable_data_ptr(); + const float* __restrict__ p_in = input.const_data_ptr(); + const float* __restrict__ p_weight = weight.const_data_ptr(); + const float* __restrict__ p_bias = bias.const_data_ptr(); + + transposed_conv2d_nhwc_core_generic<>( + p_in, + p_weight, + p_bias, + p_out, + n, + h, + w, + c, + oc, + wh, + ww, + wc, + oh, + ow, + conv1d ? 1 : stride[0], + conv1d ? stride[0] : stride[1], + conv1d ? 0 : padding[0], + conv1d ? padding[0] : padding[1], + conv1d ? 1 : dilation[0], + conv1d ? dilation[0] : dilation[1], + groups); +} + +Tensor& transposed_convolution_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + ET_UNUSED IntArrayRef output_padding, + int64_t groups, + bool channel_last, + Tensor& output) { + if (channel_last) { + transposed_convolution_nhwc( + input, weight, bias, stride, padding, dilation, groups, output); + } else { + transposed_convolution_nchw( + input, weight, bias, stride, padding, dilation, groups, output); + } + + return output; +} + +// The quantized transposed_convolution kernel. in_scale and weight_scale are +// implicit in bias_scale, since it is a product of the two. The kernel will +// branch to quantized::conv1d or quantized::conv2d based on the dimensionality +// of activation tensor. +void quantized_transposed_conv_nchw( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + float output_scale, + int32_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, c, h, w] + const int n = input.size(0); + const int c = input.size(1); + const int h = conv1d ? 1 : input.size(2); + const int w = conv1d ? input.size(2) : input.size(3); + // weight = [oc/groups, wc, wh, ww] + const int wc = weight.size(1); + const int wh = conv1d ? 1 : weight.size(2); + const int ww = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oc, oh, ow] + const int oc = out.size(1); + const int oh = conv1d ? 1 : out.size(2); + const int ow = conv1d ? out.size(2) : out.size(3); + + ScalarType out_dtype = out.scalar_type(); + ScalarType weight_dtype = weight.scalar_type(); + // Bool flag to check if weight tensor is quantized per-tensor or + // per-channel + bool per_tensor_quantized = bias_scale.numel() == 1; + +#define typed_quantized_conv2d_core(w_type, o_type) \ + transposed_conv2d_nchw_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + c, \ + h, \ + w, \ + oc, \ + wc, \ + wh, \ + ww, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point.const_data_ptr(), \ + bias_scale.const_data_ptr(), \ + output_scale, \ + (o_type)output_zero_point, \ + per_tensor_quantized); + +#define typed_weight_dtype(out_dtype) \ + switch (weight_dtype) { \ + case ScalarType::Byte: { \ + typed_quantized_conv2d_core(uint8_t, out_dtype); \ + break; \ + } \ + default: \ + ET_DCHECK_MSG( \ + false, \ + "Unhandled weight dtype %s", \ + torch::executor::toString(weight_dtype)); \ + } + + switch (out_dtype) { + case ScalarType::Byte: { + typed_weight_dtype(uint8_t); + break; + } + default: + ET_DCHECK_MSG( + false, + "Unhandled out dtype %s", + torch::executor::toString(out_dtype)); + } + +#undef typed_weight_dtype +#undef typed_quantized_conv2d_core +} + +void quantized_transposed_conv_nhwc( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + float output_scale, + int32_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, h, w, c] + const int n = input.size(0); + const int h = conv1d ? 1 : input.size(1); + const int w = conv1d ? input.size(1) : input.size(2); + const int c = conv1d ? input.size(2) : input.size(3); + // weight = [oc/groups, wh, ww, wc] + const int wh = conv1d ? 1 : weight.size(1); + const int ww = conv1d ? weight.size(1) : weight.size(2); + const int wc = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oh, ow, oc] + const int oc = conv1d ? out.size(2) : out.size(3); + const int oh = conv1d ? 1 : out.size(1); + const int ow = conv1d ? out.size(1) : out.size(2); + + ScalarType out_dtype = out.scalar_type(); + ScalarType weight_dtype = weight.scalar_type(); + // Bool flag to check if weight tensor is quantized per-tensor or + // per-channel + bool per_tensor_quantized = bias_scale.numel() == 1; + +#define typed_quantized_conv2d_core(w_type, o_type) \ + transposed_conv2d_nhwc_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + h, \ + w, \ + c, \ + oc, \ + wh, \ + ww, \ + wc, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point.const_data_ptr(), \ + bias_scale.const_data_ptr(), \ + output_scale, \ + (o_type)output_zero_point, \ + per_tensor_quantized); + +#define typed_weight_dtype(out_dtype) \ + switch (weight_dtype) { \ + case ScalarType::Byte: { \ + typed_quantized_conv2d_core(uint8_t, out_dtype); \ + break; \ + } \ + default: \ + ET_DCHECK_MSG( \ + false, \ + "Unhandled weight dtype %s", \ + torch::executor::toString(weight_dtype)); \ + } + + switch (out_dtype) { + case ScalarType::Byte: { + typed_weight_dtype(uint8_t); + break; + } + default: + ET_DCHECK_MSG( + false, + "Unhandled out dtype %s", + torch::executor::toString(out_dtype)); + } + +#undef typed_weight_dtype +#undef typed_quantized_conv2d_core +} + +Tensor& quantized_transposed_conv_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + IntArrayRef output_padding, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + bool channel_last, + Tensor& out) { + if (channel_last) { + quantized_transposed_conv_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + out); + } else { + quantized_transposed_conv_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + out); + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_transposed_convolution.h b/backends/cadence/generic/operators/op_transposed_convolution.h new file mode 100644 index 00000000000..7c6a632fbf6 --- /dev/null +++ b/backends/cadence/generic/operators/op_transposed_convolution.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& transposed_convolution_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + ::executorch::aten::IntArrayRef output_padding, + int64_t groups, + bool channel_last, + ::executorch::aten::Tensor& output); + +::executorch::aten::Tensor& quantized_transposed_conv_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + ::executorch::aten::IntArrayRef output_padding, + int64_t groups, + int64_t in_zero_point, + const ::executorch::aten::Tensor& weight_zero_point, + const ::executorch::aten::Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + bool channel_last, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_transposed_im2row.cpp b/backends/cadence/generic/operators/op_transposed_im2row.cpp new file mode 100644 index 00000000000..2266a46952d --- /dev/null +++ b/backends/cadence/generic/operators/op_transposed_im2row.cpp @@ -0,0 +1,252 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "executorch/backends/cadence/generic/operators/op_transposed_im2row.h" + +#include + +#include +#include + +#ifndef DISABLE_ALWAYS_INLINE +#define ALWAYS_INLINE __attribute__((always_inline)) +#else +#define ALWAYS_INLINE inline +#endif + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +template +ALWAYS_INLINE void transposed_im2row_( + const T* __restrict__ data_im, + const int32_t in_zero_point, + /* input parameters*/ + const int32_t channels, + const int32_t height, + const int32_t width, + /* output parameters */ + const int32_t out_height, + const int32_t out_width, + /* convolution parameters */ + const int32_t kernel_h, + const int32_t kernel_w, + const int32_t pad_h, + const int32_t pad_w, + const int32_t stride_h, + const int32_t stride_w, + const int32_t dilation_h, + const int32_t dilation_w, + T* __restrict__ data_col, + bool channels_last) { + // Consider convolving the input image of dimensions channels * height * width + // (or height * width * channels for NHWC layout) with a filter of dimensions + // channels * kernels_h * kernels_w. Assume that this convolution will produce + // an output of dimensions out_height x out_width. For each point the output, + // im2row takes the data from the input that is used in the computation of + // that output point, and flattens it into a vector of size channels_col = + // channels * kernel_h * kernel_w. The output of im2row will therefore be a 2D + // array of size (out_height * out_width) x channels_col + const int32_t channels_col = channels * kernel_h * kernel_w; + + // If the layout is NHWC, we can copy 'channels' worth of contiguous data + // points when performing im2row. + if (channels_last) { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + T* __restrict__ out_seg = data_col + i_col * channels_col; + // Each point in the output domain is the result of applying a filter of + // size kernel_h x kernel_w x channels on the input. But since channels + // is contiguous, we will not explicitly have a loop for it. + for (int _kh = 0; _kh < kernel_h; ++_kh) { + // h_im and w_im are the actual height and width coordinates of + // the input tensor that we need to copy to the output. + int32_t h_im = + _h - ((kernel_h - 1) * dilation_h) + _kh * dilation_h + pad_h; + if (h_im < 0 || h_im >= stride_h * height || h_im % stride_h != 0) { + // ET_DCHECK_MSG(_kh * kernel_w * channels + kernel_w * channels <= + // channels_col, "Access out of bounds"); + std::fill_n( + out_seg + _kh * kernel_w * channels, + kernel_w * channels, + T(in_zero_point)); + continue; + } + for (int _kw = 0; _kw < kernel_w; ++_kw) { + int32_t w_im = + _w - ((kernel_w - 1) * dilation_w) + _kw * dilation_w + pad_w; + // h_im and w_im are the actual height and width coordinates of the + // input tensor from where we need to copy 'channels' points. + const T* __restrict__ slice_im = data_im + + ((h_im / stride_h) * width + (w_im / stride_w)) * channels; + T* __restrict__ slice_col = + out_seg + (_kh * kernel_w + _kw) * channels; + // If the coordinates were within the input domain, we copy + // 'channels' contiguous values. Otherwise we will fill the output + // with 0's. + // ET_DCHECK_MSG((_kh * kernel_w + _kw + 1) * channels <= + // channels_col, "Access out of bounds"); + if (w_im < 0 || w_im >= stride_w * width || w_im % stride_w != 0) { + std::fill_n(slice_col, channels, T(in_zero_point)); + } else { + memcpy(slice_col, slice_im, channels * sizeof(T)); + } + } + } + } + } + } else { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + T* __restrict__ slice_col = data_col + i_col * channels_col; + + // Each point in the output domain is the result of applying a + // filter of size chanenls * kernel_h x kernel_w on the input + for (int _c = 0; _c < channels; ++_c) { + for (int _kh = 0; _kh < kernel_h; ++_kh) { + // c_col is the linearized access in the channels_col length vector. + int32_t c_col = (_c * kernel_h + _kh) * kernel_w; + // h_im and w_im are the actual height and width coordinates of + // the input tensor that we need to copy to the output. + int32_t h_im = + _h - ((kernel_h - 1) * dilation_h) + _kh * dilation_h + pad_h; + if (h_im < 0 || h_im >= stride_h * height || h_im % stride_h != 0) { + // ET_CHECK_MSG(c_col + kernel_w <= channels_col, "Access out of + // bounds"); + std::fill_n(slice_col + c_col, kernel_w, T(in_zero_point)); + continue; + } + for (int _kw = 0; _kw < kernel_w; ++_kw) { + int32_t w_im = + _w - ((kernel_w - 1) * dilation_w) + _kw * dilation_w + pad_w; + // If the current data access is within the input tensor, copy + // the value + // ET_CHECK_MSG(c_col + _kw <= channels_col, "Access out of + // bounds"); + slice_col[c_col + _kw] = + (w_im < 0 || w_im >= stride_w * width || w_im % stride_w != 0) + ? static_cast(in_zero_point) + : data_im + [(_c * height + (h_im / stride_h)) * width + + (w_im / stride_w)]; + } + } + } + } + } + } +} + +Tensor& transposed_im2row_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef output_padding, + const Tensor& in_zero_point, + bool channel_last, + Tensor& out) { + // Compute the input tensor's dims + bool unit_height = input.dim() == 3; + const int32_t batch_size = input.size(0); + const int32_t in_c = + channel_last ? input.size(3 - unit_height) : input.size(1); + const int32_t in_h = + unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); + const int32_t in_w = + channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); + + // Get the kernel parameters + int32_t kernel_h = kernel_size[0]; + int32_t kernel_w = kernel_size[1]; + int32_t dilation_h = dilation[0]; + int32_t dilation_w = dilation[1]; + int32_t pad_h = padding[0]; + int32_t pad_w = padding[1]; + int32_t stride_h = stride[0]; + int32_t stride_w = stride[1]; + int32_t out_pad_h = output_padding[0]; + int32_t out_pad_w = output_padding[1]; + + // If we were to apply a transposed convolution on the input tensor, compute + // the output height and width. + int32_t out_h = (in_h - 1) * stride_h - 2 * pad_h + + dilation_h * (kernel_h - 1) + out_pad_h + 1; + int32_t out_w = (in_w - 1) * stride_w - 2 * pad_w + + dilation_w * (kernel_w - 1) + out_pad_w + 1; + + ET_DCHECK_MSG( + (out_h * out_w) == out.size(1), "dimension mismatch for output"); + ET_DCHECK_MSG( + (kernel_h * kernel_w * in_c) == out.size(2), + "dimension mismatch for output"); + + // Check if the input is per-tensor quantized or per-channel quantized. The + // zero point for each batch could differ for per-channel quantized input. + bool per_tensor_quantized = in_zero_point.numel() == 1; + +#define typed_transposed_im2row(dtype, ctype) \ + case ScalarType::dtype: { \ + const ctype* __restrict__ in_data = input.const_data_ptr(); \ + ctype* __restrict__ out_data = out.mutable_data_ptr(); \ + const int32_t* __restrict__ zero_point = \ + in_zero_point.const_data_ptr(); \ + int32_t in_plane = in_c * in_h * in_w; \ + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ + for (int32_t n = 0; n < batch_size; ++n) { \ + transposed_im2row_( \ + &in_data[n * in_plane], \ + per_tensor_quantized ? zero_point[0] : zero_point[n], \ + in_c, \ + in_h, \ + in_w, \ + out_h, \ + out_w, \ + kernel_h, \ + kernel_w, \ + pad_h, \ + pad_w, \ + stride_h, \ + stride_w, \ + dilation_h, \ + dilation_w, \ + &out_data[n * out_plane], \ + channel_last); \ + } \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + typed_transposed_im2row(Float, float); + typed_transposed_im2row(Byte, uint8_t); + default: + ET_DCHECK_MSG( + false, + "transposed im2row not implemented for dtype %s", + torch::executor::toString(dtype)); + } +#undef typed_transposed_im2row + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_transposed_im2row.h b/backends/cadence/generic/operators/op_transposed_im2row.h new file mode 100644 index 00000000000..d86c02e2442 --- /dev/null +++ b/backends/cadence/generic/operators/op_transposed_im2row.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& transposed_im2row_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef kernel_size, + ::executorch::aten::IntArrayRef dilation, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef output_padding, + const ::executorch::aten::Tensor& in_zero_point, + bool channel_last, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_where_scalar.cpp b/backends/cadence/generic/operators/op_where_scalar.cpp new file mode 100644 index 00000000000..33c3afdc9f0 --- /dev/null +++ b/backends/cadence/generic/operators/op_where_scalar.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& where_Scalar_out( + ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& condition, + const double val1, + const double val2, + ::executorch::aten::Tensor& out) { + const float val1_f = static_cast(val1); + const float val2_f = static_cast(val2); + for (int i = 0; i < out.numel(); ++i) { + out.mutable_data_ptr()[i] = + condition.const_data_ptr()[i] ? val1_f : val2_f; + } + + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_where_scalar.h b/backends/cadence/generic/operators/op_where_scalar.h new file mode 100644 index 00000000000..80a0b8054d5 --- /dev/null +++ b/backends/cadence/generic/operators/op_where_scalar.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& where_Scalar_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& condition, + double val1, + double val2, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/quantized_linear.h b/backends/cadence/generic/operators/quantized_linear.h new file mode 100644 index 00000000000..1a7d0390dd4 --- /dev/null +++ b/backends/cadence/generic/operators/quantized_linear.h @@ -0,0 +1,203 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace impl::generic::quantized { + +constexpr size_t kTensorDimensionLimit = 16; + +template +inline __attribute__((always_inline)) void quantized_linear_per_tensor_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + const int64_t leading_dims = + ::executorch::runtime::getLeadingDims(src, src.dim() - 1); + const int64_t out_dim = weight.size(0); // = out_dim + const int64_t in_dim = weight.size(1); // = in_dim + + const IT* __restrict__ in_data = src.const_data_ptr(); + const WT* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + IT* __restrict__ out_data = out.mutable_data_ptr(); + // Compute the requant_scale from out_multiplier and out_shift + const float requant_scale = + -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift); + for (size_t i = 0; i < leading_dims; ++i) { + for (size_t j = 0; j < out_dim; ++j) { + int32_t sum = bias_data[j]; + for (size_t k = 0; k < in_dim; ++k) { + int32_t x = (int32_t)in_data[i * in_dim + k] - src_zero_point; + int32_t w = + (int32_t)weight_data[j * in_dim + k] - (int32_t)weight_zero_point; + sum += x * w; + } + out_data[i * out_dim + j] = ::impl::generic::kernels::quantize( + sum, requant_scale, out_zero_point); + } + } +} + +template +inline __attribute__((always_inline)) void quantized_linear_per_tensor_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + const ::executorch::aten::Tensor& weight_zero_point_t, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + // Get the zero_point of weight. + int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; + quantized_linear_per_tensor_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + out); +} + +template +inline __attribute__((always_inline)) void quantized_linear_per_channel_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + int64_t weight_zero_point, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + const int64_t leading_dims = + ::executorch::runtime::getLeadingDims(src, src.dim() - 1); + const int64_t out_dim = weight.size(0); // = out_dim + const int64_t in_dim = weight.size(1); // = in_dim + + const T* __restrict__ in_data = src.const_data_ptr(); + const WT* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + + for (size_t i = 0; i < leading_dims; ++i) { + for (size_t j = 0; j < out_dim; ++j) { + int32_t sum = bias_data[j]; + for (size_t k = 0; k < in_dim; ++k) { + int32_t x = (int32_t)in_data[i * in_dim + k] - src_zero_point; + int32_t w = + (int32_t)weight_data[j * in_dim + k] - (int32_t)weight_zero_point; + sum += x * w; + } + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = + -out_multiplier_data[j] * 1.0 / (1 << 31) * pow(2, out_shift_data[j]); + out_data[i * out_dim + j] = + ::impl::generic::kernels::quantize(sum, out_scale, out_zero_point); + } + } +} + +template +inline __attribute__((always_inline)) void quantized_linear_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + int64_t weight_zero_point, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + if (out_multiplier.numel() == 1) { + // Use per-tensor quantization kernel. + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + quantized_linear_per_tensor_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier_data[0], + out_shift_data[0], + out_zero_point, + out); + return; + } + + // Use per-channel quantization kernel. + quantized_linear_per_channel_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + out); +} + +template +inline __attribute__((always_inline)) void quantized_linear_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + const ::executorch::aten::Tensor& weight_zero_point_t, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + // Get the zero_point of weight. + int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; + quantized_linear_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + out); +} + +} // namespace impl::generic::quantized diff --git a/backends/cadence/generic/operators/quantized_op_macros.h b/backends/cadence/generic/operators/quantized_op_macros.h new file mode 100644 index 00000000000..4ac64c6f49d --- /dev/null +++ b/backends/cadence/generic/operators/quantized_op_macros.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#define DECLARE_POINTWISE_TENSOR_QUANTIZED_BINARY_OP(BINARY_FUNC_NAME, OP) \ + template \ + void BINARY_FUNC_NAME( \ + const ::executorch::aten::Tensor& X, \ + float X_scale, \ + int32_t X_zero_point, \ + const ::executorch::aten::Tensor& Y, \ + float Y_scale, \ + int32_t Y_zero_point, \ + float out_scale, \ + int32_t out_zero_point, \ + ::executorch::aten::Tensor& out) { \ + float inv_out_scale = 1.0f / out_scale; \ + ::torch::executor::apply_binary_elementwise_fn( \ + [X_scale, \ + X_zero_point, \ + Y_scale, \ + Y_zero_point, \ + inv_out_scale, \ + out_zero_point](const T x_val, const T y_val) { \ + float x = ::impl::generic::kernels::dequantize( \ + x_val, X_scale, X_zero_point); \ + float y = ::impl::generic::kernels::dequantize( \ + y_val, Y_scale, Y_zero_point); \ + float z = x OP y; \ + return ::impl::generic::kernels::quantize( \ + z, inv_out_scale, out_zero_point); \ + }, \ + X, \ + Y, \ + out); \ + } diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl new file mode 100644 index 00000000000..af271ace5ac --- /dev/null +++ b/backends/cadence/generic/operators/targets.bzl @@ -0,0 +1,500 @@ +load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + # Individual operator targets with optimized dependencies + + # Type utilities for Cadence quantized operators + runtime.cxx_library( + name = "cadence_type_util", + exported_headers = ["cadence_type_util.h"], + ) + + runtime.cxx_library( + name = "quantized_op_macros", + exported_headers = ["quantized_op_macros.h"], + exported_deps = [ + ":cadence_type_util", + "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/runtime/kernel:kernel_includes", + ] + ) + + runtime.cxx_library( + name = "quantized_linear", + exported_headers = ["quantized_linear.h"], + exported_deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + ] + ) + + runtime.cxx_library( + name = "op_dequantize_per_tensor", + srcs = ["op_dequantize_per_tensor.cpp"], + exported_headers = ["op_dequantize_per_tensor.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantize_per_tensor", + srcs = ["op_quantize_per_tensor.cpp"], + exported_headers = ["op_quantize_per_tensor.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_where_scalar", + srcs = ["op_where_scalar.cpp"], + exported_headers = ["op_where_scalar.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_rope", + srcs = ["op_rope.cpp"], + exported_headers = ["op_rope.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_linalg_svd", + srcs = ["op_linalg_svd.cpp"], + headers = ["op_linalg_svd.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/util:tensor_util", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_roi_align_box_processor", + srcs = ["op_roi_align_box_processor.cpp"], + exported_headers = ["op_roi_align_box_processor.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_add", + srcs = ["op_quantized_add.cpp"], + exported_headers = ["op_quantized_add.h"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_conv1d", + srcs = ["op_quantized_conv1d.cpp"], + exported_headers = ["op_quantized_conv1d.h"], + platforms = CXX, + deps = [ + ":cadence_type_util", + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_conv2d", + srcs = ["op_quantized_conv2d.cpp"], + exported_headers = ["op_quantized_conv2d.h"], + platforms = CXX, + deps = [ + ":cadence_type_util", + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_fully_connected", + srcs = ["op_quantized_fully_connected.cpp"], + exported_headers = ["op_quantized_fully_connected.h"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", + ":quantized_linear", + ":quantized_op_macros", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_layer_norm", + srcs = ["op_quantized_layer_norm.cpp"], + exported_headers = ["op_quantized_layer_norm.h"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_linear", + srcs = ["op_quantized_linear.cpp"], + exported_headers = ["op_quantized_linear.h"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", + ":quantized_linear", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_relu", + srcs = ["op_quantized_relu.cpp"], + exported_headers = ["op_quantized_relu.h"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_matmul", + srcs = ["op_quantized_matmul.cpp"], + exported_headers = ["op_quantized_matmul.h"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_mul", + srcs = ["op_quantized_mul.cpp"], + exported_headers = ["op_quantized_mul.h"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", + "//executorch/kernels/portable/cpu:scalar_utils", + ":quantized_op_macros", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_softmax", + srcs = ["op_quantized_softmax.cpp"], + exported_headers = ["op_quantized_softmax.h"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/kernels/portable/cpu/util:reduce_util", + "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_embedding_byte", + srcs = ["op_quantized_embedding_byte.cpp"], + exported_headers = ["op_quantized_embedding_byte.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_requantize", + srcs = ["op_requantize.cpp"], + exported_headers = ["op_requantize.h"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + + runtime.cxx_library( + name = "op_softmax", + srcs = ["op_softmax.cpp"], + exported_headers = ["op_softmax.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:reduce_util", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_conv1d", + srcs = ["op_conv1d.cpp"], + exported_headers = ["op_conv1d.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_conv2d", + srcs = ["op_conv2d.cpp"], + exported_headers = ["op_conv2d.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_conv3d", + srcs = ["op_conv3d.cpp"], + exported_headers = ["op_conv3d.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_avg_pool2d", + srcs = ["op_avg_pool2d.cpp"], + exported_headers = ["op_avg_pool2d.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_fully_connected", + srcs = ["op_fully_connected.cpp"], + exported_headers = ["op_fully_connected.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_idma_copy", + srcs = ["op_idma_copy.cpp"], + exported_headers = ["op_idma_copy.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_idma_wait", + srcs = ["op_idma_wait.cpp"], + exported_headers = ["op_idma_wait.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_im2row", + srcs = ["op_im2row.cpp"], + exported_headers = ["op_im2row.h"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + exported_deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_transposed_im2row", + srcs = ["op_transposed_im2row.cpp"], + exported_headers = ["op_transposed_im2row.h"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", + ], + exported_deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_transposed_convolution", + srcs = ["op_transposed_convolution.cpp"], + exported_headers = ["op_transposed_convolution.h"], + platforms = CXX, + deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + ], + exported_deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) diff --git a/backends/cadence/hifi/kernels/kernels.cpp b/backends/cadence/hifi/kernels/kernels.cpp index feabe6e1828..07f0ac960b1 100644 --- a/backends/cadence/hifi/kernels/kernels.cpp +++ b/backends/cadence/hifi/kernels/kernels.cpp @@ -10,7 +10,6 @@ #include #include -namespace cadence { namespace impl { namespace HiFi { namespace kernels { @@ -23,17 +22,9 @@ memcpy(void* dst, const void* src, size_t num_bytes) { void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) { constexpr size_t kAlignment = 16; // 16-byte alignment for vectorized operations - ET_LOG( - Info, - "Attempting to allocate %zu bytes of temp memory (16-byte aligned)", - size); Result temp_mem_res = ctx.allocate_temp(size, kAlignment); if (temp_mem_res.ok()) { void* ptr = temp_mem_res.get(); - ET_LOG( - Info, - "Successfully allocated temp memory at %p (16-byte aligned)", - ptr); return ptr; } else { ET_LOG( @@ -48,8 +39,8 @@ void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) { template __attribute__((always_inline)) T quantize(const float x, float scale, int32_t zero_point) { - constexpr float min_val = std::numeric_limits::min(); - constexpr float max_val = std::numeric_limits::max(); + constexpr float min_val = static_cast(std::numeric_limits::min()); + constexpr float max_val = static_cast(std::numeric_limits::max()); float tmp = roundf(x * scale + zero_point); return std::max(std::min(tmp, max_val), min_val); } @@ -65,8 +56,8 @@ void quantize( xtfloatx2 scale_vec = (xtfloatx2)scale; xtfloatx2 zero_vec = XT_FLOAT_SX2(zero_point, 0); - constexpr float min_val = std::numeric_limits::min(); - constexpr float max_val = std::numeric_limits::max(); + constexpr float min_val = static_cast(std::numeric_limits::min()); + constexpr float max_val = static_cast(std::numeric_limits::max()); const xtfloatx2* __restrict__ p0 = (const xtfloatx2* __restrict__)x; ae_valign va0 = XT_LASX2PP(p0); @@ -136,6 +127,7 @@ typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); typed_quantize_val(uint16_t); +typed_quantize_val(int32_t); #undef typed_quantize_val #define typed_quantize_vec(dtype) \ @@ -159,6 +151,7 @@ typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); typed_dequantize_val(uint16_t); +typed_dequantize_val(int32_t); #undef typed_dequantize_val #define typed_dequantize_vec(dtype) \ @@ -175,7 +168,6 @@ typed_dequantize_vec(uint16_t); typed_dequantize_vec(int32_t); #undef typed_dequantize_vec -}; // namespace kernels -}; // namespace HiFi -}; // namespace impl -}; // namespace cadence +} // namespace kernels +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/kernels/kernels.h b/backends/cadence/hifi/kernels/kernels.h index 2574b9d60ee..08343e2528b 100644 --- a/backends/cadence/hifi/kernels/kernels.h +++ b/backends/cadence/hifi/kernels/kernels.h @@ -216,7 +216,6 @@ extern "C" WORD32 xa_nn_transpose_32_32( WORD32 num_out_dims, WORD32 num_inp_dims); -namespace cadence { namespace impl { namespace HiFi { namespace kernels { @@ -285,7 +284,6 @@ void dequantize( int32_t zero_point, size_t size); -}; // namespace kernels -}; // namespace HiFi -}; // namespace impl -}; // namespace cadence +} // namespace kernels +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/CMakeLists.txt b/backends/cadence/hifi/operators/CMakeLists.txt index 6bd63c6d9f6..26555da9760 100644 --- a/backends/cadence/hifi/operators/CMakeLists.txt +++ b/backends/cadence/hifi/operators/CMakeLists.txt @@ -96,8 +96,8 @@ add_library( "op_quantize_per_tensor.cpp" "op_quantized_relu_out.cpp" "op_dequantize_per_tensor.cpp" - "op_quantized_conv_nchw_out.cpp" - "op_quantized_conv_nhwc_out.cpp" + "op_quantized_conv2d_nchw_out.cpp" + "op_quantized_conv2d_nhwc_out.cpp" "op_quantized_fully_connected_out" ) target_include_directories( diff --git a/backends/cadence/hifi/operators/op_add.cpp b/backends/cadence/hifi/operators/op_add.cpp index 9823844af7f..445cf3d9f2b 100644 --- a/backends/cadence/hifi/operators/op_add.cpp +++ b/backends/cadence/hifi/operators/op_add.cpp @@ -16,6 +16,8 @@ #include #include +#include + using executorch::aten::Scalar; using executorch::aten::ScalarType; using executorch::aten::Tensor; @@ -24,7 +26,6 @@ using executorch::runtime::CppTypeToScalarType; using executorch::runtime::KernelRuntimeContext; using torch::executor::Error; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -185,10 +186,25 @@ Tensor& add_out( for (int i = 0; i < b.dim(); i++) inp2_shape[i + off_b] = b.size(i); - xa_nn_elm_add_broadcast_4D_f32xf32_f32( - out_data, out_shape, a_data, inp1_shape, b_data, inp2_shape); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_add_broadcast_4D_f32xf32_f32, + out_data, + out_shape, + a_data, + inp1_shape, + b_data, + inp2_shape); } else { - xa_nn_elm_add_f32xf32_f32(out_data, a_data, b_data, out.numel()); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_add_f32xf32_f32, + out_data, + a_data, + b_data, + out.numel()); } return out; @@ -221,4 +237,3 @@ Tensor& add_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_atan2.cpp b/backends/cadence/hifi/operators/op_atan2.cpp index fd595a935cb..1546c1e3a7f 100644 --- a/backends/cadence/hifi/operators/op_atan2.cpp +++ b/backends/cadence/hifi/operators/op_atan2.cpp @@ -12,6 +12,8 @@ #include #include +#include + using executorch::aten::ScalarType; using executorch::aten::Tensor; using executorch::runtime::isFloatingType; @@ -24,7 +26,6 @@ using torch::executor::native::utils::apply_bitensor_elementwise_fn; using torch::executor::native::utils::get_compute_type; using torch::executor::native::utils::SupportedTensorDtypes; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -182,7 +183,15 @@ Tensor& atan2_out( for (int i = 0; i < b_dim; i++) p_inp1_shape[i] = b.size(i); - xa_nn_broadcast_32_32(ptr1, p_out_shape, pin1, p_inp1_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_32_32, + ptr1, + p_out_shape, + pin1, + p_inp1_shape, + out_dim); FLOAT32* __restrict__ p_out = (FLOAT32* __restrict__)out.mutable_data_ptr(); @@ -225,4 +234,3 @@ Tensor& atan2_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/hifi/operators/op_bitwise_and.cpp b/backends/cadence/hifi/operators/op_bitwise_and.cpp index a6cf17aa4d8..82b29b8bcd1 100644 --- a/backends/cadence/hifi/operators/op_bitwise_and.cpp +++ b/backends/cadence/hifi/operators/op_bitwise_and.cpp @@ -14,6 +14,8 @@ #include #include +#include + using exec_aten::Scalar; using exec_aten::ScalarType; using exec_aten::Tensor; @@ -26,7 +28,6 @@ using executorch::runtime::tensors_have_same_dim_order; using torch::executor::Error; using torch::executor::resize_to_broadcast_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -97,14 +98,37 @@ Tensor& bitwise_and_Tensor_out( for (int i = 0; i < b_dim; i++) p_inp2_shape[i] = b.size(i); - xa_nn_broadcast_8_8(ptr1, p_out_shape, pin1, p_inp1_shape, out_dim); - - xa_nn_broadcast_8_8(ptr2, p_out_shape, pin2, p_inp2_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_8_8, + ptr1, + p_out_shape, + pin1, + p_inp1_shape, + out_dim); + + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_8_8, + ptr2, + p_out_shape, + pin2, + p_inp2_shape, + out_dim); const WORD8* __restrict__ p_inp1 = (const WORD8* __restrict__)ptr1; const WORD8* __restrict__ p_inp2 = (const WORD8* __restrict__)ptr2; - xa_nn_elm_logicaland_boolxbool_bool(p_out, p_inp1, p_inp2, num_elm); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_logicaland_boolxbool_bool, + p_out, + p_inp1, + p_inp2, + num_elm); } else if (a_is_broadcasted && !b_is_broadcasted) { WORD8* __restrict__ ptr1 = (WORD8* __restrict__)kernels::allocate_temp_memory(ctx, num_elm); @@ -125,11 +149,26 @@ Tensor& bitwise_and_Tensor_out( for (int i = 0; i < a_dim; i++) p_inp1_shape[i] = a.size(i); - xa_nn_broadcast_8_8(ptr1, p_out_shape, pin1, p_inp1_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_8_8, + ptr1, + p_out_shape, + pin1, + p_inp1_shape, + out_dim); const WORD8* __restrict__ p_inp1 = (const WORD8* __restrict__)ptr1; - xa_nn_elm_logicaland_boolxbool_bool(p_out, p_inp1, p_inp2, num_elm); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_logicaland_boolxbool_bool, + p_out, + p_inp1, + p_inp2, + num_elm); } else if (!a_is_broadcasted && b_is_broadcasted) { WORD8* __restrict__ ptr1 = (WORD8* __restrict__)kernels::allocate_temp_memory(ctx, num_elm); @@ -150,11 +189,26 @@ Tensor& bitwise_and_Tensor_out( for (int i = 0; i < b_dim; i++) p_inp2_shape[i] = b.size(i); - xa_nn_broadcast_8_8(ptr1, p_out_shape, pinp2, p_inp2_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_8_8, + ptr1, + p_out_shape, + pinp2, + p_inp2_shape, + out_dim); const WORD8* __restrict__ p_inp2 = (const WORD8* __restrict__)ptr1; - xa_nn_elm_logicaland_boolxbool_bool(p_out, p_inp1, p_inp2, num_elm); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_logicaland_boolxbool_bool, + p_out, + p_inp1, + p_inp2, + num_elm); } else { const WORD8* __restrict__ p_inp1 = (const WORD8* __restrict__)a.const_data_ptr(); @@ -187,4 +241,3 @@ Tensor& bitwise_and_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_bitwise_or.cpp b/backends/cadence/hifi/operators/op_bitwise_or.cpp index b8e03b43bfd..9a9722aa6a0 100644 --- a/backends/cadence/hifi/operators/op_bitwise_or.cpp +++ b/backends/cadence/hifi/operators/op_bitwise_or.cpp @@ -14,6 +14,8 @@ #include #include +#include + using exec_aten::Scalar; using exec_aten::ScalarType; using exec_aten::Tensor; @@ -26,7 +28,6 @@ using executorch::runtime::tensors_have_same_dim_order; using torch::executor::Error; using torch::executor::resize_to_broadcast_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -97,14 +98,37 @@ Tensor& bitwise_or_Tensor_out( for (int i = 0; i < b_dim; i++) p_inp2_shape[i] = b.size(i); - xa_nn_broadcast_8_8(ptr1, p_out_shape, pin1, p_inp1_shape, out_dim); - - xa_nn_broadcast_8_8(ptr2, p_out_shape, pin2, p_inp2_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_8_8, + ptr1, + p_out_shape, + pin1, + p_inp1_shape, + out_dim); + + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_8_8, + ptr2, + p_out_shape, + pin2, + p_inp2_shape, + out_dim); const WORD8* __restrict__ p_inp1 = (const WORD8* __restrict__)ptr1; const WORD8* __restrict__ p_inp2 = (const WORD8* __restrict__)ptr2; - xa_nn_elm_logicalor_boolxbool_bool(p_out, p_inp1, p_inp2, num_elm); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_logicalor_boolxbool_bool, + p_out, + p_inp1, + p_inp2, + num_elm); } else if (a_is_broadcasted && !b_is_broadcasted) { WORD8* __restrict__ ptr1 = (WORD8* __restrict__)kernels::allocate_temp_memory(ctx, num_elm); @@ -125,11 +149,26 @@ Tensor& bitwise_or_Tensor_out( for (int i = 0; i < a_dim; i++) p_inp1_shape[i] = a.size(i); - xa_nn_broadcast_8_8(ptr1, p_out_shape, pin1, p_inp1_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_8_8, + ptr1, + p_out_shape, + pin1, + p_inp1_shape, + out_dim); const WORD8* __restrict__ p_inp1 = (const WORD8* __restrict__)ptr1; - xa_nn_elm_logicalor_boolxbool_bool(p_out, p_inp1, p_inp2, num_elm); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_logicalor_boolxbool_bool, + p_out, + p_inp1, + p_inp2, + num_elm); } else if (!a_is_broadcasted && b_is_broadcasted) { WORD8* __restrict__ ptr1 = (WORD8* __restrict__)kernels::allocate_temp_memory(ctx, num_elm); @@ -150,11 +189,26 @@ Tensor& bitwise_or_Tensor_out( for (int i = 0; i < b_dim; i++) p_inp2_shape[i] = b.size(i); - xa_nn_broadcast_8_8(ptr1, p_out_shape, pinp2, p_inp2_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_8_8, + ptr1, + p_out_shape, + pinp2, + p_inp2_shape, + out_dim); const WORD8* __restrict__ p_inp2 = (const WORD8* __restrict__)ptr1; - xa_nn_elm_logicalor_boolxbool_bool(p_out, p_inp1, p_inp2, num_elm); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_logicalor_boolxbool_bool, + p_out, + p_inp1, + p_inp2, + num_elm); } else { const WORD8* __restrict__ p_inp1 = (const WORD8* __restrict__)a.const_data_ptr(); @@ -187,4 +241,3 @@ Tensor& bitwise_or_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_bitwise_xor.cpp b/backends/cadence/hifi/operators/op_bitwise_xor.cpp index 2b0595e2d1d..66b9e8cc7fe 100644 --- a/backends/cadence/hifi/operators/op_bitwise_xor.cpp +++ b/backends/cadence/hifi/operators/op_bitwise_xor.cpp @@ -14,6 +14,8 @@ #include #include +#include + using exec_aten::Scalar; using exec_aten::ScalarType; using exec_aten::Tensor; @@ -26,7 +28,6 @@ using executorch::runtime::tensors_have_same_dim_order; using torch::executor::Error; using torch::executor::resize_to_broadcast_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -97,14 +98,37 @@ Tensor& bitwise_xor_Tensor_out( for (int i = 0; i < b_dim; i++) p_inp2_shape[i] = b.size(i); - xa_nn_broadcast_8_8(ptr1, p_out_shape, pin1, p_inp1_shape, out_dim); - - xa_nn_broadcast_8_8(ptr2, p_out_shape, pin2, p_inp2_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_8_8, + ptr1, + p_out_shape, + pin1, + p_inp1_shape, + out_dim); + + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_8_8, + ptr2, + p_out_shape, + pin2, + p_inp2_shape, + out_dim); const WORD8* __restrict__ p_inp1 = (const WORD8* __restrict__)ptr1; const WORD8* __restrict__ p_inp2 = (const WORD8* __restrict__)ptr2; - xa_nn_elm_logicalxor_boolxbool_bool(p_out, p_inp1, p_inp2, num_elm); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_logicalxor_boolxbool_bool, + p_out, + p_inp1, + p_inp2, + num_elm); } else if (a_is_broadcasted && !b_is_broadcasted) { WORD8* __restrict__ ptr1 = (WORD8* __restrict__)kernels::allocate_temp_memory(ctx, num_elm); @@ -125,11 +149,26 @@ Tensor& bitwise_xor_Tensor_out( for (int i = 0; i < a_dim; i++) p_inp1_shape[i] = a.size(i); - xa_nn_broadcast_8_8(ptr1, p_out_shape, pin1, p_inp1_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_8_8, + ptr1, + p_out_shape, + pin1, + p_inp1_shape, + out_dim); const WORD8* __restrict__ p_inp1 = (const WORD8* __restrict__)ptr1; - xa_nn_elm_logicalxor_boolxbool_bool(p_out, p_inp1, p_inp2, num_elm); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_logicalxor_boolxbool_bool, + p_out, + p_inp1, + p_inp2, + num_elm); } else if (!a_is_broadcasted && b_is_broadcasted) { WORD8* __restrict__ ptr1 = (WORD8* __restrict__)kernels::allocate_temp_memory(ctx, num_elm); @@ -150,11 +189,26 @@ Tensor& bitwise_xor_Tensor_out( for (int i = 0; i < b_dim; i++) p_inp2_shape[i] = b.size(i); - xa_nn_broadcast_8_8(ptr1, p_out_shape, pinp2, p_inp2_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_8_8, + ptr1, + p_out_shape, + pinp2, + p_inp2_shape, + out_dim); const WORD8* __restrict__ p_inp2 = (const WORD8* __restrict__)ptr1; - xa_nn_elm_logicalxor_boolxbool_bool(p_out, p_inp1, p_inp2, num_elm); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_logicalxor_boolxbool_bool, + p_out, + p_inp1, + p_inp2, + num_elm); } else { const WORD8* __restrict__ p_inp1 = (const WORD8* __restrict__)a.const_data_ptr(); @@ -187,4 +241,3 @@ Tensor& bitwise_xor_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_bmm.cpp b/backends/cadence/hifi/operators/op_bmm.cpp index 0262703bb73..68c5764db91 100644 --- a/backends/cadence/hifi/operators/op_bmm.cpp +++ b/backends/cadence/hifi/operators/op_bmm.cpp @@ -22,7 +22,6 @@ using torch::executor::check_bmm_args; using torch::executor::Error; using torch::executor::get_bmm_out_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -168,4 +167,3 @@ Tensor& bmm_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_cat.cpp b/backends/cadence/hifi/operators/op_cat.cpp index d4fd51871ce..f88b61b0ad8 100644 --- a/backends/cadence/hifi/operators/op_cat.cpp +++ b/backends/cadence/hifi/operators/op_cat.cpp @@ -23,7 +23,6 @@ using torch::executor::check_cat_args; using torch::executor::Error; using torch::executor::get_cat_out_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -152,4 +151,3 @@ Tensor& cat_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_clamp.cpp b/backends/cadence/hifi/operators/op_clamp.cpp index 88930a36827..e3d5c8914a4 100644 --- a/backends/cadence/hifi/operators/op_clamp.cpp +++ b/backends/cadence/hifi/operators/op_clamp.cpp @@ -20,6 +20,8 @@ #include #include +#include + using executorch::aten::RuntimeContext; using executorch::aten::Scalar; using executorch::aten::ScalarType; @@ -43,7 +45,6 @@ using torch::executor::native::utils::promote_type_with_scalar; using torch::executor::native::utils::scalar_to; using torch::executor::native::utils::SupportedTensorDtypes; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -248,8 +249,15 @@ Tensor& clamp_Tensor_out( ctx, p_scratch != nullptr, MemoryAllocationFailed, out); const FLOAT32* p_brd_cond = (const FLOAT32*)p_scratch; - xa_nn_broadcast_32_32( - (WORD32*)p_brd_cond, out_shape, (WORD32*)inp_data, inp_shape, 4); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_32_32, + (WORD32*)p_brd_cond, + out_shape, + (WORD32*)inp_data, + inp_shape, + 4); for (int i = 0; i < 4; i++) { inp_shape[i] = out_shape[i]; @@ -325,4 +333,3 @@ Tensor& clamp_Tensor_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp index 7dce0050d7f..c091d216556 100644 --- a/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp @@ -10,23 +10,24 @@ #include #include -namespace cadence { +#include + namespace impl { namespace HiFi { namespace native { -using ::cadence::impl::HiFi::kernels::dequantize; using ::executorch::aten::ScalarType; using ::executorch::aten::Tensor; using ::executorch::runtime::KernelRuntimeContext; +using ::impl::HiFi::kernels::dequantize; void dequantize_per_tensor_out( KernelRuntimeContext& ctx, const Tensor& input, double scale, int64_t zero_point, - __ET_UNUSED int64_t quant_min, - __ET_UNUSED int64_t quant_max, + ET_UNUSED int64_t quant_min, + ET_UNUSED int64_t quant_max, ScalarType dtype, Tensor& out) { float* out_data = out.mutable_data_ptr(); @@ -36,8 +37,15 @@ void dequantize_per_tensor_out( dequantize(out_data, input_data, scale, zero_point, numel); } else if (input.scalar_type() == ScalarType::Char) { const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym8s_f32( - out_data, input_data, zero_point, scale, numel); + XT_KERNEL_CHECK( + ctx, + , + xa_nn_elm_dequantize_asym8s_f32, + out_data, + input_data, + zero_point, + scale, + numel); } else if (input.scalar_type() == ScalarType::Short) { const int16_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); @@ -46,6 +54,9 @@ void dequantize_per_tensor_out( input.scalar_type() == ScalarType::UInt16) { const uint16_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Int) { + const int32_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); } else { ET_CHECK_MSG( false, @@ -54,7 +65,66 @@ void dequantize_per_tensor_out( } } -}; // namespace native -}; // namespace HiFi -}; // namespace impl -}; // namespace cadence +void dequantize_per_tensor_asym8u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const uint8_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); +} + +void dequantize_per_tensor_asym16s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const int16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); +} + +void dequantize_per_tensor_asym16u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const uint16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); +} + +void dequantize_per_tensor_asym32s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const int32_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_dequantize_per_tensor_asym8s.cpp b/backends/cadence/hifi/operators/op_dequantize_per_tensor_asym8s.cpp new file mode 100644 index 00000000000..d1099b1a4db --- /dev/null +++ b/backends/cadence/hifi/operators/op_dequantize_per_tensor_asym8s.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +void dequantize_per_tensor_asym8s_out( + KernelRuntimeContext& ctx, + const Tensor& input, + double scale, + int64_t zero_point, + __ET_UNUSED int64_t quant_min, + __ET_UNUSED int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + const size_t numel = out.numel(); + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym8s_f32( + out_data, input_data, zero_point, scale, numel); +} + +}; // namespace native +}; // namespace HiFi +}; // namespace impl diff --git a/backends/cadence/hifi/operators/op_div.cpp b/backends/cadence/hifi/operators/op_div.cpp index 2c689ae4350..057147a2a75 100644 --- a/backends/cadence/hifi/operators/op_div.cpp +++ b/backends/cadence/hifi/operators/op_div.cpp @@ -23,7 +23,6 @@ using executorch::aten::ScalarType; using executorch::aten::Tensor; using torch::executor::Error; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -322,4 +321,3 @@ Tensor& div_out_mode( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_embedding.cpp b/backends/cadence/hifi/operators/op_embedding.cpp index d6932d9daa3..21ac3398adf 100644 --- a/backends/cadence/hifi/operators/op_embedding.cpp +++ b/backends/cadence/hifi/operators/op_embedding.cpp @@ -27,7 +27,6 @@ using torch::executor::Error; using torch::executor::KernelRuntimeContext; using torch::executor::resize_embedding_output; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -132,4 +131,3 @@ Tensor& embedding_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_eq.cpp b/backends/cadence/hifi/operators/op_eq.cpp index 124eb007f05..a76b910e379 100644 --- a/backends/cadence/hifi/operators/op_eq.cpp +++ b/backends/cadence/hifi/operators/op_eq.cpp @@ -24,7 +24,6 @@ using executorch::runtime::promoteTypes; using torch::executor::Error; using torch::executor::resize_to_broadcast_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -139,4 +138,3 @@ Tensor& eq_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_fmod.cpp b/backends/cadence/hifi/operators/op_fmod.cpp index 42cea062942..25865966649 100644 --- a/backends/cadence/hifi/operators/op_fmod.cpp +++ b/backends/cadence/hifi/operators/op_fmod.cpp @@ -33,7 +33,6 @@ using torch::executor::native::utils::extract_scalar; using torch::executor::native::utils::get_scalar_dtype; using torch::executor::native::utils::promote_type_with_scalar; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -287,4 +286,3 @@ Tensor& fmod_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_full.cpp b/backends/cadence/hifi/operators/op_full.cpp index 3d30433d378..b7c54f4fb87 100644 --- a/backends/cadence/hifi/operators/op_full.cpp +++ b/backends/cadence/hifi/operators/op_full.cpp @@ -20,7 +20,6 @@ using torch::executor::Error; using torch::executor::native::utils::extract_scalar; using torch::executor::native::utils::get_scalar_dtype; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -97,4 +96,3 @@ Tensor& full_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/hifi/operators/op_ge.cpp b/backends/cadence/hifi/operators/op_ge.cpp index 4d9c186e773..5d9111b5312 100644 --- a/backends/cadence/hifi/operators/op_ge.cpp +++ b/backends/cadence/hifi/operators/op_ge.cpp @@ -24,7 +24,6 @@ using executorch::runtime::promoteTypes; using torch::executor::Error; using torch::executor::resize_to_broadcast_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -148,4 +147,3 @@ Tensor& ge_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_gt.cpp b/backends/cadence/hifi/operators/op_gt.cpp index 4a731e75c19..5995dba3bed 100644 --- a/backends/cadence/hifi/operators/op_gt.cpp +++ b/backends/cadence/hifi/operators/op_gt.cpp @@ -24,7 +24,6 @@ using executorch::runtime::promoteTypes; using torch::executor::Error; using torch::executor::resize_to_broadcast_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -142,4 +141,3 @@ Tensor& gt_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_hardtanh.cpp b/backends/cadence/hifi/operators/op_hardtanh.cpp index 3c88e33922b..b25c5e2f87f 100644 --- a/backends/cadence/hifi/operators/op_hardtanh.cpp +++ b/backends/cadence/hifi/operators/op_hardtanh.cpp @@ -23,7 +23,6 @@ using torch::executor::native::utils::get_scalar_dtype; using torch::executor::native::utils::max_override; using torch::executor::native::utils::min_override; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -100,4 +99,3 @@ Tensor& hardtanh_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_le.cpp b/backends/cadence/hifi/operators/op_le.cpp index eec95c00bea..fb224b84369 100644 --- a/backends/cadence/hifi/operators/op_le.cpp +++ b/backends/cadence/hifi/operators/op_le.cpp @@ -23,7 +23,6 @@ using executorch::runtime::promoteTypes; using torch::executor::Error; using torch::executor::resize_to_broadcast_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -139,4 +138,3 @@ Tensor& le_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_lt.cpp b/backends/cadence/hifi/operators/op_lt.cpp index ed21a7434c5..bbff9cc0aee 100644 --- a/backends/cadence/hifi/operators/op_lt.cpp +++ b/backends/cadence/hifi/operators/op_lt.cpp @@ -24,7 +24,6 @@ using executorch::runtime::promoteTypes; using torch::executor::Error; using torch::executor::resize_to_broadcast_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -137,4 +136,3 @@ Tensor& lt_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_masked_fill.cpp b/backends/cadence/hifi/operators/op_masked_fill.cpp index 39b99c937a4..8c0a3ea3236 100644 --- a/backends/cadence/hifi/operators/op_masked_fill.cpp +++ b/backends/cadence/hifi/operators/op_masked_fill.cpp @@ -24,7 +24,6 @@ using torch::executor::resize_to_broadcast_target_size; using torch::executor::native::utils::extract_scalar; using torch::executor::native::utils::get_scalar_dtype; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -76,4 +75,3 @@ Tensor& masked_fill_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_maximum.cpp b/backends/cadence/hifi/operators/op_maximum.cpp index 592ea3bc1e1..1882967f81a 100644 --- a/backends/cadence/hifi/operators/op_maximum.cpp +++ b/backends/cadence/hifi/operators/op_maximum.cpp @@ -23,7 +23,6 @@ using torch::executor::apply_binary_elementwise_fn; using torch::executor::Error; using torch::executor::resize_to_broadcast_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -171,4 +170,3 @@ Tensor& maximum_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_mean.cpp b/backends/cadence/hifi/operators/op_mean.cpp index 4b93e55047b..514813fbe05 100644 --- a/backends/cadence/hifi/operators/op_mean.cpp +++ b/backends/cadence/hifi/operators/op_mean.cpp @@ -20,7 +20,6 @@ using executorch::runtime::ArrayRef; using torch::executor::Error; using torch::executor::optional; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -181,4 +180,3 @@ Tensor& mean_dim_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_minimum.cpp b/backends/cadence/hifi/operators/op_minimum.cpp index b78ee64882a..1f069b362fd 100644 --- a/backends/cadence/hifi/operators/op_minimum.cpp +++ b/backends/cadence/hifi/operators/op_minimum.cpp @@ -23,7 +23,6 @@ using torch::executor::apply_binary_elementwise_fn; using torch::executor::Error; using torch::executor::resize_to_broadcast_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -170,4 +169,3 @@ Tensor& minimum_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_mm.cpp b/backends/cadence/hifi/operators/op_mm.cpp index 9cf922cbf56..edfc8bb7548 100644 --- a/backends/cadence/hifi/operators/op_mm.cpp +++ b/backends/cadence/hifi/operators/op_mm.cpp @@ -22,7 +22,6 @@ using torch::executor::check_mm_args; using torch::executor::Error; using torch::executor::get_mm_out_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -156,4 +155,3 @@ Tensor& mm_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_mul.cpp b/backends/cadence/hifi/operators/op_mul.cpp index 6eb79545be7..253ad6646e4 100644 --- a/backends/cadence/hifi/operators/op_mul.cpp +++ b/backends/cadence/hifi/operators/op_mul.cpp @@ -23,7 +23,6 @@ using executorch::runtime::can_cast; using executorch::runtime::CppTypeToScalarType; using torch::executor::Error; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -187,4 +186,3 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_ne.cpp b/backends/cadence/hifi/operators/op_ne.cpp index 8bbb0c64906..f183a42452a 100644 --- a/backends/cadence/hifi/operators/op_ne.cpp +++ b/backends/cadence/hifi/operators/op_ne.cpp @@ -24,7 +24,6 @@ using executorch::runtime::promoteTypes; using torch::executor::Error; using torch::executor::resize_to_broadcast_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -139,4 +138,3 @@ Tensor& ne_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_permute_copy.cpp b/backends/cadence/hifi/operators/op_permute_copy.cpp index c5f33435733..fc162d6c7f1 100644 --- a/backends/cadence/hifi/operators/op_permute_copy.cpp +++ b/backends/cadence/hifi/operators/op_permute_copy.cpp @@ -22,7 +22,6 @@ using torch::executor::check_permute_copy_args; using torch::executor::Error; using torch::executor::get_permute_copy_out_target_size; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -171,4 +170,3 @@ Tensor& permute_copy_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_pow.cpp b/backends/cadence/hifi/operators/op_pow.cpp index 6ca7ccdebe9..e5b31cc7731 100644 --- a/backends/cadence/hifi/operators/op_pow.cpp +++ b/backends/cadence/hifi/operators/op_pow.cpp @@ -9,13 +9,14 @@ #include #include -#include #include #include #include #include #include +#include + using executorch::aten::Scalar; using executorch::aten::ScalarType; using executorch::aten::Tensor; @@ -34,7 +35,6 @@ using torch::executor::native::utils::promote_type_with_scalar; using torch::executor::native::utils::scalar_to; using torch::executor::native::utils::SupportedTensorDtypes; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -122,9 +122,25 @@ Tensor& pow_Tensor_Tensor_out( for (int i = 0; i < b_dim; i++) p_inp2_shape[i] = b.size(i); - xa_nn_broadcast_32_32(ptr1, p_out_shape, pin1, p_inp1_shape, out_dim); - - xa_nn_broadcast_32_32(ptr2, p_out_shape, pin2, p_inp2_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_32_32, + ptr1, + p_out_shape, + pin1, + p_inp1_shape, + out_dim); + + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_32_32, + ptr2, + p_out_shape, + pin2, + p_inp2_shape, + out_dim); FLOAT32* __restrict__ p_out = (FLOAT32* __restrict__)out.mutable_data_ptr(); @@ -151,8 +167,15 @@ Tensor& pow_Tensor_Tensor_out( for (int i = 0; i < a_dim; i++) p_inp1_shape[i] = a.size(i); - xa_nn_broadcast_32_32( - (WORD32*)ptr1, p_out_shape, (WORD32*)pin1, p_inp1_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_32_32, + (WORD32*)ptr1, + p_out_shape, + (WORD32*)pin1, + p_inp1_shape, + out_dim); FLOAT32* __restrict__ p_out = (FLOAT32* __restrict__)out.mutable_data_ptr(); @@ -180,7 +203,15 @@ Tensor& pow_Tensor_Tensor_out( for (int i = 0; i < b_dim; i++) p_inp1_shape[i] = b.size(i); - xa_nn_broadcast_32_32(ptr1, p_out_shape, pin1, p_inp1_shape, out_dim); + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_broadcast_32_32, + ptr1, + p_out_shape, + pin1, + p_inp1_shape, + out_dim); FLOAT32* __restrict__ p_out = (FLOAT32* __restrict__)out.mutable_data_ptr(); @@ -320,4 +351,3 @@ Tensor& pow_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp b/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp index ec649a29c5b..579a4533057 100644 --- a/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp @@ -16,14 +16,16 @@ #include #include -namespace cadence { namespace impl { namespace HiFi { namespace native { + namespace { + using ::executorch::aten::ScalarType; using ::executorch::aten::Tensor; using ::executorch::runtime::KernelRuntimeContext; +using ::impl::HiFi::kernels::quantize; // Add checks for dtype quant min/max bounds. template @@ -93,22 +95,22 @@ void quantize_per_tensor_out( const size_t numel = out.numel(); if (out.scalar_type() == ScalarType::Byte) { uint8_t* out_data = out.mutable_data_ptr(); - cadence::impl::HiFi::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); + quantize(out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Char) { int8_t* out_data = out.mutable_data_ptr(); xa_nn_elm_quantize_f32_asym8s( out_data, input_data, scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Short) { int16_t* out_data = out.mutable_data_ptr(); - cadence::impl::HiFi::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); + quantize(out_data, input_data, 1. / scale, zero_point, numel); } else if ( out.scalar_type() == ScalarType::Bits16 || out.scalar_type() == ScalarType::UInt16) { uint16_t* out_data = out.mutable_data_ptr(); - cadence::impl::HiFi::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Int) { + int32_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); } else { ET_KERNEL_CHECK_MSG( ctx, @@ -120,7 +122,66 @@ void quantize_per_tensor_out( } } +void quantize_per_tensor_asym8u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + uint8_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); +} + +void quantize_per_tensor_asym16s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + int16_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); +} + +void quantize_per_tensor_asym16u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + uint16_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); +} + +void quantize_per_tensor_asym32s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + int32_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); +} + }; // namespace native }; // namespace HiFi }; // namespace impl -}; // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantize_per_tensor_asym8s.cpp b/backends/cadence/hifi/operators/op_quantize_per_tensor_asym8s.cpp new file mode 100644 index 00000000000..552b6acf150 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantize_per_tensor_asym8s.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +void quantize_per_tensor_asym8s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym8s(out_data, input_data, scale, zero_point, numel); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_add_asym8sxasym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_add_asym8sxasym8s_asym8s_per_tensor_out.cpp index 71ca15636ef..0aed0045b2f 100644 --- a/backends/cadence/hifi/operators/op_quantized_add_asym8sxasym8s_asym8s_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_add_asym8sxasym8s_asym8s_per_tensor_out.cpp @@ -9,15 +9,14 @@ #include #include -namespace cadence { namespace impl { namespace HiFi { namespace native { using ::executorch::aten::Tensor; using ::executorch::runtime::KernelRuntimeContext; -using ::impl::reference::kernels::dequantize; -using ::impl::reference::kernels::quantize; +using ::impl::generic::kernels::dequantize; +using ::impl::generic::kernels::quantize; void quantized_add_asym8sxasym8s_asym8s_per_tensor_out( KernelRuntimeContext& ctx, @@ -169,4 +168,3 @@ void quantized_add_asym8sxasym8s_asym8s_per_tensor_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_add_asym8uxasym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_add_asym8uxasym8u_asym8u_per_tensor_out.cpp index 60f4a98ec76..39f58727702 100644 --- a/backends/cadence/hifi/operators/op_quantized_add_asym8uxasym8u_asym8u_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_add_asym8uxasym8u_asym8u_per_tensor_out.cpp @@ -9,15 +9,14 @@ #include #include -namespace cadence { namespace impl { namespace HiFi { namespace native { using ::executorch::aten::Tensor; using ::executorch::runtime::KernelRuntimeContext; -using ::impl::reference::kernels::dequantize; -using ::impl::reference::kernels::quantize; +using ::impl::generic::kernels::dequantize; +using ::impl::generic::kernels::quantize; void quantized_add_asym8uxasym8u_asym8u_per_tensor_out( KernelRuntimeContext& ctx, @@ -169,4 +168,3 @@ void quantized_add_asym8uxasym8u_asym8u_per_tensor_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..b5ab0cdbaa2 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,181 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Optimized NCHW 1D convolution for int8 x int8 -> int8 +void xa_opt_quantized_conv1d_ncl_asym8sxsym8s_asym8s( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + constexpr int kNnlibMaxDim = 3; + + WORD8* __restrict__ p_out = + (WORD8* __restrict__)out.mutable_data_ptr(); + WORD8* __restrict__ p_inp = + (WORD8* __restrict__)input.const_data_ptr(); + WORD8* __restrict__ p_kernel = + (WORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 batches = input.size(0); + WORD32 input_channels = input.size(1); + WORD32 input_width = input.size(2); + WORD32 out_channels = weight.size(0); + WORD32 kernel_channels = weight.size(1); + WORD32 kernel_width = weight.size(2); + WORD32 out_width = out.size(2); + WORD32 x_stride = stride[1]; + WORD32 x_padding = padding[1]; + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_multiplier32 = bias_scale * (1. / output_scale) * 2147483648; + WORD32 out_shift32 = 0; + WORD32 kernel_zero_bias = -weight_zero_point; + + WORD32 out_zero_bias = output_zero_point; + WORD32 out_data_format = 1; + WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( + ctx, ((batches * input_channels * input_width) + 8) * sizeof(WORD8)); + WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory( + ctx, + ((out_channels * kernel_channels * kernel_width) + 8) * sizeof(WORD8)); + WORD8* pin = (WORD8*)ALIGN_PTR(ptr1, 8); + WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8); + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = batches; + p_inp_shape[1] = input_channels; + p_inp_shape[2] = input_width; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = batches; + p_out_shape[1] = input_width; + p_out_shape[2] = input_channels; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 1}; + + xa_nn_transpose_8_8( + pin, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + WORD32 p_inp_shape1[kNnlibMaxDim]; + p_inp_shape1[0] = out_channels; + p_inp_shape1[1] = kernel_channels; + p_inp_shape1[2] = kernel_width; + + WORD32 p_out_shape1[kNnlibMaxDim]; + p_out_shape1[0] = out_channels; + p_out_shape1[1] = kernel_width; + p_out_shape1[2] = kernel_channels; + + xa_nn_transpose_8_8( + pkernel, + p_out_shape1, + p_kernel, + p_inp_shape1, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + WORD32 scratch_size = + xa_nn_conv1d_std_getsize(kernel_width, input_width, input_channels, 8); + scratch_size = scratch_size < 0 ? 0 : scratch_size; + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = pin + _n * input_channels * input_width; + WORD8* out_batch = p_out + _n * out_channels * out_width; + + xa_nn_conv1d_std_asym8xasym8( + (UWORD8*)out_batch, + (UWORD8*)in_batch, + (UWORD8*)pkernel, + p_bias, + 1, + input_width, + input_channels, + kernel_width, + out_channels, + x_stride, + x_padding, + out_width, + input_zero_bias, + kernel_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } +} + +void quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + __ET_UNUSED IntArrayRef dilation, + __ET_UNUSED int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv1d_ncl_asym8sxsym8s_asym8s( + ctx, + input, + weight, + bias, + stride, + padding, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..60e700f563b --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,181 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Optimized NCHW 1D convolution for uint8 x uint8 -> uint8 +void xa_opt_quantized_conv1d_ncl_asym8uxsym8u_asym8u( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + constexpr int kNnlibMaxDim = 3; + + UWORD8* __restrict__ p_out = + (UWORD8* __restrict__)out.mutable_data_ptr(); + UWORD8* __restrict__ p_inp = + (UWORD8* __restrict__)input.const_data_ptr(); + UWORD8* __restrict__ p_kernel = + (UWORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 batches = input.size(0); + WORD32 input_channels = input.size(1); + WORD32 input_width = input.size(2); + WORD32 out_channels = weight.size(0); + WORD32 kernel_channels = weight.size(1); + WORD32 kernel_width = weight.size(2); + WORD32 out_width = out.size(2); + WORD32 x_stride = stride[1]; + WORD32 x_padding = padding[1]; + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_multiplier32 = bias_scale * (1. / output_scale) * 2147483648; + WORD32 out_shift32 = 0; + WORD32 kernel_zero_bias = -weight_zero_point; + + WORD32 out_zero_bias = output_zero_point; + WORD32 out_data_format = 1; + UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory( + ctx, ((batches * input_channels * input_width) + 8) * sizeof(UWORD8)); + UWORD8* ptr2 = (UWORD8*)kernels::allocate_temp_memory( + ctx, + ((out_channels * kernel_channels * kernel_width) + 8) * sizeof(UWORD8)); + UWORD8* pin = (UWORD8*)ALIGN_PTR(ptr1, 8); + UWORD8* pkernel = (UWORD8*)ALIGN_PTR(ptr2, 8); + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = batches; + p_inp_shape[1] = input_channels; + p_inp_shape[2] = input_width; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = batches; + p_out_shape[1] = input_width; + p_out_shape[2] = input_channels; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 1}; + + xa_nn_transpose_8_8( + (WORD8*)pin, + p_out_shape, + (WORD8*)p_inp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + WORD32 p_inp_shape1[kNnlibMaxDim]; + p_inp_shape1[0] = out_channels; + p_inp_shape1[1] = kernel_channels; + p_inp_shape1[2] = kernel_width; + + WORD32 p_out_shape1[kNnlibMaxDim]; + p_out_shape1[0] = out_channels; + p_out_shape1[1] = kernel_width; + p_out_shape1[2] = kernel_channels; + + xa_nn_transpose_8_8( + (WORD8*)pkernel, + p_out_shape1, + (WORD8*)p_kernel, + p_inp_shape1, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + WORD32 scratch_size = + xa_nn_conv1d_std_getsize(kernel_width, input_width, input_channels, 8); + scratch_size = scratch_size < 0 ? 0 : scratch_size; + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = pin + _n * input_channels * input_width; + UWORD8* out_batch = p_out + _n * out_channels * out_width; + + xa_nn_conv1d_std_asym8uxasym8u( + out_batch, + in_batch, + pkernel, + p_bias, + 1, + input_width, + input_channels, + kernel_width, + out_channels, + x_stride, + x_padding, + out_width, + input_zero_bias, + kernel_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } +} + +void quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + __ET_UNUSED IntArrayRef dilation, + __ET_UNUSED int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv1d_ncl_asym8uxsym8u_asym8u( + ctx, + input, + weight, + bias, + stride, + padding, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..c9a3d2b58de --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Optimized NHWC 1D convolution for int8 x int8 -> int8 +void xa_opt_quantized_conv1d_nlc_asym8sxsym8s_asym8s( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + WORD8* __restrict__ p_out = + (WORD8* __restrict__)out.mutable_data_ptr(); + WORD8* __restrict__ p_inp = + (WORD8* __restrict__)input.const_data_ptr(); + WORD8* __restrict__ p_kernel = + (WORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 batches = input.size(0); + WORD32 input_channels = input.size(1); + WORD32 input_width = input.size(2); + WORD32 out_channels = weight.size(0); + WORD32 kernel_width = weight.size(2); + WORD32 out_width = out.size(2); + WORD32 x_stride = stride[1]; + WORD32 x_padding = padding[1]; + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_multiplier32 = bias_scale * (1. / output_scale) * 2147483648; + WORD32 out_shift32 = 0; + WORD32 kernel_zero_bias = -weight_zero_point; + + WORD32 out_zero_bias = output_zero_point; + WORD32 out_data_format = 0; + WORD32 scratch_size = + xa_nn_conv1d_std_getsize(kernel_width, input_width, input_channels, 8); + scratch_size = scratch_size < 0 ? 0 : scratch_size; + WORD32* ptr_scratch = + (WORD32*)::impl::HiFi::kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = p_inp + _n * input_channels * input_width; + WORD8* out_batch = p_out + _n * out_channels * out_width; + + xa_nn_conv1d_std_asym8xasym8( + (UWORD8*)out_batch, + (UWORD8*)in_batch, + (UWORD8*)p_kernel, + p_bias, + 1, + input_width, + input_channels, + kernel_width, + out_channels, + x_stride, + x_padding, + out_width, + input_zero_bias, + kernel_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } +} + +void quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + __ET_UNUSED IntArrayRef dilation, + __ET_UNUSED int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv1d_nlc_asym8sxsym8s_asym8s( + ctx, + input, + weight, + bias, + stride, + padding, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..2d7a4cba509 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Optimized NHWC 1D convolution for uint8 x uint8 -> uint8 +void xa_opt_quantized_conv1d_nlc_asym8uxsym8u_asym8u( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + UWORD8* __restrict__ p_out = + (UWORD8* __restrict__)out.mutable_data_ptr(); + UWORD8* __restrict__ p_inp = + (UWORD8* __restrict__)input.const_data_ptr(); + UWORD8* __restrict__ p_kernel = + (UWORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 batches = input.size(0); + WORD32 input_channels = input.size(1); + WORD32 input_width = input.size(2); + WORD32 out_channels = weight.size(0); + WORD32 kernel_width = weight.size(2); + WORD32 out_width = out.size(2); + WORD32 x_stride = stride[1]; + WORD32 x_padding = padding[1]; + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_multiplier32 = bias_scale * (1. / output_scale) * 2147483648; + WORD32 out_shift32 = 0; + WORD32 kernel_zero_bias = -weight_zero_point; + + WORD32 out_zero_bias = output_zero_point; + WORD32 out_data_format = 0; + WORD32 scratch_size = + xa_nn_conv1d_std_getsize(kernel_width, input_width, input_channels, 8); + scratch_size = scratch_size < 0 ? 0 : scratch_size; + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = p_inp + _n * input_channels * input_width; + UWORD8* out_batch = p_out + _n * out_channels * out_width; + + xa_nn_conv1d_std_asym8uxasym8u( + out_batch, + in_batch, + p_kernel, + p_bias, + 1, + input_width, + input_channels, + kernel_width, + out_channels, + x_stride, + x_padding, + out_width, + input_zero_bias, + kernel_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } +} + +void quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + __ET_UNUSED IntArrayRef dilation, + __ET_UNUSED int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv1d_nlc_asym8uxsym8u_asym8u( + ctx, + input, + weight, + bias, + stride, + padding, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..e2584485686 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,246 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Optimized NCHW convolution for int8 x int8 -> int8 +void xa_opt_quantized_conv2d_nchw_asym8sxsym8s_asym8s( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + constexpr int kNnlibMaxDim = 4; + + WORD8* __restrict__ p_out = + (WORD8* __restrict__)out.mutable_data_ptr(); + WORD8* __restrict__ p_inp = + (WORD8* __restrict__)input.const_data_ptr(); + WORD8* __restrict__ p_kernel = + (WORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 kernel_channels = weight.size(1); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + WORD32 dilation_width = dilation[1]; + WORD32 dilation_height = dilation[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 kernel_zero_bias = -weight_zero_point; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + WORD32 kernel_precision = 8; + pVOID p_scratch = nullptr; + WORD32* ptr_scratch; + + WORD32 scratch_size = 0; + + ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution"); + WORD32 out_data_format = 1; + + WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( + ctx, + ((batches * input_channels * input_height * input_width) + 8) * + sizeof(WORD8)); + + WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory( + ctx, + ((out_channels * kernel_channels * kernel_height * kernel_width) + 8) * + sizeof(WORD8)); + + WORD8* pin = (WORD8*)ALIGN_PTR(ptr1, 8); + WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8); + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = input.size(0); + p_inp_shape[1] = input_channels; + p_inp_shape[2] = input_height; + p_inp_shape[3] = input_width; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = input.size(0); + p_out_shape[1] = input_height; + p_out_shape[2] = input_width; + p_out_shape[3] = input_channels; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 3, 1}; + + xa_nn_transpose_8_8( + pin, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + WORD32 p_inp_shape1[kNnlibMaxDim]; + p_inp_shape1[0] = out_channels; + p_inp_shape1[1] = kernel_channels; + p_inp_shape1[2] = kernel_height; + p_inp_shape1[3] = kernel_width; + + WORD32 p_out_shape1[kNnlibMaxDim]; + p_out_shape1[0] = out_channels; + p_out_shape1[1] = kernel_height; + p_out_shape1[2] = kernel_width; + p_out_shape1[3] = kernel_channels; + + xa_nn_transpose_8_8( + pkernel, + p_out_shape1, + p_kernel, + p_inp_shape1, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + scratch_size = xa_nn_conv2d_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + y_stride, + y_padding, + x_stride, + x_padding, + out_height, + out_width, + out_channels, + inp_precision, + kernel_precision, + out_data_format); + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + + p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = pin + _n * input_channels * input_height * input_width; + WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_per_chan_sym8sxasym8s( + out_batch, + in_batch, + pkernel, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + out_channels, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } +} + +void quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv2d_nchw_asym8sxsym8s_asym8s( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..8444fef6bd1 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,246 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Optimized NCHW convolution for uint8 x uint8 -> uint8 +void xa_opt_quantized_conv2d_nchw_asym8uxsym8u_asym8u( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + constexpr int kNnlibMaxDim = 4; + + UWORD8* __restrict__ p_out = + (UWORD8* __restrict__)out.mutable_data_ptr(); + UWORD8* __restrict__ p_inp = + (UWORD8* __restrict__)input.const_data_ptr(); + UWORD8* __restrict__ p_kernel = + (UWORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 kernel_channels = weight.size(1); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + WORD32 dilation_width = dilation[1]; + WORD32 dilation_height = dilation[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 kernel_zero_bias = -weight_zero_point; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + WORD32 kernel_precision = 8; + pVOID p_scratch = nullptr; + WORD32* ptr_scratch; + + WORD32 scratch_size = 0; + + ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution"); + WORD32 out_data_format = 1; + + UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory( + ctx, + ((batches * input_channels * input_height * input_width) + 8) * + sizeof(UWORD8)); + + UWORD8* ptr2 = (UWORD8*)kernels::allocate_temp_memory( + ctx, + ((out_channels * kernel_channels * kernel_height * kernel_width) + 8) * + sizeof(UWORD8)); + + UWORD8* pin = (UWORD8*)ALIGN_PTR(ptr1, 8); + UWORD8* pkernel = (UWORD8*)ALIGN_PTR(ptr2, 8); + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = input.size(0); + p_inp_shape[1] = input_channels; + p_inp_shape[2] = input_height; + p_inp_shape[3] = input_width; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = input.size(0); + p_out_shape[1] = input_height; + p_out_shape[2] = input_width; + p_out_shape[3] = input_channels; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 3, 1}; + + xa_nn_transpose_8_8( + (WORD8*)pin, + p_out_shape, + (WORD8*)p_inp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + WORD32 p_inp_shape1[kNnlibMaxDim]; + p_inp_shape1[0] = out_channels; + p_inp_shape1[1] = kernel_channels; + p_inp_shape1[2] = kernel_height; + p_inp_shape1[3] = kernel_width; + + WORD32 p_out_shape1[kNnlibMaxDim]; + p_out_shape1[0] = out_channels; + p_out_shape1[1] = kernel_height; + p_out_shape1[2] = kernel_width; + p_out_shape1[3] = kernel_channels; + + xa_nn_transpose_8_8( + (WORD8*)pkernel, + p_out_shape1, + (WORD8*)p_kernel, + p_inp_shape1, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + scratch_size = xa_nn_conv2d_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + y_stride, + y_padding, + x_stride, + x_padding, + out_height, + out_width, + out_channels, + inp_precision, + kernel_precision, + out_data_format); + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + + p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = pin + _n * input_channels * input_height * input_width; + UWORD8* out_batch = p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_per_chan_sym8sxasym8s( + (WORD8*)out_batch, + (WORD8*)in_batch, + (WORD8*)pkernel, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + out_channels, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } +} + +void quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv2d_nchw_asym8uxsym8u_asym8u( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..787984e52db --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,201 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Specialized depthwise NCHW convolution for int8 x int8 -> int8 +void xa_opt_quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + constexpr int kNnlibMaxDim = 4; + + WORD8* __restrict__ p_out = + (WORD8* __restrict__)out.mutable_data_ptr(); + WORD8* __restrict__ p_inp = + (WORD8* __restrict__)input.const_data_ptr(); + WORD8* __restrict__ p_kernel = + (WORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + + WORD32 channels_multiplier = out_channels / input_channels; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 scratch_size = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + inp_precision, + 1); // NCHW + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( + ctx, + ((batches * out_channels * out_height * out_width) + 8) * sizeof(WORD8)); + + WORD8* p_out_temp = (WORD8*)ALIGN_PTR(ptr1, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; + WORD8* out_batch = p_out_temp + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( + out_batch, + p_kernel, + in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + 1, // NCHW + 0, // NHWC + p_scratch); + } + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = batches; + p_inp_shape[1] = out_height; + p_inp_shape[2] = out_width; + p_inp_shape[3] = out_channels; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = batches; + p_out_shape[1] = out_channels; + p_out_shape[2] = out_height; + p_out_shape[3] = out_width; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2}; + + xa_nn_transpose_8_8( + p_out, + p_out_shape, + p_out_temp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); +} + +void quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..219eaf44ad7 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,201 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Specialized depthwise NCHW convolution for uint8 x uint8 -> uint8 +void xa_opt_quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + constexpr int kNnlibMaxDim = 4; + + UWORD8* __restrict__ p_out = + (UWORD8* __restrict__)out.mutable_data_ptr(); + UWORD8* __restrict__ p_inp = + (UWORD8* __restrict__)input.const_data_ptr(); + UWORD8* __restrict__ p_kernel = + (UWORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + + WORD32 channels_multiplier = out_channels / input_channels; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 scratch_size = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + inp_precision, + 1); // NCHW + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory( + ctx, + ((batches * out_channels * out_height * out_width) + 8) * sizeof(UWORD8)); + + UWORD8* p_out_temp = (UWORD8*)ALIGN_PTR(ptr1, 8); + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; + UWORD8* out_batch = p_out_temp + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( + (WORD8*)out_batch, + (WORD8*)p_kernel, + (WORD8*)in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + 1, // NCHW + 0, // NHWC + p_scratch); + } + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = batches; + p_inp_shape[1] = out_height; + p_inp_shape[2] = out_width; + p_inp_shape[3] = out_channels; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = batches; + p_out_shape[1] = out_channels; + p_out_shape[2] = out_height; + p_out_shape[3] = out_width; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2}; + + xa_nn_transpose_8_8( + (WORD8*)p_out, + p_out_shape, + (WORD8*)p_out_temp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); +} + +void quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..fc279f2bbdf --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Dilated fallback implementation for int8 x int8 -> int8 quantized 2d conv +// kernel for NCHW layout. This variant is optimized for asymmetric int8 inputs, +// weights, and outputs. The input is of shape [n x c x h x w] The weight is of +// shape [oc x wc x wh x ww], where wc == c The output is of shape [n x oc x oh +// x ow] The bias is of shape [oc] +template +__attribute__((noinline)) void conv2d_nchw_dilated_asym8sxsym8s_asym8s_core( + // All the arrays + const int8_t* __restrict__ p_in, + const int8_t* __restrict__ p_weight, + const int32_t* __restrict__ p_bias, + int8_t* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t c, + int32_t h, + int32_t w, + int32_t oc, + int32_t wc, + int32_t wh, + int32_t ww, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Quantization parameters + int8_t in_zero_point = 0, + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + int8_t out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const int8_t* in_batch = p_in + _n * c * h * w; + int8_t* out_batch = p_out + _n * oc * oh * ow; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + int8_t* out_plane = out_batch + _oc * oh * ow; + const int8_t* weight_batch = p_weight + _oc * wc * wh * ww; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of size + // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an + // output channel of size 1 x oh x ow. + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to the + // output channel being computed) with the corresponding weight + // channel. + // General path for dilated convolutions with padding support + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const int8_t* in_plane = in_batch + _ic * h * w; + const int8_t* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int input_h = _h + d0 * _wh - p0; + int input_w = _w + d1 * _ww - p1; + if ((input_h >= 0) && (input_h < h) && (input_w >= 0) && + (input_w < w)) { + int ioff = input_h * w + input_w; + int woff = _wh * ww + _ww; + float lhs = static_cast(in_plane[ioff]) - + static_cast(in_zero_point); + float rhs = static_cast(weight_plane[woff]) - + static_cast(weight_zero_point); + acc += lhs * rhs; + } + } + } + } + // Quantize the accumulated result + float val = bias_scale * acc; + out_plane[_oh * ow + _ow] = + kernels::quantize(val, inv_out_scale, out_zero_point); + } + } + } + } + } +} + +void quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, c, h, w] + const int n = input.size(0); + const int c = input.size(1); + const int h = conv1d ? 1 : input.size(2); + const int w = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wc, wh, ww] + const int oc = weight.size(0); + const int wc = weight.size(1); + const int wh = conv1d ? 1 : weight.size(2); + const int ww = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oc, oh, ow] + const int oh = conv1d ? 1 : out.size(2); + const int ow = conv1d ? out.size(2) : out.size(3); + + conv2d_nchw_dilated_asym8sxsym8s_asym8s_core( + input.const_data_ptr(), + weight.const_data_ptr(), + bias.const_data_ptr(), + out.mutable_data_ptr(), + n, + c, + h, + w, + oc, + wc, + wh, + ww, + oh, + ow, + stride[0], + stride[1], + padding[0], + padding[1], + dilation[0], + dilation[1], + groups, + static_cast(in_zero_point), + weight_zero_point, + bias_scale, + output_scale, + static_cast(output_zero_point)); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..08ca4657c75 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Dilated fallback implementation for uint8 x uint8 -> uint8 quantized 2d conv +// kernel for NCHW layout. This variant is optimized for asymmetric uint8 +// inputs, weights, and outputs. The input is of shape [n x c x h x w] The +// weight is of shape [oc x wc x wh x ww], where wc == c The output is of shape +// [n x oc x oh x ow] The bias is of shape [oc] +template +__attribute__((noinline)) void conv2d_nchw_dilated_asym8uxsym8u_asym8u_core( + // All the arrays + const uint8_t* __restrict__ p_in, + const uint8_t* __restrict__ p_weight, + const int32_t* __restrict__ p_bias, + uint8_t* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t c, + int32_t h, + int32_t w, + int32_t oc, + int32_t wc, + int32_t wh, + int32_t ww, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Quantization parameters + uint8_t in_zero_point = 0, + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + uint8_t out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const uint8_t* in_batch = p_in + _n * c * h * w; + uint8_t* out_batch = p_out + _n * oc * oh * ow; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + uint8_t* out_plane = out_batch + _oc * oh * ow; + const uint8_t* weight_batch = p_weight + _oc * wc * wh * ww; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of size + // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an + // output channel of size 1 x oh x ow. + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to the + // output channel being computed) with the corresponding weight + // channel. + // General path for dilated convolutions with padding support + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const uint8_t* in_plane = in_batch + _ic * h * w; + const uint8_t* weight_plane = + weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int input_h = _h + d0 * _wh - p0; + int input_w = _w + d1 * _ww - p1; + if ((input_h >= 0) && (input_h < h) && (input_w >= 0) && + (input_w < w)) { + int ioff = input_h * w + input_w; + int woff = _wh * ww + _ww; + float lhs = static_cast(in_plane[ioff]) - + static_cast(in_zero_point); + float rhs = static_cast(weight_plane[woff]) - + static_cast(weight_zero_point); + acc += lhs * rhs; + } + } + } + } + // Quantize the accumulated result + float val = bias_scale * acc; + out_plane[_oh * ow + _ow] = + kernels::quantize(val, inv_out_scale, out_zero_point); + } + } + } + } + } +} + +void quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, c, h, w] + const int n = input.size(0); + const int c = input.size(1); + const int h = conv1d ? 1 : input.size(2); + const int w = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wc, wh, ww] + const int oc = weight.size(0); + const int wc = weight.size(1); + const int wh = conv1d ? 1 : weight.size(2); + const int ww = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oc, oh, ow] + const int oh = conv1d ? 1 : out.size(2); + const int ow = conv1d ? out.size(2) : out.size(3); + + conv2d_nchw_dilated_asym8uxsym8u_asym8u_core( + input.const_data_ptr(), + weight.const_data_ptr(), + bias.const_data_ptr(), + out.mutable_data_ptr(), + n, + c, + h, + w, + oc, + wc, + wh, + ww, + oh, + ow, + stride[0], + stride[1], + padding[0], + padding[1], + dilation[0], + dilation[1], + groups, + static_cast(in_zero_point), + weight_zero_point, + bias_scale, + output_scale, + static_cast(output_zero_point)); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp new file mode 100644 index 00000000000..a17f1e6a3bc --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp @@ -0,0 +1,693 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::impl::HiFi::kernels::quantize; +using ::torch::executor::KernelRuntimeContext; + +// This implements a generic 2d conv kernel that operates on raw pointers. +// The version handles both quantized and fp32 convolutions. +// The input is of shape [n x c x h x w] +// The weight is of shape [oc x wc x wh x ww], where wc == c +// The output is of shape [n x oc x oh x ow] +// The bias is of shape [oc] +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv2d_nchw_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t c, + int32_t h, + int32_t w, + int32_t oc, + int32_t wc, + int32_t wh, + int32_t ww, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * c * h * w; + OT* out_batch = p_out + _n * oc * oh * ow; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + OT* out_plane = out_batch + _oc * oh * ow; + const WT* weight_batch = p_weight + _oc * wc * wh * ww; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of size + // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an + // output channel of size 1 x oh x ow. + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to the + // output channel being computed) with the corresponding weight + // channel. + // If the padding is 0, and dilation is 1, then we can remove the + // unnecessary checks, and simplify the code so that it can be + // vectorized by Tensilica compiler. + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = (_h + _wh) * w + (_w + _ww); + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_h + d0 * _wh - p0) >= 0) && + ((_h + d0 * _wh - p0) < h) && + ((_w + d1 * _ww - p1) >= 0) && + ((_w + d1 * _ww - p1) < w)) { + int ioff = + (_h + d0 * _wh - p0) * w + (_w + d1 * _ww - p1); + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_plane[_oh * ow + _ow] = + quantize(val, inv_out_scale, out_zero_point); + } else { + out_plane[_oh * ow + _ow] = acc; + } + } + } + } + } + } +} + +void xa_opt_quantized_conv2d_nchw( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + constexpr int kNnlibMaxDim = 4; + + if (input.scalar_type() == ScalarType::Char) { + WORD8* __restrict__ p_out = + (WORD8* __restrict__)out.mutable_data_ptr(); + WORD8* __restrict__ p_inp = + (WORD8* __restrict__)input.const_data_ptr(); + WORD8* __restrict__ p_kernel = + (WORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 kernel_channels = weight.size(1); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + WORD32 dilation_width = dilation[1]; + WORD32 dilation_height = dilation[0]; + + // WORD32* kernel_bias_ptr = + // (WORD32*)weight_zero_point.const_data_ptr(); + + WORD32 input_zero_bias = -in_zero_point; + WORD32 kernel_zero_bias = -weight_zero_point; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + WORD32 kernel_precision = 8; + pVOID p_scratch = nullptr; + WORD32* ptr_scratch; + + WORD32 scratch_size = 0; + + if (groups == 1) { + WORD32 out_data_format = 1; + + WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( + ctx, + ((batches * input_channels * input_height * input_width) + 8) * + sizeof(WORD8)); + + WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory( + ctx, + ((out_channels * kernel_channels * kernel_height * kernel_width) + + 8) * + sizeof(WORD8)); + + WORD8* pin = (WORD8*)ALIGN_PTR(ptr1, 8); + WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8); + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = input.size(0); + p_inp_shape[1] = input_channels; + p_inp_shape[2] = input_height; + p_inp_shape[3] = input_width; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = input.size(0); + p_out_shape[1] = input_height; + p_out_shape[2] = input_width; + p_out_shape[3] = input_channels; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 3, 1}; + + xa_nn_transpose_8_8( + pin, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, // input dimensions + kNnlibMaxDim); // output dimensions + + WORD32 p_inp_shape1[kNnlibMaxDim]; + p_inp_shape1[0] = out_channels; + p_inp_shape1[1] = kernel_channels; + p_inp_shape1[2] = kernel_height; + p_inp_shape1[3] = kernel_width; + + WORD32 p_out_shape1[kNnlibMaxDim]; + p_out_shape1[0] = out_channels; + p_out_shape1[1] = kernel_height; + p_out_shape1[2] = kernel_width; + p_out_shape1[3] = kernel_channels; + + xa_nn_transpose_8_8( + pkernel, + p_out_shape1, + p_kernel, + p_inp_shape1, + p_permute_vec, + kNnlibMaxDim, // input dimensions + kNnlibMaxDim); // output dimensions + + scratch_size = xa_nn_conv2d_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + y_stride, + y_padding, + x_stride, + x_padding, + out_height, + out_width, + out_channels, + inp_precision, + kernel_precision, + out_data_format); + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + + p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = + pin + _n * input_channels * input_height * input_width; + WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_per_chan_sym8sxasym8s( + out_batch, + in_batch, + pkernel, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + out_channels, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } + return; + } + + if (groups == input_channels) { + WORD32 channels_multiplier = out_channels / input_channels; + + scratch_size = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + inp_precision, + 1); // NCHW + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + + p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( + ctx, + ((batches * out_channels * out_height * out_width) + 8) * + sizeof(WORD8)); + + WORD8* p_out_temp = (WORD8*)ALIGN_PTR(ptr1, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = + p_inp + _n * input_channels * input_height * input_width; + WORD8* out_batch = + p_out_temp + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( + out_batch, + p_kernel, + in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + 1, // NCHW + 0, // NHWC + p_scratch); + } + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = batches; + p_inp_shape[1] = out_height; + p_inp_shape[2] = out_width; + p_inp_shape[3] = out_channels; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = batches; + p_out_shape[1] = out_channels; + p_out_shape[2] = out_height; + p_out_shape[3] = out_width; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2}; + + xa_nn_transpose_8_8( + p_out, + p_out_shape, + p_out_temp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, // input dimensions + kNnlibMaxDim); // output dimensions + + return; + } + } +} + +// The quantized convolution kernel. in_scale and weight_scale are implicit in +// bias_scale, since it is a product of the two. The kernel will branch to +// quantized::conv1d or quantized::conv2d based on the dimensionality of +// activation tensor. +void quantized_conv2d_nchw( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, c, h, w] + const int n = input.size(0); + const int c = input.size(1); + const int h = conv1d ? 1 : input.size(2); + const int w = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wc, wh, ww] + const int oc = weight.size(0); + const int wc = weight.size(1); + const int wh = conv1d ? 1 : weight.size(2); + const int ww = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oc, oh, ow] + const int oh = conv1d ? 1 : out.size(2); + const int ow = conv1d ? out.size(2) : out.size(3); + +#define typed_quantized_conv2d_nchw(ctype, dtype) \ + case ScalarType::dtype: { \ + conv2d_nchw_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + c, \ + h, \ + w, \ + oc, \ + wc, \ + wh, \ + ww, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nchw); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nchw +} + +void quantized_conv2d_nchw_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED const Tensor& out_multiplier, + __ET_UNUSED const Tensor& out_shift, + Tensor& out) { + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && + input.scalar_type() == ::executorch::aten::ScalarType::Short && + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { + ::impl::generic::native::quantized_conv2d_nchw_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + out); + return; + } + + const float bias_scale_float = bias_scale.const_data_ptr()[0]; + const int32_t weight_zero_point_int = + weight_zero_point.const_data_ptr()[0]; + + bool optimized = 0; + + if ((input.scalar_type() == ScalarType::Char) || + (input.scalar_type() == ScalarType::Byte)) + optimized = 1; + + if ((dilation[0] != 1) || (dilation[1] != 1)) + optimized = 0; + + if (optimized) { + xa_opt_quantized_conv2d_nchw( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point_int, + bias_scale_float, + output_scale, + output_zero_point, + out); + } else { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point_int, + bias_scale_float, + output_scale, + output_zero_point, + out); + } +} + +void quantized_conv2d_nchw_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && + input.scalar_type() == ::executorch::aten::ScalarType::Short && + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { + ::impl::generic::native::quantized_conv2d_nchw_per_tensor_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + out); + return; + } + + bool optimized = 0; + + if ((input.scalar_type() == ScalarType::Char) || + (input.scalar_type() == ScalarType::Byte)) + optimized = 1; + + if ((dilation[0] != 1) || (dilation[1] != 1)) + optimized = 0; + + if (optimized) { + xa_opt_quantized_conv2d_nchw( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + } else { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + } +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..9bd7e641144 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Optimized NHWC convolution for int8 x int8 -> int8 +void xa_opt_quantized_conv2d_nhwc_asym8sxsym8s_asym8s( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + constexpr int kNnlibMaxDim = 4; + + WORD8* __restrict__ p_out = + (WORD8* __restrict__)out.mutable_data_ptr(); + WORD8* __restrict__ p_inp = + (WORD8* __restrict__)input.const_data_ptr(); + WORD8* __restrict__ p_kernel = + (WORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 kernel_channels = weight.size(1); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + WORD32 dilation_width = dilation[1]; + WORD32 dilation_height = dilation[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 kernel_zero_bias = -weight_zero_point; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + WORD32 kernel_precision = 8; + pVOID p_scratch = nullptr; + WORD32* ptr_scratch; + + WORD32 scratch_size = 0; + + ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution"); + WORD32 out_data_format = 1; + + scratch_size = xa_nn_conv2d_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + y_stride, + y_padding, + x_stride, + x_padding, + out_height, + out_width, + out_channels, + inp_precision, + kernel_precision, + out_data_format); + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + + p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; + WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_per_chan_sym8sxasym8s( + out_batch, + in_batch, + p_kernel, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + out_channels, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } +} + +void quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv2d_nhwc_asym8sxsym8s_asym8s( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..433cbf76fce --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Optimized NHWC convolution for uint8 x uint8 -> uint8 +void xa_opt_quantized_conv2d_nhwc_asym8uxsym8u_asym8u( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + constexpr int kNnlibMaxDim = 4; + + UWORD8* __restrict__ p_out = + (UWORD8* __restrict__)out.mutable_data_ptr(); + UWORD8* __restrict__ p_inp = + (UWORD8* __restrict__)input.const_data_ptr(); + UWORD8* __restrict__ p_kernel = + (UWORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 kernel_channels = weight.size(1); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + WORD32 dilation_width = dilation[1]; + WORD32 dilation_height = dilation[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 kernel_zero_bias = -weight_zero_point; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + WORD32 kernel_precision = 8; + pVOID p_scratch = nullptr; + WORD32* ptr_scratch; + + WORD32 scratch_size = 0; + + ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution"); + WORD32 out_data_format = 1; + + scratch_size = xa_nn_conv2d_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + y_stride, + y_padding, + x_stride, + x_padding, + out_height, + out_width, + out_channels, + inp_precision, + kernel_precision, + out_data_format); + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + + p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; + UWORD8* out_batch = p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_per_chan_sym8sxasym8s( + (WORD8*)out_batch, + (WORD8*)in_batch, + (WORD8*)p_kernel, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + out_channels, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } +} + +void quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv2d_nhwc_asym8uxsym8u_asym8u( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..384ebbb4f48 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,171 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Specialized depthwise NHWC convolution for int8 x int8 -> int8 +void xa_opt_quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + + WORD8* __restrict__ p_out = + (WORD8* __restrict__)out.mutable_data_ptr(); + WORD8* __restrict__ p_inp = + (WORD8* __restrict__)input.const_data_ptr(); + WORD8* __restrict__ p_kernel = + (WORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + + WORD32 channels_multiplier = out_channels / input_channels; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 scratch_size = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + inp_precision, + 0); // NHWC + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; + WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( + out_batch, + p_kernel, + in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + 0, // NHWC + 0, // NHWC + p_scratch); + } +} + +void quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..07df1a416d7 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,171 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Specialized depthwise NHWC convolution for uint8 x uint8 -> uint8 +void xa_opt_quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + + UWORD8* __restrict__ p_out = + (UWORD8* __restrict__)out.mutable_data_ptr(); + UWORD8* __restrict__ p_inp = + (UWORD8* __restrict__)input.const_data_ptr(); + UWORD8* __restrict__ p_kernel = + (UWORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + + WORD32 channels_multiplier = out_channels / input_channels; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 scratch_size = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + inp_precision, + 0); // NHWC + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; + UWORD8* out_batch = p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( + (WORD8*)out_batch, + (WORD8*)p_kernel, + (WORD8*)in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + 0, // NHWC + 0, // NHWC + p_scratch); + } +} + +void quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..91965594a5d --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Dilated fallback implementation for int8 x int8 -> int8 quantized 2d conv +// kernel for NHWC layout. This variant is optimized for asymmetric int8 inputs, +// weights, and outputs. The input is of shape [n x h x w x c] The weight is of +// shape [oc x wh x ww x wc] The output is of shape [n x oh x ow x oc] The bias +// is of shape [oc] +template +__attribute__((noinline)) void conv2d_nhwc_dilated_asym8sxsym8s_asym8s_core( + // All the arrays + const int8_t* __restrict__ p_in, + const int8_t* __restrict__ p_weight, + const int32_t* __restrict__ p_bias, + int8_t* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t h, + int32_t w, + int32_t c, + int32_t oc, + int32_t wh, + int32_t ww, + int32_t wc, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Quantization parameters + int8_t in_zero_point = 0, + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + int8_t out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const int8_t* in_batch = p_in + _n * h * w * c; + int8_t* out_batch = p_out + _n * oh * ow * oc; + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + int8_t* out_line = out_batch + (_oh * ow + _ow) * oc; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + const int8_t* weight_batch = p_weight + _oc * wh * ww * wc; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of + // size h x w x icpg, with a stencil of size wh x ww x icpg, to + // compute an output channel of size oh x ow x 1. + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to + // the output channel being computed) with the corresponding + // weight channel. + // General path for dilated convolutions with padding support + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int input_h = _h + d0 * _wh - p0; + int input_w = _w + d1 * _ww - p1; + if ((input_h >= 0) && (input_h < h) && (input_w >= 0) && + (input_w < w)) { + const int8_t* in_line = + in_batch + input_h * w * c + input_w * c; + const int8_t* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = static_cast(in_line[_ic]) - + static_cast(in_zero_point); + float rhs = static_cast(weight_line[_ic - sic]) - + static_cast(weight_zero_point); + acc += lhs * rhs; + } + } + } + } + // Quantize the accumulated result + float val = bias_scale * acc; + out_line[_oc] = + kernels::quantize(val, inv_out_scale, out_zero_point); + } + } + } + } + } +} + +void quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, h, w, c] + const int n = input.size(0); + const int h = conv1d ? 1 : input.size(1); + const int w = conv1d ? input.size(1) : input.size(2); + const int c = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wh, ww, wc] + const int oc = weight.size(0); + const int wh = conv1d ? 1 : weight.size(1); + const int ww = conv1d ? weight.size(1) : weight.size(2); + const int wc = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oh, ow, oc] + const int oh = conv1d ? 1 : out.size(1); + const int ow = conv1d ? out.size(1) : out.size(2); + + conv2d_nhwc_dilated_asym8sxsym8s_asym8s_core( + input.const_data_ptr(), + weight.const_data_ptr(), + bias.const_data_ptr(), + out.mutable_data_ptr(), + n, + h, + w, + c, + oc, + wh, + ww, + wc, + oh, + ow, + stride[0], + stride[1], + padding[0], + padding[1], + dilation[0], + dilation[1], + groups, + static_cast(in_zero_point), + weight_zero_point, + bias_scale, + output_scale, + static_cast(output_zero_point)); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..14dc31a719f --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +// Dilated fallback implementation for uint8 x uint8 -> uint8 quantized 2d conv +// kernel for NHWC layout. This variant is optimized for asymmetric uint8 +// inputs, weights, and outputs. The input is of shape [n x h x w x c] The +// weight is of shape [oc x wh x ww x wc] The output is of shape [n x oh x ow x +// oc] The bias is of shape [oc] +template +__attribute__((noinline)) void conv2d_nhwc_dilated_asym8uxsym8u_asym8u_core( + // All the arrays + const uint8_t* __restrict__ p_in, + const uint8_t* __restrict__ p_weight, + const int32_t* __restrict__ p_bias, + uint8_t* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t h, + int32_t w, + int32_t c, + int32_t oc, + int32_t wh, + int32_t ww, + int32_t wc, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Quantization parameters + uint8_t in_zero_point = 0, + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + uint8_t out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const uint8_t* in_batch = p_in + _n * h * w * c; + uint8_t* out_batch = p_out + _n * oh * ow * oc; + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + uint8_t* out_line = out_batch + (_oh * ow + _ow) * oc; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + const uint8_t* weight_batch = p_weight + _oc * wh * ww * wc; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of + // size h x w x icpg, with a stencil of size wh x ww x icpg, to + // compute an output channel of size oh x ow x 1. + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to + // the output channel being computed) with the corresponding + // weight channel. + // General path for dilated convolutions with padding support + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int input_h = _h + d0 * _wh - p0; + int input_w = _w + d1 * _ww - p1; + if ((input_h >= 0) && (input_h < h) && (input_w >= 0) && + (input_w < w)) { + const uint8_t* in_line = + in_batch + input_h * w * c + input_w * c; + const uint8_t* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = static_cast(in_line[_ic]) - + static_cast(in_zero_point); + float rhs = static_cast(weight_line[_ic - sic]) - + static_cast(weight_zero_point); + acc += lhs * rhs; + } + } + } + } + // Quantize the accumulated result + float val = bias_scale * acc; + out_line[_oc] = + kernels::quantize(val, inv_out_scale, out_zero_point); + } + } + } + } + } +} + +void quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, h, w, c] + const int n = input.size(0); + const int h = conv1d ? 1 : input.size(1); + const int w = conv1d ? input.size(1) : input.size(2); + const int c = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wh, ww, wc] + const int oc = weight.size(0); + const int wh = conv1d ? 1 : weight.size(1); + const int ww = conv1d ? weight.size(1) : weight.size(2); + const int wc = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oh, ow, oc] + const int oh = conv1d ? 1 : out.size(1); + const int ow = conv1d ? out.size(1) : out.size(2); + + conv2d_nhwc_dilated_asym8uxsym8u_asym8u_core( + input.const_data_ptr(), + weight.const_data_ptr(), + bias.const_data_ptr(), + out.mutable_data_ptr(), + n, + h, + w, + c, + oc, + wh, + ww, + wc, + oh, + ow, + stride[0], + stride[1], + padding[0], + padding[1], + dilation[0], + dilation[1], + groups, + static_cast(in_zero_point), + weight_zero_point, + bias_scale, + output_scale, + static_cast(output_zero_point)); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp new file mode 100644 index 00000000000..b2a7c341997 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp @@ -0,0 +1,597 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace impl { +namespace HiFi { +namespace native { + +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv2d_nhwc_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t h, + int32_t w, + int32_t c, + int32_t oc, + int32_t wh, + int32_t ww, + int32_t wc, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * h * w * c; + OT* out_batch = p_out + _n * oh * ow * oc; + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + OT* out_line = out_batch + (_oh * ow + _ow) * oc; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + const WT* weight_batch = p_weight + _oc * wh * ww * wc; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of + // size h x w x icpg, with a stencil of size wh x ww x icpg, to + // compute an output channel of size oh x ow x 1. + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to + // the output channel being computed) with the corresponding + // weight channel. If the padding is 0, and dilation is 1, then + // we can remove the unnecessary checks, and simplify the code + // so that it can be vectorized by Tensilica compiler.x`` + if (zero_pad_unit_dilation) { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + const IT* in_line = + in_batch + (_h + _wh) * w * c + (_w + _ww) * c; + const WT* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } else { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_h + d0 * _wh - p0) >= 0) && + ((_h + d0 * _wh - p0) < h) && + ((_w + d1 * _ww - p1) >= 0) && + ((_w + d1 * _ww - p1 < w))) { + const IT* in_line = in_batch + + (_h + d0 * _wh - p0) * w * c + (_w + d1 * _ww - p1) * c; + const WT* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_line[_oc] = + kernels::quantize(val, inv_out_scale, out_zero_point); + } else { + out_line[_oc] = acc; + } + } + } + } + } + } +} + +void xa_opt_quantized_conv2d_nhwc( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + constexpr int kNnlibMaxDim = 4; + + if (input.scalar_type() == ScalarType::Char) { + WORD8* __restrict__ p_out = + (WORD8* __restrict__)out.mutable_data_ptr(); + WORD8* __restrict__ p_inp = + (WORD8* __restrict__)input.const_data_ptr(); + WORD8* __restrict__ p_kernel = + (WORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 kernel_channels = weight.size(1); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + WORD32 dilation_width = dilation[1]; + WORD32 dilation_height = dilation[0]; + + // WORD32* kernel_bias_ptr = + // (WORD32*)weight_zero_point.const_data_ptr(); + + WORD32 input_zero_bias = -in_zero_point; + WORD32 kernel_zero_bias = -weight_zero_point; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + WORD32 kernel_precision = 8; + pVOID p_scratch = nullptr; + WORD32* ptr_scratch; + + WORD32 scratch_size = 0; + + if (groups == 1) { + WORD32 out_data_format = 1; + + scratch_size = xa_nn_conv2d_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + y_stride, + y_padding, + x_stride, + x_padding, + out_height, + out_width, + out_channels, + inp_precision, + kernel_precision, + out_data_format); + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + + p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = + p_inp + _n * input_channels * input_height * input_width; + WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_per_chan_sym8sxasym8s( + out_batch, + in_batch, + p_kernel, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + out_channels, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } + return; + } + + if (groups == input_channels) { + WORD32 channels_multiplier = out_channels / input_channels; + + scratch_size = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + inp_precision, + 0); // NHWC + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + + p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( + ctx, + ((batches * out_channels * out_height * out_width) + 8) * + sizeof(WORD8)); + + WORD8* p_out_temp = (WORD8*)ALIGN_PTR(ptr1, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = + p_inp + _n * input_channels * input_height * input_width; + WORD8* out_batch = + p_out_temp + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( + out_batch, + p_kernel, + in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + 0, // NHWC + 0, // NHWC + p_scratch); + } + + return; + } + } +} + +void quantized_conv2d_nhwc( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, h, w, c] + const int n = input.size(0); + const int h = conv1d ? 1 : input.size(1); + const int w = conv1d ? input.size(1) : input.size(2); + const int c = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wh, ww, wc] + const int oc = weight.size(0); + const int wh = conv1d ? 1 : weight.size(1); + const int ww = conv1d ? weight.size(1) : weight.size(2); + const int wc = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oh, ow, oc] + const int oh = conv1d ? 1 : out.size(1); + const int ow = conv1d ? out.size(1) : out.size(2); + +#define typed_quantized_conv2d_nhwc(ctype, dtype) \ + case ScalarType::dtype: { \ + conv2d_nhwc_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + h, \ + w, \ + c, \ + oc, \ + wh, \ + ww, \ + wc, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nhwc); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nhwc +} + +void quantized_conv2d_nhwc_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& out) { + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && + input.scalar_type() == ::executorch::aten::ScalarType::Short && + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { + ::impl::generic::native::quantized_conv2d_nhwc_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + out); + return; + } + const float bias_scale_float = bias_scale.const_data_ptr()[0]; + const int32_t weight_zero_point_int = + weight_zero_point.const_data_ptr()[0]; + + bool optimized = 0; + + if ((input.scalar_type() == ScalarType::Char) || + (input.scalar_type() == ScalarType::Byte)) + optimized = 1; + + if ((dilation[0] != 1) || (dilation[1] != 1)) + optimized = 0; + + if (optimized) { + xa_opt_quantized_conv2d_nhwc( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point_int, + bias_scale_float, + output_scale, + output_zero_point, + out); + } else { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point_int, + bias_scale_float, + output_scale, + output_zero_point, + out); + } +} + +void quantized_conv2d_nhwc_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && + input.scalar_type() == ::executorch::aten::ScalarType::Short && + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { + ::impl::generic::native::quantized_conv2d_nhwc_per_tensor_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + out); + return; + } + + bool optimized = 0; + if ((input.scalar_type() == ScalarType::Char) || + (input.scalar_type() == ScalarType::Byte)) + optimized = 1; + + if ((dilation[0] != 1) || (dilation[1] != 1)) + optimized = 0; + + if (optimized) { + xa_opt_quantized_conv2d_nhwc( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + } else { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + } +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp deleted file mode 100644 index 2788de589cf..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -// Optimized NCHW convolution for int8 x int8 -> int8 -void xa_opt_quantized_conv_nchw_asym8sxsym8s_asym8s( - KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - constexpr int kNnlibMaxDim = 4; - - WORD8* __restrict__ p_out = - (WORD8* __restrict__)out.mutable_data_ptr(); - WORD8* __restrict__ p_inp = - (WORD8* __restrict__)input.const_data_ptr(); - WORD8* __restrict__ p_kernel = - (WORD8* __restrict__)weight.const_data_ptr(); - WORD32* __restrict__ p_bias = - (WORD32* __restrict__)bias.const_data_ptr(); - - WORD32 input_height = conv1d ? 1 : input.size(2); - WORD32 input_width = conv1d ? input.size(2) : input.size(3); - WORD32 input_channels = input.size(1); - WORD32 kernel_height = conv1d ? 1 : weight.size(2); - WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); - WORD32 kernel_channels = weight.size(1); - WORD32 out_channels = weight.size(0); - WORD32 out_height = conv1d ? 1 : out.size(2); - WORD32 out_width = conv1d ? out.size(2) : out.size(3); - WORD32 batches = input.size(0); - - WORD32 x_stride = stride[1]; - WORD32 y_stride = stride[0]; - WORD32 x_padding = padding[1]; - WORD32 y_padding = padding[0]; - WORD32 dilation_width = dilation[1]; - WORD32 dilation_height = dilation[0]; - - WORD32 input_zero_bias = -in_zero_point; - WORD32 kernel_zero_bias = -weight_zero_point; - - WORD32 out_multiplier32[out_channels]; - WORD32 out_shift32[out_channels]; - - float out_scale = 1. / output_scale; - - for (int i = 0; i < out_channels; i++) { - out_multiplier32[i] = bias_scale * out_scale * 2147483648; - out_shift32[i] = 0; - } - - WORD32 out_zero_bias = output_zero_point; - WORD32 inp_precision = 8; - WORD32 kernel_precision = 8; - pVOID p_scratch = nullptr; - WORD32* ptr_scratch; - - WORD32 scratch_size = 0; - - ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution"); - WORD32 out_data_format = 1; - - WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( - ctx, - ((batches * input_channels * input_height * input_width) + 8) * - sizeof(WORD8)); - - WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory( - ctx, - ((out_channels * kernel_channels * kernel_height * kernel_width) + 8) * - sizeof(WORD8)); - - WORD8* pin = (WORD8*)ALIGN_PTR(ptr1, 8); - WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8); - - WORD32 p_inp_shape[kNnlibMaxDim]; - p_inp_shape[0] = input.size(0); - p_inp_shape[1] = input_channels; - p_inp_shape[2] = input_height; - p_inp_shape[3] = input_width; - - WORD32 p_out_shape[kNnlibMaxDim]; - p_out_shape[0] = input.size(0); - p_out_shape[1] = input_height; - p_out_shape[2] = input_width; - p_out_shape[3] = input_channels; - - WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 3, 1}; - - xa_nn_transpose_8_8( - pin, - p_out_shape, - p_inp, - p_inp_shape, - p_permute_vec, - kNnlibMaxDim, - kNnlibMaxDim); - - WORD32 p_inp_shape1[kNnlibMaxDim]; - p_inp_shape1[0] = out_channels; - p_inp_shape1[1] = kernel_channels; - p_inp_shape1[2] = kernel_height; - p_inp_shape1[3] = kernel_width; - - WORD32 p_out_shape1[kNnlibMaxDim]; - p_out_shape1[0] = out_channels; - p_out_shape1[1] = kernel_height; - p_out_shape1[2] = kernel_width; - p_out_shape1[3] = kernel_channels; - - xa_nn_transpose_8_8( - pkernel, - p_out_shape1, - p_kernel, - p_inp_shape1, - p_permute_vec, - kNnlibMaxDim, - kNnlibMaxDim); - - scratch_size = xa_nn_conv2d_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - y_stride, - y_padding, - x_stride, - x_padding, - out_height, - out_width, - out_channels, - inp_precision, - kernel_precision, - out_data_format); - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - - p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = pin + _n * input_channels * input_height * input_width; - WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_per_chan_sym8sxasym8s( - out_batch, - in_batch, - pkernel, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - out_channels, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - out_data_format, - p_scratch); - } -} - -void quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - xa_opt_quantized_conv_nchw_asym8sxsym8s_asym8s( - ctx, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp deleted file mode 100644 index 9fd2d69dda9..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -// Optimized NCHW convolution for uint8 x uint8 -> uint8 -void xa_opt_quantized_conv_nchw_asym8uxsym8u_asym8u( - KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - constexpr int kNnlibMaxDim = 4; - - UWORD8* __restrict__ p_out = - (UWORD8* __restrict__)out.mutable_data_ptr(); - UWORD8* __restrict__ p_inp = - (UWORD8* __restrict__)input.const_data_ptr(); - UWORD8* __restrict__ p_kernel = - (UWORD8* __restrict__)weight.const_data_ptr(); - WORD32* __restrict__ p_bias = - (WORD32* __restrict__)bias.const_data_ptr(); - - WORD32 input_height = conv1d ? 1 : input.size(2); - WORD32 input_width = conv1d ? input.size(2) : input.size(3); - WORD32 input_channels = input.size(1); - WORD32 kernel_height = conv1d ? 1 : weight.size(2); - WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); - WORD32 kernel_channels = weight.size(1); - WORD32 out_channels = weight.size(0); - WORD32 out_height = conv1d ? 1 : out.size(2); - WORD32 out_width = conv1d ? out.size(2) : out.size(3); - WORD32 batches = input.size(0); - - WORD32 x_stride = stride[1]; - WORD32 y_stride = stride[0]; - WORD32 x_padding = padding[1]; - WORD32 y_padding = padding[0]; - WORD32 dilation_width = dilation[1]; - WORD32 dilation_height = dilation[0]; - - WORD32 input_zero_bias = -in_zero_point; - WORD32 kernel_zero_bias = -weight_zero_point; - - WORD32 out_multiplier32[out_channels]; - WORD32 out_shift32[out_channels]; - - float out_scale = 1. / output_scale; - - for (int i = 0; i < out_channels; i++) { - out_multiplier32[i] = bias_scale * out_scale * 2147483648; - out_shift32[i] = 0; - } - - WORD32 out_zero_bias = output_zero_point; - WORD32 inp_precision = 8; - WORD32 kernel_precision = 8; - pVOID p_scratch = nullptr; - WORD32* ptr_scratch; - - WORD32 scratch_size = 0; - - ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution"); - WORD32 out_data_format = 1; - - UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory( - ctx, - ((batches * input_channels * input_height * input_width) + 8) * - sizeof(UWORD8)); - - UWORD8* ptr2 = (UWORD8*)kernels::allocate_temp_memory( - ctx, - ((out_channels * kernel_channels * kernel_height * kernel_width) + 8) * - sizeof(UWORD8)); - - UWORD8* pin = (UWORD8*)ALIGN_PTR(ptr1, 8); - UWORD8* pkernel = (UWORD8*)ALIGN_PTR(ptr2, 8); - - WORD32 p_inp_shape[kNnlibMaxDim]; - p_inp_shape[0] = input.size(0); - p_inp_shape[1] = input_channels; - p_inp_shape[2] = input_height; - p_inp_shape[3] = input_width; - - WORD32 p_out_shape[kNnlibMaxDim]; - p_out_shape[0] = input.size(0); - p_out_shape[1] = input_height; - p_out_shape[2] = input_width; - p_out_shape[3] = input_channels; - - WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 3, 1}; - - xa_nn_transpose_8_8( - (WORD8*)pin, - p_out_shape, - (WORD8*)p_inp, - p_inp_shape, - p_permute_vec, - kNnlibMaxDim, - kNnlibMaxDim); - - WORD32 p_inp_shape1[kNnlibMaxDim]; - p_inp_shape1[0] = out_channels; - p_inp_shape1[1] = kernel_channels; - p_inp_shape1[2] = kernel_height; - p_inp_shape1[3] = kernel_width; - - WORD32 p_out_shape1[kNnlibMaxDim]; - p_out_shape1[0] = out_channels; - p_out_shape1[1] = kernel_height; - p_out_shape1[2] = kernel_width; - p_out_shape1[3] = kernel_channels; - - xa_nn_transpose_8_8( - (WORD8*)pkernel, - p_out_shape1, - (WORD8*)p_kernel, - p_inp_shape1, - p_permute_vec, - kNnlibMaxDim, - kNnlibMaxDim); - - scratch_size = xa_nn_conv2d_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - y_stride, - y_padding, - x_stride, - x_padding, - out_height, - out_width, - out_channels, - inp_precision, - kernel_precision, - out_data_format); - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - - p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - for (int _n = 0; _n < batches; _n++) { - UWORD8* in_batch = pin + _n * input_channels * input_height * input_width; - UWORD8* out_batch = p_out + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_per_chan_sym8sxasym8s( - (WORD8*)out_batch, - (WORD8*)in_batch, - (WORD8*)pkernel, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - out_channels, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - out_data_format, - p_scratch); - } -} - -void quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - xa_opt_quantized_conv_nchw_asym8uxsym8u_asym8u( - ctx, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp deleted file mode 100644 index 3e2c9c58401..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp +++ /dev/null @@ -1,203 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -// Specialized depthwise NCHW convolution for int8 x int8 -> int8 -void xa_opt_quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s( - KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - constexpr int kNnlibMaxDim = 4; - - WORD8* __restrict__ p_out = - (WORD8* __restrict__)out.mutable_data_ptr(); - WORD8* __restrict__ p_inp = - (WORD8* __restrict__)input.const_data_ptr(); - WORD8* __restrict__ p_kernel = - (WORD8* __restrict__)weight.const_data_ptr(); - WORD32* __restrict__ p_bias = - (WORD32* __restrict__)bias.const_data_ptr(); - - WORD32 input_height = conv1d ? 1 : input.size(2); - WORD32 input_width = conv1d ? input.size(2) : input.size(3); - WORD32 input_channels = input.size(1); - WORD32 kernel_height = conv1d ? 1 : weight.size(2); - WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); - WORD32 out_channels = weight.size(0); - WORD32 out_height = conv1d ? 1 : out.size(2); - WORD32 out_width = conv1d ? out.size(2) : out.size(3); - WORD32 batches = input.size(0); - - WORD32 x_stride = stride[1]; - WORD32 y_stride = stride[0]; - WORD32 x_padding = padding[1]; - WORD32 y_padding = padding[0]; - - WORD32 input_zero_bias = -in_zero_point; - WORD32 out_zero_bias = output_zero_point; - WORD32 inp_precision = 8; - - WORD32 channels_multiplier = out_channels / input_channels; - - WORD32 out_multiplier32[out_channels]; - WORD32 out_shift32[out_channels]; - - float out_scale = 1. / output_scale; - - for (int i = 0; i < out_channels; i++) { - out_multiplier32[i] = bias_scale * out_scale * 2147483648; - out_shift32[i] = 0; - } - - WORD32 scratch_size = xa_nn_conv2d_depthwise_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - inp_precision, - 1); // NCHW - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - WORD32* ptr_scratch = - (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( - ctx, - ((batches * out_channels * out_height * out_width) + 8) * sizeof(WORD8)); - - WORD8* p_out_temp = (WORD8*)ALIGN_PTR(ptr1, 8); - - for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; - WORD8* out_batch = p_out_temp + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( - out_batch, - p_kernel, - in_batch, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - 1, // NCHW - 0, // NHWC - p_scratch); - } - - WORD32 p_inp_shape[kNnlibMaxDim]; - p_inp_shape[0] = batches; - p_inp_shape[1] = out_height; - p_inp_shape[2] = out_width; - p_inp_shape[3] = out_channels; - - WORD32 p_out_shape[kNnlibMaxDim]; - p_out_shape[0] = batches; - p_out_shape[1] = out_channels; - p_out_shape[2] = out_height; - p_out_shape[3] = out_width; - - WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2}; - - xa_nn_transpose_8_8( - p_out, - p_out_shape, - p_out_temp, - p_inp_shape, - p_permute_vec, - kNnlibMaxDim, - kNnlibMaxDim); -} - -void quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - xa_opt_quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s( - ctx, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp deleted file mode 100644 index 103ce9568c5..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp +++ /dev/null @@ -1,203 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -// Specialized depthwise NCHW convolution for uint8 x uint8 -> uint8 -void xa_opt_quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u( - KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - constexpr int kNnlibMaxDim = 4; - - UWORD8* __restrict__ p_out = - (UWORD8* __restrict__)out.mutable_data_ptr(); - UWORD8* __restrict__ p_inp = - (UWORD8* __restrict__)input.const_data_ptr(); - UWORD8* __restrict__ p_kernel = - (UWORD8* __restrict__)weight.const_data_ptr(); - WORD32* __restrict__ p_bias = - (WORD32* __restrict__)bias.const_data_ptr(); - - WORD32 input_height = conv1d ? 1 : input.size(2); - WORD32 input_width = conv1d ? input.size(2) : input.size(3); - WORD32 input_channels = input.size(1); - WORD32 kernel_height = conv1d ? 1 : weight.size(2); - WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); - WORD32 out_channels = weight.size(0); - WORD32 out_height = conv1d ? 1 : out.size(2); - WORD32 out_width = conv1d ? out.size(2) : out.size(3); - WORD32 batches = input.size(0); - - WORD32 x_stride = stride[1]; - WORD32 y_stride = stride[0]; - WORD32 x_padding = padding[1]; - WORD32 y_padding = padding[0]; - - WORD32 input_zero_bias = -in_zero_point; - WORD32 out_zero_bias = output_zero_point; - WORD32 inp_precision = 8; - - WORD32 channels_multiplier = out_channels / input_channels; - - WORD32 out_multiplier32[out_channels]; - WORD32 out_shift32[out_channels]; - - float out_scale = 1. / output_scale; - - for (int i = 0; i < out_channels; i++) { - out_multiplier32[i] = bias_scale * out_scale * 2147483648; - out_shift32[i] = 0; - } - - WORD32 scratch_size = xa_nn_conv2d_depthwise_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - inp_precision, - 1); // NCHW - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - WORD32* ptr_scratch = - (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory( - ctx, - ((batches * out_channels * out_height * out_width) + 8) * sizeof(UWORD8)); - - UWORD8* p_out_temp = (UWORD8*)ALIGN_PTR(ptr1, 8); - - for (int _n = 0; _n < batches; _n++) { - UWORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; - UWORD8* out_batch = p_out_temp + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( - (WORD8*)out_batch, - (WORD8*)p_kernel, - (WORD8*)in_batch, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - 1, // NCHW - 0, // NHWC - p_scratch); - } - - WORD32 p_inp_shape[kNnlibMaxDim]; - p_inp_shape[0] = batches; - p_inp_shape[1] = out_height; - p_inp_shape[2] = out_width; - p_inp_shape[3] = out_channels; - - WORD32 p_out_shape[kNnlibMaxDim]; - p_out_shape[0] = batches; - p_out_shape[1] = out_channels; - p_out_shape[2] = out_height; - p_out_shape[3] = out_width; - - WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2}; - - xa_nn_transpose_8_8( - (WORD8*)p_out, - p_out_shape, - (WORD8*)p_out_temp, - p_inp_shape, - p_permute_vec, - kNnlibMaxDim, - kNnlibMaxDim); -} - -void quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - xa_opt_quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u( - ctx, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out.cpp deleted file mode 100644 index cdc1ecd8526..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out.cpp +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -// Dilated fallback implementation for int8 x int8 -> int8 quantized 2d conv -// kernel for NCHW layout. This variant is optimized for asymmetric int8 inputs, -// weights, and outputs. The input is of shape [n x c x h x w] The weight is of -// shape [oc x wc x wh x ww], where wc == c The output is of shape [n x oc x oh -// x ow] The bias is of shape [oc] -template -__attribute__((noinline)) void conv2d_nchw_dilated_asym8sxsym8s_asym8s_core( - // All the arrays - const int8_t* __restrict__ p_in, - const int8_t* __restrict__ p_weight, - const int32_t* __restrict__ p_bias, - int8_t* __restrict__ p_out, - // The array sizes - int32_t n, - int32_t c, - int32_t h, - int32_t w, - int32_t oc, - int32_t wc, - int32_t wh, - int32_t ww, - int32_t oh, - int32_t ow, - // Stride - int16_t s0, - int16_t s1, - // Padding - int16_t p0, - int16_t p1, - // Dilation - int16_t d0, - int16_t d1, - // Group for depthwise conv - int16_t groups, - // Quantization parameters - int8_t in_zero_point = 0, - int32_t weight_zero_point = 0, - float bias_scale = 1, - float out_scale = 1, - int8_t out_zero_point = 0) { - float inv_out_scale = 1. / out_scale; - - // Compute the number of in and out channels per group - const int ocpg = oc / groups; - const int icpg = c / groups; - - // Iterate over all the output batches (i.e., n) - for (int _n = 0; _n < n; ++_n) { - const int8_t* in_batch = p_in + _n * c * h * w; - int8_t* out_batch = p_out + _n * oc * oh * ow; - // Compute separable convolution for each group - for (int _g = 0; _g < groups; ++_g) { - // Identify the input and output channels involved in the computation - // of this group - int sic = _g * icpg; - int soc = _g * ocpg; - // Populate all the output channels in the group - for (int _oc = soc; _oc < soc + ocpg; ++_oc) { - int8_t* out_plane = out_batch + _oc * oh * ow; - const int8_t* weight_batch = p_weight + _oc * wc * wh * ww; - // We compute one output channel at a time. The computation can be - // thought of as a stencil computation: we iterate over an input of size - // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an - // output channel of size 1 x oh x ow. - for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { - for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { - float acc = p_bias[_oc]; - // Below is the stencil computation that performs the hadamard - // product+accumulation of each input channel (contributing to the - // output channel being computed) with the corresponding weight - // channel. - // General path for dilated convolutions with padding support - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - const int8_t* in_plane = in_batch + _ic * h * w; - const int8_t* weight_plane = weight_batch + (_ic - sic) * wh * ww; - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - int input_h = _h + d0 * _wh - p0; - int input_w = _w + d1 * _ww - p1; - if ((input_h >= 0) && (input_h < h) && (input_w >= 0) && - (input_w < w)) { - int ioff = input_h * w + input_w; - int woff = _wh * ww + _ww; - float lhs = static_cast(in_plane[ioff]) - - static_cast(in_zero_point); - float rhs = static_cast(weight_plane[woff]) - - static_cast(weight_zero_point); - acc += lhs * rhs; - } - } - } - } - // Quantize the accumulated result - float val = bias_scale * acc; - out_plane[_oh * ow + _ow] = - kernels::quantize(val, inv_out_scale, out_zero_point); - } - } - } - } - } -} - -void quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - bool conv1d = input.dim() == 3; - // input = [n, c, h, w] - const int n = input.size(0); - const int c = input.size(1); - const int h = conv1d ? 1 : input.size(2); - const int w = conv1d ? input.size(2) : input.size(3); - // weight = [oc, wc, wh, ww] - const int oc = weight.size(0); - const int wc = weight.size(1); - const int wh = conv1d ? 1 : weight.size(2); - const int ww = conv1d ? weight.size(2) : weight.size(3); - // output = [n, oc, oh, ow] - const int oh = conv1d ? 1 : out.size(2); - const int ow = conv1d ? out.size(2) : out.size(3); - - conv2d_nchw_dilated_asym8sxsym8s_asym8s_core( - input.const_data_ptr(), - weight.const_data_ptr(), - bias.const_data_ptr(), - out.mutable_data_ptr(), - n, - c, - h, - w, - oc, - wc, - wh, - ww, - oh, - ow, - stride[0], - stride[1], - padding[0], - padding[1], - dilation[0], - dilation[1], - groups, - static_cast(in_zero_point), - weight_zero_point, - bias_scale, - output_scale, - static_cast(output_zero_point)); -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out.cpp deleted file mode 100644 index 9281dcea496..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out.cpp +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -// Dilated fallback implementation for uint8 x uint8 -> uint8 quantized 2d conv -// kernel for NCHW layout. This variant is optimized for asymmetric uint8 -// inputs, weights, and outputs. The input is of shape [n x c x h x w] The -// weight is of shape [oc x wc x wh x ww], where wc == c The output is of shape -// [n x oc x oh x ow] The bias is of shape [oc] -template -__attribute__((noinline)) void conv2d_nchw_dilated_asym8uxsym8u_asym8u_core( - // All the arrays - const uint8_t* __restrict__ p_in, - const uint8_t* __restrict__ p_weight, - const int32_t* __restrict__ p_bias, - uint8_t* __restrict__ p_out, - // The array sizes - int32_t n, - int32_t c, - int32_t h, - int32_t w, - int32_t oc, - int32_t wc, - int32_t wh, - int32_t ww, - int32_t oh, - int32_t ow, - // Stride - int16_t s0, - int16_t s1, - // Padding - int16_t p0, - int16_t p1, - // Dilation - int16_t d0, - int16_t d1, - // Group for depthwise conv - int16_t groups, - // Quantization parameters - uint8_t in_zero_point = 0, - int32_t weight_zero_point = 0, - float bias_scale = 1, - float out_scale = 1, - uint8_t out_zero_point = 0) { - float inv_out_scale = 1. / out_scale; - - // Compute the number of in and out channels per group - const int ocpg = oc / groups; - const int icpg = c / groups; - - // Iterate over all the output batches (i.e., n) - for (int _n = 0; _n < n; ++_n) { - const uint8_t* in_batch = p_in + _n * c * h * w; - uint8_t* out_batch = p_out + _n * oc * oh * ow; - // Compute separable convolution for each group - for (int _g = 0; _g < groups; ++_g) { - // Identify the input and output channels involved in the computation - // of this group - int sic = _g * icpg; - int soc = _g * ocpg; - // Populate all the output channels in the group - for (int _oc = soc; _oc < soc + ocpg; ++_oc) { - uint8_t* out_plane = out_batch + _oc * oh * ow; - const uint8_t* weight_batch = p_weight + _oc * wc * wh * ww; - // We compute one output channel at a time. The computation can be - // thought of as a stencil computation: we iterate over an input of size - // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an - // output channel of size 1 x oh x ow. - for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { - for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { - float acc = p_bias[_oc]; - // Below is the stencil computation that performs the hadamard - // product+accumulation of each input channel (contributing to the - // output channel being computed) with the corresponding weight - // channel. - // General path for dilated convolutions with padding support - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - const uint8_t* in_plane = in_batch + _ic * h * w; - const uint8_t* weight_plane = - weight_batch + (_ic - sic) * wh * ww; - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - int input_h = _h + d0 * _wh - p0; - int input_w = _w + d1 * _ww - p1; - if ((input_h >= 0) && (input_h < h) && (input_w >= 0) && - (input_w < w)) { - int ioff = input_h * w + input_w; - int woff = _wh * ww + _ww; - float lhs = static_cast(in_plane[ioff]) - - static_cast(in_zero_point); - float rhs = static_cast(weight_plane[woff]) - - static_cast(weight_zero_point); - acc += lhs * rhs; - } - } - } - } - // Quantize the accumulated result - float val = bias_scale * acc; - out_plane[_oh * ow + _ow] = - kernels::quantize(val, inv_out_scale, out_zero_point); - } - } - } - } - } -} - -void quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - bool conv1d = input.dim() == 3; - // input = [n, c, h, w] - const int n = input.size(0); - const int c = input.size(1); - const int h = conv1d ? 1 : input.size(2); - const int w = conv1d ? input.size(2) : input.size(3); - // weight = [oc, wc, wh, ww] - const int oc = weight.size(0); - const int wc = weight.size(1); - const int wh = conv1d ? 1 : weight.size(2); - const int ww = conv1d ? weight.size(2) : weight.size(3); - // output = [n, oc, oh, ow] - const int oh = conv1d ? 1 : out.size(2); - const int ow = conv1d ? out.size(2) : out.size(3); - - conv2d_nchw_dilated_asym8uxsym8u_asym8u_core( - input.const_data_ptr(), - weight.const_data_ptr(), - bias.const_data_ptr(), - out.mutable_data_ptr(), - n, - c, - h, - w, - oc, - wc, - wh, - ww, - oh, - ow, - stride[0], - stride[1], - padding[0], - padding[1], - dilation[0], - dilation[1], - groups, - static_cast(in_zero_point), - weight_zero_point, - bias_scale, - output_scale, - static_cast(output_zero_point)); -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nchw_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nchw_out.cpp deleted file mode 100644 index fbc97a4c37b..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nchw_out.cpp +++ /dev/null @@ -1,646 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -using ::cadence::impl::HiFi::kernels::quantize; -using ::executorch::aten::IntArrayRef; -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::torch::executor::KernelRuntimeContext; - -// This implements a generic 2d conv kernel that operates on raw pointers. -// The version handles both quantized and fp32 convolutions. -// The input is of shape [n x c x h x w] -// The weight is of shape [oc x wc x wh x ww], where wc == c -// The output is of shape [n x oc x oh x ow] -// The bias is of shape [oc] -template < - typename IT = float, - typename WT = IT, - typename BT = IT, - typename OT = IT, - bool quantized = false> -__attribute__((noinline)) void conv2d_nchw_core_generic( - // All the arrays - const IT* __restrict__ p_in, - const WT* __restrict__ p_weight, - const BT* __restrict__ p_bias, - OT* __restrict__ p_out, - // The array sizes - int32_t n, - int32_t c, - int32_t h, - int32_t w, - int32_t oc, - int32_t wc, - int32_t wh, - int32_t ww, - int32_t oh, - int32_t ow, - // Stride - int16_t s0, - int16_t s1, - // Padding - int16_t p0, - int16_t p1, - // Dilation - int16_t d0, - int16_t d1, - // Group for depthwise conv - int16_t groups, - // Optional args that are only relevant for quantized convolution - // input zero point - IT in_zero_point = 0, - // weight zero point - int32_t weight_zero_point = 0, - float bias_scale = 1, - float out_scale = 1, - OT out_zero_point = 0) { - float inv_out_scale = 1. / out_scale; - bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; - - // Compute the number of in and out channels per group - const int ocpg = oc / groups; - const int icpg = c / groups; - - // Iterate over all the output batches (i.e., n) - for (int _n = 0; _n < n; ++_n) { - const IT* in_batch = p_in + _n * c * h * w; - OT* out_batch = p_out + _n * oc * oh * ow; - // Compute separable convolution for each group - for (int _g = 0; _g < groups; ++_g) { - // Identify the input and output channels involved in the computation - // of this group - int sic = _g * icpg; - int soc = _g * ocpg; - // Populate all the output channels in the group - for (int _oc = soc; _oc < soc + ocpg; ++_oc) { - OT* out_plane = out_batch + _oc * oh * ow; - const WT* weight_batch = p_weight + _oc * wc * wh * ww; - // We compute one output channel at a time. The computation can be - // thought of as a stencil computation: we iterate over an input of size - // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an - // output channel of size 1 x oh x ow. - for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { - for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { - float acc = p_bias[_oc]; - // Below is the stencil computation that performs the hadamard - // product+accumulation of each input channel (contributing to the - // output channel being computed) with the corresponding weight - // channel. - // If the padding is 0, and dilation is 1, then we can remove the - // unnecessary checks, and simplify the code so that it can be - // vectorized by Tensilica compiler. - if (zero_pad_unit_dilation) { - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - const IT* in_plane = in_batch + _ic * h * w; - const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - int ioff = (_h + _wh) * w + (_w + _ww); - int woff = _wh * ww + _ww; - float lhs = in_plane[ioff] - in_zero_point; - float rhs = weight_plane[woff] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } else { - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - const IT* in_plane = in_batch + _ic * h * w; - const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - if (((_h + d0 * _wh - p0) >= 0) && - ((_h + d0 * _wh - p0) < h) && - ((_w + d1 * _ww - p1) >= 0) && - ((_w + d1 * _ww - p1) < w)) { - int ioff = - (_h + d0 * _wh - p0) * w + (_w + d1 * _ww - p1); - int woff = _wh * ww + _ww; - float lhs = in_plane[ioff] - in_zero_point; - float rhs = weight_plane[woff] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } - } - if (quantized) { - float val = bias_scale * acc; - out_plane[_oh * ow + _ow] = - quantize(val, inv_out_scale, out_zero_point); - } else { - out_plane[_oh * ow + _ow] = acc; - } - } - } - } - } - } -} - -void xa_opt_quantized_conv_nchw( - KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - constexpr int kNnlibMaxDim = 4; - - if (input.scalar_type() == ScalarType::Char) { - WORD8* __restrict__ p_out = - (WORD8* __restrict__)out.mutable_data_ptr(); - WORD8* __restrict__ p_inp = - (WORD8* __restrict__)input.const_data_ptr(); - WORD8* __restrict__ p_kernel = - (WORD8* __restrict__)weight.const_data_ptr(); - WORD32* __restrict__ p_bias = - (WORD32* __restrict__)bias.const_data_ptr(); - - WORD32 input_height = conv1d ? 1 : input.size(2); - WORD32 input_width = conv1d ? input.size(2) : input.size(3); - WORD32 input_channels = input.size(1); - WORD32 kernel_height = conv1d ? 1 : weight.size(2); - WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); - WORD32 kernel_channels = weight.size(1); - WORD32 out_channels = weight.size(0); - WORD32 out_height = conv1d ? 1 : out.size(2); - WORD32 out_width = conv1d ? out.size(2) : out.size(3); - WORD32 batches = input.size(0); - - WORD32 x_stride = stride[1]; - WORD32 y_stride = stride[0]; - WORD32 x_padding = padding[1]; - WORD32 y_padding = padding[0]; - WORD32 dilation_width = dilation[1]; - WORD32 dilation_height = dilation[0]; - - // WORD32* kernel_bias_ptr = - // (WORD32*)weight_zero_point.const_data_ptr(); - - WORD32 input_zero_bias = -in_zero_point; - WORD32 kernel_zero_bias = -weight_zero_point; - - WORD32 out_multiplier32[out_channels]; - WORD32 out_shift32[out_channels]; - - float out_scale = 1. / output_scale; - - for (int i = 0; i < out_channels; i++) { - out_multiplier32[i] = bias_scale * out_scale * 2147483648; - out_shift32[i] = 0; - } - - WORD32 out_zero_bias = output_zero_point; - WORD32 inp_precision = 8; - WORD32 kernel_precision = 8; - pVOID p_scratch = nullptr; - WORD32* ptr_scratch; - - WORD32 scratch_size = 0; - - if (groups == 1) { - WORD32 out_data_format = 1; - - WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( - ctx, - ((batches * input_channels * input_height * input_width) + 8) * - sizeof(WORD8)); - - WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory( - ctx, - ((out_channels * kernel_channels * kernel_height * kernel_width) + - 8) * - sizeof(WORD8)); - - WORD8* pin = (WORD8*)ALIGN_PTR(ptr1, 8); - WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8); - - WORD32 p_inp_shape[kNnlibMaxDim]; - p_inp_shape[0] = input.size(0); - p_inp_shape[1] = input_channels; - p_inp_shape[2] = input_height; - p_inp_shape[3] = input_width; - - WORD32 p_out_shape[kNnlibMaxDim]; - p_out_shape[0] = input.size(0); - p_out_shape[1] = input_height; - p_out_shape[2] = input_width; - p_out_shape[3] = input_channels; - - WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 3, 1}; - - xa_nn_transpose_8_8( - pin, - p_out_shape, - p_inp, - p_inp_shape, - p_permute_vec, - kNnlibMaxDim, // input dimensions - kNnlibMaxDim); // output dimensions - - WORD32 p_inp_shape1[kNnlibMaxDim]; - p_inp_shape1[0] = out_channels; - p_inp_shape1[1] = kernel_channels; - p_inp_shape1[2] = kernel_height; - p_inp_shape1[3] = kernel_width; - - WORD32 p_out_shape1[kNnlibMaxDim]; - p_out_shape1[0] = out_channels; - p_out_shape1[1] = kernel_height; - p_out_shape1[2] = kernel_width; - p_out_shape1[3] = kernel_channels; - - xa_nn_transpose_8_8( - pkernel, - p_out_shape1, - p_kernel, - p_inp_shape1, - p_permute_vec, - kNnlibMaxDim, // input dimensions - kNnlibMaxDim); // output dimensions - - scratch_size = xa_nn_conv2d_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - y_stride, - y_padding, - x_stride, - x_padding, - out_height, - out_width, - out_channels, - inp_precision, - kernel_precision, - out_data_format); - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - - p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = - pin + _n * input_channels * input_height * input_width; - WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_per_chan_sym8sxasym8s( - out_batch, - in_batch, - pkernel, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - out_channels, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - out_data_format, - p_scratch); - } - return; - } - - if (groups == input_channels) { - WORD32 channels_multiplier = out_channels / input_channels; - - scratch_size = xa_nn_conv2d_depthwise_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - inp_precision, - 1); // NCHW - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - - p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( - ctx, - ((batches * out_channels * out_height * out_width) + 8) * - sizeof(WORD8)); - - WORD8* p_out_temp = (WORD8*)ALIGN_PTR(ptr1, 8); - - for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = - p_inp + _n * input_channels * input_height * input_width; - WORD8* out_batch = - p_out_temp + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( - out_batch, - p_kernel, - in_batch, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - 1, // NCHW - 0, // NHWC - p_scratch); - } - - WORD32 p_inp_shape[kNnlibMaxDim]; - p_inp_shape[0] = batches; - p_inp_shape[1] = out_height; - p_inp_shape[2] = out_width; - p_inp_shape[3] = out_channels; - - WORD32 p_out_shape[kNnlibMaxDim]; - p_out_shape[0] = batches; - p_out_shape[1] = out_channels; - p_out_shape[2] = out_height; - p_out_shape[3] = out_width; - - WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2}; - - xa_nn_transpose_8_8( - p_out, - p_out_shape, - p_out_temp, - p_inp_shape, - p_permute_vec, - kNnlibMaxDim, // input dimensions - kNnlibMaxDim); // output dimensions - - return; - } - } -} - -// The quantized convolution kernel. in_scale and weight_scale are implicit in -// bias_scale, since it is a product of the two. The kernel will branch to -// quantized::conv1d or quantized::conv2d based on the dimensionality of -// activation tensor. -void quantized_conv_nchw( - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - // input = [n, c, h, w] - const int n = input.size(0); - const int c = input.size(1); - const int h = conv1d ? 1 : input.size(2); - const int w = conv1d ? input.size(2) : input.size(3); - // weight = [oc, wc, wh, ww] - const int oc = weight.size(0); - const int wc = weight.size(1); - const int wh = conv1d ? 1 : weight.size(2); - const int ww = conv1d ? weight.size(2) : weight.size(3); - // output = [n, oc, oh, ow] - const int oh = conv1d ? 1 : out.size(2); - const int ow = conv1d ? out.size(2) : out.size(3); - -#define typed_quantized_conv2d_nchw(ctype, dtype) \ - case ScalarType::dtype: { \ - conv2d_nchw_core_generic( \ - input.const_data_ptr(), \ - weight.const_data_ptr(), \ - bias.const_data_ptr(), \ - out.mutable_data_ptr(), \ - n, \ - c, \ - h, \ - w, \ - oc, \ - wc, \ - wh, \ - ww, \ - oh, \ - ow, \ - stride[0], \ - stride[1], \ - padding[0], \ - padding[1], \ - dilation[0], \ - dilation[1], \ - groups, \ - in_zero_point, \ - weight_zero_point, \ - bias_scale, \ - output_scale, \ - (ctype)output_zero_point); \ - break; \ - } - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nchw); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_conv2d_nchw -} - -void quantized_conv_nchw_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - const Tensor& weight_zero_point, - const Tensor& bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED const Tensor& out_multiplier, - __ET_UNUSED const Tensor& out_shift, - Tensor& out) { - const float bias_scale_float = bias_scale.const_data_ptr()[0]; - const int32_t weight_zero_point_int = - weight_zero_point.const_data_ptr()[0]; - - bool optimized = 0; - - if ((input.scalar_type() == ScalarType::Char) || - (input.scalar_type() == ScalarType::Byte)) - optimized = 1; - - if ((dilation[0] != 1) || (dilation[1] != 1)) - optimized = 0; - - if (optimized) { - xa_opt_quantized_conv_nchw( - ctx, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point_int, - bias_scale_float, - output_scale, - output_zero_point, - out); - } else { - quantized_conv_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point_int, - bias_scale_float, - output_scale, - output_zero_point, - out); - } -} - -void quantized_conv_nchw_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - bool optimized = 0; - - if ((input.scalar_type() == ScalarType::Char) || - (input.scalar_type() == ScalarType::Byte)) - optimized = 1; - - if ((dilation[0] != 1) || (dilation[1] != 1)) - optimized = 0; - - if (optimized) { - xa_opt_quantized_conv_nchw( - ctx, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); - } else { - quantized_conv_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); - } -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp deleted file mode 100644 index b1e023736cf..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -// Optimized NHWC convolution for int8 x int8 -> int8 -void xa_opt_quantized_conv_nhwc_asym8sxsym8s_asym8s( - KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - constexpr int kNnlibMaxDim = 4; - - WORD8* __restrict__ p_out = - (WORD8* __restrict__)out.mutable_data_ptr(); - WORD8* __restrict__ p_inp = - (WORD8* __restrict__)input.const_data_ptr(); - WORD8* __restrict__ p_kernel = - (WORD8* __restrict__)weight.const_data_ptr(); - WORD32* __restrict__ p_bias = - (WORD32* __restrict__)bias.const_data_ptr(); - - WORD32 input_height = conv1d ? 1 : input.size(2); - WORD32 input_width = conv1d ? input.size(2) : input.size(3); - WORD32 input_channels = input.size(1); - WORD32 kernel_height = conv1d ? 1 : weight.size(2); - WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); - WORD32 kernel_channels = weight.size(1); - WORD32 out_channels = weight.size(0); - WORD32 out_height = conv1d ? 1 : out.size(2); - WORD32 out_width = conv1d ? out.size(2) : out.size(3); - WORD32 batches = input.size(0); - - WORD32 x_stride = stride[1]; - WORD32 y_stride = stride[0]; - WORD32 x_padding = padding[1]; - WORD32 y_padding = padding[0]; - WORD32 dilation_width = dilation[1]; - WORD32 dilation_height = dilation[0]; - - WORD32 input_zero_bias = -in_zero_point; - WORD32 kernel_zero_bias = -weight_zero_point; - - WORD32 out_multiplier32[out_channels]; - WORD32 out_shift32[out_channels]; - - float out_scale = 1. / output_scale; - - for (int i = 0; i < out_channels; i++) { - out_multiplier32[i] = bias_scale * out_scale * 2147483648; - out_shift32[i] = 0; - } - - WORD32 out_zero_bias = output_zero_point; - WORD32 inp_precision = 8; - WORD32 kernel_precision = 8; - pVOID p_scratch = nullptr; - WORD32* ptr_scratch; - - WORD32 scratch_size = 0; - - ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution"); - WORD32 out_data_format = 1; - - scratch_size = xa_nn_conv2d_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - y_stride, - y_padding, - x_stride, - x_padding, - out_height, - out_width, - out_channels, - inp_precision, - kernel_precision, - out_data_format); - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - - p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; - WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_per_chan_sym8sxasym8s( - out_batch, - in_batch, - p_kernel, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - out_channels, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - out_data_format, - p_scratch); - } -} - -void quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - xa_opt_quantized_conv_nhwc_asym8sxsym8s_asym8s( - ctx, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp deleted file mode 100644 index 0678cb1b821..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -// Optimized NHWC convolution for uint8 x uint8 -> uint8 -void xa_opt_quantized_conv_nhwc_asym8uxsym8u_asym8u( - KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - constexpr int kNnlibMaxDim = 4; - - UWORD8* __restrict__ p_out = - (UWORD8* __restrict__)out.mutable_data_ptr(); - UWORD8* __restrict__ p_inp = - (UWORD8* __restrict__)input.const_data_ptr(); - UWORD8* __restrict__ p_kernel = - (UWORD8* __restrict__)weight.const_data_ptr(); - WORD32* __restrict__ p_bias = - (WORD32* __restrict__)bias.const_data_ptr(); - - WORD32 input_height = conv1d ? 1 : input.size(2); - WORD32 input_width = conv1d ? input.size(2) : input.size(3); - WORD32 input_channels = input.size(1); - WORD32 kernel_height = conv1d ? 1 : weight.size(2); - WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); - WORD32 kernel_channels = weight.size(1); - WORD32 out_channels = weight.size(0); - WORD32 out_height = conv1d ? 1 : out.size(2); - WORD32 out_width = conv1d ? out.size(2) : out.size(3); - WORD32 batches = input.size(0); - - WORD32 x_stride = stride[1]; - WORD32 y_stride = stride[0]; - WORD32 x_padding = padding[1]; - WORD32 y_padding = padding[0]; - WORD32 dilation_width = dilation[1]; - WORD32 dilation_height = dilation[0]; - - WORD32 input_zero_bias = -in_zero_point; - WORD32 kernel_zero_bias = -weight_zero_point; - - WORD32 out_multiplier32[out_channels]; - WORD32 out_shift32[out_channels]; - - float out_scale = 1. / output_scale; - - for (int i = 0; i < out_channels; i++) { - out_multiplier32[i] = bias_scale * out_scale * 2147483648; - out_shift32[i] = 0; - } - - WORD32 out_zero_bias = output_zero_point; - WORD32 inp_precision = 8; - WORD32 kernel_precision = 8; - pVOID p_scratch = nullptr; - WORD32* ptr_scratch; - - WORD32 scratch_size = 0; - - ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution"); - WORD32 out_data_format = 1; - - scratch_size = xa_nn_conv2d_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - y_stride, - y_padding, - x_stride, - x_padding, - out_height, - out_width, - out_channels, - inp_precision, - kernel_precision, - out_data_format); - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - - p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - for (int _n = 0; _n < batches; _n++) { - UWORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; - UWORD8* out_batch = p_out + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_per_chan_sym8sxasym8s( - (WORD8*)out_batch, - (WORD8*)in_batch, - (WORD8*)p_kernel, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - out_channels, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - out_data_format, - p_scratch); - } -} - -void quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - xa_opt_quantized_conv_nhwc_asym8uxsym8u_asym8u( - ctx, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp deleted file mode 100644 index 6512622f221..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -// Specialized depthwise NHWC convolution for int8 x int8 -> int8 -void xa_opt_quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s( - KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - - WORD8* __restrict__ p_out = - (WORD8* __restrict__)out.mutable_data_ptr(); - WORD8* __restrict__ p_inp = - (WORD8* __restrict__)input.const_data_ptr(); - WORD8* __restrict__ p_kernel = - (WORD8* __restrict__)weight.const_data_ptr(); - WORD32* __restrict__ p_bias = - (WORD32* __restrict__)bias.const_data_ptr(); - - WORD32 input_height = conv1d ? 1 : input.size(2); - WORD32 input_width = conv1d ? input.size(2) : input.size(3); - WORD32 input_channels = input.size(1); - WORD32 kernel_height = conv1d ? 1 : weight.size(2); - WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); - WORD32 out_channels = weight.size(0); - WORD32 out_height = conv1d ? 1 : out.size(2); - WORD32 out_width = conv1d ? out.size(2) : out.size(3); - WORD32 batches = input.size(0); - - WORD32 x_stride = stride[1]; - WORD32 y_stride = stride[0]; - WORD32 x_padding = padding[1]; - WORD32 y_padding = padding[0]; - - WORD32 input_zero_bias = -in_zero_point; - WORD32 out_zero_bias = output_zero_point; - WORD32 inp_precision = 8; - - WORD32 channels_multiplier = out_channels / input_channels; - - WORD32 out_multiplier32[out_channels]; - WORD32 out_shift32[out_channels]; - - float out_scale = 1. / output_scale; - - for (int i = 0; i < out_channels; i++) { - out_multiplier32[i] = bias_scale * out_scale * 2147483648; - out_shift32[i] = 0; - } - - WORD32 scratch_size = xa_nn_conv2d_depthwise_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - inp_precision, - 0); // NHWC - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - WORD32* ptr_scratch = - (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; - WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( - out_batch, - p_kernel, - in_batch, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - 0, // NHWC - 0, // NHWC - p_scratch); - } -} - -void quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - xa_opt_quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s( - ctx, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp deleted file mode 100644 index d41a9c8d4b7..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -// Specialized depthwise NHWC convolution for uint8 x uint8 -> uint8 -void xa_opt_quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u( - KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - - UWORD8* __restrict__ p_out = - (UWORD8* __restrict__)out.mutable_data_ptr(); - UWORD8* __restrict__ p_inp = - (UWORD8* __restrict__)input.const_data_ptr(); - UWORD8* __restrict__ p_kernel = - (UWORD8* __restrict__)weight.const_data_ptr(); - WORD32* __restrict__ p_bias = - (WORD32* __restrict__)bias.const_data_ptr(); - - WORD32 input_height = conv1d ? 1 : input.size(2); - WORD32 input_width = conv1d ? input.size(2) : input.size(3); - WORD32 input_channels = input.size(1); - WORD32 kernel_height = conv1d ? 1 : weight.size(2); - WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); - WORD32 out_channels = weight.size(0); - WORD32 out_height = conv1d ? 1 : out.size(2); - WORD32 out_width = conv1d ? out.size(2) : out.size(3); - WORD32 batches = input.size(0); - - WORD32 x_stride = stride[1]; - WORD32 y_stride = stride[0]; - WORD32 x_padding = padding[1]; - WORD32 y_padding = padding[0]; - - WORD32 input_zero_bias = -in_zero_point; - WORD32 out_zero_bias = output_zero_point; - WORD32 inp_precision = 8; - - WORD32 channels_multiplier = out_channels / input_channels; - - WORD32 out_multiplier32[out_channels]; - WORD32 out_shift32[out_channels]; - - float out_scale = 1. / output_scale; - - for (int i = 0; i < out_channels; i++) { - out_multiplier32[i] = bias_scale * out_scale * 2147483648; - out_shift32[i] = 0; - } - - WORD32 scratch_size = xa_nn_conv2d_depthwise_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - inp_precision, - 0); // NHWC - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - WORD32* ptr_scratch = - (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - for (int _n = 0; _n < batches; _n++) { - UWORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; - UWORD8* out_batch = p_out + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( - (WORD8*)out_batch, - (WORD8*)p_kernel, - (WORD8*)in_batch, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - 0, // NHWC - 0, // NHWC - p_scratch); - } -} - -void quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - xa_opt_quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u( - ctx, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out.cpp deleted file mode 100644 index be661334acf..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out.cpp +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -// Dilated fallback implementation for int8 x int8 -> int8 quantized 2d conv -// kernel for NHWC layout. This variant is optimized for asymmetric int8 inputs, -// weights, and outputs. The input is of shape [n x h x w x c] The weight is of -// shape [oc x wh x ww x wc] The output is of shape [n x oh x ow x oc] The bias -// is of shape [oc] -template -__attribute__((noinline)) void conv2d_nhwc_dilated_asym8sxsym8s_asym8s_core( - // All the arrays - const int8_t* __restrict__ p_in, - const int8_t* __restrict__ p_weight, - const int32_t* __restrict__ p_bias, - int8_t* __restrict__ p_out, - // The array sizes - int32_t n, - int32_t h, - int32_t w, - int32_t c, - int32_t oc, - int32_t wh, - int32_t ww, - int32_t wc, - int32_t oh, - int32_t ow, - // Stride - int16_t s0, - int16_t s1, - // Padding - int16_t p0, - int16_t p1, - // Dilation - int16_t d0, - int16_t d1, - // Group for depthwise conv - int16_t groups, - // Quantization parameters - int8_t in_zero_point = 0, - int32_t weight_zero_point = 0, - float bias_scale = 1, - float out_scale = 1, - int8_t out_zero_point = 0) { - float inv_out_scale = 1. / out_scale; - - // Compute the number of in and out channels per group - const int ocpg = oc / groups; - const int icpg = c / groups; - - // Iterate over all the output batches (i.e., n) - for (int _n = 0; _n < n; ++_n) { - const int8_t* in_batch = p_in + _n * h * w * c; - int8_t* out_batch = p_out + _n * oh * ow * oc; - for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { - for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { - int8_t* out_line = out_batch + (_oh * ow + _ow) * oc; - // Compute separable convolution for each group - for (int _g = 0; _g < groups; ++_g) { - // Identify the input and output channels involved in the computation - // of this group - int sic = _g * icpg; - int soc = _g * ocpg; - // Populate all the output channels in the group - for (int _oc = soc; _oc < soc + ocpg; ++_oc) { - const int8_t* weight_batch = p_weight + _oc * wh * ww * wc; - // We compute one output channel at a time. The computation can be - // thought of as a stencil computation: we iterate over an input of - // size h x w x icpg, with a stencil of size wh x ww x icpg, to - // compute an output channel of size oh x ow x 1. - float acc = p_bias[_oc]; - // Below is the stencil computation that performs the hadamard - // product+accumulation of each input channel (contributing to - // the output channel being computed) with the corresponding - // weight channel. - // General path for dilated convolutions with padding support - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - int input_h = _h + d0 * _wh - p0; - int input_w = _w + d1 * _ww - p1; - if ((input_h >= 0) && (input_h < h) && (input_w >= 0) && - (input_w < w)) { - const int8_t* in_line = - in_batch + input_h * w * c + input_w * c; - const int8_t* weight_line = - weight_batch + _wh * ww * wc + _ww * wc; - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - float lhs = static_cast(in_line[_ic]) - - static_cast(in_zero_point); - float rhs = static_cast(weight_line[_ic - sic]) - - static_cast(weight_zero_point); - acc += lhs * rhs; - } - } - } - } - // Quantize the accumulated result - float val = bias_scale * acc; - out_line[_oc] = - kernels::quantize(val, inv_out_scale, out_zero_point); - } - } - } - } - } -} - -void quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - bool conv1d = input.dim() == 3; - // input = [n, h, w, c] - const int n = input.size(0); - const int h = conv1d ? 1 : input.size(1); - const int w = conv1d ? input.size(1) : input.size(2); - const int c = conv1d ? input.size(2) : input.size(3); - // weight = [oc, wh, ww, wc] - const int oc = weight.size(0); - const int wh = conv1d ? 1 : weight.size(1); - const int ww = conv1d ? weight.size(1) : weight.size(2); - const int wc = conv1d ? weight.size(2) : weight.size(3); - // output = [n, oh, ow, oc] - const int oh = conv1d ? 1 : out.size(1); - const int ow = conv1d ? out.size(1) : out.size(2); - - conv2d_nhwc_dilated_asym8sxsym8s_asym8s_core( - input.const_data_ptr(), - weight.const_data_ptr(), - bias.const_data_ptr(), - out.mutable_data_ptr(), - n, - h, - w, - c, - oc, - wh, - ww, - wc, - oh, - ow, - stride[0], - stride[1], - padding[0], - padding[1], - dilation[0], - dilation[1], - groups, - static_cast(in_zero_point), - weight_zero_point, - bias_scale, - output_scale, - static_cast(output_zero_point)); -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out.cpp deleted file mode 100644 index cab4897f5f0..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out.cpp +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -// Dilated fallback implementation for uint8 x uint8 -> uint8 quantized 2d conv -// kernel for NHWC layout. This variant is optimized for asymmetric uint8 -// inputs, weights, and outputs. The input is of shape [n x h x w x c] The -// weight is of shape [oc x wh x ww x wc] The output is of shape [n x oh x ow x -// oc] The bias is of shape [oc] -template -__attribute__((noinline)) void conv2d_nhwc_dilated_asym8uxsym8u_asym8u_core( - // All the arrays - const uint8_t* __restrict__ p_in, - const uint8_t* __restrict__ p_weight, - const int32_t* __restrict__ p_bias, - uint8_t* __restrict__ p_out, - // The array sizes - int32_t n, - int32_t h, - int32_t w, - int32_t c, - int32_t oc, - int32_t wh, - int32_t ww, - int32_t wc, - int32_t oh, - int32_t ow, - // Stride - int16_t s0, - int16_t s1, - // Padding - int16_t p0, - int16_t p1, - // Dilation - int16_t d0, - int16_t d1, - // Group for depthwise conv - int16_t groups, - // Quantization parameters - uint8_t in_zero_point = 0, - int32_t weight_zero_point = 0, - float bias_scale = 1, - float out_scale = 1, - uint8_t out_zero_point = 0) { - float inv_out_scale = 1. / out_scale; - - // Compute the number of in and out channels per group - const int ocpg = oc / groups; - const int icpg = c / groups; - - // Iterate over all the output batches (i.e., n) - for (int _n = 0; _n < n; ++_n) { - const uint8_t* in_batch = p_in + _n * h * w * c; - uint8_t* out_batch = p_out + _n * oh * ow * oc; - for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { - for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { - uint8_t* out_line = out_batch + (_oh * ow + _ow) * oc; - // Compute separable convolution for each group - for (int _g = 0; _g < groups; ++_g) { - // Identify the input and output channels involved in the computation - // of this group - int sic = _g * icpg; - int soc = _g * ocpg; - // Populate all the output channels in the group - for (int _oc = soc; _oc < soc + ocpg; ++_oc) { - const uint8_t* weight_batch = p_weight + _oc * wh * ww * wc; - // We compute one output channel at a time. The computation can be - // thought of as a stencil computation: we iterate over an input of - // size h x w x icpg, with a stencil of size wh x ww x icpg, to - // compute an output channel of size oh x ow x 1. - float acc = p_bias[_oc]; - // Below is the stencil computation that performs the hadamard - // product+accumulation of each input channel (contributing to - // the output channel being computed) with the corresponding - // weight channel. - // General path for dilated convolutions with padding support - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - int input_h = _h + d0 * _wh - p0; - int input_w = _w + d1 * _ww - p1; - if ((input_h >= 0) && (input_h < h) && (input_w >= 0) && - (input_w < w)) { - const uint8_t* in_line = - in_batch + input_h * w * c + input_w * c; - const uint8_t* weight_line = - weight_batch + _wh * ww * wc + _ww * wc; - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - float lhs = static_cast(in_line[_ic]) - - static_cast(in_zero_point); - float rhs = static_cast(weight_line[_ic - sic]) - - static_cast(weight_zero_point); - acc += lhs * rhs; - } - } - } - } - // Quantize the accumulated result - float val = bias_scale * acc; - out_line[_oc] = - kernels::quantize(val, inv_out_scale, out_zero_point); - } - } - } - } - } -} - -void quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - bool conv1d = input.dim() == 3; - // input = [n, h, w, c] - const int n = input.size(0); - const int h = conv1d ? 1 : input.size(1); - const int w = conv1d ? input.size(1) : input.size(2); - const int c = conv1d ? input.size(2) : input.size(3); - // weight = [oc, wh, ww, wc] - const int oc = weight.size(0); - const int wh = conv1d ? 1 : weight.size(1); - const int ww = conv1d ? weight.size(1) : weight.size(2); - const int wc = conv1d ? weight.size(2) : weight.size(3); - // output = [n, oh, ow, oc] - const int oh = conv1d ? 1 : out.size(1); - const int ow = conv1d ? out.size(1) : out.size(2); - - conv2d_nhwc_dilated_asym8uxsym8u_asym8u_core( - input.const_data_ptr(), - weight.const_data_ptr(), - bias.const_data_ptr(), - out.mutable_data_ptr(), - n, - h, - w, - c, - oc, - wh, - ww, - wc, - oh, - ow, - stride[0], - stride[1], - padding[0], - padding[1], - dilation[0], - dilation[1], - groups, - static_cast(in_zero_point), - weight_zero_point, - bias_scale, - output_scale, - static_cast(output_zero_point)); -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_out.cpp deleted file mode 100644 index 8af7c0da3ef..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_out.cpp +++ /dev/null @@ -1,552 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) - -using Tensor = executorch::aten::Tensor; -using KernelRuntimeContext = torch::executor::KernelRuntimeContext; -using ScalarType = executorch::aten::ScalarType; -using ::executorch::aten::IntArrayRef; - -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { - -template < - typename IT = float, - typename WT = IT, - typename BT = IT, - typename OT = IT, - bool quantized = false> -__attribute__((noinline)) void conv2d_nhwc_core_generic( - // All the arrays - const IT* __restrict__ p_in, - const WT* __restrict__ p_weight, - const BT* __restrict__ p_bias, - OT* __restrict__ p_out, - // The array sizes - int32_t n, - int32_t h, - int32_t w, - int32_t c, - int32_t oc, - int32_t wh, - int32_t ww, - int32_t wc, - int32_t oh, - int32_t ow, - // Stride - int16_t s0, - int16_t s1, - // Padding - int16_t p0, - int16_t p1, - // Dilation - int16_t d0, - int16_t d1, - // Group for depthwise conv - int16_t groups, - // Optional args that are only relevant for quantized convolution - // input zero point - IT in_zero_point = 0, - // weight zero point - int32_t weight_zero_point = 0, - float bias_scale = 1, - float out_scale = 1, - OT out_zero_point = 0) { - float inv_out_scale = 1. / out_scale; - bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; - - // Compute the number of in and out channels per group - const int ocpg = oc / groups; - const int icpg = c / groups; - - // Iterate over all the output batches (i.e., n) - for (int _n = 0; _n < n; ++_n) { - const IT* in_batch = p_in + _n * h * w * c; - OT* out_batch = p_out + _n * oh * ow * oc; - for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { - for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { - OT* out_line = out_batch + (_oh * ow + _ow) * oc; - // Compute separable convolution for each group - for (int _g = 0; _g < groups; ++_g) { - // Identify the input and output channels involved in the computation - // of this group - int sic = _g * icpg; - int soc = _g * ocpg; - // Populate all the output channels in the group - for (int _oc = soc; _oc < soc + ocpg; ++_oc) { - const WT* weight_batch = p_weight + _oc * wh * ww * wc; - // We compute one output channel at a time. The computation can be - // thought of as a stencil computation: we iterate over an input of - // size h x w x icpg, with a stencil of size wh x ww x icpg, to - // compute an output channel of size oh x ow x 1. - float acc = p_bias[_oc]; - // Below is the stencil computation that performs the hadamard - // product+accumulation of each input channel (contributing to - // the output channel being computed) with the corresponding - // weight channel. If the padding is 0, and dilation is 1, then - // we can remove the unnecessary checks, and simplify the code - // so that it can be vectorized by Tensilica compiler.x`` - if (zero_pad_unit_dilation) { - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - const IT* in_line = - in_batch + (_h + _wh) * w * c + (_w + _ww) * c; - const WT* weight_line = - weight_batch + _wh * ww * wc + _ww * wc; - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - float lhs = in_line[_ic] - in_zero_point; - float rhs = weight_line[_ic - sic] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } else { - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - if (((_h + d0 * _wh - p0) >= 0) && - ((_h + d0 * _wh - p0) < h) && - ((_w + d1 * _ww - p1) >= 0) && - ((_w + d1 * _ww - p1 < w))) { - const IT* in_line = in_batch + - (_h + d0 * _wh - p0) * w * c + (_w + d1 * _ww - p1) * c; - const WT* weight_line = - weight_batch + _wh * ww * wc + _ww * wc; - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - float lhs = in_line[_ic] - in_zero_point; - float rhs = weight_line[_ic - sic] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } - } - if (quantized) { - float val = bias_scale * acc; - out_line[_oc] = - kernels::quantize(val, inv_out_scale, out_zero_point); - } else { - out_line[_oc] = acc; - } - } - } - } - } - } -} - -void xa_opt_quantized_conv_nhwc( - KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - constexpr int kNnlibMaxDim = 4; - - if (input.scalar_type() == ScalarType::Char) { - WORD8* __restrict__ p_out = - (WORD8* __restrict__)out.mutable_data_ptr(); - WORD8* __restrict__ p_inp = - (WORD8* __restrict__)input.const_data_ptr(); - WORD8* __restrict__ p_kernel = - (WORD8* __restrict__)weight.const_data_ptr(); - WORD32* __restrict__ p_bias = - (WORD32* __restrict__)bias.const_data_ptr(); - - WORD32 input_height = conv1d ? 1 : input.size(2); - WORD32 input_width = conv1d ? input.size(2) : input.size(3); - WORD32 input_channels = input.size(1); - WORD32 kernel_height = conv1d ? 1 : weight.size(2); - WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); - WORD32 kernel_channels = weight.size(1); - WORD32 out_channels = weight.size(0); - WORD32 out_height = conv1d ? 1 : out.size(2); - WORD32 out_width = conv1d ? out.size(2) : out.size(3); - WORD32 batches = input.size(0); - - WORD32 x_stride = stride[1]; - WORD32 y_stride = stride[0]; - WORD32 x_padding = padding[1]; - WORD32 y_padding = padding[0]; - WORD32 dilation_width = dilation[1]; - WORD32 dilation_height = dilation[0]; - - // WORD32* kernel_bias_ptr = - // (WORD32*)weight_zero_point.const_data_ptr(); - - WORD32 input_zero_bias = -in_zero_point; - WORD32 kernel_zero_bias = -weight_zero_point; - - WORD32 out_multiplier32[out_channels]; - WORD32 out_shift32[out_channels]; - - float out_scale = 1. / output_scale; - - for (int i = 0; i < out_channels; i++) { - out_multiplier32[i] = bias_scale * out_scale * 2147483648; - out_shift32[i] = 0; - } - - WORD32 out_zero_bias = output_zero_point; - WORD32 inp_precision = 8; - WORD32 kernel_precision = 8; - pVOID p_scratch = nullptr; - WORD32* ptr_scratch; - - WORD32 scratch_size = 0; - - if (groups == 1) { - WORD32 out_data_format = 1; - - scratch_size = xa_nn_conv2d_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - y_stride, - y_padding, - x_stride, - x_padding, - out_height, - out_width, - out_channels, - inp_precision, - kernel_precision, - out_data_format); - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - - p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = - p_inp + _n * input_channels * input_height * input_width; - WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_per_chan_sym8sxasym8s( - out_batch, - in_batch, - p_kernel, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - kernel_channels, - dilation_height, - dilation_width, - out_channels, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - out_data_format, - p_scratch); - } - return; - } - - if (groups == input_channels) { - WORD32 channels_multiplier = out_channels / input_channels; - - scratch_size = xa_nn_conv2d_depthwise_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - inp_precision, - 0); // NHWC - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - - p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( - ctx, - ((batches * out_channels * out_height * out_width) + 8) * - sizeof(WORD8)); - - WORD8* p_out_temp = (WORD8*)ALIGN_PTR(ptr1, 8); - - for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = - p_inp + _n * input_channels * input_height * input_width; - WORD8* out_batch = - p_out_temp + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( - out_batch, - p_kernel, - in_batch, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - 0, // NHWC - 0, // NHWC - p_scratch); - } - - return; - } - } -} - -void quantized_conv_nhwc( - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - // input = [n, h, w, c] - const int n = input.size(0); - const int h = conv1d ? 1 : input.size(1); - const int w = conv1d ? input.size(1) : input.size(2); - const int c = conv1d ? input.size(2) : input.size(3); - // weight = [oc, wh, ww, wc] - const int oc = weight.size(0); - const int wh = conv1d ? 1 : weight.size(1); - const int ww = conv1d ? weight.size(1) : weight.size(2); - const int wc = conv1d ? weight.size(2) : weight.size(3); - // output = [n, oh, ow, oc] - const int oh = conv1d ? 1 : out.size(1); - const int ow = conv1d ? out.size(1) : out.size(2); - -#define typed_quantized_conv2d_nhwc(ctype, dtype) \ - case ScalarType::dtype: { \ - conv2d_nhwc_core_generic( \ - input.const_data_ptr(), \ - weight.const_data_ptr(), \ - bias.const_data_ptr(), \ - out.mutable_data_ptr(), \ - n, \ - h, \ - w, \ - c, \ - oc, \ - wh, \ - ww, \ - wc, \ - oh, \ - ow, \ - stride[0], \ - stride[1], \ - padding[0], \ - padding[1], \ - dilation[0], \ - dilation[1], \ - groups, \ - in_zero_point, \ - weight_zero_point, \ - bias_scale, \ - output_scale, \ - (ctype)output_zero_point); \ - break; \ - } - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nhwc); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_conv2d_nhwc -} - -void quantized_conv_nhwc_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - const Tensor& weight_zero_point, - const Tensor& bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED const Tensor& out_multiplier, - __ET_UNUSED const Tensor& out_shift, - Tensor& out) { - const float bias_scale_float = bias_scale.const_data_ptr()[0]; - const int32_t weight_zero_point_int = - weight_zero_point.const_data_ptr()[0]; - - bool optimized = 0; - - if ((input.scalar_type() == ScalarType::Char) || - (input.scalar_type() == ScalarType::Byte)) - optimized = 1; - - if ((dilation[0] != 1) || (dilation[1] != 1)) - optimized = 0; - - if (optimized) { - xa_opt_quantized_conv_nhwc( - ctx, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point_int, - bias_scale_float, - output_scale, - output_zero_point, - out); - } else { - quantized_conv_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point_int, - bias_scale_float, - output_scale, - output_zero_point, - out); - } -} - -void quantized_conv_nhwc_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - bool optimized = 0; - - if ((input.scalar_type() == ScalarType::Char) || - (input.scalar_type() == ScalarType::Byte)) - optimized = 1; - - if ((dilation[0] != 1) || (dilation[1] != 1)) - optimized = 0; - - if (optimized) { - xa_opt_quantized_conv_nhwc( - ctx, - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); - } else { - quantized_conv_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); - } -} - -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out.cpp index 5e3a5173f32..1c5e725a023 100644 --- a/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out.cpp @@ -9,7 +9,6 @@ #include #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -61,4 +60,3 @@ void quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out.cpp index 80509fdd5db..fe94700ebef 100644 --- a/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out.cpp @@ -9,7 +9,6 @@ #include #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -61,4 +60,3 @@ void quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_fully_connected_out.cpp b/backends/cadence/hifi/operators/op_quantized_fully_connected_out.cpp index 66c2e997142..64df7981ad1 100644 --- a/backends/cadence/hifi/operators/op_quantized_fully_connected_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_fully_connected_out.cpp @@ -12,7 +12,6 @@ #include #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -271,4 +270,3 @@ void quantized_fully_connected_per_tensor_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_layer_norm.cpp b/backends/cadence/hifi/operators/op_quantized_layer_norm.cpp index 7906f245b03..5510385b14f 100644 --- a/backends/cadence/hifi/operators/op_quantized_layer_norm.cpp +++ b/backends/cadence/hifi/operators/op_quantized_layer_norm.cpp @@ -13,15 +13,14 @@ #include #include -using ::cadence::impl::HiFi::kernels::dequantize; -using ::cadence::impl::HiFi::kernels::quantize; using ::executorch::aten::IntArrayRef; using ::executorch::aten::ScalarType; using ::executorch::aten::Tensor; using ::executorch::runtime::getLeadingDims; using ::executorch::runtime::KernelRuntimeContext; +using ::impl::HiFi::kernels::dequantize; +using ::impl::HiFi::kernels::quantize; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -195,7 +194,6 @@ void quantized_layer_norm_per_tensor_out( #undef typed_quantized_layer_norm } -}; // namespace native -}; // namespace HiFi -}; // namespace impl -}; // namespace cadence +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_linear_asym8sxasym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_linear_asym8sxasym8s_asym8s_per_tensor_out.cpp index 7b8ab8e91b9..76ab229d606 100644 --- a/backends/cadence/hifi/operators/op_quantized_linear_asym8sxasym8s_asym8s_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_linear_asym8sxasym8s_asym8s_per_tensor_out.cpp @@ -10,7 +10,6 @@ #include #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -72,4 +71,3 @@ void quantized_linear_asym8sxasym8s_asym8s_per_tensor_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_linear_asym8uxasym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_linear_asym8uxasym8u_asym8u_per_tensor_out.cpp index e9632e77eeb..be5246b8cc8 100644 --- a/backends/cadence/hifi/operators/op_quantized_linear_asym8uxasym8u_asym8u_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_linear_asym8uxasym8u_asym8u_per_tensor_out.cpp @@ -10,7 +10,6 @@ #include #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -72,4 +71,3 @@ void quantized_linear_asym8uxasym8u_asym8u_per_tensor_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_linear_out.cpp b/backends/cadence/hifi/operators/op_quantized_linear_out.cpp index 4bf71cd8838..e5d63b87a1b 100644 --- a/backends/cadence/hifi/operators/op_quantized_linear_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_linear_out.cpp @@ -6,19 +6,20 @@ * LICENSE file in the root directory of this source tree. */ -#include #include -#include -#include -#include + #include #include #include -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { +#include +#include + +#include +#include +#include + +namespace impl::HiFi::native { using ::executorch::aten::ScalarType; using ::executorch::aten::Tensor; @@ -219,7 +220,22 @@ void quantized_linear_out( int64_t out_zero_point, __ET_UNUSED const optional& offset, Tensor& out) { - if (out.scalar_type() == executorch::aten::ScalarType::Byte) { + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && + in.scalar_type() == ::executorch::aten::ScalarType::Short && + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { + ::impl::generic::native::quantized_linear_out( + ctx, + in, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + offset, + out); + } else if (out.scalar_type() == executorch::aten::ScalarType::Byte) { _quantized_linear_asym8u( in, weight, @@ -250,18 +266,33 @@ void quantized_linear_out( } void quantized_linear_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, + KernelRuntimeContext& ctx, const Tensor& in, const Tensor& weight, const Tensor& bias, - int64_t in_zero_point, - int64_t weight_zero_point, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - __ET_UNUSED const optional& offset, + const int64_t in_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + const optional& offset, Tensor& out) { - if (out.scalar_type() == executorch::aten::ScalarType::Byte) { + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && + in.scalar_type() == ::executorch::aten::ScalarType::Short && + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { + ::impl::generic::native::quantized_linear_per_tensor_out( + ctx, + in, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + offset, + out); + } else if (out.scalar_type() == executorch::aten::ScalarType::Byte) { _quantized_linear_per_tensor_asym8u( in, weight, @@ -291,7 +322,4 @@ void quantized_linear_per_tensor_out( } } -}; // namespace native -}; // namespace HiFi -}; // namespace impl -}; // namespace cadence +} // namespace impl::HiFi::native diff --git a/backends/cadence/hifi/operators/op_quantized_matmul_asym8sxasym8s_asym8s_out.cpp b/backends/cadence/hifi/operators/op_quantized_matmul_asym8sxasym8s_asym8s_out.cpp index 0e7b3f1a2aa..3fc9de4697d 100644 --- a/backends/cadence/hifi/operators/op_quantized_matmul_asym8sxasym8s_asym8s_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_matmul_asym8sxasym8s_asym8s_out.cpp @@ -15,7 +15,6 @@ using executorch::aten::Tensor; using executorch::runtime::getLeadingDims; using torch::executor::RuntimeContext; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -132,4 +131,3 @@ void quantized_matmul_asym8sxasym8s_asym8s_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_matmul_asym8uxasym8u_asym8u_out.cpp b/backends/cadence/hifi/operators/op_quantized_matmul_asym8uxasym8u_asym8u_out.cpp index 7016e6635dc..43567ff0d11 100644 --- a/backends/cadence/hifi/operators/op_quantized_matmul_asym8uxasym8u_asym8u_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_matmul_asym8uxasym8u_asym8u_out.cpp @@ -15,7 +15,6 @@ using executorch::aten::Tensor; using executorch::runtime::getLeadingDims; using torch::executor::RuntimeContext; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -132,4 +131,3 @@ void quantized_matmul_asym8uxasym8u_asym8u_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_matmul_out.cpp b/backends/cadence/hifi/operators/op_quantized_matmul_out.cpp index 024558b7c85..7a4c4229b35 100644 --- a/backends/cadence/hifi/operators/op_quantized_matmul_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_matmul_out.cpp @@ -6,19 +6,18 @@ * LICENSE file in the root directory of this source tree. */ +#include + +#include #include #include -#include -using executorch::aten::ScalarType; -using executorch::aten::Tensor; -using executorch::runtime::getLeadingDims; -using torch::executor::RuntimeContext; +namespace impl::HiFi::native { -namespace cadence { -namespace impl { -namespace HiFi { -namespace native { +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::getLeadingDims; +using ::torch::executor::RuntimeContext; // The quantized matmul. The quantized matmul accumulates in a wider register, // whose type is TA. @@ -193,8 +192,20 @@ void quantized_matmul_out( size_t leading_dim = X.size(X.dim() - 2); size_t out_dim = Y.size(Y.dim() - 1 - transposed); size_t in_dim = X.size(X.dim() - 1); - - if (out.scalar_type() == exec_aten::ScalarType::Byte) { + if (out.scalar_type() == exec_aten::ScalarType::Short) { + ::impl::generic::native::quantized_matmul_out( + ctx, + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + out); + } else if (out.scalar_type() == exec_aten::ScalarType::Byte) { _typed_quantized_matmul( ctx, X, @@ -229,7 +240,4 @@ void quantized_matmul_out( } } -} // namespace native -} // namespace HiFi -} // namespace impl -} // namespace cadence \ No newline at end of file +} // namespace impl::HiFi::native diff --git a/backends/cadence/hifi/operators/op_quantized_matmul_out.h b/backends/cadence/hifi/operators/op_quantized_matmul_out.h new file mode 100644 index 00000000000..c53a07b58aa --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_matmul_out.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "executorch/runtime/core/exec_aten/exec_aten.h" +#include "executorch/runtime/kernel/kernel_runtime_context.h" + +namespace impl { +namespace HiFi { +namespace native { + +::executorch::aten::Tensor& quantized_matmul_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + int64_t X_zero_point, + const ::executorch::aten::Tensor& Y, + int64_t Y_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_relu_asym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_relu_asym8s_asym8s_per_tensor_out.cpp index deae48d4411..7280a2038f0 100644 --- a/backends/cadence/hifi/operators/op_quantized_relu_asym8s_asym8s_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_relu_asym8s_asym8s_per_tensor_out.cpp @@ -10,7 +10,6 @@ #include #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -49,4 +48,3 @@ void quantized_relu_asym8s_asym8s_per_tensor_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_relu_asym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_relu_asym8u_asym8u_per_tensor_out.cpp index 8aaca463cf9..382f5e6b679 100644 --- a/backends/cadence/hifi/operators/op_quantized_relu_asym8u_asym8u_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_relu_asym8u_asym8u_per_tensor_out.cpp @@ -10,7 +10,6 @@ #include #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -49,4 +48,3 @@ void quantized_relu_asym8u_asym8u_per_tensor_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_relu_out.cpp b/backends/cadence/hifi/operators/op_quantized_relu_out.cpp index 9b65751da71..80c02a79e93 100644 --- a/backends/cadence/hifi/operators/op_quantized_relu_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_relu_out.cpp @@ -6,18 +6,18 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include -using executorch::aten::ScalarType; -using executorch::aten::Tensor; -using torch::executor::KernelRuntimeContext; - -namespace cadence { namespace impl { namespace HiFi { namespace native { +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + void quantized_relu_per_tensor_out( KernelRuntimeContext& ctx, const Tensor& input, @@ -35,7 +35,10 @@ void quantized_relu_per_tensor_out( const uint8_t* p_in = input.const_data_ptr(); uint8_t* p_out = output.mutable_data_ptr(); - WORD32 ret_val = xa_nn_vec_relu_asym8u_asym8u( + XT_KERNEL_CHECK( + ctx, + , + xa_nn_vec_relu_asym8u_asym8u, p_out, p_in, _in_zero_point, @@ -46,15 +49,16 @@ void quantized_relu_per_tensor_out( 255, input.numel()); - ET_CHECK_MSG(ret_val == 0, "An internal error occured"); - } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { - const int8_t _in_zero_point = static_cast(in_zero_point); - const int8_t _out_zero_point = static_cast(out_zero_point); + const int _in_zero_point = static_cast(in_zero_point); + const int _out_zero_point = static_cast(out_zero_point); const int8_t* p_in = input.const_data_ptr(); int8_t* p_out = output.mutable_data_ptr(); - WORD32 ret_val = xa_nn_vec_relu_asym8s_asym8s( + XT_KERNEL_CHECK( + ctx, + , + xa_nn_vec_relu_asym8s_asym8s, p_out, p_in, _in_zero_point, @@ -65,8 +69,6 @@ void quantized_relu_per_tensor_out( 127, input.numel()); - ET_CHECK_MSG(ret_val == 0, "An internal error occured"); - } else { ET_CHECK_MSG( false, @@ -118,4 +120,3 @@ void quantized_relu_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_remainder.cpp b/backends/cadence/hifi/operators/op_remainder.cpp index 99cd6ad544e..f73fbb5c3cc 100644 --- a/backends/cadence/hifi/operators/op_remainder.cpp +++ b/backends/cadence/hifi/operators/op_remainder.cpp @@ -40,7 +40,6 @@ using torch::executor::native::utils::remainder_override; using torch::executor::native::utils::scalar_to; using torch::executor::native::utils::SupportedTensorDtypes; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -280,4 +279,3 @@ Tensor& remainder_Scalar_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_rsqrt.cpp b/backends/cadence/hifi/operators/op_rsqrt.cpp index 81a20398087..4a1734b355a 100644 --- a/backends/cadence/hifi/operators/op_rsqrt.cpp +++ b/backends/cadence/hifi/operators/op_rsqrt.cpp @@ -15,7 +15,6 @@ using executorch::aten::RuntimeContext; using executorch::aten::ScalarType; using executorch::aten::Tensor; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -53,4 +52,3 @@ Tensor& rsqrt_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_select_copy.cpp b/backends/cadence/hifi/operators/op_select_copy.cpp index 520cd4103d1..6bb80dc0c50 100644 --- a/backends/cadence/hifi/operators/op_select_copy.cpp +++ b/backends/cadence/hifi/operators/op_select_copy.cpp @@ -16,7 +16,6 @@ using torch::executor::Error; using torch::executor::KernelRuntimeContext; using torch::executor::select_copy_util; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -37,4 +36,3 @@ Tensor& select_copy_int_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_sigmoid.cpp b/backends/cadence/hifi/operators/op_sigmoid.cpp index 872d9255bd7..dd46f13b784 100644 --- a/backends/cadence/hifi/operators/op_sigmoid.cpp +++ b/backends/cadence/hifi/operators/op_sigmoid.cpp @@ -19,7 +19,6 @@ using executorch::aten::ScalarType; using executorch::aten::Tensor; using torch::executor::Error; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -89,4 +88,3 @@ Tensor& sigmoid_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_slice_copy.cpp b/backends/cadence/hifi/operators/op_slice_copy.cpp index 680bae4630c..ff447461d6e 100644 --- a/backends/cadence/hifi/operators/op_slice_copy.cpp +++ b/backends/cadence/hifi/operators/op_slice_copy.cpp @@ -21,7 +21,6 @@ using torch::executor::Error; using torch::executor::get_slice_copy_out_target_size; using torch::executor::KernelRuntimeContext; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -65,7 +64,7 @@ Tensor& slice_copy_Tensor_out( InvalidArgument, out); - compute_slice(in, dim, start, length, step, out); + compute_slice(ctx, in, dim, start, length, step, out); return out; } @@ -73,4 +72,3 @@ Tensor& slice_copy_Tensor_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_softmax.cpp b/backends/cadence/hifi/operators/op_softmax.cpp index be496813ce8..bf42030e35d 100644 --- a/backends/cadence/hifi/operators/op_softmax.cpp +++ b/backends/cadence/hifi/operators/op_softmax.cpp @@ -19,7 +19,6 @@ using executorch::aten::Tensor; using executorch::runtime::KernelRuntimeContext; using torch::executor::Error; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -234,4 +233,3 @@ Tensor& softmax_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_softmax_f32_f32.cpp b/backends/cadence/hifi/operators/op_softmax_f32_f32.cpp index bbcf2c66c3d..074ff29b301 100644 --- a/backends/cadence/hifi/operators/op_softmax_f32_f32.cpp +++ b/backends/cadence/hifi/operators/op_softmax_f32_f32.cpp @@ -14,7 +14,6 @@ using executorch::aten::Tensor; using executorch::runtime::KernelRuntimeContext; using torch::executor::Error; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -155,4 +154,3 @@ Tensor& softmax_f32_f32_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_split_with_sizes_copy.cpp b/backends/cadence/hifi/operators/op_split_with_sizes_copy.cpp index 31a8922d91b..f85b81bb74a 100644 --- a/backends/cadence/hifi/operators/op_split_with_sizes_copy.cpp +++ b/backends/cadence/hifi/operators/op_split_with_sizes_copy.cpp @@ -28,7 +28,6 @@ using torch::executor::Error; using torch::executor::KernelRuntimeContext; using torch::executor::linearize_access_indexes; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -157,4 +156,3 @@ void split_with_sizes_copy_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_sub.cpp b/backends/cadence/hifi/operators/op_sub.cpp index c62a04b7b28..08502cdf39e 100644 --- a/backends/cadence/hifi/operators/op_sub.cpp +++ b/backends/cadence/hifi/operators/op_sub.cpp @@ -24,7 +24,6 @@ using executorch::runtime::can_cast; using executorch::runtime::CppTypeToScalarType; using torch::executor::Error; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -222,4 +221,3 @@ Tensor& sub_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_tanh.cpp b/backends/cadence/hifi/operators/op_tanh.cpp index 1132efee3d8..a80dbe9a60b 100644 --- a/backends/cadence/hifi/operators/op_tanh.cpp +++ b/backends/cadence/hifi/operators/op_tanh.cpp @@ -16,7 +16,6 @@ using executorch::aten::ScalarType; using executorch::aten::Tensor; using torch::executor::Error; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -41,4 +40,3 @@ Tensor& tanh_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_view_copy.cpp b/backends/cadence/hifi/operators/op_view_copy.cpp index 03824ffc9ad..bc2cb7f28e6 100644 --- a/backends/cadence/hifi/operators/op_view_copy.cpp +++ b/backends/cadence/hifi/operators/op_view_copy.cpp @@ -20,7 +20,6 @@ using torch::executor::Error; using torch::executor::get_view_copy_target_size; using torch::executor::KernelRuntimeContext; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -70,4 +69,3 @@ Tensor& view_copy_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_where.cpp b/backends/cadence/hifi/operators/op_where.cpp index 94c1684fe09..af06c17f50f 100644 --- a/backends/cadence/hifi/operators/op_where.cpp +++ b/backends/cadence/hifi/operators/op_where.cpp @@ -23,7 +23,6 @@ using torch::executor::native::utils::apply_tritensor_elementwise_fn; using torch::executor::native::utils::get_compute_type; using torch::executor::native::utils::SupportedTensorDtypes; -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -195,4 +194,3 @@ Tensor& where_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/operators.h b/backends/cadence/hifi/operators/operators.h index 5b8a1e253c1..90028535848 100644 --- a/backends/cadence/hifi/operators/operators.h +++ b/backends/cadence/hifi/operators/operators.h @@ -15,7 +15,11 @@ _(uint8_t, Byte) \ _(int8_t, Char) -namespace cadence { +#define ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) + namespace impl { namespace HiFi { namespace native { @@ -84,7 +88,7 @@ void quantized_linear_per_tensor_out( const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, ::executorch::aten::Tensor& out); -void quantized_conv_nhwc_out( +void quantized_conv2d_nhwc_out( ::executorch::runtime::KernelRuntimeContext& ctx, const ::executorch::aten::Tensor& input, const ::executorch::aten::Tensor& weight, @@ -102,7 +106,7 @@ void quantized_conv_nhwc_out( const ::executorch::aten::Tensor& out_shift, ::executorch::aten::Tensor& out); -void quantized_conv_nchw_out( +void quantized_conv2d_nchw_out( ::executorch::runtime::KernelRuntimeContext& ctx, const ::executorch::aten::Tensor& input, const ::executorch::aten::Tensor& weight, @@ -120,7 +124,7 @@ void quantized_conv_nchw_out( const ::executorch::aten::Tensor& out_shift, ::executorch::aten::Tensor& out); -void quantized_conv_nchw_per_tensor_out( +void quantized_conv2d_nchw_per_tensor_out( ::executorch::runtime::KernelRuntimeContext& ctx, const ::executorch::aten::Tensor& input, const ::executorch::aten::Tensor& weight, @@ -138,7 +142,7 @@ void quantized_conv_nchw_per_tensor_out( int64_t out_shift, ::executorch::aten::Tensor& out); -void quantized_conv_nhwc_per_tensor_out( +void quantized_conv2d_nhwc_per_tensor_out( ::executorch::runtime::KernelRuntimeContext& ctx, const ::executorch::aten::Tensor& input, const ::executorch::aten::Tensor& weight, @@ -195,4 +199,3 @@ void quantized_add_asym8uxasym8u_asym8u_per_tensor_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index d310396c262..5a92a37237f 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -2,7 +2,7 @@ load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") -def define_operator(name: str, deps: list[str] | None = None) -> None: +def define_operator(name: str, deps: list[str] | None = None, exported_headers: list[str] | None = None) -> None: op_name = "op_{}".format(name) # Deps used by all operators. @@ -16,10 +16,13 @@ def define_operator(name: str, deps: list[str] | None = None) -> None: "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/pattern:bitwise_op", "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions", - "//executorch/kernels/portable/cpu/pattern:comparison_op" + "//executorch/kernels/portable/cpu/pattern:comparison_op", + "//executorch/backends/cadence/common:xt_macros" ] if deps == None: deps = [] + if exported_headers == None: + exported_headers = ["operators.h"] runtime.cxx_library( name = op_name, @@ -31,7 +34,7 @@ def define_operator(name: str, deps: list[str] | None = None) -> None: ], compatible_with = ["ovr_config//cpu:xtensa"], deps = deps + common_deps, - exported_headers = ["operators.h"], + exported_headers = exported_headers, ) OPERATORS = [ @@ -44,6 +47,7 @@ OPERATORS = [ "cat", "clamp", "dequantize_per_tensor", + "dequantize_per_tensor_asym8s", "div", "embedding", "eq", @@ -63,34 +67,35 @@ OPERATORS = [ "ne", "permute_copy", "pow", - "quantized_conv_nchw_out", - "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out", - "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out", - "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out", - "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out", - "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out", - "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out", - "quantized_conv_nhwc_out", - "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out", - "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out", - "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out", - "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out", - "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out", - "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out", + "quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out", + "quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out", + "quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out", + "quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out", + "quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out", + "quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out", + "quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out", + "quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out", + "quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out", + "quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out", + "quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out", + "quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out", + "quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out", + "quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out", + "quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out", + "quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out", "quantized_fully_connected_out", "quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out", "quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out", "quantized_layer_norm", - "quantized_linear_out", "quantized_linear_asym8sxasym8s_asym8s_per_tensor_out", "quantized_linear_asym8uxasym8u_asym8u_per_tensor_out", - "quantized_matmul_out", "quantized_matmul_asym8sxasym8s_asym8s_out", "quantized_matmul_asym8uxasym8u_asym8u_out", "quantized_relu_out", "quantized_relu_asym8s_asym8s_per_tensor_out", "quantized_relu_asym8u_asym8u_per_tensor_out", "quantize_per_tensor", + "quantize_per_tensor_asym8s", "remainder", "rsqrt", "select_copy", @@ -115,3 +120,14 @@ def define_common_targets(): # Define build targets for all operators registered in the tables above. for op in OPERATORS: define_operator(op) + + # quantized_linear_out and quantized_linear_per_tensor_out needs additional dependency for int16 support + define_operator("quantized_linear_out", deps=["//executorch/backends/cadence/generic/operators:op_quantized_linear"]) + define_operator("quantized_linear_per_tensor_out", deps=["//executorch/backends/cadence/generic/operators:op_quantized_linear"]) + + # quantized_conv2d_nchw_out and quantized_conv2d_nhwc_out need additional dependency for int16 support + define_operator("quantized_conv2d_nchw_out", deps=["//executorch/backends/cadence/generic/operators:op_quantized_conv2d"]) + define_operator("quantized_conv2d_nhwc_out", deps=["//executorch/backends/cadence/generic/operators:op_quantized_conv2d"]) + + # quantized_matmul_out needs additional dependency for int16 support + define_operator("quantized_matmul_out", deps=["//executorch/backends/cadence/generic/operators:op_quantized_matmul"], exported_headers=["op_quantized_matmul_out.h"]) diff --git a/backends/cadence/hifi/operators/tests/test_op_cat.cpp b/backends/cadence/hifi/operators/tests/test_op_cat.cpp index 2f012ed6c81..f0b76e7a555 100644 --- a/backends/cadence/hifi/operators/tests/test_op_cat.cpp +++ b/backends/cadence/hifi/operators/tests/test_op_cat.cpp @@ -19,7 +19,6 @@ #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -38,7 +37,7 @@ class HiFiCatTest : public OperatorTest { public: protected: Tensor& cat_out(ArrayRef tensors, int64_t dim, Tensor& out) { - return ::cadence::impl::HiFi::native::cat_out(context_, tensors, dim, out); + return ::impl::HiFi::native::cat_out(context_, tensors, dim, out); } }; @@ -133,4 +132,3 @@ TEST_F(HiFiCatTest, ThreeDimensionalCatTest) { } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/tests/test_op_dequantize_per_tensor_out.cpp b/backends/cadence/hifi/operators/tests/test_op_dequantize_per_tensor_out.cpp index d6f02501be2..c4370f3b572 100644 --- a/backends/cadence/hifi/operators/tests/test_op_dequantize_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/tests/test_op_dequantize_per_tensor_out.cpp @@ -18,7 +18,6 @@ #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -46,7 +45,7 @@ class HiFiDequantizePerTensorTest : public OperatorTest { int64_t quant_max, ScalarType dtype, Tensor& out) { - return ::cadence::impl::HiFi::native::dequantize_per_tensor_out( + return ::impl::HiFi::native::dequantize_per_tensor_out( context_, input, scale, zero_point, quant_min, quant_max, dtype, out); } }; @@ -101,4 +100,3 @@ TEST_F(HiFiDequantizePerTensorTest, OneDimensionalTest) { } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/tests/test_op_div.cpp b/backends/cadence/hifi/operators/tests/test_op_div.cpp index 790319d2db4..4fee8dbf874 100644 --- a/backends/cadence/hifi/operators/tests/test_op_div.cpp +++ b/backends/cadence/hifi/operators/tests/test_op_div.cpp @@ -19,7 +19,6 @@ #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -44,8 +43,7 @@ class HiFiDivTest : public OperatorTest { const Tensor& b, optional mode, Tensor& out) { - return ::cadence::impl::HiFi::native::div_out_mode( - context_, a, b, mode, out); + return ::impl::HiFi::native::div_out_mode(context_, a, b, mode, out); } }; @@ -70,4 +68,3 @@ TEST_F(HiFiDivTest, DISABLED_Int32FloorDivideTest) { } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/tests/test_op_permute_copy.cpp b/backends/cadence/hifi/operators/tests/test_op_permute_copy.cpp index a549fac786e..981c4934c78 100644 --- a/backends/cadence/hifi/operators/tests/test_op_permute_copy.cpp +++ b/backends/cadence/hifi/operators/tests/test_op_permute_copy.cpp @@ -19,7 +19,6 @@ #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -38,8 +37,7 @@ class HiFiPermuteCopyTest : public OperatorTest { public: protected: Tensor& permute_copy_out(const Tensor& in, IntArrayRef dims, Tensor& out) { - return ::cadence::impl::HiFi::native::permute_copy_out( - context_, in, dims, out); + return ::impl::HiFi::native::permute_copy_out(context_, in, dims, out); } }; @@ -229,4 +227,3 @@ TEST_F(HiFiPermuteCopyTest, MixedDataTypesTest) { } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/tests/test_op_quantize_per_tensor.cpp b/backends/cadence/hifi/operators/tests/test_op_quantize_per_tensor.cpp index 6f910cb76a8..cac85fcbef8 100644 --- a/backends/cadence/hifi/operators/tests/test_op_quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/tests/test_op_quantize_per_tensor.cpp @@ -19,7 +19,6 @@ #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -45,7 +44,7 @@ class HiFiQuantizePerTensorTest : public OperatorTest { __ET_UNUSED int64_t quant_max, ScalarType dtype, Tensor& out) { - ::cadence::impl::HiFi::native::quantize_per_tensor_out( + ::impl::HiFi::native::quantize_per_tensor_out( context_, input, scale, zero_point, quant_min, quant_max, dtype, out); } }; @@ -162,4 +161,3 @@ TEST_F(HiFiQuantizePerTensorTest, CheckSingleElementUInt16Quantize) { } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/operators/tests/test_op_quantized_conv2d_out.cpp b/backends/cadence/hifi/operators/tests/test_op_quantized_conv2d_out.cpp new file mode 100644 index 00000000000..70afc030b4c --- /dev/null +++ b/backends/cadence/hifi/operators/tests/test_op_quantized_conv2d_out.cpp @@ -0,0 +1,304 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace impl { +namespace HiFi { +namespace native { +namespace { + +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::aten::TensorImpl; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::runtime_init; +using ::executorch::runtime::testing::TensorFactory; + +class HiFiQuantizedConv2dTest : public OperatorTest { + public: + protected: + void quantized_conv2d_nchw_out( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { + return ::impl::HiFi::native::quantized_conv2d_nchw_out( + context_, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + output); + } + + void quantized_conv2d_nhwc_out( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { + return ::impl::HiFi::native::quantized_conv2d_nhwc_out( + context_, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + output); + } +}; + +// Test quantized_conv2d_nchw_out with int16 activations and int8 weights +TEST_F(HiFiQuantizedConv2dTest, QuantizedConv2dNchwInt16Test) { + TensorFactory tf_int16; + TensorFactory tf_int32; + TensorFactory tf_int8; + TensorFactory tf_float; + + // Minimal test case: input [1, 2, 3, 3], kernel [1, 2, 2, 2] -> output [1, 1, + // 2, 2] Small enough to verify by hand calculation + // + // Input Channel 0 (3x3): Input Channel 1 (3x3): + // 1 2 3 1 1 1 + // 4 5 6 1 1 1 + // 7 8 9 1 1 1 + // + // Weight Out Ch 0, In Ch 0: Weight Out Ch 0, In Ch 1: + // 1 0 1 1 + // 0 1 1 1 + // + // Hand calculation for each output position: + // (0,0): Ch0: 1*1+2*0+4*0+5*1=6, Ch1: 1*1+1*1+1*1+1*1=4 -> 10 + // (0,1): Ch0: 2*1+3*0+5*0+6*1=8, Ch1: 1*1+1*1+1*1+1*1=4 -> 12 + // (1,0): Ch0: 4*1+5*0+7*0+8*1=12, Ch1: 1*1+1*1+1*1+1*1=4 -> 16 + // (1,1): Ch0: 5*1+6*0+8*0+9*1=14, Ch1: 1*1+1*1+1*1+1*1=4 -> 18 + Tensor input = tf_int16.make( + {1, 2, 3, 3}, + {1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, // Channel 0 + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1}); // Channel 1 + Tensor weight = tf_int8.make( + {1, 2, 2, 2}, + {1, + 0, + 0, + 1, // Out Ch 0, In Ch 0: diagonal pattern + 1, + 1, + 1, + 1}); // Out Ch 0, In Ch 1: all ones + Tensor bias = tf_int32.zeros({1}); + + // Output dimensions: (3-2)/1+1=2 for each spatial dimension + Tensor output = tf_int16.zeros({1, 1, 2, 2}); + + int64_t in_zero_point = 0; + Tensor weight_zero_point = tf_int32.make({1}, {0}); + Tensor bias_scale = tf_float.make({1}, {1.0f}); + double output_scale = 1.0; + int64_t output_zero_point = 0; + Tensor out_multiplier = tf_int32.make({1}, {1073741824}); // 0.5 * 2^31 + Tensor out_shift = tf_int32.make({1}, {0}); + + std::array stride_arr = {1, 1}; + std::array padding_arr = {0, 0}; + std::array dilation_arr = {1, 1}; + + ::executorch::aten::ArrayRef stride(stride_arr.data(), 2); + ::executorch::aten::ArrayRef padding(padding_arr.data(), 2); + ::executorch::aten::ArrayRef dilation(dilation_arr.data(), 2); + + quantized_conv2d_nchw_out( + input, + weight, + bias, + stride, + padding, + dilation, + 1, // groups + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + output); + + Tensor expected = tf_int16.make({1, 1, 2, 2}, {10, 12, 16, 18}); + EXPECT_TENSOR_EQ(output, expected); +} + +// Test quantized_conv2d_nhwc_out with int16 activations and int8 weights +TEST_F(HiFiQuantizedConv2dTest, QuantizedConv2dNhwcInt16Test) { + TensorFactory tf_int16; + TensorFactory tf_int32; + TensorFactory tf_int8; + TensorFactory tf_float; + + // Minimal test case in NHWC format: input [1, 3, 3, 2], kernel [1, 2, 2, 2] + // -> output [1, 2, 2, 1] Same values as NCHW test, just different layout + // + // Input (H=3, W=3, C=2): + // Position (h,w): [Ch0, Ch1] + // (0,0): [1, 1] (0,1): [2, 1] (0,2): [3, 1] + // (1,0): [4, 1] (1,1): [5, 1] (1,2): [6, 1] + // (2,0): [7, 1] (2,1): [8, 1] (2,2): [9, 1] + // + // Weight (Out=1, H=2, W=2, In=2): + // For output channel 0: + // Position (h,w): [In0, In1] + // (0,0): [1, 1] (0,1): [0, 1] + // (1,0): [0, 1] (1,1): [1, 1] + // + // Hand calculation matches NCHW test: + // Output (0,0): 10, (0,1): 12, (1,0): 16, (1,1): 18 + Tensor input = tf_int16.make( + {1, 3, 3, 2}, + {1, + 1, + 2, + 1, + 3, + 1, // Row 0: (Ch0,Ch1) pairs + 4, + 1, + 5, + 1, + 6, + 1, // Row 1 + 7, + 1, + 8, + 1, + 9, + 1}); // Row 2 + Tensor weight = tf_int8.make( + {1, 2, 2, 2}, + {1, + 1, + 0, + 1, // Row 0: (In0,In1) pairs + 0, + 1, + 1, + 1}); // Row 1 + Tensor bias = tf_int32.zeros({1}); + + // Output dimensions: (3-2)/1+1=2 for each spatial dimension + Tensor output = tf_int16.zeros({1, 2, 2, 1}); + + int64_t in_zero_point = 0; + Tensor weight_zero_point = tf_int32.make({1}, {0}); + Tensor bias_scale = tf_float.make({1}, {1.0f}); + double output_scale = 1.0; + int64_t output_zero_point = 0; + Tensor out_multiplier = tf_int32.make({1}, {1073741824}); // 0.5 * 2^31 + Tensor out_shift = tf_int32.make({1}, {0}); + + std::array stride_arr = {1, 1}; + std::array padding_arr = {0, 0}; + std::array dilation_arr = {1, 1}; + + ::executorch::aten::ArrayRef stride(stride_arr.data(), 2); + ::executorch::aten::ArrayRef padding(padding_arr.data(), 2); + ::executorch::aten::ArrayRef dilation(dilation_arr.data(), 2); + + quantized_conv2d_nhwc_out( + input, + weight, + bias, + stride, + padding, + dilation, + 1, // groups + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + output); + + Tensor expected = tf_int16.make({1, 2, 2, 1}, {10, 12, 16, 18}); + EXPECT_TENSOR_EQ(output, expected); +} + +} // namespace +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/tests/test_op_quantized_linear_out.cpp b/backends/cadence/hifi/operators/tests/test_op_quantized_linear_out.cpp new file mode 100644 index 00000000000..fddf373290f --- /dev/null +++ b/backends/cadence/hifi/operators/tests/test_op_quantized_linear_out.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace impl { +namespace HiFi { +namespace native { +namespace { + +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::aten::TensorImpl; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::runtime_init; +using ::executorch::runtime::testing::TensorFactory; +using std::optional; +using std::string_view; + +class HiFiQuantizedLinearTest : public OperatorTest { + public: + protected: + void quantized_linear_out( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + const optional& offset, + Tensor& output) { + return ::impl::HiFi::native::quantized_linear_out( + context_, + input, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + offset, + output); + } + + void quantized_linear_per_tensor_out( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + const optional& offset, + Tensor& output) { + return ::impl::HiFi::native::quantized_linear_per_tensor_out( + context_, + input, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + offset, + output); + } +}; + +// Test quantized_linear_out with int16 activations (asym8s) +TEST_F(HiFiQuantizedLinearTest, QuantizedLinearInt16Test) { + TensorFactory tf_int16; + TensorFactory tf_int32; + TensorFactory tf_int8; + + // Simple 2D case: input [2, 3] x weight [4, 3] = output [2, 4] + // Values captured from e2e test with + // CadenceWith16BitLinearActivationsQuantizer + Tensor input = + tf_int16.make({2, 3}, {-28170, -26389, -32768, -31474, -32266, -29076}); + Tensor weight = tf_int8.make( + {4, 3}, {1, 87, -128, -114, -59, 44, -1, 127, -12, 44, -46, -29}); + Tensor bias = tf_int32.zeros({4}); + Tensor output = tf_int16.zeros({2, 4}); + + int64_t in_zero_point = -29822; + Tensor weight_zero_point = tf_int32.make({1}, {2}); + Tensor out_multiplier = tf_int32.make({1}, {2011373824}); + Tensor out_shift = tf_int32.make({1}, {-8}); + int64_t out_zero_point = -30847; + quantized_linear_out( + input, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + std::nullopt, + output); + // Expected output from e2e test + Tensor expected_output = tf_int16.make( + {2, 4}, {-28384, -32767, -29144, -30862, -31956, -29486, -31985, -30756}); + EXPECT_TENSOR_CLOSE(output, expected_output); +} + +} // namespace +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/tests/test_op_quantized_matmul_out.cpp b/backends/cadence/hifi/operators/tests/test_op_quantized_matmul_out.cpp new file mode 100644 index 00000000000..3286913f055 --- /dev/null +++ b/backends/cadence/hifi/operators/tests/test_op_quantized_matmul_out.cpp @@ -0,0 +1,165 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace impl { +namespace HiFi { +namespace native { +namespace { + +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::aten::TensorImpl; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::runtime_init; +using ::executorch::runtime::testing::TensorFactory; + +class HiFiQuantizedMatmulTest : public OperatorTest { + public: + protected: + Tensor& quantized_matmul_out( + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const std::optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& output) { + return impl::HiFi::native::quantized_matmul_out( + context_, + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + output); + } +}; + +// Test quantized_matmul_out with int16 activations and int8 weights +TEST_F(HiFiQuantizedMatmulTest, QuantizedMatmulInt16Test) { + TensorFactory tf_int16; + TensorFactory tf_int32; + TensorFactory tf_int8; + + // Minimal test case: X [2, 2] x Y [2, 2] = output [2, 2] + // Small enough to verify by hand calculation + // + // X (2x2): Y (2x2): + // 2 4 1 2 + // 6 8 1 0 + // + // Hand calculation for matmul (before scaling): + // (0,0): 2*1 + 4*1 = 6 + // (0,1): 2*2 + 4*0 = 4 + // (1,0): 6*1 + 8*1 = 14 + // (1,1): 6*2 + 8*0 = 12 + // + // Raw result: [[6, 4], [14, 12]] + // After 0.5 scaling: [[3, 2], [7, 6]] + Tensor X = tf_int16.make({2, 2}, {2, 4, 6, 8}); + Tensor Y = tf_int8.make({2, 2}, {1, 2, 1, 0}); + Tensor bias = tf_int32.zeros({2}); + Tensor output = tf_int16.zeros({2, 2}); + + int64_t X_zero_point = 0; + int64_t Y_zero_point = 0; + int64_t out_multiplier = 1073741824; // 0.5 * 2^31 + int64_t out_shift = 0; + int64_t out_zero_point = 0; + + quantized_matmul_out( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + false, // transposed + output); + + Tensor expected = tf_int16.make({2, 2}, {3, 2, 7, 6}); + EXPECT_TENSOR_EQ(output, expected); +} + +// Test quantized_matmul_out with transposed Y (int16 activations and int8 +// weights) +TEST_F(HiFiQuantizedMatmulTest, QuantizedMatmulInt16TransposedTest) { + TensorFactory tf_int16; + TensorFactory tf_int32; + TensorFactory tf_int8; + + // Minimal test case with transposed Y: X [2, 2] x Y^T [2, 2] = output [2, 2] + // Y is stored transposed, so we compute X @ Y^T + // + // X (2x2): Y_stored (2x2, which is Y^T): + // 2 4 1 1 + // 6 8 2 0 + // + // When transposed=true, we compute X @ Y_stored^T = X @ Y + // Y = Y_stored^T = [[1, 2], [1, 0]] + // + // Hand calculation for matmul (before scaling): + // (0,0): 2*1 + 4*1 = 6 + // (0,1): 2*2 + 4*0 = 4 + // (1,0): 6*1 + 8*1 = 14 + // (1,1): 6*2 + 8*0 = 12 + // + // Raw result: [[6, 4], [14, 12]] + // After 0.5 scaling: [[3, 2], [7, 6]] + Tensor X = tf_int16.make({2, 2}, {2, 4, 6, 8}); + Tensor Y = tf_int8.make({2, 2}, {1, 1, 2, 0}); // Stored as Y^T + Tensor bias = tf_int32.zeros({2}); + Tensor output = tf_int16.zeros({2, 2}); + + int64_t X_zero_point = 0; + int64_t Y_zero_point = 0; + int64_t out_multiplier = 1073741824; // 0.5 * 2^31 + int64_t out_shift = 0; + int64_t out_zero_point = 0; + + quantized_matmul_out( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + true, // transposed + output); + + Tensor expected = tf_int16.make({2, 2}, {3, 2, 7, 6}); + EXPECT_TENSOR_EQ(output, expected); +} + +} // namespace +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/tests/test_op_quantized_relu_out.cpp b/backends/cadence/hifi/operators/tests/test_op_quantized_relu_out.cpp index 3a2ef85087c..a599d73ccc8 100644 --- a/backends/cadence/hifi/operators/tests/test_op_quantized_relu_out.cpp +++ b/backends/cadence/hifi/operators/tests/test_op_quantized_relu_out.cpp @@ -18,7 +18,6 @@ #include -namespace cadence { namespace impl { namespace HiFi { namespace native { @@ -45,7 +44,7 @@ class HiFiQuantizedReluTest : public OperatorTest { const Tensor& out_multiplier, const Tensor& out_shift, Tensor& output) { - return ::cadence::impl::HiFi::native::quantized_relu_out( + return ::impl::HiFi::native::quantized_relu_out( context_, input, in_zero_point, @@ -58,14 +57,14 @@ class HiFiQuantizedReluTest : public OperatorTest { TEST_F(HiFiQuantizedReluTest, MultiDimensionalTest) { TensorFactory tf_chars; + TensorFactory tf_ints; const std::vector sizes{2, 3, 5, 6}; Tensor quantized_input = tf_chars.full(sizes, -128); Tensor quantized_output = tf_chars.full(sizes, 100); Tensor in_zero_point = tf_chars.full({1}, 127); int64_t out_zero_point = -128; - Tensor out_multiplier = - TensorFactory().full({1}, 1077952640); - Tensor out_shift = TensorFactory().full({1}, 5); + Tensor out_multiplier = tf_ints.full({1}, 1077952640); + Tensor out_shift = tf_ints.full({1}, 5); quantized_relu_out( quantized_input, @@ -81,14 +80,14 @@ TEST_F(HiFiQuantizedReluTest, MultiDimensionalTest) { TEST_F(HiFiQuantizedReluTest, OneDimensionalTest) { TensorFactory tf_chars; + TensorFactory tf_ints; const std::vector sizes{56}; Tensor quantized_input = tf_chars.full(sizes, -128); Tensor quantized_output = tf_chars.full(sizes, 100); Tensor in_zero_point = tf_chars.full({1}, 127); int64_t out_zero_point = -128; - Tensor out_multiplier = - TensorFactory().full({1}, 1077952640); - Tensor out_shift = TensorFactory().full({1}, 5); + Tensor out_multiplier = tf_ints.full({1}, 1077952640); + Tensor out_shift = tf_ints.full({1}, 5); quantized_relu_out( quantized_input, @@ -106,4 +105,3 @@ TEST_F(HiFiQuantizedReluTest, OneDimensionalTest) { } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence diff --git a/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp b/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp index fb944a66431..a984a1a1faf 100644 --- a/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp +++ b/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp @@ -43,7 +43,6 @@ /*----------------------------Main function---------------------------------*/ -namespace cadence { namespace impl { namespace HiFi { namespace kernels { @@ -434,7 +433,6 @@ WORD32 matmul_asym8uxasym8u_asym8u( return 0; } -}; // namespace kernels -}; // namespace HiFi -}; // namespace impl -}; // namespace cadence +} // namespace kernels +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/third-party/nnlib/targets.bzl b/backends/cadence/hifi/third-party/nnlib/targets.bzl index a63a4dd3954..2ad9d6568ac 100644 --- a/backends/cadence/hifi/third-party/nnlib/targets.bzl +++ b/backends/cadence/hifi/third-party/nnlib/targets.bzl @@ -13,6 +13,10 @@ def define_common_targets(): "@EXECUTORCH_CLIENTS", ], compatible_with = ["ovr_config//cpu:xtensa"], + compiler_flags = [ + "-Wno-pointer-sign", + "-Wno-incompatible-pointer-types-discards-qualifiers", + ], deps = [ "fbsource//third-party/nnlib-hifi4/xa_nnlib:libxa_nnlib", ], diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c index 2f1d2071777..68a51223cde 100644 --- a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c @@ -21,7 +21,7 @@ ******************************************************************************/ #include -#include "../include/NatureDSP_Signal_math.h" +#include "NatureDSP_Signal_math.h" #include "NatureDSP_types.h" #include "xa_nn_common.h" diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_pow_f32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_pow_f32.c index aa81d695784..5fb69113ee7 100644 --- a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_pow_f32.c +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_pow_f32.c @@ -20,7 +20,7 @@ ******************************************************************************/ -#include "../include/NatureDSP_Signal_math.h" +#include "NatureDSP_Signal_math.h" #include "NatureDSP_types.h" #include "xa_nn_common.h" diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c index e7e83846484..840a027f7a7 100644 --- a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c @@ -117,6 +117,7 @@ WORD32 xa_nn_elm_where_f32xf32_f32(FLOAT32 * __restrict__ p_out, XT_MOVF_S(a, a2, s); XT_SSI(a, (xtfloat *)out, 0); } + return 0; } static void internal_elm_where_broadcast_f32xf32_f32(FLOAT32 * __restrict__ p_out, diff --git a/backends/cadence/reference/kernels/kernels.cpp b/backends/cadence/reference/kernels/kernels.cpp deleted file mode 100644 index ad8746f51eb..00000000000 --- a/backends/cadence/reference/kernels/kernels.cpp +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include -#include -#include - -namespace impl { -namespace reference { -namespace kernels { - -// Quantize a fp32 value to an int8_t/uint8_t value -template -T quantize(const float x, float scale, int32_t zero_point) { - // constexpr float min_val = std::numeric_limits::min(); - // constexpr float max_val = std::numeric_limits::max(); - // float tmp = roundf(x * scale + zero_point); - // return std::max(std::min(tmp, max_val), min_val); - // Match Executorch CPU kernel implementation at - // https://fburl.com/code/fxizw6u6 - int64_t qvalue; - qvalue = static_cast(zero_point + std::nearbyint(scale * x)); - - qvalue = std::max(qvalue, std::numeric_limits::min()); - qvalue = std::min(qvalue, std::numeric_limits::max()); - return static_cast(qvalue); -} - -// Quantize an fp32 array to an int8_t/uint8_t array -template -void quantize( - T* __restrict__ y, - const float* __restrict__ x, - float inv_scale, - int32_t zero_point, - size_t size) { - for (size_t i = 0; i < size; ++i) { - y[i] = quantize(x[i], inv_scale, zero_point); - } -} - -// Dequantize an int8_t/uint8_t value to an fp32 value -template -float dequantize(const T x, float scale, int32_t zero_point) { - return scale * (x - zero_point); -} - -// Dequantize an int8_t/uint8_t/int16_t array to an fp32 array -template -void dequantize( - float* __restrict__ y, - const T* __restrict__ x, - float scale, - int32_t zero_point, - size_t size) { - for (size_t i = 0; i < size; ++i) { - y[i] = dequantize(x[i], scale, zero_point); - } -} - -// explicit template instantiation - -#define typed_quantize_val(dtype) \ - template dtype quantize(const float x, float inv_scale, int32_t zero_point); -typed_quantize_val(int8_t); -typed_quantize_val(uint8_t); -typed_quantize_val(int16_t); -typed_quantize_val(uint16_t); -#undef typed_quantize_val - -#define typed_quantize_vec(dtype) \ - template void quantize( \ - dtype* __restrict__ y, \ - const float* __restrict__ x, \ - float inv_scale, \ - int32_t zero_point, \ - size_t size); -typed_quantize_vec(int8_t); -typed_quantize_vec(uint8_t); -typed_quantize_vec(int16_t); -typed_quantize_vec(uint16_t); -#undef typed_quantize_vec - -#define typed_dequantize_val(dtype) \ - template float dequantize(const dtype x, float scale, int32_t zero_point); -typed_dequantize_val(int8_t); -typed_dequantize_val(uint8_t); -typed_dequantize_val(int16_t); -typed_dequantize_val(uint16_t); -#undef typed_dequantize_val - -#define typed_dequantize_vec(dtype) \ - template void dequantize( \ - float* __restrict__ y, \ - const dtype* __restrict__ x, \ - float scale, \ - int32_t zero_point, \ - size_t size); -typed_dequantize_vec(int8_t); -typed_dequantize_vec(uint8_t); -typed_dequantize_vec(int16_t); -typed_dequantize_vec(uint16_t); -#undef typed_dequantize_vec - -}; // namespace kernels -}; // namespace reference -}; // namespace impl diff --git a/backends/cadence/reference/kernels/kernels.h b/backends/cadence/reference/kernels/kernels.h deleted file mode 100644 index de6ae9486f5..00000000000 --- a/backends/cadence/reference/kernels/kernels.h +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "inttypes.h" -#include "stddef.h" - -namespace impl { -namespace reference { -namespace kernels { - -template -T quantize(const float x, float scale, int32_t zero_point); - -template -float dequantize(const T x, float scale, int32_t zero_point); - -template -void quantize( - T* __restrict__ y, - const float* __restrict__ x, - float scale, - int32_t zero_point, - size_t size); - -// Deuantize an int8_t/uint8_t/int16_t array to an fp32 array -template -void dequantize( - float* __restrict__ y, - const T* __restrict__ x, - float scale, - int32_t zero_point, - size_t size); - -template -OT requantize( - const IT in, - float in_scale, - int32_t in_zero_point, - float inv_out_scale, - int32_t out_zero_point); - -template -void requantize( - OT* __restrict__ out, - const IT* __restrict__ in, - float in_scale, - int32_t in_zero_point, - float inv_out_scale, - int32_t out_zero_point, - size_t size); - -}; // namespace kernels -}; // namespace reference -}; // namespace impl diff --git a/backends/cadence/reference/operators/CMakeLists.txt b/backends/cadence/reference/operators/CMakeLists.txt deleted file mode 100644 index ea5b699f441..00000000000 --- a/backends/cadence/reference/operators/CMakeLists.txt +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) - -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -if(NOT CMAKE_CXX_STANDARD) - set(CMAKE_CXX_STANDARD 17) -endif() - -include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) -include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake) - -# ATen compliant ops that are needed to run this model. -set(_aten_ops__srcs - "${CMAKE_CURRENT_SOURCE_DIR}/op_add.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/op_embedding.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/op_full.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/op_view_copy.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/slice_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_hardtanh.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_max_pool2d_with_indices.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mean.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_rsqrt.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sub.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_expand_copy.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_gelu.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_empty.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_transpose_copy.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_eq.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_logical_not.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_any.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_native_group_norm.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sum.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_select_copy.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/select_copy_util.cpp" -) -add_library(aten_ops_cadence ${_aten_ops__srcs}) -target_link_libraries(aten_ops_cadence PUBLIC executorch) -target_link_libraries(aten_ops_cadence PRIVATE cadence_kernels) - -# Let files say "include ". -set(_common_include_directories - ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 -) - -target_include_directories( - aten_ops_cadence PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} - ${_common_include_directories} -) - -# Custom ops that are needed to run the test model. -add_library( - custom_ops - "quantized_linear_out.cpp" - "quantized_conv_nchw_out.cpp" - "quantized_conv_nhwc_out.cpp" - "quantized_relu_out.cpp" - "quantized_layer_norm.cpp" - "quantize_per_tensor.cpp" - "quantized_fully_connected_out.cpp" - "dequantize_per_tensor.cpp" - "quantized_matmul_out.cpp" - "requantize_out.cpp" - "im2row_out.cpp" -) -target_include_directories( - custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} - ${_common_include_directories} -) - -target_link_libraries(custom_ops PUBLIC executorch) -target_link_libraries(custom_ops PRIVATE cadence_kernels) - -# Generate C++ bindings to register kernels into both PyTorch (for AOT) and -# Executorch (for runtime). Here select all ops in functions.yaml -gen_selected_ops( - LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML - "${CMAKE_CURRENT_LIST_DIR}/../../aot/functions.yaml" "" "" -) -generate_bindings_for_kernels( - LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML FUNCTIONS_YAML - ${CMAKE_CURRENT_SOURCE_DIR}/../../aot/functions.yaml -) -message("Generated cadence x86 files ${gen_command_sources}") - -gen_operators_lib( - LIB_NAME "cadence_ops_lib" KERNEL_LIBS custom_ops DEPS aten_ops_cadence -) diff --git a/backends/cadence/reference/operators/TARGETS b/backends/cadence/reference/operators/TARGETS deleted file mode 100644 index 67f2bab681a..00000000000 --- a/backends/cadence/reference/operators/TARGETS +++ /dev/null @@ -1,5 +0,0 @@ -load("targets.bzl", "define_common_targets") - -oncall("odai_jarvis") - -define_common_targets() diff --git a/backends/cadence/reference/operators/dequantize_per_tensor.cpp b/backends/cadence/reference/operators/dequantize_per_tensor.cpp deleted file mode 100644 index f53292e312d..00000000000 --- a/backends/cadence/reference/operators/dequantize_per_tensor.cpp +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -namespace impl { -namespace reference { -namespace native { - -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; -using ::impl::reference::kernels::dequantize; - -void dequantize_per_tensor_out( - KernelRuntimeContext& context, - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - float* out_data = out.mutable_data_ptr(); - size_t numel = out.numel(); - - if (input.scalar_type() == ScalarType::Byte) { - const uint8_t* input_data = input.const_data_ptr(); - dequantize(out_data, input_data, scale, zero_point, numel); - } else if (input.scalar_type() == ScalarType::Char) { - const int8_t* input_data = input.const_data_ptr(); - dequantize(out_data, input_data, scale, zero_point, numel); - } else if ( - input.scalar_type() == ScalarType::Bits16 || - input.scalar_type() == ScalarType::UInt16) { - const uint16_t* input_data = input.const_data_ptr(); - dequantize(out_data, input_data, scale, zero_point, numel); - } else if (input.scalar_type() == ScalarType::Short) { - const int16_t* input_data = input.const_data_ptr(); - dequantize(out_data, input_data, scale, zero_point, numel); - } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(input.scalar_type())); - } -} - -}; // namespace native -}; // namespace reference -}; // namespace impl diff --git a/backends/cadence/reference/operators/im2row_out.cpp b/backends/cadence/reference/operators/im2row_out.cpp deleted file mode 100644 index 0cd2e338e6e..00000000000 --- a/backends/cadence/reference/operators/im2row_out.cpp +++ /dev/null @@ -1,298 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -namespace impl { -namespace reference { -namespace native { - -using ::executorch::aten::IntArrayRef; -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; - -template -__attribute__((always_inline)) void im2row_( - const T* __restrict__ data_im, - const int32_t in_zero_point, - /* input parameters*/ - const int32_t channels, - const int32_t height, - const int32_t width, - /* output parameters */ - const int32_t out_height, - const int32_t out_width, - /* convolution parameters */ - const int32_t kernel_h, - const int32_t kernel_w, - const int32_t pad_h, - const int32_t pad_w, - const int32_t stride_h, - const int32_t stride_w, - const int32_t dilation_h, - const int32_t dilation_w, - T* __restrict__ data_col, - bool channels_last) { - // Consider convolving the input image of dimensions channels * height * width - // (or height * width * channels for NHWC layout) with a filter of dimensions - // channels * kernels_h * kernels_w. Assume that this convolution will produce - // an output of dimensinos out_height x out_width. For each point the output, - // im2row takes the data from the input that is used in the computation of - // that output point, and flattens it into a vector of size channels_col = - // channels * kernel_h * kernel_w. The output of im2row will therefore be a 2D - // array of size (out_height * out_width) x channels_col - const int32_t channels_col = channels * kernel_h * kernel_w; - - // If the layout is NHWC, we can copy 'channels' worth of contiguous data - // points when performing im2row. - if (channels_last) { - // Iterate over the output domain - for (int _h = 0; _h < out_height; ++_h) { - for (int _w = 0; _w < out_width; ++_w) { - int32_t i_col = _h * out_width + _w; - // Each point in the output domain is the result of applying a filter of - // size kernel_h x kernel_w x channels on the input. But since channels - // is contiguous, we will not explicitly have a loop for it. - for (int _kh = 0; _kh < kernel_h; ++_kh) { - int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; - for (int _kw = 0; _kw < kernel_w; ++_kw) { - int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; - - // h_im and w_im are the actual height and width coordinates of the - // input tensor from where we need to copy 'channels' points. - const T* __restrict__ slice_im = - data_im + (h_im * width + w_im) * channels; - T* __restrict__ slice_col = data_col + i_col * channels_col + - (_kh * kernel_w + _kw) * channels; - // If the coordinates were within the input domain, we copy - // 'channels' contiguous values. Otherwise we will fill the output - // with 0's. - if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { - std::memcpy(slice_col, slice_im, channels * sizeof(T)); - } else { - std::fill_n(slice_col, channels, T(in_zero_point)); - } - } - } - } - } - } else { - // Iterate over the output domain - for (int _h = 0; _h < out_height; ++_h) { - for (int _w = 0; _w < out_width; ++_w) { - int32_t i_col = _h * out_width + _w; - - // Each point in the output domain is the result of applying a filter - // of size chanenls * kernel_h x kernel_w on the input - for (int _c = 0; _c < channels; ++_c) { - for (int _kh = 0; _kh < kernel_h; ++_kh) { - for (int _kw = 0; _kw < kernel_w; ++_kw) { - // c_col is the linearized access in the channels_col vector. - int32_t c_col = (_c * kernel_h + _kh) * kernel_w + _kw; - // h_im and w_im are the actual height and width coordinates of - // the input tensor that we need to copy to the output. - int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; - int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; - // If the current data access is within the input tensor, copy the - // value - data_col[i_col * channels_col + c_col] = - (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) - ? data_im[(_c * height + h_im) * width + w_im] - : static_cast(in_zero_point); - } - } - } - } - } - } -} - -void im2row_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - IntArrayRef kernel_size, - IntArrayRef dilation, - IntArrayRef padding, - IntArrayRef stride, - const Tensor& in_zero_point, - bool channel_last, - Tensor& out) { - // Compute the input tensor's dims - bool unit_height = input.dim() == 3; - const int32_t batch_size = input.size(0); - const int32_t in_c = - channel_last ? input.size(3 - unit_height) : input.size(1); - const int32_t in_h = - unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); - const int32_t in_w = - channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); - - // Get the kernel parameters - int32_t kernel_h = kernel_size[0]; - int32_t kernel_w = kernel_size[1]; - int32_t dilation_h = dilation[0]; - int32_t dilation_w = dilation[1]; - int32_t pad_h = padding[0]; - int32_t pad_w = padding[1]; - int32_t stride_h = stride[0]; - int32_t stride_w = stride[1]; - - // If we were to apply a convolution on the input tensor, compute the output - // height and width. - int32_t out_h = - (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; - int32_t out_w = - (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; - - ET_DCHECK_MSG( - (out_h * out_w) == out.size(1), "dimension mismatch for output"); - ET_DCHECK_MSG( - (kernel_h * kernel_w * in_c) == out.size(2), - "dimension mismatch for output"); - - // Check if the input is per-tensor quantized or per-channel quantized. The - // zero point for each batch could differ for per-channel quantized input. - bool per_tensor_quantized = in_zero_point.numel() == 1; - -#define typed_im2row(dtype, ctype) \ - case ScalarType::dtype: { \ - const ctype* __restrict__ in_data = input.const_data_ptr(); \ - ctype* __restrict__ out_data = out.mutable_data_ptr(); \ - const int32_t* __restrict__ zero_point = \ - in_zero_point.const_data_ptr(); \ - int32_t in_plane = in_c * in_h * in_w; \ - int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ - for (size_t n = 0; n < batch_size; ++n) { \ - im2row_( \ - &in_data[n * in_plane], \ - per_tensor_quantized ? zero_point[0] : zero_point[n], \ - in_c, \ - in_h, \ - in_w, \ - out_h, \ - out_w, \ - kernel_h, \ - kernel_w, \ - pad_h, \ - pad_w, \ - stride_h, \ - stride_w, \ - dilation_h, \ - dilation_w, \ - &out_data[n * out_plane], \ - channel_last); \ - } \ - break; \ - } - - ScalarType dtype = input.scalar_type(); - switch (dtype) { - typed_im2row(Float, float); - typed_im2row(Byte, uint8_t); - typed_im2row(Char, int8_t); - default: - ET_DCHECK_MSG( - false, - "im2row not implemented for dtype %s", - torch::executor::toString(dtype)); - } -#undef typed_im2row -} - -void im2row_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - IntArrayRef kernel_size, - IntArrayRef dilation, - IntArrayRef padding, - IntArrayRef stride, - int64_t in_zero_point, - bool channel_last, - Tensor& out) { - // Compute the input tensor's dims - bool unit_height = input.dim() == 3; - const int32_t batch_size = input.size(0); - const int32_t in_c = - channel_last ? input.size(3 - unit_height) : input.size(1); - const int32_t in_h = - unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); - const int32_t in_w = - channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); - - // Get the kernel parameters - int32_t kernel_h = kernel_size[0]; - int32_t kernel_w = kernel_size[1]; - int32_t dilation_h = dilation[0]; - int32_t dilation_w = dilation[1]; - int32_t pad_h = padding[0]; - int32_t pad_w = padding[1]; - int32_t stride_h = stride[0]; - int32_t stride_w = stride[1]; - - // If we were to apply a convolution on the input tensor, compute the output - // height and width. - int32_t out_h = - (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; - int32_t out_w = - (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; - - ET_DCHECK_MSG( - (out_h * out_w) == out.size(1), "dimension mismatch for output"); - ET_DCHECK_MSG( - (kernel_h * kernel_w * in_c) == out.size(2), - "dimension mismatch for output"); - -#define typed_im2row_per_tensor(dtype, ctype) \ - case ScalarType::dtype: { \ - const ctype* __restrict__ in_data = input.const_data_ptr(); \ - ctype* __restrict__ out_data = out.mutable_data_ptr(); \ - int32_t in_plane = in_c * in_h * in_w; \ - int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ - for (size_t n = 0; n < batch_size; ++n) { \ - im2row_( \ - &in_data[n * in_plane], \ - in_zero_point, \ - in_c, \ - in_h, \ - in_w, \ - out_h, \ - out_w, \ - kernel_h, \ - kernel_w, \ - pad_h, \ - pad_w, \ - stride_h, \ - stride_w, \ - dilation_h, \ - dilation_w, \ - &out_data[n * out_plane], \ - channel_last); \ - } \ - break; \ - } - - ScalarType dtype = input.scalar_type(); - switch (dtype) { - typed_im2row_per_tensor(Float, float); - typed_im2row_per_tensor(Byte, uint8_t); - typed_im2row_per_tensor(Char, int8_t); - default: - ET_DCHECK_MSG( - false, - "im2row.per_tensor not implemented for dtype %s", - torch::executor::toString(dtype)); - } -#undef typed_im2row_per_tensor -} - -} // namespace native -} // namespace reference -} // namespace impl diff --git a/backends/cadence/reference/operators/op_add.cpp b/backends/cadence/reference/operators/op_add.cpp deleted file mode 100644 index 89b67467605..00000000000 --- a/backends/cadence/reference/operators/op_add.cpp +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include -#include - -namespace torch { -namespace executor { -namespace native { - -Tensor& add_out( - KernelRuntimeContext& ctx, - const Tensor& a, - const Tensor& b, - const Scalar& alpha, - Tensor& out) { - (void)ctx; - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type); - ScalarType out_type = out.scalar_type(); - - ET_CHECK_MSG(a_type == ScalarType::Float, "Input tensor not a float.\n"); - ET_CHECK_MSG(b_type == ScalarType::Float, "Input tensor not a float.\n"); - ET_CHECK_MSG(out_type == ScalarType::Float, "Output tensor not a float.\n"); - - ET_CHECK(canCast(common_type, out_type)); - - using CTYPE_A = float; - using CTYPE_B = float; - using CTYPE_IN = float; - using CTYPE_OUT = float; - CTYPE_IN alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); - - apply_binary_elementwise_fn( - [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted + alpha_val * b_casted; - - return static_cast(value); - }, - a, - b, - out); - - return out; -} - -} // namespace native -} // namespace executor -} // namespace torch diff --git a/backends/cadence/reference/operators/op_full.cpp b/backends/cadence/reference/operators/op_full.cpp deleted file mode 100644 index 21d5fc56299..00000000000 --- a/backends/cadence/reference/operators/op_full.cpp +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -namespace torch { -namespace executor { -namespace native { - -using executorch::aten::ScalarType; -using executorch::aten::Tensor; - -Tensor& full_out( - KernelRuntimeContext& ctx, - const IntArrayRef sizes, - const Scalar& fill_value, - Tensor& out) { - (void)ctx; - - ScalarType val_type = utils::get_scalar_dtype(fill_value); - ScalarType out_type = out.scalar_type(); - - Error err = resize_tensor(out, sizes); - ET_CHECK_MSG(err == Error::Ok, "Could not resize out"); - - ET_SWITCH_REAL_TYPES_AND(Bool, val_type, ctx, "full", CTYPE_VAL, [&] { - CTYPE_VAL val; - ET_EXTRACT_SCALAR(fill_value, val); - - ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "full", CTYPE_OUT, [&] { - CTYPE_OUT val_casted = static_cast(val); - auto data_out = out.mutable_data_ptr(); - for (size_t i = 0; i < out.numel(); ++i) { - data_out[i] = val_casted; - } - }); - }); - - return out; -} - -} // namespace native -} // namespace executor -} // namespace torch diff --git a/backends/cadence/reference/operators/operators.h b/backends/cadence/reference/operators/operators.h deleted file mode 100644 index 637f38f8fec..00000000000 --- a/backends/cadence/reference/operators/operators.h +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include -#include -#include - -namespace cadence { -namespace impl { -namespace cpu { -namespace native { -namespace { -using ::executorch::runtime::getLeadingDims; - -#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \ - _(uint8_t, Byte) \ - _(int8_t, Char) - -inline __attribute__((always_inline)) void linear_( - const ::executorch::aten::Tensor& input, - const ::executorch::aten::Tensor& weight, - const std::optional<::executorch::aten::Tensor>& bias, - ::executorch::aten::Tensor& output) { - const float* __restrict__ input_data = input.const_data_ptr(); - const float* __restrict__ weight_data = weight.const_data_ptr(); - const float* __restrict__ bias_data = bias.value().const_data_ptr(); - float* __restrict__ output_data = output.mutable_data_ptr(); - - // input comes in shape [batch_size, in_dim] - // weight comes in shape [out_dim, in_dim] - // output comes in empty with shape [batch_size, out_dim] - // Perform matrix multiply (M x N) x (N x P) => M x P - int64_t M = weight.size(0); // = out_dim - int64_t N = weight.size(1); // = in_dim - - // Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the - // leading dimensions is d0 * d1 * ... * d_{N-2} - int64_t leading_dims = getLeadingDims(input, input.dim() - 1); - - for (int i = 0; i < leading_dims; ++i) { - for (int j = 0; j < M; ++j) { - float sum = bias_data[j]; - for (int k = 0; k < N; ++k) { - sum += input_data[i * N + k] * weight_data[j * N + k]; - } - output_data[i * M + j] = sum; - } - } -} - -} // namespace -} // namespace native -} // namespace cpu -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/reference/operators/quantize_per_tensor.cpp b/backends/cadence/reference/operators/quantize_per_tensor.cpp deleted file mode 100644 index 8f8cc961dd8..00000000000 --- a/backends/cadence/reference/operators/quantize_per_tensor.cpp +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -namespace impl { -namespace reference { -namespace native { - -using executorch::aten::ScalarType; -using executorch::aten::Tensor; -using executorch::runtime::KernelRuntimeContext; - -// Quantize the input tensor (PT2 version). Note that quant_ are not -// used in any computation. -void quantize_per_tensor_out( - KernelRuntimeContext& context, - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - const float* input_data = input.const_data_ptr(); - size_t numel = out.numel(); - - if (out.scalar_type() == ScalarType::Byte) { - uint8_t* out_data = out.mutable_data_ptr(); - impl::reference::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); - } else if (out.scalar_type() == ScalarType::Char) { - int8_t* out_data = out.mutable_data_ptr(); - impl::reference::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); - } else if ( - out.scalar_type() == ScalarType::Bits16 || - out.scalar_type() == ScalarType::UInt16) { - uint16_t* out_data = out.mutable_data_ptr(); - impl::reference::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); - } else if (out.scalar_type() == ScalarType::Short) { - int16_t* out_data = out.mutable_data_ptr(); - impl::reference::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); - } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(out.scalar_type())); - } -} - -}; // namespace native -}; // namespace reference -}; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_add_out.cpp b/backends/cadence/reference/operators/quantized_add_out.cpp deleted file mode 100644 index 7e5834de7bf..00000000000 --- a/backends/cadence/reference/operators/quantized_add_out.cpp +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -namespace impl { -namespace reference { -namespace native { - -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; -using ::impl::reference::kernels::dequantize; -using ::impl::reference::kernels::quantize; - -template -void quantized_add_per_tensor_impl( - const Tensor& X, - double X_scale, - int64_t X_zero_point, - const Tensor& Y, - double Y_scale, - int64_t Y_zero_point, - double out_scale, - int64_t out_zero_point, - Tensor& out) { - const T* __restrict__ X_data = X.const_data_ptr(); - const T* __restrict__ Y_data = Y.const_data_ptr(); - T* __restrict__ out_data = out.mutable_data_ptr(); - - ssize_t Y_numel = Y.numel(); - ssize_t X_numel = X.numel(); - ssize_t out_numel = out.numel(); - - float X_scale_f = static_cast(X_scale); - float Y_scale_f = static_cast(Y_scale); - float out_scale_f = static_cast(out_scale); - int32_t X_zero_point_i32 = static_cast(X_zero_point); - int32_t Y_zero_point_i32 = static_cast(Y_zero_point); - int32_t out_zero_point_i32 = static_cast(out_zero_point); - - float inv_out_scale = 1.0f / out_scale_f; - - // Simple case: tensors have the same shape, no broadcasting - if (X_numel == Y_numel && Y_numel == out_numel) { - for (size_t i = 0; i < X_numel; ++i) { - float x = dequantize(X_data[i], X_scale_f, X_zero_point_i32); - float y = dequantize(Y_data[i], Y_scale_f, Y_zero_point_i32); - float z = x + y; - out_data[i] = quantize(z, inv_out_scale, out_zero_point_i32); - } - } - // Y is a scalar tensor - else if (Y_numel == 1) { - float y = dequantize(Y_data[0], Y_scale_f, Y_zero_point_i32); - for (size_t i = 0; i < X_numel; ++i) { - float x = dequantize(X_data[i], X_scale_f, X_zero_point_i32); - float z = x + y; - out_data[i] = quantize(z, inv_out_scale, out_zero_point_i32); - } - } - // X is a scalar tensor - else if (X_numel == 1) { - float x = dequantize(X_data[0], X_scale_f, X_zero_point_i32); - for (size_t i = 0; i < Y_numel; ++i) { - float y = dequantize(Y_data[i], Y_scale_f, Y_zero_point_i32); - float z = x + y; - out_data[i] = quantize(z, inv_out_scale, out_zero_point_i32); - } - } - // General broadcasting case - simplified implementation - else { - for (ssize_t i = 0; i < out_numel; ++i) { - // Simple broadcasting: repeat elements as needed - size_t x_idx = (X_numel == 1) ? 0 : i % X_numel; - size_t y_idx = (Y_numel == 1) ? 0 : i % Y_numel; - - float x = dequantize(X_data[x_idx], X_scale_f, X_zero_point_i32); - float y = dequantize(Y_data[y_idx], Y_scale_f, Y_zero_point_i32); - float z = x + y; - out_data[i] = quantize(z, inv_out_scale, out_zero_point_i32); - } - } -} - -// Generic quantized add with type dispatch -void quantized_add_per_tensor_out( - KernelRuntimeContext& ctx, - const Tensor& X, - double X_scale, - int64_t X_zero_point, - const Tensor& Y, - double Y_scale, - int64_t Y_zero_point, - double out_scale, - int64_t out_zero_point, - Tensor& out) { - (void)ctx; - - executorch::aten::ScalarType dtype = X.scalar_type(); - switch (dtype) { - case executorch::aten::ScalarType::Byte: - quantized_add_per_tensor_impl( - X, - X_scale, - X_zero_point, - Y, - Y_scale, - Y_zero_point, - out_scale, - out_zero_point, - out); - break; - case executorch::aten::ScalarType::Char: - quantized_add_per_tensor_impl( - X, - X_scale, - X_zero_point, - Y, - Y_scale, - Y_zero_point, - out_scale, - out_zero_point, - out); - break; - default: - ET_CHECK_MSG( - false, "Unhandled input dtype %hhd", static_cast(dtype)); - } -} - -// int8-specific quantized add -void quantized_add_asym8sxasym8s_asym8s_per_tensor_out( - KernelRuntimeContext& ctx, - const Tensor& X, - double X_scale, - int64_t X_zero_point, - const Tensor& Y, - double Y_scale, - int64_t Y_zero_point, - double out_scale, - int64_t out_zero_point, - Tensor& out) { - (void)ctx; - - quantized_add_per_tensor_impl( - X, - X_scale, - X_zero_point, - Y, - Y_scale, - Y_zero_point, - out_scale, - out_zero_point, - out); -} - -// uint8-specific quantized add -void quantized_add_asym8uxasym8u_asym8u_per_tensor_out( - KernelRuntimeContext& ctx, - const Tensor& X, - double X_scale, - int64_t X_zero_point, - const Tensor& Y, - double Y_scale, - int64_t Y_zero_point, - double out_scale, - int64_t out_zero_point, - Tensor& out) { - (void)ctx; - - quantized_add_per_tensor_impl( - X, - X_scale, - X_zero_point, - Y, - Y_scale, - Y_zero_point, - out_scale, - out_zero_point, - out); -} - -} // namespace native -} // namespace reference -} // namespace impl diff --git a/backends/cadence/reference/operators/quantized_conv_nchw_out.cpp b/backends/cadence/reference/operators/quantized_conv_nchw_out.cpp deleted file mode 100644 index aefa75d7047..00000000000 --- a/backends/cadence/reference/operators/quantized_conv_nchw_out.cpp +++ /dev/null @@ -1,501 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -namespace impl { -namespace reference { -namespace native { - -using ::executorch::aten::IntArrayRef; -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; - -// This implements a generic 2d conv kernel that operates on raw pointers. -// The version handles both quantized and fp32 convolutions. -// The input is of shape [n x c x h x w] -// The weight is of shape [oc x wc x wh x ww], where wc == c -// The output is of shape [n x oc x oh x ow] -// The bias is of shape [oc] -template < - typename IT = float, - typename WT = IT, - typename BT = IT, - typename OT = IT, - bool quantized = false> -__attribute__((noinline)) void conv2d_nchw_core_generic( - // All the arrays - const IT* __restrict__ p_in, - const WT* __restrict__ p_weight, - const BT* __restrict__ p_bias, - OT* __restrict__ p_out, - // The array sizes - int32_t n, - int32_t c, - int32_t h, - int32_t w, - int32_t oc, - int32_t wc, - int32_t wh, - int32_t ww, - int32_t oh, - int32_t ow, - // Stride - int16_t s0, - int16_t s1, - // Padding - int16_t p0, - int16_t p1, - // Dilation - int16_t d0, - int16_t d1, - // Group for depthwise conv - int16_t groups, - // Optional args that are only relevant for quantized convolution - // input zero point - IT in_zero_point = 0, - // weight zero point - int32_t weight_zero_point = 0, - float bias_scale = 1, - float out_scale = 1, - OT out_zero_point = 0) { - float inv_out_scale = 1. / out_scale; - bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; - - // Compute the number of in and out channels per group - const int ocpg = oc / groups; - const int icpg = c / groups; - - // Iterate over all the output batches (i.e., n) - for (int _n = 0; _n < n; ++_n) { - const IT* in_batch = p_in + _n * c * h * w; - OT* out_batch = p_out + _n * oc * oh * ow; - // Compute separable convolution for each group - for (int _g = 0; _g < groups; ++_g) { - // Identify the input and output channels involved in the computation - // of this group - int sic = _g * icpg; - int soc = _g * ocpg; - // Populate all the output channels in the group - for (int _oc = soc; _oc < soc + ocpg; ++_oc) { - OT* out_plane = out_batch + _oc * oh * ow; - const WT* weight_batch = p_weight + _oc * wc * wh * ww; - // We compute one output channel at a time. The computation can be - // thought of as a stencil computation: we iterate over an input of size - // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an - // output channel of size 1 x oh x ow. - for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { - for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { - float acc = p_bias[_oc]; - // Below is the stencil computation that performs the hadamard - // product+accumulation of each input channel (contributing to the - // output channel being computed) with the corresponding weight - // channel. - // If the padding is 0, and dilation is 1, then we can remove the - // unnecessary checks, and simplify the code so that it can be - // vectorized by Tensilica compiler. - if (zero_pad_unit_dilation) { - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - const IT* in_plane = in_batch + _ic * h * w; - const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - int ioff = (_h + _wh) * w + (_w + _ww); - int woff = _wh * ww + _ww; - float lhs = in_plane[ioff] - in_zero_point; - float rhs = weight_plane[woff] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } else { - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - const IT* in_plane = in_batch + _ic * h * w; - const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - if (((_h + d0 * _wh - p0) >= 0) && - ((_h + d0 * _wh - p0) < h) && - ((_w + d1 * _ww - p1) >= 0) && - ((_w + d1 * _ww - p1) < w)) { - int ioff = - (_h + d0 * _wh - p0) * w + (_w + d1 * _ww - p1); - int woff = _wh * ww + _ww; - float lhs = in_plane[ioff] - in_zero_point; - float rhs = weight_plane[woff] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } - } - if (quantized) { - float val = bias_scale * acc; - out_plane[_oh * ow + _ow] = - ::impl::reference::kernels::quantize( - val, inv_out_scale, out_zero_point); - } else { - out_plane[_oh * ow + _ow] = acc; - } - } - } - } - } - } -} - -// The quantized convolution kernel. in_scale and weight_scale are implicit in -// bias_scale, since it is a product of the two. The kernel will branch to -// quantized::conv1d or quantized::conv2d based on the dimensionality of -// activation tensor. -void quantized_conv_nchw( - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - // input = [n, c, h, w] - const int n = input.size(0); - const int c = input.size(1); - const int h = conv1d ? 1 : input.size(2); - const int w = conv1d ? input.size(2) : input.size(3); - // weight = [oc, wc, wh, ww] - const int oc = weight.size(0); - const int wc = weight.size(1); - const int wh = conv1d ? 1 : weight.size(2); - const int ww = conv1d ? weight.size(2) : weight.size(3); - // output = [n, oc, oh, ow] - const int oh = conv1d ? 1 : out.size(2); - const int ow = conv1d ? out.size(2) : out.size(3); - -#define typed_quantized_conv2d_nchw(ctype, dtype) \ - case ScalarType::dtype: { \ - conv2d_nchw_core_generic( \ - input.const_data_ptr(), \ - weight.const_data_ptr(), \ - bias.const_data_ptr(), \ - out.mutable_data_ptr(), \ - n, \ - c, \ - h, \ - w, \ - oc, \ - wc, \ - wh, \ - ww, \ - oh, \ - ow, \ - stride[0], \ - stride[1], \ - padding[0], \ - padding[1], \ - dilation[0], \ - dilation[1], \ - groups, \ - in_zero_point, \ - weight_zero_point, \ - bias_scale, \ - output_scale, \ - (ctype)output_zero_point); \ - break; \ - } - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nchw); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_conv2d_nchw -} - -void quantized_conv_nchw_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - const Tensor& weight_zero_point, - const Tensor& bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED const Tensor& out_multiplier, - __ET_UNUSED const Tensor& out_shift, - Tensor& out) { - const float bias_scale_float = bias_scale.const_data_ptr()[0]; - const int32_t weight_zero_point_int = - weight_zero_point.const_data_ptr()[0]; - quantized_conv_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point_int, - bias_scale_float, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nchw_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - bool channel_last, - Tensor& out) { - quantized_conv_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -} // namespace native -} // namespace reference -} // namespace impl diff --git a/backends/cadence/reference/operators/quantized_conv_nhwc_out.cpp b/backends/cadence/reference/operators/quantized_conv_nhwc_out.cpp deleted file mode 100644 index 26fbc86d5b0..00000000000 --- a/backends/cadence/reference/operators/quantized_conv_nhwc_out.cpp +++ /dev/null @@ -1,488 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -namespace impl { -namespace reference { -namespace native { - -using ::executorch::aten::IntArrayRef; -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; - -template < - typename IT = float, - typename WT = IT, - typename BT = IT, - typename OT = IT, - bool quantized = false> -__attribute__((noinline)) void conv2d_nhwc_core_generic( - // All the arrays - const IT* __restrict__ p_in, - const WT* __restrict__ p_weight, - const BT* __restrict__ p_bias, - OT* __restrict__ p_out, - // The array sizes - int32_t n, - int32_t h, - int32_t w, - int32_t c, - int32_t oc, - int32_t wh, - int32_t ww, - int32_t wc, - int32_t oh, - int32_t ow, - // Stride - int16_t s0, - int16_t s1, - // Padding - int16_t p0, - int16_t p1, - // Dilation - int16_t d0, - int16_t d1, - // Group for depthwise conv - int16_t groups, - // Optional args that are only relevant for quantized convolution - // input zero point - IT in_zero_point = 0, - // weight zero point - int32_t weight_zero_point = 0, - float bias_scale = 1, - float out_scale = 1, - OT out_zero_point = 0) { - float inv_out_scale = 1. / out_scale; - bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; - - // Compute the number of in and out channels per group - const int ocpg = oc / groups; - const int icpg = c / groups; - - // Iterate over all the output batches (i.e., n) - for (int _n = 0; _n < n; ++_n) { - const IT* in_batch = p_in + _n * h * w * c; - OT* out_batch = p_out + _n * oh * ow * oc; - for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { - for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { - OT* out_line = out_batch + (_oh * ow + _ow) * oc; - // Compute separable convolution for each group - for (int _g = 0; _g < groups; ++_g) { - // Identify the input and output channels involved in the computation - // of this group - int sic = _g * icpg; - int soc = _g * ocpg; - // Populate all the output channels in the group - for (int _oc = soc; _oc < soc + ocpg; ++_oc) { - const WT* weight_batch = p_weight + _oc * wh * ww * wc; - // We compute one output channel at a time. The computation can be - // thought of as a stencil computation: we iterate over an input of - // size h x w x icpg, with a stencil of size wh x ww x icpg, to - // compute an output channel of size oh x ow x 1. - float acc = p_bias[_oc]; - // Below is the stencil computation that performs the hadamard - // product+accumulation of each input channel (contributing to - // the output channel being computed) with the corresponding - // weight channel. If the padding is 0, and dilation is 1, then - // we can remove the unnecessary checks, and simplify the code - // so that it can be vectorized by Tensilica compiler.x`` - if (zero_pad_unit_dilation) { - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - const IT* in_line = - in_batch + (_h + _wh) * w * c + (_w + _ww) * c; - const WT* weight_line = - weight_batch + _wh * ww * wc + _ww * wc; - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - float lhs = in_line[_ic] - in_zero_point; - float rhs = weight_line[_ic - sic] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } else { - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - if (((_h + d0 * _wh - p0) >= 0) && - ((_h + d0 * _wh - p0) < h) && - ((_w + d1 * _ww - p1) >= 0) && - ((_w + d1 * _ww - p1 < w))) { - const IT* in_line = in_batch + - (_h + d0 * _wh - p0) * w * c + (_w + d1 * _ww - p1) * c; - const WT* weight_line = - weight_batch + _wh * ww * wc + _ww * wc; - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - float lhs = in_line[_ic] - in_zero_point; - float rhs = weight_line[_ic - sic] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } - } - if (quantized) { - float val = bias_scale * acc; - out_line[_oc] = ::impl::reference::kernels::quantize( - val, inv_out_scale, out_zero_point); - } else { - out_line[_oc] = acc; - } - } - } - } - } - } -} - -void quantized_conv_nhwc( - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - // input = [n, h, w, c] - const int n = input.size(0); - const int h = conv1d ? 1 : input.size(1); - const int w = conv1d ? input.size(1) : input.size(2); - const int c = conv1d ? input.size(2) : input.size(3); - // weight = [oc, wh, ww, wc] - const int oc = weight.size(0); - const int wh = conv1d ? 1 : weight.size(1); - const int ww = conv1d ? weight.size(1) : weight.size(2); - const int wc = conv1d ? weight.size(2) : weight.size(3); - // output = [n, oh, ow, oc] - const int oh = conv1d ? 1 : out.size(1); - const int ow = conv1d ? out.size(1) : out.size(2); - -#define typed_quantized_conv2d_nhwc(ctype, dtype) \ - case ScalarType::dtype: { \ - conv2d_nhwc_core_generic( \ - input.const_data_ptr(), \ - weight.const_data_ptr(), \ - bias.const_data_ptr(), \ - out.mutable_data_ptr(), \ - n, \ - h, \ - w, \ - c, \ - oc, \ - wh, \ - ww, \ - wc, \ - oh, \ - ow, \ - stride[0], \ - stride[1], \ - padding[0], \ - padding[1], \ - dilation[0], \ - dilation[1], \ - groups, \ - in_zero_point, \ - weight_zero_point, \ - bias_scale, \ - output_scale, \ - (ctype)output_zero_point); \ - break; \ - } - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nhwc); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_conv2d_nhwc -} - -void quantized_conv_nhwc_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - const Tensor& weight_zero_point, - const Tensor& bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED const Tensor& out_multiplier, - __ET_UNUSED const Tensor& out_shift, - Tensor& out) { - const float bias_scale_float = bias_scale.const_data_ptr()[0]; - const int32_t weight_zero_point_int = - weight_zero_point.const_data_ptr()[0]; - quantized_conv_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point_int, - bias_scale_float, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nhwc_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - bool channel_last, - Tensor& out) { - quantized_conv_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -} // namespace native -} // namespace reference -} // namespace impl diff --git a/backends/cadence/reference/operators/quantized_fully_connected_out.cpp b/backends/cadence/reference/operators/quantized_fully_connected_out.cpp deleted file mode 100644 index 136055de70a..00000000000 --- a/backends/cadence/reference/operators/quantized_fully_connected_out.cpp +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -#include -#include - -namespace impl { -namespace reference { -namespace native { - -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; -using std::optional; - -void quantized_fully_connected_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& in, - const Tensor& weight, - const Tensor& bias, - int64_t in_zero_point, - const Tensor& weight_zero_point_t, - const Tensor& out_multiplier, - const Tensor& out_shift, - int64_t out_zero_point, - __ET_UNUSED const optional& offset, - Tensor& out) { -#define typed_quantized_linear(ctype, dtype) \ - case ScalarType::dtype: { \ - quantized_linear_( \ - in, \ - weight, \ - bias, \ - in_zero_point, \ - weight_zero_point_t, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } -#undef typed_quantized_linear -} - -void quantized_fully_connected_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& in, - const Tensor& weight, - const Tensor& bias, - int64_t in_zero_point, - int64_t weight_zero_point, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - __ET_UNUSED const optional& offset, - Tensor& out) { -#define typed_quantized_linear(ctype, dtype) \ - case ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - in, \ - weight, \ - bias, \ - in_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } -#undef typed_quantized_linear -} - -void quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& in, - const Tensor& weight, - const Tensor& bias, - int64_t in_zero_point, - int64_t weight_zero_point, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - __ET_UNUSED const optional& offset, - Tensor& out) { -#define typed_quantized_linear(ctype, dtype) \ - case ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - in, \ - weight, \ - bias, \ - in_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } -#undef typed_quantized_linear -} - -void quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& in, - const Tensor& weight, - const Tensor& bias, - int64_t in_zero_point, - int64_t weight_zero_point, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - __ET_UNUSED const optional& offset, - Tensor& out) { -#define typed_quantized_linear(ctype, dtype) \ - case ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - in, \ - weight, \ - bias, \ - in_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } -#undef typed_quantized_linear -} - -}; // namespace native -}; // namespace reference -}; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_layer_norm.cpp b/backends/cadence/reference/operators/quantized_layer_norm.cpp deleted file mode 100644 index 64dcb000cc1..00000000000 --- a/backends/cadence/reference/operators/quantized_layer_norm.cpp +++ /dev/null @@ -1,203 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -#include - -using ::executorch::aten::IntArrayRef; -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::getLeadingDims; -using ::executorch::runtime::KernelRuntimeContext; -using ::impl::reference::kernels::dequantize; -using ::impl::reference::kernels::quantize; - -namespace impl { -namespace reference { -namespace native { - -// Compute quantized layer_norm. The current implementation assumes that the -// input is per-tensor quantized. -template -void quantized_layer_norm_per_tensor_( - const Tensor& input, - double input_scale, - int64_t input_zero_point, - const Tensor& weight, - const Tensor& bias, - double eps, - double output_scale, - int64_t output_zero_point, - Tensor& out) { - // Get the raw pointers to input, output, weight, and bias - const T* __restrict__ in_data = input.const_data_ptr(); - T* __restrict__ out_data = out.mutable_data_ptr(); - const float* __restrict__ weight_data = weight.const_data_ptr(); - const float* __restrict__ bias_data = bias.const_data_ptr(); - - float output_inv_scale = 1.0f / output_scale; - - size_t last_dim = input.size(input.dim() - 1); - size_t leading_dims = getLeadingDims(input, input.dim() - 1); - - // Visualize the input tensor as a set of 1d vectors, and compute the - // layer_norm for each vector. - for (size_t i = 0; i < leading_dims; ++i) { - const T* x = in_data + i * last_dim; - T* y = out_data + i * last_dim; - - // compute sum and squared sum. The fp32 sum can be approximated as: - // (X_1 - in_zero_point) * in_scale + (X_2 - in_zero_point) * in_scale + ... - // (X_N - in_zero_point) * in_scale. - int32_t sum = 0; - int32_t sq_sum = last_dim * input_zero_point * input_zero_point; - for (size_t j = 0; j < last_dim; ++j) { - int32_t val = x[j]; - sum += val; - sq_sum += val * val; - } - sq_sum -= (2 * sum * input_zero_point); - sum -= (last_dim * input_zero_point); - - float mean = (input_scale * sum) / last_dim; - float variance = - (sq_sum * input_scale * input_scale) / last_dim - mean * mean; - float inv_std = 1.0f / std::sqrt(variance + eps); - - // y = (x - mean) / std * kGamma + kBeta - for (int j = 0; j < last_dim; ++j) { - // y[j] = (x[j] - mean) / std * kGamma + kBeta; - // Since X is quantized, we dequantize it, compute fp32 result, and - // quantize the result to an int8/uint8 value. - float val = dequantize(x[j], input_scale, input_zero_point); - - val = (val - mean) * inv_std * weight_data[j] + bias_data[j]; - y[j] = quantize(val, output_inv_scale, output_zero_point); - } - } -} - -// Compute quantized layer_norm. The current implementation assumes that the -// input is per-tensor quantized. -template -void quantized_layer_norm_( - const Tensor& input, - const Tensor& in_scale, - const Tensor& in_zero_point, - const Tensor& weight, - const Tensor& bias, - double eps, - double output_scale, - int64_t output_zero_point, - Tensor& out) { - // Extract the zero point and scale for input tensor. - float input_scale = in_scale.const_data_ptr()[0]; - int64_t input_zero_point = in_zero_point.const_data_ptr()[0]; - - // Call other overload - quantized_layer_norm_per_tensor_( - input, - input_scale, - input_zero_point, - weight, - bias, - eps, - output_scale, - output_zero_point, - out); -} - -void quantized_layer_norm_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& in_scale, - const Tensor& in_zero_point, - __ET_UNUSED const executorch::aten::IntArrayRef normalized_shape, - const Tensor& weight, - const Tensor& bias, - double eps, - double output_scale, - int64_t output_zero_point, - Tensor& out) { - if (input.scalar_type() == executorch::aten::ScalarType::Byte) { - quantized_layer_norm_( - input, - in_scale, - in_zero_point, - weight, - bias, - eps, - output_scale, - output_zero_point, - out); - } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { - quantized_layer_norm_( - input, - in_scale, - in_zero_point, - weight, - bias, - eps, - output_scale, - output_zero_point, - out); - } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(input.scalar_type())); - } -} - -void quantized_layer_norm_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - double in_scale, - int64_t in_zero_point, - __ET_UNUSED const executorch::aten::IntArrayRef normalized_shape, - const Tensor& weight, - const Tensor& bias, - double eps, - double output_scale, - int64_t output_zero_point, - Tensor& out) { - if (input.scalar_type() == executorch::aten::ScalarType::Byte) { - quantized_layer_norm_per_tensor_( - input, - in_scale, - in_zero_point, - weight, - bias, - eps, - output_scale, - output_zero_point, - out); - } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { - quantized_layer_norm_per_tensor_( - input, - in_scale, - in_zero_point, - weight, - bias, - eps, - output_scale, - output_zero_point, - out); - } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(input.scalar_type())); - } -} - -}; // namespace native -}; // namespace reference -}; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_linear_out.cpp b/backends/cadence/reference/operators/quantized_linear_out.cpp deleted file mode 100644 index f60c98e5875..00000000000 --- a/backends/cadence/reference/operators/quantized_linear_out.cpp +++ /dev/null @@ -1,233 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -namespace impl { -namespace reference { -namespace native { - -using executorch::aten::Tensor; -using executorch::runtime::getLeadingDims; -using executorch::runtime::KernelRuntimeContext; - -template -void inline _typed_quantized_linear( - const Tensor& src, - const Tensor& weight, - const Tensor& bias, - int64_t src_zero_point, - const Tensor& weight_zero_point_t, - const Tensor& out_multiplier, - const Tensor& out_shift, - int64_t out_zero_point, - Tensor& out) { - const T* __restrict__ src_data = src.const_data_ptr(); - const T* __restrict__ weight_data = weight.const_data_ptr(); - const int32_t* __restrict__ bias_data = bias.const_data_ptr(); - T* __restrict__ out_data = out.mutable_data_ptr(); - - int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; - - // input comes in shape [batch_size, in_dim] - // weight comes in shape [out_dim, in_dim] - // output comes in empty with shape [batch_size, out_dim] - // Perform matrix multiply (M x N) x (N x P) => M x P - const auto M = weight.size(0); // = out_dim - const auto N = weight.size(1); // = in_dim - - // Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the - // leading dimensions is d0 * d1 * ... * d_{N-2} - const auto leading_dims = getLeadingDims(src, src.dim() - 1); - - ET_CHECK_MSG( - out_multiplier.numel() == 1, "out_multiplier should have one element"); - ET_CHECK_MSG( - out_shift.numel() == 1, "out_multiplier should have one element"); - - const int32_t* __restrict__ out_multiplier_data = - out_multiplier.const_data_ptr(); - const int32_t* __restrict__ out_shift_data = - out_shift.const_data_ptr(); - - // Compute the out_scale from out_multiplier and out_shift - const float out_scale = - -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); - - for (int i = 0; i < leading_dims; ++i) { - for (int j = 0; j < M; ++j) { - float sum = bias_data[j]; - for (int k = 0; k < N; ++k) { - sum += (src_data[i * N + k] - src_zero_point) * - (weight_data[j * N + k] - weight_zero_point); - } - out_data[i * M + j] = - kernels::quantize(sum, out_scale, out_zero_point); - } - } -} - -void quantized_linear_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& src, - const Tensor& weight, - const Tensor& bias, - int64_t src_zero_point, - const Tensor& weight_zero_point_t, - const Tensor& out_multiplier, - const Tensor& out_shift, - int64_t out_zero_point, - __ET_UNUSED const std::optional& offset, - Tensor& out) { - // TODO: refactor to use switch case as quantized_linear_per_tensor_out - if (out.scalar_type() == executorch::aten::ScalarType::Byte) { - _typed_quantized_linear( - src, - weight, - bias, - src_zero_point, - weight_zero_point_t, - out_multiplier, - out_shift, - out_zero_point, - out); - } else if (out.scalar_type() == executorch::aten::ScalarType::Char) { - _typed_quantized_linear( - src, - weight, - bias, - src_zero_point, - weight_zero_point_t, - out_multiplier, - out_shift, - out_zero_point, - out); - } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(src.scalar_type())); - } -} - -void quantized_linear_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& src, - const Tensor& weight, - const Tensor& bias, - const int64_t src_zero_point, - const int64_t weight_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - const int64_t out_zero_point, - __ET_UNUSED const std::optional& offset, - Tensor& out) { -#define typed_quantized_linear_per_tensor(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - src, \ - weight, \ - bias, \ - src_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - executorch::aten::ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); - } -#undef typed_quantized_linear_per_tensor -} - -void quantized_linear_asym8sxasym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& src, - const Tensor& weight, - const Tensor& bias, - const int64_t src_zero_point, - const int64_t weight_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - const int64_t out_zero_point, - __ET_UNUSED const std::optional& offset, - Tensor& out) { -#define typed_quantized_linear_per_tensor(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - src, \ - weight, \ - bias, \ - src_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - executorch::aten::ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); - } -#undef typed_quantized_linear_per_tensor -} - -void quantized_linear_asym8uxasym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& src, - const Tensor& weight, - const Tensor& bias, - const int64_t src_zero_point, - const int64_t weight_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - const int64_t out_zero_point, - __ET_UNUSED const std::optional& offset, - Tensor& out) { -#define typed_quantized_linear_per_tensor(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - src, \ - weight, \ - bias, \ - src_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - executorch::aten::ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); - } -#undef typed_quantized_linear_per_tensor -} - -}; // namespace native -}; // namespace reference -}; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_matmul_out.cpp b/backends/cadence/reference/operators/quantized_matmul_out.cpp deleted file mode 100644 index 3c2070c70dc..00000000000 --- a/backends/cadence/reference/operators/quantized_matmul_out.cpp +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -namespace impl { -namespace reference { -namespace native { - -using executorch::aten::Tensor; -using executorch::runtime::getLeadingDims; -using executorch::runtime::KernelRuntimeContext; - -// The quantized matmul. The quantized matmul accumulates in a wider register, -// whose type is TA. -template < - typename TZ, - typename TA = float, - bool transposed = false, - typename TX = TZ, - typename TY = TZ> -__attribute__((noinline)) void qmatmul( - TZ* __restrict__ Z, - int32_t Z_multiplier, - int32_t Z_shift, - int32_t Z_zero_point, - const TX* __restrict__ X, - int32_t X_zero_point, - const TY* __restrict__ y, - int32_t Y_zero_point, - size_t m, - size_t n, - size_t p) { - // Compute the Z_scale from Z_multiplier and Z_shift - const float Z_scale = -Z_multiplier * 1.0 / (1 << 31) * pow(2, Z_shift); - for (size_t i = 0; i < m; ++i) { - for (size_t j = 0; j < p; ++j) { - TA sum = 0; - for (size_t k = 0; k < n; ++k) { - if (transposed) { - sum += (X[i * n + k] - X_zero_point) * (y[j * n + k] - Y_zero_point); - } else { - sum += (X[i * n + k] - X_zero_point) * (y[k * p + j] - Y_zero_point); - } - } - Z[i * p + j] = kernels::quantize(sum, Z_scale, Z_zero_point); - } - } -} - -template -void inline _typed_quantized_matmul( - const Tensor& X, - int64_t X_zero_point, - const Tensor& Y, - int64_t Y_zero_point, - const std::optional& bias, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - bool transposed, - Tensor& out) { - size_t batch_size = getLeadingDims(X, X.dim() - 2); - size_t leading_dim = X.size(X.dim() - 2); - size_t out_dim = Y.size(Y.dim() - 1 - transposed); - size_t in_dim = X.size(X.dim() - 1); - - T* __restrict__ out_data = out.mutable_data_ptr(); - const T* __restrict__ X_data = X.const_data_ptr(); - const T* __restrict__ Y_data = Y.const_data_ptr(); - for (size_t i = 0; i < batch_size; ++i) { - const T* x = X_data + i * leading_dim * in_dim; - const T* y = Y_data + i * in_dim * out_dim; - T* z = out_data + i * leading_dim * out_dim; - if (transposed) { - qmatmul( - z, - static_cast(out_multiplier), - static_cast(out_shift), - static_cast(out_zero_point), - x, - static_cast(X_zero_point), - y, - static_cast(Y_zero_point), - leading_dim, - in_dim, - out_dim); - } else { - qmatmul( - z, - static_cast(out_multiplier), - static_cast(out_shift), - static_cast(out_zero_point), - x, - static_cast(X_zero_point), - y, - static_cast(Y_zero_point), - leading_dim, - in_dim, - out_dim); - } - } -} - -void quantized_matmul_out( - KernelRuntimeContext& ctx, - const Tensor& X, - int64_t X_zero_point, - const Tensor& Y, - int64_t Y_zero_point, - const std::optional& bias, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - bool transposed, - Tensor& out) { - if (out.scalar_type() == executorch::aten::ScalarType::Byte) { - _typed_quantized_matmul( - X, - X_zero_point, - Y, - Y_zero_point, - bias, - out_multiplier, - out_shift, - out_zero_point, - transposed, - out); - } else if (out.scalar_type() == executorch::aten::ScalarType::Char) { - _typed_quantized_matmul( - X, - X_zero_point, - Y, - Y_zero_point, - bias, - out_multiplier, - out_shift, - out_zero_point, - transposed, - out); - } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(X.scalar_type())); - } -} - -void quantized_matmul_asym8sxasym8s_asym8s_out( - KernelRuntimeContext& ctx, - const Tensor& X, - int64_t X_zero_point, - const Tensor& Y, - int64_t Y_zero_point, - const std::optional& bias, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - bool transposed, - Tensor& out) { - _typed_quantized_matmul( - X, - X_zero_point, - Y, - Y_zero_point, - bias, - out_multiplier, - out_shift, - out_zero_point, - transposed, - out); -} - -void quantized_matmul_asym8uxasym8u_asym8u_out( - KernelRuntimeContext& ctx, - const Tensor& X, - int64_t X_zero_point, - const Tensor& Y, - int64_t Y_zero_point, - const std::optional& bias, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - bool transposed, - Tensor& out) { - _typed_quantized_matmul( - X, - X_zero_point, - Y, - Y_zero_point, - bias, - out_multiplier, - out_shift, - out_zero_point, - transposed, - out); -} - -}; // namespace native -}; // namespace reference -}; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_relu_out.cpp b/backends/cadence/reference/operators/quantized_relu_out.cpp deleted file mode 100644 index 8dab01cf982..00000000000 --- a/backends/cadence/reference/operators/quantized_relu_out.cpp +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -namespace impl { -namespace reference { -namespace native { - -using executorch::aten::Tensor; -using executorch::runtime::KernelRuntimeContext; - -template -void quantized_relu_( - const Tensor& input, - const Tensor& in_zero_point, - const int64_t out_zero_point, - const Tensor& out_multiplier, - const Tensor& out_shift, - Tensor& output) { - T q_zero_point = in_zero_point.const_data_ptr()[0]; - const T* __restrict__ in = input.const_data_ptr(); - T* __restrict__ out = output.mutable_data_ptr(); - - const int32_t* __restrict__ out_multiplier_data = - out_multiplier.const_data_ptr(); - const int32_t* __restrict__ out_shift_data = - out_shift.const_data_ptr(); - - // Compute the out_scale from out_multiplier and out_shift - const float out_scale = - -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); - - for (size_t i = 0, e = input.numel(); i < e; ++i) { - const T temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0; - out[i] = kernels::quantize(temp, out_scale, out_zero_point); - } -} - -void quantized_relu_out( - KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& in_zero_point, - const int64_t out_zero_point, - const Tensor& out_multiplier, - const Tensor& out_shift, - Tensor& output) { - if (input.scalar_type() == executorch::aten::ScalarType::Byte) { - quantized_relu_( - input, - in_zero_point, - out_zero_point, - out_multiplier, - out_shift, - output); - } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { - quantized_relu_( - input, - in_zero_point, - out_zero_point, - out_multiplier, - out_shift, - output); - } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(input.scalar_type())); - } -} - -template -void quantized_relu_per_tensor_out_( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const int64_t in_zero_point, - const int64_t out_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - Tensor& output) { - const T* __restrict__ in = input.const_data_ptr(); - T* __restrict__ out = output.mutable_data_ptr(); - - // Compute the out_scale from out_multiplier and out_shift - const float out_scale = -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift); - - for (size_t i = 0, e = input.numel(); i < e; ++i) { - const float temp = in[i] > in_zero_point ? (in[i] - in_zero_point) : 0; - out[i] = kernels::quantize(temp, out_scale, out_zero_point); - } -} - -void quantized_relu_per_tensor_out( - KernelRuntimeContext& ctx, - const Tensor& input, - const int64_t in_zero_point, - const int64_t out_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - Tensor& output) { -#define typed_quantized_relu(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_relu_per_tensor_out_( \ - ctx, \ - input, \ - in_zero_point, \ - out_zero_point, \ - out_multiplier, \ - out_shift, \ - output); \ - break; \ - } - - executorch::aten::ScalarType dtype = input.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_relu -} - -void quantized_relu_asym8s_asym8s_per_tensor_out( - KernelRuntimeContext& ctx, - const Tensor& input, - const int64_t in_zero_point, - const int64_t out_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - Tensor& output) { -#define typed_quantized_relu(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_relu_per_tensor_out_( \ - ctx, \ - input, \ - in_zero_point, \ - out_zero_point, \ - out_multiplier, \ - out_shift, \ - output); \ - break; \ - } - - executorch::aten::ScalarType dtype = input.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_relu -} - -void quantized_relu_asym8u_asym8u_per_tensor_out( - KernelRuntimeContext& ctx, - const Tensor& input, - const int64_t in_zero_point, - const int64_t out_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - Tensor& output) { -#define typed_quantized_relu(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_relu_per_tensor_out_( \ - ctx, \ - input, \ - in_zero_point, \ - out_zero_point, \ - out_multiplier, \ - out_shift, \ - output); \ - break; \ - } - - executorch::aten::ScalarType dtype = input.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_relu -} - -}; // namespace native -}; // namespace reference -}; // namespace impl diff --git a/backends/cadence/reference/operators/targets.bzl b/backends/cadence/reference/operators/targets.bzl deleted file mode 100644 index 488aeebb82a..00000000000 --- a/backends/cadence/reference/operators/targets.bzl +++ /dev/null @@ -1,23 +0,0 @@ -load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX") -load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") - -def define_common_targets(): - runtime.cxx_library( - name = "cadence_cpu_ops", - srcs = glob([ - "*.cpp", - ]), - exported_headers =glob([ - "*.h", - ]), - platforms = CXX, - deps = [ - "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/runtime/kernel:kernel_includes", - "//executorch/kernels/portable/cpu:scalar_utils", - "//executorch/backends/cadence/reference/kernels:cadence_kernels", - ], - visibility = [ - "//executorch/backends/cadence/...", - ], - ) diff --git a/backends/cadence/runtime/TARGETS b/backends/cadence/runtime/TARGETS index 9c65c469280..65a578f4751 100644 --- a/backends/cadence/runtime/TARGETS +++ b/backends/cadence/runtime/TARGETS @@ -21,6 +21,7 @@ runtime.python_library( "//executorch/devtools/bundled_program/serialize:lib", "//executorch/devtools:lib", "//executorch/exir:lib", + ":etdump", ], ) diff --git a/backends/cadence/runtime/etdump.py b/backends/cadence/runtime/etdump.py new file mode 100644 index 00000000000..4ef5d28285a --- /dev/null +++ b/backends/cadence/runtime/etdump.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +import os +from typing import cast, Optional, Tuple + +import torch +from executorch.devtools import Inspector +from executorch.devtools.inspector import Event, EventBlock, PerfData +from executorch.devtools.inspector._inspector_utils import TimeScale +from tabulate import tabulate + + +class CadenceETDump: + def __init__(self, output_dir: str) -> None: + self.tensor_dump_dir: str = os.path.join(output_dir, "tensors") + self.etdump_path: str = os.path.join(output_dir, "etdump.etdp") + self.etrecord_path: Optional[str] = os.path.join(output_dir, "etrecord.bin") + self.debug_buffer_path: Optional[str] = os.path.join( + output_dir, "debug_output.bin" + ) + + if not os.path.exists(self.etdump_path): + raise RuntimeError(f"{self.etdump_path} does not exist") + # pyre-ignore[6]: os.path.exists expects str, but got Optional[str] + if not os.path.exists(self.etrecord_path): + logging.warning( + "ETRecord not found, intermediate tensors will not be dumped" + ) + self.etrecord_path = None + # pyre-ignore[6]: os.path.exists expects str, but got Optional[str] + if not os.path.exists(self.debug_buffer_path): + logging.warning( + "Debug buffer not found, intermediate tensors will not be dumped" + ) + self.debug_buffer_path = None + + self.et_inspector: Inspector = Inspector( + etdump_path=self.etdump_path, + debug_buffer_path=self.debug_buffer_path, + etrecord=self.etrecord_path, + source_time_scale=TimeScale.CYCLES, + target_time_scale=TimeScale.CYCLES, + ) + + def get_outputs(self, log_to_stdout: bool = False) -> Tuple[torch.Tensor]: + output = [ + event_block.run_output + for event_block in self.et_inspector.event_blocks + if event_block.name == "Execute" + ] + logging.debug(f"[CadenceETDump] output: {output}") + return output[0] + + def get_execute_event_block(self) -> EventBlock: + exec_blocks = [ + eb for eb in self.et_inspector.event_blocks if eb.name == "Execute" + ] + return exec_blocks[0] + + def should_include_event(self, event: Event) -> bool: + # exclude duplicate events + if event.name in ("OPERATOR_CALL", "Method::execute"): + return False + + # exclude custom multi-zion events + if event.name.startswith("DELEGATE_ZION"): + return False + + return True + + def print_summary( + self, + bundled_prog_size: Optional[int] = None, + external_link: Optional[str] = None, + ) -> None: + """ + Print performance summary with optional program size and external link. + + Args: + bundled_prog_size: Size of the bundled program in bytes (optional) + external_link: External analytics/monitoring link (optional, e.g., Scuba link for Meta internal use) + """ + block = self.get_execute_event_block() + op_events = [e for e in block.events if self.should_include_event(e)] + op_time_sum = sum([cast(PerfData, e.perf_data).avg for e in op_events]) + + overall_event = [ev for ev in block.events if ev.name == "Method::execute"] + if not len(overall_event) == 1: + logging.warning( + f"Expected one 'Method::execute' event, found {len(overall_event)}" + ) + + total_cycles = cast(PerfData, overall_event[0].perf_data).avg + op_cycles = op_time_sum + + # Build table data and headers dynamically based on what's provided + table_data = [ + "{:,.0f}".format(total_cycles), + "{:,.0f}".format(op_cycles), + "{:,.0f}".format(total_cycles - op_cycles), + "{:.2%}".format((total_cycles - op_cycles) / total_cycles), + ] + headers = [ + "Total Cycles", + "Cycles in Ops", + "Other Cycles", + "Framework Tax (%)", + ] + + # Add optional fields if provided + if bundled_prog_size is not None: + table_data.append("{:,.0f}".format(bundled_prog_size)) + headers.append("Bundled Program Size (bytes)") + + if external_link is not None: + table_data.append(external_link) + headers.append("External Link") + + logging.info( + "Performance Summary:\n%s", + tabulate( + [table_data], + headers=headers, + tablefmt="outline", + ), + ) + + def print_event_block(self) -> None: + logging.info("Profiled events:") + if logging.getLogger().level <= logging.INFO: + self.et_inspector.print_data_tabular() + + def dump_intermediate_tensors(self) -> None: + if self.etrecord_path is None: + logging.info("[CadenceETDump] Intermediate tensors not available") + return + + logging.info( + f"[CadenceETDump] Dumping intermediate tensors to {self.tensor_dump_dir}" + ) + os.makedirs(self.tensor_dump_dir, exist_ok=True) + exec_blocks = [ + eb for eb in self.et_inspector.event_blocks if eb.name == "Execute" + ] + if len(exec_blocks) > 1: + logging.warning( + f'Found {len(exec_blocks)} "Execute" blocks, using the first one and ignoring the rest.' + ) + block = exec_blocks[0] + + # OPERATOR_CALL events are duplicates that contain framework tax data. We don't need them + op_events = [e for e in block.events if e.name != "OPERATOR_CALL"] + torch.set_printoptions(profile="full") + + for event in op_events: + instr_id = event._instruction_id + if not event.debug_data: + logging.debug( + f"Missing intermediate tensor data for {event.name} ({instr_id=})" + ) + continue + + with open(f"{self.tensor_dump_dir}/{instr_id}.txt", "w") as f: + for dd in event.debug_data: + f.write(f"{str(dd)}\n\n") + torch.set_printoptions(profile="default") diff --git a/backends/cadence/runtime/runtime.py b/backends/cadence/runtime/runtime.py index 4d1c876bcdb..3a139e415ea 100644 --- a/backends/cadence/runtime/runtime.py +++ b/backends/cadence/runtime/runtime.py @@ -9,9 +9,8 @@ import logging import numbers -import os import tempfile -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Union import executorch.exir.schema as et_schema @@ -19,8 +18,8 @@ import torch from executorch.backends.cadence.runtime import utils +from executorch.backends.cadence.runtime.etdump import CadenceETDump from executorch.backends.cadence.runtime.executor import Executor -from executorch.devtools import Inspector from executorch.exir import ExecutorchProgramManager from executorch.exir._serialize._program import deserialize_pte_binary from executorch.exir.schema import DataLocation @@ -30,90 +29,6 @@ from torch.utils._pytree import TreeSpec -class CadenceETDump: - def __init__(self, output_dir: str) -> None: - self.tensor_dump_dir: str = os.path.join(output_dir, "tensors") - self.etdump_path: str = os.path.join(output_dir, "etdump.etdp") - self.etrecord_path: Optional[str] = os.path.join(output_dir, "etrecord.bin") - self.debug_buffer_path: Optional[str] = os.path.join( - output_dir, "debug_output.bin" - ) - - if not os.path.exists(self.etdump_path): - raise RuntimeError(f"{self.etdump_path} does not exist") - # pyre-ignore[6]: os.path.exists expects str, but got Optional[str] - if not os.path.exists(self.etrecord_path): - logging.warning( - "ETRecord not found, intermediate tensors will not be dumped" - ) - self.etrecord_path = None - # pyre-ignore[6]: os.path.exists expects str, but got Optional[str] - if not os.path.exists(self.debug_buffer_path): - logging.warning( - "Debug buffer not found, intermediate tensors will not be dumped" - ) - self.debug_buffer_path = None - - self.et_inspector: Inspector = Inspector( - etdump_path=self.etdump_path, - debug_buffer_path=self.debug_buffer_path, - etrecord=self.etrecord_path, - ) - - def get_outputs(self, log_to_stdout: bool = False) -> Tuple[torch.Tensor]: - output = [ - event_block.run_output - for event_block in self.et_inspector.event_blocks - if event_block.name == "Execute" - ] - logging.debug(f"[ETdump] output: {output}") - return output[0] - - def print_event_block(self) -> None: - logging.debug("[ETdump] data tabular:") - if logging.getLogger().level <= logging.DEBUG: - self.et_inspector.print_data_tabular() - - def print_event_data(self) -> None: - logging.debug("[ETdump] event data ") - for event_block in self.et_inspector.event_blocks: - for event in event_block.events: - logging.debug(event) - - def dump_intermediate_tensors(self) -> None: - if self.etrecord_path is None: - logging.info("[ETdump] Intermediate tensors not available") - return - - logging.info(f"[ETdump] Dumping intermediate tensors to {self.tensor_dump_dir}") - os.makedirs(self.tensor_dump_dir, exist_ok=True) - exec_blocks = [ - eb for eb in self.et_inspector.event_blocks if eb.name == "Execute" - ] - if len(exec_blocks) > 1: - logging.warning( - f'Found {len(exec_blocks)} "Execute" blocks, using the first one and ignoring the rest.' - ) - block = exec_blocks[0] - - # OPERATOR_CALL events are duplicates that contain framework tax data. We don't need them - op_events = [e for e in block.events if e.name != "OPERATOR_CALL"] - torch.set_printoptions(profile="full") - - for event in op_events: - instr_id = event._instruction_id - if not event.debug_data: - logging.debug( - f"Missing intermediate tensor data for {event.name} ({instr_id=})" - ) - continue - - with open(f"{self.tensor_dump_dir}/{instr_id}.txt", "w") as f: - for dd in event.debug_data: - f.write(f"{str(dd)}\n\n") - torch.set_printoptions(profile="default") - - def get_op_names(program: et_schema.Program, execution_plan_id: int = 0) -> set[str]: """ Get the list of operators from a Program @@ -130,7 +45,7 @@ def get_op_names(program: et_schema.Program, execution_plan_id: int = 0) -> set[ op_names |= get_op_names( deserialize_pte_binary( program.backend_delegate_data[delegate.processed.index].data - ) + ).program ) return op_names @@ -162,6 +77,9 @@ def run( etdump = CadenceETDump(output_dir=working_dir) outputs = etdump.get_outputs() + # Print performance summary + etdump.print_summary() + assert isinstance(out_spec, TreeSpec) outputs = torch.utils._pytree.tree_unflatten(outputs, out_spec) diff --git a/backends/cadence/runtime/targets.bzl b/backends/cadence/runtime/targets.bzl index dabe42ad824..09a116764c2 100644 --- a/backends/cadence/runtime/targets.bzl +++ b/backends/cadence/runtime/targets.bzl @@ -13,3 +13,17 @@ def define_common_targets(): "//executorch/runtime/platform:platform", ], ) + + runtime.python_library( + name = "etdump", + srcs = ["etdump.py"], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS" + ], + deps = [ + "fbcode//executorch/devtools:lib", + "fbcode//executorch/devtools/inspector:inspector_utils", + "fbsource//third-party/pypi/tabulate:tabulate", + ], + ) diff --git a/backends/cadence/utils/facto_util.py b/backends/cadence/utils/facto_util.py index fd056cd08cc..b5c5683ab5d 100644 --- a/backends/cadence/utils/facto_util.py +++ b/backends/cadence/utils/facto_util.py @@ -15,6 +15,7 @@ import torch from facto.inputgen.argtuple.gen import ArgumentTupleGenerator from facto.inputgen.specs.model import ConstraintProducer as cp +from facto.inputgen.utils.random_manager import seeded_random_manager as rm from facto.inputgen.variable.type import ScalarDtype from facto.specdb.db import SpecDictDB @@ -22,7 +23,123 @@ MAX_CASES = 50 +# Global cache to store generated shapes per tensor to ensure consistency +_shape_cache: dict[str, list[int]] = {} + + +def _positive_valid_dim_list(tensor: torch.Tensor, length: int) -> set[tuple[int, ...]]: + """ + Generate valid permutations using only positive dimension indices. + This is required for Cadence/Xtensa kernels that don't support negative indexing. + + Args: + tensor: Input tensor to generate permutations for + length: Number of dimensions in the permutation (must equal tensor.dim()) + + Returns: + Set of valid permutation tuples containing only positive indices [0, rank-1] + """ + if length > tensor.dim(): + return set() + + n = tensor.dim() + pool = list(range(n)) + + # Generate multiple valid permutations (only positive indices) + permutations: set[tuple[int, ...]] = set() + for _ in range(3): # Generate 3 different permutations for diversity + perm = tuple(rm.get_random().sample(pool, length)) + permutations.add(perm) + + return permutations + + def apply_tensor_contraints(op_name: str, index: int) -> list[object]: + # Constraint to limit tensor size to < 4000 bytes with fully randomized shapes + import random + + def get_dtype_bytes(dtype: torch.dtype) -> int: + """Get the number of bytes per element for a given dtype""" + dtype_bytes = { + torch.int8: 1, + torch.uint8: 1, + torch.int16: 2, + torch.uint16: 2, + torch.int32: 4, + torch.float32: 4, + torch.int64: 8, + torch.float64: 8, + torch.bool: 1, + torch.float: 4, # alias for float32 + torch.int: 4, # alias for int32 + torch.long: 8, # alias for int64 + } + return dtype_bytes.get(dtype, 4) # Default to 4 bytes if dtype not found + + def generate_random_shape_with_byte_limit( + rank: int, dtype: torch.dtype, max_bytes: int = 3999, seed_base: int = 42 + ) -> list[int]: + """Generate a random shape with given rank ensuring total byte size < max_bytes""" + random.seed(seed_base + rank) + + bytes_per_element = get_dtype_bytes(dtype) + max_elements = max_bytes // bytes_per_element + + # Start with all dimensions as 1 + shape = [1] * rank + remaining_elements = ( + max_elements - 1 + ) # Leave room since we start with product=1 + + # Randomly distribute the remaining capacity across dimensions + for i in range(rank): + if remaining_elements <= 1: + break + + # Calculate maximum size this dimension can have without exceeding limit + current_product = 1 + for j in range(rank): + if j != i: + current_product *= shape[j] + + max_size_for_dim = min( + remaining_elements // current_product, 50 + ) # Cap at 50 + if max_size_for_dim > shape[i]: + # Randomly choose a size between current and max + new_size = random.randint(shape[i], max_size_for_dim) + shape[i] = new_size + remaining_elements = max_elements // (current_product * new_size) + remaining_elements = max(1, remaining_elements) + + # Final random shuffle of the dimensions to make it more random + random.shuffle(shape) + return shape + + def random_size_constraint(deps: object, r: int, d: int) -> int: + """Generate random sizes ensuring total byte size < 4000 bytes""" + # Use conservative approach: assume worst case is 4 bytes per element (float32/int32) + # This ensures we never exceed 4000 bytes regardless of actual dtype + worst_case_dtype = torch.float32 # 4 bytes per element + + # Create a unique key for this tensor configuration + cache_key = f"{r}_{d}_conservative" + + if cache_key not in _shape_cache: + # Generate a new random shape for this rank using worst-case byte estimation + shape = generate_random_shape_with_byte_limit( + r, worst_case_dtype, max_bytes=3999, seed_base=42 + r * 10 + d + ) + _shape_cache[cache_key] = shape + + # Return the size for dimension d, ensuring we don't go out of bounds + cached_shape = _shape_cache[cache_key] + return cached_shape[d] if d < len(cached_shape) else 1 + + max_size_constraint = cp.Size.Le( + lambda deps, r, d: random_size_constraint(deps, r, d) + ) + tensor_constraints = ( [ cp.Dtype.In( @@ -39,7 +156,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), + max_size_constraint, cp.Rank.Le(lambda deps: 2**3), ] if op_name @@ -62,7 +179,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), + max_size_constraint, cp.Rank.Le(lambda deps: 2**3), ] ) @@ -72,29 +189,38 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: if index == 0: # condition tensor_constraints = [ cp.Dtype.In(lambda deps: [torch.bool]), + cp.Value.Ge(lambda deps, dtype, struct: 0), + cp.Value.Le(lambda deps, dtype, struct: 1), + cp.Rank.Ge(lambda deps: 1), + cp.Size.Ge(lambda deps, r, d: 1), + max_size_constraint, + ] + elif index == 1: # input tensor(a) + tensor_constraints = [ + cp.Dtype.In(lambda deps: [torch.float32]), cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), + cp.Size.In( + lambda deps, r, d: fn.broadcast_with(deps[0].shape, r, d) + ), + max_size_constraint, ] - else: + else: # input tensor(b) tensor_constraints = [ - cp.Dtype.In( - lambda deps: [ - torch.int8, - torch.int16, - torch.uint8, - torch.uint16, - torch.int32, - torch.float32, - ] - ), + cp.Dtype.In(lambda deps: [torch.float32]), + cp.Dtype.Eq(lambda deps: deps[1].dtype), cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), + cp.Size.In( + lambda deps, r, d: fn.broadcast_with( + fn.broadcasted_shape(deps[0].shape, deps[1].shape), r, d + ) + ), + max_size_constraint, ] case "embedding.default": tensor_constraints = [ @@ -104,7 +230,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), + max_size_constraint, ] case "sigmoid.default": tensor_constraints.extend( @@ -114,6 +240,37 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Value.Le(lambda deps, dtype, struct: 2), ] ) + case "transpose_copy.int": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float32, torch.int32]), + ] + ) + case "permute_copy.default": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float32, torch.int8, torch.uint8]), + cp.Rank.Le( + lambda deps: 5 + ), # xa_nn_transpose only supports up to 5D + cp.Rank.Ge(lambda deps: 1), # Must have at least 1 dimension + ] + ) + case "sqrt.default": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float32, torch.int32]), + ] + ) + case "clamp.default": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float32, torch.int32]), + # Avoid NaN/Inf values that expose clamp NaN handling bugs + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), + ] + ) case "rsqrt.default": tensor_constraints.extend( [ @@ -124,6 +281,12 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Value.Le(lambda deps, dtype, struct: 2**2), ] ) + case "relu.default": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float32]), + ] + ) case "mean.dim": tensor_constraints.extend( [ @@ -133,10 +296,17 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: case "exp.default": tensor_constraints.extend( [ + cp.Dtype.In(lambda deps: [torch.float32]), cp.Value.Ge(lambda deps, dtype, struct: -(2**2)), cp.Value.Le(lambda deps, dtype, struct: 2**2), ] ) + case "tanh.default": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float32]), + ] + ) case "slice_copy.Tensor": tensor_constraints.extend( [ @@ -145,13 +315,44 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Value.Le(lambda deps, dtype, struct: 2), ] ) - case "constant_pad_nd.default": + case "div.Scalar" | "add.Tensor" | "mul.Tensor" | "sub.Tensor": tensor_constraints.extend( [ - cp.Dtype.In(lambda deps: [torch.float32]), + cp.Dtype.In( + lambda deps: [ + torch.int32, + torch.int64, + torch.float32, + ] + ), + ] + ) + case "split_copy.Tensor": + tensor_constraints.extend( + [ + cp.Dtype.In( + lambda deps: [ + torch.int32, + torch.int64, + torch.float32, + ] + ), + cp.Value.Ge(lambda deps, dtype, struct: 1), + cp.Value.Le(lambda deps, dtype, struct: 2**3), + cp.Rank.Le(lambda deps: 3), cp.Size.Le(lambda deps, r, d: 2**2), ] ) + case "constant_pad_nd.default": + tensor_constraints = [ + cp.Dtype.In(lambda deps: [torch.float32]), + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Rank.Ge(lambda deps: 1), + cp.Rank.Le(lambda deps: 2), # Reduced from 3 to 2 (max 2D tensors) + cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.Le(lambda deps, r, d: 3), # Max dimension size of 3 + ] case "avg_pool2d.default": tensor_constraints.extend( [ @@ -167,12 +368,29 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: ] ) case "div.Tensor": + if index == 1: # Only apply zero-prevention to divisor + tensor_constraints.extend( + [ + cp.Value.Ne( + lambda deps, dtype, struct: 0 + ), # Prevent division by zero + cp.Value.Le(lambda deps, dtype, struct: 2**3), + cp.Size.Le(lambda deps, r, d: 2**3), + cp.Rank.Le(lambda deps: 2**2), + ] + ) + else: + tensor_constraints.extend( + [ + cp.Value.Le(lambda deps, dtype, struct: 2**3), + cp.Size.Le(lambda deps, r, d: 2**3), + cp.Rank.Le(lambda deps: 2**2), + ] + ) + case "pow.Tensor_Scalar": tensor_constraints.extend( [ - cp.Value.Ne(lambda deps, dtype, struct: 0), - cp.Value.Le(lambda deps, dtype, struct: 2**3), - cp.Size.Le(lambda deps, r, d: 2**3), - cp.Rank.Le(lambda deps: 2**2), + cp.Dtype.In(lambda deps: [torch.float32, torch.int32]), ] ) case "div.Tensor_mode" | "minimum.default": @@ -190,6 +408,9 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Dtype.In(lambda deps: [torch.int64, torch.int32, torch.float32]), cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Value.Ne( + lambda deps, dtype, struct: 0 + ), # Prevent division by zero cp.Rank.Ge(lambda deps: 1), cp.Rank.Eq(lambda deps: deps[0].dim()), cp.Size.Eq(lambda deps, r, d: fn.safe_size(deps[0], d)), @@ -206,6 +427,12 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Value.Le(lambda deps, dtype, struct: 2**2), cp.Size.Le(lambda deps, r, d: 2**3), ] + case "leaky_relu.default": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float32]), + ] + ) case "_softmax.default": tensor_constraints.extend( [ @@ -213,6 +440,12 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Size.Le(lambda deps, r, d: 2**2), ] ) + case "flip.default": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float32]), + ] + ) case _: pass return tensor_constraints @@ -226,6 +459,7 @@ def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]: | "mul.Scalar" | "div.Scalar" | "constant_pad_nd.default" + | "clamp.default" ): return [ScalarDtype.int] case "full.default": @@ -253,11 +487,44 @@ def facto_testcase_gen( # noqa: C901 cp.Size.Le(lambda deps, r, d: 2**2), ] ) - if in_spec.name == "max_val": # hardtanh + # Special handling for clamp.default to ensure min < max with sufficient gap (at least 2) and never None + if op_name == "clamp.default": + if in_spec.name == "min": + # min must always be provided (not None) and bounded, leave room for max + spec.inspec[index].constraints.extend( + [ + cp.Optional.Eq(lambda deps: False), # Never None + cp.Value.Ge(lambda deps, dtype: -(2**4)), + cp.Value.Le( + lambda deps, dtype: 2**4 - 2 + ), # Leave room for max (at least 2 units) + ] + ) + elif in_spec.name == "max": + # max must always be provided (not None), be >= min + 2 (sufficient gap), and bounded + spec.inspec[index].deps = [0, 1] # deps on input tensor and min + spec.inspec[index].constraints.extend( + [ + cp.Optional.Eq(lambda deps: False), # Never None + cp.Value.Ge( + lambda deps, dtype: deps[1] + 2 + ), # max >= min + 2 (sufficient gap) + cp.Value.Le(lambda deps, dtype: 2**4), + ] + ) + elif in_spec.name == "max_val": # hardtanh spec.inspec[index].deps = [0, 1] spec.inspec[index].constraints.extend( [cp.Value.Ge(lambda deps, _: deps[1])] ) + elif in_spec.name == "negative_slope" and op_name == "leaky_relu.default": + # For leaky_relu, negative_slope should be in typical range (0, 1] + spec.inspec[index].constraints.extend( + [ + cp.Value.Gt(lambda deps, dtype: 0), + cp.Value.Le(lambda deps, dtype: 1.0), + ] + ) else: spec.inspec[index].constraints.extend( [ @@ -282,12 +549,32 @@ def facto_testcase_gen( # noqa: C901 apply_tensor_contraints(op_name, index) ) elif in_spec.type.is_dim_list(): - spec.inspec[index].constraints.extend( - [ - cp.Length.Ge(lambda deps: 1), - cp.Optional.Eq(lambda deps: False), - ] - ) + # Special handling for permute_copy.default to ensure valid permutation + if op_name == "permute_copy.default": + spec.inspec[index].constraints.extend( + [ + cp.Length.Ge(lambda deps: 1), + cp.Length.Eq( + lambda deps: deps[0].dim() + ), # Must be a complete permutation + cp.Optional.Eq(lambda deps: False), + # Generate valid permutations using only positive indices + # Cadence/Xtensa hardware kernels do not support negative dimension indices + cp.Value.Gen( + lambda deps, length: ( + _positive_valid_dim_list(deps[0], length), + fn.invalid_dim_list(deps[0], length), + ) + ), + ] + ) + else: + spec.inspec[index].constraints.extend( + [ + cp.Length.Ge(lambda deps: 1), + cp.Optional.Eq(lambda deps: False), + ] + ) elif in_spec.type.is_bool(): spec.inspec[index].constraints.extend( [ diff --git a/backends/cadence/vision/kernels/CMakeLists.txt b/backends/cadence/vision/kernels/CMakeLists.txt new file mode 100644 index 00000000000..fa7b2b5203b --- /dev/null +++ b/backends/cadence/vision/kernels/CMakeLists.txt @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# lint_cmake: -linelength +add_library( + cadence_kernels + kernels.cpp + ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/library/api/tensor_transposef.c + ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/library/api/vsoftmaxf.c + ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/library/tables/expf_tbl.c + ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/library/tables/nanf_tbl.c + ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/library/tables/inff_tbl.c +) + +# Let files say "include ". +set(_common_include_directories + ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 +) + +target_include_directories( + cadence_kernels + PUBLIC . ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/include + ${EXECUTORCH_ROOT}/backends/cadence/vision/third-party/include_private + ${_common_include_directories} +) + +target_link_libraries(cadence_kernels PRIVATE idma) diff --git a/backends/cadence/vision/kernels/kernels.cpp b/backends/cadence/vision/kernels/kernels.cpp new file mode 100644 index 00000000000..66cfcadbf13 --- /dev/null +++ b/backends/cadence/vision/kernels/kernels.cpp @@ -0,0 +1,197 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +namespace impl { +namespace vision { +namespace kernels { + +void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) { + Result temp_mem_res = ctx.allocate_temp(size); + return temp_mem_res.ok() ? temp_mem_res.get() : nullptr; +} + +// Quantize a fp32 value to an int8_t/uint8_t value +template +T quantize(const float x, float scale, int32_t zero_point) { + constexpr float kMinValue = static_cast(std::numeric_limits::min()); + constexpr float kMaxValue = static_cast(std::numeric_limits::max()); + float tmp = roundf(x * scale + zero_point); + return std::max(std::min(tmp, kMaxValue), kMinValue); +} + +// Quantize an fp32 array to an int8_t/uint8_t array +template +void quantize( + T* __restrict__ y, + const float* __restrict__ x, + float inv_scale, + int32_t zero_point, + size_t size) { + for (size_t i = 0; i < size; ++i) { + y[i] = quantize(x[i], inv_scale, zero_point); + } +} + +// Dequantize an int8_t/uint8_t value to an fp32 value +template +float dequantize(const T x, float scale, int32_t zero_point) { + return scale * (x - zero_point); +} + +// Dequantize an int8_t/uint8_t/int16_t array to an fp32 array +template +void dequantize( + float* __restrict__ y, + const T* __restrict__ x, + float scale, + int32_t zero_point, + size_t size) { + for (size_t i = 0; i < size; ++i) { + y[i] = dequantize(x[i], scale, zero_point); + } +} + +// Requantize the int8_t/uint8_t in value to a uint8_t/int8_t out value. +// The scale and zero_point for requantization are in the args. +template +OT requantize( + const IT in, + float in_scale, + int32_t in_zero_point, + float inv_out_scale, + int32_t out_zero_point) { + float dequant = dequantize(in, in_scale, in_zero_point); + return quantize(dequant, inv_out_scale, out_zero_point); +} + +// Requantize the int8_t/uint8_t in array to a uint8_t/int8_t out array. +// The scale and zero_point for requantization are in the args. +template +void requantize( + OT* __restrict__ out, + const IT* __restrict__ in, + float in_scale, + int32_t in_zero_point, + float inv_out_scale, + int32_t out_zero_point, + size_t size) { + for (size_t i = 0; i < size; ++i) { + out[i] = requantize( + in[i], in_scale, in_zero_point, inv_out_scale, out_zero_point); + } +} + +// explicit template instantiation + +#define typed_quantize_val(dtype) \ + template dtype quantize(const float x, float inv_scale, int32_t zero_point); +typed_quantize_val(int8_t); +typed_quantize_val(uint8_t); +typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); +typed_quantize_val(int32_t); +#undef typed_quantize_val + +#define typed_quantize_vec(dtype) \ + template void quantize( \ + dtype* __restrict__ y, \ + const float* __restrict__ x, \ + float inv_scale, \ + int32_t zero_point, \ + size_t size); +typed_quantize_vec(int8_t); +typed_quantize_vec(uint8_t); +typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); +typed_quantize_vec(int32_t); +#undef typed_quantize_vec + +#define typed_dequantize_val(dtype) \ + template float dequantize(const dtype x, float scale, int32_t zero_point); +typed_dequantize_val(int8_t); +typed_dequantize_val(uint8_t); +typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); +typed_dequantize_val(int32_t); +#undef typed_dequantize_val + +#define typed_dequantize_vec(dtype) \ + template void dequantize( \ + float* __restrict__ y, \ + const dtype* __restrict__ x, \ + float scale, \ + int32_t zero_point, \ + size_t size); +typed_dequantize_vec(int8_t); +typed_dequantize_vec(uint8_t); +typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); +typed_dequantize_vec(int32_t); +#undef typed_dequantize_vec + +#define typed_requantize_val(itype, otype) \ + template otype requantize( \ + const itype in, \ + float in_scale, \ + int32_t in_zero_point, \ + float inv_out_scale, \ + int32_t out_zero_point); +typed_requantize_val(int8_t, int8_t); +typed_requantize_val(int8_t, uint8_t); +typed_requantize_val(int8_t, int16_t); +typed_requantize_val(int8_t, uint16_t); +typed_requantize_val(uint8_t, int8_t); +typed_requantize_val(uint8_t, uint8_t); +typed_requantize_val(uint8_t, int16_t); +typed_requantize_val(uint8_t, uint16_t); +typed_requantize_val(int16_t, int8_t); +typed_requantize_val(int16_t, uint8_t); +typed_requantize_val(int16_t, int16_t); +typed_requantize_val(int16_t, uint16_t); +typed_requantize_val(uint16_t, int8_t); +typed_requantize_val(uint16_t, uint8_t); +typed_requantize_val(uint16_t, int16_t); +typed_requantize_val(uint16_t, uint16_t); +#undef typed_requantize_val + +#define typed_requantize_vec(itype, otype) \ + template void requantize( \ + otype* __restrict__ out, \ + const itype* __restrict__ in, \ + float in_scale, \ + int32_t in_zero_point, \ + float inv_out_scale, \ + int32_t out_zero_point, \ + size_t size); +typed_requantize_vec(int8_t, int8_t); +typed_requantize_vec(int8_t, uint8_t); +typed_requantize_vec(int8_t, int16_t); +typed_requantize_vec(int8_t, uint16_t); +typed_requantize_vec(uint8_t, int8_t); +typed_requantize_vec(uint8_t, uint8_t); +typed_requantize_vec(uint8_t, int16_t); +typed_requantize_vec(uint8_t, uint16_t); +typed_requantize_vec(int16_t, int8_t); +typed_requantize_vec(int16_t, uint8_t); +typed_requantize_vec(int16_t, int16_t); +typed_requantize_vec(int16_t, uint16_t); +typed_requantize_vec(uint16_t, int8_t); +typed_requantize_vec(uint16_t, uint8_t); +typed_requantize_vec(uint16_t, int16_t); +typed_requantize_vec(uint16_t, uint16_t); +#undef typed_requantize_vec + +}; // namespace kernels +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/kernels/kernels.h b/backends/cadence/vision/kernels/kernels.h new file mode 100644 index 00000000000..e86a36515ec --- /dev/null +++ b/backends/cadence/vision/kernels/kernels.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "inttypes.h" +#include "stddef.h" + +using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::Result; + +namespace impl { +namespace vision { +namespace kernels { + +void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size); + +template +T quantize(const float x, float scale, int32_t zero_point); + +template +float dequantize(const T x, float scale, int32_t zero_point); + +template +void quantize( + T* __restrict__ y, + const float* __restrict__ x, + float scale, + int32_t zero_point, + size_t size); + +// Deuantize an int8_t/uint8_t/int16_t array to an fp32 array +template +void dequantize( + float* __restrict__ y, + const T* __restrict__ x, + float scale, + int32_t zero_point, + size_t size); + +template +OT requantize( + const IT in, + float in_scale, + int32_t in_zero_point, + float inv_out_scale, + int32_t out_zero_point); + +template +void requantize( + OT* __restrict__ out, + const IT* __restrict__ in, + float in_scale, + int32_t in_zero_point, + float inv_out_scale, + int32_t out_zero_point, + size_t size); + +}; // namespace kernels +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/kernels/targets.bzl b/backends/cadence/vision/kernels/targets.bzl new file mode 100644 index 00000000000..02136c872b3 --- /dev/null +++ b/backends/cadence/vision/kernels/targets.bzl @@ -0,0 +1,25 @@ +load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + runtime.cxx_library( + name = "cadence_kernels", + srcs = ["kernels.cpp"], + exported_headers = [ + "kernels.h", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + platforms = CXX, + compatible_with = select({ + "DEFAULT": [], + "ovr_config//cpu:xtensa": ["ovr_config//cpu:xtensa"], + }), + define_static_target = True, + deps = [ + "//executorch/backends/cadence/vision/third-party:vision-nnlib", + "//executorch/runtime/kernel:kernel_includes", + ], + ) diff --git a/backends/cadence/vision/operators/CMakeLists.txt b/backends/cadence/vision/operators/CMakeLists.txt new file mode 100644 index 00000000000..38e4f97f841 --- /dev/null +++ b/backends/cadence/vision/operators/CMakeLists.txt @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake) + +if(NOT PYTHON_EXECUTABLE) + resolve_python_executable() +endif() + +# ATen compliant ops that are needed to run this model. +set(_aten_ops__srcs + "${CMAKE_CURRENT_SOURCE_DIR}/op_add.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_embedding.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_full.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_view_copy.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_softmax.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/slice_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_hardtanh.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_max_pool2d_with_indices.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mean.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_rsqrt.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sub.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_expand_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_gelu.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_empty.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_transpose_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_eq.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_logical_not.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_any.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_native_group_norm.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sum.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_select_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/delinearize_index.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/select_copy_util.cpp" +) +add_library(aten_ops_cadence ${_aten_ops__srcs}) +target_link_libraries(aten_ops_cadence PUBLIC executorch) +target_link_libraries(aten_ops_cadence PRIVATE cadence_kernels) + +# Let files say "include ". +set(_common_include_directories + ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 +) + +target_include_directories( + aten_ops_cadence + PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} ${_common_include_directories} + ${CMAKE_CURRENT_SOURCE_DIR}/../third-party +) + +# Custom ops that are needed to run the test model. +add_library( + custom_ops + "op_quantized_linear_out.cpp" + "op_quantized_conv_out.cpp" + "op_quantized_relu_out.cpp" + "op_quantized_layer_norm.cpp" + "op_quantize_per_tensor.cpp" + "op_quantized_fully_connected_out.cpp" + "op_dequantize_per_tensor.cpp" + "op_quantized_matmul_out.cpp" + "op_requantize_out.cpp" + "op_im2row_out.cpp" +) +target_include_directories( + custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} + ${_common_include_directories} +) + +target_link_libraries(custom_ops PUBLIC executorch) +target_link_libraries(custom_ops PRIVATE cadence_kernels) + +# Generate C++ bindings to register kernels into both PyTorch (for AOT) and +# Executorch (for runtime). Here select all ops in functions_vision.yaml +gen_selected_ops( + LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML + "${CMAKE_CURRENT_LIST_DIR}/../../aot/functions_vision.yaml" "" "" +) +generate_bindings_for_kernels( + LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML FUNCTIONS_YAML + ${CMAKE_CURRENT_SOURCE_DIR}/../../aot/functions_vision.yaml +) +message("Generated cadence x86 files ${gen_command_sources}") + +gen_operators_lib( + LIB_NAME "cadence_ops_lib" KERNEL_LIBS custom_ops DEPS aten_ops_cadence +) + +# Link custom_ops to the generated library to ensure the symbols are available +target_link_libraries(cadence_ops_lib PUBLIC custom_ops) diff --git a/backends/cadence/vision/operators/op_add.cpp b/backends/cadence/vision/operators/op_add.cpp new file mode 100644 index 00000000000..81014143275 --- /dev/null +++ b/backends/cadence/vision/operators/op_add.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using executorch::aten::Scalar; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::canCast; +using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::promoteTypes; +using torch::executor::apply_binary_elementwise_fn; +using torch::executor::Error; +using torch::executor::native::utils::extract_scalar; + +namespace impl { +namespace vision { +namespace native { + +Tensor& add_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out) { + (void)ctx; + + using namespace torch::executor::native::utils; + + ScalarType a_type = a.scalar_type(); + ScalarType b_type = b.scalar_type(); + ScalarType common_type = promoteTypes(a_type, b_type); + ScalarType out_type = out.scalar_type(); + + ET_CHECK_MSG(a_type == ScalarType::Float, "Input tensor not a float.\n"); + ET_CHECK_MSG(b_type == ScalarType::Float, "Input tensor not a float.\n"); + ET_CHECK_MSG(out_type == ScalarType::Float, "Output tensor not a float.\n"); + + ET_CHECK(canCast(common_type, out_type)); + + using CTYPE_A = float; + using CTYPE_B = float; + using CTYPE_IN = float; + using CTYPE_OUT = float; + CTYPE_IN alpha_val; + ET_CHECK_MSG( + extract_scalar(alpha, &alpha_val), + "Could not be extracted: wrong type or out of range"); + + apply_binary_elementwise_fn( + [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = a_casted + alpha_val * b_casted; + + return static_cast(value); + }, + a, + b, + out); + + return out; +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_dequantize_per_tensor.cpp b/backends/cadence/vision/operators/op_dequantize_per_tensor.cpp new file mode 100644 index 00000000000..daffecda1bf --- /dev/null +++ b/backends/cadence/vision/operators/op_dequantize_per_tensor.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; + +void dequantize_per_tensor_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + + if (input.scalar_type() == ScalarType::Byte) { + const uint8_t* input_data = input.const_data_ptr(); + kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Char) { + const int8_t* input_data = input.const_data_ptr(); + kernels::dequantize(out_data, input_data, scale, zero_point, numel); + } else if ( + input.scalar_type() == ScalarType::Bits16 || + input.scalar_type() == ScalarType::UInt16) { + const uint16_t* input_data = input.const_data_ptr(); + kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Short) { + const int16_t* input_data = input.const_data_ptr(); + kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Int) { + const int32_t* input_data = input.const_data_ptr(); + kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); + } +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/reference/operators/op_embedding.cpp b/backends/cadence/vision/operators/op_embedding.cpp similarity index 92% rename from backends/cadence/reference/operators/op_embedding.cpp rename to backends/cadence/vision/operators/op_embedding.cpp index ce28789a156..5273cb083e8 100644 --- a/backends/cadence/reference/operators/op_embedding.cpp +++ b/backends/cadence/vision/operators/op_embedding.cpp @@ -8,13 +8,13 @@ #include -namespace torch { -namespace executor { -namespace native { - using executorch::aten::Tensor; using executorch::runtime::KernelRuntimeContext; +namespace impl { +namespace vision { +namespace native { + void embedding_out( KernelRuntimeContext& ctx, const Tensor& weight, @@ -37,5 +37,5 @@ void embedding_out( } } // namespace native -} // namespace executor -} // namespace torch +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_full.cpp b/backends/cadence/vision/operators/op_full.cpp new file mode 100644 index 00000000000..afc29718a2b --- /dev/null +++ b/backends/cadence/vision/operators/op_full.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +using executorch::aten::IntArrayRef; +using executorch::aten::Scalar; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; +using torch::executor::Error; +using torch::executor::native::utils::extract_scalar; +using torch::executor::native::utils::get_scalar_dtype; + +namespace impl { +namespace vision { +namespace native { + +Tensor& full_out( + KernelRuntimeContext& ctx, + const IntArrayRef sizes, + const Scalar& fill_value, + Tensor& out) { + (void)ctx; + + ScalarType val_type = get_scalar_dtype(fill_value); + ScalarType out_type = out.scalar_type(); + + Error err = resize_tensor(out, sizes); + ET_CHECK_MSG(err == Error::Ok, "Could not resize out"); + + ET_SWITCH_REAL_TYPES_AND(Bool, val_type, ctx, "full", CTYPE_VAL, [&] { + CTYPE_VAL val; + ET_CHECK_MSG( + extract_scalar(fill_value, &val), + "Could not be extracted: wrong type or out of range"); + + ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "full", CTYPE_OUT, [&] { + CTYPE_OUT val_casted = static_cast(val); + auto data_out = out.mutable_data_ptr(); + for (size_t i = 0; i < out.numel(); ++i) { + data_out[i] = val_casted; + } + }); + }); + + return out; +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_im2row_out.cpp b/backends/cadence/vision/operators/op_im2row_out.cpp new file mode 100644 index 00000000000..501f8ce5376 --- /dev/null +++ b/backends/cadence/vision/operators/op_im2row_out.cpp @@ -0,0 +1,298 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace impl { +namespace vision { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +template +__attribute__((always_inline)) void im2row_( + const T* __restrict__ data_im, + const int32_t in_zero_point, + /* input parameters*/ + const int32_t channels, + const int32_t height, + const int32_t width, + /* output parameters */ + const int32_t out_height, + const int32_t out_width, + /* convolution parameters */ + const int32_t kernel_h, + const int32_t kernel_w, + const int32_t pad_h, + const int32_t pad_w, + const int32_t stride_h, + const int32_t stride_w, + const int32_t dilation_h, + const int32_t dilation_w, + T* __restrict__ data_col, + bool channels_last) { + // Consider convolving the input image of dimensions channels * height * width + // (or height * width * channels for NHWC layout) with a filter of dimensions + // channels * kernels_h * kernels_w. Assume that this convolution will produce + // an output of dimensinos out_height x out_width. For each point the output, + // im2row takes the data from the input that is used in the computation of + // that output point, and flattens it into a vector of size channels_col = + // channels * kernel_h * kernel_w. The output of im2row will therefore be a 2D + // array of size (out_height * out_width) x channels_col + const int32_t channels_col = channels * kernel_h * kernel_w; + + // If the layout is NHWC, we can copy 'channels' worth of contiguous data + // points when performing im2row. + if (channels_last) { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + // Each point in the output domain is the result of applying a filter of + // size kernel_h x kernel_w x channels on the input. But since channels + // is contiguous, we will not explicitly have a loop for it. + for (int _kh = 0; _kh < kernel_h; ++_kh) { + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; + for (int _kw = 0; _kw < kernel_w; ++_kw) { + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + + // h_im and w_im are the actual height and width coordinates of the + // input tensor from where we need to copy 'channels' points. + const T* __restrict__ slice_im = + data_im + (h_im * width + w_im) * channels; + T* __restrict__ slice_col = data_col + i_col * channels_col + + (_kh * kernel_w + _kw) * channels; + // If the coordinates were within the input domain, we copy + // 'channels' contiguous values. Otherwise we will fill the output + // with 0's. + if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + std::memcpy(slice_col, slice_im, channels * sizeof(T)); + } else { + std::fill_n(slice_col, channels, T(in_zero_point)); + } + } + } + } + } + } else { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + + // Each point in the output domain is the result of applying a filter + // of size chanenls * kernel_h x kernel_w on the input + for (int _c = 0; _c < channels; ++_c) { + for (int _kh = 0; _kh < kernel_h; ++_kh) { + for (int _kw = 0; _kw < kernel_w; ++_kw) { + // c_col is the linearized access in the channels_col vector. + int32_t c_col = (_c * kernel_h + _kh) * kernel_w + _kw; + // h_im and w_im are the actual height and width coordinates of + // the input tensor that we need to copy to the output. + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + // If the current data access is within the input tensor, copy the + // value + data_col[i_col * channels_col + c_col] = + (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) + ? data_im[(_c * height + h_im) * width + w_im] + : static_cast(in_zero_point); + } + } + } + } + } + } +} + +void im2row_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + const Tensor& in_zero_point, + bool channel_last, + Tensor& out) { + // Compute the input tensor's dims + bool unit_height = input.dim() == 3; + const int32_t batch_size = input.size(0); + const int32_t in_c = + channel_last ? input.size(3 - unit_height) : input.size(1); + const int32_t in_h = + unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); + const int32_t in_w = + channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); + + // Get the kernel parameters + int32_t kernel_h = kernel_size[0]; + int32_t kernel_w = kernel_size[1]; + int32_t dilation_h = dilation[0]; + int32_t dilation_w = dilation[1]; + int32_t pad_h = padding[0]; + int32_t pad_w = padding[1]; + int32_t stride_h = stride[0]; + int32_t stride_w = stride[1]; + + // If we were to apply a convolution on the input tensor, compute the output + // height and width. + int32_t out_h = + (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; + int32_t out_w = + (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; + + ET_DCHECK_MSG( + (out_h * out_w) == out.size(1), "dimension mismatch for output"); + ET_DCHECK_MSG( + (kernel_h * kernel_w * in_c) == out.size(2), + "dimension mismatch for output"); + + // Check if the input is per-tensor quantized or per-channel quantized. The + // zero point for each batch could differ for per-channel quantized input. + bool per_tensor_quantized = in_zero_point.numel() == 1; + +#define typed_im2row(dtype, ctype) \ + case ScalarType::dtype: { \ + const ctype* __restrict__ in_data = input.const_data_ptr(); \ + ctype* __restrict__ out_data = out.mutable_data_ptr(); \ + const int32_t* __restrict__ zero_point = \ + in_zero_point.const_data_ptr(); \ + int32_t in_plane = in_c * in_h * in_w; \ + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ + for (size_t n = 0; n < batch_size; ++n) { \ + im2row_( \ + &in_data[n * in_plane], \ + per_tensor_quantized ? zero_point[0] : zero_point[n], \ + in_c, \ + in_h, \ + in_w, \ + out_h, \ + out_w, \ + kernel_h, \ + kernel_w, \ + pad_h, \ + pad_w, \ + stride_h, \ + stride_w, \ + dilation_h, \ + dilation_w, \ + &out_data[n * out_plane], \ + channel_last); \ + } \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + typed_im2row(Float, float); + typed_im2row(Byte, uint8_t); + typed_im2row(Char, int8_t); + default: + ET_DCHECK_MSG( + false, + "im2row not implemented for dtype %s", + torch::executor::toString(dtype)); + } +#undef typed_im2row +} + +void im2row_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + int64_t in_zero_point, + bool channel_last, + Tensor& out) { + // Compute the input tensor's dims + bool unit_height = input.dim() == 3; + const int32_t batch_size = input.size(0); + const int32_t in_c = + channel_last ? input.size(3 - unit_height) : input.size(1); + const int32_t in_h = + unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); + const int32_t in_w = + channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); + + // Get the kernel parameters + int32_t kernel_h = kernel_size[0]; + int32_t kernel_w = kernel_size[1]; + int32_t dilation_h = dilation[0]; + int32_t dilation_w = dilation[1]; + int32_t pad_h = padding[0]; + int32_t pad_w = padding[1]; + int32_t stride_h = stride[0]; + int32_t stride_w = stride[1]; + + // If we were to apply a convolution on the input tensor, compute the output + // height and width. + int32_t out_h = + (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; + int32_t out_w = + (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; + + ET_DCHECK_MSG( + (out_h * out_w) == out.size(1), "dimension mismatch for output"); + ET_DCHECK_MSG( + (kernel_h * kernel_w * in_c) == out.size(2), + "dimension mismatch for output"); + +#define typed_im2row_per_tensor(dtype, ctype) \ + case ScalarType::dtype: { \ + const ctype* __restrict__ in_data = input.const_data_ptr(); \ + ctype* __restrict__ out_data = out.mutable_data_ptr(); \ + int32_t in_plane = in_c * in_h * in_w; \ + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ + for (size_t n = 0; n < batch_size; ++n) { \ + im2row_( \ + &in_data[n * in_plane], \ + in_zero_point, \ + in_c, \ + in_h, \ + in_w, \ + out_h, \ + out_w, \ + kernel_h, \ + kernel_w, \ + pad_h, \ + pad_w, \ + stride_h, \ + stride_w, \ + dilation_h, \ + dilation_w, \ + &out_data[n * out_plane], \ + channel_last); \ + } \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + typed_im2row_per_tensor(Float, float); + typed_im2row_per_tensor(Byte, uint8_t); + typed_im2row_per_tensor(Char, int8_t); + default: + ET_DCHECK_MSG( + false, + "im2row.per_tensor not implemented for dtype %s", + torch::executor::toString(dtype)); + } +#undef typed_im2row_per_tensor +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_quantize_per_tensor.cpp b/backends/cadence/vision/operators/op_quantize_per_tensor.cpp new file mode 100644 index 00000000000..cd72d2de2b5 --- /dev/null +++ b/backends/cadence/vision/operators/op_quantize_per_tensor.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; + +// Quantize the input tensor (PT2 version). Note that quant_ are not +// used in any computation. +void quantize_per_tensor_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + + if (out.scalar_type() == ScalarType::Byte) { + uint8_t* out_data = out.mutable_data_ptr(); + kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Char) { + int8_t* out_data = out.mutable_data_ptr(); + kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if ( + out.scalar_type() == ScalarType::Bits16 || + out.scalar_type() == ScalarType::UInt16) { + uint16_t* out_data = out.mutable_data_ptr(); + kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Int) { + int32_t* out_data = out.mutable_data_ptr(); + kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(out.scalar_type())); + } +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/operators/op_quantized_conv_out.cpp b/backends/cadence/vision/operators/op_quantized_conv_out.cpp new file mode 100644 index 00000000000..b632f0931c2 --- /dev/null +++ b/backends/cadence/vision/operators/op_quantized_conv_out.cpp @@ -0,0 +1,682 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// This implements a generic 2d conv kernel that operates on raw pointers. +// The version handles both quantized and fp32 convolutions. +// The input is of shape [n x c x h x w] +// The weight is of shape [oc x wc x wh x ww], where wc == c +// The output is of shape [n x oc x oh x ow] +// The bias is of shape [oc] +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv2d_nchw_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t c, + int32_t h, + int32_t w, + int32_t oc, + int32_t wc, + int32_t wh, + int32_t ww, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * c * h * w; + OT* out_batch = p_out + _n * oc * oh * ow; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + OT* out_plane = out_batch + _oc * oh * ow; + const WT* weight_batch = p_weight + _oc * wc * wh * ww; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of size + // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an + // output channel of size 1 x oh x ow. + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to the + // output channel being computed) with the corresponding weight + // channel. + // If the padding is 0, and dilation is 1, then we can remove the + // unnecessary checks, and simplify the code so that it can be + // vectorized by Tensilica compiler. + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = (_h + _wh) * w + (_w + _ww); + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_h + d0 * _wh - p0) >= 0) && + ((_h + d0 * _wh - p0) < h) && + ((_w + d1 * _ww - p1) >= 0) && + ((_w + d1 * _ww - p1) < w)) { + int ioff = + (_h + d0 * _wh - p0) * w + (_w + d1 * _ww - p1); + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_plane[_oh * ow + _ow] = + kernels::quantize(val, inv_out_scale, out_zero_point); + } else { + out_plane[_oh * ow + _ow] = acc; + } + } + } + } + } + } +} + +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv2d_nhwc_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t h, + int32_t w, + int32_t c, + int32_t oc, + int32_t wh, + int32_t ww, + int32_t wc, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * h * w * c; + OT* out_batch = p_out + _n * oh * ow * oc; + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + OT* out_line = out_batch + (_oh * ow + _ow) * oc; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + const WT* weight_batch = p_weight + _oc * wh * ww * wc; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of + // size h x w x icpg, with a stencil of size wh x ww x icpg, to + // compute an output channel of size oh x ow x 1. + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to + // the output channel being computed) with the corresponding + // weight channel. If the padding is 0, and dilation is 1, then + // we can remove the unnecessary checks, and simplify the code + // so that it can be vectorized by Tensilica compiler.x`` + if (zero_pad_unit_dilation) { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + const IT* in_line = + in_batch + (_h + _wh) * w * c + (_w + _ww) * c; + const WT* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } else { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_h + d0 * _wh - p0) >= 0) && + ((_h + d0 * _wh - p0) < h) && + ((_w + d1 * _ww - p1) >= 0) && + ((_w + d1 * _ww - p1 < w))) { + const IT* in_line = in_batch + + (_h + d0 * _wh - p0) * w * c + (_w + d1 * _ww - p1) * c; + const WT* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_line[_oc] = + kernels::quantize(val, inv_out_scale, out_zero_point); + } else { + out_line[_oc] = acc; + } + } + } + } + } + } +} + +// The quantized convolution kernel. in_scale and weight_scale are implicit in +// bias_scale, since it is a product of the two. The kernel will branch to +// quantized::conv1d or quantized::conv2d based on the dimensionality of +// activation tensor. +void quantized_conv_nchw( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, c, h, w] + const int n = input.size(0); + const int c = input.size(1); + const int h = conv1d ? 1 : input.size(2); + const int w = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wc, wh, ww] + const int oc = weight.size(0); + const int wc = weight.size(1); + const int wh = conv1d ? 1 : weight.size(2); + const int ww = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oc, oh, ow] + const int oh = conv1d ? 1 : out.size(2); + const int ow = conv1d ? out.size(2) : out.size(3); + +#define typed_quantized_conv2d_nchw(ctype, dtype) \ + case ScalarType::dtype: { \ + conv2d_nchw_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + c, \ + h, \ + w, \ + oc, \ + wc, \ + wh, \ + ww, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nchw); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nchw +} + +void quantized_conv_nhwc( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, h, w, c] + const int n = input.size(0); + const int h = conv1d ? 1 : input.size(1); + const int w = conv1d ? input.size(1) : input.size(2); + const int c = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wh, ww, wc] + const int oc = weight.size(0); + const int wh = conv1d ? 1 : weight.size(1); + const int ww = conv1d ? weight.size(1) : weight.size(2); + const int wc = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oh, ow, oc] + const int oh = conv1d ? 1 : out.size(1); + const int ow = conv1d ? out.size(1) : out.size(2); + +#define typed_quantized_conv2d_nhwc(ctype, dtype) \ + case ScalarType::dtype: { \ + conv2d_nhwc_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + h, \ + w, \ + c, \ + oc, \ + wh, \ + ww, \ + wc, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nhwc); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nhwc +} + +void quantized_conv_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED const Tensor& out_multiplier, + __ET_UNUSED const Tensor& out_shift, + bool channel_last, + Tensor& out) { + const float bias_scale_float = bias_scale.const_data_ptr()[0]; + const int32_t weight_zero_point_int = + weight_zero_point.const_data_ptr()[0]; + if (channel_last) { + quantized_conv_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point_int, + bias_scale_float, + output_scale, + output_zero_point, + out); + } else { + quantized_conv_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point_int, + bias_scale_float, + output_scale, + output_zero_point, + out); + } +} + +void quantized_conv_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + bool channel_last, + Tensor& out) { + if (channel_last) { + quantized_conv_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + } else { + quantized_conv_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + } +} + +void quantized_conv2d_nchw_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out) { + quantized_conv_per_tensor_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + false, // channel_last = false for NCHW + out); +} + +void quantized_conv2d_nhwc_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out) { + quantized_conv_per_tensor_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + true, // channel_last = true for NHWC + out); +} + +void quantized_conv2d_nchw_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& out) { + quantized_conv_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + false, // channel_last = false for NCHW + out); +} + +void quantized_conv2d_nhwc_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& out) { + quantized_conv_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + true, // channel_last = true for NHWC + out); +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_quantized_fully_connected_out.cpp b/backends/cadence/vision/operators/op_quantized_fully_connected_out.cpp new file mode 100644 index 00000000000..29aa8906414 --- /dev/null +++ b/backends/cadence/vision/operators/op_quantized_fully_connected_out.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +void quantized_fully_connected_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_linear_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point_t, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear +} + +void quantized_fully_connected_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_linear_per_tensor_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/operators/op_quantized_layer_norm.cpp b/backends/cadence/vision/operators/op_quantized_layer_norm.cpp new file mode 100644 index 00000000000..a9685eddedb --- /dev/null +++ b/backends/cadence/vision/operators/op_quantized_layer_norm.cpp @@ -0,0 +1,201 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::getLeadingDims; +using ::executorch::runtime::KernelRuntimeContext; + +namespace impl { +namespace vision { +namespace native { + +// Compute quantized layer_norm. The current implementation assumes that the +// input is per-tensor quantized. +template +void quantized_layer_norm_per_tensor_( + const Tensor& input, + double input_scale, + int64_t input_zero_point, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + // Get the raw pointers to input, output, weight, and bias + const T* __restrict__ in_data = input.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + const float* __restrict__ weight_data = weight.const_data_ptr(); + const float* __restrict__ bias_data = bias.const_data_ptr(); + + float output_inv_scale = 1.0f / output_scale; + + size_t last_dim = input.size(input.dim() - 1); + size_t leading_dims = getLeadingDims(input, input.dim() - 1); + + // Visualize the input tensor as a set of 1d vectors, and compute the + // layer_norm for each vector. + for (size_t i = 0; i < leading_dims; ++i) { + const T* x = in_data + i * last_dim; + T* y = out_data + i * last_dim; + + // compute sum and squared sum. The fp32 sum can be approximated as: + // (X_1 - in_zero_point) * in_scale + (X_2 - in_zero_point) * in_scale + ... + // (X_N - in_zero_point) * in_scale. + int32_t sum = 0; + int32_t sq_sum = last_dim * input_zero_point * input_zero_point; + for (size_t j = 0; j < last_dim; ++j) { + int32_t val = x[j]; + sum += val; + sq_sum += val * val; + } + sq_sum -= (2 * sum * input_zero_point); + sum -= (last_dim * input_zero_point); + + float mean = (input_scale * sum) / last_dim; + float variance = + (sq_sum * input_scale * input_scale) / last_dim - mean * mean; + float inv_std = 1.0f / std::sqrt(variance + eps); + + // y = (x - mean) / std * kGamma + kBeta + for (int j = 0; j < last_dim; ++j) { + // y[j] = (x[j] - mean) / std * kGamma + kBeta; + // Since X is quantized, we dequantize it, compute fp32 result, and + // quantize the result to an int8/uint8 value. + float val = kernels::dequantize(x[j], input_scale, input_zero_point); + + val = (val - mean) * inv_std * weight_data[j] + bias_data[j]; + y[j] = kernels::quantize(val, output_inv_scale, output_zero_point); + } + } +} + +// Compute quantized layer_norm. The current implementation assumes that the +// input is per-tensor quantized. +template +void quantized_layer_norm_( + const Tensor& input, + const Tensor& in_scale, + const Tensor& in_zero_point, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + // Extract the zero point and scale for input tensor. + float input_scale = in_scale.const_data_ptr()[0]; + int64_t input_zero_point = in_zero_point.const_data_ptr()[0]; + + // Call other overload + quantized_layer_norm_per_tensor_( + input, + input_scale, + input_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); +} + +void quantized_layer_norm_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& in_scale, + const Tensor& in_zero_point, + __ET_UNUSED const executorch::aten::IntArrayRef normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + if (input.scalar_type() == executorch::aten::ScalarType::Byte) { + quantized_layer_norm_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { + quantized_layer_norm_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); + } +} + +void quantized_layer_norm_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + double in_scale, + int64_t in_zero_point, + __ET_UNUSED const executorch::aten::IntArrayRef normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + if (input.scalar_type() == executorch::aten::ScalarType::Byte) { + quantized_layer_norm_per_tensor_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { + quantized_layer_norm_per_tensor_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); + } +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/operators/op_quantized_linear_out.cpp b/backends/cadence/vision/operators/op_quantized_linear_out.cpp new file mode 100644 index 00000000000..b6b7cdd17bc --- /dev/null +++ b/backends/cadence/vision/operators/op_quantized_linear_out.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using executorch::aten::Tensor; +using executorch::runtime::getLeadingDims; +using executorch::runtime::KernelRuntimeContext; + +template +void inline _typed_quantized_linear( + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + int64_t src_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + Tensor& out) { + const T* __restrict__ src_data = src.const_data_ptr(); + const T* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + + int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; + + // input comes in shape [batch_size, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [batch_size, out_dim] + // Perform matrix multiply (M x N) x (N x P) => M x P + const auto M = weight.size(0); // = out_dim + const auto N = weight.size(1); // = in_dim + + // Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the + // leading dimensions is d0 * d1 * ... * d_{N-2} + const auto leading_dims = getLeadingDims(src, src.dim() - 1); + + ET_CHECK_MSG( + out_multiplier.numel() == 1, "out_multiplier should have one element"); + ET_CHECK_MSG( + out_shift.numel() == 1, "out_multiplier should have one element"); + + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = + -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); + + for (int i = 0; i < leading_dims; ++i) { + for (int j = 0; j < M; ++j) { + float sum = bias_data[j]; + for (int k = 0; k < N; ++k) { + sum += (src_data[i * N + k] - src_zero_point) * + (weight_data[j * N + k] - weight_zero_point); + } + out_data[i * M + j] = + kernels::quantize(sum, out_scale, out_zero_point); + } + } +} + +void quantized_linear_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + int64_t src_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + __ET_UNUSED const executorch::aten::optional& offset, + Tensor& out) { + // TODO: refactor to use switch case as quantized_linear_per_tensor_out + if (out.scalar_type() == executorch::aten::ScalarType::Byte) { + _typed_quantized_linear( + src, + weight, + bias, + src_zero_point, + weight_zero_point_t, + out_multiplier, + out_shift, + out_zero_point, + out); + } else if (out.scalar_type() == executorch::aten::ScalarType::Char) { + _typed_quantized_linear( + src, + weight, + bias, + src_zero_point, + weight_zero_point_t, + out_multiplier, + out_shift, + out_zero_point, + out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(src.scalar_type())); + } +} + +void quantized_linear_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + __ET_UNUSED const executorch::aten::optional& offset, + Tensor& out) { +#define typed_quantized_linear_per_tensor(ctype, dtype) \ + case executorch::aten::ScalarType::dtype: { \ + quantized_linear_per_tensor_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + executorch::aten::ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); + } +#undef typed_quantized_linear_per_tensor +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/operators/op_quantized_matmul_out.cpp b/backends/cadence/vision/operators/op_quantized_matmul_out.cpp new file mode 100644 index 00000000000..54a303288c3 --- /dev/null +++ b/backends/cadence/vision/operators/op_quantized_matmul_out.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using executorch::aten::Tensor; +using executorch::runtime::getLeadingDims; +using executorch::runtime::KernelRuntimeContext; + +// The quantized matmul. The quantized matmul accumulates in a wider register, +// whose type is TA. +template < + typename TZ, + typename TA = float, + bool transposed = false, + typename TX = TZ, + typename TY = TZ> +__attribute__((noinline)) void qmatmul( + TZ* __restrict__ Z, + int32_t Z_multiplier, + int32_t Z_shift, + int32_t Z_zero_point, + const TX* __restrict__ X, + int32_t X_zero_point, + const TY* __restrict__ y, + int32_t Y_zero_point, + size_t m, + size_t n, + size_t p) { + // Compute the Z_scale from Z_multiplier and Z_shift + const float Z_scale = -Z_multiplier * 1.0 / (1 << 31) * pow(2, Z_shift); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < p; ++j) { + TA sum = 0; + for (size_t k = 0; k < n; ++k) { + if (transposed) { + sum += (X[i * n + k] - X_zero_point) * (y[j * n + k] - Y_zero_point); + } else { + sum += (X[i * n + k] - X_zero_point) * (y[k * p + j] - Y_zero_point); + } + } + Z[i * p + j] = kernels::quantize(sum, Z_scale, Z_zero_point); + } + } +} + +template +void inline _typed_quantized_matmul( + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const executorch::aten::optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + size_t batch_size = getLeadingDims(X, X.dim() - 2); + size_t leading_dim = X.size(X.dim() - 2); + size_t out_dim = Y.size(Y.dim() - 1 - transposed); + size_t in_dim = X.size(X.dim() - 1); + + T* __restrict__ out_data = out.mutable_data_ptr(); + const T* __restrict__ X_data = X.const_data_ptr(); + const T* __restrict__ Y_data = Y.const_data_ptr(); + for (size_t i = 0; i < batch_size; ++i) { + const T* x = X_data + i * leading_dim * in_dim; + const T* y = Y_data + i * in_dim * out_dim; + T* z = out_data + i * leading_dim * out_dim; + if (transposed) { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } else { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } + } +} + +void quantized_matmul_out( + KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const executorch::aten::optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + if (out.scalar_type() == executorch::aten::ScalarType::Byte) { + _typed_quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + out); + } else if (out.scalar_type() == executorch::aten::ScalarType::Char) { + _typed_quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(X.scalar_type())); + } +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/vision/operators/op_quantized_relu_out.cpp b/backends/cadence/vision/operators/op_quantized_relu_out.cpp new file mode 100644 index 00000000000..45b9e09b1dd --- /dev/null +++ b/backends/cadence/vision/operators/op_quantized_relu_out.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; + +template +void quantized_relu_( + const Tensor& input, + const Tensor& in_zero_point, + const int64_t out_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { + T q_zero_point = in_zero_point.const_data_ptr()[0]; + const T* __restrict__ in = input.const_data_ptr(); + T* __restrict__ out = output.mutable_data_ptr(); + + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = + -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); + + for (size_t i = 0, e = input.numel(); i < e; ++i) { + const T temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0; + out[i] = kernels::quantize(temp, out_scale, out_zero_point); + } +} + +void quantized_relu_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& in_zero_point, + const int64_t out_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { + if (input.scalar_type() == executorch::aten::ScalarType::Byte) { + quantized_relu_( + input, + in_zero_point, + out_zero_point, + out_multiplier, + out_shift, + output); + } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { + quantized_relu_( + input, + in_zero_point, + out_zero_point, + out_multiplier, + out_shift, + output); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); + } +} + +template +void quantized_relu_per_tensor_out_( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + Tensor& output) { + const T* __restrict__ in = input.const_data_ptr(); + T* __restrict__ out = output.mutable_data_ptr(); + + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift); + + for (size_t i = 0, e = input.numel(); i < e; ++i) { + const float temp = in[i] > in_zero_point ? (in[i] - in_zero_point) : 0; + out[i] = kernels::quantize(temp, out_scale, out_zero_point); + } +} + +void quantized_relu_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + Tensor& output) { +#define typed_quantized_relu(ctype, dtype) \ + case executorch::aten::ScalarType::dtype: { \ + quantized_relu_per_tensor_out_( \ + ctx, \ + input, \ + in_zero_point, \ + out_zero_point, \ + out_multiplier, \ + out_shift, \ + output); \ + break; \ + } + + executorch::aten::ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_relu +} + +}; // namespace native +}; // namespace vision +}; // namespace impl diff --git a/backends/cadence/reference/operators/op_requantize_out.cpp b/backends/cadence/vision/operators/op_requantize_out.cpp similarity index 84% rename from backends/cadence/reference/operators/op_requantize_out.cpp rename to backends/cadence/vision/operators/op_requantize_out.cpp index 5cb9ee3943d..ef538bf4045 100644 --- a/backends/cadence/reference/operators/op_requantize_out.cpp +++ b/backends/cadence/vision/operators/op_requantize_out.cpp @@ -6,11 +6,11 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include #include namespace impl { -namespace reference { +namespace vision { namespace native { using executorch::aten::ScalarType; @@ -86,15 +86,18 @@ Tensor& requantize_out( torch::executor::toString(out.scalar_type()), torch::executor::toString(out_dtype)); -#define typed_requantize(ctype, dtype) \ - const ctype* input_data = input.const_data_ptr(); \ - dtype* out_data = out.mutable_data_ptr(); \ - for (size_t i = 0; i < numel; ++i) { \ - float dequant = \ - kernels::dequantize(input_data[i], in_scale, in_zero_point); \ - out_data[i] = \ - kernels::quantize(dequant, 1 / out_scale, out_zero_point); \ - }; +#define typed_requantize(ctype, dtype) \ + const ctype* input_data = input.const_data_ptr(); \ + dtype* out_data = out.mutable_data_ptr(); \ + kernels::requantize( \ + out_data, \ + input_data, \ + in_scale, \ + in_zero_point, \ + 1.0 / out_scale, \ + out_zero_point, \ + numel); + #define typed_requantize_in(ctype) \ switch (out_dtype) { \ case ScalarType::Byte: { \ @@ -187,15 +190,17 @@ Tensor& requantize_per_tensor_out( torch::executor::toString(out.scalar_type()), torch::executor::toString(out_dtype)); -#define typed_requantize(ctype, dtype) \ - const ctype* input_data = input.const_data_ptr(); \ - dtype* out_data = out.mutable_data_ptr(); \ - for (size_t i = 0; i < numel; ++i) { \ - float dequant = \ - kernels::dequantize(input_data[i], in_scale, in_zero_point); \ - out_data[i] = \ - kernels::quantize(dequant, 1 / out_scale, out_zero_point); \ - }; +#define typed_requantize(ctype, dtype) \ + const ctype* input_data = input.const_data_ptr(); \ + dtype* out_data = out.mutable_data_ptr(); \ + kernels::requantize( \ + out_data, \ + input_data, \ + static_cast(in_scale), \ + static_cast(in_zero_point), \ + 1.0 / static_cast(out_scale), \ + static_cast(out_zero_point), \ + numel); #define typed_requantize_in(ctype) \ switch (out_dtype) { \ @@ -256,6 +261,6 @@ Tensor& requantize_per_tensor_out( return out; } -}; // namespace native -}; // namespace reference -}; // namespace impl +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/op_softmax.cpp b/backends/cadence/vision/operators/op_softmax.cpp new file mode 100644 index 00000000000..58ca33c6a0b --- /dev/null +++ b/backends/cadence/vision/operators/op_softmax.cpp @@ -0,0 +1,303 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; +using torch::executor::Error; + +namespace impl { +namespace vision { +namespace native { + +Tensor& _softmax_out( + KernelRuntimeContext& ctx, + const Tensor& in, + int64_t dim, + bool half_to_float, + Tensor& out) { + (void)ctx; + + ET_KERNEL_CHECK( + ctx, + torch::executor::check_softmax_args(in, dim, half_to_float, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(in, out), + InvalidArgument, + out); + + // Adjust for negative dim + dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim; + + const executorch::aten::optional& dim_t = dim; + const size_t d = ET_NORMALIZE_IX(dim_t.value(), in.dim()); + const size_t size = in.size(d); + + size_t stride = 1, outer_size = 1; + + size_t outer_stride = 1; + + constexpr auto name = "_softmax.out"; + constexpr int MaxDim = 5; + + bool optimized = true; + bool ping_pong_process = false; + bool ping_process_pong = false; + + if ((d == in.dim() - 1)) { + if (size <= IDMA_BUFF_SIZE / 4 && in.dim() != 1) { + ping_pong_process = true; + } else if (size <= IDMA_BUFF_SIZE / 2) { + ping_process_pong = true; + } + } + + if (out.scalar_type() != ScalarType::Float) + optimized = false; + + if (in.dim() > MaxDim) + optimized = false; + + if (optimized) { + const float* ptr_inp = (float*)in.const_data_ptr(); + float* out_data = (float*)out.mutable_data_ptr(); + + /* Channel 0*/ + idma_init(0, 0, MAX_BLOCK_16, 8, TICK_CYCLES_1, 0, NULL); + idma_init_loop(0, descbuf[0], IDMA_2D_DESC, 1, NULL, NULL); + + /* Channel 1*/ + idma_init(1, 0, MAX_BLOCK_16, 8, TICK_CYCLES_1, 0, NULL); + idma_init_loop(1, descbuf[1], IDMA_2D_DESC, 1, NULL, NULL); + + if (ping_pong_process) { + for (int i = 0; i < in.dim(); i++) { + if (i != d) + outer_size *= in.size(i); + } + + outer_stride = size; + stride = size; + + int pp_swap = 0; + + float32_t* ptr_out = out_data; + float32_t* ptr_in = (float32_t*)ptr_inp; + + idma_copy_2d_desc( + 0, inpData[pp_swap], ptr_in, 4 * stride, DESC_IDMA_PRIOR_H, 1, 0, 0); + pp_swap = 1; + + for (int i = 0; i < (outer_size - 1); i++) { + IDMA_HW_WAIT_ALL(0); + ptr_in += outer_stride; + idma_copy_2d_desc( + 0, + inpData[pp_swap], + ptr_in, + 4 * stride, + DESC_IDMA_PRIOR_H, + 1, + 0, + 0); + pp_swap = pp_swap ^ 1; + + /* PROCESS CALL */ + vsoftmaxf(outData[pp_swap], inpData[pp_swap], stride); + + IDMA_HW_WAIT_ALL(1); + idma_copy_2d_desc( + 1, + ptr_out, + outData[pp_swap], + 4 * stride, + DESC_IDMA_PRIOR_H, + 1, + 0, + 0); + ptr_out += outer_stride; + } + + IDMA_HW_WAIT_ALL(0); + pp_swap = pp_swap ^ 1; + + /* PROCESS CALL */ + vsoftmaxf(outData[pp_swap], inpData[pp_swap], stride); + + IDMA_HW_WAIT_ALL(1); + idma_copy_2d_desc( + 1, ptr_out, outData[pp_swap], 4 * stride, DESC_IDMA_PRIOR_H, 1, 0, 0); + + IDMA_HW_WAIT_ALL(1); + + return out; + } else if (ping_process_pong) { + for (int i = 0; i < in.dim(); i++) { + if (i != d) + outer_size *= in.size(i); + } + + outer_stride = size; + stride = size; + + float32_t* ptr_out = out_data; + float32_t* ptr_in = (float32_t*)ptr_inp; + + for (int i = 0; i < outer_size; i++) { + idma_copy_2d_desc( + 0, data_dram0, ptr_in, 4 * stride, DESC_IDMA_PRIOR_H, 1, 0, 0); + IDMA_HW_WAIT_ALL(0); + + vsoftmaxf(data_dram1, data_dram0, stride); + + idma_copy_2d_desc( + 1, ptr_out, data_dram1, 4 * stride, DESC_IDMA_PRIOR_H, 1, 0, 0); + IDMA_HW_WAIT_ALL(1); + + ptr_in += outer_stride; + ptr_out += outer_stride; + } + + return out; + } else { + int num_inp_dims = in.dim(); + int num_out_dims = num_inp_dims; + + int ptr_inp_shape[MaxDim]; + int ptr_out_shape[MaxDim]; + int ptr_permute_vec[MaxDim]; + + for (int i = 0; i < num_inp_dims; i++) + ptr_inp_shape[i] = in.size(i); + + for (int i = 0; i < num_inp_dims; i++) { + if (i == d) + ptr_permute_vec[i] = num_inp_dims - 1; + else if (i == (num_inp_dims - 1)) + ptr_permute_vec[num_inp_dims - 1] = d; + else + ptr_permute_vec[i] = i; + + ptr_out_shape[i] = ptr_inp_shape[ptr_permute_vec[i]]; + + if (i != d) + outer_size = outer_size * ptr_inp_shape[i]; + } + + outer_stride = size; + + float* ptr_out = (float*)kernels::allocate_temp_memory( + ctx, out.numel() * sizeof(float)); + + ET_KERNEL_CHECK(ctx, ptr_out != nullptr, MemoryAllocationFailed, out); + + float* ptr_out1 = (float*)kernels::allocate_temp_memory( + ctx, out.numel() * sizeof(float)); + + ET_KERNEL_CHECK(ctx, ptr_out1 != nullptr, MemoryAllocationFailed, out); + + tensor_transposef( + ptr_out, + ptr_out_shape, + ptr_inp, + ptr_inp_shape, + ptr_permute_vec, + num_out_dims, + num_inp_dims); + + for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + size_t outer = outer_idx * outer_stride; + for (size_t inner_idx = 0; inner_idx < stride; ++inner_idx) { + size_t base = outer + inner_idx; + + float* ptr_in_data = &ptr_out[base]; + float* ptr_out_data = &ptr_out1[base]; + + vsoftmaxf(ptr_out_data, ptr_in_data, size); + } + } + + tensor_transposef( + out_data, + ptr_inp_shape, + ptr_out1, + ptr_out_shape, + ptr_permute_vec, + num_out_dims, + num_inp_dims); + + return out; + } + } + + ET_SWITCH_FLOATHBF16_TYPES( + in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() { + const CTYPE* const in_data = in.const_data_ptr(); + CTYPE* const out_data = out.mutable_data_ptr(); + + torch::executor::apply_over_dim( + [in_data, out_data]( + const size_t size, const size_t stride, const size_t base) { + // calculate max in softmax dim. During softmax computation each + // value is subtracted by the maximum in value before calling exp + // to preserve numerical stability. + const CTYPE max_in = torch::executor::apply_unary_reduce_fn( + [](const CTYPE val_in, CTYPE val_accum) { + return std::max(val_in, val_accum); + }, + in_data + base, + size, + stride); + + const CTYPE temp_sum = + torch::executor::apply_unary_map_reduce_fn( + [max_in](const CTYPE val_in) { + return std::exp(val_in - max_in); + }, + [](const CTYPE mapped_in, CTYPE val_accum) { + return val_accum + mapped_in; + }, + in_data + base, + size, + stride); + + torch::executor::apply_unary_map_fn( + [max_in, temp_sum](const CTYPE val_in) { + return std::exp(val_in - max_in) / temp_sum; + }, + in_data + base, + out_data + base, + size, + stride); + }, + in, + dim); + }); + + return out; +} + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/reference/operators/op_view_copy.cpp b/backends/cadence/vision/operators/op_view_copy.cpp similarity index 80% rename from backends/cadence/reference/operators/op_view_copy.cpp rename to backends/cadence/vision/operators/op_view_copy.cpp index 162e9ee201b..6d4d3a8a5e0 100644 --- a/backends/cadence/reference/operators/op_view_copy.cpp +++ b/backends/cadence/vision/operators/op_view_copy.cpp @@ -8,10 +8,12 @@ #include -namespace torch { -namespace executor { +namespace impl { +namespace vision { namespace native { +using executorch::aten::IntArrayRef; +using ::executorch::aten::IntArrayRef; using executorch::aten::Tensor; using executorch::runtime::KernelRuntimeContext; @@ -25,5 +27,5 @@ Tensor& view_copy_out( } } // namespace native -} // namespace executor -} // namespace torch +} // namespace vision +} // namespace impl diff --git a/backends/cadence/vision/operators/operators.h b/backends/cadence/vision/operators/operators.h new file mode 100644 index 00000000000..8b5db4161eb --- /dev/null +++ b/backends/cadence/vision/operators/operators.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace impl { +namespace vision { +namespace native { + +using ::executorch::runtime::getLeadingDims; + +#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) + +#define ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) + +inline __attribute__((always_inline)) void linear_( + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::optional<::executorch::aten::Tensor>& bias, + ::executorch::aten::Tensor& output) { + const float* __restrict__ input_data = input.const_data_ptr(); + const float* __restrict__ weight_data = weight.const_data_ptr(); + const float* __restrict__ bias_data = bias.value().const_data_ptr(); + float* __restrict__ output_data = output.mutable_data_ptr(); + + // input comes in shape [batch_size, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [batch_size, out_dim] + // Perform matrix multiply (M x N) x (N x P) => M x P + int64_t M = weight.size(0); // = out_dim + int64_t N = weight.size(1); // = in_dim + + // Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the + // leading dimensions is d0 * d1 * ... * d_{N-2} + int64_t leading_dims = getLeadingDims(input, input.dim() - 1); + + for (int i = 0; i < leading_dims; ++i) { + for (int j = 0; j < M; ++j) { + float sum = bias_data[j]; + for (int k = 0; k < N; ++k) { + sum += input_data[i * N + k] * weight_data[j * N + k]; + } + output_data[i * M + j] = sum; + } + } +} + +void quantized_conv2d_nchw_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const ::executorch::aten::Tensor& weight_zero_point, + const ::executorch::aten::Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + ::executorch::aten::Tensor& out); + +void quantized_conv2d_nhwc_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const ::executorch::aten::Tensor& weight_zero_point, + const ::executorch::aten::Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace vision +} // namespace impl diff --git a/backends/cadence/reference/operators/quantized_ops.h b/backends/cadence/vision/operators/quantized_ops.h similarity index 95% rename from backends/cadence/reference/operators/quantized_ops.h rename to backends/cadence/vision/operators/quantized_ops.h index f42d66bed3c..a7251724c53 100644 --- a/backends/cadence/reference/operators/quantized_ops.h +++ b/backends/cadence/vision/operators/quantized_ops.h @@ -8,8 +8,8 @@ #pragma once -#include -#include +#include +#include template inline __attribute__((always_inline)) void quantized_linear_per_tensor_( @@ -49,7 +49,7 @@ inline __attribute__((always_inline)) void quantized_linear_per_tensor_( (int32_t)weight_data[j * in_dim + k] - (int32_t)weight_zero_point; sum += x * w; } - out_data[i * out_dim + j] = ::impl::reference::kernels::quantize( + out_data[i * out_dim + j] = impl::vision::kernels::quantize( sum, requant_scale, out_zero_point); } } @@ -121,8 +121,8 @@ inline __attribute__((always_inline)) void quantized_linear_per_channel_( // Compute the out_scale from out_multiplier and out_shift const float out_scale = -out_multiplier_data[j] * 1.0 / (1 << 31) * pow(2, out_shift_data[j]); - out_data[i * out_dim + j] = ::impl::reference::kernels::quantize( - sum, out_scale, out_zero_point); + out_data[i * out_dim + j] = + impl::vision::kernels::quantize(sum, out_scale, out_zero_point); } } } diff --git a/backends/cadence/vision/operators/targets.bzl b/backends/cadence/vision/operators/targets.bzl new file mode 100644 index 00000000000..2dd47e12bd2 --- /dev/null +++ b/backends/cadence/vision/operators/targets.bzl @@ -0,0 +1,83 @@ +load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + + +def define_operator(name: str, deps: list[str] | None = None) -> None: + op_name = "op_{}".format(name) + + # Deps used by all operators. + common_deps = [ + "//executorch/kernels/portable/cpu/util:all_deps", + "//executorch/kernels/portable/cpu/pattern:all_deps", + "//executorch/runtime/kernel:kernel_includes", + "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/backends/cadence/vision/kernels:cadence_kernels", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", + "//executorch/kernels/portable/cpu/pattern:bitwise_op", + "//executorch/backends/cadence/vision/third-party:vision-nnlib", + "//executorch/kernels/portable/cpu/pattern:comparison_op" + ] + if deps == None: + deps = [] + + # Determine which headers to export based on operator name + exported_headers = ["operators.h"] + + # Add quantized_ops.h header for quantized operators + quantized_ops = [ + "quantized_fully_connected_out", + "quantized_matmul_out", + "quantized_layer_norm", + "quantized_relu_out", + "quantized_conv_out", + "quantized_linear_out", + "quantize_per_tensor", + "dequantize_per_tensor", + "requantize_out" + ] + + if name in quantized_ops: + exported_headers.append("quantized_ops.h") + + runtime.cxx_library( + name = op_name, + srcs = [op_name + ".cpp"], + platforms = CXX, + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + compatible_with = ["ovr_config//cpu:xtensa"], + deps = deps + common_deps, + exported_headers = exported_headers, + ) + +OPERATORS = [ + "add", + "full", + "quantized_fully_connected_out", + "quantized_matmul_out", + "requantize_out", + "dequantize_per_tensor", + "im2row_out", + "quantized_layer_norm", + "quantized_relu_out", + "softmax", + "embedding", + "quantized_conv_out", + "quantized_linear_out", + "quantize_per_tensor", + "view_copy" +] + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + # Define build targets for all operators registered in the tables above. + for op in OPERATORS: + define_operator(op) diff --git a/backends/cadence/vision/third-party/dummy.c b/backends/cadence/vision/third-party/dummy.c new file mode 100644 index 00000000000..52fb7c18c38 --- /dev/null +++ b/backends/cadence/vision/third-party/dummy.c @@ -0,0 +1,17 @@ +/* Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* Dummy source file for non-Xtensa builds + * This file is used when building the vision-nnlib library on platforms + * other than Xtensa, providing empty stubs for compatibility. + * The actual function implementations are provided as stubs via DISCARD_FUN + * in headers when COMPILER_XTENSA is not defined. + */ + +// This file intentionally contains no function definitions and no includes. +// When COMPILER_XTENSA is not defined, all functions are stubbed out +// using the DISCARD_FUN macro in the header files. diff --git a/backends/cadence/vision/third-party/include/api.h b/backends/cadence/vision/third-party/include/api.h new file mode 100644 index 00000000000..efb80c3d76d --- /dev/null +++ b/backends/cadence/vision/third-party/include/api.h @@ -0,0 +1,83 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ +/* + * API + */ + +#ifndef __API_H__ +#define __API_H__ + +#include "dtypes.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*------------------------------------------------------------------------- +Softmax + +Description: The function computes the softmax (normalized exponential +function) of input data. 16-bit fixed-point functions accept inputs in +Q3.12 and form outputs in Q7.8 format. + +vsoftmax 16-bit +vsoftmax_fp16 IEEE-754 Std. half precision floating-point. +vsoftmaxf IEEE-754 Std. single precision floating-point. + +Accuracy: +2 LSB for fixed point API +2 ULP for floating point API +NOTE: Accuracy of function may depend on amount of data and their +distribution. Given accuracy is achieved for N=2 for any pair of +data from input domain. + + +Parameters: +Input: +x[N] input data, Q3.12 floating point +N Length of input/output data vectors +Output: +y[N] result, Q7.8 or floating point + +Restrictions: +x,y aligned on 2*BBE_SIMD_WIDTH-bytes boundary (vsoftmax) +x,y Must not overlap +N multiple of BBE_SIMD_WIDTH (vsoftmax) +-------------------------------------------------------------------------*/ +void vsoftmaxf(float32_t *y, const float32_t *x, int N); + +void tensor_transposef(float32_t *restrict ptr_out + ,const int *const ptr_out_shape + ,const float32_t *restrict ptr_inp + ,const int *const ptr_inp_shape + ,const int *restrict ptr_permute_vec + ,int num_out_dims + ,int num_inp_dims); + +#ifdef __cplusplus +}; +#endif + +#endif /* __API_H__ */ diff --git a/backends/cadence/vision/third-party/include/dtypes.h b/backends/cadence/vision/third-party/include/dtypes.h new file mode 100644 index 00000000000..c12bbf23ac2 --- /dev/null +++ b/backends/cadence/vision/third-party/include/dtypes.h @@ -0,0 +1,380 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ +/* + * Cross-platform data type definitions and utility macros + */ + +#ifndef __DTYPES_H__ +#define __DTYPES_H__ + +#include + +#ifndef COMPILER_ANSI +/* ---------------------------------------------------------- + Compilers autodetection + ----------------------------------------------------------*/ +#define ___UNKNOWN_COMPILER_YET +#ifdef ___UNKNOWN_COMPILER_YET +#ifdef _MSC_VER + +#ifdef _ARM_ +#define COMPILER_CEARM9E /* Microsoft Visual C++,ARM9E */ +#else +#define COMPILER_MSVC /* Microsoft Visual C++ */ +#endif + +#undef ___UNKNOWN_COMPILER_YET +#endif +#endif + +#ifdef ___UNKNOWN_COMPILER_YET +#ifdef _TMS320C6X +#if defined(_TMS320C6400) +#define COMPILER_C64 +#undef ___UNKNOWN_COMPILER_YET +#endif +#if defined(_TMS320C6400_PLUS) +#define COMPILER_C64PLUS +#undef ___UNKNOWN_COMPILER_YET +#endif +#endif +#endif + +#ifdef ___UNKNOWN_COMPILER_YET +#ifdef __TMS320C55X__ +#define COMPILER_C55 +#undef ___UNKNOWN_COMPILER_YET +#endif +#endif + +#ifdef ___UNKNOWN_COMPILER_YET +#ifdef __ADSPBLACKFIN__ +#define COMPILER_ADSP_BLACKFIN +#undef ___UNKNOWN_COMPILER_YET +#endif +#endif + +#ifdef ___UNKNOWN_COMPILER_YET +#ifdef __XCC__ +#define COMPILER_XTENSA +#undef ___UNKNOWN_COMPILER_YET +#endif +#endif + +#ifdef ___UNKNOWN_COMPILER_YET +#ifdef __GNUC__ +#ifdef __arm__ +#ifndef COMPILER_GNU_ARM +#endif +#define COMPILER_GNUARM /* GNU C/C++ compiler*/ +#else +/* GNU GCC x86 compiler */ +#ifndef COMPILER_GNU +#endif +#define COMPILER_GNU /* GNU C/C++ */ +#endif +#undef ___UNKNOWN_COMPILER_YET +#endif +#endif + +#ifdef ___UNKNOWN_COMPILER_YET +#error Unknown compiler +#endif + +#endif /* #ifndef COMPILER_ANSI */ + +/* ---------------------------------------------------------- + Language-dependent definitions + ----------------------------------------------------------*/ +#ifdef __cplusplus + +#undef extern_C +#define extern_C extern "C" + +#else + +#undef extern_C +#define extern_C + +#ifndef false +#define false 0 +#endif +#ifndef true +#define true 1 +#endif + +#endif + +/* Assertion support */ +#if !defined(_ASSERT) +#include +#if defined(_DEBUG) /*&& defined(COMPILER_MSVC)*/ +#define ASSERT(x) \ + { assert(x); } +#else + +/*#undef ASSERT*/ +#ifndef ASSERT +#define ASSERT(_ignore) ((void)0) +#endif + +#endif /* _DEBUG */ +#else /* ASSERT*/ +#define ASSERT(exp) \ + { \ + extern void ExternalAssertHandler(void *, void *, unsigned); \ + (void)((exp) || (ExternalAssertHandler(#exp, __FILE__, __LINE__), 0)); \ + } +#endif /* ASSERT */ + +/*** Inline methods definition ***/ +#undef inline_ +#if (defined COMPILER_MSVC) || (defined COMPILER_CEARM9E) +#define inline_ __inline +#elif defined(COMPILER_ADSP_BLACKFIN) +#define inline_ inline +#elif defined(COMPILER_ANSI) +#define inline_ +#elif (defined COMPILER_GNU) || (defined COMPILER_GNUARM) || \ + (defined COMPILER_ARM) +#define inline_ static inline +#else +#define inline_ static inline +#endif + +#ifndef MAX_INT16 +#define MAX_INT16 ((int16_t)0x7FFF) +#endif +#ifndef MIN_INT16 +#define MIN_INT16 ((int16_t)0x8000) +#endif +#ifndef MAX_INT32 +#define MAX_INT32 ((int32_t)0x7FFFFFFFL) +#endif +#ifndef MIN_INT32 +#define MIN_INT32 ((int32_t)0x80000000L) +#endif +#ifndef MIN_INT64 +#define MIN_INT64 ((int64_t)0x8000000000000000LL) +#endif +#ifndef MAX_INT64 +#define MAX_INT64 ((int64_t)0x7fffffffffffffffLL) +#endif + +/* size of variables in bytes */ +#ifdef COMPILER_C55 +#define SIZEOF_BYTE(x) (sizeof(x) << 1) +#else +#define SIZEOF_BYTE(x) sizeof(x) +#endif + +/*--------------------------------------- + special keywords definition + restrict keyword means that the memory + is addressed exclusively via + this pointer + onchip keyword means that the memory + is on-chip and can not be + accessed via external bus +---------------------------------------*/ +#if defined(COMPILER_C55) +#define NASSERT _nassert +#elif defined(COMPILER_C64) +#define onchip +#define NASSERT _nassert +#elif defined(COMPILER_ADSP_BLACKFIN) +#define onchip +#define NASSERT(x) __builtin_assert(x) +#elif defined(COMPILER_GNUARM) +#define onchip +#define NASSERT(x) \ + { (void)__builtin_expect((x) != 0, 1); } +#define restrict __restrict +#elif defined(COMPILER_GNU) +#define onchip +#define NASSERT(x) \ + { \ + (void)__builtin_expect((x) != 0, 1); \ + ASSERT(x); \ + } +#define restrict __restrict +#elif defined(COMPILER_CEARM9E) +#define onchip +#define NASSERT(x) +#define restrict +#elif defined(COMPILER_XTENSA) +#ifndef restrict +#define restrict __restrict +#endif +#define onchip +#define NASSERT(x) \ + { \ + (void)__builtin_expect((x) != 0, 1); \ + ASSERT(x); \ + } +#else +#define restrict +#define onchip +#define NASSERT ASSERT +#endif +#if defined(COMPILER_ADSP_BLACKFIN) +#define NASSERT_ALIGN(addr, align) __builtin_aligned(addr, align) +#else +#define NASSERT_ALIGN(addr, align) NASSERT(((uintptr_t)(addr)) % (align) == 0) +#endif +#define NASSERT_ALIGN2(addr) NASSERT_ALIGN(addr, 2) +#define NASSERT_ALIGN4(addr) NASSERT_ALIGN(addr, 4) +#define NASSERT_ALIGN8(addr) NASSERT_ALIGN(addr, 8) +#define NASSERT_ALIGN16(addr) NASSERT_ALIGN(addr, 16) +#define NASSERT_ALIGN32(addr) NASSERT_ALIGN(addr, 32) +#define NASSERT_ALIGN64(addr) NASSERT_ALIGN(addr, 64) +#define NASSERT_ALIGN128(addr) NASSERT_ALIGN(addr, 128) +/* ---------------------------------------------------------- + Common types + ----------------------------------------------------------*/ +#if defined(COMPILER_GNU) | defined(COMPILER_GNUARM) | defined(COMPILER_XTENSA) +/* + typedef signed char int8_t; + typedef unsigned char uint8_t; +*/ +#include +#elif defined(COMPILER_C64) +#include +#elif defined(COMPILER_C55) +#include +typedef signed char int8_t; +typedef unsigned char uint8_t; +#elif defined(COMPILER_ADSP_BLACKFIN) +typedef signed char int8_t; +typedef unsigned char uint8_t; +typedef unsigned long uint32_t; +typedef unsigned short uint16_t; +typedef long int32_t; +typedef short int16_t; +typedef long long int64_t; +typedef unsigned long long uint64_t; +typedef uint32_t uintptr_t; +#else +typedef signed char int8_t; +typedef unsigned char uint8_t; +typedef unsigned long uint32_t; +typedef unsigned short uint16_t; +typedef long int32_t; +typedef short int16_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +#endif + +#if defined(COMPILER_CEARM9E) +typedef uint32_t uintptr_t; +#endif + +#if defined(COMPILER_ARM) +typedef uint32_t uintptr_t; +#endif + +typedef int16_t float16_t; +typedef float float32_t; +typedef double float64_t; +typedef int16_t fract16; +typedef int32_t fract32; + +typedef union tag_complex_fract16 { + struct { + int16_t re, im; + } s; + uint32_t a; /* just for 32-bit alignment */ +} complex_fract16; + +typedef union tag_complex_fract32 { + struct { + int32_t re, im; + } s; + uint64_t a; /* just for 64-bit alignment */ +} complex_fract32; + +#if defined(COMPILER_MSVC) +#if 0 +/* Note: Visual Studio does not support C99 compatible complex types yet */ +typedef union tag_complex_float { + struct { + float32_t re, im; + } s; + uint64_t a; /* just for 64-bit alignment */ +} complex_float; +typedef union tag_complex_double { + struct { + float64_t re, im; + } s; + uint64_t a[2]; /* only 64-bit alignment under Visual Studio :(( */ +} complex_double; + +inline_ float32_t crealf(complex_float x) { return x.s.re; } +inline_ float32_t cimagf(complex_float x) { return x.s.im; } +inline_ float64_t creal(complex_double x) { return x.s.re; } +inline_ float64_t cimag(complex_double x) { return x.s.im; } +#else +#include +#define complex_float _Fcomplex +#define complex_double _Dcomplex +#endif + +#else +/* C99 compatible type */ +#include +#define complex_float __complex__ float +#define complex_double __complex__ double +#endif + +/* complex half-precision datatype */ +typedef union tag_complex_float16 { + struct { + float16_t re, im; + } s; + uint32_t a; /* just for 32-bit alignment */ +} complex_float16; + +inline_ float16_t crealh(complex_float16 x) { return x.s.re; } +inline_ float16_t cimagh(complex_float16 x) { return x.s.im; } +/* union data type for writing float32_t/float64_t constants in a bitexact + * form */ +union ufloat32uint32 { + uint32_t u; + float32_t f; +}; +union ufloat64uint64 { + uint64_t u; + float64_t f; +}; +union ufloat16uint16 { + uint16_t u; + float16_t f; +}; + +#if defined(__RENAMING__) +#include "__renaming__.h" +#endif + +#endif /* __DTYPE_H__ */ diff --git a/backends/cadence/vision/third-party/include_private/common.h b/backends/cadence/vision/third-party/include_private/common.h new file mode 100644 index 00000000000..4fc07d8b4d1 --- /dev/null +++ b/backends/cadence/vision/third-party/include_private/common.h @@ -0,0 +1,199 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ + +#ifndef __COMMON_H__ +#define __COMMON_H__ + +#if defined COMPILER_XTENSA +#include +#include +#include +#include +#include +#include +#if XCHAL_HAVE_IDMA +#ifndef IDMA_USE_MULTICHANNEL + #define IDMA_USE_MULTICHANNEL 1 +#endif +#include +#endif +#define IVP_SIMD_WIDTH XCHAL_IVPN_SIMD_WIDTH + +#include "xtensa/config/core-isa.h" +#include "xtensa/tie/xt_ivpn.h" +#if XCHAL_HAVE_IDMA +#include "xtensa/idma.h" +#endif + +#ifdef _MSC_VER +#define ALIGN(x) _declspec(align(x)) +#else +#define ALIGN(x) __attribute__((aligned(x))) +#endif + +#ifdef COMPILER_XTENSA +#define ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) +#define ATTRIBUTE_NEVER_INLINE __attribute__((noinline)) +#define ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define ATTRIBUTE_ALWAYS_INLINE +#define ATTRIBUTE_NEVER_INLINE +#define ATTRIBUTE_UNUSED +#endif + +/* 'restrict' qualifier, is applied to pointers only under clang compiler */ +#ifdef __clang__ +#define restrict_clang restrict +#else +#define restrict_clang +#endif + +// Performance measurement macros +#define XTPERF_PRINTF(...) printf(__VA_ARGS__) +#define TIME_DECL(test) long start_time_##test, end_time_##test; +#define TIME_START(test) { start_time_##test = 0; XT_WSR_CCOUNT(0); } +#define TIME_END(test) { end_time_##test = XT_RSR_CCOUNT(); } +#define TIME_DISPLAY(test, opcnt, opname) { long long cycles_##test = end_time_##test - start_time_##test; \ + XTPERF_PRINTF("PERF_LOG : %s : %d : %s : %lld : cycles : %.2f : %s/cycle : %.2f : cycles/%s\n", \ + #test, opcnt, opname, cycles_##test, cycles_##test == 0 ? 0 : (double)(opcnt)/cycles_##test, \ + opname, cycles_##test == 0 ? 0 : 1/((double)(opcnt)/cycles_##test), opname); } + +//----------------------------------------------------- +// log2(BBE_SIMD_WIDTH) +//----------------------------------------------------- +#define LOG2_IVP_SIMD_WIDTH 5 +#define ALIGN_SIMD ALIGN(64) +#define ALIGN_2SIMD ALIGN(128) + +#define LOG2_SIMD_N_2 (LOG2_IVP_SIMD_WIDTH - 1) +#define LOG2_SIMD_2N (LOG2_IVP_SIMD_WIDTH + 1) +//----------------------------------------------------- +// some C++ support +//----------------------------------------------------- + +// special XCC type casting of pointers +#ifdef __cplusplus +#define castxcc(type_, ptr) (ptr) +#else +#define castxcc(type_, ptr) (type_ *)(ptr) +#endif + +//----------------------------------------------------- +// C99 pragma wrapper +//----------------------------------------------------- + +#ifdef COMPILER_XTENSA +#define __Pragma(a) _Pragma(a) +#else +#define __Pragma(a) +#endif + +//----------------------------------------------------- +// Conditionalization support +//----------------------------------------------------- +/* place DISCARD_FUN(retval_type,name) instead of function definition for + functions to be discarded from the executable THIS WORKS only for external + library functions declared as extern "C" and not supported for internal + references without "C" qualifier! +*/ +#ifdef COMPILER_MSVC +#pragma section("$DISCARDED_FUNCTIONS", execute, discard) +#pragma section("$$$$$$$$$$", execute, discard) +#define DISCARD_FUN(retval_type, name, arglist) \ + __pragma(alloc_text("$DISCARDED_FUNCTIONS", name)) \ + __pragma(section("$DISCARDED_FUNCTIONS", execute, discard)) \ + __pragma(warning(push)) __pragma(warning(disable : 4026 4716)) \ + retval_type name arglist {} \ + __pragma(warning(pop)) +#endif + +#if defined(COMPILER_XTENSA) || defined(COMPILER_GNU) +#define DISCARD_FUN(retval_type, name, arglist) \ + __asm__(".type " #name ", @object\n\t.global " #name \ + "\n\t.align 4\n\t" #name ":\n\t.long 0x49438B96,0x4D73F192\n\t"); +#endif + +/*------ LIST OF DEFINES DEPENDING ON ISA OPTIONS ------*/ + +/* Single-precision Extended Vector Floating-point option */ +#if ((XCHAL_HAVE_VISION_SP_VFPU)) +#define HAVE_SPX_VFPU 1 +#else +#define HAVE_SPX_VFPU 0 +#endif + +/* all vector single precision/Extended vector floating point instructions */ +#if ((XCHAL_HAVE_VISION_SP_VFPU)) +#define HAVE_SPX_VFPU 1 +#define HAVE_VFPU 1 +#else +#define HAVE_SPX_VFPU 0 +#define HAVE_VFPU 0 +#endif + +/* all scalar single precision floating point instructions */ +#if ((XCHAL_HAVE_VISION_SP_VFPU) || (XCHAL_HAVE_FP)) +#define HAVE_FPU 1 +#else +#define HAVE_FPU 0 +#endif + +#else +#define HAVE_VFPU 0 +#define HAVE_FPU 0 +#endif + +/* detect if half precision FPU is present in a core */ +#if ((XCHAL_HAVE_VISION_HP_VFPU)) +#define HAVE_HPFPU 1 +#include +#else +#define HAVE_HPFPU 0 +#endif + +/* detect if double precision FPU is present in a core */ +#if ((XCHAL_HAVE_VISION_DP_VFPU)) +#define HAVE_DPFPU 1 +#include +#else +#define HAVE_DPFPU 0 +#endif + +/* + 32x32 multiplier +*/ +#if defined(BBE_MULN_2X32) +#define HAVE_32X32 1 +#else +#define HAVE_32X32 0 +#endif + +#ifdef __cplusplus +#define externC extern "C" +#else +#define externC extern +#endif + +#endif // __COMMON_H__ diff --git a/backends/cadence/vision/third-party/include_private/expf_tbl.h b/backends/cadence/vision/third-party/include_private/expf_tbl.h new file mode 100644 index 00000000000..702164aba11 --- /dev/null +++ b/backends/cadence/vision/third-party/include_private/expf_tbl.h @@ -0,0 +1,53 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ + +/* + tables for expf(x) approximation +*/ +#ifndef __EXPF_TBL_H__ +#define __EXPF_TBL_H__ + +/* Portable data types. */ +#include "dtypes.h" +#include "common.h" + +/* + polynomial coefficients for 2^x in range 0...1 + + derived by MATLAB code: + order=6; + x=(0:pow2(1,-16):1); + y=2.^x; + p=polyfit(x,y,6); + p(order+1)=1; + p(order)=p(order)-(sum(p)-2); +*/ +externC const int32_t expftbl_Q30[8]; +externC const union ufloat32uint32 + expfminmax[2]; /* minimum and maximum arguments of expf() input */ +externC const int32_t invln2_Q30; /* 1/ln(2), Q30 */ +externC const union ufloat32uint32 expftblf[7]; +externC const union ufloat32uint32 log2_e[2]; +#endif /* __EXPF_TBL_H__ */ diff --git a/backends/cadence/vision/third-party/include_private/idma_init.h b/backends/cadence/vision/third-party/include_private/idma_init.h new file mode 100644 index 00000000000..a885bdf6086 --- /dev/null +++ b/backends/cadence/vision/third-party/include_private/idma_init.h @@ -0,0 +1,36 @@ +#ifndef __IDMA__INIT_H__ +#define __IDMA__INIT_H__ + +#include "../include/dtypes.h" +#include "common.h" + + // 4 kb x sizeof(float32_t) = 16 kb DRAM storage. Assume 4 buffers (2 input and 2 output) +#define IDMA_BUFF_SIZE 4096 + +#ifndef PLACE_IN_DRAM0 +#define PLACE_IN_DRAM0 \ + __attribute__((aligned(2 * IVP_SIMD_WIDTH), section(".dram0.data"))) +#endif + +#ifndef PLACE_IN_DRAM1 +#define PLACE_IN_DRAM1 \ + __attribute__((aligned(2 * IVP_SIMD_WIDTH), section(".dram1.data"))) +#endif + +float32_t data_dram0[IDMA_BUFF_SIZE / 2] PLACE_IN_DRAM0; +float32_t data_dram1[IDMA_BUFF_SIZE / 2] PLACE_IN_DRAM1; + +float32_t* inpData[2] = {&data_dram0[0], &data_dram1[0]}; +float32_t* outData[2] = { + &data_dram0[IDMA_BUFF_SIZE / 4], + &data_dram1[IDMA_BUFF_SIZE / 4]}; + +IDMA_BUFFER_DEFINE(buffer_idma_ch0, 1, IDMA_2D_DESC); +IDMA_BUFFER_DEFINE(buffer_idma_ch1, 1, IDMA_2D_DESC); + +idma_buffer_t* descbuf[] = { + buffer_idma_ch0, + buffer_idma_ch1, +}; + +#endif // __IDMA__INIT_H__ diff --git a/backends/cadence/vision/third-party/include_private/inff_tbl.h b/backends/cadence/vision/third-party/include_private/inff_tbl.h new file mode 100644 index 00000000000..1326e92a3c1 --- /dev/null +++ b/backends/cadence/vision/third-party/include_private/inff_tbl.h @@ -0,0 +1,39 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ + +/* + Infinities for single precision routines +*/ +#ifndef __INFF_TBL_H__ +#define __INFF_TBL_H__ + +#include "dtypes.h" +#include "common.h" + +externC const union ufloat32uint32 minusInff; /* -Inf */ +externC const union ufloat32uint32 plusInff; /* +Inf */ +externC const union ufloat32uint32 realmaxf; /* maximum floating point number */ +externC const union ufloat32uint32 realminf; /* minimum floating point number */ +#endif /* __INFF_TBL_H__ */ diff --git a/backends/cadence/vision/third-party/include_private/nanf_tbl.h b/backends/cadence/vision/third-party/include_private/nanf_tbl.h new file mode 100644 index 00000000000..4881b99f070 --- /dev/null +++ b/backends/cadence/vision/third-party/include_private/nanf_tbl.h @@ -0,0 +1,42 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ +/* + NaN values for single precision routines +*/ + +#ifndef __NANF_TBL_H__ +#define __NANF_TBL_H__ + +/* Portable data types. */ +#include "dtypes.h" +/* Common utility macros. */ +#include "common.h" + +extern const union ufloat32uint32 sNaNf; /* Signalling NaN */ +extern const union ufloat32uint32 qNaNf; /* Quiet NaN */ +extern const union ufloat32uint32 minus_sNaNf; /* Negative Signalling NaN */ +extern const union ufloat32uint32 minus_qNaNf; /* Negative Quiet NaN */ + +#endif /* __NANF_TBL_H__ */ diff --git a/backends/cadence/vision/third-party/library/api/tensor_transposef.c b/backends/cadence/vision/third-party/library/api/tensor_transposef.c new file mode 100644 index 00000000000..e6865033740 --- /dev/null +++ b/backends/cadence/vision/third-party/library/api/tensor_transposef.c @@ -0,0 +1,167 @@ +#include "api.h" +#include "common.h" + +/* + * Currently only supports upto 5D input tensors. + * 1/2/3/4 D input tensors will be scaled up to 5D. + * For example, 2x3 -> 1x1x1x2x3. + */ + +void tensor_transposef(float32_t *restrict ptr_out + ,const int *const ptr_out_shape + ,const float32_t *restrict ptr_inp + ,const int *const ptr_inp_shape + ,const int *restrict ptr_permute_vec + ,int num_out_dims + ,int num_inp_dims) +{ + + /* Shift all dim with 1 in the outer part */ + int eff_output_shape[5]; + int eff_permute_vec[5]; + + for (int i = 0; i < num_out_dims; i++){ + eff_output_shape[i] = ptr_out_shape[i]; + eff_permute_vec[i] = ptr_permute_vec[i]; + } + + int one_i = num_out_dims - 1, non_one_i = num_out_dims - 1; + while (one_i > 0 && non_one_i >= 0){ + while (one_i > 0 && eff_output_shape[one_i] != 1){ + one_i--; + } + non_one_i = one_i; + while (non_one_i >= 0 && eff_output_shape[non_one_i]==1){ + non_one_i--; + } + if (one_i > 0 && non_one_i >= 0){ + int temp; + /*swap output_shape*/ + { + temp = eff_output_shape[one_i]; + eff_output_shape[one_i] = eff_output_shape[non_one_i]; + eff_output_shape[non_one_i] = temp; + } + /*swap permute_vec*/ + { + temp = eff_permute_vec[one_i]; + eff_permute_vec[one_i] = eff_permute_vec[non_one_i]; + eff_permute_vec[non_one_i] = temp; + } + } + } + + /* Promoting lesser dim tensors to 5D tensors. + * Also updating the permute_vec and shapes as needed for optimization */ + int ptr_5D_inp_shape[5] = {1, 1, 1, 1, 1}; + int ptr_5D_out_shape[5] = {1, 1, 1, 1, 1}; + int ptr_5D_permute_vec[5] = {0, 1, 2, 3, 4}; + + /* Check if any inner inp dimension is same in the output */ + int last_dim_same = 1, last_n_same_dim = 0; + int itr = num_inp_dims - 1; + while(itr >= 0){ + last_n_same_dim = (last_dim_same && (eff_permute_vec[itr] == itr)) ? (last_n_same_dim + 1) : last_n_same_dim; + last_dim_same = (eff_permute_vec[itr] == itr) ? last_dim_same & 1 : last_dim_same & 0; + itr--; + } + + int dims_added = 5 - num_inp_dims; + itr = num_inp_dims - 1; + int same_count = last_n_same_dim; + int count = 4; + while(itr >= 0){ + ptr_5D_inp_shape[count] = (same_count > 0) ? ptr_5D_inp_shape[count] * ptr_inp_shape[itr] : ptr_inp_shape[itr]; + ptr_5D_out_shape[count] = (same_count > 0) ? ptr_5D_out_shape[count] * eff_output_shape[itr] : eff_output_shape[itr]; + same_count--; + itr--; + count = (same_count > 0) ? count : count - 1; + } + + itr = num_inp_dims - 1; + same_count = (last_n_same_dim) ? num_inp_dims - (last_n_same_dim - 1) : 0; + count = 4; + while(itr >= 0){ + ptr_5D_permute_vec[count] = (same_count > 0) ? eff_permute_vec[itr-(last_n_same_dim - 1)] + dims_added + last_n_same_dim - 1 : eff_permute_vec[itr] + dims_added; + same_count--; + itr--; + count--; + } + + int out_dim0, out_dim1, out_dim2, out_dim3, out_dim4; + int inp_dim1, inp_dim2, inp_dim3, inp_dim4; + int inp_stride[5]; + + out_dim0 = ptr_5D_out_shape[0]; + out_dim1 = ptr_5D_out_shape[1]; + out_dim2 = ptr_5D_out_shape[2]; + out_dim3 = ptr_5D_out_shape[3]; + out_dim4 = ptr_5D_out_shape[4]; + + inp_dim1 = ptr_5D_inp_shape[1]; + inp_dim2 = ptr_5D_inp_shape[2]; + inp_dim3 = ptr_5D_inp_shape[3]; + inp_dim4 = ptr_5D_inp_shape[4]; + + inp_stride[0] = inp_dim1 * inp_dim2 * inp_dim3 * inp_dim4; + inp_stride[1] = inp_dim2 * inp_dim3 * inp_dim4; + inp_stride[2] = inp_dim3 * inp_dim4; + inp_stride[3] = inp_dim4; + inp_stride[4] = 1; + + if (last_n_same_dim){ + int itr0, itr1, itr2, itr3, itr4; + float32_t *ptr_inp0 = (float32_t *)ptr_inp; + for (itr0 = 0; itr0 < out_dim0; itr0++){ + float32_t *ptr_inp1 = ptr_inp0 + (itr0 * inp_stride[ptr_5D_permute_vec[0]]); +#pragma looptr_count min=1 + for (itr1 = 0; itr1 < out_dim1; itr1++){ + float32_t *ptr_inp2 = ptr_inp1 + (itr1 * inp_stride[ptr_5D_permute_vec[1]]); +#pragma looptr_count min=1 + for (itr2 = 0; itr2 < out_dim2; itr2++){ + float32_t *ptr_inp3 = ptr_inp2 + (itr2 * inp_stride[ptr_5D_permute_vec[2]]); +#pragma looptr_count min=1 + for (itr3 = 0; itr3 < out_dim3; itr3++, ptr_out += out_dim4){ + float32_t *ptr_inp4 = ptr_inp3 + (itr3 * inp_stride[ptr_5D_permute_vec[3]]); + xb_vecN_2xf32 *restrict pae_i = (xb_vecN_2xf32 *)(ptr_inp4); + xb_vecN_2xf32 *restrict pae_o = (xb_vecN_2xf32 *)(ptr_out); + valign a_inp = IVP_LAN_2XF32_PP(pae_i); + valign a_out = IVP_ZALIGN(); + xb_vecN_2xf32 d0; + for(itr4 = 0; itr4 < (out_dim4 >> (LOG2_IVP_SIMD_WIDTH - 1)); itr4++){ + IVP_LAN_2XF32_IP(d0, a_inp, pae_i); + IVP_SAN_2XF32_IP(d0, a_out, pae_o); + } + IVP_SAPOSN_2XF32_FP(a_out, pae_o); + float32_t *restrict puae_i = (float32_t *)(pae_i); + float32_t *restrict puae_o = (float32_t *)(pae_o); +#pragma looptr_count max = 17 + for(itr4 = 0; itr4 < (out_dim4 & (IVP_SIMD_WIDTH / 2 - 1)); itr4++){ + puae_o[itr4] = puae_i[itr4]; + } + } + } + } + } + } + else{ + int itr0, itr1, itr2, itr3, itr4; + float32_t *ptr_inp0 = (float32_t *)ptr_inp; + for(itr0 = 0; itr0 < out_dim0; itr0++){ + float32_t *ptr_inp1 = ptr_inp0 + (itr0 * inp_stride[ptr_5D_permute_vec[0]]); + for(itr1 = 0; itr1 < out_dim1; itr1++){ + float32_t *ptr_inp2 = ptr_inp1 + (itr1 * inp_stride[ptr_5D_permute_vec[1]]); + for(itr2 = 0; itr2 < out_dim2; itr2++){ + float32_t *ptr_inp3 = ptr_inp2 + (itr2 * inp_stride[ptr_5D_permute_vec[2]]); + for(itr3 = 0; itr3 < out_dim3; itr3++){ + float32_t *ptr_inp4 = ptr_inp3 + (itr3 * inp_stride[ptr_5D_permute_vec[3]]); + for(itr4 = 0; itr4 < out_dim4; itr4++){ + *ptr_out++ = *ptr_inp4; + ptr_inp4 = ptr_inp4 + inp_stride[ptr_5D_permute_vec[4]]; + } + } + } + } + } + } +} diff --git a/backends/cadence/vision/third-party/library/api/vsoftmaxf.c b/backends/cadence/vision/third-party/library/api/vsoftmaxf.c new file mode 100644 index 00000000000..27487c75d6c --- /dev/null +++ b/backends/cadence/vision/third-party/library/api/vsoftmaxf.c @@ -0,0 +1,241 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ +/* + NatureDSP_Baseband library. Vector Mathematics. + Softmax, floating-point data +*/ +#include "api.h" +#include "common.h" +#include "expf_tbl.h" +#include "inff_tbl.h" +#include "nanf_tbl.h" + +/*------------------------------------------------------------------------- +Softmax + +Description: The function computes the softmax (normalized exponential +function) of input data. 16-bit fixed-point functions accept inputs in +Q3.12 and form outputs in Q7.8 format. + +vsoftmax 16-bit +vsoftmax_fp16 IEEE-754 Std. half precision floating-point. +vsoftmaxf IEEE-754 Std. single precision floating-point. + +Accuracy: +2 LSB for fixed point API +2 ULP for floating point API +NOTE: Accuracy of function may depend on amount of data and their +distribution. Given accuracy is achieved for N=2 for any pair of +data from input domain. + + +Parameters: +Input +: +x[N] input data, Q3.12 floating point +N Length of input/output data vectors +Output: +y[N] result, Q7.8 or floating point + +Restrictions: +x,y Must not overlap +-------------------------------------------------------------------------*/ + +#define IVP_ADDSN_2X32(b_, c_) \ + ({ \ + xb_vecN_2x32v a_; \ + xb_vecN_2x64w tmp_a_; \ + tmp_a_ = IVP_MULN_2X32(b_, 1); \ + IVP_MULAN_2X32(tmp_a_, c_, 1); \ + a_ = IVP_PACKVRN_2X64W(tmp_a_, 0); \ + a_; \ + }) + +#if !HAVE_VFPU +DISCARD_FUN(void, vsoftmaxf, (float32_t * y, const float32_t* x, int N)) +#else +void vsoftmaxf(float32_t* y, const float32_t* x, int N) { +#if !defined(IVP_MULN_2X32) +#else + const int* pTbl = (const int*)expftbl_Q30; +#endif + const xb_vecN_2xf32* restrict pX; + xb_vecN_2xf32* restrict pY; + xb_vecN_2xf32 norm, ysum, xmax; + int n; + valign al_X, al_R, al_Y; + if (N < 0) + return; + xmax = minusInff.f; + pX = (const xb_vecN_2xf32*)x; + al_X = IVP_LAN_2XF32_PP(pX); + al_Y = IVP_ZALIGN(); + for (n = 0; n < (N >> (LOG2_IVP_SIMD_WIDTH - 1)); n++) { + xb_vecN_2xf32 x; + IVP_LAN_2XF32_IP(x, al_X, pX); + xmax = IVP_MAXNUMN_2XF32(xmax, x); + } + if (N & (IVP_SIMD_WIDTH / 2 - 1)) { + xb_vecN_2xf32 x; + IVP_LAVN_2XF32_XP( + x, al_X, pX, sizeof(float32_t) * (N & (IVP_SIMD_WIDTH / 2 - 1))); + IVP_MAXNUMN_2XF32T( + xmax, xmax, x, IVP_LTRSN_2((N & (IVP_SIMD_WIDTH / 2 - 1)))); + } + + xmax = IVP_REPN_2XF32(IVP_RMAXNUMN_2XF32(xmax), 0); + __Pragma("no_reorder"); + ysum = 0.f; + pX = (const xb_vecN_2xf32*)x; + pY = (xb_vecN_2xf32*)y; + al_X = IVP_LAN_2XF32_PP(pX); + { + vboolN_2 bnan; + bnan = IVP_LTRN_2I(0); + for (n = 0; n < (N >> (LOG2_IVP_SIMD_WIDTH - 1)); n++) { + xb_vecN_2xf32 x; + IVP_LAN_2XF32_IP(x, al_X, pX); + x = IVP_SUBN_2XF32(x, xmax); + bnan |= IVP_UNN_2XF32(x, x); + { + xb_vecN_2xf32 gf, zout; + xb_vecN_2x32v xin_i, fr, exp, t; + xb_vecN_2x32v y, y1, y2, c1, c2, f2; + xb_vecN_2x64w w; + xin_i = IVP_TRUNCN_2XF32(x, 24); + /* Multiply by 1/ln2, extract the integer and fractional (Q32) + * components. */ + /* Q54 <- Q24*Q30 */ + w = IVP_MULN_2X32(xin_i, invln2_Q30); + exp = IVP_PACKVRNRN_2X64W(w, 54); + fr = IVP_SRLN_2X32(IVP_PACKVRNRN_2X64W(w, 22), 1); + /* polynomial for 2^x */ + f2 = IVP_PACKVRN_2X64W(IVP_MULN_2X32(fr, fr), 31); + y1 = IVP_LSRN_2X32_I(pTbl, 0 * sizeof(int32_t)); + y2 = IVP_LSRN_2X32_I(pTbl, 1 * sizeof(int32_t)); + c1 = IVP_LSRN_2X32_I(pTbl, 2 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y1), 31); + y1 = IVP_ADDSN_2X32(c1, t); + c2 = IVP_LSRN_2X32_I(pTbl, 3 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y2), 31); + y2 = IVP_ADDSN_2X32(c2, t); + c1 = IVP_LSRN_2X32_I(pTbl, 4 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y1), 31); + y1 = IVP_ADDSN_2X32(c1, t); + c2 = IVP_LSRN_2X32_I(pTbl, 5 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y2), 31); + y2 = IVP_ADDSN_2X32(c2, t); + c1 = IVP_LSRN_2X32_I(pTbl, 6 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y1), 31); + y1 = IVP_ADDSN_2X32(c1, t); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(fr, y2), 31); + y = IVP_ADDSN_2X32(y1, t); + /* scale result to original exponent ignoring very low items */ + gf = IVP_FLOATN_2X32(y, 30); + exp = IVP_SLLIN_2X32(IVP_MAXN_2X32(IVP_ADDN_2X32(127, exp), 0), 23); + zout = IVP_MULN_2XF32(gf, IVP_MOVN_2XF32_FROMN_2X32(exp)); + x = zout; + } + ysum = IVP_ADDN_2XF32(ysum, x); + IVP_SAN_2XF32_IP(x, al_Y, pY); + } + if (N & (IVP_SIMD_WIDTH / 2 - 1)) { + xb_vecN_2xf32 x; + IVP_LAVN_2XF32_XP( + x, al_X, pX, sizeof(float32_t) * (N & (IVP_SIMD_WIDTH / 2 - 1))); + x = IVP_SUBN_2XF32(x, xmax); + bnan |= IVP_UNN_2XF32(x, x); + { + xb_vecN_2xf32 gf, zout; + xb_vecN_2x32v xin_i, fr, exp, t; + xb_vecN_2x32v y, y1, y2, c1, c2, f2; + xb_vecN_2x64w w; + xin_i = IVP_TRUNCN_2XF32(x, 24); + /* Multiply by 1/ln2, extract the integer and fractional (Q32) + * components. */ + /* Q54 <- Q24*Q30 */ + w = IVP_MULN_2X32(xin_i, invln2_Q30); + exp = IVP_PACKVRNRN_2X64W(w, 54); + fr = IVP_SRLN_2X32(IVP_PACKVRNRN_2X64W(w, 22), 1); + /* polynomial for 2^x */ + f2 = IVP_PACKVRN_2X64W(IVP_MULN_2X32(fr, fr), 31); + y1 = IVP_LSRN_2X32_I(pTbl, 0 * sizeof(int32_t)); + y2 = IVP_LSRN_2X32_I(pTbl, 1 * sizeof(int32_t)); + c1 = IVP_LSRN_2X32_I(pTbl, 2 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y1), 31); + y1 = IVP_ADDSN_2X32(c1, t); + c2 = IVP_LSRN_2X32_I(pTbl, 3 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y2), 31); + y2 = IVP_ADDSN_2X32(c2, t); + c1 = IVP_LSRN_2X32_I(pTbl, 4 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y1), 31); + y1 = IVP_ADDSN_2X32(c1, t); + c2 = IVP_LSRN_2X32_I(pTbl, 5 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y2), 31); + y2 = IVP_ADDSN_2X32(c2, t); + c1 = IVP_LSRN_2X32_I(pTbl, 6 * sizeof(int32_t)); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(f2, y1), 31); + y1 = IVP_ADDSN_2X32(c1, t); + t = IVP_PACKVRN_2X64W(IVP_MULN_2X32(fr, y2), 31); + y = IVP_ADDSN_2X32(y1, t); + /* scale result to original exponent ignoring very low items */ + gf = IVP_FLOATN_2X32(y, 30); + exp = IVP_SLLIN_2X32(IVP_MAXN_2X32(IVP_ADDN_2X32(127, exp), 0), 23); + zout = IVP_MULN_2XF32(gf, IVP_MOVN_2XF32_FROMN_2X32(exp)); + x = zout; + } + IVP_ADDN_2XF32T( + ysum, ysum, x, IVP_LTRSN_2((N & (IVP_SIMD_WIDTH / 2 - 1)))); + IVP_SAVN_2XF32_XP( + x, al_Y, pY, sizeof(float32_t) * (N & (IVP_SIMD_WIDTH / 2 - 1))); + } + IVP_SAPOSN_2XF32_FP(al_Y, pY); + ysum = IVP_MOVN_2XF32T(qNaNf.f, ysum, bnan); + } + norm = XT_RECIP_S(IVP_RADDN_2XF32(ysum)); + __Pragma("no_reorder"); + pX = (const xb_vecN_2xf32*)y; + pY = (xb_vecN_2xf32*)y; + + al_R = IVP_LAN_2XF32_PP(pX); + + for (n = 0; n < (N >> (LOG2_IVP_SIMD_WIDTH - 1)); n++) { + xb_vecN_2xf32 x; + IVP_LAN_2XF32_IP(x, al_R, pX); + x = IVP_MULN_2XF32(x, norm); + IVP_SAN_2XF32_IP(x, al_Y, pY); + } + if (N & (IVP_SIMD_WIDTH / 2 - 1)) { + xb_vecN_2xf32 x; + IVP_LAVN_2XF32_XP( + x, al_R, pX, sizeof(float32_t) * (N & (IVP_SIMD_WIDTH / 2 - 1))); + x = IVP_MULN_2XF32(x, norm); + IVP_SAVN_2XF32_XP( + x, al_Y, pY, sizeof(float32_t) * (N & (IVP_SIMD_WIDTH / 2 - 1))); + } + IVP_SAPOSN_2XF32_FP(al_Y, pY); + +} /* vsoftmaxf() */ +#endif diff --git a/backends/cadence/vision/third-party/library/tables/expf_tbl.c b/backends/cadence/vision/third-party/library/tables/expf_tbl.c new file mode 100644 index 00000000000..f1c6f3d44ae --- /dev/null +++ b/backends/cadence/vision/third-party/library/tables/expf_tbl.c @@ -0,0 +1,85 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ + +/* + tables for expf(x) approximation +*/ +/* Portable data types. */ +#include "expf_tbl.h" +#include "dtypes.h" + +/* + polynomial coefficients for 2^x in range 0...1 + + derived by MATLAB code: + order=6; + x=(0:pow2(1,-16):1); + y=2.^x; + p=polyfit(x,y,6); + p(order+1)=1; + p(order)=p(order)-(sum(p)-2); +*/ +const int32_t ALIGN_2SIMD expftbl_Q30[8] = { + 234841, + 1329551, + 10400465, + 59570027, + 257946177, + 744260763, + 1073741824, + 0 /* Padding to allow for vector loads */ +}; + +const union ufloat32uint32 ALIGN_2SIMD + expfminmax[2] = /* minimum and maximum arguments of expf() input */ + { + {0xc2ce8ed0}, /*-1.0327893066e+002f */ + {0x42b17218} /* 8.8722839355e+001f */ +}; + +const int32_t invln2_Q30 = 1549082005L; /* 1/ln(2), Q30 */ + +const union ufloat32uint32 ALIGN_2SIMD log2_e[2] = { + {0x3fb8aa3b}, /* 1.4426950216 */ + {0x32a57060} /* 1.9259629891e-008 */ +}; + +/* +order=6; +x=(0:pow2(1,-16):1); +y=2.^x; +p=polyfit(x,y,order); +p(order+1)=1; +p(order)=p(order)-(sum(p)-2); +num2hex(single(p)); +*/ +const union ufloat32uint32 ALIGN_2SIMD expftblf[] = { + {0x39655635}, + {0x3aa24c7a}, + {0x3c1eb2d1}, + {0x3d633ddb}, + {0x3e75ff24}, + {0x3f317212}, + {0x3f800000}}; diff --git a/backends/cadence/vision/third-party/library/tables/inff_tbl.c b/backends/cadence/vision/third-party/library/tables/inff_tbl.c new file mode 100644 index 00000000000..8464ee9f549 --- /dev/null +++ b/backends/cadence/vision/third-party/library/tables/inff_tbl.c @@ -0,0 +1,38 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ + +/* + infinities for single precision routines +*/ + +#include "inff_tbl.h" +#include "dtypes.h" + +const union ufloat32uint32 minusInff = {0xff800000}; /* -Inf */ +const union ufloat32uint32 plusInff = {0x7f800000}; /* +Inf */ +const union ufloat32uint32 realmaxf = { + 0x7f7fffff}; /* maximum floating point number */ +const union ufloat32uint32 realminf = { + 0x00800000}; /* minimum floating point number */ diff --git a/backends/cadence/vision/third-party/library/tables/nanf_tbl.c b/backends/cadence/vision/third-party/library/tables/nanf_tbl.c new file mode 100644 index 00000000000..f165234fce4 --- /dev/null +++ b/backends/cadence/vision/third-party/library/tables/nanf_tbl.c @@ -0,0 +1,38 @@ +/* ------------------------------------------------------------------------ */ +/* Copyright (c) 2024 by Cadence Design Systems, Inc. ALL RIGHTS RESERVED. */ +/* These coded instructions, statements, and computer programs ('Cadence */ +/* Libraries') are the copyrighted works of Cadence Design Systems Inc. */ +/* Cadence IP is licensed for use with Cadence processor cores only and */ +/* must not be used for any other processors and platforms. Your use of the */ +/* Cadence Libraries is subject to the terms of the license agreement you */ +/* have entered into with Cadence Design Systems, or a sublicense granted */ +/* to you by a direct Cadence licensee. */ +/* ------------------------------------------------------------------------ */ +/* IntegrIT, Ltd. www.integrIT.com, info@integrIT.com */ +/* */ +/* NatureDSP_Baseband Library */ +/* */ +/* This library contains copyrighted materials, trade secrets and other */ +/* proprietary information of IntegrIT, Ltd. This software is licensed for */ +/* use with Cadence processor cores only and must not be used for any other */ +/* processors and platforms. The license to use these sources was given to */ +/* Cadence, Inc. under Terms and Condition of a Software License Agreement */ +/* between Cadence, Inc. and IntegrIT, Ltd. */ +/* ------------------------------------------------------------------------ */ +/* Copyright (C) 2009-2022 IntegrIT, Limited. */ +/* All Rights Reserved. */ +/* ------------------------------------------------------------------------ */ +/* + NaN values for single precision routines +*/ + +/* Portable data types. */ +/* NaN values for single precision routines. */ +#include "nanf_tbl.h" +#include "dtypes.h" + +const union ufloat32uint32 sNaNf = {0x7f800001}; /* Signalling NaN */ +const union ufloat32uint32 qNaNf = {0x7fc00000}; /* Quiet NaN */ +const union ufloat32uint32 minus_sNaNf = { + 0xff800001}; /* Negative Signalling NaN */ +const union ufloat32uint32 minus_qNaNf = {0xffc00000}; /* Negative Quiet NaN */ diff --git a/backends/cadence/vision/third-party/targets.bzl b/backends/cadence/vision/third-party/targets.bzl new file mode 100644 index 00000000000..26a097010d5 --- /dev/null +++ b/backends/cadence/vision/third-party/targets.bzl @@ -0,0 +1,38 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbsource//arvr/tools/build_defs:oxx.bzl", "oxx_binary", "oxx_static_library") + + +def define_common_targets(): + runtime.cxx_library( + name = "vision-nnlib", + srcs = select({ + "DEFAULT": ["dummy.c"], # Use dummy file for non-Xtensa builds + "ovr_config//cpu:xtensa": glob(["library/**/*.c"]), + }), + exported_headers = glob([ + "include/*.h", + "include_private/*.h" + ]), + header_namespace = "", + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + platforms = CXX, + compatible_with = select({ + "DEFAULT": [], + "ovr_config//cpu:xtensa": ["ovr_config//cpu:xtensa"], + }), + compiler_flags = select({ + "DEFAULT": ["-UCOMPILER_XTENSA"], # Ensure COMPILER_XTENSA is not defined for non-Xtensa builds + "ovr_config//cpu:xtensa": [ + "-DCOMPILER_XTENSA", + "-Ixplat/executorch/backends/cadence/vision/third-party/include", + "-Ixplat/executorch/backends/cadence/vision/third-party/include_private", + ], + }), + define_static_target = True, + ) diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index 1567b8b5e1c..ac330d4b015 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -12,7 +12,7 @@ if(NOT CMAKE_CXX_STANDARD) set(CMAKE_CXX_STANDARD 17) endif() -# Source root directory for executorch. +# Source root directory for executorch if(NOT EXECUTORCH_ROOT) set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) endif() @@ -21,71 +21,81 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake) include(FetchContent) -# CMSIS-NN version to download +# CMSIS-NN configuration with dynamic path detection set(CMSIS_NN_VERSION - "v4.1.0" + "v7.0.0" CACHE STRING "CMSIS-NN version to download" ) - -# Declare CMSIS-NN as a FetchContent project -FetchContent_Declare( - cmsis_nn - GIT_REPOSITORY https://github.com/ARM-software/CMSIS-NN.git - GIT_TAG ${CMSIS_NN_VERSION} +set(CMSIS_NN_LOCAL_PATH + "" + CACHE PATH "Path to existing local CMSIS-NN installation" ) -# Download and make CMSIS-NN available -FetchContent_MakeAvailable(cmsis_nn) - -# Print paths for debugging -message(STATUS "CMSIS-NN source dir: ${cmsis_nn_SOURCE_DIR}") -message(STATUS "CMSIS-NN binary dir: ${cmsis_nn_BINARY_DIR}") +# Try to find existing / local CMSIS-NN installation. This is useful for +# debugging and testing with local changes. This is not common, as the CMSIS-NN +# library is downloaded via FetchContent in the default/regular case. +if(CMSIS_NN_LOCAL_PATH AND EXISTS "${CMSIS_NN_LOCAL_PATH}") + message(STATUS "Using CMSIS-NN from specified path: ${CMSIS_NN_LOCAL_PATH}") + add_subdirectory(${CMSIS_NN_LOCAL_PATH} _deps/cmsis_nn-build) +else() + # Use FetchContent with automatic fallback + message(STATUS "Using CMSIS-NN via FetchContent") + + FetchContent_Declare( + cmsis_nn + GIT_REPOSITORY https://github.com/ARM-software/CMSIS-NN.git + GIT_TAG ${CMSIS_NN_VERSION} + GIT_SHALLOW TRUE + ) + + FetchContent_MakeAvailable(cmsis_nn) +endif() # Cortex-M ops kernel sources set(_cortex_m_kernels__srcs ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_conv2d.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_maximum.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_transpose.cpp ) -# Generate C++ bindings to register kernels into Executorch (for runtime) +# Generate C++ bindings to register kernels into Executorch set(_yaml_file ${CMAKE_CURRENT_LIST_DIR}/ops/operators.yaml) gen_selected_ops(LIB_NAME "cortex_m_ops_lib" OPS_SCHEMA_YAML "${_yaml_file}") - generate_bindings_for_kernels( LIB_NAME "cortex_m_ops_lib" CUSTOM_OPS_YAML "${_yaml_file}" ) -message("Generated files ${gen_command_sources}") -# Build a library for cortex_m_kernels +# Build library for cortex_m_kernels add_library(cortex_m_kernels ${_cortex_m_kernels__srcs}) -target_compile_options(cortex_m_kernels PUBLIC ${_common_compile_options}) -# Include directories for cortex_m_kernels -target_include_directories( +# Use PRIVATE for implementation dependencies to avoid INTERFACE pollution +target_link_libraries( cortex_m_kernels - PRIVATE ${EXECUTORCH_ROOT}/.. - ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 - ${cmsis_nn_SOURCE_DIR}/Include + PRIVATE cmsis-nn + PRIVATE executorch ) -# Link directly to the CMSIS-NN static library file -target_link_libraries( - cortex_m_kernels PUBLIC ${cmsis_nn_BINARY_DIR}/libcmsis-nn.a executorch +# Include directories for cortex_m_kernels +target_include_directories( + cortex_m_kernels PRIVATE ${EXECUTORCH_ROOT}/.. + ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 ) -# Add dependency to ensure CMSIS-NN builds before we try to link. Use the actual -# CMSIS-NN target name (usually 'cmsis-nn') -add_dependencies(cortex_m_kernels cmsis-nn) - # cortex_m_ops_lib: Register Cortex-M ops kernels into Executorch runtime gen_operators_lib( LIB_NAME "cortex_m_ops_lib" KERNEL_LIBS cortex_m_kernels DEPS executorch ) install( - TARGETS cortex_m_kernels cortex_m_ops_lib + TARGETS cortex_m_kernels cortex_m_ops_lib cmsis-nn EXPORT ExecuTorchTargets - DESTINATION lib - PUBLIC_HEADER DESTINATION include/executorch/backends/cortex_m/ops/ + DESTINATION ${CMAKE_INSTALL_LIBDIR} + PUBLIC_HEADER + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/executorch/backends/cortex_m/ops/ ) diff --git a/backends/cortex_m/ops/cmsis_scratch_buffer_context.h b/backends/cortex_m/ops/cmsis_scratch_buffer_context.h new file mode 100644 index 00000000000..4b9fdaebdf7 --- /dev/null +++ b/backends/cortex_m/ops/cmsis_scratch_buffer_context.h @@ -0,0 +1,187 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include "cortex_m_ops_common.h" +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +// During AOT phase, quantized_linear_fusion_pass allocates total buffer +// and passes in as 'Tensor'. (Total buffer = 8-byte header + x bytes) +// ┌─────────────────┬─────────────────────────────────────┐ +// │ KernelSum Header│ CMSIS Workspace │ +// │ (8 bytes) │ (x bytes) │ +// └─────────────────┴─────────────────────────────────────┘ +// │ │ +// │ └─> Passed to CMSIS API +// │ +// └─> State for kernel sum + +// C++ Runtime: +// ┌─────────────────┬─────────────────────────────────────┐ +// │ KernelSum Header│ CMSIS Workspace │ +// │ (8 bytes) │ (x bytes) │ +// └─────────────────┴─────────────────────────────────────┘ +// ^ ^ +// │ │ +// scratch_ptr cmsis_workspace_ptr +// │ │ +// ▼ ▼ +// arm_vector_sum_s8() writes kernel sums (with bias if avail): +// [sum₀+bias₀][sum₁+bias₁][sum₂+bias₂]...[sum_{n-1}+bias_{n-1}] +// (n * 4-byte int32_t values = x bytes) +// +// - n = out_features (number of output features) +// - x = n * 4 bytes (total CMSIS buffer size) +// - Total buffer = 8 + x bytes + +class CMSISScratchBufferContext final { + public: + CMSISScratchBufferContext( + Tensor& scratch_buffer, + const Tensor& weights, + const Tensor& weight_zero_point, + const torch::executor::optional& bias) + : scratch_ptr_(scratch_buffer.mutable_data_ptr()), + total_size_(scratch_buffer.size(0)), + base_ptr_(reinterpret_cast(scratch_ptr_)), + in_features_(weights.size(1)), + out_features_(weights.size(0)), + is_per_channel_(weight_zero_point.numel() > 1), + weight_data_offset_(calculate_offset(weights.const_data_ptr())), + weight_zp_data_offset_( + calculate_offset(weight_zero_point.const_data_ptr())), + bias_data_offset_( + bias.has_value() + ? calculate_offset(bias.value().const_data_ptr()) + : 0), + header_(reinterpret_cast(scratch_ptr_)), + cmsis_workspace_ptr_(scratch_ptr_ + KERNEL_SUM_HEADER_SIZE) { + cmsis_nn_dims filter_dims = {in_features_, 1, 1, out_features_}; + validate_size(filter_dims); + } + + cmsis_nn_context get_cmsis_ctx() const { + cmsis_nn_context ctx; + ET_CHECK_MSG( + reinterpret_cast(cmsis_workspace_ptr_) % 4 == 0, + "CMSIS workspace not 4-byte aligned"); + ctx.buf = cmsis_workspace_ptr_; + ctx.size = get_cmsis_workspace_size(); + return ctx; + } + + bool is_kernel_sum_updated() const { + return header_->updated; + } + + void compute_kernel_sums_if_needed() { + if (!header_->updated) { + arm_vector_sum_s8( + reinterpret_cast(cmsis_workspace_ptr_), + in_features_, + out_features_, + get_weight_data(), + get_weight_zp_data()[0], + 0, + get_bias_data()); + header_->updated = true; + ET_LOG( + Info, + "Computed kernel sums. [required_bytes : %d]", + header_->required_size); + } + } + + const int8_t* get_weight_data() const { + return reinterpret_cast(base_ptr_ + weight_data_offset_); + } + + const int32_t* get_weight_zp_data() const { + return reinterpret_cast(base_ptr_ + weight_zp_data_offset_); + } + + const int32_t* get_bias_data() const { + return bias_data_offset_ == 0 + ? nullptr + : reinterpret_cast(base_ptr_ + bias_data_offset_); + } + + bool is_per_channel_quant() const { + return is_per_channel_; + } + int32_t get_in_features() const { + return in_features_; + } + int32_t get_out_features() const { + return out_features_; + } + + private: + static constexpr size_t KERNEL_SUM_HEADER_SIZE = 8; + + // Header for kernel sum computation state only + struct KernelSumHeader { + bool updated = false; + int32_t required_size = 0; + }; + static_assert( + sizeof(KernelSumHeader) == KERNEL_SUM_HEADER_SIZE, + "KernelSumHeader must be exactly 8 bytes"); + + int8_t* scratch_ptr_; + size_t total_size_; + uint8_t* base_ptr_; + + // Context members + const int32_t in_features_; + const int32_t out_features_; + const bool is_per_channel_; + const uint32_t weight_data_offset_; + const uint32_t weight_zp_data_offset_; + const uint32_t bias_data_offset_; + + KernelSumHeader* header_; + int8_t* cmsis_workspace_ptr_; + + uint32_t calculate_offset(const void* ptr) const { + if (ptr == nullptr) + return 0; + + const uint8_t* ptr_bytes = reinterpret_cast(ptr); + ET_CHECK_MSG(ptr_bytes >= base_ptr_, "Pointer is before base address"); + + const std::ptrdiff_t offset = ptr_bytes - base_ptr_; + ET_CHECK_MSG( + offset >= 0 && offset <= UINT32_MAX, "Offset out of valid range"); + return static_cast(offset); + } + + size_t get_cmsis_workspace_size() const { + return total_size_ - KERNEL_SUM_HEADER_SIZE; + } + + void validate_size(const cmsis_nn_dims& filter_dims) const { + header_->required_size = + arm_fully_connected_s8_get_buffer_size(&filter_dims); + + ET_CHECK_MSG( + get_cmsis_workspace_size() >= + static_cast(header_->required_size), + "Scratch buffer size %zu insufficient for required size %d", + get_cmsis_workspace_size(), + header_->required_size); + } +}; + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/cortex_m_ops_common.h b/backends/cortex_m/ops/cortex_m_ops_common.h index 5ef2d9d4bf9..71cf718c9a8 100644 --- a/backends/cortex_m/ops/cortex_m_ops_common.h +++ b/backends/cortex_m/ops/cortex_m_ops_common.h @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -17,10 +18,23 @@ #include #include +#include +#include + +extern "C" { +#include "arm_nn_types.h" +} + using Tensor = torch::executor::Tensor; using ScalarType = executorch::aten::ScalarType; using Scalar = torch::executor::Scalar; using Error = executorch::runtime::Error; +using IntArrayRef = executorch::aten::ArrayRef; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +// From arm_nn_math_types.h +#define ARM_NN_Q31_MAX ((int32_t)(0x7FFFFFFFL)) +#define ARM_NN_Q31_MIN ((int32_t)(0x80000000L)) // Basic tensor type / layout validation and dimension order checking inline void validate_cmsis_nn_tensor_requirements( @@ -28,26 +42,34 @@ inline void validate_cmsis_nn_tensor_requirements( const Tensor& input2, Tensor& output, ScalarType expected_dtype = ScalarType::Char, - bool require_channels_last = false) { + bool require_channels_last = false, + bool require_same_sizes = true) { // Basic dtype validation ET_CHECK_MSG( input1.scalar_type() == expected_dtype, - "Input1 dtype must be %hhd", - expected_dtype); + "Input1 dtype must be %hhd, got %hhd", + expected_dtype, + input1.scalar_type()); ET_CHECK_MSG( input2.scalar_type() == expected_dtype, - "Input2 dtype must be %hhd", - expected_dtype); + "Input2 dtype must be %hhd, got %hhd", + expected_dtype, + input2.scalar_type()); ET_CHECK_MSG( output.scalar_type() == expected_dtype, - "Output dtype must be %hhd", - expected_dtype); - - // Dim order consistency - ET_CHECK_MSG( - executorch::runtime::tensors_have_same_dim_order(input1, input2, output), - "Tensors must have same dimension order"); - + "Output dtype must be %hhd, got %hhd", + expected_dtype, + output.scalar_type()); + if (require_same_sizes) { + ET_CHECK_MSG( + input1.sizes() == input2.sizes(), + "Input1 and Input2 must have the same sizes"); + ET_CHECK_MSG( + output.sizes() == input1.sizes(), + "Output must have the same sizes as inputs"); + } + + // TBD (#16032): Validate dim_order // TBD: Validate memory alignment (CMSIS-NN requirement) } @@ -60,13 +82,6 @@ inline void validate_single_quant_params( int64_t mult_val = multiplier.to(); int64_t shift_val = shift.to(); - ET_CHECK_MSG( - zp_val >= std::numeric_limits::min() && - zp_val <= std::numeric_limits::max(), - "%s zero point must be in int8 range [Value: %d]", - param_name, - zp_val); - ET_CHECK_MSG( mult_val >= std::numeric_limits::min() && mult_val <= std::numeric_limits::max(), @@ -114,6 +129,70 @@ inline void validate_quantization_params( "Single quant Output"); } +inline bool is_channels_last_tensor(const Tensor& tensor) { + if (tensor.dim() != 4) { + return false; + } + + // When channels or spatial dims are 1 the layout information is ambiguous. + if (tensor.size(1) == 1 || (tensor.size(2) == 1 && tensor.size(3) == 1)) { + return true; + } + + constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = { + 0, 2, 3, 1}; + executorch::aten::ArrayRef + channels_last_order(kChannelsLastDimOrder, 4); + + return tensor.dim_order() == channels_last_order; +} + +inline bool is_channel_broadcast(const Tensor& tensor1, const Tensor& tensor2) { + if (tensor1.dim() != tensor2.dim()) { + return false; + } + + if (tensor1.dim() != 4) { + return false; + } + + if (tensor1.size(1) != tensor2.size(1)) { + return false; + } + + const bool tensor1_channels_only = tensor1.numel() == tensor1.size(1); + const bool tensor2_channels_only = tensor2.numel() == tensor2.size(1); + + return tensor1_channels_only || tensor2_channels_only; +} + +// Refer to CMSIS-NN 'arm_nn_requantize' implementation for details: +// https://github.com/ARM-software/CMSIS-NN/blob/main/Include/arm_nnsupportfunctions.h#L1625 +// multiplier: Range {ARM_NN_Q31_MIN + 1, Q32_MAX} +// shift : Range {-31, 30} +inline bool validate_per_channel_quant_params( + const int32_t* multipliers, + const int32_t* shifts, + int num_channels) { + for (int i = 0; i < num_channels; ++i) { + // Multiplier: {ARM_NN_Q31_MIN + 1, ARM_NN_Q31_MAX} + if (multipliers[i] <= ARM_NN_Q31_MIN || multipliers[i] > ARM_NN_Q31_MAX) { + ET_LOG( + Error, + "weight_multiplier[%d] out of CMSIS-NN range: %d", + i, + multipliers[i]); + return false; + } + // Shift: {-31, 30} for arm_nn_requantize + if (shifts[i] < -31 || shifts[i] > 30) { + ET_LOG(Error, "weight_shift[%d] out of range: %d", i, shifts[i]); + return false; + } + } + return true; +} + inline Error resize_to_broadcast_target_size( const Tensor& input1, const Tensor& input2, diff --git a/backends/cortex_m/ops/op_maximum.cpp b/backends/cortex_m/ops/op_maximum.cpp new file mode 100644 index 00000000000..71a907f12ea --- /dev/null +++ b/backends/cortex_m/ops/op_maximum.cpp @@ -0,0 +1,102 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +// Include CMSIS-NN headers with C linkage +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +Tensor& maximum_out( + KernelRuntimeContext& context, + const Tensor& input1, + const Tensor& input2, + Tensor& out) { + validate_cmsis_nn_tensor_requirements( + input1, + input2, + out, + ScalarType::Char, + /*require_channels_last=*/false, + /*require_same_sizes=*/false); + + auto resize_error = resize_to_broadcast_target_size(input1, input2, out); + if (resize_error != Error::Ok) { + ET_LOG(Error, "maximum_out: broadcast shape mismatch between inputs"); + context.fail(resize_error); + return out; + } + + const int8_t* input1_data = input1.const_data_ptr(); + const int8_t* input2_data = input2.const_data_ptr(); + int8_t* output_data = out.mutable_data_ptr(); + + // Create CMSIS-NN dims directly from tensor sizes + const auto input1_rank = input1.dim(); + const auto input1_sizes = input1.sizes(); + const cmsis_nn_dims input1_dims{ + static_cast( + input1_rank >= 4 ? input1_sizes[input1_rank - 4] : 1), + static_cast( + input1_rank >= 3 ? input1_sizes[input1_rank - 3] : 1), + static_cast( + input1_rank >= 2 ? input1_sizes[input1_rank - 2] : 1), + static_cast( + input1_rank >= 1 ? input1_sizes[input1_rank - 1] : 1)}; + + const auto input2_rank = input2.dim(); + const auto input2_sizes = input2.sizes(); + const cmsis_nn_dims input2_dims{ + static_cast( + input2_rank >= 4 ? input2_sizes[input2_rank - 4] : 1), + static_cast( + input2_rank >= 3 ? input2_sizes[input2_rank - 3] : 1), + static_cast( + input2_rank >= 2 ? input2_sizes[input2_rank - 2] : 1), + static_cast( + input2_rank >= 1 ? input2_sizes[input2_rank - 1] : 1)}; + + const auto output_rank = out.dim(); + const auto output_sizes = out.sizes(); + const cmsis_nn_dims output_dims{ + static_cast( + output_rank >= 4 ? output_sizes[output_rank - 4] : 1), + static_cast( + output_rank >= 3 ? output_sizes[output_rank - 3] : 1), + static_cast( + output_rank >= 2 ? output_sizes[output_rank - 2] : 1), + static_cast( + output_rank >= 1 ? output_sizes[output_rank - 1] : 1)}; + + const arm_cmsis_nn_status status = arm_maximum_s8( + /* ctx */ nullptr, + input1_data, + &input1_dims, + input2_data, + &input2_dims, + output_data, + &output_dims); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "maximum_out: arm_maximum_s8 failed with status [%d]", + static_cast(status)); + context.fail(Error::Internal); + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/op_minimum.cpp b/backends/cortex_m/ops/op_minimum.cpp new file mode 100644 index 00000000000..f220aa2664b --- /dev/null +++ b/backends/cortex_m/ops/op_minimum.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +// Include CMSIS-NN headers with C linkage +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +Tensor& minimum_out( + KernelRuntimeContext& context, + const Tensor& input1, + const Tensor& input2, + Tensor& out) { + validate_cmsis_nn_tensor_requirements( + input1, + input2, + out, + ScalarType::Char, + /*require_channels_last=*/false, + /*require_same_sizes=*/false); + + auto resize_error = resize_to_broadcast_target_size(input1, input2, out); + if (resize_error != Error::Ok) { + ET_LOG(Error, "minimum_out: broadcast shape mismatch between inputs"); + context.fail(resize_error); + return out; + } + + const int8_t* input1_data = input1.const_data_ptr(); + const int8_t* input2_data = input2.const_data_ptr(); + int8_t* output_data = out.mutable_data_ptr(); + + // Create CMSIS-NN dims directly from tensor sizes + const auto input1_rank = input1.dim(); + const auto input1_sizes = input1.sizes(); + const cmsis_nn_dims input1_dims{ + static_cast( + input1_rank >= 4 ? input1_sizes[input1_rank - 4] : 1), + static_cast( + input1_rank >= 3 ? input1_sizes[input1_rank - 3] : 1), + static_cast( + input1_rank >= 2 ? input1_sizes[input1_rank - 2] : 1), + static_cast( + input1_rank >= 1 ? input1_sizes[input1_rank - 1] : 1)}; + + const auto input2_rank = input2.dim(); + const auto input2_sizes = input2.sizes(); + const cmsis_nn_dims input2_dims{ + static_cast( + input2_rank >= 4 ? input2_sizes[input2_rank - 4] : 1), + static_cast( + input2_rank >= 3 ? input2_sizes[input2_rank - 3] : 1), + static_cast( + input2_rank >= 2 ? input2_sizes[input2_rank - 2] : 1), + static_cast( + input2_rank >= 1 ? input2_sizes[input2_rank - 1] : 1)}; + + const auto output_rank = out.dim(); + const auto output_sizes = out.sizes(); + const cmsis_nn_dims output_dims{ + static_cast( + output_rank >= 4 ? output_sizes[output_rank - 4] : 1), + static_cast( + output_rank >= 3 ? output_sizes[output_rank - 3] : 1), + static_cast( + output_rank >= 2 ? output_sizes[output_rank - 2] : 1), + static_cast( + output_rank >= 1 ? output_sizes[output_rank - 1] : 1)}; + + const arm_cmsis_nn_status status = arm_minimum_s8( + /* ctx */ nullptr, + input1_data, + &input1_dims, + input2_data, + &input2_dims, + output_data, + &output_dims); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "minimum_out: arm_minimum_s8 failed with status [%d]", + static_cast(status)); + context.fail(Error::Internal); + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/op_quantized_add.cpp b/backends/cortex_m/ops/op_quantized_add.cpp index 044c2bd92d5..019ab4cfb58 100644 --- a/backends/cortex_m/ops/op_quantized_add.cpp +++ b/backends/cortex_m/ops/op_quantized_add.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -32,7 +33,14 @@ Tensor& quantized_add_out( const Scalar& output_shift, Tensor& out) { // Validate tensor types and dim order - validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out); + bool channel_broadcast = is_channel_broadcast(input1_int8, input2_int8); + validate_cmsis_nn_tensor_requirements( + input1_int8, + input2_int8, + out, + ScalarType::Char, + /*require_channels_last=*/channel_broadcast, + /*require_same_sizes=*/!channel_broadcast); // Validate quantization parameters validate_quantization_params( @@ -47,13 +55,6 @@ Tensor& quantized_add_out( output_shift, out); - // Broadcast if needed - auto result = resize_to_broadcast_target_size(input1_int8, input2_int8, out); - ET_CHECK_MSG( - (result == Error::Ok), - "Failed to resize output tensor. Status: [%d]", - result); - ET_LOG( Info, "quantized_add_out: input1_int8.sizes() = %zu", @@ -68,8 +69,10 @@ Tensor& quantized_add_out( int32_t out_zp = extractScalarToInt32(output_zero_point); int32_t output_mult = extractScalarToInt32(output_multiplier); int output_shift_val = extractScalarToInt(output_shift); + int8_t* input1_ptr = input1_int8.data_ptr(); + int8_t* input2_ptr = input2_int8.data_ptr(); - // Left shift to maximize precision (tune as needed) + // Left shift to maximize precision const int32_t left_shift = 20; const int32_t activation_min = std::numeric_limits::min(); const int32_t activation_max = std::numeric_limits::max(); @@ -84,33 +87,58 @@ Tensor& quantized_add_out( output_mult, output_shift_val); - // Call CMSIS-NN kernel with precomputed parameters - arm_cmsis_nn_status status = arm_elementwise_add_s8( - input1_int8.const_data_ptr(), - input2_int8.const_data_ptr(), - static_cast(zp1), - input1_mult, - input1_shift_val, - static_cast(zp2), - input2_mult, - input2_shift_val, - left_shift, - out.mutable_data_ptr(), - static_cast(out_zp), - output_mult, - output_shift_val, - static_cast(out.numel()), - activation_min, - activation_max); - - if (status != ARM_CMSIS_NN_SUCCESS) { - ET_LOG( - Error, - "quantized_add_out: arm_elementwise_add_s8 failed with status [%d]", - status); - - context.fail(Error::Internal); // Fail the execution context - return out; + // Note 1: The CMSIS-NN kernel implementation uses offsets which are always + // added to the data, whereas zero_points are subtracted when dequantizing + // (for the inputs) and added when quantizing (for the output). Hence the + // negative signs required here. + + // Note 2: It is not possible to perform the same rewrite as for mul for + // addition. To preserve precision when rescaling the inputs, they are first + // upscaled as much as possible, Hence the left_shift parameter required here. + + int32_t adds_per_loop = 0; + if (channel_broadcast) { + if (input1_int8.numel() < input2_int8.numel()) { + std::swap(zp1, zp2); + std::swap(input1_mult, input2_mult); + std::swap(input1_shift_val, input2_shift_val); + std::swap(input1_ptr, input2_ptr); + } + adds_per_loop = input1_int8.size(1); + } else { + adds_per_loop = out.numel(); + } + + for (int32_t broadcast_offset = 0; broadcast_offset < out.numel(); + broadcast_offset += adds_per_loop) { + // Call CMSIS-NN kernel with precomputed parameters + arm_cmsis_nn_status status = arm_elementwise_add_s8( + input1_ptr + broadcast_offset, + input2_ptr, + -static_cast(zp1), + input1_mult, + input1_shift_val, + -static_cast(zp2), + input2_mult, + input2_shift_val, + left_shift, + out.mutable_data_ptr() + broadcast_offset, + static_cast(out_zp), + output_mult, + output_shift_val, + activation_min, + activation_max, + adds_per_loop); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_add_out: arm_elementwise_add_s8 failed with status [%d]", + status); + + context.fail(Error::Internal); // Fail the execution context + return out; + } } ET_LOG( Info, @@ -119,32 +147,5 @@ Tensor& quantized_add_out( return out; } -// Stub Implementation: Non-out variant for compatibility (functional variant) -// EXIR/ExecuTorch runs an out-variant pass that converts -// .default operations to .out variants before memory planning. -// In the pass we are calling quantized_add's default variant -// but ExecuTorch's kernel dispatch mechanism will end up calling the out -// variant. This stub is to make sure that compiler doesn't complain. -Tensor quantized_add( - KernelRuntimeContext& context, - const Tensor& input1_int8, - const Scalar& input1_zero_point, - const Scalar& input1_multiplier, - const Scalar& input1_shift, - const Tensor& input2_int8, - const Scalar& input2_zero_point, - const Scalar& input2_multiplier, - const Scalar& input2_shift, - const Scalar& output_zero_point, - const Scalar& output_multiplier, - const Scalar& output_shift) { - ET_LOG(Info, "quantized_add: input1_int8.sizes() = %zu", input1_int8.sizes()); - - // Crash on Debug builds if invoked - assert(False); - // This is to make sure compiler doesn't complain. - return const_cast(input1_int8); -} - } // namespace native } // namespace cortex_m diff --git a/backends/cortex_m/ops/op_quantized_conv2d.cpp b/backends/cortex_m/ops/op_quantized_conv2d.cpp new file mode 100644 index 00000000000..ad14af98865 --- /dev/null +++ b/backends/cortex_m/ops/op_quantized_conv2d.cpp @@ -0,0 +1,236 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +namespace { +constexpr int64_t kConvDim = 4; + +bool validate_conv2d_arguments( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& weight, + const torch::executor::optional& bias, + const Tensor& output, + const IntArrayRef& stride, + const IntArrayRef& padding, + const IntArrayRef& dilation, + const Tensor& requantize_multipliers, + const Tensor& requantize_shifts) { + if (input.dim() != kConvDim || weight.dim() != kConvDim || + output.dim() != kConvDim) { + ET_LOG(Error, "quantized_conv2d_out: tensors must be 4-D"); + context.fail(Error::InvalidArgument); + return false; + } + + // Check for channels_last dim_order (NHWC: 0, 2, 3, 1) + // Skip check if channels == 1, as dim_order is ambiguous in that case + constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = { + 0, 2, 3, 1}; + executorch::aten::ArrayRef + channels_last_order(kChannelsLastDimOrder, 4); + + if (input.size(1) > 1 && input.dim_order() != channels_last_order) { + ET_LOG( + Error, + "quantized_conv2d_out: input must have channels_last dim_order (NHWC)"); + context.fail(Error::InvalidArgument); + return false; + } + + if (output.size(1) > 1 && output.dim_order() != channels_last_order) { + ET_LOG( + Error, + "quantized_conv2d_out: output must have channels_last dim_order (NHWC)"); + context.fail(Error::InvalidArgument); + return false; + } + + if (input.scalar_type() != ScalarType::Char || + output.scalar_type() != ScalarType::Char) { + ET_LOG(Error, "quantized_conv2d_out: input and output must be int8"); + context.fail(Error::InvalidArgument); + return false; + } + + if (weight.scalar_type() != ScalarType::Char) { + ET_LOG(Error, "quantized_conv2d_out: weight must be int8"); + context.fail(Error::InvalidArgument); + return false; + } + + if (bias.has_value() && bias.value().scalar_type() != ScalarType::Int) { + ET_LOG(Error, "quantized_conv2d_out: bias must be int32 if provided"); + context.fail(Error::InvalidArgument); + return false; + } + + if (stride.size() != 2 || padding.size() != 2 || dilation.size() != 2) { + ET_LOG( + Error, + "quantized_conv2d_out: stride, padding, and dilation must have length 2"); + context.fail(Error::InvalidArgument); + return false; + } + + const int64_t out_channels = output.size(1); + if (requantize_multipliers.size(0) != out_channels || + requantize_shifts.size(0) != out_channels) { + ET_LOG( + Error, + "quantized_conv2d_out: per-channel params must match output channels (%zd)", + out_channels); + context.fail(Error::InvalidArgument); + return false; + } + + return true; +} +} // namespace + +Tensor& quantized_conv2d_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& weight, + const torch::executor::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const int64_t input_offset, + const int64_t output_offset, + const Tensor& requantize_multipliers, + const Tensor& requantize_shifts, + const int64_t activation_min, + const int64_t activation_max, + Tensor& out) { + if (!validate_conv2d_arguments( + context, + input, + weight, + bias, + out, + stride, + padding, + dilation, + requantize_multipliers, + requantize_shifts)) { + return out; + } + + const int32_t batch = static_cast(input.size(0)); + const int32_t input_channels = static_cast(input.size(1)); + const int32_t input_height = static_cast(input.size(2)); + const int32_t input_width = static_cast(input.size(3)); + + const int32_t kernel_output_channels = static_cast(weight.size(0)); + const int32_t kernel_height = static_cast(weight.size(1)); + const int32_t kernel_width = static_cast(weight.size(2)); + const int32_t kernel_input_channels = static_cast(weight.size(3)); + + const int32_t output_channels = static_cast(out.size(1)); + const int32_t output_height = static_cast(out.size(2)); + const int32_t output_width = static_cast(out.size(3)); + + const int32_t input_offset_val = static_cast(input_offset); + const int32_t output_offset_val = static_cast(output_offset); + const int32_t activation_min_val = static_cast(activation_min); + const int32_t activation_max_val = static_cast(activation_max); + + const cmsis_nn_dims input_dims{ + batch, input_height, input_width, input_channels}; + const cmsis_nn_dims filter_dims{ + kernel_output_channels, + kernel_height, + kernel_width, + kernel_input_channels}; + const cmsis_nn_dims output_dims{ + batch, output_height, output_width, output_channels}; + const cmsis_nn_dims bias_dims{1, 1, 1, output_channels}; + const cmsis_nn_dims upscale_dims{1, 1, 1, 1}; + + cmsis_nn_conv_params conv_params; + conv_params.input_offset = input_offset_val; + conv_params.output_offset = output_offset_val; + conv_params.stride.h = static_cast(stride[0]); + conv_params.stride.w = static_cast(stride[1]); + conv_params.padding.h = static_cast(padding[0]); + conv_params.padding.w = static_cast(padding[1]); + conv_params.dilation.h = static_cast(dilation[0]); + conv_params.dilation.w = static_cast(dilation[1]); + conv_params.activation.min = activation_min_val; + conv_params.activation.max = activation_max_val; + + cmsis_nn_per_channel_quant_params quant_params; + quant_params.multiplier = requantize_multipliers.data_ptr(); + quant_params.shift = requantize_shifts.data_ptr(); + + const int8_t* input_data = input.const_data_ptr(); + const int8_t* weight_data = weight.const_data_ptr(); + int8_t* output_data = out.mutable_data_ptr(); + const int32_t* bias_data = + bias.has_value() ? bias.value().const_data_ptr() : nullptr; + + cmsis_nn_context cmsis_context; + cmsis_context.buf = nullptr; + cmsis_context.size = 0; + + const size_t buffer_bytes = static_cast( + arm_convolve_s8_get_buffer_size(&input_dims, &filter_dims)); + if (buffer_bytes > 0) { + auto buffer_or_error = + context.allocate_temp(buffer_bytes, alignof(int16_t)); + if (!buffer_or_error.ok()) { + if (buffer_or_error.error() != Error::NotFound) { + ET_LOG( + Error, + "quantized_conv2d_out: failed to allocate scratch buffer (%d)", + static_cast(buffer_or_error.error())); + context.fail(buffer_or_error.error()); + return out; + } + } else { + cmsis_context.buf = buffer_or_error.get(); + cmsis_context.size = buffer_bytes; + } + } + + const arm_cmsis_nn_status status = arm_convolve_wrapper_s8( + &cmsis_context, + &conv_params, + &quant_params, + &input_dims, + input_data, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + output_data); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_conv2d_out: arm_convolve_s8 failed with status %d", + status); + context.fail(Error::Internal); + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/op_quantized_linear.cpp b/backends/cortex_m/ops/op_quantized_linear.cpp new file mode 100644 index 00000000000..015fa805134 --- /dev/null +++ b/backends/cortex_m/ops/op_quantized_linear.cpp @@ -0,0 +1,110 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +Tensor& quantized_linear_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& weights, + const torch::executor::optional& bias, + const torch::executor::optional& kernel_sum, + const Scalar& input_offset, + const Scalar& filter_offset, + const Scalar& output_offset, + const IntArrayRef requantize_multipliers, + const IntArrayRef requantize_shifts, + const Scalar& activation_max, + const Scalar& activation_min, + Tensor& out) { + ET_LOG(Info, "quantized_linear_out: called"); + + const int8_t* input_data = input.const_data_ptr(); + const int8_t* weight_data = weights.const_data_ptr(); + const int32_t* bias_data = + bias.has_value() ? bias.value().const_data_ptr() : nullptr; + int32_t* kernel_sum_data = + kernel_sum.has_value() ? kernel_sum.value().data_ptr() : nullptr; + int8_t* output_data = out.mutable_data_ptr(); + + cmsis_nn_context ctx; + ctx.size = 0; // Not used in CMSIS-NN + ctx.buf = kernel_sum_data; + + // Setup CMSIS-NN parameters + cmsis_nn_fc_params fc_params; + fc_params.input_offset = static_cast(input_offset.to()); + fc_params.filter_offset = static_cast(filter_offset.to()); + fc_params.output_offset = static_cast(output_offset.to()); + fc_params.activation.min = static_cast(activation_min.to()); + fc_params.activation.max = static_cast(activation_max.to()); + + cmsis_nn_per_tensor_quant_params per_tensor_quant_params; + per_tensor_quant_params.multiplier = + static_cast(requantize_multipliers.at(0)); + per_tensor_quant_params.shift = static_cast(requantize_shifts.at(0)); + + auto in_feat = input.size(input.dim() - 1); + auto out_feat = out.size(out.dim() - 1); + auto batches = 1; + for (size_t i = 0; i < input.dim() - 1; i++) { + batches *= input.size(i); + } + ET_LOG( + Info, + "in features: %d, out_features: %d, batches: %d, kernel_sum_size: %d", + in_feat, + out_feat, + batches, + kernel_sum.has_value() ? kernel_sum.value().numel() : 0); + ET_LOG( + Info, + "kernel_sum[0]: %d, kernel_sum[1]: %d", + kernel_sum_data != nullptr ? kernel_sum_data[0] : -1, + kernel_sum_data != nullptr ? kernel_sum_data[1] : -1); + cmsis_nn_dims input_dims = {batches, 1, 1, in_feat}; + cmsis_nn_dims filter_dims = {in_feat, 1, 1, out_feat}; + cmsis_nn_dims bias_dims = {1, 1, 1, out_feat}; + cmsis_nn_dims output_dims = {batches, 1, 1, out_feat}; + + arm_cmsis_nn_status status = arm_fully_connected_s8( + &ctx, + &fc_params, + &per_tensor_quant_params, + &input_dims, + input_data, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + output_data); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_linear_out: CMSIS-NN failed with status [%d]", + status); + context.fail(Error::Internal); + return out; + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/op_quantized_mul.cpp b/backends/cortex_m/ops/op_quantized_mul.cpp new file mode 100644 index 00000000000..3d2d7657e36 --- /dev/null +++ b/backends/cortex_m/ops/op_quantized_mul.cpp @@ -0,0 +1,126 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +// Include CMSIS-NN headers with C linkage +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { +namespace { + +constexpr int32_t kInt8ActivationMin = std::numeric_limits::min(); +constexpr int32_t kInt8ActivationMax = std::numeric_limits::max(); + +} // namespace + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +Tensor& quantized_mul_out( + KernelRuntimeContext& context, + const Tensor& input1_int8, + const Scalar& input1_zero_point, + const Tensor& input2_int8, + const Scalar& input2_zero_point, + const Scalar& output_zero_point, + const Scalar& output_multiplier, + const Scalar& output_shift, + Tensor& out) { + // Validate tensor types and quantization parameters + + bool channel_broadcast = is_channel_broadcast(input1_int8, input2_int8); + validate_cmsis_nn_tensor_requirements( + input1_int8, + input2_int8, + out, + ScalarType::Char, + /*require_channels_last=*/channel_broadcast, + /*require_same_sizes=*/!channel_broadcast); + + const Scalar kIdentityMultiplier(/*value=*/1); + const Scalar kZeroShift(/*value=*/0); + validate_quantization_params( + input1_zero_point, + kIdentityMultiplier, + kZeroShift, + input2_zero_point, + kIdentityMultiplier, + kZeroShift, + output_zero_point, + output_multiplier, + output_shift, + out); + + // Extract quantization parameters + int8_t* input1_ptr = input1_int8.data_ptr(); + int8_t* input2_ptr = input2_int8.data_ptr(); + int32_t zp1 = extractScalarToInt32(input1_zero_point); + int32_t zp2 = extractScalarToInt32(input2_zero_point); + const int32_t out_zp = extractScalarToInt32(output_zero_point); + const int32_t output_mult = extractScalarToInt32(output_multiplier); + const int32_t output_shift_val = extractScalarToInt32(output_shift); + + int32_t muls_per_loop = 0; + + if (channel_broadcast) { + if (input1_int8.numel() < input2_int8.numel()) { + std::swap(zp1, zp2); + std::swap(input1_ptr, input2_ptr); + } + + muls_per_loop = input1_int8.size(1); + } else { + muls_per_loop = out.numel(); + } + // Note 1: The CMSIS-NN kernel implementation uses offsets which are always + // added to the data, whereas zero_points are subtracted when dequantizing + // (for the inputs) and added when quantizing (for the output). Hence the + // negative signs required here. + + // Note 2: The following rewrite is used + // yq = y / scale_out + zp_out + // y = x_1*x_2 + // x_i = scale_in_i * (xq_i - xq_i), i = 1, 2 + // ==> + // yq = (xq_1 - zp_in1) * (xq_2 - zp_in_2) * effective_scale + zp_out + // where + // effective_scale = (scale_in1 * scale_in2 / scale_out) + // Hence no input quantization params required here. + + for (int32_t broadcast_offset = 0; broadcast_offset < out.numel(); + broadcast_offset += muls_per_loop) { + // Call CMSIS-NN elementwise multiply kernel + arm_cmsis_nn_status status = arm_elementwise_mul_s8( + input1_ptr + broadcast_offset, + input2_ptr, + -static_cast(zp1), + -static_cast(zp2), + out.mutable_data_ptr() + broadcast_offset, + static_cast(out_zp), + output_mult, + output_shift_val, + kInt8ActivationMin, + kInt8ActivationMax, + muls_per_loop); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_mul_out: arm_elementwise_mul_s8 failed with status [%d]", + status); + context.fail(Error::Internal); + return out; + } + } + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/op_transpose.cpp b/backends/cortex_m/ops/op_transpose.cpp new file mode 100644 index 00000000000..7befafc3791 --- /dev/null +++ b/backends/cortex_m/ops/op_transpose.cpp @@ -0,0 +1,124 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +#include +#include +#include + +// Include CMSIS-NN headers with C linkage +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +namespace { + +constexpr size_t kMaxSupportedDims = 4; + +} // namespace + +Tensor& transpose_out( + KernelRuntimeContext& context, + const Tensor& input, + const IntArrayRef perm, + Tensor& out) { + if (input.scalar_type() != ScalarType::Char || + out.scalar_type() != ScalarType::Char) { + ET_LOG( + Error, + "transpose_out: only int8 tensors are supported (input=%d, out=%d)", + static_cast(input.scalar_type()), + static_cast(out.scalar_type())); + context.fail(Error::InvalidArgument); + return out; + } + + const size_t rank = input.dim(); + if (rank == 0 || rank > kMaxSupportedDims) { + ET_LOG( + Error, + "transpose_out: expected tensor rank in [1, %zu], got %zu", + kMaxSupportedDims, + rank); + context.fail(Error::InvalidArgument); + return out; + } + + if (perm.size() != static_cast(rank)) { + ET_LOG( + Error, + "transpose_out: permutation length %zd does not match tensor rank %zu", + perm.size(), + rank); + context.fail(Error::InvalidArgument); + return out; + } + + std::array input_dims_arr{1, 1, 1, 1}; + std::array output_dims_arr{1, 1, 1, 1}; + for (size_t i = 0; i < rank; ++i) { + const auto in_size = input.size(i); + const auto out_size = out.size(i); + if (in_size > std::numeric_limits::max() || + out_size > std::numeric_limits::max()) { + ET_LOG( + Error, + "transpose_out: dimension size exceeds int32_t range (input=%lld, output=%lld)", + static_cast(in_size), + static_cast(out_size)); + context.fail(Error::InvalidArgument); + return out; + } + input_dims_arr[i] = static_cast(in_size); + output_dims_arr[i] = static_cast(out_size); + } + + cmsis_nn_dims input_dims = { + input_dims_arr[0], + input_dims_arr[1], + input_dims_arr[2], + input_dims_arr[3]}; + cmsis_nn_dims output_dims = { + output_dims_arr[0], + output_dims_arr[1], + output_dims_arr[2], + output_dims_arr[3]}; + + std::array perm_buffer{0, 1, 2, 3}; + for (size_t i = 0; i < rank; ++i) { + perm_buffer[i] = static_cast(perm[i]); + } + + const cmsis_nn_transpose_params transpose_params{ + static_cast(rank), perm_buffer.data()}; + + const int8_t* input_data = input.const_data_ptr(); + int8_t* output_data = out.mutable_data_ptr(); + + const arm_cmsis_nn_status status = arm_transpose_s8( + input_data, output_data, &input_dims, &output_dims, &transpose_params); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "transpose_out: arm_transpose_s8 failed with status [%d]", + static_cast(status)); + context.fail(Error::Internal); + return out; + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 926dcd85e4b..291615f613a 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -1,13 +1,19 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from math import prod +from typing import Sequence + import torch +import torch.nn.functional as F from executorch.backends.cortex_m.passes.passes_utils import ( - dequantize_per_tensor_cmsis, - quantize_per_tensor_cmsis, + is_channel_broadcast, + requantize_cmsis, + SHIFT_INT8, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -111,6 +117,15 @@ def dequantize_per_tensor_impl( "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor" ) +# Define the operator schema with multipliers and shifts (11 args + out tensor) +lib.define( + "quantized_add.out(" + "Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, " + "Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, " + "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, " + "*, Tensor(a!) out) -> Tensor(a!)" +) + @register_fake("cortex_m::quantized_add") def quantized_add_meta( @@ -126,8 +141,15 @@ def quantized_add_meta( output_multiplier: int, output_shift: int, ) -> torch.Tensor: - broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) - return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) + assert self.shape == other.shape or is_channel_broadcast(self, other), ( + "Cortex-M quantized_add: broadcasting is not yet supported except for channel dim — " + f"got self.shape={self.shape}, other.shape={other.shape}" + ) + if self.numel() > other.numel(): + output_tensor = self + else: + output_tensor = other + return torch.empty_like(output_tensor) @impl(lib, "quantized_add", "CompositeExplicitAutograd") @@ -144,82 +166,414 @@ def quantized_add_impl( output_multiplier: int, output_shift: int, ) -> torch.Tensor: - self_fp = dequantize_per_tensor_cmsis( - self, self_zero_point, self_multiplier, self_shift - ) - other_fp = dequantize_per_tensor_cmsis( - other, other_zero_point, other_multiplier, other_shift + assert self.shape == other.shape or is_channel_broadcast(self, other), ( + "Cortex-M quantized_add: broadcasting is not yet supported except for channel dim — " + f"got self.shape={self.shape}, other.shape={other.shape}" ) + self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8 + self_fp = requantize_cmsis(self_shifted, self_multiplier, self_shift) + + other_shifted = (other.to(torch.int32) - other_zero_point) << SHIFT_INT8 + other_fp = requantize_cmsis(other_shifted, other_multiplier, other_shift) + result_fp = self_fp + other_fp - result_quantized = quantize_per_tensor_cmsis( - result_fp, output_zero_point, output_multiplier, output_shift - ) - return result_quantized + result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift) + result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8) + return result -# Define the operator schema with multipliers and shifts (11 args + out tensor) +# =================================================================== +# QUANTIZED MUL OPERATION DEFINITION +# =================================================================== lib.define( - "quantized_add.out(" - "Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, " - "Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, " + "quantized_mul(" + "Tensor self, Scalar self_zero_point, " + "Tensor other, Scalar other_zero_point, " + "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor" +) +lib.define( + "quantized_mul.out(" + "Tensor self, Scalar self_zero_point, " + "Tensor other, Scalar other_zero_point, " "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, " "*, Tensor(a!) out) -> Tensor(a!)" ) -# Fake meta function for shape and dtype inference during compilation -@register_fake("cortex_m::quantized_add.out") -def quantized_add_out_meta( +@register_fake("cortex_m::quantized_mul") +def quantized_mul_meta( self: torch.Tensor, self_zero_point: int, - self_multiplier: int, - self_shift: int, other: torch.Tensor, other_zero_point: int, - other_multiplier: int, - other_shift: int, output_zero_point: int, output_multiplier: int, output_shift: int, - out: torch.Tensor, ) -> torch.Tensor: - # Validate against correct broadcasted shape - expected_shape = torch.broadcast_shapes(self.shape, other.shape) - assert ( - out.shape == expected_shape - ), f"Output shape {out.shape} must match broadcasted shape {expected_shape}" - return out + # Broadcast to output shape + assert self.shape == other.shape or is_channel_broadcast(self, other), ( + "Cortex-M quantized_mul: broadcasting is not yet supported except for channel dim — " + f"got self.shape={self.shape}, other.shape={other.shape}" + ) + if self.numel() > other.numel(): + output_tensor = self + else: + output_tensor = other + return torch.empty_like(output_tensor) -# Actual implementation delegating to backend or custom kernel -@impl(lib, "quantized_add.out", "CompositeExplicitAutograd") -def quantized_add_out_impl( +@impl(lib, "quantized_mul", "CompositeExplicitAutograd") +def quantized_mul_impl( self: torch.Tensor, self_zero_point: int, - self_multiplier: int, - self_shift: int, other: torch.Tensor, other_zero_point: int, - other_multiplier: int, - other_shift: int, output_zero_point: int, output_multiplier: int, output_shift: int, - *, - out: torch.Tensor, ) -> torch.Tensor: - self_fp = dequantize_per_tensor_cmsis( - self, self_zero_point, self_multiplier, self_shift + # CMSIS-NN kernel multiplies raw int8 tensors (after zero-point offset) and + # only uses the output multiplier/shift for rescaling. Mirror that here to + # keep the composite implementation numerically aligned with the backend. + assert self.shape == other.shape or is_channel_broadcast(self, other), ( + "Cortex-M quantized_mul: broadcasting is not yet supported except for channel dim — " + f"got self.shape={self.shape}, other.shape={other.shape}" ) - other_fp = dequantize_per_tensor_cmsis( - other, other_zero_point, other_multiplier, other_shift + self_int = self.to(torch.int32) - self_zero_point + other_int = other.to(torch.int32) - other_zero_point + result_fp = self_int * other_int + result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift) + result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8) + return result + + +# =================================================================== +# MINIMUM/MAXIMUM OPERATION DEFINITIONS +# =================================================================== +lib.define("minimum(Tensor self, Tensor other) -> Tensor") +lib.define("minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)") + + +@register_fake("cortex_m::minimum") +def minimum_meta(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + assert self.dtype == other.dtype, ( + "Cortex-M minimum: dtype mismatch — " + f"got self.dtype={self.dtype}, other.dtype={other.dtype}" ) - result_fp = self_fp + other_fp - result_quantized = quantize_per_tensor_cmsis( - result_fp, output_zero_point, output_multiplier, output_shift + broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) + return torch.empty(broadcasted_shape, dtype=self.dtype, device=self.device) + + +@impl(lib, "minimum", "CompositeExplicitAutograd") +def minimum_impl(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.minimum(self, other) + + +lib.define("maximum(Tensor self, Tensor other) -> Tensor") +lib.define("maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)") + + +@register_fake("cortex_m::maximum") +def maximum_meta(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + assert self.dtype == other.dtype, ( + "Cortex-M maximum: dtype mismatch — " + f"got self.dtype={self.dtype}, other.dtype={other.dtype}" ) + broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) + return torch.empty(broadcasted_shape, dtype=self.dtype, device=self.device) + + +@impl(lib, "maximum", "CompositeExplicitAutograd") +def maximum_impl(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.maximum(self, other) + + +# =================================================================== +# QUANTIZED LINEAR OPERATION DEFINITION +# =================================================================== + +lib.define( + "quantized_linear.out(" + "Tensor input, " + "Tensor weights, " + "Tensor? bias, " + "Tensor? kernel_sum, " + "Scalar input_offset, " + "Scalar filter_offset, " + "Scalar output_offset, " + "int[] requantize_multipliers, " + "int[] requantize_shifts, " + "Scalar activation_max, " + "Scalar activation_min, " + "*, Tensor(a!) out" + ") -> Tensor(a!)" +) + +# Define functional variant (non-out version) +lib.define( + "quantized_linear(" + "Tensor input, " + "Tensor weights, " + "Tensor? bias, " + "Tensor? kernel_sum, " + "Scalar input_offset, " + "Scalar filter_offset, " + "Scalar output_offset, " + "int[] requantize_multipliers, " + "int[] requantize_shifts, " + "Scalar activation_max, " + "Scalar activation_min" + ") -> Tensor" +) + + +# Fake meta function for shape inference (functional variant) +@register_fake("cortex_m::quantized_linear") +def quantized_linear_meta( + input, + weights, + bias, + kernel_sum, + input_offset, + filter_offset, + output_offset, + requantize_multipliers, + requantize_shifts, + activation_max, + activation_min, +) -> torch.Tensor: + + shape = (*input.shape[:-1], weights.shape[0]) + return torch.empty(shape, dtype=input.dtype, device=input.device) + + +# Functional variant implementation +@impl(lib, "quantized_linear", "CompositeExplicitAutograd") +def quantized_linear_impl( + input: torch.Tensor, + weights: torch.Tensor, + bias: torch.Tensor, + kernel_sum: torch.Tensor, + input_offset: int, + filter_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_max: int, + activation_min: int, +) -> torch.Tensor: + """ + Functional variant - creates output tensor and calls out variant + """ + + # Leaving both implementations for debugging purposes. + compute_using_kernel_sum = True + + if compute_using_kernel_sum: + weights_int32 = weights.to(torch.int32) + + input_int32 = input.to(torch.int32) + new_shape = (prod(input.shape[:-1]), input.shape[-1]) + input_reshaped = input_int32.reshape(new_shape) + + lhs_sum = torch.sum(input_reshaped, dim=-1, keepdim=True) * filter_offset + output = torch.mm(input_reshaped, weights_int32.T) + lhs_sum + kernel_sum + output_shape = (*input.shape[:-1], output.shape[-1]) + output_reshaped = output.reshape(output_shape) + else: + weights_int32 = weights.to(torch.int32) + filter_offset + + input_int32 = input.to(torch.int32) + input_offset + new_shape = (prod(input.shape[:-1]), input.shape[-1]) + input_reshaped = input_int32.reshape(new_shape) + + output = torch.mm(input_reshaped, weights_int32.T) + if bias is not None: + output = output + bias + output_shape = (*input.shape[:-1], output.shape[-1]) + output_reshaped = output.reshape(output_shape) + + output = requantize_cmsis( + output_reshaped, requantize_multipliers[0], requantize_shifts[0] + ) + output += output_offset + output = torch.clamp(output, activation_min, activation_max).to(torch.int8) + return output + + +# =================================================================== +# TRANSPOSE OPERATION DEFINITION +# =================================================================== +lib.define("transpose(Tensor input, int[] perm) -> Tensor") +lib.define("transpose.out(Tensor input, int[] perm, *, Tensor(a!) out) -> Tensor(a!)") + + +@register_fake("cortex_m::transpose") +def transpose_meta(input: torch.Tensor, perm) -> torch.Tensor: + output_shape = [input.shape[idx] for idx in perm] + return torch.empty(output_shape, dtype=input.dtype, device=input.device) + + +@impl(lib, "transpose", "CompositeExplicitAutograd") +def transpose_impl(input: torch.Tensor, perm) -> torch.Tensor: + return input.permute(tuple(perm)).contiguous() + + +# =================================================================== +# QUANTIZED CONV2D OPERATION DEFINITION +# =================================================================== + +lib.define( + "quantized_conv2d(" + "Tensor input, " + "Tensor weight, " + "Tensor? bias, " + "int[] stride, " + "int[] padding, " + "int[] dilation, " + "int input_offset, " + "int output_offset, " + "Tensor requantize_multipliers, " + "Tensor requantize_shifts, " + "int activation_min, " + "int activation_max" + ") -> Tensor" +) + + +lib.define( + "quantized_conv2d.out(" + "Tensor input, " + "Tensor weight, " + "Tensor? bias, " + "int[] stride, " + "int[] padding, " + "int[] dilation, " + "int input_offset, " + "int output_offset, " + "Tensor requantize_multipliers, " + "Tensor requantize_shifts, " + "int activation_min, " + "int activation_max, " + "*, Tensor(a!) out" + ") -> Tensor(a!)" +) + + +def _compute_conv2d_output_shape( + input_shape: torch.Size, + weight_shape: torch.Size, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], +) -> torch.Size: + batch = input_shape[0] + in_height = input_shape[2] + in_width = input_shape[3] + # We store the weights in OHWI layout (out, kernel_h, kernel_w, in) + kernel_height = weight_shape[1] + kernel_width = weight_shape[2] + + stride_h, stride_w = stride + pad_h, pad_w = padding + dilation_h, dilation_w = dilation + + out_channels = weight_shape[0] + out_height = ( + in_height + 2 * pad_h - dilation_h * (kernel_height - 1) - 1 + ) // stride_h + 1 + out_width = ( + in_width + 2 * pad_w - dilation_w * (kernel_width - 1) - 1 + ) // stride_w + 1 + return torch.Size([batch, out_channels, out_height, out_width]) + + +@register_fake("cortex_m::quantized_conv2d") +def quantized_conv2d_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + input_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_min: int, + activation_max: int, +) -> torch.Tensor: + stride_vals = list(stride) + padding_vals = list(padding) + dilation_vals = list(dilation) + output_shape = _compute_conv2d_output_shape( + input.shape, weight.shape, stride_vals, padding_vals, dilation_vals + ) + return torch.empty( + output_shape, + dtype=torch.int8, + device=input.device, + memory_format=torch.channels_last, + ) + + +@impl(lib, "quantized_conv2d", "CompositeExplicitAutograd") +def quantized_conv2d_impl( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + input_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_min: int, + activation_max: int, +) -> torch.Tensor: + if input.dim() != 4 or weight.dim() != 4: + raise RuntimeError("quantized_conv2d expects 4D input and weight tensors") + # Convert to int32 for accumulation and apply offsets + input_int32 = input.to(torch.int32) + int(input_offset) + weight_int32 = weight.to(torch.int32) + + if bias is None: + bias_int32 = torch.zeros( + weight.shape[0], dtype=torch.int32, device=input.device + ) + else: + bias_int32 = bias.to(torch.int32) + + input_channels = input.shape[1] + kernel_input_channels = weight.shape[3] + groups = input_channels // kernel_input_channels + + # Convert weights back to OIHW layout expected by torch.nn.functional.conv2d + weight_oi_hw = weight_int32.permute(0, 3, 1, 2).contiguous() + + conv_acc = F.conv2d( + input_int32, + weight_oi_hw, + bias_int32, + stride=tuple(stride), + padding=tuple(padding), + dilation=tuple(dilation), + groups=groups, + ) + + result_channels = [] + for output_channel_i in range(conv_acc.shape[1]): + result_channel = requantize_cmsis( + conv_acc[:, output_channel_i, :, :], + int(requantize_multipliers[output_channel_i]), + int(requantize_shifts[output_channel_i]), + ) + result_channels.append(result_channel) + + result = torch.stack(result_channels, dim=1) - # Write into the provided output tensor - out.copy_(result_quantized) + result += output_offset + result = torch.clamp(result, activation_min, activation_max) - return out + return result.to(torch.int8) diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index f2615a1f525..0b0b2f5c715 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -16,14 +17,44 @@ - arg_meta: null kernel_name: cortex_m::dequantize_per_tensor_out -- func: cortex_m::quantized_add(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor +- func: cortex_m::quantized_add.out(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: cortex_m::quantized_add + kernel_name: cortex_m::quantized_add_out -- func: cortex_m::quantized_add.out(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_mul.out(Tensor self, Scalar self_zero_point, Tensor other, Scalar other_zero_point, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: cortex_m::quantized_add_out + kernel_name: cortex_m::quantized_mul_out + +- func: cortex_m::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::minimum_out + +- func: cortex_m::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::maximum_out + +- func: cortex_m::quantized_linear.out(Tensor input, Tensor weights, Tensor? bias, Tensor? kernel_sum, Scalar input_offset, Scalar filter_offset, Scalar output_offset, int[] requantize_multipliers, int[] requantize_shifts, Scalar activation_max, Scalar activation_min, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantized_linear_out + +- func: cortex_m::transpose.out(Tensor input, int[] perm, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::transpose_out + +- func: cortex_m::quantized_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantized_conv2d_out diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py new file mode 100644 index 00000000000..c8bb743e278 --- /dev/null +++ b/backends/cortex_m/passes/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .activation_fusion_pass import ActivationFusionPass # noqa +from .clamp_hardswish_pass import ClampHardswishPass # noqa +from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa +from .decompose_hardswish_pass import DecomposeHardswishPass # noqa +from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa +from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa +from .cortex_m_pass_manager import CortexMPassManager # noqa # usort: skip diff --git a/backends/cortex_m/passes/activation_fusion_pass.py b/backends/cortex_m/passes/activation_fusion_pass.py new file mode 100644 index 00000000000..864f9e47ec8 --- /dev/null +++ b/backends/cortex_m/passes/activation_fusion_pass.py @@ -0,0 +1,181 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging + +import executorch.backends.cortex_m.ops.operators # noqa: F401 +from executorch.backends.arm._passes.quant_args import QuantArgs +from executorch.backends.cortex_m.passes.passes_utils import quantize_val + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_manager import PassResult + +logger = logging.getLogger(__name__) + + +class ActivationFusionPass(ExportPass): + """Fuse activations into preceding Cortex-M quantized operators. + + Supported activation patterns: + q-> [conv2d, linear] -> [relu, hardtanh, hardsigmoid] -> dq + + Fusing works by clamping the quantized output range (and zero-point when + required) of the preceding Cortex-M operator, then removing the activation + node from the graph. + """ + + TARGETS = { + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.hardsigmoid.default, + exir_ops.edge.aten.clamp.default, + } + + FUSE_OPS = { + exir_ops.edge.aten.linear.default, + exir_ops.edge.aten.convolution.default, + } + + def _get_validated_qparams(self, node, input_node): + + if "input_qparams" not in input_node.meta or "output_qparams" not in node.meta: + logger.warning( + f"Cannot fuse activation for {input_node.name}->{node.name} as the pattern wasn't quantized properly." + ) + return None + + qparams_dict = node.meta["output_qparams"][0]._asdict() + zp = qparams_dict["zp"] + scale = qparams_dict["scale"] + qmin = qparams_dict["qmin"] + qmax = qparams_dict["qmax"] + + if not isinstance(scale, float) or not isinstance(zp, int): + logger.warning( + f"Cannot fuse activation {node.name} as quantization parameters are not per tensor." + ) + return None + + match node.target: + case exir_ops.edge.aten.relu.default: + quantized_min_val = quantize_val(0, scale, zp, qmin, qmax) + quantized_max_val = qmax + case exir_ops.edge.aten.hardtanh.default: + quantized_min_val = quantize_val(node.args[1], scale, zp, qmin, qmax) + quantized_max_val = quantize_val(node.args[2], scale, zp, qmin, qmax) + case exir_ops.edge.aten.hardsigmoid.default: + quantized_min_val = quantize_val(0, scale, zp, qmin, qmax) + quantized_max_val = quantize_val(1, scale, zp, qmin, qmax) + case exir_ops.edge.aten.clamp.default: + quantized_min_val = ( + quantize_val(node.args[1], scale, zp, qmin, qmax) + if node.args[1] is not None + else qmin + ) + # Last arg is removed if none, so check length of args here + quantized_max_val = ( + quantize_val(node.args[2], scale, zp, qmin, qmax) + if len(node.args) == 3 + else qmax + ) + case _: + raise RuntimeError("Unexpected target {node.target}.") + + # If the minimal quantized value is larger than the qmin, it means that the quantized range contains + # invalid values [qmin, ..., quantized_min_val-1], indicating bad quantization parameters. + if qparams_dict["qmin"] != quantized_min_val: + logger.warning( + f"Cannot fuse activation {node.name} as qmin is out of range." + ) + return None + + # If the maximal quantized value is smaller than the qmax, it means that the quantized range contains + # invalid values [quantized_max_val + 1, ... , qmax], indicating bad quantization parameters. + if quantized_max_val != qparams_dict["qmax"]: + logger.warning( + f"Cannot fuse activation {node.name} as qmax is out of range." + ) + return None + + return qparams_dict + + def _update_qparams_hardsigmoid(self, quant_dict): + """ + Returns quant_dict with scale and zp updated to match hardsigmoid activation. + + The quantized output from the hard sigmoid is defined by + Q(y) = clamp(round(y/scale + zp), qmin, qmax) + y = clamp(x/6 + 1/2, 0, 1) + where x is the output of the fused activation op, conv or linear. + + Q(y) can be rewritten as a function of only x: + Q(y) = clamp(round(clamp(x/6 + 1/2, 0, 1)/scale + zp), qmin, qmax) + Q(y) = clamp(round(clamp((x/(6*scale) + 1/(2*scale) + zp, zp, 1/scale + zp)), qmin, qmax) + + From definition of the qparams mapping the output in the range [0,1] to quantized range + [qmin, qmax], we have: + zp = Q(0) <= qmin + 1/scale + zp = Q(1) >= qmax + which makes the inner clamp redundant. + + Therefore, hardsigmoid is equivalent to a quantization with modified parameters + new_scale := 6*scale + new_zp = zp + 1/(2*scale) ~= zp + round(1/(2*scale)) + """ + + new_scale = quant_dict["scale"] * 6 + + new_zp = quant_dict["zp"] + round(1 / (2 * quant_dict["scale"])) + clamped_new_zp = max(quant_dict["qmin"], min(quant_dict["qmax"], new_zp)) + + quant_dict["scale"] = new_scale + quant_dict["zp"] = clamped_new_zp + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + nodes_to_erase: list[Node] = [] + + for node in list(graph_module.graph.nodes): + if node.op != "call_function" or node.target not in self.TARGETS: + continue + + input_node = node.args[0] + if ( + input_node.op != "call_function" + or input_node.target not in self.FUSE_OPS + ): + logger.warning( + f"Cannot fuse activation {node.name} as input node {input_node.name} is not a supported fused activation op." + ) + continue + if len(input_node.users.values()) > 1: + logger.warning( + f"Cannot fuse activation {node.name} as input node {input_node.name} has multiple users." + ) + continue + + if (qparams_dict := self._get_validated_qparams(node, input_node)) is None: + continue + + if node.target == exir_ops.edge.aten.hardsigmoid.default: + self._update_qparams_hardsigmoid(qparams_dict) + + input_node.meta["output_qparams"][0] = QuantArgs(**qparams_dict) + + node.replace_all_uses_with(input_node) + nodes_to_erase.append(node) + modified = True + + for node in nodes_to_erase: + graph_module.graph.erase_node(node) + + if modified: + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/passes/clamp_hardswish_pass.py b/backends/cortex_m/passes/clamp_hardswish_pass.py new file mode 100644 index 00000000000..d257520499e --- /dev/null +++ b/backends/cortex_m/passes/clamp_hardswish_pass.py @@ -0,0 +1,37 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch + +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch.fx.node import Argument + + +class ClampHardswishPass(ExportPass): + """ + Adds a clamp operation before hardswish to ensure input is in the range [-3, inf). + + By doing this before quantization the output range of the preceeding op is minimized, + potentially improving accuracy. + """ + + def call_operator( + self, + op: EdgeOpOverload, + args: tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op == torch.ops.aten.hardswish.default: + clamped_args = (args[0], -3) + clamped_input = super().call_operator( + torch.ops.aten.clamp.default, clamped_args, {}, meta + ) + args = (clamped_input,) + + return super().call_operator(op, args, kwargs, meta) diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py new file mode 100644 index 00000000000..5a142efd639 --- /dev/null +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import executorch.backends.cortex_m.ops.operators # noqa + +import torch +import torch.fx +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot + +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + get_param_tensor, +) + +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export.graph_signature import InputKind +from torch.fx.passes.infra.pass_manager import PassResult + + +class ConvertToCortexMPass(XNNPACKPass): + """ + Cortex-M backend pass for replacing supported quantized kernels with Cortex-M + accelerated kernels. + + Used for ops which require changes to input tensors which is not supported + by call_operator. + """ + + def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): + """ + Computes the precomputed kernel sum term (bias optional) + a * sum_j(wij + b) + ci + + for i = (1, ..., n), where j indexes the input activations. + """ + weights_transposed = weights.T + weights_int32 = weights_transposed.to(torch.int32) + offset_weights = weights_int32 + weight_offset + kernel_sum = torch.sum(offset_weights, dim=0, keepdim=True, dtype=torch.int32) + kernel_sum_offset = kernel_sum * input_offset + + if bias is not None: + kernel_sum_offset += bias + + return kernel_sum_offset + + def _get_linear_replacement(self, node): + """ + Let + - yi be the output activations (y1, ... yn) + - xj be the input activations (x1, ... xm) + - wij be the weights (w11, ... wnm) + - a be the input offset + - b be the weight offset + - ci be the bias + + Then the linear operation can be written as: + yi = sum_j((xj + a) * (wij + b)) + ci + = sum_j(xj*wij + xj*b + a*wij + a*b) + ci + = sum_j(xj*wij) + sum_j(xj)*b + (a * sum_j(wij + b) + ci) + = sum_j(xj*wij) + sum_j(xj)*b + kernel_sum + + where kernel_sum is precomputed aot. + """ + input_scale = node.meta["input_qparams"][0].scale + input_zp = node.meta["input_qparams"][0].zp + weight_scale = node.meta["input_qparams"][1].scale + weight_zp = node.meta["input_qparams"][1].zp + output_scale = node.meta["output_qparams"][0].scale + output_zp = node.meta["output_qparams"][0].zp + output_min = node.meta["output_qparams"][0].qmin + output_max = node.meta["output_qparams"][0].qmax + + quantized_multiplier, quantized_shift = quantize_multiplier_aot( + (input_scale * weight_scale) / output_scale + ) + + # TODO: Add support for configuring the backend to support other extensions. + # Kernel sum is only used in the CMSIS-NN implementation for the MVE extension, + # so this should be optional. + weights = node.args[1] + weights_tensor = get_param_tensor(self.exported_program, weights) + bias_tensor = ( + get_param_tensor(self.exported_program, node.args[2]) + if len(node.args) > 2 + else None + ) + kernel_sum_tensor = self._compute_kernel_sum( + weights_tensor, bias_tensor, -input_zp, -weight_zp + ) + with node.graph.inserting_after(weights): + kernel_sum = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_kernel_sum", + InputKind.PARAMETER, + kernel_sum_tensor, + ) + + args = ( + node.args[0], + weights, + None, + kernel_sum, + -input_zp, + -weight_zp, + output_zp, + [quantized_multiplier], + [quantized_shift], + output_max, + output_min, + ) + + return exir_ops.edge.cortex_m.quantized_linear.default, args + + def _get_convolution_replacement(self, node) -> int: + ( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) = node.args + + # Extract values + input_scale = node.meta["input_qparams"][0].scale + input_zero_point = node.meta["input_qparams"][0].zp + weight_scales = node.meta["input_qparams"][1].scale + if not isinstance(weight_scales, list): + weight_tensor = get_first_fake_tensor(weight) + weight_scales = [weight_scales] * weight_tensor.shape[0] + + output_qparams = node.meta["output_qparams"][0] + output_scale = output_qparams.scale + output_zero_point = output_qparams.zp + output_qmin = output_qparams.qmin + output_qmax = output_qparams.qmax + + quantized_multipliers = [] + quantized_shifts = [] + for weight_scale in weight_scales: + quantized_multiplier, quantized_shift = quantize_multiplier_aot( + input_scale * weight_scale / output_scale + ) + quantized_multipliers.append(quantized_multiplier) + quantized_shifts.append(quantized_shift) + + # Permute the weight tensor to the OHWI layout expected by CMSIS-NN. + weight_tensor = get_param_tensor(self.exported_program, weight) + weight_permuted = weight_tensor.permute(0, 2, 3, 1).contiguous( + memory_format=torch.channels_last + ) + + with node.graph.inserting_after(weight): + weight_nhwc = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_weight_nhwc", + InputKind.PARAMETER, + weight_permuted, + ) + + quantized_multiplier_tensor = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_quantized_multiplier", + InputKind.PARAMETER, + torch.tensor(quantized_multipliers, dtype=torch.int32), + ) + + quantized_shift_tensor = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_quantized_shift", + InputKind.PARAMETER, + torch.tensor(quantized_shifts, dtype=torch.int32), + ) + + new_args = ( + x, + weight_nhwc, + bias, + stride, + padding, + dilation, + -input_zero_point, + output_zero_point, + quantized_multiplier_tensor, + quantized_shift_tensor, + output_qmin, + output_qmax, + ) + return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if ( + node.meta.get("input_qparams", {}) == {} + or node.meta.get("output_qparams", {}) == {} + ): + continue + + match node.target: + case exir_ops.edge.aten.linear.default: + op, args = self._get_linear_replacement(node) + case exir_ops.edge.aten.convolution.default: + op, args = self._get_convolution_replacement(node) + case _: + continue + + with graph_module.graph.inserting_before(node): + cortex_m_op = graph_module.graph.create_node( + "call_function", + target=op, + args=args, + kwargs={}, + ) + + node.replace_all_uses_with(cortex_m_op) + graph_module.graph.erase_node(node) + + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py new file mode 100644 index 00000000000..bd3fad1cf94 --- /dev/null +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -0,0 +1,74 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import inspect + +from executorch.backends.arm._passes import ( + FoldAndAnnotateQParamsPass, + ScalarsToAttributePass, +) +from executorch.backends.cortex_m.passes import ( + ActivationFusionPass, + ClampHardswishPass, + ConvertToCortexMPass, + DecomposeHardswishPass, + QuantizedOpFusionPass, + ReplaceQuantNodesPass, +) +from executorch.backends.transforms.replace_scalar_with_tensor import ( + ReplaceScalarWithTensorArgPass, +) +from executorch.exir.pass_base import ExportPass +from executorch.exir.pass_manager import PassManager +from executorch.exir.program._program import _transform +from torch.export import ExportedProgram + + +class CortexMPassManager(PassManager): + + pass_list: list[ExportPass] = [ + FoldAndAnnotateQParamsPass, + ReplaceScalarWithTensorArgPass, + ReplaceQuantNodesPass, + ActivationFusionPass, + DecomposeHardswishPass, + QuantizedOpFusionPass, + ConvertToCortexMPass, + ] + + pass_list_transform_for_annotation: list[ExportPass] = [ + ScalarsToAttributePass, + ReplaceScalarWithTensorArgPass, + ClampHardswishPass, + ] + + def __init__(self, exported_program, passes=None): + self.exported_program = exported_program + if passes is not None: + self.passes = passes + else: + self.passes = self.pass_list + + def transform_for_annotation(self, model): + passes = self.pass_list_transform_for_annotation + for p in passes: + model = p().call(model).graph_module + return model + + def transform(self) -> ExportedProgram: + ep = self.exported_program + for pass_ in self.passes: + signature = inspect.signature(pass_.__init__) + if "exported_program" in signature.parameters: + transform_pass = pass_(ep) + elif issubclass(pass_, ExportPass): + transform_pass = pass_() + else: + raise RuntimeError( + f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}" + ) + ep = _transform(ep, transform_pass) + return ep diff --git a/backends/cortex_m/passes/decompose_hardswish_pass.py b/backends/cortex_m/passes/decompose_hardswish_pass.py new file mode 100644 index 00000000000..36ca6bd759d --- /dev/null +++ b/backends/cortex_m/passes/decompose_hardswish_pass.py @@ -0,0 +1,127 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging + +import executorch.backends.cortex_m.ops.operators # noqa: F401 + +import torch +from executorch.backends.arm._passes.quant_args import QuantArgs + +from executorch.backends.cortex_m.passes.passes_utils import quantize_val + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_manager import PassResult + +logger = logging.getLogger(__name__) + + +class DecomposeHardswishPass(ExportPass): + """ + Decomposes hardswish like + + hardswish(x) = x * (clamp(x, -3, 3) + 3)/6 + + where the add and division is implemented by modifying the quantization parameters similar + to hardsigmoid in the activation_fusion_pass. Note that this pass assumes + that the output range of the preceding op is already clamped to [-3, inf] during + quantization by the clamp_hardswish_pass, removing the need for the negative clamp. + """ + + TARGETS = { + exir_ops.edge.aten.hardswish.default, + } + + FUSE_OPS = { + exir_ops.edge.aten.linear.default, + exir_ops.edge.aten.convolution.default, + } + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + nodes_to_erase: list[Node] = [] + + for node in list(graph_module.graph.nodes): + if node.op != "call_function" or node.target not in self.TARGETS: + continue + + input_node = node.args[0] + if ( + input_node.op != "call_function" + or input_node.target not in self.FUSE_OPS + ): + logger.warning( + f"Cannot fuse activation {node.name} as input node {input_node.name} is not a supported fused activation op." + ) + continue + if len(input_node.users.values()) > 1: + logger.warning( + f"Cannot fuse activation {node.name} as input node {input_node.name} has multiple users." + ) + continue + + input_quant_dict = input_node.meta.get("output_qparams", [None])[ + 0 + ]._asdict() + scale = input_quant_dict["scale"] + zero_point = input_quant_dict["zp"] + qmin = input_quant_dict["qmin"] + qmax = input_quant_dict["qmax"] + + # Create min node + with graph_module.graph.inserting_after(input_node): + clamp_node = graph_module.graph.create_node( + "call_function", + target=exir_ops.edge.aten.minimum.default, + args=( + input_node, + torch.tensor( + quantize_val(3, scale, zero_point, qmin, qmax), + dtype=torch.int8, + ), + ), + kwargs={}, + ) + clamp_node.meta = input_node.meta.copy() + + # Create mul node + with graph_module.graph.inserting_after(clamp_node): + mul_node = graph_module.graph.create_node( + "call_function", + target=exir_ops.edge.aten.mul.Tensor, + args=(input_node, clamp_node), + kwargs={}, + ) + mul_node.meta = node.meta.copy() + + mul_quant_dict = node.meta["input_qparams"][0]._asdict() + + mul_quant_dict_shifted = mul_quant_dict.copy() + mul_quant_dict_shifted["zp"] = mul_quant_dict_shifted["zp"] - round( + 3 / (mul_quant_dict_shifted["scale"]) + ) + + output_quant_dict = node.meta["output_qparams"][0]._asdict() + output_quant_dict["scale"] = output_quant_dict["scale"] * 6 + + node.meta["input_qparams"][0] = QuantArgs(**mul_quant_dict) + mul_node.meta["input_qparams"][1] = QuantArgs(**mul_quant_dict_shifted) + mul_node.meta["output_qparams"][0] = QuantArgs(**output_quant_dict) + + node.replace_all_uses_with(mul_node) + nodes_to_erase.append(node) + modified = True + + for node in nodes_to_erase: + graph_module.graph.erase_node(node) + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/passes/passes_utils.py b/backends/cortex_m/passes/passes_utils.py index 3f6e05fc4de..131541fcb75 100644 --- a/backends/cortex_m/passes/passes_utils.py +++ b/backends/cortex_m/passes/passes_utils.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,6 +9,17 @@ import torch +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.fx import Node + +# L-shift value used in CMSIS-NN for int8 operations +SHIFT_INT8 = 20 + + +def quantize_val(val, scale, zp, qmin, qmax): + return min(max(round(val / scale + zp), qmin), qmax) + def dequantize_per_tensor_cmsis( qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int @@ -37,6 +49,39 @@ def quantize_per_tensor_cmsis( return quantized.clamp(qmin, qmax).to(torch.int8) +def requantize_cmsis( + tensor: torch.Tensor, + multiplier: int, + shift: int, +) -> torch.Tensor: + """Simulate CMSIS-NN's arm_nn_requantize helper.""" + + tensor_64 = tensor.to(torch.int64) + left_shift = max(shift, 0) + right_shift = max(-shift, 0) + + # Equivalent to val * (1 << LEFT_SHIFT(shift)) + value = tensor_64 << left_shift + + # arm_nn_doubling_high_mult_no_sat(value, multiplier) + product = value * int(multiplier) + product = product + (1 << 30) + result = product >> 31 + + if right_shift: + remainder_mask = (1 << right_shift) - 1 + remainder = torch.bitwise_and(result, remainder_mask) + result = result >> right_shift + threshold = remainder_mask >> 1 + threshold_tensor = torch.full_like(result, threshold, dtype=torch.int64) + threshold_tensor = torch.where( + result < 0, threshold_tensor + 1, threshold_tensor + ) + result = result + torch.where(remainder > threshold_tensor, 1, 0) + + return result.to(torch.int32) + + def extract_scalar_value(node_arg) -> float: """ Extract scalar value from various PyTorch scalar representations. @@ -79,16 +124,103 @@ def is_qualified_int8_node(args) -> bool: def quantize_multiplier_aot(scale: float) -> tuple[int, int]: if scale == 0.0: return 0, 0 - mantissa, exponent = math.frexp(scale) - shift = -exponent + mantissa, shift = math.frexp(scale) q_fixed = int(round(mantissa * (1 << 31))) if q_fixed == (1 << 31): q_fixed //= 2 - shift -= 1 - multiplier = max(-2147483648, min(2147483647, q_fixed)) + shift += 1 + multiplier = max( + torch.iinfo(torch.int32).min, min(torch.iinfo(torch.int32).max, q_fixed) + ) return multiplier, shift def cleanup_erased_nodes(graph_module: torch.fx.GraphModule): # Placeholder for any additional cleanup if needed pass + + +def transfer_metadata( + new_node: Node, source_node: Node, pass_name: str = "QuantizedPass" +) -> None: + """Transfer metadata with proper provenance tracking.""" + if hasattr(source_node, "meta") and source_node.meta: + new_node.meta = source_node.meta.copy() + if "from_node" in new_node.meta: + from_node_list = new_node.meta.get("from_node", []).copy() + from_node_list.append( + {"source": source_node.name, "pass": pass_name, "op": "fuse"} + ) + new_node.meta["from_node"] = from_node_list + for field in ["tensor_meta", "stack_trace"]: + if field in source_node.meta: + new_node.meta[field] = source_node.meta[field] + + +def is_dequant_node(node: Node) -> bool: + """Check if node is a dequantize operation.""" + dequant_targets = { + exir_ops.edge.cortex_m.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + } + return node.op == "call_function" and node.target in dequant_targets + + +def is_quant_node(node: Node) -> bool: + """Check if node is a quantize operation.""" + quant_targets = { + exir_ops.edge.cortex_m.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + } + return node.op == "call_function" and node.target in quant_targets + + +def cleanup_nodes(nodes_to_erase, graph): + """Clean up marked nodes from graph.""" + failed_nodes = [] + + for node in reversed(nodes_to_erase): + if node in graph.nodes and len(node.users) == 0: + try: + graph.erase_node(node) + except Exception as e: + print(f"Warning: Failed to erase node {node}: {e}") + failed_nodes.append(node) + continue + + if failed_nodes: + print(f"Warning: {len(failed_nodes)} nodes could not be erased") + + return failed_nodes + + +def is_channels_last(tensor: torch.Tensor) -> bool: + """Check if a 4D tensor is in channels last format.""" + if tensor.ndim != 4: + return False + + if tensor.shape[1] == 1 or tensor.shape[2] == tensor.shape[3] == 1: + return True + + dim_order = list(tensor.dim_order()) + return dim_order[0:2] == [0, 2] + + +def is_channel_broadcast(tensor1: torch.Tensor, tensor2: torch.Tensor) -> bool: + """ + Check if tensor1 is broadcasted to tensor2 along channel dimension. + Assumes tensor2 has shape [N, C, ...] and tensor1 has shape [N, 1, ...] or [1, C, ...]. + """ + if tensor1.dim() != tensor2.dim(): + return False + if not is_channels_last(tensor1): + return False + if not is_channels_last(tensor2): + return False + + channel_match = tensor1.size(1) == tensor2.size(1) + tensor1_channels_only = tensor1.numel() == tensor1.size(1) + tensor2_channels_only = tensor2.numel() == tensor2.size(1) + + return channel_match and (tensor1_channels_only or tensor2_channels_only) diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index ca6d8b97795..c84e66dd7d9 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -1,25 +1,23 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging -from typing import Set +from typing import Dict -import executorch.backends.cortex_m.ops.operators # noqa import torch from executorch.backends.cortex_m.passes.passes_utils import ( - extract_scalar_value, quantize_multiplier_aot, + SHIFT_INT8, ) -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass -from torch.fx.passes.infra.pass_manager import PassResult -logger = logging.getLogger("quant_op_fusion_pass") -logger.setLevel(logging.INFO) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch.fx.node import Argument class QuantizedOpFusionPass(ExportPass): @@ -33,223 +31,117 @@ class QuantizedOpFusionPass(ExportPass): Supports multiple binary operations with backward compatibility for add. """ - # Generic operation mapping - SUPPORTED_OPS_MAPPING = { - exir_ops.edge.aten.add.Tensor: exir_ops.edge.cortex_m.quantized_add.default, - # Future ops to be added here: - } - - def __init__(self): - super().__init__() - - def _get_dequant_targets(self) -> Set: - """Support both decomposed and cortex_m dequant targets for flexible pass ordering.""" - return { - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.cortex_m.dequantize_per_tensor.default, - } - - def _get_quant_targets(self) -> Set: - """Support both decomposed and cortex_m quant targets for flexible pass ordering.""" - return { - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.cortex_m.quantize_per_tensor.default, - } - - def _is_supported_binary_op(self, node: torch.fx.Node) -> bool: - """Check if node is a supported binary operation.""" - return node.op == "call_function" and node.target in self.SUPPORTED_OPS_MAPPING - - def _is_dequant_node(self, node: torch.fx.Node) -> bool: - """Check if node is a dequantize operation.""" - return ( - hasattr(node, "op") - and node.op == "call_function" - and node.target in self._get_dequant_targets() + def _get_add_replacement(self, args, meta): + if ( + meta.data.get("input_qparams", {}) == {} + or meta.data.get("output_qparams", {}) == {} + ): + return exir_ops.edge.aten.add.Tensor, args + + # Extract values + scale1 = meta["input_qparams"][0].scale + zero_point1 = meta["input_qparams"][0].zp + scale2 = meta["input_qparams"][1].scale + zero_point2 = meta["input_qparams"][1].zp + output_scale = meta["output_qparams"][0].scale + output_zero_point = meta["output_qparams"][0].zp + + # AoT COMPUTATION: Calculate multipliers and shifts + max_scale_2x = 2 * max(scale1, scale2) + + input1_mult, input1_shift = quantize_multiplier_aot(scale1 / max_scale_2x) + input2_mult, input2_shift = quantize_multiplier_aot(scale2 / max_scale_2x) + output_mult, output_shift = quantize_multiplier_aot( + max_scale_2x / (output_scale * (1 << SHIFT_INT8)) + ) + + args = ( + args[0], + zero_point1, + input1_mult, + input1_shift, + args[1], + zero_point2, + input2_mult, + input2_shift, + output_zero_point, + output_mult, + output_shift, ) - def _is_quant_node(self, node: torch.fx.Node) -> bool: - """Check if node is a quantize operation.""" - return ( - hasattr(node, "op") - and node.op == "call_function" - and node.target in self._get_quant_targets() + return exir_ops.edge.cortex_m.quantized_add.default, args + + def _get_mul_replacement(self, args, meta): + if ( + meta.data.get("input_qparams", {}) == {} + or meta.data.get("output_qparams", {}) == {} + ): + return exir_ops.edge.aten.mul.Tensor, args + + # Extract values + scale1 = meta["input_qparams"][0].scale + zero_point1 = meta["input_qparams"][0].zp + scale2 = meta["input_qparams"][1].scale + zero_point2 = meta["input_qparams"][1].zp + output_scale = meta["output_qparams"][0].scale + output_zero_point = meta["output_qparams"][0].zp + + scale_factor = (scale1 * scale2) / output_scale + output_mult, output_shift = quantize_multiplier_aot(scale_factor) + + args = ( + args[0], + zero_point1, + args[1], + zero_point2, + output_zero_point, + output_mult, + output_shift, ) - def _transfer_metadata( + return exir_ops.edge.cortex_m.quantized_mul.default, args + + def _get_minimum_replacement(self, args, meta): + if args[0].data.dtype != torch.int8: + return exir_ops.edge.aten.minimum.default, args + + return exir_ops.edge.cortex_m.minimum.default, args + + def _get_maximum_replacement(self, args, meta): + if args[0].data.dtype != torch.int8: + return exir_ops.edge.aten.maximum.default, args + + return exir_ops.edge.cortex_m.maximum.default, args + + def _get_permute_replacement(self, args, meta): + if args[0].data.dtype != torch.int8: + return exir_ops.edge.aten.permute_copy.default, args + + rank = len(args[0].data.shape) + perms = [p % rank for p in args[1]] + args = (args[0], perms) + return exir_ops.edge.cortex_m.transpose.default, args + + def call_operator( self, - new_node: torch.fx.Node, - source_node: torch.fx.Node, - pass_name: str = "QuantizedOpFusionPass", - ) -> None: - """Metadata transfer with proper provenance tracking.""" - if hasattr(source_node, "meta") and source_node.meta: - new_node.meta = source_node.meta.copy() - - if "from_node" in new_node.meta: - from_node_list = new_node.meta.get("from_node", []).copy() - from_node_list.append( - {"source": source_node.name, "pass": pass_name, "op": "fuse"} - ) - new_node.meta["from_node"] = from_node_list - - # Copy essential fields - for field in ["tensor_meta", "stack_trace"]: - if field in source_node.meta: - new_node.meta[field] = source_node.meta[field] - - def _normalize_to_cortex_m_targets(self, graph_module: torch.fx.GraphModule) -> int: - """Convert decomposed targets to cortex_m equivalents for consistent handling.""" - target_mapping = { - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.cortex_m.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: exir_ops.edge.cortex_m.quantize_per_tensor.default, - } - - normalization_count = 0 - for node in list(graph_module.graph.nodes): - if node.op == "call_function" and node.target in target_mapping: - logger.info(f"Normalizing {node.target} to cortex_m equivalent") - node.target = target_mapping[node.target] - normalization_count += 1 - - return normalization_count - - def _fuse_quantized_binary_patterns( - self, graph_module: torch.fx.GraphModule - ) -> int: - """Generic fusion for quantized binary operation patterns.""" - fusion_count = 0 - nodes_to_erase = [] - - for node in list(graph_module.graph.nodes): - if not self._is_quant_node(node): - continue - - quantize_node = node - if not quantize_node.args: - continue - - binary_op_node = quantize_node.args[0] - if not self._is_supported_binary_op(binary_op_node): - continue - - if len(binary_op_node.args) < 2: - continue - - dequant_node1, dequant_node2 = binary_op_node.args[:2] - if not ( - self._is_dequant_node(dequant_node1) - and self._is_dequant_node(dequant_node2) - ): - continue - - # Get the target quantized operation - quantized_target = self.SUPPORTED_OPS_MAPPING[binary_op_node.target] - # Extract op name (e.g., 'Tensor' -> 'add') - op_name = str(binary_op_node.target).split(".")[-1] - logger.info(f"✅ Found complete cortex_m Q/DQ + {op_name} pattern!") - - try: - # Extract values - int8_tensor1, scale1, zero_point1 = dequant_node1.args[:3] - int8_tensor2, scale2, zero_point2 = dequant_node2.args[:3] - output_scale, output_zero_point = quantize_node.args[1:3] - - # Convert to Python floats - scale1_val = extract_scalar_value(scale1) - scale2_val = extract_scalar_value(scale2) - output_scale_val = extract_scalar_value(output_scale) - zp1_val = int(extract_scalar_value(zero_point1)) - zp2_val = int(extract_scalar_value(zero_point2)) - output_zp_val = int(extract_scalar_value(output_zero_point)) - - # AoT COMPUTATION: Calculate multipliers and shifts - input1_mult, input1_shift = quantize_multiplier_aot( - scale1_val / output_scale_val - ) - input2_mult, input2_shift = quantize_multiplier_aot( - scale2_val / output_scale_val - ) - output_mult, output_shift = quantize_multiplier_aot( - 1.0 - ) # Output multiplier is 1 - - logger.info("AoT computed parameters:") - logger.info(f" Input1: mult={input1_mult}, shift={input1_shift}") - logger.info(f" Input2: mult={input2_mult}, shift={input2_shift}") - logger.info(f" Output: mult={output_mult}, shift={output_shift}") - - with graph_module.graph.inserting_after(quantize_node): - fused = graph_module.graph.create_node( - "call_function", - target=quantized_target, - args=( - int8_tensor1, - zp1_val, - input1_mult, - input1_shift, - int8_tensor2, - zp2_val, - input2_mult, - input2_shift, - output_zp_val, - output_mult, - output_shift, - ), - kwargs={}, - ) - - # metadata transfer - self._transfer_metadata(fused, quantize_node) - - logger.info(f"✅ Created fused quantized_{op_name} node: {fused}") - - # Replace all uses - quantize_node.replace_all_uses_with(fused) - binary_op_node.replace_all_uses_with(fused) - dequant_node1.replace_all_uses_with(fused) - dequant_node2.replace_all_uses_with(fused) - - nodes_to_erase.extend( - [quantize_node, binary_op_node, dequant_node1, dequant_node2] - ) - fusion_count += 1 - logger.info(f"Pattern fused, total so far: {fusion_count}") - - except Exception as e: - logger.info(f"❌ Error during AoT computation: {e}") - logger.info(" Skipping fusion for this pattern") - continue - - for old_node in reversed(nodes_to_erase): - if old_node in graph_module.graph.nodes and len(old_node.users) == 0: - logger.info(f"🗑️ Erasing node: {old_node}") - graph_module.graph.erase_node(old_node) - - return fusion_count - - def call(self, graph_module: torch.fx.GraphModule): - logger.info("QuantizedOpFusionPass.call() started") - - # Normalize targets for flexible pass ordering - normalization_count = self._normalize_to_cortex_m_targets(graph_module) - - # Generic fusion for supported binary operations - fusion_count = self._fuse_quantized_binary_patterns(graph_module) - - total_changes = normalization_count + fusion_count - logger.info(f"Total changes: {total_changes}") - - if total_changes > 0: - graph_module.graph.eliminate_dead_code() - graph_module.graph.lint() - graph_module.recompile() - - logger.debug("=== AFTER FUSION: All nodes in the graph ===") - for i, node in enumerate(graph_module.graph.nodes): - logger.debug(f"Node {i}: op={node.op}, target={node.target}") - if "quantized_" in str(node.target) and "add" in str(node.target): - logger.debug(" ⭐ FOUND QUANTIZED BINARY OP NODE! ⭐") - logger.debug("=== END DEBUG ===") - - return PassResult(graph_module, total_changes > 0) + op: EdgeOpOverload, + args: tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + + match op: + case exir_ops.edge.aten.add.Tensor: + op, args = self._get_add_replacement(args, meta) + case exir_ops.edge.aten.mul.Tensor: + op, args = self._get_mul_replacement(args, meta) + case exir_ops.edge.aten.minimum.default: + op, args = self._get_minimum_replacement(args, meta) + case exir_ops.edge.aten.maximum.default: + op, args = self._get_maximum_replacement(args, meta) + case exir_ops.edge.aten.permute_copy.default: + op, args = self._get_permute_replacement(args, meta) + case _: + pass + + return super().call_operator(op, args, {}, meta) diff --git a/backends/cortex_m/quantizer/operator_configs.py b/backends/cortex_m/quantizer/operator_configs.py new file mode 100644 index 00000000000..dadee30fa41 --- /dev/null +++ b/backends/cortex_m/quantizer/operator_configs.py @@ -0,0 +1,63 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Operator configs maps a list of operators/operator patterns to a quantization configuration. +These can be used with the OperatorConfigQuantizer to quantize models based on operator patterns. +""" + +import torch + +from executorch.backends.cortex_m.quantizer.quantization_configs import ( + INT8_PER_CHANNEL_CONFIG, + INT8_PER_TENSOR_CONFIG, +) +from torchao.quantization.pt2e.quantizer import OperatorConfig + +# ----------------- OPERATOR PATTERN PRESETS ----------------- +BINARY_OP_PATTERNS = [ + [torch.ops.aten.add.Tensor], + [torch.ops.aten.mul.Tensor], + [torch.ops.aten.hardswish.default], + [torch.ops.aten.hardswish_.default], +] + +LINEAR_OP_PATTERNS = [ + [torch.ops.aten.linear.default], + [torch.ops.aten.linear.default, torch.ops.aten.relu.default], + [torch.ops.aten.linear.default, torch.ops.aten.relu_.default], + [torch.ops.aten.linear.default, torch.ops.aten.hardtanh.default], + [torch.ops.aten.linear.default, torch.ops.aten.hardtanh_.default], + [torch.ops.aten.linear.default, torch.ops.aten.hardsigmoid.default], + [torch.ops.aten.linear.default, torch.ops.aten.hardsigmoid_.default], + [torch.ops.aten.linear.default, torch.ops.aten.clamp.default], + [torch.ops.aten.linear.default, torch.ops.aten.clamp_.default], +] + +CONV_OP_PATTERNS = [ + [torch.ops.aten.conv2d.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.relu.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.hardtanh.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.hardtanh_.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.hardsigmoid.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.hardsigmoid_.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.clamp.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.clamp_.default], +] + +# ----------------- OPERATOR CONFIG PRESETS ----------------- +INT8_BINARY_OPS_OPERATOR_CONFIG = OperatorConfig( + INT8_PER_TENSOR_CONFIG, BINARY_OP_PATTERNS +) + +INT8_LINEAR_OPERATOR_CONFIG = OperatorConfig( + INT8_PER_TENSOR_CONFIG, + LINEAR_OP_PATTERNS, +) + +INT8_CONV_OPERATOR_CONFIG = OperatorConfig( + INT8_PER_CHANNEL_CONFIG, + CONV_OP_PATTERNS, +) diff --git a/backends/cortex_m/quantizer/quantization_configs.py b/backends/cortex_m/quantizer/quantization_configs.py new file mode 100644 index 00000000000..c6600241b6d --- /dev/null +++ b/backends/cortex_m/quantizer/quantization_configs.py @@ -0,0 +1,99 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from torchao.quantization.pt2e import ( + HistogramObserver, + MinMaxObserver, + PerChannelMinMaxObserver, +) +from torchao.quantization.pt2e.quantizer import ( + DerivedQuantizationSpec, + QuantizationConfig, + QuantizationSpec, +) + +# ----------------- QUANTIZATION SPEC PRESETS ----------------- +INT8_WEIGHT_PER_TENSOR_QSPEC = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=MinMaxObserver, + qscheme=torch.per_tensor_symmetric, +) + +INT8_WEIGHT_PER_CHANNEL_QSPEC = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver, + qscheme=torch.per_channel_symmetric, + ch_axis=0, +) + +INT8_ACTIVATION_PER_TENSOR_QSPEC = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=HistogramObserver, + qscheme=torch.per_tensor_affine, +) + +INT8_ACTIVATION_PER_CHANNEL_QSPEC = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver, + qscheme=torch.per_channel_affine, + ch_axis=0, +) + + +def _derive_bias_qparams_fn( + obs_or_fqs, +) -> tuple[torch.Tensor, torch.Tensor]: + if len(obs_or_fqs) != 2: + raise ValueError( + f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" + ) + act_obs_or_fq = obs_or_fqs[0] + weight_obs_or_fq = obs_or_fqs[1] + act_scale, _ = act_obs_or_fq.calculate_qparams() + weight_scale, _ = weight_obs_or_fq.calculate_qparams() + return act_scale * weight_scale, torch.full_like( + weight_scale, fill_value=0, dtype=torch.int32 + ) + + +def _get_int32_bias_qspec(node): + return DerivedQuantizationSpec( + derived_from=[(node.args[0], node), (node.args[1], node)], # type: ignore[list-item] + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max - 1, + ) + + +def _get_int32_per_channel_bias_qspec(node): + return DerivedQuantizationSpec( + derived_from=[(node.args[0], node), (node.args[1], node)], # type: ignore[list-item] + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max - 1, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + ) + + +# ----------------- QUANTIZATION CONFIG PRESETS ----------------- +INT8_PER_TENSOR_CONFIG = QuantizationConfig( + INT8_ACTIVATION_PER_TENSOR_QSPEC, + INT8_ACTIVATION_PER_TENSOR_QSPEC, + INT8_WEIGHT_PER_TENSOR_QSPEC, + _get_int32_bias_qspec, +) + + +INT8_PER_CHANNEL_CONFIG = QuantizationConfig( + INT8_ACTIVATION_PER_TENSOR_QSPEC, + INT8_ACTIVATION_PER_TENSOR_QSPEC, + INT8_WEIGHT_PER_CHANNEL_QSPEC, + _get_int32_per_channel_bias_qspec, +) diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py new file mode 100644 index 00000000000..185a39b9eae --- /dev/null +++ b/backends/cortex_m/quantizer/quantizer.py @@ -0,0 +1,410 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Callable, List, Optional + +import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager +from executorch.backends.cortex_m.passes.passes_utils import ( + is_channel_broadcast, + is_channels_last, +) +from executorch.backends.cortex_m.quantizer.operator_configs import ( + BINARY_OP_PATTERNS, + CONV_OP_PATTERNS, + INT8_BINARY_OPS_OPERATOR_CONFIG, + INT8_CONV_OPERATOR_CONFIG, + INT8_LINEAR_OPERATOR_CONFIG, +) +from executorch.backends.cortex_m.quantizer.quantization_configs import ( + INT8_PER_TENSOR_CONFIG, + QuantizationSpec, +) +from torch._ops import OpOverload +from torch.fx import GraphModule, Node +from torchao.quantization.pt2e.quantizer import ( + ComposableQuantizer, + QuantizationAnnotation, + Quantizer, + SharedQuantizationSpec, +) +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY + + +def mark_node_as_annotated( + node: Node, + input_qspec_map: dict[Node, Optional[QuantizationSpec]], + output_qspec: Optional[QuantizationSpec], +) -> None: + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(input_qspec_map, output_qspec) + annotation_info = ArmAnnotationInfo( + quantized=True, + ) + meta_custom = node.meta.get("custom", {}) + meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = dict(annotation_info) + node.meta["custom"] = meta_custom + + +class CortexMQuantizer(ComposableQuantizer): + + def broadcasting_filter(self, node: Optional[Node]) -> bool: + """ + Filter function to exclude nodes that perform broadcasting. + """ + if node is None: + return False + if [node.target] not in BINARY_OP_PATTERNS: + return False + + if len(node.all_input_nodes) == 2: + t1 = get_first_fake_tensor(node.all_input_nodes[0]) + t2 = get_first_fake_tensor(node.all_input_nodes[1]) + return t1.shape != t2.shape and not ( + is_channel_broadcast(t1, t2) and is_channels_last(t1) + ) + + return False + + def nchw_filter(self, node: Optional[Node]) -> bool: + """ + Filter function to exclude nodes that use NCHW memory format. + """ + if node is None: + return False + if [node.target] not in CONV_OP_PATTERNS: + return False + + tensor = get_first_fake_tensor(node) + if tensor is None: + return False + + return not is_channels_last(tensor) + + def __init__(self) -> None: + quantizers: List[Quantizer] = [ + OperatorConfigQuantizer( + INT8_BINARY_OPS_OPERATOR_CONFIG, filter_fn=self.broadcasting_filter + ), + OperatorConfigQuantizer(INT8_LINEAR_OPERATOR_CONFIG), + OperatorConfigQuantizer( + INT8_CONV_OPERATOR_CONFIG, filter_fn=self.nchw_filter + ), + InputQuantizer(INT8_PER_TENSOR_CONFIG), + OutputQuantizer(INT8_PER_TENSOR_CONFIG), + SharedQspecQuantizer(), + ] + super().__init__(quantizers) + + def validate(self, model: GraphModule) -> bool: + return True + + def transform_for_annotation(self, model: GraphModule) -> GraphModule: + pass_manager = CortexMPassManager(None) + return pass_manager.transform_for_annotation(model) + + +class OperatorConfigQuantizer(Quantizer): + """ + Quantizes a graph according to an OperatorConfig. + + Args: + operator_config (OperatorConfig): The operator config to use for quantization. + filter_fn (Callable): Negative filter function. If it returns True on any node in the pattern, the pattern is + skipped. Used to match for example particular targets or modules. + """ + + def __init__( + self, + operator_config: QuantizationConfig, + filter_fn: Callable[[Node], bool] = lambda node: False, + ) -> None: + self.operator_config = operator_config + self.filter_fn = filter_fn + + def check_node(self, node: Optional[Node], target: str) -> bool: + """ + Return true if the node is a valid match for the given target. + """ + if node is None: + return False + if not node.target == target: + return False + if node.meta.get("quantizer_matched", False): + return False + if self.filter_fn(node): + return False + + return True + + def check_pattern( + self, node: Optional[Node], pattern: List[OpOverload] + ) -> Optional[List[Node]]: + """ + Returns the matched nodes if the given node matches the given pattern, otherwise None. + """ + match: List[Node] = [] + + for pattern_target in pattern: + if self.check_node(node, pattern_target): + match.append(node) + node = list(node.users)[0] if len(node.users) > 0 else None + else: + return None + + return match + + def match_patterns( + self, model: GraphModule, patterns: List[List[str]] + ) -> List[List[Node]]: + """ + Match all given patterns in the graph and return list of matches. + Each node can only be part of one match, larger patterns are prioritized. + Currently only linear patterns (single chain) are supported. + """ + patterns.sort(key=len, reverse=True) + matches: List[List[Node]] = [] + for pattern in patterns: + for node in model.graph.nodes: + potential_match = self.check_pattern(node, pattern) + if potential_match: + matches.append(potential_match) + for node in potential_match: + node.meta["quantizer_matched"] = True + + return matches + + def is_parameter(self, node: Node, model: GraphModule) -> bool: + """Returns True if the given node is a parameter of the model.""" + try: + _ = model.get_parameter(node.target) + return True + except Exception: + return False + + def is_weight(self, node: Node, params: List[Node], model: GraphModule) -> bool: + """Returns True if node is the first parameter of the given parameters""" + return len(params) > 0 and node == params[0] + + def is_bias(self, node: Node, params: List[Node], model: GraphModule) -> bool: + """Returns True if node is the second parameter of the given parameters""" + return len(params) == 2 and node == params[1] + + def annotate_match( + self, match: List[Node], config: QuantizationConfig, model: GraphModule + ) -> None: + """ + Annotates a matched pattern according to the given quantization config. The + following assumptions are made: + + - All operators have either no parameters, only weights, or weights and biases + - Tensors which are the first parameter of an operator are annotated as weights + - Tensors which are the second parameter of an operator are annotated as biases + - All other tensors going into the matched pattern are annotated as input activations. + - All other outputs coming out of the matched pattern are annotated as output activations. + + """ + for node in match: + input_qspec_map = {} + output_qspec = None + + params = [n for n in node.all_input_nodes if self.is_parameter(n, model)] + # Check that the assumptions on number of parameters hold to avoid silent errors + assert ( + 0 <= len(params) <= 2 + ), f"{self.__class__.__name__} expected 0 params, 1 params (weight) or 2 params (weight, bias), but got {len(params)} for node {node}." + + for input_node in node.all_input_nodes: + if self.is_weight(input_node, params, model): + input_qspec_map[input_node] = config.weight if config else None + elif self.is_bias(input_node, params, model): + # Bias qspec is derived from input + weight qspecs + input_qspec_map[input_node] = config.bias(node) if config else None + elif input_node not in match: + input_qspec_map[input_node] = ( + config.input_activation if config else None + ) + + if all(node not in match for node in node.users) and output_qspec is None: + output_qspec = config.output_activation if config else None + + mark_node_as_annotated(node, input_qspec_map, output_qspec) + + def annotate(self, model: GraphModule) -> None: + matches = self.match_patterns(model, self.operator_config.operators) + for match in matches: + self.annotate_match(match, self.operator_config.config, model) + + def validate(self, model: GraphModule) -> bool: + return True + + +class InputQuantizer(Quantizer): + """ + Quantizes only the input activations of the graph. + """ + + def __init__( + self, + quantization_config: QuantizationConfig, + filter_fn: Callable[[Node], bool] = lambda node: False, + ) -> None: + self.quantization_config = quantization_config + self.filter_fn = filter_fn + + def annotate(self, model: GraphModule) -> None: + for node in model.graph.nodes: + is_placeholder = node.op == "placeholder" + is_filtered = self.filter_fn(node) + if is_placeholder and not is_filtered: + mark_node_as_annotated( + node, {}, self.quantization_config.output_activation + ) + + def validate(self, model: GraphModule) -> bool: + return True + + +class OutputQuantizer(Quantizer): + """ + Quantizes only the output activations of the graph. + """ + + def __init__( + self, + quantization_config: QuantizationConfig, + filter_fn: Callable[[Node], bool] = lambda node: False, + ) -> None: + self.quantization_config = quantization_config + self.filter_fn = filter_fn + + def annotate(self, model: GraphModule) -> None: + output_node = model.graph.output_node() + input_qspec_map = { + n: self.quantization_config.input_activation + for n in output_node.all_input_nodes + if not self.filter_fn(n) + } + output_qspec = self.quantization_config.output_activation + mark_node_as_annotated(output_node, input_qspec_map, output_qspec) + + def validate(self, model: GraphModule) -> bool: + return True + + +class SharedQspecQuantizer(Quantizer): + """ + Special quantizer for assuring that given ops share the same quantization parameters on all input and outputs, + i.e. ops which does not change the scale such as clone, min/max, transposes and so on. + + Args: + targets (Optional[List[OpOverload]]): List of operator overloads to apply shared quantization spec to. + If None, a default list of supported ops is used. + """ + + SHARED_QSPEC_OPS_DEFAULT: List[OpOverload] = [ + # Clone + torch.ops.aten.clone.default, + torch.ops.aten.lift_fresh_copy.default, + torch.ops.aten.detach_.default, + # Min/Max/Mean + torch.ops.aten.minimum.default, + torch.ops.aten.maximum.default, + # Data shuffling + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.transpose.Dimname, + torch.ops.aten.transpose.int, + torch.ops.aten.transpose_copy.int, + torch.ops.aten.t_copy.default, + torch.ops.aten.t.default, + # Change shape + torch.ops.aten.squeeze.default, + torch.ops.aten.squeeze_copy.default, + torch.ops.aten.squeeze_copy.dim, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.reshape.default, + torch.ops.aten.view.default, + torch.ops.aten.view_as.default, + torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, + torch.ops.aten.unflatten.int, + torch.ops.aten.flatten.using_ints, + ] + + def __init__(self, targets: Optional[List[OpOverload]] = None) -> None: + super().__init__() + if targets is None: + self.targets = self.SHARED_QSPEC_OPS_DEFAULT + else: + self.targets = targets + + def _is_annotated(self, node: Node) -> bool: + return Q_ANNOTATION_KEY in node.meta + + def _annotate_shared_cluster(self, root_node: Node) -> None: + """ + Finds a cluster of unannotated nodes starting in root_node and annotates them with a common + SharedQuantizationSpec. + """ + + shared_nodes = set() + leaf_nodes = set() + bfs_queue = [root_node] + + while bfs_queue: + node = bfs_queue.pop(0) + + if self._is_annotated(node): + leaf_nodes.add(node) + continue + if node.op == "get_attr": + continue + + if node.target not in self.targets: + raise NotImplementedError( + ( + f"{SharedQspecQuantizer.__name__} found unannoted node '{node.name}' in neighbour_nodes " + "which is not in the supported target list. This might be the case either because:\n" + "1) The op should have shared qspec but is not in the target list. " + "In this case, try modifying the list using the targets field in the initializer.\n" + "2) The op should not be quantized, which is not currently supported by the SharedQspecQuantizer." + ) + ) + + shared_nodes.add(node) + neighbour_nodes = list(node.all_input_nodes) + list(node.users) + for n in neighbour_nodes: + if n not in shared_nodes: + bfs_queue.append(n) + + # The selection of root node for the shared_qspec is important for + # torchao.quantization.pt2e.prepare._create_obs_or_fq_from_qspec: + # 1. For regular QuantizationSpecs, it creates a new observer + # 2. For SharedQuantizationSpecs, it returns the observer created for it's root node + # 3. It handles nodes in the order they appear in graph.nodes + # This means that the root node of the shared group needs to be the first annotated node that appears in graph.nodes. + shared_root_node = next(n for n in root_node.graph.nodes if n in leaf_nodes) + shared_qspec = SharedQuantizationSpec(shared_root_node) + + for node in shared_nodes: + input_qspec_map: dict[Node, Optional[QuantizationSpec]] = { + n: shared_qspec for n in node.all_input_nodes + } + mark_node_as_annotated(node, input_qspec_map, shared_qspec) + + def annotate(self, model: GraphModule) -> None: + for node in model.graph.nodes: + if node.target in self.targets and not self._is_annotated(node): + self._annotate_shared_cluster(node) + + def validate(self, model: GraphModule) -> bool: + return True diff --git a/backends/cortex_m/test/TARGETS b/backends/cortex_m/test/TARGETS index b7a04f3efab..292a087a88a 100644 --- a/backends/cortex_m/test/TARGETS +++ b/backends/cortex_m/test/TARGETS @@ -8,13 +8,11 @@ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") load("targets.bzl", "define_common_targets") oncall("executorch") - python_unittest( name="test_replace_quant_nodes", srcs=[ "test_helpers_passes_utils.py", "test_replace_quant_nodes.py", - "test_quantize_op_fusion_pass.py", ], deps=[ "//pytorch/ao:torchao", # @manual diff --git a/backends/cortex_m/test/build_test_runner.sh b/backends/cortex_m/test/build_test_runner.sh new file mode 100755 index 00000000000..bf29b21d310 --- /dev/null +++ b/backends/cortex_m/test/build_test_runner.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# TODO: More separation from the regular arm executor runner and testing. + +set -eu + +# Always rebuild executorch in case the cortex-m kernels has been updated. +script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")") +et_root_dir=$(realpath "${script_dir}/../../..") +build_executorch="${et_root_dir}/backends/arm/scripts/build_executorch.sh" +${build_executorch} + +# Build executor runner with selected aten ops and semi hosting +build_dir="${et_root_dir}/arm_test" +build_executor_runner="${et_root_dir}/backends/arm/scripts/build_executor_runner.sh" +build_root_test_dir="${et_root_dir}/arm_test/arm_semihosting_executor_runner_corstone-300" + +select_ops_list="\ +aten::add.out,\ +aten::clamp.out,\ +aten::convolution.out,\ +aten::div.out,\ +aten::mean.out,\ +aten::mul.out,\ +aten::relu.out,\ +aten::view_copy.out,\ +dim_order_ops::_to_dim_order_copy.out" + +${build_executor_runner} --pte=semihosting --target=ethos-u55-128 --output="${build_root_test_dir}" --select_ops_list="${select_ops_list}" diff --git a/backends/cortex_m/test/misc/test_quantization.py b/backends/cortex_m/test/misc/test_quantization.py new file mode 100644 index 00000000000..d4f84e4f075 --- /dev/null +++ b/backends/cortex_m/test/misc/test_quantization.py @@ -0,0 +1,359 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +class SharedQspecMulipleClusters(torch.nn.Module): + """Three linear shared qspec clusters.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 8, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 8, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 4, + } + + def forward(self, x): + x1 = torch.clone(x) + x2 = x1 + x1 + x3 = torch.clone(x2) + x3 = torch.clone(x3) + x3 = torch.clone(x3) + x4 = x3 + x3 + x5 = torch.transpose(x4, 2, 1) + return x5 + + +class SharedQspecInputForkNonShared(torch.nn.Module): + """Shared qspec cluster with an input fork with both inputs as non-shared qspecs.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_maximum_default": 1, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_maximum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + } + + def forward(self, x, y): + z = torch.maximum(x, y) + return torch.flatten(z) + + +class SharedQspecInputForkShared(torch.nn.Module): + """Shared qspec cluster with an input fork with both inputs as shared qspecs.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 5, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 5, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_minimum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + } + + def forward(self, x, y): + x = torch.clone(x) + y = torch.permute(y, (0, 1, 3, 2)) + z = torch.minimum(x, y) + return z + + +class SharedQspecInputForkXShared(torch.nn.Module): + """Shared qspec cluster with an input fork with left input as shared qspec.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_maximum_default": 1, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_maximum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + } + + def forward(self, x, y): + x = torch.t_copy(x) + z = torch.maximum(x, y) + return z + + +class SharedQspecInputForkYShared(torch.nn.Module): + """Shared qspec cluster with an input fork with right input as shared qspec.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 5, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 5, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_minimum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + } + + def forward(self, x, y): + y = torch.clone(y) + z = torch.minimum(x, y) + return torch.squeeze(z) + + +class SharedQspecInputForkXConstant(torch.nn.Module): + """Shared qspec cluster with an input fork with left input as global constant.""" + + ops_before_transforms = {} + ops_after_transforms = {} + constant = torch.tensor(5.0) + + def forward(self, x): + return torch.minimum(self.constant, x) + + +class SharedQspecInputForkYConstant(torch.nn.Module): + """Shared qspec cluster with an input fork with left input as local constant.""" + + ops_before_transforms = {} + ops_after_transforms = {} + + def forward(self, x): + return torch.maximum(x, torch.tensor(5.0)) + + +class SharedQspecOutputForkNonShared(torch.nn.Module): + """Shared qspec cluster with an output fork with both outputs as non-shared qspecs.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + } + + def forward(self, x): + x = torch.unsqueeze(x, 0) + y = x + x + return x, y + + +class SharedQspecOutputForkShared(torch.nn.Module): + """Shared qspec cluster with an output fork with both outputs as shared qspecs.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 6, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + } + + def forward(self, x): + x = torch.unsqueeze(x, 0) + y = torch.clone(x) + z = torch.permute_copy(x, (0, 2, 1, 3)) + return y, z, x + + +class SharedQspecManyForks(torch.nn.Module): + """Shared qspec cluster with a number of forks to testmore complex structures.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_maximum_default": 2, + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 9, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_maximum_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_minimum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + } + + def forward(self, x): + x1 = torch.clone(x) + x2 = torch.maximum(x, x1) + x3 = torch.maximum(x, torch.t(x2)) + x4 = torch.minimum(x2, x3) + + return x4 + + +class SharedQspecSurroundedQuantizedOp(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_maximum_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 5, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_maximum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + } + + def forward(self, x): + x1 = torch.clone(x) + x2 = torch.add(x1, x1) + x3 = torch.maximum(x1, x2) + return x3 + + +class SharedQspecSurroundedQuantizedOpConstant(torch.nn.Module): + ops_before_transforms = {} + ops_after_transforms = {} + + def forward(self, x): + x1 = torch.clone(x) + x2 = torch.add(x1, torch.ones(2, 2)) + x3 = torch.maximum(x1, x2) + return x3 + + +class SharedQspecSub(torch.nn.Module): + ops_before_transforms = {} + ops_after_transforms = {} + + def forward(self, x, y): + return torch.clone(x - y) + + +test_cases = { + "multiple_clusters": McuTestCase( + SharedQspecMulipleClusters(), + (ramp_tensor(-2, 2, (2, 3, 4)),), + ), + "input_fork_non_shared": McuTestCase( + SharedQspecInputForkNonShared(), + (ramp_tensor(-2, 2, (2, 3, 4)), ramp_tensor(-1, 3, (2, 3, 4))), + ), + "input_fork_shared": McuTestCase( + SharedQspecInputForkShared(), + (ramp_tensor(-2, 2, (2, 3, 4, 5)), ramp_tensor(-1, 3, (2, 3, 5, 4))), + ), + "input_fork_x_shared": McuTestCase( + SharedQspecInputForkXShared(), + (ramp_tensor(-2, 2, (3, 4)), ramp_tensor(-1, 3, (4, 3))), + ), + "input_fork_y_shared": McuTestCase( + SharedQspecInputForkYShared(), + (ramp_tensor(-2, 2, (2, 3, 4)), ramp_tensor(-1, 3, (2, 3, 4))), + ), + "input_fork_x_constant": McuTestCase( + SharedQspecInputForkXConstant(), + (ramp_tensor(-2, 2, (2, 3, 4)),), + ), + "input_fork_y_constant": McuTestCase( + SharedQspecInputForkYConstant(), + (ramp_tensor(-2, 2, (2, 3, 4)),), + ), + "surrounded_quantized_op": McuTestCase( + SharedQspecSurroundedQuantizedOp(), + (ramp_tensor(-128, 2, (2, 3, 4)),), + ), + "surrounded_quantized_op_constant": McuTestCase( + SharedQspecSurroundedQuantizedOpConstant(), + (ramp_tensor(-2, 2, (2, 2)),), + ), + "output_fork_non_shared": McuTestCase( + SharedQspecOutputForkNonShared(), + (ramp_tensor(-2, 2, (2, 3, 4)),), + ), + "output_fork_shared": McuTestCase( + SharedQspecOutputForkShared(), + (ramp_tensor(-2, 2, (2, 3, 4)),), + ), + "many_forks": McuTestCase( + SharedQspecManyForks(), + (ramp_tensor(-20, 2, (4, 4)),), + ), + "non-quantized_op": McuTestCase( + SharedQspecSub(), + (ramp_tensor(0, 10, (5, 5)), ramp_tensor(0, 1, (5, 5))), + ), +} + +xfails = { + "surrounded_quantized_op_constant": "Numerical error since the add is forced to have non-correct qparams.", + "non-quantized_op": "Non-quantized ops are not currently supported in SharedQspecQuantizer.", +} + + +@parametrize("test_case", test_cases, xfails=xfails) +def test_shared_qspec_quantizer(test_case): + """ + Test that ops which does not change dynamic range are able to use int8 portable kernels. + """ + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + ) + + # Check that all nodes in the graph are in int8 + artifact = tester.get_artifact() + for node in artifact.exported_program().module().graph.nodes: + if node.op != "call_function": + continue + if node.target == exir_ops.edge.cortex_m.dequantize_per_tensor.default: + continue + + assert get_first_fake_tensor(node).dtype == torch.int8, f"{node.name}" diff --git a/backends/cortex_m/test/models/__init__.py b/backends/cortex_m/test/models/__init__.py new file mode 100644 index 00000000000..c8d1c683da3 --- /dev/null +++ b/backends/cortex_m/test/models/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/cortex_m/test/models/test_mobilenet_v3.py b/backends/cortex_m/test/models/test_mobilenet_v3.py new file mode 100644 index 00000000000..598d71ed212 --- /dev/null +++ b/backends/cortex_m/test/models/test_mobilenet_v3.py @@ -0,0 +1,74 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase +from torchvision import models + + +# TODO: Update as more ops are converted by CMSIS-NN ops. +ops_before_transforms: dict[str, int] = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 34, + "executorch_exir_dialects_edge__ops_aten_addmm_default": 2, + "executorch_exir_dialects_edge__ops_aten_clamp_default": 56, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 52, + "executorch_exir_dialects_edge__ops_aten_div_Tensor": 28, + "executorch_exir_dialects_edge__ops_aten_mean_dim": 10, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 28, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 2, + "executorch_exir_dialects_edge__ops_aten_relu_default": 14, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 56, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 178, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 109, +} +ops_after_transforms: dict[str, int] = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 28, # Not lowered due to broadcasting + "executorch_exir_dialects_edge__ops_aten_addmm_default": 0, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 6, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 2, + "executorch_exir_dialects_edge__ops_aten_clamp_default": 56, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 52, + "executorch_exir_dialects_edge__ops_aten_div_Tensor": 28, + "executorch_exir_dialects_edge__ops_aten_mean_dim": 10, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 28, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 0, + "executorch_exir_dialects_edge__ops_aten_relu_default": 14, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 56, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 0, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 0, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 162, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 101, +} + +model = models.mobilenet_v3_small(weights=None) +example_input = torch.randn(1, 3, 224, 224) + + +test_cases = { + "mobilenet_v3_small": McuTestCase( + model=models.mobilenet_v3_small(weights=None), + example_inputs=(example_input,), + ), +} + + +@pytest.mark.skip("Skip until add + linear fix are upstreamed.") +def test_dialect_mv3(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + ops_before_transforms, + ops_after_transforms, + qtol=1, + ) + + +@pytest.mark.skip("Skip until add + linear fix are upstreamed.") +def test_implementation_mv3(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation(qtol=1) diff --git a/backends/cortex_m/test/ops/__init__.py b/backends/cortex_m/test/ops/__init__.py new file mode 100644 index 00000000000..c8d1c683da3 --- /dev/null +++ b/backends/cortex_m/test/ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/cortex_m/test/ops/test_activation.py b/backends/cortex_m/test/ops/test_activation.py new file mode 100644 index 00000000000..407966521f3 --- /dev/null +++ b/backends/cortex_m/test/ops/test_activation.py @@ -0,0 +1,528 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMLinearReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=4, out_features=3): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.linear(x)) + + +class CortexMLinearHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-0.25, max_val=0.75): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val) + self.min_val = min_val + self.max_val = max_val + + def forward(self, x): + return self.act(self.linear(x)) + + +class CortexMLinearReLU6(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=8, out_features=8): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.relu6 = torch.nn.ReLU6() + + def forward(self, x): + return self.relu6(self.linear(x)) + + +class CortexMLinearReLUInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=8, out_features=8): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(self.linear(x)) + + +class CortexMLinearHardtanhInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-1.0, max_val=1.0): + super().__init__() + self.linear = torch.nn.Linear(8, 8, bias=False) + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=True) + + def forward(self, x): + return self.act(self.linear(x)) + + +class CortexMLinearHardsigmoid(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardsigmoid_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=6, out_features=6): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.act = torch.nn.Hardsigmoid() + + def forward(self, x): + return self.act(self.linear(x)) + + +class CortexMLinearHardswish(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardswish_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_minimum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_mul_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=8, out_features=8): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.act = torch.nn.Hardswish() + + def forward(self, x): + return self.act(self.linear(x)) + + +class CortexMConv2DReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + +class CortexMConv2DReLU6(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 3, stride=2, padding=1, bias=False) + self.relu6 = torch.nn.ReLU6() + + def forward(self, x): + return self.relu6(self.conv(x)) + + +class CortexMConv2DHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-2.0, max_val=2.0): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=True) + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val) + + def forward(self, x): + return self.act(self.conv(x)) + + +class CortexMConv2DHardswish(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardswish_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_minimum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_mul_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_channels=1, out_channels=1): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1, padding=0, bias=False) + self.act = torch.nn.Hardswish() + self.conv.weight.data.fill_(1) + + def forward(self, x): + return self.act(self.conv(x)) + + +class CortexMConv2DReLUInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=False) + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(self.conv(x)) + + +class CortexMConv2DHardtanhInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-0.5, max_val=0.5): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=False) + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=True) + torch.nn.init.ones_(self.conv.weight) + + def forward(self, x): + return self.act(self.conv(x)) + + +class CortexMConv2DHardsigmoid(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardsigmoid_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1, bias=False) + self.act = torch.nn.Hardsigmoid(inplace=True) + self.conv.weight.data.fill_(1) + + def forward(self, x): + return self.act(self.conv(x)) + + +class CortexMConv2DClampInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1, bias=False) + self.conv.weight.data.fill_(1) + + def forward(self, x): + return torch.clamp_(self.conv(x), min=0.0, max=None) + + +class CortexMLinearClamp(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=4, out_features=3): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + + def forward(self, x): + return torch.clamp(self.linear(x), min=None, max=6.0) + + +test_cases = { + # Linear + activation tests with various data ranges + "linear_relu_small_range": McuTestCase( + model=CortexMLinearReLU(), + example_inputs=(ramp_tensor(-10, 10, (1, 4)),), + ), + "linear_relu_large_range": McuTestCase( + model=CortexMLinearReLU(in_features=16, out_features=16), + example_inputs=(ramp_tensor(-100, 100, (2, 16)),), + ), + "linear_relu_negative": McuTestCase( + model=CortexMLinearReLU(in_features=8, out_features=8), + example_inputs=(ramp_tensor(-50, 0, (1, 8)),), + ), + "linear_relu6": McuTestCase( + model=CortexMLinearReLU6(), + example_inputs=(ramp_tensor(-2, 10, (1, 8)),), + ), + "linear_relu_inplace": McuTestCase( + model=CortexMLinearReLUInplace(), + example_inputs=(ramp_tensor(-5, 5, (2, 8)),), + ), + "linear_hardtanh_symmetric": McuTestCase( + model=CortexMLinearHardtanh(min_val=-0.5, max_val=0.5), + example_inputs=(ramp_tensor(-1, 1, (2, 1, 4)),), + ), + "linear_hardtanh_asymmetric": McuTestCase( + model=CortexMLinearHardtanh(min_val=-1.5, max_val=0.25), + example_inputs=(ramp_tensor(-2, 1, (1, 4)),), + ), + "linear_hardtanh_large_range": McuTestCase( + model=CortexMLinearHardtanh(min_val=-10.0, max_val=10.0), + example_inputs=(ramp_tensor(-20, 20, (2, 4)),), + ), + "linear_hardtanh_inplace": McuTestCase( + model=CortexMLinearHardtanhInplace(min_val=-0.75, max_val=0.75), + example_inputs=(ramp_tensor(-2, 2, (1, 8)),), + ), + # Convolution + activation tests with various configurations + "conv2d_relu_small_kernel": McuTestCase( + model=CortexMConv2DReLU(), + example_inputs=( + ramp_tensor(-5, 5, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_relu_large_range": McuTestCase( + model=CortexMConv2DReLU(), + example_inputs=( + ramp_tensor(-50, 50, (2, 4, 16, 16)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_relu6_stride": McuTestCase( + model=CortexMConv2DReLU6(), + example_inputs=( + ramp_tensor(-10, 20, (1, 3, 12, 12)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_relu_inplace": McuTestCase( + model=CortexMConv2DReLUInplace(), + example_inputs=( + ramp_tensor(-3, 3, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_hardtanh_narrow": McuTestCase( + model=CortexMConv2DHardtanh(min_val=-0.5, max_val=0.5), + example_inputs=( + ramp_tensor(-2, 2, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_hardtanh_wide": McuTestCase( + model=CortexMConv2DHardtanh(min_val=-5.0, max_val=5.0), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_hardtanh_inplace": McuTestCase( + model=CortexMConv2DHardtanhInplace(min_val=-10.0, max_val=10.0), + example_inputs=( + ramp_tensor(-15, 15, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "linear_hardsigmoid": McuTestCase( + model=CortexMLinearHardsigmoid(in_features=6, out_features=4), + example_inputs=(ramp_tensor(-8, 8, (2, 6)),), + ), + "linear_hardswish": McuTestCase( + model=CortexMLinearHardswish(in_features=12, out_features=6), + example_inputs=(ramp_tensor(-2, 0, (1, 12)),), + ), + "conv2d_hardsigmoid_inplace": McuTestCase( + model=CortexMConv2DHardsigmoid(), + example_inputs=( + ramp_tensor(-4, 4, (1, 1, 6, 6)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_hardswish": McuTestCase( + model=CortexMConv2DHardswish(in_channels=1, out_channels=1), + example_inputs=( + ramp_tensor(-3, 0, (1, 1, 1, 100)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_clamp_inplace": McuTestCase( + model=CortexMConv2DClampInplace(), + example_inputs=( + ramp_tensor(-4, 4, (1, 1, 1, 10)).to(memory_format=torch.channels_last), + ), + ), + "linear_clamp": McuTestCase( + model=CortexMLinearClamp(in_features=4, out_features=3), + example_inputs=(ramp_tensor(-10, 10, (1, 4)),), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_activation(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, + ) + + +@parametrize("test_case", test_cases) +def test_implementation_activation(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation(qtol=1) diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py new file mode 100644 index 00000000000..ad5f276b544 --- /dev/null +++ b/backends/cortex_m/test/ops/test_add.py @@ -0,0 +1,182 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) +from executorch.backends.test.suite.operators.test_add import Model, ModelAlpha + + +class CortexMSelfAdd(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return x + x + + +class CortexMScalarAdd(Model): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +class CortexMTensorAdd(Model): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +class CortexMAlphaAdd(ModelAlpha): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +test_cases = { + "self_rank_1": McuTestCase( + CortexMSelfAdd(), + (torch.linspace(-5, 5, 10),), + ), + "self_rank_2_pos": McuTestCase( + CortexMSelfAdd(), + (ramp_tensor(0, 1000, (10, 1)),), + ), + "self_rank_3_neg": McuTestCase( + CortexMSelfAdd(), + (ramp_tensor(-100, 0, (2, 2, 2)),), + ), + "self_rank_4_small": McuTestCase( + CortexMSelfAdd(), + (ramp_tensor(-0.1, 0.1, (2, 2, 2, 2)),), + ), + "self_rank_5": McuTestCase( + CortexMSelfAdd(), + (ramp_tensor(-5, 5, (2, 2, 2, 2, 2)),), + ), + "tensor_scalar": McuTestCase( + CortexMScalarAdd(), + (torch.ones(1), 1.1), + ), + "scalar_tensor": McuTestCase( + CortexMScalarAdd(), + (1000.1, torch.ones(1)), + ), + "tensor_tensor": McuTestCase( + CortexMTensorAdd(), + (torch.rand(2, 2) * 10, torch.rand(2, 2)), + ), + "broadcast_1": McuTestCase( + CortexMTensorAdd(), + (torch.ones(1), torch.ones(2, 2, 2, 2)), + ), + "broadcast_2": McuTestCase( + CortexMTensorAdd(), + (torch.ones((2, 1, 1, 1)), torch.ones(1)), + ), + "broadcast_3": McuTestCase( + CortexMTensorAdd(), + ( + ramp_tensor(-2, 2, (2, 1, 2, 1)), + ramp_tensor(-5, 5, (1, 2, 1, 2)), + ), + ), + "broadcast_channels_1": McuTestCase( + CortexMTensorAdd(), + ( + ramp_tensor(-2, 2, (1, 8, 1, 1)).to(memory_format=torch.channels_last), + ramp_tensor(-5, 5, (1, 8, 5, 5)).to(memory_format=torch.channels_last), + ), + ), + "broadcast_channels_2": McuTestCase( + CortexMTensorAdd(), + ( + ramp_tensor(-5, 5, (2, 8, 5, 5)).to(memory_format=torch.channels_last), + ramp_tensor(-2, 2, (1, 8, 1, 1)).to(memory_format=torch.channels_last), + ), + ), + "broadcast_channels_continous": McuTestCase( + CortexMTensorAdd(), + ( + ramp_tensor(-5, 5, (2, 8, 5, 5)), + ramp_tensor(-2, 2, (1, 8, 1, 1)), + ), + ), + "alpha": McuTestCase( + CortexMAlphaAdd(0.5), + ( + ramp_tensor(-10, 10, (4, 5)), + ramp_tensor(-20, 20, (4, 5)), + ), + ), +} + + +xfails_implementation = { + "alpha": ( + "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", + AssertionError, + ), +} +xfails_dialect = xfails_implementation | { + # Cortex-M quantizer will not quantize additions that require broadcasting + # leading to the add op not being replaced by a cortex-m specific implementation + "broadcast_1": "Broadcasting is not supported in Cortex-M backend", + "broadcast_2": "Broadcasting is not supported in Cortex-M backend", + "broadcast_3": "Broadcasting is not supported in Cortex-M backend", + "broadcast_channels_continous": "Broadcasting channels is not supported in continous memory_format in Cortex-M backend.", +} + + +@parametrize("test_case", test_cases, xfails=xfails_dialect) +def test_dialect_add(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + ) + + +@parametrize("test_case", test_cases, xfails=xfails_implementation) +def test_implementation_add(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation() diff --git a/backends/cortex_m/test/ops/test_conv.py b/backends/cortex_m/test/ops/test_conv.py new file mode 100644 index 00000000000..5630abbdab3 --- /dev/null +++ b/backends/cortex_m/test/ops/test_conv.py @@ -0,0 +1,210 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMConv1D(torch.nn.Module): + ops_before_transforms = {} + ops_after_transforms = {} + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv1d(*args, **kwargs, bias=False) + + def forward(self, x): + return self.conv(x) + + +class CortexMConv2D(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv2d(*args, **kwargs, bias=False) + self.conv.weight.data.fill_(1.0) + + def forward(self, x): + return self.conv(x) + + +class CortexMConv2DBias(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv2d(*args, **kwargs, bias=True) + + def forward(self, x): + + return self.conv(x) + + +class CortexMConv3D(torch.nn.Module): + ops_before_transforms = {} + + ops_after_transforms = {} + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv3d(*args, **kwargs, bias=False) + self.conv.weight.data.fill_(2.0) + + def forward(self, x): + return self.conv(x) + + +class CortexMConv2Dx3(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 3, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 8, 3, padding=1, bias=False) + self.conv2 = torch.nn.Conv2d(8, 16, 3, padding=1, bias=False) + self.conv3 = torch.nn.Conv2d(16, 8, 3, padding=1, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + + +# in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode +test_cases = { + "conv2d": McuTestCase( + model=CortexMConv2D(2, 4, 3), + example_inputs=( + ramp_tensor(1, 5, (1, 2, 5, 5)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_stride": McuTestCase( + model=CortexMConv2D(3, 4, (1, 2), stride=2), + example_inputs=( + ramp_tensor(-100, 10, (3, 3, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_padding": McuTestCase( + model=CortexMConv2D(3, 2, 3, padding=(4, 1)), + example_inputs=( + ramp_tensor(0, 1, (2, 3, 5, 5)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_dilation": McuTestCase( + model=CortexMConv2D(1, 4, 3, dilation=(2, 2)), + example_inputs=( + ramp_tensor(0, 10, (3, 1, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_groups": McuTestCase( + model=CortexMConv2D(4, 4, 1, groups=2), + example_inputs=( + ramp_tensor(0, 10, (1, 4, 1, 1)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_bias_ch_out_1": McuTestCase( + model=CortexMConv2DBias(5, 1, 1), + example_inputs=( + ramp_tensor(0, 10, (2, 5, 3, 3)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_bias_ch_out_4": McuTestCase( + model=CortexMConv2DBias(5, 4, (1, 2)), + example_inputs=( + ramp_tensor(-3, 3, (2, 5, 10, 10)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_nchw": McuTestCase( + model=CortexMConv2D(5, 5, 1), + example_inputs=(ramp_tensor(0, 10, (1, 5, 8, 8)),), + ), + "conv1d": McuTestCase( + model=CortexMConv1D(1, 1, 1), + example_inputs=(ramp_tensor(0, 10, (1, 3, 2)),), + ), + "conv3d": McuTestCase( + model=CortexMConv3D(1, 1, 1), + example_inputs=( + ramp_tensor(-1000, 1000, (2, 1, 3, 3, 3)).to( + memory_format=torch.channels_last_3d + ), + ), + ), + "conv2d_x3": McuTestCase( + model=CortexMConv2Dx3(), + example_inputs=( + ramp_tensor(0, 10, (1, 3, 8, 8)).to(memory_format=torch.channels_last), + ), + ), +} + + +xfails_dialect = { + "conv2d_dilation": "NotImplementedError: 'slow_conv_dilated<>' not implemented for 'Int'", + "conv1d": "Currently not supported.", + "conv2d_nchw": "Currently not supported.", +} + + +@parametrize("test_case", test_cases, xfails=xfails_dialect) +def test_dialect_conv2d(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, + ) + + +xfails_implementation = { + "conv1d": "Currently not supported.", + "conv3d": "Currently not supported.", +} + + +@parametrize("test_case", test_cases, xfails=xfails_implementation) +def test_implementation_conv2d(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation(qtol=2) diff --git a/backends/cortex_m/test/ops/test_linear.py b/backends/cortex_m/test/ops/test_linear.py new file mode 100644 index 00000000000..e81daa7e83e --- /dev/null +++ b/backends/cortex_m/test/ops/test_linear.py @@ -0,0 +1,130 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMLinear(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.linear = torch.nn.Linear(*args, bias=False) + self.linear.weight.data.fill_(1.0) + + def forward(self, x): + return self.linear(x) + + +class CortexMLinearX3(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 3, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.linear = torch.nn.Linear(*args, bias=False) + self.linear.weight.data.fill_(1.0) + + def forward(self, x): + x = self.linear(x) + x = self.linear(x) + x = self.linear(x) + return x + + +class CortexMLinearBias(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.linear = torch.nn.Linear(*args, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.linear(x) + + +test_cases = { + "linear_rank1": McuTestCase( + model=CortexMLinear(1, 2), + example_inputs=(torch.Tensor([1]),), + ), + "linear_rank2_pos": McuTestCase( + model=CortexMLinear(1, 2), + example_inputs=(ramp_tensor(-1, 1, (1, 1)),), + ), + "linear_rank3_neg": McuTestCase( + model=CortexMLinear(5, 3), + example_inputs=(ramp_tensor(-40, 0, (4, 2, 5)),), + ), + "linear_rank4": McuTestCase( + model=CortexMLinear(16, 32), + example_inputs=(ramp_tensor(-100, 100, (2, 1, 2, 16)),), + ), + "linear_rank5": McuTestCase( + model=CortexMLinear(4, 3), + example_inputs=(ramp_tensor(-2, 2, (5, 2, 1, 2, 4)),), + ), + "linear_bias": McuTestCase( + model=CortexMLinearBias(61, 37), + example_inputs=(ramp_tensor(0, 10, (8, 61)),), + ), + "linear_x3": McuTestCase( + model=CortexMLinearX3(4, 4), + example_inputs=(ramp_tensor(0, 10, (2, 4)),), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_linear(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, + ) + + +@parametrize("test_case", test_cases) +def test_implementation_linear(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation(qtol=1) diff --git a/backends/cortex_m/test/ops/test_lstm.py b/backends/cortex_m/test/ops/test_lstm.py new file mode 100644 index 00000000000..60d7aba4271 --- /dev/null +++ b/backends/cortex_m/test/ops/test_lstm.py @@ -0,0 +1,98 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import pytest +import torch +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMLSTM(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_full_default": 2, + "executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 4, + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 2, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 6, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 3, + "executorch_exir_dialects_edge__ops_aten_addmm_default": 3, + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 4, + "executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default": 2, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 6, + "executorch_exir_dialects_edge__ops_aten_tanh_default": 4, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 6, + "executorch_exir_dialects_edge__ops_aten_cat_default": 1, + } + + ops_after_transforms = {} + + def __init__(self, input_size: int = 4, hidden_size: int = 3) -> None: + super().__init__() + self.lstm = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y, _ = self.lstm(x) + return y + + +class CortexMQuantizableLSTM(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 4, + "executorch_exir_dialects_edge__ops_aten_addmm_default": 4, + "executorch_exir_dialects_edge__ops_aten_cat_default": 1, + "executorch_exir_dialects_edge__ops_aten_full_default": 1, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 6, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 4, + "executorch_exir_dialects_edge__ops_aten_select_copy_int": 2, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 6, + "executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default": 2, + "executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 1, + "executorch_exir_dialects_edge__ops_aten_tanh_default": 4, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 34, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 27, + } + + ops_after_transforms = {} + + def __init__(self, input_size: int = 4, hidden_size: int = 3) -> None: + super().__init__() + self.lstm = torch.ao.nn.quantizable.LSTM( + input_size=input_size, hidden_size=hidden_size + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y, _ = self.lstm(x) + return y + + +test_cases = { + "lstm_fp32": McuTestCase( + model=CortexMLSTM(), + example_inputs=(ramp_tensor(-1, 1, (2, 1, 4)),), + ), + "lstm_quantizable": McuTestCase( + model=CortexMQuantizableLSTM(), + example_inputs=(ramp_tensor(-1, 1, (2, 1, 4)),), + ), +} + + +@pytest.mark.skip("Not implemented yet.") +def test_dialect_lstm(test_case: McuTestCase) -> None: + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + ) + + +@pytest.mark.skip("Not implemented yet.") +def test_implementation_lstm(test_case: McuTestCase) -> None: + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation() diff --git a/backends/cortex_m/test/ops/test_maximum.py b/backends/cortex_m/test/ops/test_maximum.py new file mode 100644 index 00000000000..58d477a9516 --- /dev/null +++ b/backends/cortex_m/test/ops/test_maximum.py @@ -0,0 +1,83 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMTensorMaximum(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_maximum_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_maximum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x, y): + return torch.maximum(x, y) + + +test_cases = { + "tensor_small": McuTestCase( + CortexMTensorMaximum(), + ( + torch.tensor([[1.0, -2.0], [3.5, -4.5]]), + torch.tensor([[0.5, -1.0], [4.0, -3.5]]), + ), + ), + "tensor_rand": McuTestCase( + CortexMTensorMaximum(), + ( + torch.rand(2, 2, 2) * 4 - 2, + torch.rand(2, 2, 2) * 4 - 2, + ), + ), + "broadcast": McuTestCase( + CortexMTensorMaximum(), + ( + ramp_tensor(-2, 2, (2, 1, 2)), + ramp_tensor(-3, 3, (1, 2, 1)), + ), + ), + "broadcast_rank4": McuTestCase( + CortexMTensorMaximum(), + ( + ramp_tensor(-4, 4, (1, 2, 3, 1)), + ramp_tensor(-6, 6, (4, 1, 1, 3)), + ), + ), + "broadcast_scalar": McuTestCase( + CortexMTensorMaximum(), + ( + torch.tensor(1.0), + ramp_tensor(-6, 6, (4, 1, 1, 3)), + ), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_maximum(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + ) + + +@parametrize("test_case", test_cases) +def test_implementation_maximum(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation() diff --git a/backends/cortex_m/test/ops/test_minimum.py b/backends/cortex_m/test/ops/test_minimum.py new file mode 100644 index 00000000000..633ccdbf483 --- /dev/null +++ b/backends/cortex_m/test/ops/test_minimum.py @@ -0,0 +1,104 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMSelfMinimum(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_minimum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.minimum(x, x) + + +class CortexMTensorMinimum(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_minimum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x, y): + return torch.minimum(x, y) + + +test_cases = { + "self_rank_1": McuTestCase( + CortexMSelfMinimum(), + (ramp_tensor(-5, 5, (10,)),), + ), + "self_rank_3": McuTestCase( + CortexMSelfMinimum(), + (ramp_tensor(-10, 10, (2, 3, 4)),), + ), + "tensor_small": McuTestCase( + CortexMTensorMinimum(), + ( + torch.tensor([[1.0, -2.0], [3.5, -4.5]]), + torch.tensor([[0.5, -3.0], [3.0, -4.0]]), + ), + ), + "tensor_rand": McuTestCase( + CortexMTensorMinimum(), + ( + torch.rand(2, 2, 2) * 4 - 2, + torch.rand(2, 2, 2) * 4 - 2, + ), + ), + "broadcast": McuTestCase( + CortexMTensorMinimum(), + ( + ramp_tensor(-2, 2, (2, 1, 2)), + ramp_tensor(-3, 3, (1, 2, 1)), + ), + ), + "broadcast_rank4": McuTestCase( + CortexMTensorMinimum(), + ( + ramp_tensor(-4, 4, (1, 2, 3, 1)), + ramp_tensor(-6, 6, (4, 1, 1, 3)), + ), + ), +} + + +xfails = {} + + +@parametrize("test_case", test_cases, xfails=xfails) +def test_dialect_minimum(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + ) + + +@parametrize("test_case", test_cases, xfails=xfails) +def test_implementation_minimum(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation() diff --git a/backends/cortex_m/test/ops/test_mul.py b/backends/cortex_m/test/ops/test_mul.py new file mode 100644 index 00000000000..88dd904eb6e --- /dev/null +++ b/backends/cortex_m/test/ops/test_mul.py @@ -0,0 +1,156 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) +from executorch.backends.test.suite.operators.test_mul import Model + + +class CortexMSelfMul(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_mul_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return x * x + + +class CortexMScalarMul(Model): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_mul_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +class CortexMTensorMul(Model): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_mul_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +test_cases = { + "self_rank_1": McuTestCase( + CortexMSelfMul(), + (ramp_tensor(-5, 5, (10,)),), + ), + "self_rank_2_pos": McuTestCase( + CortexMSelfMul(), + (ramp_tensor(0, 1000, (10, 1)),), + ), + "self_rank_3_neg": McuTestCase( + CortexMSelfMul(), + (ramp_tensor(-100, 0, (2, 2, 2)),), + ), + "self_rank_4_small": McuTestCase( + CortexMSelfMul(), + (ramp_tensor(-0.1, 0.1, (2, 2, 2, 2)),), + ), + "self_rank_5": McuTestCase( + CortexMSelfMul(), + (ramp_tensor(-5, 5, (2, 2, 2, 2, 2)),), + ), + "tensor_scalar": McuTestCase( + CortexMScalarMul(), + (torch.ones(1), 1.0), + ), + "scalar_tensor": McuTestCase( + CortexMScalarMul(), + (1000.0, torch.ones(1)), + ), + "broadcast_1": McuTestCase( + CortexMTensorMul(), + (torch.ones(1), torch.ones(2, 2, 2, 2)), + ), + "broadcast_2": McuTestCase( + CortexMTensorMul(), + (torch.ones((2, 1, 1, 1)), torch.ones(1)), + ), + "broadcast_3": McuTestCase( + CortexMTensorMul(), + ( + ramp_tensor(-2, 2, (2, 1, 2, 1)), + ramp_tensor(-5, 5, (1, 2, 1, 2)), + ), + ), + "broadcast_channels_1": McuTestCase( + CortexMTensorMul(), + ( + ramp_tensor(-2, 2, (1, 8, 1, 1)).to(memory_format=torch.channels_last), + ramp_tensor(-5, 5, (1, 8, 5, 5)).to(memory_format=torch.channels_last), + ), + ), + "broadcast_channels_2": McuTestCase( + CortexMTensorMul(), + ( + ramp_tensor(-5, 5, (2, 8, 5, 5)).to(memory_format=torch.channels_last), + ramp_tensor(-2, 2, (1, 8, 1, 1)).to(memory_format=torch.channels_last), + ), + ), + "broadcast_channels_continous": McuTestCase( + CortexMTensorMul(), + ( + ramp_tensor(-5, 5, (2, 8, 5, 5)), + ramp_tensor(-2, 2, (1, 8, 1, 1)), + ), + ), +} + + +xfail_cases_dialect = { + # Cortex-M quantizer will not quantize multiplicaitons that require broadcasting + # leading to the mul op not being replaced by a cortex-m specific implementation + "broadcast_1": "Broadcasting is not supported in Cortex-M backend", + "broadcast_2": "Broadcasting is not supported in Cortex-M backend", + "broadcast_3": "Broadcasting is not supported in Cortex-M backend", + "broadcast_channels_continous": "Broadcasting channels is not supported in continous memory_format in Cortex-M backend.", +} + + +@parametrize("test_case", test_cases, xfails=xfail_cases_dialect) +def test_dialect_mul(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, + ) + + +@parametrize( + "test_case", + test_cases, +) +def test_implementation_mul(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation(qtol=1) diff --git a/backends/cortex_m/test/ops/test_transpose.py b/backends/cortex_m/test/ops/test_transpose.py new file mode 100644 index 00000000000..de16c2f81ad --- /dev/null +++ b/backends/cortex_m/test/ops/test_transpose.py @@ -0,0 +1,102 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + +OPS_BEFORE_PASSES = { + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, +} + +OPS_AFTER_PASSES = { + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, +} + + +class CortexMPermute(torch.nn.Module): + ops_before_transforms = OPS_BEFORE_PASSES + ops_after_transforms = OPS_AFTER_PASSES + + def __init__(self, perms): + super().__init__() + self.perms = perms + + def forward(self, x): + return x.permute(self.perms) + + +class CortexMTranspose(torch.nn.Module): + ops_before_transforms = OPS_BEFORE_PASSES + ops_after_transforms = OPS_AFTER_PASSES + + def __init__(self, dim0, dim1): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + return x.transpose(self.dim0, self.dim1) + + +class CortexMT(torch.nn.Module): + ops_before_transforms = OPS_BEFORE_PASSES + ops_after_transforms = OPS_AFTER_PASSES + + def forward(self, x): + return x.t() + + +test_cases = { + "permute_nhwc_to_nchw": McuTestCase( + CortexMPermute((0, 3, 1, 2)), + (ramp_tensor(-0.5, 0.5, (2, 3, 4, 2)),), + ), + "permute_nchw_to_nhwc_neg_index": McuTestCase( + CortexMPermute((0, -2, -1, -3)), + (ramp_tensor(10, 100, (2, 3, 4, 2)),), + ), + "permute_rank_1": McuTestCase( + CortexMPermute((0,)), + (ramp_tensor(10, 100, (3)),), + ), + "transpose_1_2": McuTestCase( + CortexMTranspose(1, 2), + (ramp_tensor(-1.0, 1.0, (1, 3, 4)),), + ), + "transpose_0_1": McuTestCase( + CortexMTranspose(0, 1), + (ramp_tensor(-2.0, 2.0, (2, 3, 4, 3)),), + ), + "t_operator": McuTestCase( + CortexMT(), + (ramp_tensor(-0.5, 0.5, (4, 2)),), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_transpose(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, + ) + + +@parametrize("test_case", test_cases) +def test_implementation_transpose(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation(qtol=1) diff --git a/backends/cortex_m/test/targets.bzl b/backends/cortex_m/test/targets.bzl index 5a83be49c58..49cd4579ad6 100644 --- a/backends/cortex_m/test/targets.bzl +++ b/backends/cortex_m/test/targets.bzl @@ -23,7 +23,8 @@ def define_operator_test_target(op): "//executorch/runtime/kernel:kernel_includes", "//executorch/kernels/test:test_util", "//executorch/backends/cortex_m/ops:op_{}".format(op), - "//executorch/backends/cortex_m/ops:cortex_m_generated_lib", + "//executorch/backends/cortex_m/ops:op_quantize_per_tensor", + "//executorch/backends/cortex_m/ops:op_dequantize_per_tensor", "//executorch/backends/cortex_m/ops:cortex_m_generated_lib_headers", ] ) diff --git a/backends/cortex_m/test/test_quantize_op_fusion_pass.py b/backends/cortex_m/test/test_quantize_op_fusion_pass.py deleted file mode 100644 index 1595b0cfbc3..00000000000 --- a/backends/cortex_m/test/test_quantize_op_fusion_pass.py +++ /dev/null @@ -1,369 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import executorch -import executorch.backends.cortex_m.ops.operators # noqa - -import torch - -from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( - QuantizedOpFusionPass, -) -from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import ( - ReplaceQuantNodesPass, -) -from executorch.backends.cortex_m.test.test_helpers_passes_utils import ( - AddQuantizer, - check_count, - get_node_args, -) -from executorch.exir.dialects._ops import ops as exir_ops -from torch.export import export -from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e - - -class TestQuantizedOpFusionPass(unittest.TestCase): - """ - Test suite for the QuantizedOpFusionPass which fuses dequantize->add->quantize patterns - into a single quantized_add operation with AoT-computed parameters. - """ - - def setUp(self): - """Set up common test fixtures""" - self.example_inputs = (torch.randn(4, 8), torch.randn(4, 8)) - - def _prepare_quantized_model(self, model_class): - """Helper to prepare a quantized model for testing""" - model = model_class() - - # Export and quantize - exported_model = export(model.eval(), self.example_inputs, strict=True).module() - prepared_model = prepare_pt2e(exported_model, AddQuantizer()) - quantized_model = convert_pt2e(prepared_model) - - # Export to EXIR Edge - exported = export(quantized_model, self.example_inputs, strict=True) - edge_program = executorch.exir.to_edge( - exported, - compile_config=executorch.exir.EdgeCompileConfig(_check_ir_validity=False), - ) - return edge_program - - def _apply_passes(self, edge_program): - """Apply both ReplaceQuantNodesPass and QuantizedOpFusionPass""" - passes = [QuantizedOpFusionPass(), ReplaceQuantNodesPass()] - final_program = edge_program.transform(passes) - return final_program - - def test_single_add_fusion(self): - """Single add with full Q/DQ pattern should fuse into one quantized_add node""" - - class SingleAddModel(torch.nn.Module): - def forward(self, x, y): - return x + y - - # Prepare model - edge_program = self._prepare_quantized_model(SingleAddModel) - edge_graph = edge_program.exported_program().graph_module - - # Get reference output - reference_output = edge_graph(*self.example_inputs) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Verify fusion occurred - check_count( - transformed_graph, - exir_ops.edge.cortex_m.quantized_add.default, - 1, # Should have exactly 1 fused quantized_add - ) - - # Verify the following - # Before fusion: - # x --> quantize_per_tensor --> dequantize_per_tensor --> add --> quantize_per_tensor --> - # dequantize_per_tensor --> output y --> quantize_per_tensor --> dequantize_per_tensor --^ - # After fusion: - # x --> quantize_per_tensor --> quantized_add --> dequantize_per_tensor --> output - # y --> quantize_per_tensor --^ - check_count( - transformed_graph, exir_ops.edge.cortex_m.quantize_per_tensor.default, 2 - ) - check_count( - transformed_graph, exir_ops.edge.cortex_m.dequantize_per_tensor.default, 1 - ) - check_count(transformed_graph, exir_ops.edge.cortex_m.quantized_add.default, 1) - - # Verify numerical equivalence - fused_output = transformed_graph(*self.example_inputs) - torch.testing.assert_close(reference_output, fused_output, rtol=1e-3, atol=1e-3) - - def test_multiple_add_fusion(self): - """Multiple independent adds should create multiple quantized_add nodes""" - - class MultipleAddModel(torch.nn.Module): - def forward(self, x, y): - z1 = x + y # First add - z2 = x + z1 # Second add - return z2 - - # Prepare model - edge_program = self._prepare_quantized_model(MultipleAddModel) - edge_graph = edge_program.exported_program().graph_module - - # Get reference output - reference_output = edge_graph(*self.example_inputs) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Verify multiple fusions occurred - check_count( - transformed_graph, - exir_ops.edge.cortex_m.quantized_add.default, - 2, # Should have 2 fused quantized_add nodes - ) - - # Verify numerical equivalence - fused_output = transformed_graph(*self.example_inputs) - torch.testing.assert_close(reference_output, fused_output, rtol=1e-3, atol=1e-3) - - def test_no_fusion_without_pattern(self): - """Add without proper Q/DQ pattern should not be fused""" - - class NonQuantizedAddModel(torch.nn.Module): - def forward(self, x, y): - # This will have add but not the full Q/DQ pattern after quantization - return torch.relu(x + y) # ReLU breaks the pattern - - # For this test, we'll create a model that doesn't have the complete pattern - # We need to manually construct a graph that has add without full Q/DQ - - model = NonQuantizedAddModel() - exported = export(model, self.example_inputs, strict=True) - edge_program = executorch.exir.to_edge( - exported, - compile_config=executorch.exir.EdgeCompileConfig(_check_ir_validity=False), - ) - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Verify no fusion occurred - check_count( - transformed_graph, - exir_ops.edge.cortex_m.quantized_add.default, - 0, # Should have no fused quantized_add nodes - ) - - def test_precomputed_parameters(self): - """Fused node should have precomputed multipliers/shifts instead of scales""" - - class SingleAddModel(torch.nn.Module): - def forward(self, x, y): - return x + y - - # Prepare model - edge_program = self._prepare_quantized_model(SingleAddModel) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Get arguments of the fused quantized_add node - quantized_add_args = get_node_args( - transformed_graph, exir_ops.edge.cortex_m.quantized_add.default - ) - - # Should have exactly one quantized_add node - self.assertEqual(len(quantized_add_args), 1) - args = quantized_add_args[0] - - # Verify argument structure: (tensor1, zp1, mult1, shift1, tensor2, zp2, mult2, shift2, out_zp, out_mult, out_shift) - self.assertEqual(len(args), 11, "quantized_add should have 11 arguments") - - # Check that multipliers and shifts are integers (not floats/scales) - # args[2], args[3] = input1 multiplier, shift - # args[6], args[7] = input2 multiplier, shift - # args[9], args[10] = output multiplier, shift - for i in [2, 3, 6, 7, 9, 10]: # multiplier and shift positions - self.assertIsInstance( - args[i], int, f"Argument {i} should be an integer (precomputed)" - ) - - def test_mixed_fusion_pattern(self): - """Mixed pattern (some fusable, some not) should partially fuse""" - - class MixedModel(torch.nn.Module): - def forward(self, x, y): - z1 = x + y # This should fuse - z2 = torch.relu(z1) # ReLU breaks next fusion - z3 = z2 + x # This won't have full Q/DQ pattern - return z3 - - # Prepare model - edge_program = self._prepare_quantized_model(MixedModel) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Should have partial fusion (at least 1, but not necessarily all adds) - quantized_add_count = sum( - 1 - for node in transformed_graph.graph.nodes - if node.op == "call_function" - and node.target == exir_ops.edge.cortex_m.quantized_add.default - ) - - self.assertGreaterEqual( - quantized_add_count, 1, "Should have at least 1 fused operation" - ) - - def test_different_tensor_shapes(self): - """Different tensor shapes should still fuse correctly""" - - class SingleAddModel(torch.nn.Module): - def forward(self, x, y): - return x + y - - # Test with different input shapes - for shape in [(2, 3), (10, 20, 30), (1,)]: - with self.subTest(shape=shape): - inputs = (torch.randn(shape), torch.randn(shape)) - - model = SingleAddModel() - exported_model = export(model.eval(), inputs, strict=True).module() - prepared_model = prepare_pt2e(exported_model, AddQuantizer()) - quantized_model = convert_pt2e(prepared_model) - - exported = export(quantized_model, inputs, strict=True) - edge_program = executorch.exir.to_edge( - exported, - compile_config=executorch.exir.EdgeCompileConfig( - _check_ir_validity=False - ), - ) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Verify fusion occurred regardless of shape - check_count( - transformed_graph, exir_ops.edge.cortex_m.quantized_add.default, 1 - ) - - def test_aot_parameter_computation_accuracy(self): - """Verify that AoT-computed parameters match runtime computation""" - - class SingleAddModel(torch.nn.Module): - def forward(self, x, y): - return x + y - - # Prepare model - edge_program = self._prepare_quantized_model(SingleAddModel) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Get the fused node arguments - quantized_add_args = get_node_args( - transformed_graph, exir_ops.edge.cortex_m.quantized_add.default - )[0] - - # Extract the computed multipliers and shifts - input1_mult, input1_shift = quantized_add_args[2], quantized_add_args[3] - input2_mult, input2_shift = quantized_add_args[6], quantized_add_args[7] - output_mult, output_shift = quantized_add_args[9], quantized_add_args[10] - - # Verify they are reasonable values - # Multipliers should be in int32 range - self.assertTrue(-(2**31) <= input1_mult < 2**31) - self.assertTrue(-(2**31) <= input2_mult < 2**31) - self.assertTrue(-(2**31) <= output_mult < 2**31) - - # Shifts should be reasonable (typically -31 to 31) - self.assertTrue(-50 <= input1_shift <= 50) - self.assertTrue(-50 <= input2_shift <= 50) - self.assertTrue(-50 <= output_shift <= 50) - - # Output multiplier should be close to 2^30 (for 1.0 scale) - self.assertAlmostEqual(output_mult, 2**30, delta=1000) - self.assertEqual(output_shift, -1) - - def test_executorch_program_generation(self): - """Verify ExecuTorch program generation with fused ops""" - - class SingleAddModel(torch.nn.Module): - def forward(self, x, y): - return x + y - - # Prepare model - edge_program = self._prepare_quantized_model(SingleAddModel) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - - # Generate ExecutorTorch program - executorch_program = transformed_program.to_executorch() - - # Verify the program contains the expected fused operator - operator_names = [ - op.name - for op in executorch_program.executorch_program.execution_plan[0].operators - ] - - self.assertIn("cortex_m::quantized_add", operator_names) - self.assertIn("cortex_m::quantize_per_tensor", operator_names) - self.assertIn("cortex_m::dequantize_per_tensor", operator_names) - # quantize_per_tensor --> dequantize_per_tensor --> add --> quantize_per_tensor --> dequantize_per_tensor - # (input quant) (dequant) (fp32 add) (re-quant) (dequant) - # ↓ - # Fusion Pass detects pattern: - # dequantize_per_tensor --> quantized_add (Fused node) --> quantize_per_tensor - - def test_broadcastable_shapes(self): - """Verify that broadcastable shapes are supported""" - - class BroadcastAddModel(torch.nn.Module): - def forward(self, x, y): - return x + y - - # input broadcastable shapes - inputs = (torch.randn(4, 1), torch.randn(4, 8)) - print(inputs) - - # Prepare quantized model - edge_program = self._prepare_quantized_model(BroadcastAddModel) - - # Get unfused output - unfused_graph = edge_program.exported_program().graph_module - unfused_output = unfused_graph(*inputs) - if isinstance(unfused_output, tuple): - unfused_output = unfused_output[0] - - # Apply fusion pass - fused_program = self._apply_passes(edge_program) - fused_graph = fused_program.exported_program().graph_module - fused_output = fused_graph(*inputs) - if isinstance(fused_output, tuple): - fused_output = fused_output[0] - - # Check fusion occurred - check_count(fused_graph, exir_ops.edge.cortex_m.quantized_add.default, 1) - - # Compare fused vs unfused (both quantized) - torch.testing.assert_close(fused_output, unfused_output, rtol=1e-3, atol=1e-3) - - -if __name__ == "__main__": - unittest.main() diff --git a/backends/cortex_m/test/tester.py b/backends/cortex_m/test/tester.py new file mode 100644 index 00000000000..ce5f16195c0 --- /dev/null +++ b/backends/cortex_m/test/tester.py @@ -0,0 +1,112 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Any + +import torch +from executorch.backends.arm.test.common import get_u55_compile_spec +from executorch.backends.arm.test.tester.arm_tester import Serialize +from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager + +from executorch.backends.cortex_m.quantizer.quantizer import CortexMQuantizer +from executorch.backends.test.harness import Tester as TesterBase +from executorch.backends.test.harness.stages import ( + Export, + Quantize, + RunPasses, + StageType, + ToEdge, + ToExecutorch, +) + +from executorch.exir import EdgeCompileConfig + + +class CortexMQuantize(Quantize): + def __init__(self): + quantizer = CortexMQuantizer() + super().__init__(quantizer) + + +class CortexMToEdge(ToEdge): + def __init__(self): + config = EdgeCompileConfig( + preserve_ops=[ + torch.ops.aten.linear.default, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardsigmoid_.default, + torch.ops.aten.hardswish.default, + torch.ops.aten.hardswish_.default, + ] + ) + super().__init__(config) + + +class CortexMRunPasses(RunPasses): + def __init__(self): + super().__init__( + CortexMPassManager, + CortexMPassManager.pass_list, + ) + + +class CortexMSerialize(Serialize): + def __init__(self): + compile_spec = get_u55_compile_spec() + super().__init__(compile_spec, 1024) + + +cortex_m_stage_classes = { + StageType.EXPORT: Export, + StageType.QUANTIZE: CortexMQuantize, + StageType.RUN_PASSES: CortexMRunPasses, + StageType.SERIALIZE: Serialize, + StageType.TO_EDGE: CortexMToEdge, + StageType.TO_EXECUTORCH: ToExecutorch, + StageType.SERIALIZE: CortexMSerialize, +} + + +class CortexMTester(TesterBase): + def __init__(self, module, example_inputs): + super().__init__(module, example_inputs, cortex_m_stage_classes) + + def test_dialect(self, ops_before_transforms, ops_after_transforms, qtol=0): + """ + Test the python dialect op implementation. + """ + self.quantize() + self.export() + self.to_edge() + self.check_count(ops_before_transforms) + self.run_passes() + self.check_count(ops_after_transforms) + self.run_method_and_compare_outputs(inputs=self.example_inputs, qtol=qtol) + + def test_implementation(self, qtol=0): + """ + Test the optimized op implementation in simulation + """ + self.quantize() + self.export() + self.to_edge() + self.run_passes() + self.to_executorch() + self.serialize() + self.run_method_and_compare_outputs(inputs=self.example_inputs, qtol=qtol) + + +@dataclass +class McuTestCase: + model: torch.nn.Module + example_inputs: tuple[Any] + + +def ramp_tensor(start: int, end: int, shape: tuple[int]) -> torch.Tensor: + return torch.linspace(start, end, steps=torch.prod(torch.tensor(shape))).reshape( + shape + ) diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt new file mode 100644 index 00000000000..c85e07d4b59 --- /dev/null +++ b/backends/cuda/CMakeLists.txt @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Build AOTI CUDA backend for runtime. +# +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# +cmake_minimum_required(VERSION 3.29) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +# Use dynamic linking for CUDA runtime +set(CUDA_USE_STATIC_CUDA_RUNTIME OFF) + +find_package(CUDAToolkit REQUIRED) + +# Use ExecutorTorch's standard way to find PyTorch libraries for AOTI +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +find_package_torch() + +# CUDA tensor maker for backends that support incontiguous tensors +set(_tensor_maker_sources runtime/tensor/tensor_maker.cpp) +add_library(cuda_tensor_maker STATIC ${_tensor_maker_sources}) +target_include_directories( + cuda_tensor_maker + PUBLIC $ $ + $ +) +target_compile_options( + cuda_tensor_maker + PUBLIC $<$:/EHsc /GR> + $<$>:-fexceptions -frtti -fPIC> +) +# Ensure symbols are exported properly +if(APPLE) + target_link_options(cuda_tensor_maker PUBLIC -Wl,-export_dynamic) +else() + target_link_options( + cuda_tensor_maker PUBLIC + $<$>:-Wl,--export-dynamic> + ) +endif() + +# Link against ExecuTorch core libraries +target_link_libraries( + cuda_tensor_maker PRIVATE executorch_core ${CMAKE_DL_LIBS} +) +executorch_target_link_options_shared_lib(cuda_tensor_maker) + +install( + TARGETS cuda_tensor_maker + EXPORT ExecuTorchTargets + DESTINATION lib +) + +# Platform utilities (load_library, close_library, etc.) +set(_cuda_platform_sources runtime/platform/platform.cpp) +add_library(cuda_platform STATIC ${_cuda_platform_sources}) + +target_include_directories( + cuda_platform + PUBLIC $ $ + $ +) + +target_compile_options( + cuda_platform + PUBLIC $<$:/EHsc /GR> + $<$>:-fexceptions -frtti -fPIC> +) + +# Link against ExecuTorch core libraries +target_link_libraries(cuda_platform PRIVATE executorch_core ${CMAKE_DL_LIBS}) + +install( + TARGETS cuda_platform + EXPORT ExecuTorchTargets + DESTINATION lib +) + +# CUDA-specific AOTI shim symbols (dynamically linked) +set(_aoti_cuda_shim_sources + runtime/shims/memory.cpp runtime/shims/tensor_attribute.cpp + runtime/guard.cpp runtime/shims/cuda_guard.cpp runtime/shims/int4mm.cu + ${EXECUTORCH_ROOT}/backends/aoti/common_shims.cpp +) + +add_library(aoti_cuda_shims SHARED ${_aoti_cuda_shim_sources}) + +# Define export macros for shared library +if(MSVC) + target_compile_definitions(aoti_cuda_shims PRIVATE EXPORT_AOTI_FUNCTIONS) + + # Ensure proper DLL import/export library naming on Windows + set_target_properties( + aoti_cuda_shims PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS OFF + ) +endif() + +target_include_directories( + aoti_cuda_shims + PUBLIC ${CUDAToolkit_INCLUDE_DIRS} $ + $ +) + +target_compile_options( + aoti_cuda_shims + PUBLIC $<$:/EHsc /GR> + $<$>:-fexceptions -frtti -fPIC> +) + +# Ensure symbols are exported properly +target_link_options( + aoti_cuda_shims PUBLIC $<$>:-Wl,--export-dynamic> +) + +# Link against CUDA::cudart, common AOTI library, cuda_tensor_maker, and +# platform utilities +target_link_libraries( + aoti_cuda_shims + PRIVATE cuda_platform + PUBLIC extension_tensor cuda_tensor_maker CUDA::cudart ${CMAKE_DL_LIBS} +) + +if(NOT MSVC) + executorch_target_link_options_shared_lib(aoti_cuda_shims) +endif() + +install( + TARGETS aoti_cuda_shims + EXPORT ExecuTorchTargets + DESTINATION lib +) + +# CUDA backend implementation +set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp) + +# CUDA backend implementation +add_library(aoti_cuda_backend STATIC ${_aoti_cuda_backend_sources}) + +target_include_directories( + aoti_cuda_backend + PUBLIC ${CUDAToolkit_INCLUDE_DIRS} $ + $ +) +target_compile_options( + aoti_cuda_backend + PUBLIC $<$:/EHsc /GR> + $<$>:-fexceptions -frtti -fPIC> +) +# Ensure symbols are exported properly +target_link_options( + aoti_cuda_backend PUBLIC + $<$>:-Wl,--export-dynamic> +) + +# Link against shims library and other dependencies On Windows (MSVC), use +# PRIVATE linkage for aoti_cuda_shims since the DLL is copied to the executable +# directory. On other platforms, use PUBLIC so the dependency propagates to +# consumers. +target_link_libraries( + aoti_cuda_backend PUBLIC cuda_platform extension_tensor cuda_tensor_maker + CUDA::cudart ${CMAKE_DL_LIBS} +) + +if(MSVC) + target_link_libraries(aoti_cuda_backend PRIVATE aoti_cuda_shims) +else() + target_link_libraries(aoti_cuda_backend PUBLIC aoti_cuda_shims) +endif() + +executorch_target_link_options_shared_lib(aoti_cuda_backend) + +install( + TARGETS aoti_cuda_backend + EXPORT ExecuTorchTargets + DESTINATION lib +) diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS new file mode 100644 index 00000000000..3ae4eec6680 --- /dev/null +++ b/backends/cuda/TARGETS @@ -0,0 +1,66 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "cuda_backend", + srcs = [ + "cuda_backend.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + ":triton_replacement_pass", + "//caffe2:torch", + "//executorch/backends/aoti/passes:passes", + "//executorch/exir/_serialize:lib", + "//executorch/exir/backend:backend_details", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/backends/aoti:aoti_backend", + ], +) + +runtime.python_library( + name = "cuda_partitioner", + srcs = [ + "cuda_partitioner.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/aoti:aoti_partitioner", + ], +) + +runtime.python_library( + name = "triton_kernels", + srcs = [ + "triton/kernels/__init__.py", + "triton/kernels/sdpa.py", + ], + visibility = [ + "//executorch/backends/cuda/...", + ], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_library( + name = "triton_replacement_pass", + srcs = [ + "triton/__init__.py", + "triton/replacement_pass.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + ":triton_kernels", + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ], +) diff --git a/backends/cuda/__init__.py b/backends/cuda/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/cuda/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py new file mode 100644 index 00000000000..dbbd79f4881 --- /dev/null +++ b/backends/cuda/cuda_backend.py @@ -0,0 +1,244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +import typing +from importlib import resources +from typing import Any, Dict, final, List, Optional + +import torch +from executorch.backends.aoti.aoti_backend import AotiBackend +from executorch.backends.cuda.triton.replacement_pass import ( + ReplaceEdgeOpWithTritonOpPass, +) +from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import BackendDetails +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._inductor.decomposition import conv1d_to_conv2d +from torch.nn.attention import SDPBackend + + +@final +@experimental( + "This API and all of cuda backend related functionality are experimental." +) +class CudaBackend(AotiBackend, BackendDetails): + """ + CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate + optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices + using the Executorch runtime. + """ + + @classmethod + def get_device_name(cls) -> str: + return "cuda" + + @staticmethod + def _find_ptxas_for_version(cuda_version: str) -> Optional[str]: # noqa: C901 + """ + Find ptxas binary that matches the expected CUDA version. + Returns the path to ptxas if found and version matches, None otherwise. + """ + expected_version_marker = f"/cuda-{cuda_version}/" + + def _validate_ptxas_version(path: str) -> bool: + """Check if ptxas at given path matches expected CUDA version.""" + if not os.path.exists(path): + return False + resolved = os.path.realpath(path) + return expected_version_marker in resolved + + # 1. Try PyTorch's CUDA_HOME + try: + from torch.utils.cpp_extension import CUDA_HOME + + if CUDA_HOME: + ptxas_path = os.path.join(CUDA_HOME, "bin", "ptxas") + if _validate_ptxas_version(ptxas_path): + return ptxas_path + except ImportError: + pass + + # 2. Try CUDA_HOME / CUDA_PATH environment variables + for env_var in ("CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"): + cuda_home = os.environ.get(env_var) + if cuda_home: + ptxas_path = os.path.join(cuda_home, "bin", "ptxas") + if _validate_ptxas_version(ptxas_path): + return ptxas_path + + # 3. Try versioned path directly + versioned_path = f"/usr/local/cuda-{cuda_version}/bin/ptxas" + if os.path.exists(versioned_path): + return versioned_path + + # 4. Try system PATH via shutil.which + ptxas_in_path = shutil.which("ptxas") + if ptxas_in_path and _validate_ptxas_version(ptxas_in_path): + return ptxas_in_path + + # 5. Try default symlink path as last resort + default_path = "/usr/local/cuda/bin/ptxas" + if _validate_ptxas_version(default_path): + return default_path + + return None + + @staticmethod + def _setup_cuda_environment_for_fatbin() -> bool: + """ + Configure CUDA environment variables based on detected CUDA version and GPU architecture. + These are needed to compile fatbin kernels for more portable binaries on older CUDA versions. + Returns True if setup succeeded or if setup was skipped (CUDA >= 12.9), false otherwise. + """ + try: + # Detect CUDA version from torch + cuda_version = torch.version.cuda + if cuda_version is None: + return False + + major, minor = map(int, cuda_version.split(".")[:2]) + + # Only set up environment variables for CUDA < 12.9 + if major > 12 or (major == 12 and minor >= 9): + return True + + # Set TRITON_PTXAS_PATH for CUDA 12.6+ + if major == 12 and minor >= 6: + ptxas_path = CudaBackend._find_ptxas_for_version(cuda_version) + if ptxas_path is None: + return False + os.environ["TRITON_PTXAS_PATH"] = ptxas_path + + # Get compute capability of current CUDA device + device = torch.cuda.current_device() + capability = torch.cuda.get_device_capability(device) + os.environ["TORCH_CUDA_ARCH_LIST"] = f"{capability[0]}.{capability[1]}" + return True + except Exception: + return False + + @classmethod + def get_supported_fallback_kernels(cls) -> Dict[str, Any]: + return { + "at::_ops::_weight_int4pack_mm::call": None, + } + + @classmethod + def get_decomposition_table(cls) -> Dict[Any, Any]: + return { + torch.ops.aten.conv1d.default: conv1d_to_conv2d, + } + + @classmethod + def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]: + """ + Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass. + + The Triton kernel replacement behavior can be controlled via compile_specs: + - triton_kernel_mode="ON": Always use Triton kernels + - triton_kernel_mode="OFF": Never use Triton kernels and fallback to other implementations like cuda or decomposed operator. + """ + # Parse compile_specs for triton_kernel_mode + triton_kernel_mode = "ON" # Default mode + for spec in compile_specs: + if spec.key == "triton_kernel_mode": + mode = spec.value.decode("utf-8").upper() + if mode not in ["ON", "OFF"]: + raise ValueError( + f"Invalid triton_kernel_mode: {mode}. " + f"Expected 'ON' or 'OFF'." + ) + triton_kernel_mode = mode + + return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else [] + + @classmethod + def get_aoti_compile_options( + cls, compile_specs: List[CompileSpec] + ) -> Dict[str, typing.Any]: + """ + Get AOTI compile options for CUDA backend. + Options may vary based on platform (Linux vs Windows). + """ + + # Configure CUDA environment variables based on detected version + emit_multi_arch_kernel = CudaBackend._setup_cuda_environment_for_fatbin() + # Base options for all platforms + options: Dict[str, typing.Any] = { + # Disable this to support sdpa decomposition + # TODO(gasoonjia): remove it after pin bump to latest pytorch + "loop_ordering_after_fusion": False, + # Better model precision + "emulate_precision_casts": True, + # Embed CUDA kernel binaries directly into the compiled shared object + "aot_inductor.embed_kernel_binary": True, + # Do not link against the full PyTorch/libtorch library + "aot_inductor.link_libtorch": False, + # Separate weight constants from the .so file + "aot_inductor.package": True, + "aot_inductor.package_constants_in_so": False, + # Store weight constants on disk in a binary blob + "aot_inductor.package_constants_on_disk_format": "binary_blob", + # Enable maximum automatic tuning for optimal performance + "max_autotune": True, + # Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch + "max_autotune_gemm_backends": "TRITON", + # Use TRITON backend for convolution operations tuning only to avoid using operators in libtorch + "max_autotune_conv_backends": "TRITON", + "aot_inductor.emit_multi_arch_kernel": emit_multi_arch_kernel, + } + + # Parse compile_specs to check for platform + platform = "linux" + shim_library_path = None + for spec in compile_specs: + if spec.key == "platform": + platform = spec.value.decode("utf-8") + if spec.key == "shim_library_path": + shim_library_path = spec.value.decode("utf-8") + + # Add platform-specific options + if platform == "windows": + # For Windows, get default shim library path if not provided + if shim_library_path is None: + lib_dir = resources.files("executorch").joinpath("data/lib") + shim_library_path = str(lib_dir) + + options.update( + { + "aot_inductor.cross_target_platform": "windows", + "aot_inductor.aoti_shim_library": "aoti_cuda_shims", + "aot_inductor.aoti_shim_library_path": shim_library_path, + "aot_inductor.precompile_headers": False, + } + ) + else: + # Linux platform + assert ( + shim_library_path is None + ), "shim_library_path should not be set for Linux" + + return options + + @classmethod + def get_extra_aoti_compile_context_manager(cls): + """ + Return SDPA MATH backend context manager for CUDA compilation. + + This context manager plays as a fallback solution for any remaining PyTorch SDPA + operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation. + + Note: + - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass, + this context manager will have no effect on those ops (they are no longer + PyTorch SDPA ops). + - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this + context manager will force them to use the MATH backend, causing them to + be automatically decomposed during compilation. + """ + return torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) diff --git a/backends/cuda/cuda_partitioner.py b/backends/cuda/cuda_partitioner.py new file mode 100644 index 00000000000..e8f1276d5eb --- /dev/null +++ b/backends/cuda/cuda_partitioner.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import final, List + +from executorch.backends.aoti.aoti_partitioner import AotiPartitioner +from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip +from executorch.exir._warnings import experimental +from executorch.exir.backend.compile_spec_schema import CompileSpec + + +@final +@experimental( + "This API and all of cuda backend related functionality are experimental." +) +class CudaPartitioner(AotiPartitioner): + """ + CUDA partitioner driven by AOTInductor backend. + """ + + def __init__(self, compile_spec: List[CompileSpec]) -> None: + super().__init__(CudaBackend.__name__, compile_spec) diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS new file mode 100644 index 00000000000..a85f3a7e6a3 --- /dev/null +++ b/backends/cuda/runtime/TARGETS @@ -0,0 +1,110 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args") + +oncall("executorch") + +runtime.cxx_library( + name = "cuda_platform", + srcs = [ + "platform/platform.cpp", + ], + headers = [ + "platform/platform.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + "//executorch/runtime/core:core", + ], + nvcc_flags = get_nvcc_arch_args() + [ + "-_NVCC_HOST_COMPILER_FLAG_", + "gcc", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], +) + +runtime.cxx_library( + name = "tensor_maker", + srcs = [ + "tensor/tensor_maker.cpp", + ], + headers = [ + "tensor/tensor_maker.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/util:tensor_util", + ], +) + +runtime.cxx_library( + name = "runtime_shims", + srcs = [ + "guard.cpp", + "shims/cuda_guard.cpp", + "shims/int4mm.cu", + "shims/memory.cpp", + "shims/tensor_attribute.cpp", + ], + headers = [ + "guard.h", + "shims/cuda_guard.h", + "shims/int4mm.cuh", + "shims/int4mm.h", + "shims/memory.h", + "shims/tensor_attribute.h", + "utils.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + # Constructor needed for backend registration. + compiler_flags = ["-Wno-global-constructors"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + ":tensor_maker", + "//executorch/backends/aoti:common_shims", + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/platform:platform", + "//executorch/backends/cuda/runtime:cuda_platform", + ], + nvcc_flags = get_nvcc_arch_args() + [ + "-_NVCC_HOST_COMPILER_FLAG_", + "gcc", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], +) + +runtime.cxx_library( + name = "cuda_backend", + srcs = [ + "cuda_backend.cpp", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + # Constructor needed for backend registration. + compiler_flags = ["-Wno-global-constructors"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + ":runtime_shims", + "//executorch/backends/aoti:aoti_common", + "//executorch/runtime/backend:interface", + "//executorch/runtime/core/exec_aten/util:tensor_util", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], +) diff --git a/backends/cuda/runtime/aoti_cuda_shims.lib b/backends/cuda/runtime/aoti_cuda_shims.lib new file mode 100644 index 00000000000..bd6cc53bf07 Binary files /dev/null and b/backends/cuda/runtime/aoti_cuda_shims.lib differ diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp new file mode 100644 index 00000000000..0cef859ddfb --- /dev/null +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -0,0 +1,379 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// Include our shim layer headers +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using namespace std; +using namespace aoti; + +using executorch::aten::ScalarType; +using executorch::runtime::ArrayRef; +using executorch::runtime::Backend; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::NamedDataMap; +using executorch::runtime::Result; +using executorch::runtime::Span; +using executorch::runtime::etensor::Tensor; + +class ET_EXPERIMENTAL CudaBackend final + : public ::executorch::runtime::BackendInterface { + private: + Error load_function_pointers_into_handle( + void* so_handle, + AOTIDelegateHandle* handle) const { +#define LOAD_SYMBOL(member, name) \ + do { \ + auto symbol_res = get_function(so_handle, #name); \ + if (!symbol_res.ok()) { \ + return symbol_res.error(); \ + } \ + handle->member = reinterpret_cast(symbol_res.get()); \ + } while (0) + + LOAD_SYMBOL(create_with_device, AOTInductorModelContainerCreateWithDevice); + + LOAD_SYMBOL(delete_container, AOTInductorModelContainerDelete); + + LOAD_SYMBOL(get_num_inputs, AOTInductorModelContainerGetNumInputs); + + LOAD_SYMBOL(get_num_outputs, AOTInductorModelContainerGetNumOutputs); + + LOAD_SYMBOL(run, AOTInductorModelContainerRun); +#undef LOAD_SYMBOL + + auto symbol_res = + get_function(so_handle, "AOTInductorModelUpdateConstantsFromBlob"); + if (symbol_res.ok()) { + handle->update_constants_from_blob = + reinterpret_cast( + symbol_res.get()); + } else { + ET_LOG( + Info, + "Failed to load AOTInductorModelUpdateConstantsFromBlob. This .so is probably compiled on an old version of torch (<2.9.0)"); + } + return Error::Ok; + } + + public: + bool is_available() const override { + return 1; + } + + // Once per loaded binary blob + Result init( + BackendInitContext& context, + FreeableBuffer* processed, // This will be a empty buffer + ArrayRef compile_specs // This will be my empty list + ) const override { + std::string method_name; + for (const CompileSpec& spec : compile_specs) { + if (std::strcmp(spec.key, "method_name") == 0) { + method_name.assign( + static_cast(spec.value.buffer), + spec.value.nbytes); // no nullptr guarantee, so pass size + break; + } + } + + std::string so_blob_key = + method_name.empty() ? "so_blob" : method_name + "_so_blob"; + + const NamedDataMap* named_data_map = context.get_named_data_map(); + auto aoti_dso_buffer = named_data_map->get_data(so_blob_key.c_str()); + ET_CHECK_OR_RETURN_ERROR( + aoti_dso_buffer.ok(), + Internal, + "Failed to get data for key %s: 0x%x", + so_blob_key.c_str(), + static_cast(aoti_dso_buffer.error())); + + // Generate dynamic temporary file path + filesystem::path temp_dir = filesystem::temp_directory_path(); + filesystem::path so_path = + temp_dir / (so_blob_key + to_string(get_process_id()) + ".so"); + + // Create a temporary file + ofstream outfile(so_path, ios::binary); + + // Write the ELF buffer to the temporary file + ET_LOG( + Info, + "Writing %zu bytes to %s", + aoti_dso_buffer->size(), + so_path.c_str()); + + outfile.write( + static_cast(aoti_dso_buffer->data()), + aoti_dso_buffer->size()); + + ET_CHECK_OR_RETURN_ERROR( + outfile, AccessFailed, "Failed to write to file %s", so_path.c_str()); + + // Finish writing the file to disk + outfile.close(); + + // Free the buffer immediately after writing to disk + aoti_dso_buffer->Free(); + // Load the lib + Result lib_handle_res = load_library(so_path); + if (!lib_handle_res.ok()) { + return lib_handle_res.error(); + } + void* lib_handle = lib_handle_res.get(); + + processed->Free(); + + // Create handle and load function pointers into it + AOTIDelegateHandle* handle = new AOTIDelegateHandle(); + handle->so_handle = lib_handle; + handle->so_path = so_path.string(); + + // Load function pointers specific to this handle's shared library + ET_CHECK_OK_OR_RETURN_ERROR( + load_function_pointers_into_handle(lib_handle, handle)); + + AOTInductorModelContainerHandle container_handle = nullptr; + + ET_CHECK_OK_OR_RETURN_ERROR( + handle->create_with_device(&container_handle, 1, "cuda", nullptr)); + + ET_LOG(Info, "container_handle = %p", container_handle); + + handle->container_handle = container_handle; + + // Look into named data map for constant data + std::string weights_blob_key = + method_name.empty() ? "weights_blob" : method_name + "_weights_blob"; + auto buffer_res = named_data_map->get_data(weights_blob_key.c_str()); + if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) { + ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str()); + const void* weights_blob = buffer_res->data(); + // Feed the weights blob into the container. Under the hood it's copying + // weights, so we should free the buffer immediately. + ET_CHECK_OK_OR_RETURN_ERROR(handle->update_constants_from_blob( + handle->container_handle, static_cast(weights_blob))); + buffer_res->Free(); + } + // Create a CUDA stream for asynchronous execution + cudaStream_t cuda_stream; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream)); + handle->cuda_stream = static_cast(cuda_stream); + + return (DelegateHandle*)handle; // Return the handle post-processing + } + + // Once per execution + Error execute( + BackendExecutionContext& context, + DelegateHandle* handle_, + Span args) const override { + AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + + size_t n_inputs; + handle->get_num_inputs(handle->container_handle, &n_inputs); + + size_t n_outputs; + handle->get_num_outputs(handle->container_handle, &n_outputs); + + ET_CHECK_OR_RETURN_ERROR( + n_inputs + n_outputs == args.size(), + InvalidArgument, + "number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.", + n_inputs, + n_outputs, + args.size()) + + // NOTE: ExecuTorch tensors are always on CPU/host memory + // We need to create GPU copies for CUDA kernel execution + std::vector gpu_inputs( + n_inputs); // GPU copies for kernel execution + std::vector gpu_outputs( + n_outputs); // GPU tensors for kernel output + + // Process input tensors: ExecuTorch provides CPU tensors, create GPU + // copies + for (int i = 0; i < n_inputs; i++) { + // Get tensor dimensions and properties from ExecuTorch CPU tensor + auto cpu_tensor = &(args[i]->toTensor()); + auto sizes = cpu_tensor->sizes(); + auto scalar_type = cpu_tensor->scalar_type(); + + // Create GPU tensor with same shape + std::vector sizes_vec(sizes.begin(), sizes.end()); + + AOTITensorHandle gpu_input_handle; + Error create_err = aoti_torch_empty_strided( + sizes_vec.size(), + sizes_vec.data(), + nullptr, // use default strides + static_cast(scalar_type), + 1, // device_type = cuda + 0, // device_index = 0 + &gpu_input_handle); + + ET_CHECK_OR_RETURN_ERROR( + create_err == Error::Ok, + Internal, + "Failed to create GPU tensor for input %d", + i); + + gpu_inputs[i] = gpu_input_handle; + + // Copy data from CPU to GPU + ET_CHECK_OR_RETURN_ERROR( + aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0) == Error::Ok, + Internal, + "Failed to copy input %d from CPU to GPU", + i); + } + // Process output tensors: create GPU counterparts for ExecuTorch CPU + // tensors + for (int i = 0; i < n_outputs; i++) { + // Get output tensor dimensions from ExecuTorch CPU tensor + auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); + auto sizes = cpu_output_tensor->sizes(); + auto scalar_type = cpu_output_tensor->scalar_type(); + + // Create GPU tensor with same shape for kernel output + std::vector sizes_vec(sizes.begin(), sizes.end()); + + AOTITensorHandle gpu_output_handle; + Error create_err = aoti_torch_empty_strided( + sizes_vec.size(), + sizes_vec.data(), + nullptr, // use default strides + static_cast(scalar_type), + 1, // device_type = cuda + 0, // device_index = 0 + &gpu_output_handle); + + ET_CHECK_OR_RETURN_ERROR( + create_err == Error::Ok, + Internal, + "Failed to create GPU tensor for output %d", + i); + + gpu_outputs[i] = gpu_output_handle; + } + // Run AOTI container with GPU tensors + AOTIRuntimeError error = handle->run( + handle->container_handle, + gpu_inputs.data(), // Use GPU input tensors + n_inputs, + gpu_outputs.data(), // Use GPU output tensors + n_outputs, + handle->cuda_stream, // Pass the actual CUDA stream + nullptr); // proxy_executor_handle can remain nullptr + + ET_CHECK_OR_RETURN_ERROR( + error == Error::Ok, + Internal, + "AOTInductorModelContainerRun failed with error code %d", + error); + + // Copy GPU output results back to CPU output tensors + for (int i = 0; i < n_outputs; i++) { + auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); + // For DYNAMIC_BOUND tensors we try to resize + ET_CHECK_OK_OR_RETURN_ERROR( + resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()), + "Error resizing tensor at output index %d", + i); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0), + "Failed to copy GPU output %d back to CPU", + i); + } + + return Error::Ok; + } + + void destroy(DelegateHandle* handle_) const override { + if (handle_ == nullptr) { + return; + } + AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + + // Destroy the CUDA stream if it exists + if (handle->cuda_stream != nullptr) { + cudaStream_t cuda_stream = static_cast(handle->cuda_stream); + cudaError_t stream_err = cudaStreamDestroy(cuda_stream); + ET_CHECK_OR_LOG_ERROR( + stream_err == cudaSuccess, + "Failed to destroy CUDA stream: %s", + cudaGetErrorString(stream_err)); + handle->cuda_stream = nullptr; + } + + // NOTE: AOTInductorModelContainerDelete does not work correctly with + // multiple .so files. Deleting one container frees shared resources, + // which causes segmentation faults when attempting to delete other + // containers. As a workaround, we skip explicit container deletion + // and defer cleanup to the OS. + // TODO(gasoonjia): Find a proper solution for safe container deletion. + // AOTInductorModelContainerDelete(handle->container_handle); + + // Now close the shared library + auto err = Error::Ok; + if (handle->so_handle != nullptr) { + err = close_library(handle->so_handle); + } + + // Remove the temporary shared library file + if (!handle->so_path.empty()) { + std::error_code remove_error; + std::filesystem::remove(handle->so_path, remove_error); + ET_CHECK_OR_LOG_ERROR( + !remove_error, + "Failed to remove temporary shared library %s: %s", + handle->so_path.c_str(), + remove_error.message().c_str()); + } + + delete handle; + clear_all_tensors(); + } +}; + +} // namespace executorch::backends::cuda + +namespace executorch::backends { +namespace { +auto cls = cuda::CudaBackend(); +executorch::runtime::Backend backend{"CudaBackend", &cls}; +static executorch::runtime::Error success_with_compiler = + register_backend(backend); +} // namespace +} // namespace executorch::backends diff --git a/backends/cuda/runtime/guard.cpp b/backends/cuda/runtime/guard.cpp new file mode 100644 index 00000000000..8de959b6c8a --- /dev/null +++ b/backends/cuda/runtime/guard.cpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch::backends::cuda { + +namespace { +// Thread-local stream storage (private to this file) +thread_local std::unordered_map current_streams_; +} // namespace + +Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index) { + if (device_index == -1) { + // Get current device if not specified + int current_device; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(¤t_device)); + device_index = current_device; + } + + current_streams_[device_index] = stream; + return Error::Ok; +} + +Result getCurrentCUDAStream(DeviceIndex device_index) { + if (device_index == -1) { + int current_device; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(¤t_device)); + device_index = current_device; + } + + auto it = current_streams_.find(device_index); + if (it != current_streams_.end()) { + return it->second; + } + + cudaStream_t stream; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&stream)); + setCurrentCUDAStream(stream, device_index); + return stream; +} + +CUDAGuard::CUDAGuard(CUDAGuard&& other) noexcept + : original_device_index_(other.original_device_index_), + current_device_index_(other.current_device_index_) { + // Mark the moved-from object as "already restored" so its destructor doesn't + // try to restore the device + other.original_device_index_ = other.current_device_index_; +} + +CUDAGuard::~CUDAGuard() { + if (original_device_index_ != current_device_index_) { + cudaError_t err = cudaSetDevice(original_device_index_); + if (err != cudaSuccess) { + ET_LOG( + Error, + "~CUDAGuard: Failed to restore device to %d: %s", + original_device_index_, + cudaGetErrorString(err)); + } + } +} + +Error CUDAGuard::set_index(DeviceIndex device_index) { + int orig_index = -1; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&orig_index)); + + original_device_index_ = orig_index; + current_device_index_ = device_index; + + if (current_device_index_ != original_device_index_) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaSetDevice(current_device_index_)); + } + + return Error::Ok; +} + +Result CUDAGuard::create(DeviceIndex device_index) { + CUDAGuard guard; // Fixed: Removed () to create a variable, not a function + ET_CHECK_OK_OR_RETURN_ERROR(guard.set_index(device_index)); + return guard; +} + +CUDAStreamGuard::CUDAStreamGuard(CUDAStreamGuard&& other) noexcept + : device_guard_(std::move(other.device_guard_)), + original_stream_(other.original_stream_), + current_stream_(other.current_stream_), + device_index_(other.device_index_) { + // Mark the moved-from object as "already restored" so its destructor doesn't + // try to restore the stream + other.original_stream_ = other.current_stream_; +} + +CUDAStreamGuard::~CUDAStreamGuard() { + // Restore the original stream unless this object was moved-from. + // After a move, original_stream_ == current_stream_, which indicates + // the moved-from object should not restore. + // Note: nullptr is a valid stream value (represents the default stream), + // so we must restore even if original_stream_ is nullptr. + if (original_stream_ != current_stream_) { + Error err = setCurrentCUDAStream(original_stream_, device_index_); + if (err != Error::Ok) { + ET_LOG( + Error, + "~CUDAStreamGuard: Failed to restore stream for device %d", + device_index_); + } + } +} + +Error CUDAStreamGuard::set_stream( + cudaStream_t stream, + DeviceIndex device_index) { + auto result = getCurrentCUDAStream(device_index); + if (!result.ok()) { + ET_LOG(Error, "Failed to get current stream for device %d", device_index); + return result.error(); + } + + original_stream_ = result.get(); + current_stream_ = stream; + device_index_ = device_index; + + ET_CHECK_OK_OR_RETURN_ERROR(setCurrentCUDAStream(stream, device_index)); + + return Error::Ok; +} + +Result CUDAStreamGuard::create( + cudaStream_t stream, + DeviceIndex device_index) { + auto guard_result = CUDAGuard::create(device_index); + ET_CHECK_OK_OR_RETURN_ERROR(guard_result.error()); + + CUDAStreamGuard stream_guard(std::move(guard_result.get())); + ET_CHECK_OK_OR_RETURN_ERROR(stream_guard.set_stream(stream, device_index)); + + return stream_guard; +} + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/guard.h b/backends/cuda/runtime/guard.h new file mode 100644 index 00000000000..3f187000f90 --- /dev/null +++ b/backends/cuda/runtime/guard.h @@ -0,0 +1,191 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::runtime::Error; +using executorch::runtime::Result; + +// Type alias for device index +using DeviceIndex = int32_t; + +/** + * Set the current CUDA stream for the specified device. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index (-1 to use current device) + * @return Error code indicating success or failure + */ +Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index = -1); + +/** + * Get the current CUDA stream for the specified device. + * If no stream has been set, creates a new stream and sets it as current. + * + * @param device_index The device index (-1 to use current device) + * @return Result containing the current stream on success, or an error code on + * failure + */ +Result getCurrentCUDAStream(DeviceIndex device_index = -1); + +/** + * RAII guard that sets the current CUDA device and restores it on destruction. + * This ensures that the device is properly restored even if an exception + * occurs. + * + */ +class CUDAGuard { + private: + /** + * Private constructor - use create() factory method instead. + */ + explicit CUDAGuard() + : original_device_index_(-1), current_device_index_(-1) {} + + public: + /** + * Factory method to create a CUDAGuard. + * + * @param device_index The device index to set as current + * @return Result containing the guard on success, or an error code on failure + */ + static Result create(DeviceIndex device_index); + + // Copy is not allowed + CUDAGuard(const CUDAGuard&) = delete; + CUDAGuard& operator=(const CUDAGuard&) = delete; + + // Move constructor and assignment + CUDAGuard(CUDAGuard&& other) noexcept; + CUDAGuard& operator=(CUDAGuard&& other) = delete; + + /** + * Destructor that restores the original device if necessary. + */ + ~CUDAGuard(); + + /** + * Sets the CUDA device to the given device index. + * + * @param device_index The device index to set as current + * @return Error code indicating success or failure + */ + Error set_index(DeviceIndex device_index); + + /** + * Get the original device index before the guard was created. + * + * @return The original device index + */ + DeviceIndex original_device() const { + return original_device_index_; + } + + /** + * Get the current device index. + * + * @return The current device index + */ + DeviceIndex current_device() const { + return current_device_index_; + } + + private: + /// The original device before this guard was created + DeviceIndex original_device_index_; + /// The current device managed by this guard + DeviceIndex current_device_index_; +}; + +/** + * RAII guard that sets the current CUDA device and stream, restoring both on + * destruction. This is useful for temporarily switching to a different device + * and stream. + * + */ +class CUDAStreamGuard { + private: + // Private constructor that takes a CUDAGuard + explicit CUDAStreamGuard(CUDAGuard&& guard) + : device_guard_(std::move(guard)), + original_stream_(nullptr), + current_stream_(nullptr), + device_index_(-1) {} + + public: + /** + * Factory method to create a CUDAStreamGuard. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index for the stream + * @return Result containing the guard on success, or an error code on failure + */ + static Result create( + cudaStream_t stream, + DeviceIndex device_index); + + // Copy is not allowed + CUDAStreamGuard(const CUDAStreamGuard&) = delete; + CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; + + // Move constructor and assignment + CUDAStreamGuard(CUDAStreamGuard&& other) noexcept; + CUDAStreamGuard& operator=(CUDAStreamGuard&& other) noexcept = delete; + + /** + * Destructor that restores the original stream and device. + */ + ~CUDAStreamGuard(); + + /** + * Sets the CUDA stream to the given stream on the specified device. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index for the stream + * @return Error code indicating success or failure + */ + Error set_stream(cudaStream_t stream, DeviceIndex device_index); + + /** + * Get the current guarded stream. + * + * @return The current stream + */ + cudaStream_t stream() const { + return current_stream_; + } + + /** + * Get the device index being guarded. + * + * @return The device index + */ + DeviceIndex device_index() const { + return device_index_; + } + + private: + /// The device guard that handles device switching + CUDAGuard device_guard_; + /// The original stream that was current before this guard + cudaStream_t original_stream_ = nullptr; + /// The current stream being guarded + cudaStream_t current_stream_ = nullptr; + /// The device index for this stream guard + DeviceIndex device_index_; +}; + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/memory_tracker.h b/backends/cuda/runtime/memory_tracker.h new file mode 100644 index 00000000000..e09a96da6a6 --- /dev/null +++ b/backends/cuda/runtime/memory_tracker.h @@ -0,0 +1,192 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace executorch::backends::cuda { + +/** + * @class CudaMemoryTracker + * @brief Tracks CUDA memory usage and logs memory state at key points + * + * This class provides utilities to query and track CUDA memory usage, + * including peak memory usage and detailed memory state logging. + */ +class CudaMemoryTracker { + public: + /** + * @brief Constructor - initializes tracker and logs startup memory state + */ + CudaMemoryTracker() { + if (!query(&last_free_bytes_, &total_bytes_)) { + return; + } + available_ = true; + // Record the initial free bytes observed at startup. We'll use this as a + // baseline so reported "peak usage" reflects additional memory used + // since the tracker was created (instead of the absolute device usage, + // which may include other processes). + initial_free_bytes_ = last_free_bytes_; + min_free_bytes_ = last_free_bytes_; + log_state("startup", last_free_bytes_, total_bytes_); + } + + /** + * @brief Logs current memory state at a tagged checkpoint + * @param tag Descriptive tag for this memory sample (e.g., "after_load") + */ + void log_sample(const char* tag) { + if (!available_) { + return; + } + size_t free_bytes = 0; + size_t total_bytes = 0; + if (!query(&free_bytes, &total_bytes)) { + return; + } + min_free_bytes_ = std::min(min_free_bytes_, free_bytes); + total_bytes_ = total_bytes; + last_free_bytes_ = free_bytes; + log_state(tag, free_bytes, total_bytes); + } + + /** + * @brief Destructor - logs final memory state and peak usage summary + */ + ~CudaMemoryTracker() { + if (!available_) { + return; + } + size_t free_bytes = 0; + size_t total_bytes = 0; + if (!query(&free_bytes, &total_bytes)) { + return; + } + min_free_bytes_ = std::min(min_free_bytes_, free_bytes); + total_bytes_ = total_bytes; + last_free_bytes_ = free_bytes; + // Compute peak usage relative to the initial free baseline so that + // allocations by other processes present at startup are not attributed + // to this process. If for some reason initial_free_bytes_ was not set, + // fall back to absolute device usage. + double peak_mb = 0.0; + if (initial_free_bytes_ != std::numeric_limits::max()) { + size_t used_delta = 0; + if (initial_free_bytes_ > min_free_bytes_) { + used_delta = initial_free_bytes_ - min_free_bytes_; + } + peak_mb = static_cast(used_delta) / (1024.0 * 1024.0); + } else { + peak_mb = static_cast(total_bytes_ - min_free_bytes_) / + (1024.0 * 1024.0); + } + const double total_mb = + static_cast(total_bytes_) / (1024.0 * 1024.0); + ET_LOG( + Info, + "CUDA memory peak usage (since startup): %.2f MB, device total: %.2f MB", + peak_mb, + total_mb); + } + + private: + /** + * @brief Queries current CUDA memory info + * @param free_bytes Output parameter for free memory in bytes + * @param total_bytes Output parameter for total memory in bytes + * @return true if query succeeded, false otherwise + */ + bool query(size_t* free_bytes, size_t* total_bytes) { + cudaError_t err = cudaMemGetInfo(free_bytes, total_bytes); + if (err != cudaSuccess) { + if (!error_logged_) { + error_logged_ = true; + ET_LOG( + Error, + "cudaMemGetInfo failed with error: %s", + cudaGetErrorString(err)); + } + available_ = false; + return false; + } + return true; + } + + /** + * @brief Logs the current memory state + * @param tag Tag describing this log point + * @param free_bytes Current free memory in bytes + * @param total_bytes Current total memory in bytes + */ + void log_state(const char* tag, size_t free_bytes, size_t total_bytes) const { + const double used_mb = + static_cast(total_bytes - free_bytes) / (1024.0 * 1024.0); + const double free_mb = static_cast(free_bytes) / (1024.0 * 1024.0); + const double total_mb = + static_cast(total_bytes) / (1024.0 * 1024.0); + ET_LOG( + Info, + "CUDA memory (%s): used %.2f MB, free %.2f MB, total %.2f MB", + tag, + used_mb, + free_mb, + total_mb); + } + + bool available_{false}; + bool error_logged_{false}; + size_t last_free_bytes_{0}; + size_t total_bytes_{0}; + size_t min_free_bytes_{std::numeric_limits::max()}; + // Baseline free bytes observed at tracker construction. Used to compute + // peak usage attributable to this process since the tracker started. + size_t initial_free_bytes_{std::numeric_limits::max()}; + + public: + // Simple accessors to allow other components to read last-sampled values. + // These are safe to call after a successful log_sample() invocation. + uint64_t last_free_bytes() const { + return static_cast(last_free_bytes_); + } + uint64_t total_bytes() const { + return static_cast(total_bytes_); + } + uint64_t min_free_bytes() const { + return static_cast(min_free_bytes_); + } + uint64_t initial_free_bytes() const { + return static_cast(initial_free_bytes_); + } + double peak_usage_mb() const { + // Prefer peak relative to the initial free baseline; fall back to + // absolute device peak if baseline isn't available. + if (min_free_bytes_ == std::numeric_limits::max()) { + return 0.0; + } + if (initial_free_bytes_ != std::numeric_limits::max()) { + size_t used_delta = 0; + if (initial_free_bytes_ > min_free_bytes_) { + used_delta = initial_free_bytes_ - min_free_bytes_; + } + return static_cast(used_delta) / (1024.0 * 1024.0); + } + if (total_bytes_ == 0) { + return 0.0; + } + return static_cast(total_bytes_ - min_free_bytes_) / + (1024.0 * 1024.0); + } +}; + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/platform/platform.cpp b/backends/cuda/runtime/platform/platform.cpp new file mode 100644 index 00000000000..5264dcbd03a --- /dev/null +++ b/backends/cuda/runtime/platform/platform.cpp @@ -0,0 +1,125 @@ + +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#include +#else // Posix +#include +#include +#include +#endif + +namespace executorch { +namespace backends { +namespace cuda { + +executorch::runtime::Result load_library( + const std::filesystem::path& path) { +#ifdef _WIN32 + std::string utf8 = path.u8string(); + auto lib_handle = LoadLibrary(utf8.c_str()); + if (lib_handle == NULL) { + ET_LOG( + Error, + "Failed to load %s with error: %lu", + utf8.c_str(), + GetLastError()); + return executorch::runtime::Error::AccessFailed; + } + +#else + std::string path_str = path.string(); + void* lib_handle = dlopen(path_str.c_str(), RTLD_LAZY | RTLD_LOCAL); + if (lib_handle == nullptr) { + ET_LOG( + Error, "Failed to load %s with error: %s", path_str.c_str(), dlerror()); + return executorch::runtime::Error::AccessFailed; + } +#endif + return (void*)lib_handle; +} + +executorch::runtime::Error close_library(void* lib_handle) { +#ifdef _WIN32 + if (!FreeLibrary((HMODULE)lib_handle)) { + printf("FreeLibrary failed with error %lu\n", GetLastError()); + return executorch::runtime::Error::Internal; + } +#else + if (dlclose(lib_handle) != 0) { + ET_LOG(Error, "dlclose failed: %s\n", dlerror()); + return executorch::runtime::Error::Internal; + } +#endif + return executorch::runtime::Error::Ok; +} + +executorch::runtime::Result get_function( + void* lib_handle, + const std::string& fn_name) { +#ifdef _WIN32 + auto fn = GetProcAddress((HMODULE)lib_handle, fn_name.c_str()); + if (!fn) { + ET_LOG( + Error, + "Failed loading symbol %s with error %lu\n", + fn_name.c_str(), + GetLastError()); + return executorch::runtime::Error::Internal; + } +#else + auto fn = dlsym(lib_handle, fn_name.c_str()); + if (fn == nullptr) { + ET_LOG( + Error, + "Failed loading symbol %s with error %s\n", + fn_name.c_str(), + dlerror()); + return executorch::runtime::Error::Internal; + } +#endif + + return (void*)fn; // This I think is technically ub on windows. We should + // probably explicitly pack the bytes. +} + +int32_t get_process_id() { +#ifdef _WIN32 + return GetCurrentProcessId(); +#else + return getpid(); +#endif +} + +void* aligned_alloc(size_t alignment, size_t size) { +#ifdef _WIN32 + return _aligned_malloc(size, alignment); +#else + return std::aligned_alloc(alignment, size); +#endif +} + +void aligned_free(void* ptr) { +#ifdef _WIN32 + _aligned_free(ptr); +#else + std::free(ptr); +#endif +} + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/platform/platform.h b/backends/cuda/runtime/platform/platform.h new file mode 100644 index 00000000000..00f278ef85e --- /dev/null +++ b/backends/cuda/runtime/platform/platform.h @@ -0,0 +1,38 @@ + +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +executorch::runtime::Result load_library( + const std::filesystem::path& path); + +executorch::runtime::Error close_library(void* lib_handle); + +executorch::runtime::Result get_function( + void* lib_handle, + const std::string& fn_name); + +int32_t get_process_id(); + +void* aligned_alloc(size_t alignment, size_t size); + +void aligned_free(void* ptr); + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/cuda_guard.cpp b/backends/cuda/runtime/shims/cuda_guard.cpp new file mode 100644 index 00000000000..bb07acc7ffa --- /dev/null +++ b/backends/cuda/runtime/shims/cuda_guard.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch::backends::cuda { + +extern "C" { + +AOTITorchError aoti_torch_create_cuda_guard( + int32_t device_index, + CUDAGuardHandle* ret_guard) { + ET_CHECK_OR_RETURN_ERROR( + ret_guard != nullptr, + InvalidArgument, + "aoti_torch_create_cuda_guard failed: ret_guard is null"); + + auto result = CUDAGuard::create(device_index); + if (!result.ok()) { + return result.error(); + } + *ret_guard = new CUDAGuard(std::move(result.get())); + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard) { + ET_CHECK_OR_RETURN_ERROR( + guard != nullptr, + InvalidArgument, + "aoti_torch_delete_cuda_guard failed: guard is null"); + + delete guard; + return Error::Ok; +} + +AOTITorchError aoti_torch_cuda_guard_set_index( + CUDAGuardHandle guard, + int32_t device_index) { + ET_CHECK_OR_RETURN_ERROR( + guard != nullptr, + InvalidArgument, + "aoti_torch_cuda_guard_set_index failed: guard is null"); + + ET_CHECK_OK_OR_RETURN_ERROR(guard->set_index(device_index)); + return Error::Ok; +} + +AOTITorchError aoti_torch_create_cuda_stream_guard( + void* stream, + int32_t device_index, + CUDAStreamGuardHandle* ret_guard) { + ET_CHECK_OR_RETURN_ERROR( + ret_guard != nullptr, + InvalidArgument, + "aoti_torch_create_cuda_stream_guard failed: ret_guard is null"); + + ET_CHECK_OR_RETURN_ERROR( + stream != nullptr, + InvalidArgument, + "aoti_torch_create_cuda_stream_guard failed: stream is null"); + + auto result = + CUDAStreamGuard::create(static_cast(stream), device_index); + if (!result.ok()) { + return result.error(); + } + *ret_guard = new CUDAStreamGuard(std::move(result.get())); + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_cuda_stream_guard( + CUDAStreamGuardHandle guard) { + ET_CHECK_OR_RETURN_ERROR( + guard != nullptr, + InvalidArgument, + "aoti_torch_delete_cuda_stream_guard failed: guard is null"); + + delete guard; + return Error::Ok; +} + +AOTITorchError aoti_torch_get_current_cuda_stream( + int32_t device_index, + void** ret_stream) { + ET_CHECK_OR_RETURN_ERROR( + ret_stream != nullptr, + InvalidArgument, + "aoti_torch_get_current_cuda_stream failed: ret_stream is null"); + + auto result = getCurrentCUDAStream(device_index); + if (!result.ok()) { + return result.error(); + } + *ret_stream = static_cast(result.get()); + return Error::Ok; +} + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/cuda_guard.h b/backends/cuda/runtime/shims/cuda_guard.h new file mode 100644 index 00000000000..83fceabb98f --- /dev/null +++ b/backends/cuda/runtime/shims/cuda_guard.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; + +extern "C" { + +// Handle types for CUDA guards +using CUDAGuardHandle = CUDAGuard*; +using CUDAStreamGuardHandle = CUDAStreamGuard*; + +/** + * Creates a CUDA device guard that sets the current device and restores it + * upon destruction. + * + * @param device_index The device index to set as current + * @param ret_guard Output parameter for the created guard handle (must not be + * null) + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_create_cuda_guard(int32_t device_index, CUDAGuardHandle* ret_guard); + +/** + * Deletes a CUDA device guard and frees its associated resources. + * + * @param guard Handle to the guard to be deleted + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_delete_cuda_guard(CUDAGuardHandle guard); + +/** + * Sets the CUDA device to a new index for an existing guard. + * + * @param guard Handle to the guard + * @param device_index The device index to set as current + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_cuda_guard_set_index(CUDAGuardHandle guard, int32_t device_index); + +/** + * Creates a CUDA stream guard that sets the current device and stream, + * restoring both upon destruction. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index for the stream + * @param ret_guard Output parameter for the created guard handle (must not be + * null) + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_cuda_stream_guard( + void* stream, + int32_t device_index, + CUDAStreamGuardHandle* ret_guard); + +/** + * Deletes a CUDA stream guard and frees its associated resources. + * + * @param guard Handle to the stream guard to be deleted + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard); + +/** + * Gets the current CUDA stream for a specified device. + * + * @param device_index The device index (-1 to use current device) + * @param ret_stream Output parameter for the current stream (must not be null) + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int4mm.cu b/backends/cuda/runtime/shims/int4mm.cu new file mode 100644 index 00000000000..c1896f4eec0 --- /dev/null +++ b/backends/cuda/runtime/shims/int4mm.cu @@ -0,0 +1,59 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include + +namespace executorch::backends::cuda { +#ifdef __cplusplus +extern "C" { +#endif + +AOTITorchError aoti_torch_cuda__weight_int4pack_mm( + Tensor* self, + Tensor* mat2, + int64_t qGroupSize, + Tensor* qScaleAndZeros, + Tensor** ret0) { + // Validate input parameters first + // Only check for null pointers here, as the actual validation of tensor + // properties is done in _weight_int4pack_mm_cuda + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch_cuda__weight_int4pack_mm failed: self tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + mat2 != nullptr, + InvalidArgument, + "aoti_torch_cuda__weight_int4pack_mm failed: mat2 tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + qScaleAndZeros != nullptr, + InvalidArgument, + "aoti_torch_cuda__weight_int4pack_mm failed: qScaleAndZeros tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret0 != nullptr, + InvalidArgument, + "aoti_torch_cuda__weight_int4pack_mm failed: ret0 is null"); + + *ret0 = _weight_int4pack_mm_cuda(*self, *mat2, qGroupSize, *qScaleAndZeros); + ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR(); + return Error::Ok; +} + +#ifdef __cplusplus +} +#endif +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int4mm.cuh b/backends/cuda/runtime/shims/int4mm.cuh new file mode 100644 index 00000000000..ee12fb51004 --- /dev/null +++ b/backends/cuda/runtime/shims/int4mm.cuh @@ -0,0 +1,1334 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + + + // This file is a port of PyTorch's int4mm.cu kernel implementation +// (aten/src/ATen/native/cuda/int4mm.cu) adapted for the ExecuTorch runtime. +// +// In the future, we should consider making the PyTorch code generic enough +// to be reusable in executorch. +// +// PORTING NOTES: +// -------------- +// 1. KERNEL CODE (lines 36-1067): Identical to PyTorch - preserved 100% +// - All utility templates, vector types, and conversion logic unchanged +// - Tensor core kernels (tinygemm_m16n8k16_chunk_kernel) byte-for-byte identical +// - Same inline PTX assembly for mma.sync.aligned instructions +// - Identical performance characteristics and register allocation +// +// 2. API ADAPTATIONS: +// - Replaced at::Tensor with executorch::backends::aoti::Tensor +// - Output returned via pointer-to-pointer instead of by-value +// +// 3. REMOVED FEATURES: +// - _convert_weight_to_int4pack_cuda(): Weight conversion happens offline +// during model export via optimum-executorch. Runtime only consumes +// pre-packed weights. +// - isCDNA2orLater() runtime check: Removed dependency on ATen GPU detection +// hooks. ROCm support relies on compile-time guards only. +// +// 4. INFRASTRUCTURE CHANGES: +// - Removed c10::cuda::CUDAGuard: Device management handled by AOTI backend +// - Removed at::cuda::getCurrentCUDAStream(): Stream passed explicitly + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) +#include +#include +#include +#if !defined(USE_ROCM) +#include +#endif +#endif + +namespace executorch::backends::cuda { +using executorch::backends::aoti::Tensor; + +template +constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral_v && std::is_integral_v, ""); + return (a / b); +} + +template +constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral_v && std::is_integral_v, ""); + // Overflow safe variant of (a + b - 1) / b + const uint64_t blocks = a / b + (a % b != 0); + return blocks; +} + +template +constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral_v && std::is_integral_v, ""); + return divDown(a, b) * b; +} + +template +constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral_v && std::is_integral_v, ""); + return divUp(a, b) * b; +} + +template +constexpr __host__ __device__ bool isEvenDivisor(U a, V b) { + static_assert(std::is_integral_v && std::is_integral_v, ""); + return (a % V(b) == 0) && ((a / V(b)) >= 1); +} + +template +constexpr __host__ __device__ T pow(T n, int power) { + return (power > 0 ? n * pow(n, power - 1) : 1); +} + +template +constexpr __host__ __device__ T pow2(int power) { + return pow(2, power); +} + +static_assert(pow2(8) == 256, "pow2"); + +template +constexpr __host__ __device__ int log2(T n, int p = 0) { + return (n <= 1) ? p : log2(n / 2, p + 1); +} + +static_assert(log2(2) == 1, "log2"); +static_assert(log2(3) == 1, "log2"); +static_assert(log2(4) == 2, "log2"); + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + static_assert(std::is_integral_v, ""); + return (v && !(v & (v - 1))); +} + +static_assert(isPowerOf2(2048), "isPowerOf2"); +static_assert(!isPowerOf2(3333), "isPowerOf2"); + +template +constexpr __host__ __device__ T nextHighestPowerOf2(T v) { + static_assert(std::is_integral_v, ""); + return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1))); +} + +static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2"); + +static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2"); + +static_assert( + nextHighestPowerOf2(1536000000u) == 2147483648u, + "nextHighestPowerOf2"); +static_assert( + nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL, + "nextHighestPowerOf2"); + +template +constexpr __host__ __device__ T nextLowestPowerOf2(T v) { + static_assert(std::is_integral_v, ""); + return (isPowerOf2(v) ? v / (T)2 : ((T)1 << (log2(v)))); +} + +static_assert(nextLowestPowerOf2(1) == 0, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(2) == 1, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(3) == 2, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(4) == 2, "nextLowestPowerOf2"); + +static_assert(nextLowestPowerOf2(15) == 8, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(16) == 8, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(17) == 16, "nextLowestPowerOf2"); + +inline __host__ __device__ bool isPointerAligned(const void* p, int align) { + return reinterpret_cast(p) % align == 0; +} + +// Returns the increment needed to aligned the pointer to the next highest +// aligned address +template +inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) { + static_assert(isPowerOf2(Align), ""); + const uint32_t diff = uint32_t(uintptr_t(p) & uintptr_t(Align - 1)); + return diff == 0 ? 0 : uint32_t(Align) - diff; +} + +#if defined (__gfx90a__) || defined(__gfx942__) +#define CDNA2_OR_LATER 1 +#else +#define CDNA2_OR_LATER 0 +#endif + +#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) + +#if defined(USE_ROCM) +// TODO: Support RDNA +constexpr int32_t kWarpSize = 64; + +template +using VecT = T __attribute__((ext_vector_type(Rank))); + +/* + * Not used by ET +static bool isCDNA2orLater(int index) { + return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942"}, index); +} +*/ + +#else +constexpr int32_t kWarpSize = 32; +#endif + +// f16 vector types +struct __align__(2) f16x1 { + __half vals[1]; +}; + +struct __align__(4) f16x2 { + __half vals[2]; +}; + +struct __align__(8) f16x4 { + __half vals[4]; +}; + +struct __align__(16) f16x8 { + __half vals[8]; +}; + +// bf16 vector types +struct __align__(2) bf16x1 { + __nv_bfloat16 vals[1]; +}; + +struct __align__(4) bf16x2 { + __nv_bfloat16 vals[2]; +}; + +struct __align__(8) bf16x4 { + __nv_bfloat16 vals[4]; +}; + +struct __align__(16) bf16x8 { + __nv_bfloat16 vals[8]; +}; + +// bf162 vector types +struct __align__(4) bf16x2x1 { + __nv_bfloat162 vals[1]; +}; + +struct __align__(8) bf16x2x2 { + __nv_bfloat162 vals[2]; +}; + +struct __align__(16) bf16x2x4 { + __nv_bfloat162 vals[4]; +}; + +struct __align__(16) bf16x2x4_u32 { +#if defined(USE_ROCM) + VecT val[2]; +#else + uint32_t vals[4]; +#endif +}; + +struct __align__(8) bf16x2x2_u32 { +#if defined(USE_ROCM) + VecT val; +#else + uint32_t vals[2]; +#endif +}; + +struct __align__(4) bf16x2x1_u32 { + uint32_t vals[1]; +}; + +template +struct __align__(sizeof(T) * N) VectorType { + T vals[N]; +}; + +// from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { + bf16x2x4 result; + constexpr int kElements = 8; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = source; + + // First, we extract the i4s and construct an intermediate fp16 number. +#if !defined(USE_ROCM) + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; +#endif + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so + // we must loop. No shift needed for first item. + uint32_t i4s = source_i4s; + +#if defined(USE_ROCM) + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(h[0]) + : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); +#else + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#endif + +#pragma unroll + for (int ii = 1; ii < kElements / 2; ++ii) { + i4s >>= 4; // or is it 8? + // (i4s & 0x000f000f) | 0x43004300 +#if defined(USE_ROCM) + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(h[ii]) + : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); +#else + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#endif + } + + // This is the BF16 {-136, -136} represented as an integer. +#if defined(USE_ROCM) +#if ROCM_VERSION >= 60200 + auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308})); + auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80})); +#else + auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308}); + auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80}); +#endif +#else + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; +#endif + +// Finally, we construct the output numbers. +#pragma unroll + for (int ii = 0; ii < kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias + // subtraction +#if defined(USE_ROCM) + result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS); +#else + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[ii]) + : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); +#endif + } + + return result; +} + + + +enum class KReductionType { + // No k-reduction is needed between blocks as the number of k-tiles processed + // per block are exact and we can directly write the output + None, +}; + +// Loads the A matrix in 16-bit standard m x k row major layout, and writes +// the C matrix in 16-bit standard m x n row major layout: +// +// size [m][k] +template +struct ALayout_RM { + static constexpr int32_t kMTileSize = 16; +#if defined(USE_ROCM) + static constexpr int32_t kNTileSize = 16; +#else + static constexpr int32_t kNTileSize = 8; +#endif + static constexpr int32_t kKTileSize = 16; + + template + static __device__ void load( + const void* A, + int32_t m, + int32_t k, + int32_t mTiles, + int32_t mTile, + int32_t kTiles, + int32_t kTileStart, + int32_t laneId, +#if defined(USE_ROCM) + bf16x2x2_u32 out[KTilesToLoad] +#else + bf16x2x4_u32 out[KTilesToLoad] +#endif + ) { +#if defined(USE_ROCM) + const auto mLane = mTile * kMTileSize + (laneId % kMTileSize); + const auto kLane = kTileStart * kKTileSize + (laneId / kMTileSize) * 4; +#else + const auto mLane = mTile * kMTileSize + (laneId / 4); + const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2; +#endif + + // access + // [mTile * kMTileSize + (laneId / 4)] + // [kTileStart * kKTileSize + (laneId % 4) * 2] + auto aPtr = reinterpret_cast(A) + mLane * k + kLane; + bool m0InBounds = mLane < m; + +#if !defined(USE_ROCM) + auto aPtrPlus8Rows = aPtr + 8 * k; + + bool m1InBounds = (mLane + 8) < m; +#endif + +#pragma unroll + for (int i = 0; i < KTilesToLoad; ++i) { +#if defined(USE_ROCM) + out[i].val = m0InBounds ? *((VecT *)(aPtr + i * kKTileSize)) : VecT{0, 0, 0, 0}; +#else + out[i].vals[0] = m0InBounds + ? *reinterpret_cast(aPtr + i * kKTileSize) + : uint32_t(0); + out[i].vals[1] = m1InBounds + ? *reinterpret_cast(aPtrPlus8Rows + i * kKTileSize) + : uint32_t(0); + + out[i].vals[2] = m0InBounds + ? *reinterpret_cast(aPtr + i * kKTileSize + 8) + : uint32_t(0); + out[i].vals[3] = m1InBounds ? *reinterpret_cast( + aPtrPlus8Rows + i * kKTileSize + 8) + : uint32_t(0); +#endif + } + } + + static __device__ void store( + void* C, + int32_t m, + int32_t n, + int32_t mOutTiles, + int32_t mTile, + int32_t nOutTiles, + int32_t nTile, + int32_t laneId, + const float4& out) { + static_assert(ReduceType == KReductionType::None, ""); + + if constexpr (ReduceType == KReductionType::None) { +#if defined(USE_ROCM) + const int outRow = mTile * kMTileSize + (laneId / kNTileSize) * 4; + const int outCol = nTile * kNTileSize + (laneId % kNTileSize); +#else + // sum.x / sum.y are written at + // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] + // sum.z / sum.w are written at + // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] + // i.e., same columns, different row. + const int outRow = mTile * kMTileSize + (laneId / 4); + const int outCol = nTile * kNTileSize + (laneId % 4) * 2; +#endif + + // Pointer where sum.x / sum.y is written + auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol; + +#if defined(USE_ROCM) + if (outRow < m) + cPtr[0] = __float2bfloat16(out.x); + if ((outRow + 1) < m) + cPtr[n] = __float2bfloat16(out.y); + if ((outRow + 2) < m) + cPtr[2*n] = __float2bfloat16(out.z); + if ((outRow + 3) < m) + cPtr[3*n] = __float2bfloat16(out.w); +#else + auto v01 = __float22bfloat162_rn(float2{out.x, out.y}); + auto v23 = __float22bfloat162_rn(float2{out.z, out.w}); + + if (outRow < m) { + *reinterpret_cast<__nv_bfloat162*>(cPtr) = v01; + } + + // sum.z, sum.w at +8 rows from cPtr + if (outRow + 8 < m) { + *reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23; + } +#endif + } + } +}; + +template +struct BLayout_TC_int4 { + static constexpr int32_t kInnerKTiles = InnerKTiles; + static constexpr int32_t kMTileSize = 16; +#if defined(USE_ROCM) + static constexpr int32_t kNTileSize = 16; +#else + static constexpr int32_t kNTileSize = 8; +#endif + static constexpr int32_t kKTileSize = 16; + + template + static __device__ void load( + // type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2] + // n-tiles: n / 8 for NV, n /16 for AMD + // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16 for NV, m16n16k16 for AMD) + // value per warp lane: 32 for NV, 64 for AMD + // (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile. + // 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest + // value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a + // uint32x4 (128 bits) + const void* ET_RESTRICT B, + // size [k / qGroupSize][n][2] + // Contains the scale and zero point of each of the quantized int4 values + // within B + // v_reconstructed = (bf16(B_int4_val) * scale) - zero + const void* ET_RESTRICT quantizationInfo, + int32_t n, + int32_t k, + int32_t nTiles, + int32_t nTile, + int32_t kTiles, + int32_t kTileStart, + int32_t laneId, + bf16x2x4_u32 out[KTilesToLoad / InnerKTiles][InnerKTiles / 2]) { + // offset [nTile][kTileStart / InnerKTiles][laneId][0] + auto bPtr = reinterpret_cast(B) + + (((nTile * (kTiles / InnerKTiles) + (kTileStart / InnerKTiles)) * + kWarpSize) + + laneId) * + (InnerKTiles / 2); + + int32_t b_int4[KTilesToLoad / InnerKTiles][InnerKTiles / 2]; + +#pragma unroll + for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) { + auto bPtrCur = bPtr + i * kWarpSize * (InnerKTiles / 2); + + if constexpr (InnerKTiles == 2) { + b_int4[i][0] = bPtrCur[0]; + } + + if constexpr (InnerKTiles == 4) { + // asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]) + // : "l"(bPtrCur)); + + int2 load8 = reinterpret_cast(bPtrCur)[0]; + b_int4[i][0] = load8.x; + b_int4[i][1] = load8.y; + } + + if constexpr (InnerKTiles == 8) { + // asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n" + // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]), + // "=r"(b_int4[i][2]), "=r"(b_int4[i][3]) : "l"(bPtrCur)); + + int4 load16 = reinterpret_cast(bPtrCur)[0]; + b_int4[i][0] = load16.x; + b_int4[i][1] = load16.y; + b_int4[i][2] = load16.z; + b_int4[i][3] = load16.w; + } + } + + // Load needed info for dequantization + + static_assert(isPowerOf2(QGroupSize), ""); + static_assert(isEvenDivisor(QGroupSize, kKTileSize), ""); + // smallest quantization group size is 32 (2 k-tiles are packed in an int32) + static_assert(QGroupSize >= kKTileSize * 2, ""); + constexpr int kKTilesPerQGroup = (QGroupSize / kKTileSize); + // a q-group could be larger than what we are handling in a single warp + constexpr int kNumQGroups = (KTilesToLoad / kKTilesPerQGroup) < 1 + ? 1 + : (KTilesToLoad / kKTilesPerQGroup); + + __nv_bfloat162 qScaleAndZero[kNumQGroups]; + { +#if defined(USE_ROCM) + int32_t laneN = nTile * kNTileSize + (laneId % kNTileSize); +#else + int32_t laneN = nTile * kNTileSize + (laneId / 4); +#endif + int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize; + + int32_t n = nTiles * kNTileSize; + + // offset [qScale_kGroup][qScale_n][0] + auto qInfoPtr = reinterpret_cast(quantizationInfo) + + (groupStart * n + laneN) * 2; + +#pragma unroll + for (int i = 0; i < kNumQGroups; ++i) { + qScaleAndZero[i] = + *reinterpret_cast(qInfoPtr + i * n * 2); + } + } + + // + // De-quantize int4 values to bf16. Values are dequantized as truly int4 + // [-8, 7] range; dequant = (bf16(int4_value) * bf16_scale) + bf16_zero + // + { + // FIXME: does this negatively affect register counts, or will nvcc + // move this expansion (and data loads above) closer to the point of use? + __nv_bfloat162 qScale[kNumQGroups]; + __nv_bfloat162 qZero[kNumQGroups]; + +#pragma unroll + for (int i = 0; i < kNumQGroups; ++i) { + qScale[i] = __bfloat162bfloat162(qScaleAndZero[i].x); + qZero[i] = __bfloat162bfloat162(qScaleAndZero[i].y); + } + +#pragma unroll + for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) { +#pragma unroll + for (int j = 0; j < InnerKTiles / 2; ++j) { + bf16x2x4 v = convert_i4x8_to_bf16x2x4(b_int4[i][j]); + + int curKTile = i * InnerKTiles + j * 2; + int curQGroup = (curKTile * kKTileSize) / QGroupSize; + + // The dequantized values in `v` for a given lane have the same n + // dimension (the B tensor core layout has all values in the same + // thread along the same n) but different k dimension, but all are + // guaranteed to occur within the same quantization group, so we need + // only load a single scale + zero to cover what this lane has +#pragma unroll + for (int k = 0; k < 4; ++k) { + v.vals[k] = __hfma2(v.vals[k], qScale[curQGroup], qZero[curQGroup]); + } + + // type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and + // can't be used as a 32-bit asm register argument for `mma` + static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), ""); + std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32)); + } + } + } + } +}; + +template < + typename ALayout, + typename BLayout, + typename CLayout, + int Warps, + int KTilesPerIteration> +__global__ +__launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( + // Data for the A matrix, loaded as per ALayout + const void* const ET_RESTRICT A, + + // Data for the B matrix, loaded as per BLayout + const void* const ET_RESTRICT B, + + // Optional quantization data for dequantizing B, loaded as per BLayout + const void* const ET_RESTRICT B_quantizationInfo, + + // Output data for the C matrix, stored as per CLayout + void* ET_RESTRICT C, + + // The size of the matrix multiplication + int32_t m, + int32_t n, + int32_t k, + + // The size of the matrix multiplication, in multiples of our TC tile size + int32_t mTiles, + int32_t nTiles, + int32_t kTiles) { + constexpr int32_t kMTileSize = 16; +#if defined(USE_ROCM) + constexpr int32_t kNTileSize = 16; +#else + constexpr int32_t kNTileSize = 8; +#endif + constexpr int32_t kKTileSize = 16; + +#if !defined(USE_ROCM) || CDNA2_OR_LATER + + static_assert( + ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize && + ALayout::kKTileSize == kKTileSize, + ""); + + static_assert( + BLayout::kMTileSize == kMTileSize && BLayout::kNTileSize == kNTileSize && + BLayout::kKTileSize == kKTileSize, + ""); + + static_assert( + CLayout::kMTileSize == kMTileSize && CLayout::kNTileSize == kNTileSize && + CLayout::kKTileSize == kKTileSize, + ""); + + constexpr int kInnerKTiles = BLayout::kInnerKTiles; + + // 2/4/8 inner k-tiles correspond to 4, 8 and 16 byte innermost loads + static_assert( + kInnerKTiles == 2 || kInnerKTiles == 4 || kInnerKTiles == 8, ""); + + // We always process at least kInnerKTiles k-tiles back to back in a warp + static_assert( + KTilesPerIteration >= kInnerKTiles && + isEvenDivisor(KTilesPerIteration, kInnerKTiles), + ""); + + auto warpId = threadIdx.y; + auto laneId = threadIdx.x; + + int32_t mTile = blockIdx.z; + int32_t nTile = blockIdx.y; + +#if defined(USE_ROCM) + VecT c{0.0f, 0.0f, 0.0f, 0.0f}; +#else + float4 c{0.0f, 0.0f, 0.0f, 0.0f}; +#endif + + // First, handle whole multiples of KTilesPerIteration + auto kTilesLimit = roundDown(kTiles, KTilesPerIteration); + + // Each warp handles a set of KTilesPerIteration under the above limit + for (int32_t kTileBase = (blockIdx.x * Warps + warpId) * KTilesPerIteration; + kTileBase < kTilesLimit; + kTileBase += Warps * KTilesPerIteration) { + // + // Load data from A + // +#if defined(USE_ROCM) + bf16x2x2_u32 a[KTilesPerIteration]; +#else + bf16x2x4_u32 a[KTilesPerIteration]; +#endif + ALayout::template load( + A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a); + + // + // Load data from B and de-quantize as needed + // Each k-tile is bf16x2x2 + // + bf16x2x4_u32 b[KTilesPerIteration / kInnerKTiles][kInnerKTiles / 2]; + BLayout::template load( + B, + B_quantizationInfo, + n, + k, + nTiles, + nTile, + kTiles, + kTileBase, + laneId, + b); + + // + // Now, perform the matrix multiplication + // + + // We accumulate across k-tiles here +#pragma unroll + for (int i = 0; i < KTilesPerIteration / kInnerKTiles; ++i) { + static_assert(isEvenDivisor(kInnerKTiles, 2) && kInnerKTiles >= 2, ""); +#pragma unroll + for (int j = 0; j < kInnerKTiles / 2; ++j) { + // We don't simply accumulate into `c` as this creates a too-strong + // execution dependency. Instead, we only periodically accumulate into + // `c` +#if defined(USE_ROCM) + VecT cTmp[2]; +#else + float4 cTmp[2]; +#endif + +#pragma unroll + for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + cTmp[k] = VecT{0.0f, 0.0f, 0.0f, 0.0f}; +#else + cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; +#endif + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + a[i * kInnerKTiles + j * 2 + k].val, + b[i][(j * 2 + k) / 2].val[((j * 2 + k) % 2)], + cTmp[k], 0, 0, 0); +#else + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" + : "=f"(cTmp[k].x), + "=f"(cTmp[k].y), + "=f"(cTmp[k].z), + "=f"(cTmp[k].w) + : "r"(a[i * kInnerKTiles + j * 2 + k].vals[0]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[1]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[2]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[3]), + "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]), + "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]), + "f"(cTmp[k].x), + "f"(cTmp[k].y), + "f"(cTmp[k].z), + "f"(cTmp[k].w)); +#endif + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + c[0] += cTmp[k][0]; + c[1] += cTmp[k][1]; + c[2] += cTmp[k][2]; + c[3] += cTmp[k][3]; +#else + c.x += cTmp[k].x; + c.y += cTmp[k].y; + c.z += cTmp[k].z; + c.w += cTmp[k].w; +#endif + } + } + } + } // for all tiles under kTilesLimit + + // Now, there could be a remainder of 1 to KTilesPerIteration - 1 k-tiles + // remaining. We guarantee that the number of warps is >= KTilesPerIteration / + // kInnerKTiles, so that each warp can simply load kInnerKTiles and do its + // thing without needing more warps + static_assert(Warps >= KTilesPerIteration / kInnerKTiles, ""); + + auto kTileBaseRemaining = kTilesLimit + warpId * kInnerKTiles; + + // If we have any remainder k-tiles, some warps will handle them, processing + // kInnerKTiles k-tiles at a time + if (kTileBaseRemaining < kTiles) { +#if defined(USE_ROCM) + bf16x2x2_u32 a[kInnerKTiles]; +#else + bf16x2x4_u32 a[kInnerKTiles]; +#endif + ALayout::template load( + A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a); + + bf16x2x4_u32 b[1][kInnerKTiles / 2]; + BLayout::template load( + B, + B_quantizationInfo, + n, + k, + nTiles, + nTile, + kTiles, + kTileBaseRemaining, + laneId, + b); + +#pragma unroll + for (int j = 0; j < kInnerKTiles / 2; ++j) { + // We don't simply accumulate into `c` as this creates a too-strong + // execution dependency. Instead, we only periodically accumulate into + // `c` +#if defined(USE_ROCM) + VecT cTmp[2]; +#else + float4 cTmp[2]; +#endif + +#pragma unroll + for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + cTmp[k] = VecT{0.0f, 0.0f, 0.0f, 0.0f}; +#else + cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; +#endif + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + a[j * 2 + k].val, + b[0][(j * 2 + k) / 2].val[((j * 2 + k) % 2)], + cTmp[k], 0, 0, 0); +#else + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" + : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w) + : "r"(a[j * 2 + k].vals[0]), + "r"(a[j * 2 + k].vals[1]), + "r"(a[j * 2 + k].vals[2]), + "r"(a[j * 2 + k].vals[3]), + "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]), + "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]), + "f"(cTmp[k].x), + "f"(cTmp[k].y), + "f"(cTmp[k].z), + "f"(cTmp[k].w)); +#endif + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { +#if defined(USE_ROCM) + c[0] += cTmp[k][0]; + c[1] += cTmp[k][1]; + c[2] += cTmp[k][2]; + c[3] += cTmp[k][3]; +#else + c.x += cTmp[k].x; + c.y += cTmp[k].y; + c.z += cTmp[k].z; + c.w += cTmp[k].w; +#endif + } + } + } + + // + // Reduce independent k-tiles (same m/n) across warps + // + __shared__ float4 smem_sum[Warps][kWarpSize]; + + // FIXME: this likely doesn't need to be a true reduction tree, can just be a + // serial sum, maybe (unless nvcc/ptxas goes back to its old ways) + // smem_sum[warpId][laneId] = TreeReduce4::reduce(c); +#if defined(USE_ROCM) + smem_sum[warpId][laneId].x = c[0]; + smem_sum[warpId][laneId].y = c[1]; + smem_sum[warpId][laneId].z = c[2]; + smem_sum[warpId][laneId].w = c[3]; +#else + smem_sum[warpId][laneId] = c; +#endif + + __syncthreads(); + + if (warpId == 0) { + float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f}; + + // Reduce across the block in the first warp + for (int i = 0; i < Warps; ++i) { + float4 v = smem_sum[i][laneId]; + sum_f32.x += v.x; + sum_f32.y += v.y; + sum_f32.z += v.z; + sum_f32.w += v.w; + } + + // Write the reduced result (in the first warp) into the output + CLayout::store( + C, + m, + n, + mTiles, + mTile, + // n for C output becomes k for A input, so for m16n8k16, + // we need to halve the tiles + nTiles / 2, + nTile, + laneId, + sum_f32); + } +#else + printf("__builtin_amdgcn_mfma_f32_16x16x16bf16_1k is only supported on AMD gpu arch greater than or equal to CDNA2\n"); +#endif +} + +template < + typename ALayout, + typename BLayout, + typename CLayout, + int Warps, + int KTilesPerWarp> +void launch_tinygemm_kernel( + const Tensor& A, + const Tensor& B, + const Tensor* qScaleAndZeros, /* optional */ + Tensor& C_final, + int32_t mTiles, + int32_t nTiles, + int32_t kTiles, + int32_t m, + int32_t n, + int32_t k, + cudaStream_t stream) { + // The chunking kernel requires that kTiles is a multiple of kInnerKTiles + ET_CHECK( + kTiles >= BLayout::kInnerKTiles && + isEvenDivisor(kTiles, BLayout::kInnerKTiles)); + + ET_CHECK( + KTilesPerWarp >= BLayout::kInnerKTiles && + isEvenDivisor(KTilesPerWarp, BLayout::kInnerKTiles)); + + // After intra-block reduction across the k dimension, we are left with this + // many tiles + // int32_t postKernelKTiles = kTiles / (Warps * KTilesPerWarp); + int32_t postKernelKTiles = 1; // we loop + + auto grid = dim3(postKernelKTiles, nTiles, mTiles); + auto block = dim3(kWarpSize, Warps); + + auto func = + tinygemm_m16n8k16_chunk_kernel; + + func<<>>( + A.data_ptr(), + B.data_ptr(), + qScaleAndZeros ? qScaleAndZeros->data_ptr() : nullptr, + C_final.data_ptr(), + m, + n, + k, + mTiles, + nTiles, + kTiles); + + ET_CUDA_KERNEL_LAUNCH_CHECK(); + + cudaFuncAttributes funcAttr; +#if defined(USE_ROCM) + ET_CUDA_CHECK(cudaFuncGetAttributes(&funcAttr, (void *)func)); +#else + ET_CUDA_CHECK(cudaFuncGetAttributes(&funcAttr, func)); +#endif +} + +/* + * Not used by ET +// FIXME: parallelize better, smem staging etc? +template +__global__ void matrix_to_m16n8k16_Bint4_layout( + // size [n][k / 2] + const at::PackedTensorAccessor32 in, + // size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] + at::PackedTensorAccessor32 out) { + // int4 values are packed into int32 values, which require at least 8. Given + // m16n8k16 B layout requires 4 scalar values/lane, the minimum number of + // innermost k-tiles that we can use is 2. + static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), ""); + +#if defined(USE_ROCM) + constexpr int32_t kNTileSize = 16; +#else + constexpr int32_t kNTileSize = 8; +#endif + constexpr int32_t kKTileSize = 16; + + // gridDim.x corresponds to the number of k-tiles divided by InnerKTiles + auto kOuterTile = blockIdx.x; + auto nTile = blockIdx.y; + auto t = threadIdx.x; + + // Two k-tiles are packed into an int32 at a time +#pragma unroll + for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { + // n dimension that this lane loads from +#if defined(USE_ROCM) + auto n0 = nTile * kNTileSize + (t % kNTileSize); +#else + auto n0 = nTile * kNTileSize + (t / 4); +#endif + + bool n0Valid = n0 < in.size(0); + + // Four uint8 are packed into an int32 + int32_t ks[4]; + + auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize / 2; + +#if defined(USE_ROCM) + ks[0] = kBase0 + (t / kNTileSize) * 2; + ks[1] = ks[0] + 1; + + auto kBase1 = kBase0 + kKTileSize / 2; + ks[2] = kBase1 + (t / kNTileSize) * 2; + ks[3] = ks[2] + 1; +#else + ks[0] = kBase0 + t % 4; + ks[1] = ks[0] + 4; + + auto kBase1 = kBase0 + kKTileSize / 2; + ks[2] = kBase1 + t % 4; + ks[3] = ks[2] + 4; +#endif + + auto pIn = &in[n0][0]; + + uint8_t v[4]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + v[i] = (n0Valid && ks[i] < in.size(1)) ? pIn[ks[i]] : uint8_t(0); + } + + // To clearly explain the packed result with 8 int4 values (4 uint8) + // into one int32, we use the follow figure: + // [n][k] int32: v[0] v[1] v[2] v[3] v[4] v[5] v[6] v[7] + // [n][k / 2] uint8: v[0] v[1] v[2] v[3] + // When using int32 weight as input, the packed result is consisted of + // v[7] | v[5] | v[3] | v[1] | v[6] | v[4] | v[2] | v[0], + // which epuals to + // v[3]L | v[2]L | v[1]L | v[0]L | v[3]H | v[2]H | v[1]H | v[0]H + // when using uint8 weight as input. + int32_t pack = ((uint32_t)(v[3] & 0xF) << 28) | + ((uint32_t)(v[2] & 0xF) << 24) | ((uint32_t)(v[1] & 0xF) << 20) | + ((uint32_t)(v[0] & 0xF) << 16) | ((uint32_t)(v[3] & 0xF0) << 8) | + ((uint32_t)(v[2] & 0xF0) << 4) | ((uint32_t)(v[1] & 0xF0)) | + ((uint32_t)(v[0] & 0xF0) >> 4); + + // inner k-tiles pack two at a time +#if defined(USE_ROCM) + // The output tensor shape is [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2], which is specific to Nvidia + // But AMD needs [ceil(n / 16)][ceil(k / (InnerKTiles * 16))][64][InnerKTiles / 2] + // So construct the pointer accordingly + auto bPtr = out.data() + + ((nTile * out.size(1) * kWarpSize * (InnerKTiles / 2)) + + (kOuterTile * kWarpSize * (InnerKTiles / 2)) + + (t * (InnerKTiles / 2)) + + (innerKTile / 2)); + *bPtr = pack; +#else + out[nTile][kOuterTile][t][innerKTile / 2] = pack; +#endif + } +} +*/ + +#endif // defined(USE_ROCM) || CUDA_VERSION >= 12000 + + +Tensor* _weight_int4pack_mm_cuda( + const Tensor& A, + const Tensor& B, + int64_t qGroupSize, + const Tensor& qScaleAndZeros) { + // Skip CUDAGuard because ETensor doesn't carry device information + // auto result = CUDAGuard::create(0); + + // Skip device check because ETensor doesn't carry device information + // ET_CHECK( + // A.device() == B.device() && A.device() == qScaleAndZeros.device()); + +#if defined(USE_ROCM) + if (!isCDNA2orLater(A.device().index())) { + ET_CHECK(false, "_weight_int4pack_mm_cuda is only supported on AMD gpu arch greater than or equal to CDNA2"); + } +#endif + + constexpr int32_t kMTileSize = 16; +#if defined(USE_ROCM) + constexpr int32_t kNTileSize = 16; +#else + constexpr int32_t kNTileSize = 8; +#endif + constexpr int32_t kKTileSize = 16; + + // row major layout + auto m = A.size(0); + auto mTiles = divUp(m, kMTileSize); + + // To convert the nTiles from tensor storage layout to the actual matrix core layout + constexpr int32_t kNTileSizeTensor = 8; + auto nTileScaleFactor = (kNTileSize / kNTileSizeTensor); + + // tensor core layout + auto nTiles = (B.size(0) / nTileScaleFactor); + auto n = nTiles * kNTileSize; + + // row major layout + auto k = A.size(1); + auto kTiles = divUp(k, kKTileSize); + + // The number of inner k tiles is the innermost dimension of times 2 + // 2 k-tiles (4 values per lane per tile, 8 values total) quantized to int4 + // packed into 1 int32 for int4 B + auto B_innerKTiles = B.size(3) * 2; + ET_CHECK(B_innerKTiles == 2 || B_innerKTiles == 4 || B_innerKTiles == 8); + + // A is standard row major + ET_CHECK(A.dtype() == executorch::aten::ScalarType::BFloat16); + // ET only supports contiguous tensors for now + // ET_CHECK(A.is_contiguous()); + ET_CHECK(A.dim() == 2); + + // B has B_innerKTiles k-tiles in the innermost dimension + ET_CHECK(B.dtype() == executorch::aten::ScalarType::Int); + // ET only supports contiguous tensors for now + // ET_CHECK(B.is_contiguous()); + ET_CHECK(B.dim() == 4); + ET_CHECK(B.size(1) == k / (B_innerKTiles * kKTileSize)); + ET_CHECK(B.size(2) == 32); + + // Validate the scale and zero point tensor for dequantization + // These are the only versions handled at the moment + ET_CHECK( + qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || + qGroupSize == 256); + + ET_CHECK(qScaleAndZeros.dim() == 3); + auto numQGroups = qScaleAndZeros.size(0); + ET_CHECK( + kTiles * kKTileSize >= qGroupSize && + isEvenDivisor(kTiles * kKTileSize, qGroupSize)); + ET_CHECK(qScaleAndZeros.size(1) == n); + ET_CHECK(qScaleAndZeros.size(2) == 2); + + // Output is a standard row-major matrix + Tensor* C_final = nullptr; + std::array shape = {m, n}; + std::array stride = {n, 1}; + aoti_torch_empty_strided( + 2, + shape.data(), + stride.data(), + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, + &C_final + ); + +#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_MSG(stream_result.ok(), "Failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); +#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \ + do { \ + using ACLayout = ALayout_RM; \ + \ + ET_CHECK( \ + K_TILES_PER_WARP >= B_innerKTiles && \ + isEvenDivisor(K_TILES_PER_WARP, B_innerKTiles)); \ + \ + switch (B_innerKTiles) { \ + case 2: \ + if constexpr (K_TILES_PER_WARP >= 2) { \ + using BLayout = BLayout_TC_int4<2, Q_GROUP_SIZE>; \ + launch_tinygemm_kernel< \ + ACLayout, \ + BLayout, \ + ACLayout, \ + WARPS, \ + K_TILES_PER_WARP>( \ + A, \ + B, \ + &qScaleAndZeros, \ + *C_final, \ + mTiles, \ + nTiles, \ + kTiles, \ + m, \ + n, \ + k, \ + stream); \ + } \ + break; \ + case 4: \ + if constexpr (K_TILES_PER_WARP >= 4) { \ + using BLayout = BLayout_TC_int4<4, Q_GROUP_SIZE>; \ + launch_tinygemm_kernel< \ + ACLayout, \ + BLayout, \ + ACLayout, \ + WARPS, \ + K_TILES_PER_WARP>( \ + A, \ + B, \ + &qScaleAndZeros, \ + *C_final, \ + mTiles, \ + nTiles, \ + kTiles, \ + m, \ + n, \ + k, \ + stream); \ + } \ + break; \ + case 8: \ + if constexpr (K_TILES_PER_WARP >= 8) { \ + using BLayout = BLayout_TC_int4<8, Q_GROUP_SIZE>; \ + launch_tinygemm_kernel< \ + ACLayout, \ + BLayout, \ + ACLayout, \ + WARPS, \ + K_TILES_PER_WARP>( \ + A, \ + B, \ + &qScaleAndZeros, \ + *C_final, \ + mTiles, \ + nTiles, \ + kTiles, \ + m, \ + n, \ + k, \ + stream); \ + } \ + break; \ + default: \ + break; \ + } \ + } while (false) + +#define HANDLE_Q_GROUP(WARPS, K_TILES_PER_WARP, REDUCE_TYPE) \ + do { \ + switch (qGroupSize) { \ + case 32: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 32, REDUCE_TYPE); \ + break; \ + case 64: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 64, REDUCE_TYPE); \ + break; \ + case 128: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 128, REDUCE_TYPE); \ + break; \ + case 256: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 256, REDUCE_TYPE); \ + break; \ + } \ + } while (false) + + HANDLE_Q_GROUP(8, 8, KReductionType::None); + +#undef HANDLE_Q_GROUP +#undef RUN_GEMM + + return C_final; +#endif + ET_CHECK_MSG(false, "_weight_int4pack_mm_cuda is not available for build."); + return C_final; +} + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int4mm.h b/backends/cuda/runtime/shims/int4mm.h new file mode 100644 index 00000000000..87a9916b0aa --- /dev/null +++ b/backends/cuda/runtime/shims/int4mm.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Performs quantized INT4 matrix multiplication. + * + * INT4 weights are stored in a packed tensor core layout optimized for + * NVIDIA Ampere+ GPUs (sm_80+) using m16n8k16 tensor core tiles. + * + * HARDWARE REQUIREMENTS: + * - CUDA Compute Capability >= 8.0 (Ampere or later) + * - BFloat16 support (native on sm_80+) + * + * TENSOR REQUIREMENTS: + * @param self Input activation matrix [m, k] + * - Must be BFloat16 dtype + * - Must be 2D + * - Must be on CUDA device + * - Row-major layout (contiguous) + * + * @param mat2 Quantized weight matrix in packed tensor core layout + * - Must be Int32 dtype (contains packed INT4 values) + * - Must be 4D: [n/8][k/(InnerKTiles*16)][32][InnerKTiles/2] + * where InnerKTiles = 2, 4, or 8 + * - Each Int32 contains 8 packed INT4 values + * - Layout optimized for tensor core access patterns + * - Must be on CUDA device + * + * @param qGroupSize Quantization group size (number of values sharing + * scale/zero) + * - Must be one of: 32, 64, 128, or 256 + * - Smaller groups = higher accuracy but more metadata + * - Must evenly divide k dimension + * + * @param qScaleAndZeros Dequantization parameters [k/qGroupSize][n][2] + * - Must be BFloat16 dtype + * - Must be 3D + * - [:, :, 0] contains scales + * - [:, :, 1] contains zero points + * - Must be on CUDA device + * + * @param ret0 Output parameter for result matrix [m, n] + * - Allocated by this function as BFloat16 + * - Must not be null + * - Caller is responsible for freeing via aoti_torch_delete_tensor_object() + * + * @return AOTITorchError error code: + * - Error::Ok: Success + * - Error::InvalidArgument: Null pointer, wrong dtype, wrong dimensions, + * or invalid qGroupSize + * - Error::Internal: CUDA kernel launch failure + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda__weight_int4pack_mm( + Tensor* self, + Tensor* mat2, + int64_t qGroupSize, + Tensor* qScaleAndZeros, + Tensor** ret0); + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp new file mode 100644 index 00000000000..aaaf3913381 --- /dev/null +++ b/backends/cuda/runtime/shims/memory.cpp @@ -0,0 +1,776 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::aten::SizesType; +using executorch::aten::StridesType; +using executorch::backends::aoti::aoti_torch_get_device_index; +using executorch::backends::aoti::aoti_torch_get_dtype; +using executorch::backends::aoti::aoti_torch_get_sizes; +using executorch::backends::aoti::aoti_torch_get_strides; +using executorch::backends::aoti::convert_sizes_to_vector; +using executorch::backends::aoti::convert_strides_to_vector; +using executorch::backends::aoti::dtype_to_element_size; +using executorch::backends::aoti::dtype_to_scalar_type; +using executorch::backends::aoti::validate_storage_offset; + +// Global storage for tensors and their metadata +std::unordered_set> tensors; + +// Reference counting for memory addresses +// Maps memory address to number of tensors using it +// Special value: NOT_OWN (-1) means tensor never owns the memory +constexpr int32_t NOT_OWN = -1; +std::unordered_map memory_to_n_tensor; + +namespace { + +// Calculate linear offset from strides and indices +int64_t calculate_linear_offset( + const int64_t* indices, + const int64_t* strides, + int64_t ndim) { + int64_t offset = 0; + for (int64_t i = 0; i < ndim; ++i) { + offset += indices[i] * strides[i]; + } + return offset; +} + +// Convert linear index to multi-dimensional indices based on sizes +void linear_to_indices( + int64_t linear_idx, + const int64_t* sizes, + int64_t ndim, + int64_t* indices) { + for (int64_t i = ndim - 1; i >= 0; --i) { + indices[i] = linear_idx % sizes[i]; + linear_idx /= sizes[i]; + } +} + +// Generic pointwise copy function that handles arbitrary strides +template +AOTITorchError pointwise_copy_generic( + T* dst_data, + const T* src_data, + const int64_t* dst_sizes, + const int64_t* dst_strides, + const int64_t* src_sizes, + const int64_t* src_strides, + int64_t dst_ndim, + int64_t src_ndim, + int64_t total_elements) { + std::vector dst_indices(dst_ndim); + std::vector src_indices(src_ndim); + + for (int64_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + // Convert linear index to multi-dimensional indices for both tensors + linear_to_indices(linear_idx, dst_sizes, dst_ndim, dst_indices.data()); + linear_to_indices(linear_idx, src_sizes, src_ndim, src_indices.data()); + + // Calculate offsets for both source and destination + int64_t src_offset = + calculate_linear_offset(src_indices.data(), src_strides, src_ndim); + int64_t dst_offset = + calculate_linear_offset(dst_indices.data(), dst_strides, dst_ndim); + + // Copy element + dst_data[dst_offset] = src_data[src_offset]; + } + + return Error::Ok; +} + +} // anonymous namespace + +extern "C" { + +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { + // TODO(gasoonjia): verify given data is on the target device + (void)device_type; + (void)opaque_metadata; + (void)layout; + (void)opaque_metadata_size; + + // Validate input parameters first + ET_CHECK_OR_RETURN_ERROR( + data != nullptr, + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null"); + + ET_CHECK_OR_RETURN_ERROR( + !(sizes_ptr == nullptr && ndim > 0), + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_new_tensor != nullptr, + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null"); + + // Check that device_index is always 0 + ET_CHECK_OR_RETURN_ERROR( + device_index == 0, + InvalidArgument, + "device_index must be 0, got: %d", + device_index); + + // Validate dtype using SupportedDTypes from utils.h + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + // Storage offset must be 0 since from_blob cannot handle different offsets + ET_CHECK_OK_OR_RETURN_ERROR(validate_storage_offset(storage_offset)); + + // Convert sizes to the format expected by ExecutorTorch using SizesType + std::vector sizes = + convert_sizes_to_vector(ndim, sizes_ptr); + + // Convert strides using the common helper function with StridesType + std::vector strides = + convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create ExecutorTorch tensor that wraps the existing memory + // Note: We're NOT copying the data, just wrapping it + // Using CUDA-specific tensor maker that supports incontiguous tensors + auto tensor = make_tensor( + sizes, // tensor dimensions + data, // existing memory (don't copy!) + {}, // dim_order (empty, will be auto-generated) + strides, // tensor strides (allows different strides) + dtype_to_scalar_type(dtype) // map int32_t dtype to ScalarType + ); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, InvalidArgument, "Failed to create tensor from blob"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *ret_new_tensor = tensor.get(); + + // Check if this memory address is already being tracked + auto memory_it = memory_to_n_tensor.find(data); + ET_CHECK_OR_RETURN_ERROR( + memory_it == memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is already being tracked by another tensor", + data); + + // Mark this memory as NOT_OWN since tensor created from blob never owns + // memory + memory_to_n_tensor[data] = NOT_OWN; + + return Error::Ok; +} + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor) { + // Check that device_index is always 0 + ET_CHECK_OR_RETURN_ERROR( + device_index == 0, + InvalidArgument, + "device_index must be 0, got: %d", + device_index); + + // This requires us to reserve CUDA memory and put it into a ETensor + void* ptr; + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + size_t element_size = dtype_to_element_size(dtype); + ET_CHECK_OR_RETURN_ERROR( + element_size != 0, + InvalidArgument, + "Invalid element size for dtype: %d", + dtype); + + // Calculate storage size based on strides, matching PyTorch's behavior + // This is critical when sizes and strides don't match the expected contiguous + // layout Reference: PyTorch's computeStorageNbytes in EmptyTensor.cpp + int64_t storage_size = 1; // storage offset (0) + 1 + for (int64_t i = 0; i < ndim; i++) { + if (sizes_ptr[i] == 0) { + storage_size = 0; + break; + } + // For each dimension, add stride[i] * (size[i] - 1) + // This gives us the maximum offset in that dimension + int64_t stride_i = (strides_ptr != nullptr) ? strides_ptr[i] : 1; + if (strides_ptr == nullptr) { + // Calculate contiguous stride if not provided + for (int64_t j = i + 1; j < ndim; j++) { + stride_i *= sizes_ptr[j]; + } + } + storage_size += stride_i * (sizes_ptr[i] - 1); + } + int64_t nbytes = storage_size * element_size; + + if (device_type == static_cast(SupportedDevices::CUDA)) { + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaMallocAsync(&ptr, static_cast(nbytes), cudaStreamDefault)); + } else if (device_type == static_cast(SupportedDevices::CPU)) { + // Ensure 16-byte alignment for CPU memory to match CUDA requirements + ptr = aligned_alloc(16, nbytes); + ET_CHECK_OR_RETURN_ERROR( + ptr != nullptr, + MemoryAllocationFailed, + "Failed to allocate aligned CPU memory"); + } else { + ET_CHECK_OR_RETURN_ERROR( + false, + NotImplemented, + "Need to implement empty_strided for non-CUDA non-CPU device type %d", + device_type); + } + + // ETensor sizes + auto sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // ETensor strides + auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // ETensor creation with dynamic shape support for edge cases + // Using CUDA-specific tensor maker that supports incontiguous tensors + auto tensor = make_tensor( + sizes, + ptr, + {}, // dim_order (empty, will be auto-generated) + strides, + dtype_to_scalar_type(dtype)); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + *ret_new_tensor = tensor.get(); + + // This tensor owns the memory it allocated, set reference count to 1 + memory_to_n_tensor[ptr] = 1; + return Error::Ok; +} + +void clear_all_tensors() { + // Use aoti_torch_delete_tensor_object to properly delete each tensor + // Note: We need to collect tensor pointers first since deletion modifies the + // set + std::vector tensor_ptrs; + tensor_ptrs.reserve(tensors.size()); + for (const auto& tensor_shared : tensors) { + tensor_ptrs.push_back(tensor_shared.get()); + } + + // Now delete each tensor - this will modify the global tensors set + for (Tensor* tensor_ptr : tensor_ptrs) { + aoti_torch_delete_tensor_object(tensor_ptr); + } + + // tensors set should now be empty, but ensure it's cleared + tensors.clear(); + + ET_LOG(Info, "Cleared all tensors"); +} + +AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { + // Handle null tensor pointer + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, InvalidArgument, "Cannot delete null tensor"); + + // Check if tensor exists in our tracking + bool found_in_tensors = false; + for (auto it = tensors.begin(); it != tensors.end(); ++it) { + if (it->get() == tensor) { + found_in_tensors = true; + break; + } + } + + // If tensor not found in our tracking, it's invalid + ET_CHECK_OR_RETURN_ERROR( + found_in_tensors, InvalidArgument, "Didn't find tensor %p", tensor); + + // Find and delete the tensor + for (auto it = tensors.begin(); it != tensors.end(); ++it) { + if (it->get() == tensor) { + // Get the tensor before erasing + auto tensor_ptr = *it; + void* data_ptr = tensor_ptr->mutable_data_ptr(); + + // Find the reference count for this memory address + auto memory_it = memory_to_n_tensor.find(data_ptr); + if (memory_it != memory_to_n_tensor.end()) { + int32_t ref_count = memory_it->second; + + if (ref_count == NOT_OWN) { + // Tensor never owned the memory, skip freeing + // Just remove tensor from tracking + tensors.erase(it); + return Error::Ok; + } else if (ref_count == 1) { + // Only current tensor using this memory, free it + // Determine if it's GPU memory + cudaPointerAttributes attributes{}; + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&attributes, data_ptr)); + + if (attributes.type == cudaMemoryTypeDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaFreeAsync(data_ptr, cudaStreamDefault)); + } else { + ET_CHECK_OR_RETURN_ERROR( + attributes.type != cudaMemoryTypeManaged, + Internal, + "Expected host memory but got managed!") + // This is CPU memory - free immediately + aligned_free(data_ptr); + data_ptr = nullptr; + } + + // Remove from memory tracking + memory_to_n_tensor.erase(memory_it); + } else if (ref_count > 1) { + // Other tensors still using this memory, just decrement count + memory_to_n_tensor[data_ptr] = ref_count - 1; + } + } else { + ET_CHECK_OR_RETURN_ERROR( + false, + Internal, + "Internal error: memory not found during deletion"); + } + + // Remove tensor from set (this will call the destructor if it's the last + // reference) + tensors.erase(it); + return Error::Ok; + } + } + + // This should never be reached since we found it above + ET_CHECK_OR_RETURN_ERROR( + false, Internal, "Internal error: tensor not found after validation"); +} + +AOTITorchError +aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) { + (void)non_blocking; + + // Check for null pointers first + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch_copy_ failed: self tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + src != nullptr, + InvalidArgument, + "aoti_torch_copy_ failed: src tensor is null"); + + // Get dtype information and validate compatibility + int32_t self_dtype, src_dtype; + aoti_torch_get_dtype(self, &self_dtype); + aoti_torch_get_dtype(src, &src_dtype); + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(self_dtype)); + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(src_dtype)); + + // Check dtype compatibility - both tensors must have the same dtype + ET_CHECK_OR_RETURN_ERROR( + self_dtype == src_dtype, + InvalidArgument, + "dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_ requires same dtypes", + self_dtype, + src_dtype); + + // Check total number of elements compatibility (PyTorch copy_ behavior) + int64_t self_numel = self->numel(); + int64_t src_numel = src->numel(); + + ET_CHECK_OR_RETURN_ERROR( + self_numel == src_numel, + InvalidArgument, + "numel mismatch. self.numel()=%ld, src.numel()=%ld", + self_numel, + src_numel); + + // Get tensor metadata + int64_t* self_strides; + int64_t* src_strides; + aoti_torch_get_strides(self, &self_strides); + aoti_torch_get_strides(src, &src_strides); + + int64_t* self_sizes; + int64_t* src_sizes; + aoti_torch_get_sizes(self, &self_sizes); + aoti_torch_get_sizes(src, &src_sizes); + + // Determine device locations + cudaPointerAttributes srcAttributes{}; + cudaPointerAttributes dstAttributes{}; + + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&srcAttributes, src->data_ptr())); + + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&dstAttributes, self->data_ptr())); + + bool srcIsDevice = srcAttributes.type == cudaMemoryTypeDevice; + bool dstIsDevice = dstAttributes.type == cudaMemoryTypeDevice; + + // Check if tensors have the same schema (sizes, strides, dtype) for fast path + bool same_schema = true; + for (int i = 0; i < self->dim(); i++) { + if (self_strides[i] != src_strides[i]) { + same_schema = false; + break; + } + } + + size_t total_bytes = src->nbytes(); + int64_t total_elements = self->numel(); + + if (same_schema) { + // Fast path: Direct memory copy since layouts match exactly + if (srcIsDevice && dstIsDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + cudaMemcpyDeviceToDevice)); + } else if (srcIsDevice && !dstIsDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + cudaMemcpyDeviceToHost)); + } else if (!srcIsDevice && dstIsDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + cudaMemcpyHostToDevice)); + } else { + std::memcpy(self->mutable_data_ptr(), src->data_ptr(), total_bytes); + } + } else { + // Fallback path: Pointwise copy with stride-aware indexing + // This handles arbitrary tensor layouts and strides + + size_t element_size = dtype_to_element_size(self_dtype); + ET_CHECK_OR_RETURN_ERROR( + element_size != 0, + InvalidArgument, + "Invalid element size for dtype: %d", + self_dtype); + + // Allocate temporary host memory for GPU tensors + float* src_host_data = nullptr; + float* dst_host_data = nullptr; + bool need_free_src = false; + bool need_free_dst = false; + + if (srcIsDevice) { + src_host_data = + static_cast(malloc(total_elements * sizeof(float))); + ET_CHECK_OR_RETURN_ERROR( + src_host_data != nullptr, + MemoryAllocationFailed, + "Failed to allocate memory for src_host_data"); + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + src_host_data, src->data_ptr(), total_bytes, cudaMemcpyDeviceToHost)); + need_free_src = true; + } else { + src_host_data = static_cast(src->data_ptr()); + } + + if (dstIsDevice) { + dst_host_data = + static_cast(malloc(total_elements * sizeof(float))); + if (dst_host_data == nullptr) { + if (need_free_src) { + free(src_host_data); + } + ET_CHECK_OR_RETURN_ERROR( + false, + MemoryAllocationFailed, + "Failed to allocate memory for dst_host_data"); + } + need_free_dst = true; + } else { + dst_host_data = static_cast(self->mutable_data_ptr()); + } + + // Perform pointwise copy with stride calculation + AOTITorchError copy_err = pointwise_copy_generic( + dst_host_data, + src_host_data, + self_sizes, + self_strides, + src_sizes, + src_strides, + self->dim(), + src->dim(), + total_elements); + + if (copy_err != Error::Ok) { + // Clean up temporary buffers before returning + if (need_free_src) { + free(src_host_data); + } + if (need_free_dst) { + free(dst_host_data); + } + return copy_err; + } + + // Copy result back to device if needed + if (dstIsDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + self->mutable_data_ptr(), + dst_host_data, + total_bytes, + cudaMemcpyHostToDevice)); + } + + // Clean up temporary buffers + if (need_free_src) { + free(src_host_data); + } + if (need_free_dst) { + free(dst_host_data); + } + } + + return Error::Ok; +} + +AOTITorchError aoti_torch__reinterpret_tensor( + Tensor* self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + Tensor** ret_new_tensor) { + // Validate input parameters first + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: self tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + !(sizes_ptr == nullptr && ndim > 0), + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: sizes_ptr is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_new_tensor != nullptr, + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: ret_new_tensor is null"); + + // Check if storage_offset is not 0 - return error if not + ET_CHECK_OK_OR_RETURN_ERROR(validate_storage_offset(storage_offset)); + + // Get the device info from the source tensor to perform device_index + // validation + int32_t device_type = 0; + int32_t device_index = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_type(self, &device_type)); + + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_index(self, &device_index)); + + // Ensure device_index is always 0 + ET_CHECK_OR_RETURN_ERROR( + device_index == 0, + InvalidArgument, + "device_index must be 0, got: %d", + device_index); + + // Get the dtype from the source tensor + int32_t dtype = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(self, &dtype)); + + // Validate dtype using SupportedDTypes + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + // Get the original data pointer from the source tensor + void* data_ptr = self->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + // Check if the given memory is in the map, if not return error + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked by reference counting system", + data_ptr); + + // Convert sizes using utility function from utils.h + std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // Convert strides using utility function from utils.h + std::vector strides = + convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create new tensor view that reinterprets the same memory with different + // shape/strides This creates a view, not a copy - the data pointer is shared + // Using CUDA-specific tensor maker that supports incontiguous tensors + std::shared_ptr tensor = make_tensor( + sizes, // New sizes with explicit SizesType + data_ptr, // Reuse the same memory from source tensor + {}, // dim_order (empty, will be auto-generated) + strides, // New strides with explicit StridesType + dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting + ); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, + InvalidArgument, + "Failed to create reinterpreted tensor view"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *ret_new_tensor = tensor.get(); + + // Increment the reference count for this memory address only if it is owned + // by tensor + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; + + return Error::Ok; +} + +AOTITorchError aoti_torch_new_tensor_handle( + Tensor* orig_handle, + Tensor** new_handle) { + // Validate input parameters + ET_CHECK_OR_RETURN_ERROR( + orig_handle != nullptr, + InvalidArgument, + "aoti_torch_new_tensor_handle failed: orig_handle is null"); + + ET_CHECK_OR_RETURN_ERROR( + new_handle != nullptr, + InvalidArgument, + "aoti_torch_new_tensor_handle failed: new_handle is null"); + + // Get metadata from the original tensor + int64_t* sizes_ptr; + int64_t* strides_ptr; + int32_t dtype; + int32_t device_type; + int32_t device_index; + + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(orig_handle, &sizes_ptr)); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_strides(orig_handle, &strides_ptr)); + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(orig_handle, &dtype)); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_device_type(orig_handle, &device_type)); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_device_index(orig_handle, &device_index)); + + int64_t ndim = orig_handle->dim(); + + // Validate dtype + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + // Ensure device_index is always 0 + ET_CHECK_OR_RETURN_ERROR( + device_index == 0, + InvalidArgument, + "device_index must be 0, got: %d", + device_index); + + // Get the original data pointer from the source tensor + void* data_ptr = orig_handle->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + // Check if the given memory is in the map + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked by reference counting system", + data_ptr); + + // Convert sizes and strides to vectors + std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); + std::vector strides = + convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create new tensor that shares the same memory as the original + // This is similar to PyTorch's Tensor copy constructor - creates a new + // tensor object that shares the same underlying storage + std::shared_ptr tensor = make_tensor( + sizes, // Same sizes as original + data_ptr, // Share the same memory from source tensor + {}, // dim_order (empty, will be auto-generated) + strides, // Same strides as original + dtype_to_scalar_type(dtype) // Same dtype as original + ); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, InvalidArgument, "Failed to create new tensor handle"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *new_handle = tensor.get(); + + // Increment the reference count for this memory address only if it is owned + // by tensor + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; + + return Error::Ok; +} +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h new file mode 100644 index 00000000000..1a89d8b782c --- /dev/null +++ b/backends/cuda/runtime/shims/memory.h @@ -0,0 +1,172 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +extern "C" { + +/** + * Creates a tensor object from an existing memory blob without copying the + * data. The tensor will wrap the provided memory and will not take ownership of + * it. When the tensor is deleted, the original memory will remain valid and + * must be freed by the caller. + * + * @param data Pointer to the memory blob to wrap (must not be null) + * @param ndim Number of dimensions in the tensor + * @param sizes_ptr Pointer to array of dimension sizes (using SizesType) + * @param strides_ptr Pointer to array of strides for each dimension (using + * StridesType, can be null for contiguous) + * @param storage_offset Storage offset (must be 0 for current implementation) + * @param dtype Data type identifier (supports FLOAT32 and BFLOAT16 from + * SupportedDTypes) + * @param device_type Device type (CPU=0, CUDA=1 from SupportedDevices) + * @param device_index Device index (must be 0 for current implementation) + * @param ret_new_tensor Output parameter for the created tensor (must not be + * null) + * @param layout Tensor layout identifier (0=strided) + * @param opaque_metadata Optional metadata pointer (can be null) + * @param opaque_metadata_size Size of opaque metadata in bytes + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + +/** + * Creates an uninitialized tensor with specified dimensions, strides, and + * dtyper on either CPU or CUDA device. + * + * @param ndim Number of dimensions in the tensor + * @param sizes_ptr Pointer to array of dimension sizes + * @param strides_ptr Pointer to array of strides for each dimension + * @param dtype Data type identifier (matches PyTorch scalar types) + * @param device_type Device type (0=CPU, 1=CUDA) + * @param device_index Device index (must be 0 for current implementation) + * @param ret_new_tensor Output parameter for the created tensor + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor); + +/** + * Deletes a tensor object and frees its associated memory. + * + * @param tensor Pointer to the tensor object to be deleted + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor); + +/** + * Creates a tensor view that reinterprets the same underlying memory with + * different shape and strides without copying data. + * + * Note that the new tensor will not have the ownership of the underlying + * memory. + * + * @param self Input tensor whose memory will be reinterpreted + * @param ndim Number of dimensions for the new tensor view + * @param sizes_ptr Array of sizes for each dimension + * @param strides_ptr Array of strides for each dimension (or nullptr for + * contiguous) + * @param storage_offset Storage offset (must be 0) + * @param ret_new_tensor Output pointer to store the new tensor view + * + * @return Error::Ok on success, appropriate error code on failure + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor( + Tensor* self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + Tensor** ret_new_tensor); + +/** + * Copies data from source tensor to destination tensor. + * + * This function implements copy function for tensors living in CUDA AOTI + * backend. It supports copying between tensors with different shapes (as long + * as they have the same total number of elements) and different memory + * layouts/strides. + * + * Note that currently this function does not support copying between tensors + * with different dtypes. + * + * @param self Destination tensor (data will be overwritten) + * @param src Source tensor (data will be copied from this tensor) + * @param non_blocking Whether the copy should be non-blocking (currently + * ignored) + * + * @return Error::Ok on success, appropriate error code on failure: + * - Error::InvalidArgument: null pointers, dtype mismatch, numel + * mismatch + * - Error::MemoryAllocationFailed: failed to allocate temporary memory + * - Error::Internal: CUDA operation failures + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); + +/** + * Creates a new tensor handle from an existing one. + * + * This function creates a new tensor object that shares the same underlying + * memory as the original tensor. Similar to PyTorch's Tensor copy constructor, + * it creates a new handle/reference to the same data without performing a deep + * copy. + * + * The new tensor will: + * - Share the same memory/storage as the original tensor + * - Have the same shape, strides, and dtype as the original + * - Increment the reference count for the underlying memory (if owned) + * + * @param orig_handle Original tensor to create a new handle from (must not be + * null) + * @param new_handle Output pointer to store the new tensor handle (must not be + * null) + * + * @return Error::Ok on success, appropriate error code on failure: + * - Error::InvalidArgument: null pointers or invalid parameters + */ +AOTITorchError aoti_torch_new_tensor_handle( + Tensor* orig_handle, + Tensor** new_handle); + +// Function to clear all tensors from internal storage +AOTI_SHIM_EXPORT void clear_all_tensors(); +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tensor_attribute.cpp b/backends/cuda/runtime/shims/tensor_attribute.cpp new file mode 100644 index 00000000000..1a14c79f9f2 --- /dev/null +++ b/backends/cuda/runtime/shims/tensor_attribute.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch::backends::cuda { + +extern "C" { + +// Device type functions for tensor attributes +AOTITorchError aoti_torch_get_device_type( + Tensor* tensor, + int32_t* ret_device_type) { + // All tensors in aoti-cuda delegate are on CUDA + *ret_device_type = aoti_torch_device_type_cuda(); + return Error::Ok; +} + +// Device type constants +int32_t aoti_torch_device_type_cuda() { + // Let's say cuda is 1 for ET as well + return 1; +} + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tensor_attribute.h b/backends/cuda/runtime/shims/tensor_attribute.h new file mode 100644 index 00000000000..683f270ccda --- /dev/null +++ b/backends/cuda/runtime/shims/tensor_attribute.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +// Common using declarations for ExecutorTorch types +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +extern "C" { + +// Common AOTI type aliases +using AOTITorchError = Error; + +// Device type functions for tensor attributes +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_device_type(Tensor* tensor, int32_t* ret_device_type); + +// Device type constants +AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cuda(); + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/TARGETS b/backends/cuda/runtime/shims/tests/TARGETS new file mode 100644 index 00000000000..9ff3e83a8bd --- /dev/null +++ b/backends/cuda/runtime/shims/tests/TARGETS @@ -0,0 +1,6 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl new file mode 100644 index 00000000000..b274ecf3675 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -0,0 +1,37 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") + +def cuda_shim_cpp_unittest(name): + cpp_unittest( + name = "test_" + name, + srcs = [ + "test_" + name + ".cpp", + ], + deps = [ + "//executorch/backends/aoti:common_shims", + "//executorch/backends/cuda/runtime:runtime_shims", + "//executorch/extension/tensor:tensor", + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + "//executorch/runtime/core/exec_aten:lib", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + ) + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + cuda_shim_cpp_unittest("aoti_torch_empty_strided") + cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object") + cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2") + cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor") + cuda_shim_cpp_unittest("aoti_torch_copy_") + cuda_shim_cpp_unittest("aoti_torch_cuda_guard") + cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm") + cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp new file mode 100644 index 00000000000..d3044810b15 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp @@ -0,0 +1,812 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch__reinterpret_tensor tests +class AOTITorchReinterpretTensorTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to calculate number of elements from sizes + int64_t calculate_numel(const std::vector& sizes) { + int64_t numel = 1; + for (int64_t size : sizes) { + numel *= size; + } + return numel; + } + + // Helper to calculate contiguous strides from sizes + std::vector calculate_contiguous_strides( + const std::vector& sizes) { + std::vector strides(sizes.size()); + if (sizes.empty()) { + return strides; + } + + strides[sizes.size() - 1] = 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + return strides; + } + + // Helper to create a source tensor using empty_strided (which allocates new + // memory) + Tensor* create_source_tensor( + const std::vector& sizes, + int32_t dtype = 6, // float32 + int32_t device_type = 1, // CUDA + int32_t device_index = 0) { + std::vector strides = calculate_contiguous_strides(sizes); + + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + dtype, + device_type, + device_index, + &tensor); + + if (error != Error::Ok) { + return nullptr; + } + + return tensor; + } + + private: + std::vector cuda_memory_buffers_; + std::vector cpu_memory_buffers_; +}; + +// Test basic functionality: reinterpret tensor with different shapes +TEST_F(AOTITorchReinterpretTensorTest, BasicReinterpretation) { + // Create a source tensor with shape [12] (1D with 12 elements) + std::vector source_sizes = {12}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + // Store the original data pointer + void* original_data_ptr = source_tensor->mutable_data_ptr(); + ASSERT_NE(original_data_ptr, nullptr); + + // Reinterpret as [3, 4] (2D with same number of elements) + std::vector new_sizes = {3, 4}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check that the reinterpreted tensor has the new shape + EXPECT_EQ(reinterpreted_tensor->dim(), 2); + EXPECT_EQ(reinterpreted_tensor->size(0), 3); + EXPECT_EQ(reinterpreted_tensor->size(1), 4); + + // CRITICAL: Check that the reinterpreted tensor uses the SAME memory + void* reinterpreted_data_ptr = reinterpreted_tensor->mutable_data_ptr(); + EXPECT_EQ(reinterpreted_data_ptr, original_data_ptr) + << "Reinterpreted tensor should use the same memory as the source tensor"; + + // Write data through the original tensor and verify it's visible through the + // reinterpreted tensor + std::vector test_data = { + 1.0f, + 2.0f, + 3.0f, + 4.0f, + 5.0f, + 6.0f, + 7.0f, + 8.0f, + 9.0f, + 10.0f, + 11.0f, + 12.0f}; + cudaError_t cuda_err = cudaMemcpy( + original_data_ptr, + test_data.data(), + test_data.size() * sizeof(float), + cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess); + + // Read back through the reinterpreted tensor + std::vector readback_data(12); + cuda_err = cudaMemcpy( + readback_data.data(), + reinterpreted_data_ptr, + readback_data.size() * sizeof(float), + cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + + // Verify the data matches + for (size_t i = 0; i < test_data.size(); i++) { + EXPECT_EQ(readback_data[i], test_data[i]) + << "Data should be the same through both tensors at index " << i; + } +} + +// Test reinterpreting with different strides +TEST_F(AOTITorchReinterpretTensorTest, ReinterpretWithCustomStrides) { + // Create a source tensor with shape [2, 6] (contiguous) + std::vector source_sizes = {2, 6}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + void* original_data_ptr = source_tensor->mutable_data_ptr(); + ASSERT_NE(original_data_ptr, nullptr); + + // Reinterpret as [3, 4] with custom strides (still valid for the same memory) + std::vector new_sizes = {3, 4}; + std::vector new_strides = {4, 1}; // Row-major strides for [3, 4] + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check shape + EXPECT_EQ(reinterpreted_tensor->dim(), 2); + EXPECT_EQ(reinterpreted_tensor->size(0), 3); + EXPECT_EQ(reinterpreted_tensor->size(1), 4); + + // CRITICAL: Check that the reinterpreted tensor uses the SAME memory + void* reinterpreted_data_ptr = reinterpreted_tensor->mutable_data_ptr(); + EXPECT_EQ(reinterpreted_data_ptr, original_data_ptr) + << "Reinterpreted tensor should use the same memory as the source tensor"; + + // Verify strides were set correctly + int64_t* tensor_strides; + error = aoti_torch_get_strides(reinterpreted_tensor, &tensor_strides); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(tensor_strides[0], 4); + EXPECT_EQ(tensor_strides[1], 1); +} + +// Test error cases: null input tensor +TEST_F(AOTITorchReinterpretTensorTest, NullInputTensor) { + std::vector new_sizes = {2, 3}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + nullptr, // null input tensor + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test error cases: null sizes pointer +TEST_F(AOTITorchReinterpretTensorTest, NullSizesPointer) { + std::vector source_sizes = {6}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + std::vector new_strides = {2, 1}; + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + 2, // ndim > 0 + nullptr, // null sizes pointer + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test error cases: null return tensor pointer +TEST_F(AOTITorchReinterpretTensorTest, NullReturnTensorPointer) { + std::vector source_sizes = {6}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + std::vector new_sizes = {2, 3}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + nullptr); // null return tensor pointer + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test error cases: non-zero storage offset (should fail) +TEST_F(AOTITorchReinterpretTensorTest, NonZeroStorageOffset) { + std::vector source_sizes = {6}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + std::vector new_sizes = {2, 3}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 1, // non-zero storage_offset (should fail) + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test reinterpreting CPU tensor +TEST_F(AOTITorchReinterpretTensorTest, ReinterpretCPUTensor) { + // Create a CPU tensor with shape [8] + std::vector source_sizes = {8}; + Tensor* source_tensor = create_source_tensor( + source_sizes, + 6, // float32 + 0, // CPU device + 0); + ASSERT_NE(source_tensor, nullptr); + + void* original_data_ptr = source_tensor->mutable_data_ptr(); + ASSERT_NE(original_data_ptr, nullptr); + + // Reinterpret as [2, 4] + std::vector new_sizes = {2, 4}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check that the reinterpreted tensor uses the SAME memory + void* reinterpreted_data_ptr = reinterpreted_tensor->mutable_data_ptr(); + EXPECT_EQ(reinterpreted_data_ptr, original_data_ptr) + << "Reinterpreted CPU tensor should use the same memory as the source tensor"; + + // Test direct memory access for CPU tensors + float* original_float_ptr = reinterpret_cast(original_data_ptr); + float* reinterpreted_float_ptr = + reinterpret_cast(reinterpreted_data_ptr); + + // Write through original and read through reinterpreted + original_float_ptr[0] = 42.0f; + EXPECT_EQ(reinterpreted_float_ptr[0], 42.0f) + << "Changes through original tensor should be visible through reinterpreted tensor"; +} + +// Test that deleting source tensor doesn't affect reinterpreted tensor (they +// share memory) +TEST_F(AOTITorchReinterpretTensorTest, DeletionBehavior) { + std::vector source_sizes = {6}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + void* shared_data_ptr = source_tensor->mutable_data_ptr(); + + // Reinterpret as [2, 3] + std::vector new_sizes = {2, 3}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Verify they share the same memory + EXPECT_EQ(reinterpreted_tensor->mutable_data_ptr(), shared_data_ptr); + + // Delete the source tensor (which owns the memory) + error = aoti_torch_delete_tensor_object(source_tensor); + EXPECT_EQ(error, Error::Ok); + + // The reinterpreted tensor should still be valid but the memory might be + // freed Since the source tensor owned the memory, the reinterpreted tensor + // becomes invalid This is expected behavior - the user needs to manage the + // lifecycle properly + + // Clean up the reinterpreted tensor + error = aoti_torch_delete_tensor_object(reinterpreted_tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test scalar tensor reinterpretation +TEST_F(AOTITorchReinterpretTensorTest, ReinterpretScalarTensor) { + // Create a scalar tensor (0D) + std::vector source_sizes = {}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + void* original_data_ptr = source_tensor->mutable_data_ptr(); + + // Try to reinterpret scalar as [1] (1D with 1 element) + std::vector new_sizes = {1}; + std::vector new_strides = {1}; + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check that the reinterpreted tensor uses the SAME memory + EXPECT_EQ(reinterpreted_tensor->mutable_data_ptr(), original_data_ptr); + + // Check new shape + EXPECT_EQ(reinterpreted_tensor->dim(), 1); + EXPECT_EQ(reinterpreted_tensor->size(0), 1); +} + +// Test reinterpreting tensor with zero-sized dimension +// TODO: This test is disabled because zero-sized tensors have complex stride +// validation requirements that need further investigation +TEST_F(AOTITorchReinterpretTensorTest, DISABLED_ReinterpretZeroSizedTensor) { + // Create a tensor with shape [0, 5] (zero elements) + std::vector source_sizes = {0, 5}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + void* original_data_ptr = source_tensor->mutable_data_ptr(); + + // Reinterpret as [5, 0] (still zero elements) + std::vector new_sizes = {5, 0}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check that the reinterpreted tensor uses the SAME memory + EXPECT_EQ(reinterpreted_tensor->mutable_data_ptr(), original_data_ptr); + + // Check new shape + EXPECT_EQ(reinterpreted_tensor->dim(), 2); + EXPECT_EQ(reinterpreted_tensor->size(0), 5); + EXPECT_EQ(reinterpreted_tensor->size(1), 0); +} + +// Test with nullptr strides (should use contiguous strides) +TEST_F(AOTITorchReinterpretTensorTest, NullStridesPointer) { + std::vector source_sizes = {12}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + void* original_data_ptr = source_tensor->mutable_data_ptr(); + + // Reinterpret as [3, 4] with null strides (should calculate contiguous + // strides) + std::vector new_sizes = {3, 4}; + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + nullptr, // null strides - should calculate contiguous strides + 0, + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check that the reinterpreted tensor uses the SAME memory + EXPECT_EQ(reinterpreted_tensor->mutable_data_ptr(), original_data_ptr); + + // Check that contiguous strides were calculated correctly + int64_t* tensor_strides; + error = aoti_torch_get_strides(reinterpreted_tensor, &tensor_strides); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(tensor_strides[0], 4); // stride for dimension 0 should be 4 + EXPECT_EQ(tensor_strides[1], 1); // stride for dimension 1 should be 1 +} + +// Test bf16 tensor reinterpretation +TEST_F(AOTITorchReinterpretTensorTest, ReinterpretBF16Tensor) { + // Create a bf16 source tensor with shape [6] + std::vector source_sizes = {6}; + Tensor* source_tensor = create_source_tensor( + source_sizes, + static_cast( + SupportedDTypes::BFLOAT16), // bf16 dtype from SupportedDTypes + static_cast( + SupportedDevices::CUDA), // CUDA device from SupportedDevices + 0); // device_index must be 0 + ASSERT_NE(source_tensor, nullptr); + + void* original_data_ptr = source_tensor->mutable_data_ptr(); + ASSERT_NE(original_data_ptr, nullptr); + + // Verify the tensor is actually bf16 + int32_t actual_dtype = 0; + AOTITorchError dtype_check_error = + aoti_torch_get_dtype(source_tensor, &actual_dtype); + EXPECT_EQ(dtype_check_error, Error::Ok); + EXPECT_EQ(actual_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Source tensor should have bfloat16 dtype"; + + // Reinterpret as [2, 3] (same number of elements) + std::vector new_sizes = {2, 3}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(reinterpreted_tensor, nullptr); + + // Check that the reinterpreted tensor has the new shape + EXPECT_EQ(reinterpreted_tensor->dim(), 2); + EXPECT_EQ(reinterpreted_tensor->size(0), 2); + EXPECT_EQ(reinterpreted_tensor->size(1), 3); + + // Verify the dtype is preserved as bf16 + int32_t reinterpreted_dtype = 0; + dtype_check_error = + aoti_torch_get_dtype(reinterpreted_tensor, &reinterpreted_dtype); + EXPECT_EQ(dtype_check_error, Error::Ok); + EXPECT_EQ( + reinterpreted_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Reinterpreted tensor should preserve bfloat16 dtype"; + + // CRITICAL: Check that the reinterpreted tensor uses the SAME memory + void* reinterpreted_data_ptr = reinterpreted_tensor->mutable_data_ptr(); + EXPECT_EQ(reinterpreted_data_ptr, original_data_ptr) + << "Reinterpreted tensor should use the same memory as the source tensor"; + + // Test memory sharing by writing data through the original tensor + // and verifying it's visible through the reinterpreted tensor + // Note: bf16 has 2 bytes per element + std::vector test_data_bf16 = { + 0x3F80, 0x4000, 0x4040, 0x4080, 0x40A0, 0x40C0}; // bf16 values + cudaError_t cuda_err = cudaMemcpy( + original_data_ptr, + test_data_bf16.data(), + test_data_bf16.size() * sizeof(uint16_t), + cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess); + + // Read back through the reinterpreted tensor + std::vector readback_data_bf16(6); + cuda_err = cudaMemcpy( + readback_data_bf16.data(), + reinterpreted_data_ptr, + readback_data_bf16.size() * sizeof(uint16_t), + cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + + // Verify the data matches + for (size_t i = 0; i < test_data_bf16.size(); i++) { + EXPECT_EQ(readback_data_bf16[i], test_data_bf16[i]) + << "BF16 data should be the same through both tensors at index " << i; + } +} + +// Test reference counting behavior - memory not in map should fail +TEST_F(AOTITorchReinterpretTensorTest, MemoryNotInMapShouldFail) { + // Create a tensor directly without using our allocation functions + // This should NOT be in the reference counting map + void* external_memory; + ASSERT_EQ( + cudaMallocManaged(&external_memory, 12 * sizeof(float)), cudaSuccess); + + // Create a tensor by manually wrapping this memory without going through our + // APIs + std::vector sizes = {12}; + std::vector strides = calculate_contiguous_strides(sizes); + + // Create the tensor directly using ExecutorTorch extension + auto tensor_shared = executorch::extension::from_blob( + external_memory, + convert_sizes_to_vector(sizes.size(), sizes.data()), + convert_strides_to_vector(sizes.size(), sizes.data(), strides.data()), + executorch::runtime::etensor::ScalarType::Float); + + ASSERT_TRUE(tensor_shared); + Tensor* external_tensor = tensor_shared.get(); + + // Try to reinterpret this tensor - should fail because memory is not in map + std::vector new_sizes = {3, 4}; + std::vector new_strides = calculate_contiguous_strides(new_sizes); + + Tensor* reinterpreted_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + external_tensor, + new_sizes.size(), + new_sizes.data(), + new_strides.data(), + 0, // storage_offset + &reinterpreted_tensor); + + // Should fail because memory is not being tracked by reference counting + // system + EXPECT_EQ(error, Error::InvalidArgument); + + // Clean up the external memory + ASSERT_EQ(cudaFree(external_memory), cudaSuccess); +} + +// Test reference counting behavior - creating view increments reference count +TEST_F(AOTITorchReinterpretTensorTest, ViewCreationIncrementsReferenceCount) { + // Create a source tensor that owns memory (reference count = 1) + std::vector source_sizes = {12}; + Tensor* source_tensor = create_source_tensor(source_sizes); + ASSERT_NE(source_tensor, nullptr); + + void* shared_data_ptr = source_tensor->mutable_data_ptr(); + ASSERT_NE(shared_data_ptr, nullptr); + + // Create first view - should increment reference count to 2 + std::vector view1_sizes = {3, 4}; + std::vector view1_strides = + calculate_contiguous_strides(view1_sizes); + + Tensor* view1_tensor; + AOTITorchError error = aoti_torch__reinterpret_tensor( + source_tensor, + view1_sizes.size(), + view1_sizes.data(), + view1_strides.data(), + 0, + &view1_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view1_tensor, nullptr); + EXPECT_EQ(view1_tensor->mutable_data_ptr(), shared_data_ptr); + + // Create second view - should increment reference count to 3 + std::vector view2_sizes = {2, 6}; + std::vector view2_strides = + calculate_contiguous_strides(view2_sizes); + + Tensor* view2_tensor; + error = aoti_torch__reinterpret_tensor( + source_tensor, + view2_sizes.size(), + view2_sizes.data(), + view2_strides.data(), + 0, + &view2_tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view2_tensor, nullptr); + EXPECT_EQ(view2_tensor->mutable_data_ptr(), shared_data_ptr); + + // Now delete the source tensor - memory should NOT be freed (reference count + // = 2) + error = aoti_torch_delete_tensor_object(source_tensor); + EXPECT_EQ(error, Error::Ok); + + // Both views should still be valid - test by accessing memory + float test_value = 42.0f; + cudaError_t cuda_err = cudaMemcpy( + shared_data_ptr, &test_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess); + + float readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, + view1_tensor->mutable_data_ptr(), + sizeof(float), + cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + EXPECT_EQ(readback_value, test_value); + + // Delete first view - memory should still NOT be freed (reference count = 1) + error = aoti_torch_delete_tensor_object(view1_tensor); + EXPECT_EQ(error, Error::Ok); + + // Second view should still be valid + readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, + view2_tensor->mutable_data_ptr(), + sizeof(float), + cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + EXPECT_EQ(readback_value, test_value); + + // Delete second view - NOW memory should be freed (reference count = 0) + error = aoti_torch_delete_tensor_object(view2_tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test reference counting behavior with NOT_OWN memory (from blob) - should +// SUCCEED and keep NOT_OWN +TEST_F(AOTITorchReinterpretTensorTest, ViewOfNotOwnMemoryKeepsNotOwnStatus) { + // Allocate external memory + void* external_memory; + cudaError_t cuda_err = + cudaMallocManaged(&external_memory, 12 * sizeof(float)); + ASSERT_EQ(cuda_err, cudaSuccess); + + // Create tensor from blob (which marks memory as NOT_OWN) + std::vector blob_sizes = {12}; + std::vector blob_strides = calculate_contiguous_strides(blob_sizes); + + Tensor* blob_tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + external_memory, + blob_sizes.size(), + blob_sizes.data(), + blob_strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device_index + &blob_tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(blob_tensor, nullptr); + + // Create view of NOT_OWN memory - should SUCCEED and keep NOT_OWN status + std::vector view_sizes = {3, 4}; + std::vector view_strides = calculate_contiguous_strides(view_sizes); + + Tensor* view_tensor; + error = aoti_torch__reinterpret_tensor( + blob_tensor, + view_sizes.size(), + view_sizes.data(), + view_strides.data(), + 0, + &view_tensor); + + // Should succeed - NOT_OWN memory can be reinterpreted but stays NOT_OWN + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(view_tensor, nullptr); + EXPECT_EQ(view_tensor->mutable_data_ptr(), external_memory); + + // Verify both tensors share the same memory + EXPECT_EQ(blob_tensor->mutable_data_ptr(), view_tensor->mutable_data_ptr()); + + // Test memory sharing by writing data through one tensor and reading through + // the other + float test_value = 42.0f; + cuda_err = cudaMemcpy( + external_memory, &test_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess); + + float readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, + view_tensor->mutable_data_ptr(), + sizeof(float), + cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + EXPECT_EQ(readback_value, test_value); + + // Delete the blob tensor - external memory should NOT be freed (NOT_OWN + // behavior) + error = aoti_torch_delete_tensor_object(blob_tensor); + EXPECT_EQ(error, Error::Ok); + + // View tensor should still be valid - test by accessing memory + readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, + view_tensor->mutable_data_ptr(), + sizeof(float), + cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + EXPECT_EQ(readback_value, test_value); + + // Delete view tensor - external memory should still NOT be freed (NOT_OWN + // behavior) + error = aoti_torch_delete_tensor_object(view_tensor); + EXPECT_EQ(error, Error::Ok); + + // External memory should still be accessible (proves neither tensor freed it) + readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, external_memory, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess); + EXPECT_EQ(readback_value, test_value); + + // Clean up external memory manually (as expected for NOT_OWN memory) + ASSERT_EQ(cudaFree(external_memory), cudaSuccess); +} diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp new file mode 100644 index 00000000000..9fca0f92cf8 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp @@ -0,0 +1,398 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::backends::aoti; +using namespace executorch::runtime; + +// Test fixture for aoti_torch_copy_ tests +class AOTITorchCopyTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create test tensors with specific data + Tensor* create_test_tensor_with_data( + const std::vector& sizes, + const std::vector& data, + const std::vector& strides = {}, + int32_t dtype = static_cast(SupportedDTypes::FLOAT32), + int32_t device_type = static_cast(SupportedDevices::CUDA), + int32_t device_index = 0) { + Tensor* tensor; + + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides_ptr, + dtype, + device_type, + device_index, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Fill tensor with data + size_t total_bytes = data.size() * sizeof(float); + if (device_type == static_cast(SupportedDevices::CUDA)) { + cudaError_t memcpy_err = cudaMemcpy( + tensor->mutable_data_ptr(), + data.data(), + total_bytes, + cudaMemcpyHostToDevice); + // Note: Error is checked but we don't fail the function + // This allows tests to proceed and handle errors as needed + (void)memcpy_err; // Suppress unused variable warning + } else { // CPU + std::memcpy(tensor->mutable_data_ptr(), data.data(), total_bytes); + } + + return tensor; + } + + // Helper to get data from tensor + std::vector get_tensor_data(Tensor* tensor) { + if (!tensor) { + return {}; + } + + size_t num_elements = tensor->numel(); + std::vector data(num_elements); + + // Determine if this is a CUDA tensor + cudaPointerAttributes attributes{}; + cudaError_t err = cudaPointerGetAttributes(&attributes, tensor->data_ptr()); + bool is_device = + (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice); + + if (is_device) { + cudaError_t memcpy_err = cudaMemcpy( + data.data(), + tensor->data_ptr(), + num_elements * sizeof(float), + cudaMemcpyDeviceToHost); + // Note: Error is checked but we don't fail the function + // This allows tests to proceed and handle errors as needed + (void)memcpy_err; // Suppress unused variable warning + } else { + std::memcpy( + data.data(), tensor->data_ptr(), num_elements * sizeof(float)); + } + + return data; + } + + // Helper to verify two tensors have same data + bool tensors_equal(Tensor* a, Tensor* b, float tolerance = 1e-6f) { + if (!a || !b) { + return false; + } + if (a->numel() != b->numel()) { + return false; + } + + auto data_a = get_tensor_data(a); + auto data_b = get_tensor_data(b); + + for (size_t i = 0; i < data_a.size(); ++i) { + if (std::abs(data_a[i] - data_b[i]) > tolerance) { + return false; + } + } + return true; + } +}; + +// Test basic copy functionality - same schema (fast path) +TEST_F(AOTITorchCopyTest, BasicCopySameSchema) { + // Create source tensor with test data + std::vector sizes = {2, 3}; + std::vector src_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + Tensor* src = create_test_tensor_with_data(sizes, src_data); + EXPECT_NE(src, nullptr); + + // Create destination tensor with same schema + Tensor* dst = + create_test_tensor_with_data(sizes, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + EXPECT_NE(dst, nullptr); + + // Perform copy + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify copy was successful + EXPECT_TRUE(tensors_equal(dst, src)); +} + +// Test copy with different strides (pointwise fallback) +TEST_F(AOTITorchCopyTest, CopyDifferentStrides) { + // Create source tensor (2x3) with contiguous layout + std::vector src_sizes = {2, 3}; + std::vector src_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + Tensor* src = create_test_tensor_with_data(src_sizes, src_data); + EXPECT_NE(src, nullptr); + + // Create destination tensor with transposed strides + std::vector dst_strides = {1, 2}; // Column-major layout + Tensor* dst = create_test_tensor_with_data( + src_sizes, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, dst_strides); + EXPECT_NE(dst, nullptr); + + // Perform copy - this should use pointwise fallback + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify the copy worked correctly by checking specific elements + auto dst_data = get_tensor_data(dst); + auto src_data_check = get_tensor_data(src); + + // For transposed layout, the data should be rearranged + EXPECT_EQ(dst_data.size(), 6); + EXPECT_EQ(src_data_check.size(), 6); +} + +// Test copy between CPU and CUDA tensors +TEST_F(AOTITorchCopyTest, CopyCPUToCUDA) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + + // Create CPU tensor + Tensor* cpu_tensor = create_test_tensor_with_data( + sizes, + data, + {}, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU)); // CPU + EXPECT_NE(cpu_tensor, nullptr); + + // Create CUDA tensor + Tensor* cuda_tensor = create_test_tensor_with_data( + sizes, + {0.0f, 0.0f, 0.0f, 0.0f}, + {}, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA)); // CUDA + EXPECT_NE(cuda_tensor, nullptr); + + // Copy from CPU to CUDA + AOTITorchError error = aoti_torch_copy_(cuda_tensor, cpu_tensor, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify copy + EXPECT_TRUE(tensors_equal(cuda_tensor, cpu_tensor)); +} + +// Test copy between CUDA and CPU tensors +TEST_F(AOTITorchCopyTest, CopyCUDAToCPU) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + + // Create CUDA tensor + Tensor* cuda_tensor = create_test_tensor_with_data( + sizes, + data, + {}, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA)); // CUDA + EXPECT_NE(cuda_tensor, nullptr); + + // Create CPU tensor + Tensor* cpu_tensor = create_test_tensor_with_data( + sizes, + {0.0f, 0.0f, 0.0f, 0.0f}, + {}, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU)); // CPU + EXPECT_NE(cpu_tensor, nullptr); + + // Copy from CUDA to CPU + AOTITorchError error = aoti_torch_copy_(cpu_tensor, cuda_tensor, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify copy + EXPECT_TRUE(tensors_equal(cpu_tensor, cuda_tensor)); +} + +// Test copy with bf16 dtype support +TEST_F(AOTITorchCopyTest, CopyBf16Tensors) { + // Test that bf16 tensors can be created and copied + std::vector sizes = {2, 3}; + std::vector src_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Note: We create float32 data but the tensor will be created with bf16 dtype + // This simulates creating bf16 tensors + Tensor* src = create_test_tensor_with_data( + sizes, + src_data, + {}, // default strides + static_cast(SupportedDTypes::BFLOAT16), // bf16 dtype + static_cast(SupportedDevices::CUDA), // CUDA device + 0 // device_index = 0 + ); + EXPECT_NE(src, nullptr); + + // Create destination tensor with bf16 dtype + std::vector dst_init(6, 0.0f); + Tensor* dst = create_test_tensor_with_data( + sizes, + dst_init, + {}, // default strides + static_cast(SupportedDTypes::BFLOAT16), // bf16 dtype + static_cast(SupportedDevices::CUDA), // CUDA device + 0 // device_index = 0 + ); + EXPECT_NE(dst, nullptr); + + // Perform copy between bf16 tensors + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify that both tensors have the expected dtype + int32_t src_dtype, dst_dtype; + aoti_torch_get_dtype(src, &src_dtype); + aoti_torch_get_dtype(dst, &dst_dtype); + + EXPECT_EQ(src_dtype, static_cast(SupportedDTypes::BFLOAT16)); + EXPECT_EQ(dst_dtype, static_cast(SupportedDTypes::BFLOAT16)); + + // Verify copy was successful by checking numel matches + EXPECT_EQ(src->numel(), dst->numel()); + EXPECT_EQ(src->numel(), 6); +} + +// Test copy between different dtypes should fail +TEST_F(AOTITorchCopyTest, CopyDTypeMismatchError) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + + // Create float32 tensor + Tensor* float32_tensor = create_test_tensor_with_data( + sizes, + data, + {}, // default strides + static_cast(SupportedDTypes::FLOAT32), // float32 dtype + static_cast(SupportedDevices::CUDA), // CUDA device + 0 // device_index = 0 + ); + EXPECT_NE(float32_tensor, nullptr); + + // Create bf16 tensor + Tensor* bf16_tensor = create_test_tensor_with_data( + sizes, + {0.0f, 0.0f, 0.0f, 0.0f}, + {}, // default strides + static_cast(SupportedDTypes::BFLOAT16), // bf16 dtype + static_cast(SupportedDevices::CUDA), // CUDA device + 0 // device_index = 0 + ); + EXPECT_NE(bf16_tensor, nullptr); + + // Attempting to copy between different dtypes should fail + AOTITorchError error = aoti_torch_copy_(bf16_tensor, float32_tensor, 0); + EXPECT_EQ(error, Error::InvalidArgument); + + // Reverse direction should also fail + error = aoti_torch_copy_(float32_tensor, bf16_tensor, 0); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test error conditions +TEST_F(AOTITorchCopyTest, ErrorHandling) { + std::vector sizes = {2, 3}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + Tensor* valid_tensor = create_test_tensor_with_data(sizes, data); + EXPECT_NE(valid_tensor, nullptr); + + // Test null pointers + AOTITorchError error = aoti_torch_copy_(nullptr, valid_tensor, 0); + EXPECT_NE(error, Error::Ok); + + error = aoti_torch_copy_(valid_tensor, nullptr, 0); + EXPECT_NE(error, Error::Ok); + + // Test numel mismatch (different total number of elements) + std::vector different_numel_sizes = { + 2, 3, 4}; // 24 elements vs 6 elements + std::vector different_data(24, 1.0f); + Tensor* different_numel = + create_test_tensor_with_data(different_numel_sizes, different_data); + EXPECT_NE(different_numel, nullptr); + + error = aoti_torch_copy_(valid_tensor, different_numel, 0); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test copy from 1D to 3D with same total elements +TEST_F(AOTITorchCopyTest, Copy1DTo3DSameNumel) { + // Source tensor: 8 elements in 1D + std::vector src_sizes = {8}; + std::vector src_data = { + 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + + Tensor* src = create_test_tensor_with_data(src_sizes, src_data); + EXPECT_NE(src, nullptr); + + // Destination tensor: 2x2x2 = 8 elements (different shape, same total) + std::vector dst_sizes = {2, 2, 2}; + std::vector dst_init(8, 0.0f); + Tensor* dst = create_test_tensor_with_data(dst_sizes, dst_init); + EXPECT_NE(dst, nullptr); + + // This should work - same total number of elements + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify the data was copied correctly + auto dst_data = get_tensor_data(dst); + EXPECT_EQ(dst_data.size(), 8); + + // Check some specific elements to verify correct copying + EXPECT_FLOAT_EQ(dst_data[0], 1.0f); + EXPECT_FLOAT_EQ(dst_data[7], 8.0f); +} diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp new file mode 100644 index 00000000000..d9b785a5a78 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp @@ -0,0 +1,754 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_create_tensor_from_blob_v2 tests +class AOTITorchCreateTensorFromBlobV2Test : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + + // Clean up any allocated memory buffers + for (void* ptr : cuda_memory_buffers_) { + if (ptr) { + cudaError_t cuda_err = cudaFree(ptr); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Failed to free CUDA memory: " << cudaGetErrorString(cuda_err); + } + } + cuda_memory_buffers_.clear(); + + for (void* ptr : cpu_memory_buffers_) { + if (ptr) { + free(ptr); + } + } + cpu_memory_buffers_.clear(); + } + + // Helper to allocate CUDA memory and track it for cleanup + void* allocate_cuda_memory(size_t bytes) { + void* ptr; + cudaError_t err = cudaMallocManaged(&ptr, bytes); + if (err == cudaSuccess) { + cuda_memory_buffers_.push_back(ptr); + return ptr; + } + return nullptr; + } + + // Helper to allocate CPU memory and track it for cleanup + void* allocate_cpu_memory(size_t bytes) { + void* ptr; + int result = posix_memalign(&ptr, 16, bytes); // 16-byte aligned + if (result == 0 && ptr != nullptr) { + cpu_memory_buffers_.push_back(ptr); + return ptr; + } + return nullptr; + } + + // Helper to calculate number of elements from sizes + int64_t calculate_numel(const std::vector& sizes) { + int64_t numel = 1; + for (int64_t size : sizes) { + numel *= size; + } + return numel; + } + + // Helper to calculate contiguous strides from sizes + std::vector calculate_contiguous_strides( + const std::vector& sizes) { + std::vector strides(sizes.size()); + if (sizes.empty()) { + return strides; + } + + strides[sizes.size() - 1] = 1; + // Use int64_t and check for underflow to avoid unsigned integer wraparound + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + return strides; + } + + private: + std::vector cuda_memory_buffers_; + std::vector cpu_memory_buffers_; +}; + +// Test basic functionality with CUDA memory +TEST_F(AOTITorchCreateTensorFromBlobV2Test, BasicFunctionalityCUDA) { + // Test 1D tensor + std::vector sizes_1d = {5}; + std::vector strides_1d = calculate_contiguous_strides(sizes_1d); + + // Allocate CUDA memory + size_t bytes = calculate_numel(sizes_1d) * sizeof(float); + void* cuda_data = allocate_cuda_memory(bytes); + ASSERT_NE(cuda_data, nullptr); + + Tensor* tensor_1d; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + cuda_data, + sizes_1d.size(), + sizes_1d.data(), + strides_1d.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_1d, + 0, // layout (strided) + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_1d, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor_1d->dim(), 1); + EXPECT_EQ(tensor_1d->size(0), 5); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor_1d->mutable_data_ptr(); + EXPECT_EQ(tensor_data, cuda_data); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor_1d); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) For CUDA memory, check that we can still access it (synchronously) + // after tensor deletion + float pattern_value = 42.0f; + cudaError_t cuda_err = cudaMemcpy( + cuda_data, &pattern_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to write to original CUDA memory after tensor deletion"; + + float readback_value = 0.0f; + cuda_err = cudaMemcpy( + &readback_value, cuda_data, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to read from original CUDA memory after tensor deletion"; + EXPECT_EQ(readback_value, pattern_value) + << "Original CUDA memory should still contain our test pattern"; +} + +// Test basic functionality with CPU memory +TEST_F(AOTITorchCreateTensorFromBlobV2Test, BasicFunctionalityCPU) { + // Test 2D tensor + std::vector sizes_2d = {3, 4}; + std::vector strides_2d = calculate_contiguous_strides(sizes_2d); + + // Allocate CPU memory + size_t bytes = calculate_numel(sizes_2d) * sizeof(float); + void* cpu_data = allocate_cpu_memory(bytes); + ASSERT_NE(cpu_data, nullptr); + + Tensor* tensor_2d; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + cpu_data, + sizes_2d.size(), + sizes_2d.data(), + strides_2d.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU), + 0, // device index + &tensor_2d, + 0, // layout (strided) + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_2d, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor_2d->dim(), 2); + EXPECT_EQ(tensor_2d->size(0), 3); + EXPECT_EQ(tensor_2d->size(1), 4); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor_2d->mutable_data_ptr(); + EXPECT_EQ(tensor_data, cpu_data); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor_2d); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) For CPU memory, directly write and read to verify accessibility + float* float_ptr = reinterpret_cast(cpu_data); + float pattern_value = 42.0f; + *float_ptr = pattern_value; + EXPECT_EQ(*float_ptr, pattern_value) + << "Original CPU memory should still be accessible after tensor deletion"; +} + +// Test with invalid dtype +TEST_F(AOTITorchCreateTensorFromBlobV2Test, InvalidDtype) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + 999, // invalid dtype + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with non-zero storage offset (should fail since from_blob cannot handle +// offsets) +TEST_F(AOTITorchCreateTensorFromBlobV2Test, NonZeroStorageOffset) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 1, // non-zero storage_offset (should fail since from_blob cannot handle + // offsets) + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with custom strides (using stride parameter but still contiguous) +TEST_F(AOTITorchCreateTensorFromBlobV2Test, CustomContiguousStrides) { + std::vector sizes = {2, 3}; + // Use the correct contiguous strides but pass them explicitly + std::vector contiguous_strides = {3, 1}; // Proper contiguous strides + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + contiguous_strides.data(), // Explicitly pass contiguous strides + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor->mutable_data_ptr(); + EXPECT_EQ(tensor_data, data); + + // Verify strides were properly set (we can check via aoti_torch_get_strides) + int64_t* tensor_strides; + error = aoti_torch_get_strides(tensor, &tensor_strides); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(tensor_strides[0], 3); + EXPECT_EQ(tensor_strides[1], 1); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) + float pattern_value = 42.0f; + cudaError_t cuda_err = + cudaMemcpy(data, &pattern_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to write to original CUDA memory after tensor deletion"; + + float readback_value = 0.0f; + cuda_err = + cudaMemcpy(&readback_value, data, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to read from original CUDA memory after tensor deletion"; + EXPECT_EQ(readback_value, pattern_value) + << "Original CUDA memory should still contain our test pattern"; +} + +// Test with null data pointer +TEST_F(AOTITorchCreateTensorFromBlobV2Test, NullDataPointer) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + nullptr, // null data pointer + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test scalar tensor (0D) +TEST_F(AOTITorchCreateTensorFromBlobV2Test, ScalarTensor) { + std::vector sizes = {}; // 0D tensor + std::vector strides = {}; // Empty strides for scalar + + size_t bytes = sizeof(float); // Single element + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 0); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor->mutable_data_ptr(); + EXPECT_EQ(tensor_data, data); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) + float pattern_value = 42.0f; + cudaError_t cuda_err = + cudaMemcpy(data, &pattern_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to write to original CUDA memory after tensor deletion"; + + float readback_value = 0.0f; + cuda_err = + cudaMemcpy(&readback_value, data, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to read from original CUDA memory after tensor deletion"; + EXPECT_EQ(readback_value, pattern_value) + << "Original CUDA memory should still contain our test pattern"; +} + +// Test zero-sized tensor +TEST_F(AOTITorchCreateTensorFromBlobV2Test, ZeroSizedTensor) { + std::vector sizes = {0, 5}; // Zero elements + std::vector strides = calculate_contiguous_strides(sizes); + + // Even for zero-sized tensor, we need some memory allocated + size_t bytes = sizeof(float); // Minimum allocation + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 0); + EXPECT_EQ(tensor->size(1), 5); + + // Verify the tensor uses the same data pointer + void* tensor_data = tensor->mutable_data_ptr(); + EXPECT_EQ(tensor_data, data); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // Test that the original memory is still accessible (proves tensor didn't own + // it) + float pattern_value = 42.0f; + cudaError_t cuda_err = + cudaMemcpy(data, &pattern_value, sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to write to original CUDA memory after tensor deletion"; + + float readback_value = 0.0f; + cuda_err = + cudaMemcpy(&readback_value, data, sizeof(float), cudaMemcpyDeviceToHost); + EXPECT_EQ(cuda_err, cudaSuccess) + << "Should be able to read from original CUDA memory after tensor deletion"; + EXPECT_EQ(readback_value, pattern_value) + << "Original CUDA memory should still contain our test pattern"; +} + +// Test multi-dimensional tensors +TEST_F(AOTITorchCreateTensorFromBlobV2Test, MultiDimensionalTensors) { + // Test 3D tensor + std::vector sizes_3d = {2, 3, 4}; + std::vector strides_3d = calculate_contiguous_strides(sizes_3d); + + size_t bytes_3d = calculate_numel(sizes_3d) * sizeof(float); + void* data_3d = allocate_cuda_memory(bytes_3d); + ASSERT_NE(data_3d, nullptr); + + Tensor* tensor_3d; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data_3d, + sizes_3d.size(), + sizes_3d.data(), + strides_3d.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_3d, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_3d, nullptr); + EXPECT_EQ(tensor_3d->dim(), 3); + EXPECT_EQ(tensor_3d->size(0), 2); + EXPECT_EQ(tensor_3d->size(1), 3); + EXPECT_EQ(tensor_3d->size(2), 4); + + // Test 4D tensor + std::vector sizes_4d = {2, 3, 4, 5}; + std::vector strides_4d = calculate_contiguous_strides(sizes_4d); + + size_t bytes_4d = calculate_numel(sizes_4d) * sizeof(float); + void* data_4d = allocate_cuda_memory(bytes_4d); + ASSERT_NE(data_4d, nullptr); + + Tensor* tensor_4d; + error = aoti_torch_create_tensor_from_blob_v2( + data_4d, + sizes_4d.size(), + sizes_4d.data(), + strides_4d.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_4d, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_4d, nullptr); + EXPECT_EQ(tensor_4d->dim(), 4); + EXPECT_EQ(tensor_4d->size(0), 2); + EXPECT_EQ(tensor_4d->size(1), 3); + EXPECT_EQ(tensor_4d->size(2), 4); + EXPECT_EQ(tensor_4d->size(3), 5); +} + +// Test tensor data pointer consistency +TEST_F(AOTITorchCreateTensorFromBlobV2Test, DataPointerConsistency) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* original_data = allocate_cuda_memory(bytes); + ASSERT_NE(original_data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + original_data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check that the tensor uses the same data pointer + void* tensor_data = tensor->mutable_data_ptr(); + EXPECT_EQ(tensor_data, original_data); +} + +// Test creating multiple tensors from different blobs +TEST_F(AOTITorchCreateTensorFromBlobV2Test, MultipleTensorsFromBlobs) { + const int num_tensors = 5; + std::vector tensors; + std::vector data_ptrs; + + for (int i = 0; i < num_tensors; i++) { + std::vector sizes = {i + 1, i + 2}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + data_ptrs.push_back(data); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + tensors.push_back(tensor); + + // Verify dimensions + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), i + 1); + EXPECT_EQ(tensor->size(1), i + 2); + + // Verify the tensor uses the correct data pointer + EXPECT_EQ(tensor->mutable_data_ptr(), data); + } + + // Verify all tensors have different data pointers + for (int i = 0; i < num_tensors; i++) { + EXPECT_EQ(tensors[i]->mutable_data_ptr(), data_ptrs[i]); + for (int j = i + 1; j < num_tensors; j++) { + EXPECT_NE(tensors[i]->mutable_data_ptr(), tensors[j]->mutable_data_ptr()); + } + } +} + +// Test deletion of tensor created from blob (should not free the original +// memory) +TEST_F(AOTITorchCreateTensorFromBlobV2Test, DeletionDoesNotFreeOriginalMemory) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Delete the tensor - this should NOT free the original memory + error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // The original memory should still be valid (we'll free it in teardown) + // We can't easily test if the memory is still valid without risking crashes, + // but the test should pass without issues if memory management is correct +} + +// Test with opaque metadata +TEST_F(AOTITorchCreateTensorFromBlobV2Test, WithOpaqueMetadata) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + ASSERT_NE(data, nullptr); + + // Create some opaque metadata + std::vector metadata = {0x01, 0x02, 0x03, 0x04}; + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + metadata.data(), // opaque_metadata + metadata.size()); // opaque_metadata_size + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); +} + +// Test stress test with many small tensors from blobs +TEST_F(AOTITorchCreateTensorFromBlobV2Test, StressTestManySmallTensors) { + const int num_tensors = 50; // Reduced for reasonable test time + std::vector tensors; + + for (int i = 0; i < num_tensors; i++) { + std::vector sizes = {1, 1}; // Minimal size + std::vector strides = calculate_contiguous_strides(sizes); + + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* data = allocate_cuda_memory(bytes); + if (data == nullptr) { + // Skip if we run out of memory + continue; + } + + Tensor* tensor; + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + if (error == Error::Ok && tensor != nullptr) { + tensors.push_back(tensor); + + // Verify the tensor uses the correct data pointer + EXPECT_EQ(tensor->mutable_data_ptr(), data); + } + } + + // Delete all created tensors + for (Tensor* tensor : tensors) { + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + } +} diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda__weight_int4pack_mm.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda__weight_int4pack_mm.cpp new file mode 100644 index 00000000000..19fc4dad685 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda__weight_int4pack_mm.cpp @@ -0,0 +1,333 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::backends::aoti; +using namespace executorch::runtime; + +// Test fixture for aoti_torch_cuda__weight_int4pack_mm tests +class AOTITorchInt4MMTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Check if GPU supports sm_80+ (required for int4mm) + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + int compute_capability = prop.major * 10 + prop.minor; + if (compute_capability < 80) { + GTEST_SKIP() << "GPU compute capability " << compute_capability + << " < 80 (Ampere+), int4mm requires sm_80+"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create a BFloat16 tensor + Tensor* create_bfloat16_tensor(const std::vector& sizes) { + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, // default strides + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } + + // Helper to create an Int32 tensor + Tensor* create_int32_tensor(const std::vector& sizes) { + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, // default strides + static_cast(SupportedDTypes::INT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// Test basic int4mm functionality with minimal valid inputs +TEST_F(AOTITorchInt4MMTest, BasicFunctionality) { + // Create input tensor A: [m, k] = [2, 128] in BFloat16 + int64_t m = 2; + int64_t k = 128; + int64_t n = 64; + int64_t qGroupSize = 128; + + Tensor* A = create_bfloat16_tensor({m, k}); + ASSERT_NE(A, nullptr) << "Failed to create input tensor A"; + + // Create weight tensor B (int4 packed): [n/8, k/(innerKTiles*16), 32, 4] in + // Int32 For int4mm, innerKTiles is typically 8, so k/(8*16) = 128/128 = 1 + int64_t B_innerKTiles = 8; + int64_t B_kTiles = k / (B_innerKTiles * 16); + Tensor* B = create_int32_tensor({n / 8, B_kTiles, 32, 4}); + ASSERT_NE(B, nullptr) << "Failed to create weight tensor B"; + + // Create scale and zeros tensor: [k/qGroupSize, n, 2] in BFloat16 + // For k=128, qGroupSize=128, k/qGroupSize=1 + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + ASSERT_NE(qScaleAndZeros, nullptr) + << "Failed to create qScaleAndZeros tensor"; + + // Create output tensor: [m, n] in BFloat16 + Tensor* output = create_bfloat16_tensor({m, n}); + ASSERT_NE(output, nullptr) << "Failed to create output tensor"; + + printf("Testing int4mm with shapes:\n"); + printf(" A: [%ldx%ld] BFloat16\n", m, k); + printf(" B: [%ldx%ldx32x4] Int32\n", n / 8, B_kTiles); + printf(" qScaleAndZeros: [%ldx%ldx2] BFloat16\n", k / qGroupSize, n); + printf(" qGroupSize: %ld\n", qGroupSize); + printf(" Output: [%ldx%ld] BFloat16\n", m, n); + + // Call int4mm + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &output); + + // Check if the function succeeded + EXPECT_EQ(error, Error::Ok) << "int4mm operation should succeed"; + + // Verify output tensor properties + EXPECT_EQ(output->dim(), 2); + EXPECT_EQ(output->size(0), m); + EXPECT_EQ(output->size(1), n); + + printf("int4mm test passed successfully!\n"); +} + +// Test with different qGroupSize values +TEST_F(AOTITorchInt4MMTest, DifferentQGroupSizes) { + int64_t m = 4; + int64_t k = 256; + int64_t n = 128; + int64_t B_innerKTiles = 8; + + // Test qGroupSize = 64 + { + int64_t qGroupSize = 64; + + Tensor* A = create_bfloat16_tensor({m, k}); + ASSERT_NE(A, nullptr); + + Tensor* B = create_int32_tensor({n / 8, k / (B_innerKTiles * 16), 32, 4}); + ASSERT_NE(B, nullptr); + + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + ASSERT_NE(qScaleAndZeros, nullptr); + + Tensor* output = create_bfloat16_tensor({m, n}); + ASSERT_NE(output, nullptr); + + printf("Testing int4mm with qGroupSize=64\n"); + + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &output); + EXPECT_EQ(error, Error::Ok) << "int4mm with qGroupSize=64 should succeed"; + } + + // Test qGroupSize = 128 + { + int64_t qGroupSize = 128; + + Tensor* A = create_bfloat16_tensor({m, k}); + ASSERT_NE(A, nullptr); + + Tensor* B = create_int32_tensor({n / 8, k / (B_innerKTiles * 16), 32, 4}); + ASSERT_NE(B, nullptr); + + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + ASSERT_NE(qScaleAndZeros, nullptr); + + Tensor* output = create_bfloat16_tensor({m, n}); + ASSERT_NE(output, nullptr); + + printf("Testing int4mm with qGroupSize=128\n"); + + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &output); + EXPECT_EQ(error, Error::Ok) << "int4mm with qGroupSize=128 should succeed"; + } + + // Test qGroupSize = 256 + { + int64_t qGroupSize = 256; + + Tensor* A = create_bfloat16_tensor({m, k}); + ASSERT_NE(A, nullptr); + + Tensor* B = create_int32_tensor({n / 8, k / (B_innerKTiles * 16), 32, 4}); + ASSERT_NE(B, nullptr); + + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + ASSERT_NE(qScaleAndZeros, nullptr); + + Tensor* output = create_bfloat16_tensor({m, n}); + ASSERT_NE(output, nullptr); + + printf("Testing int4mm with qGroupSize=256\n"); + + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &output); + EXPECT_EQ(error, Error::Ok) << "int4mm with qGroupSize=256 should succeed"; + } +} + +// Test error handling with null inputs +TEST_F(AOTITorchInt4MMTest, NullInputHandling) { + int64_t m = 2; + int64_t k = 128; + int64_t n = 64; + int64_t qGroupSize = 128; + int64_t B_innerKTiles = 8; + + Tensor* A = create_bfloat16_tensor({m, k}); + Tensor* B = create_int32_tensor({n / 8, k / (B_innerKTiles * 16), 32, 4}); + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + Tensor* output = create_bfloat16_tensor({m, n}); + + // Test null A + { + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + nullptr, B, qGroupSize, qScaleAndZeros, &output); + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail with null A tensor"; + } + + // Test null B + { + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, nullptr, qGroupSize, qScaleAndZeros, &output); + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail with null B tensor"; + } + + // Test null qScaleAndZeros + { + AOTITorchError error = + aoti_torch_cuda__weight_int4pack_mm(A, B, qGroupSize, nullptr, &output); + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail with null qScaleAndZeros tensor"; + } + + // Test null output pointer + { + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, nullptr); + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail with null output pointer"; + } +} + +// Test with larger batch size +TEST_F(AOTITorchInt4MMTest, LargerBatchSize) { + int64_t m = 16; // Batch size + int64_t k = 256; + int64_t n = 128; + int64_t qGroupSize = 128; + int64_t B_innerKTiles = 8; + + Tensor* A = create_bfloat16_tensor({m, k}); + ASSERT_NE(A, nullptr); + + Tensor* B = create_int32_tensor({n / 8, k / (B_innerKTiles * 16), 32, 4}); + ASSERT_NE(B, nullptr); + + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + ASSERT_NE(qScaleAndZeros, nullptr); + + Tensor* output = create_bfloat16_tensor({m, n}); + ASSERT_NE(output, nullptr); + + printf("Testing int4mm with larger batch: m=%ld\n", m); + + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &output); + + EXPECT_EQ(error, Error::Ok) << "int4mm with larger batch should succeed"; + EXPECT_EQ(output->size(0), m); + EXPECT_EQ(output->size(1), n); +} + +// Test with larger tensors +TEST_F(AOTITorchInt4MMTest, LargerTensors) { + int64_t m = 8; + int64_t k = 512; + int64_t n = 256; + int64_t qGroupSize = 128; + int64_t B_innerKTiles = 8; + + Tensor* A = create_bfloat16_tensor({m, k}); + ASSERT_NE(A, nullptr); + + Tensor* B = create_int32_tensor({n / 8, k / (B_innerKTiles * 16), 32, 4}); + ASSERT_NE(B, nullptr); + + Tensor* qScaleAndZeros = create_bfloat16_tensor({k / qGroupSize, n, 2}); + ASSERT_NE(qScaleAndZeros, nullptr); + + Tensor* output = create_bfloat16_tensor({m, n}); + ASSERT_NE(output, nullptr); + + printf( + "Testing int4mm with larger tensors: [%ldx%ld] x [weight] -> [%ldx%ld]\n", + m, + k, + m, + n); + + AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( + A, B, qGroupSize, qScaleAndZeros, &output); + + EXPECT_EQ(error, Error::Ok) << "int4mm with larger tensors should succeed"; + EXPECT_EQ(output->dim(), 2); + EXPECT_EQ(output->size(0), m); + EXPECT_EQ(output->size(1), n); +} diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_guard.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_guard.cpp new file mode 100644 index 00000000000..7527965cdb8 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_guard.cpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; + +// TODO(gasoonjia): Multiple device tests were not included due to test +// environment limitations. Will be added in the future. +class AOTITorchCUDAGuardTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + ASSERT_EQ(cudaGetDevice(&original_device_), cudaSuccess); + } + + void TearDown() override { + if (cudaGetDeviceCount(&original_device_) == cudaSuccess) { + ASSERT_EQ(cudaGetDevice(&original_device_), cudaSuccess); + } + } + + int original_device_ = 0; +}; + +TEST_F(AOTITorchCUDAGuardTest, CreateAndDeleteCUDAGuard) { + CUDAGuardHandle guard = nullptr; + AOTITorchError error = aoti_torch_create_cuda_guard(0, &guard); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(guard, nullptr); + + int current_device = -1; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + EXPECT_EQ(current_device, 0); + + error = aoti_torch_delete_cuda_guard(guard); + EXPECT_EQ(error, Error::Ok); +} + +TEST_F(AOTITorchCUDAGuardTest, CreateCUDAGuardNullReturnPointer) { + AOTITorchError error = aoti_torch_create_cuda_guard(0, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, DeleteCUDAGuardNullHandle) { + AOTITorchError error = aoti_torch_delete_cuda_guard(nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, CUDAGuardSetIndexNullHandle) { + AOTITorchError error = aoti_torch_cuda_guard_set_index(nullptr, 0); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, CUDAGuardSetIndexInvalidDevice) { + CUDAGuardHandle guard = nullptr; + AOTITorchError error = aoti_torch_create_cuda_guard(0, &guard); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(guard, nullptr); + + error = aoti_torch_cuda_guard_set_index(guard, 999); + EXPECT_NE(error, Error::Ok); + + error = aoti_torch_delete_cuda_guard(guard); + EXPECT_EQ(error, Error::Ok); +} + +TEST_F(AOTITorchCUDAGuardTest, CreateAndDeleteCUDAStreamGuard) { + cudaStream_t stream; + ASSERT_EQ(cudaStreamCreate(&stream), cudaSuccess); + + CUDAStreamGuardHandle guard = nullptr; + AOTITorchError error = aoti_torch_create_cuda_stream_guard(stream, 0, &guard); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(guard, nullptr); + + error = aoti_torch_delete_cuda_stream_guard(guard); + EXPECT_EQ(error, Error::Ok); + + ASSERT_EQ(cudaStreamDestroy(stream), cudaSuccess); +} + +TEST_F(AOTITorchCUDAGuardTest, CreateCUDAStreamGuardNullReturnPointer) { + cudaStream_t stream; + ASSERT_EQ(cudaStreamCreate(&stream), cudaSuccess); + + AOTITorchError error = + aoti_torch_create_cuda_stream_guard(stream, 0, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); + + ASSERT_EQ(cudaStreamDestroy(stream), cudaSuccess); +} + +TEST_F(AOTITorchCUDAGuardTest, CreateCUDAStreamGuardNullStream) { + CUDAStreamGuardHandle guard = nullptr; + AOTITorchError error = + aoti_torch_create_cuda_stream_guard(nullptr, 0, &guard); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, DeleteCUDAStreamGuardNullHandle) { + AOTITorchError error = aoti_torch_delete_cuda_stream_guard(nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, GetCurrentCUDAStream) { + void* ret_stream = nullptr; + AOTITorchError error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(ret_stream, nullptr); +} + +TEST_F(AOTITorchCUDAGuardTest, GetCurrentCUDAStreamNullReturnPointer) { + AOTITorchError error = aoti_torch_get_current_cuda_stream(0, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, StreamGuardWithSameDevice) { + ASSERT_EQ(cudaSetDevice(0), cudaSuccess); + + cudaStream_t stream1, stream2; + ASSERT_EQ(cudaStreamCreate(&stream1), cudaSuccess); + ASSERT_EQ(cudaStreamCreate(&stream2), cudaSuccess); + + CUDAStreamGuardHandle guard1 = nullptr; + AOTITorchError error = + aoti_torch_create_cuda_stream_guard(stream1, 0, &guard1); + EXPECT_EQ(error, Error::Ok); + + void* ret_stream = nullptr; + error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(static_cast(ret_stream), stream1); + + CUDAStreamGuardHandle guard2 = nullptr; + error = aoti_torch_create_cuda_stream_guard(stream2, 0, &guard2); + EXPECT_EQ(error, Error::Ok); + + ret_stream = nullptr; + error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(static_cast(ret_stream), stream2); + + error = aoti_torch_delete_cuda_stream_guard(guard2); + EXPECT_EQ(error, Error::Ok); + + ret_stream = nullptr; + error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(static_cast(ret_stream), stream1); + + error = aoti_torch_delete_cuda_stream_guard(guard1); + EXPECT_EQ(error, Error::Ok); + + ASSERT_EQ(cudaStreamDestroy(stream1), cudaSuccess); + ASSERT_EQ(cudaStreamDestroy(stream2), cudaSuccess); +} + +TEST_F(AOTITorchCUDAGuardTest, GetCurrentStreamAfterSetStream) { + cudaStream_t new_stream; + ASSERT_EQ(cudaStreamCreate(&new_stream), cudaSuccess); + + CUDAStreamGuardHandle guard = nullptr; + AOTITorchError error = + aoti_torch_create_cuda_stream_guard(new_stream, 0, &guard); + EXPECT_EQ(error, Error::Ok); + + void* ret_stream = nullptr; + error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(static_cast(ret_stream), new_stream); + + error = aoti_torch_delete_cuda_stream_guard(guard); + EXPECT_EQ(error, Error::Ok); + + ASSERT_EQ(cudaStreamDestroy(new_stream), cudaSuccess); +} diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp new file mode 100644 index 00000000000..10c8d8c1a31 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp @@ -0,0 +1,454 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_delete_tensor_object tests +class AOTITorchDeleteTensorObjectTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create test tensors + Tensor* create_test_tensor( + const std::vector& sizes, + const std::vector& strides = {}, + int32_t dtype = 6, // float32 + int32_t device_type = 1, // CUDA + int32_t device_index = 0) { + Tensor* tensor; + + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides_ptr, + dtype, + device_type, + device_index, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// Test basic deletion of CUDA tensor +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteCudaTensorBasic) { + // Create a CUDA tensor + std::vector sizes = {2, 3}; + Tensor* tensor = create_test_tensor(sizes, {}, 6, 1, 0); // CUDA device + ASSERT_NE(tensor, nullptr); + + // Verify tensor properties before deletion + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test basic deletion of CPU tensor +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteCpuTensorBasic) { + // Create a CPU tensor + std::vector sizes = {3, 4}; + Tensor* tensor = create_test_tensor(sizes, {}, 6, 0, 0); // CPU device + ASSERT_NE(tensor, nullptr); + + // Verify tensor properties before deletion + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->size(1), 4); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test deletion of null tensor pointer +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteNullTensor) { + AOTITorchError error = aoti_torch_delete_tensor_object(nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test deletion of tensor not in tracking system +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteUntrackedTensor) { + // Create a tensor and then clear the tracking system + std::vector sizes = {2, 3}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + + // Clear the tracking system (simulating an untracked tensor) + clear_all_tensors(); + + // Try to delete the tensor - should fail + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test deletion of multiple tensors +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteMultipleTensors) { + // Create multiple tensors + std::vector tensors; + + for (int i = 1; i <= 5; i++) { + std::vector sizes = {i, i + 1}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + tensors.push_back(tensor); + } + + // Delete all tensors + for (Tensor* tensor : tensors) { + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + } +} + +// Test deletion of zero-sized tensors +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteZeroSizedTensor) { + // Create a zero-sized tensor + std::vector sizes = {0, 5}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 0); + EXPECT_EQ(tensor->size(1), 5); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test deletion of scalar (0D) tensors +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteScalarTensor) { + // Create a scalar tensor + std::vector sizes = {}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor->dim(), 0); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test deletion of large multi-dimensional tensors +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteLargeTensor) { + // Create a large multi-dimensional tensor + std::vector sizes = {10, 20, 30}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 10); + EXPECT_EQ(tensor->size(1), 20); + EXPECT_EQ(tensor->size(2), 30); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test deletion of tensors with custom strides +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteTensorWithCustomStrides) { + // Create tensor with custom strides + std::vector sizes = {3, 4}; + std::vector strides = {4, 1}; // Row-major strides + Tensor* tensor = create_test_tensor(sizes, strides); + ASSERT_NE(tensor, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->size(1), 4); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test deletion after accessing tensor data +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteAfterDataAccess) { + // Create a tensor + std::vector sizes = {2, 3}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + + // Access tensor data (this should not prevent deletion) + void* data_ptr = tensor->mutable_data_ptr(); + EXPECT_NE(data_ptr, nullptr); + + // Delete the tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); +} + +// Test double deletion (should fail on second attempt) +TEST_F(AOTITorchDeleteTensorObjectTest, DoubleDeletion) { + // Create a tensor + std::vector sizes = {2, 3}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + + // First deletion should succeed + AOTITorchError error1 = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error1, Error::Ok); + + // Second deletion should fail (tensor no longer tracked) + AOTITorchError error2 = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error2, Error::InvalidArgument); +} + +// Test deletion of tensors on both CUDA and CPU devices +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteMixedDeviceTensors) { + // Create CUDA tensor + std::vector sizes = {2, 3}; + Tensor* cuda_tensor = create_test_tensor(sizes, {}, 6, 1, 0); + ASSERT_NE(cuda_tensor, nullptr); + + // Create CPU tensor + Tensor* cpu_tensor = create_test_tensor(sizes, {}, 6, 0, 0); + ASSERT_NE(cpu_tensor, nullptr); + + // Delete both tensors + AOTITorchError cuda_error = aoti_torch_delete_tensor_object(cuda_tensor); + EXPECT_EQ(cuda_error, Error::Ok); + + AOTITorchError cpu_error = aoti_torch_delete_tensor_object(cpu_tensor); + EXPECT_EQ(cpu_error, Error::Ok); +} + +// Test memory consistency after deletion +TEST_F(AOTITorchDeleteTensorObjectTest, MemoryConsistencyAfterDeletion) { + // Create multiple tensors + std::vector tensors; + const int num_tensors = 10; + + for (int i = 0; i < num_tensors; i++) { + std::vector sizes = {i + 1, i + 2}; + Tensor* tensor = create_test_tensor(sizes); + ASSERT_NE(tensor, nullptr); + tensors.push_back(tensor); + } + + // Delete every other tensor + for (int i = 0; i < num_tensors; i += 2) { + AOTITorchError error = aoti_torch_delete_tensor_object(tensors[i]); + EXPECT_EQ(error, Error::Ok); + } + + // Delete remaining tensors + for (int i = 1; i < num_tensors; i += 2) { + AOTITorchError error = aoti_torch_delete_tensor_object(tensors[i]); + EXPECT_EQ(error, Error::Ok); + } +} + +// Test stress deletion with many small tensors +TEST_F(AOTITorchDeleteTensorObjectTest, StressDeletionManySmallTensors) { + const int num_tensors = 100; + std::vector tensors; + + // Create many small tensors + for (int i = 0; i < num_tensors; i++) { + std::vector sizes = {1, 1}; // Minimal size + Tensor* tensor = create_test_tensor(sizes); + if (tensor != nullptr) { + tensors.push_back(tensor); + } + } + + // Delete all created tensors + for (Tensor* tensor : tensors) { + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + } +} + +// Test CUDA synchronization during deletion +TEST_F(AOTITorchDeleteTensorObjectTest, CudaSynchronizationDuringDeletion) { + // Create a larger CUDA tensor to ensure memory allocation + std::vector sizes = {100, 100}; + Tensor* tensor = create_test_tensor(sizes, {}, 6, 1, 0); // CUDA device + ASSERT_NE(tensor, nullptr); + + // Delete the tensor (should handle synchronization internally) + AOTITorchError error = aoti_torch_delete_tensor_object(tensor); + EXPECT_EQ(error, Error::Ok); + + // Verify CUDA state is still good + cudaError_t cuda_error = cudaGetLastError(); + EXPECT_EQ(cuda_error, cudaSuccess); +} + +// Test specific deletion of bfloat16 tensors +TEST_F(AOTITorchDeleteTensorObjectTest, DeleteBFloat16Tensor) { + // Test 1D bfloat16 tensor deletion + std::vector sizes_1d = {10}; + Tensor* tensor_bf16_1d = create_test_tensor( + sizes_1d, + {}, + static_cast(SupportedDTypes::BFLOAT16), + 1, // CUDA device + 0); + ASSERT_NE(tensor_bf16_1d, nullptr); + + // Verify it's bfloat16 before deletion + int32_t actual_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16_1d, &actual_dtype), Error::Ok); + EXPECT_EQ(actual_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Expected bfloat16 dtype (" + << static_cast(SupportedDTypes::BFLOAT16) << "), got " + << actual_dtype; + + // Verify element size (bfloat16 should be 2 bytes per element) + EXPECT_EQ(tensor_bf16_1d->element_size(), 2); + + // Delete the bfloat16 tensor + AOTITorchError error = aoti_torch_delete_tensor_object(tensor_bf16_1d); + EXPECT_EQ(error, Error::Ok); + + // Test 2D bfloat16 tensor deletion with custom strides + std::vector sizes_2d = {4, 6}; + std::vector strides_2d = {6, 1}; // Row-major strides + Tensor* tensor_bf16_2d = create_test_tensor( + sizes_2d, + strides_2d, + static_cast(SupportedDTypes::BFLOAT16), + 1, // CUDA device + 0); + ASSERT_NE(tensor_bf16_2d, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor_bf16_2d->dim(), 2); + EXPECT_EQ(tensor_bf16_2d->size(0), 4); + EXPECT_EQ(tensor_bf16_2d->size(1), 6); + EXPECT_EQ(tensor_bf16_2d->element_size(), 2); + + // Verify it's bfloat16 + int32_t dtype_2d; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16_2d, &dtype_2d), Error::Ok); + EXPECT_EQ(dtype_2d, static_cast(SupportedDTypes::BFLOAT16)); + + // Delete the 2D bfloat16 tensor + error = aoti_torch_delete_tensor_object(tensor_bf16_2d); + EXPECT_EQ(error, Error::Ok); + + // Test 3D bfloat16 tensor deletion + std::vector sizes_3d = {2, 3, 4}; + Tensor* tensor_bf16_3d = create_test_tensor( + sizes_3d, + {}, + static_cast(SupportedDTypes::BFLOAT16), + 1, // CUDA device + 0); + ASSERT_NE(tensor_bf16_3d, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor_bf16_3d->dim(), 3); + EXPECT_EQ(tensor_bf16_3d->size(0), 2); + EXPECT_EQ(tensor_bf16_3d->size(1), 3); + EXPECT_EQ(tensor_bf16_3d->size(2), 4); + EXPECT_EQ(tensor_bf16_3d->element_size(), 2); + + // Verify memory size (2 * 3 * 4 * 2 bytes = 48 bytes) + size_t expected_memory = 2 * 3 * 4 * 2; + size_t actual_memory = + tensor_bf16_3d->numel() * tensor_bf16_3d->element_size(); + EXPECT_EQ(actual_memory, expected_memory); + + // Delete the 3D bfloat16 tensor + error = aoti_torch_delete_tensor_object(tensor_bf16_3d); + EXPECT_EQ(error, Error::Ok); + + // Test bfloat16 scalar tensor (0D) deletion + std::vector scalar_sizes = {}; + Tensor* tensor_bf16_scalar = create_test_tensor( + scalar_sizes, + {}, + static_cast(SupportedDTypes::BFLOAT16), + 1, // CUDA device + 0); + ASSERT_NE(tensor_bf16_scalar, nullptr); + + // Verify scalar tensor properties + EXPECT_EQ(tensor_bf16_scalar->dim(), 0); + EXPECT_EQ(tensor_bf16_scalar->numel(), 1); + EXPECT_EQ(tensor_bf16_scalar->element_size(), 2); + + // Delete the scalar bfloat16 tensor + error = aoti_torch_delete_tensor_object(tensor_bf16_scalar); + EXPECT_EQ(error, Error::Ok); + + // Test zero-element bfloat16 tensor deletion + std::vector zero_sizes = {0, 5}; + Tensor* tensor_bf16_zero = create_test_tensor( + zero_sizes, + {}, + static_cast(SupportedDTypes::BFLOAT16), + 1, // CUDA device + 0); + ASSERT_NE(tensor_bf16_zero, nullptr); + + // Verify zero-element tensor properties + EXPECT_EQ(tensor_bf16_zero->dim(), 2); + EXPECT_EQ(tensor_bf16_zero->size(0), 0); + EXPECT_EQ(tensor_bf16_zero->size(1), 5); + EXPECT_EQ(tensor_bf16_zero->numel(), 0); + EXPECT_EQ(tensor_bf16_zero->element_size(), 2); + + // Delete the zero-element bfloat16 tensor + error = aoti_torch_delete_tensor_object(tensor_bf16_zero); + EXPECT_EQ(error, Error::Ok); +} + +// Test deletion of mixed dtype tensors (float32 and bfloat16) diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp new file mode 100644 index 00000000000..799a8d1221b --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp @@ -0,0 +1,667 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::backends::aoti; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_empty_strided tests +class AOTITorchEmptyStridedTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create test tensors + Tensor* create_tracked_tensor( + const std::vector& sizes, + const std::vector& strides = {}, + int32_t dtype = static_cast(SupportedDTypes::FLOAT32), + int32_t device_type = static_cast(SupportedDevices::CUDA), + int32_t device_index = 0) { + Tensor* tensor; + + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides_ptr, + dtype, + device_type, + device_index, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// Test aoti_torch_empty_strided basic functionality +TEST_F(AOTITorchEmptyStridedTest, BasicFunctionality) { + // Test 1D tensor + std::vector sizes_1d = {5}; + Tensor* tensor_1d; + AOTITorchError error = aoti_torch_empty_strided( + sizes_1d.size(), + sizes_1d.data(), + nullptr, // Let function compute strides + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_1d); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_1d, nullptr); + + // CRITICAL: Verify the tensor is actually float32 + int32_t actual_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_1d, &actual_dtype), Error::Ok); + EXPECT_EQ(actual_dtype, static_cast(SupportedDTypes::FLOAT32)) + << "Expected float32 dtype (" + << static_cast(SupportedDTypes::FLOAT32) << "), got " + << actual_dtype; + + // Verify element size (float32 should be 4 bytes per element) + size_t element_size = tensor_1d->element_size(); + EXPECT_EQ(element_size, 4) + << "Expected float32 element size to be 4 bytes, got " << element_size; + + // Verify total number of elements and memory usage + int64_t expected_numel = 5; // 5 elements + EXPECT_EQ(tensor_1d->numel(), expected_numel) + << "Expected " << expected_numel << " elements, got " + << tensor_1d->numel(); + + // Verify total memory size (numel * element_size) + size_t expected_memory_size = expected_numel * 4; // 5 * 4 = 20 bytes + size_t actual_memory_size = tensor_1d->numel() * tensor_1d->element_size(); + EXPECT_EQ(actual_memory_size, expected_memory_size) + << "Expected " << expected_memory_size << " bytes, got " + << actual_memory_size; + + // Check tensor properties + EXPECT_EQ(tensor_1d->dim(), 1); + EXPECT_EQ(tensor_1d->size(0), 5); + + // Test 2D tensor with explicit strides + std::vector sizes_2d = {3, 4}; + std::vector strides_2d = {4, 1}; + Tensor* tensor_2d; + error = aoti_torch_empty_strided( + sizes_2d.size(), + sizes_2d.data(), + strides_2d.data(), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_2d); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_2d, nullptr); + + // Verify 2D tensor is also float32 + int32_t dtype_2d; + EXPECT_EQ(aoti_torch_get_dtype(tensor_2d, &dtype_2d), Error::Ok); + EXPECT_EQ(dtype_2d, static_cast(SupportedDTypes::FLOAT32)) + << "Expected float32 dtype (" + << static_cast(SupportedDTypes::FLOAT32) << "), got " + << dtype_2d; + + // Verify element size for 2D tensor + EXPECT_EQ(tensor_2d->element_size(), 4); + + // Check tensor properties + EXPECT_EQ(tensor_2d->dim(), 2); + EXPECT_EQ(tensor_2d->size(0), 3); + EXPECT_EQ(tensor_2d->size(1), 4); + + // Verify memory size for 2D tensor + int64_t expected_numel_2d = 3 * 4; // 12 elements + size_t expected_memory_2d = expected_numel_2d * 4; // 12 * 4 = 48 bytes + EXPECT_EQ(tensor_2d->numel() * tensor_2d->element_size(), expected_memory_2d); +} + +// Test aoti_torch_empty_strided with CPU device +TEST_F(AOTITorchEmptyStridedTest, CPUDevice) { + std::vector sizes = {2, 3}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, // Let function compute strides + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU), + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); +} + +// Test aoti_torch_empty_strided with invalid dtype +TEST_F(AOTITorchEmptyStridedTest, InvalidDtype) { + std::vector sizes = {2, 3}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 999, // invalid dtype + 1, // CUDA device + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test aoti_torch_empty_strided with unsupported device +TEST_F(AOTITorchEmptyStridedTest, UnsupportedDevice) { + std::vector sizes = {2, 3}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 2, // unsupported device type + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::NotImplemented); +} + +// Test aoti_torch_empty_strided with zero-sized tensor +TEST_F(AOTITorchEmptyStridedTest, ZeroSized) { + std::vector sizes = {0, 5}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 0); + EXPECT_EQ(tensor->size(1), 5); +} + +// Test aoti_torch_empty_strided scalar tensor (0D) +TEST_F(AOTITorchEmptyStridedTest, Scalar) { + std::vector sizes = {}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 0); +} + +// Test aoti_torch_empty_strided with large tensor +TEST_F(AOTITorchEmptyStridedTest, LargeTensor) { + std::vector sizes = {100, 200, 50}; + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 100); + EXPECT_EQ(tensor->size(1), 200); + EXPECT_EQ(tensor->size(2), 50); +} + +// Test aoti_torch_empty_strided with bfloat16 dtype +TEST_F(AOTITorchEmptyStridedTest, BFloat16Tensor) { + // Test creating bfloat16 tensor on CUDA + std::vector sizes = {2, 3, 4}; + Tensor* tensor_bf16; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, // Let function compute strides + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_bf16); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_bf16, nullptr); + + // CRITICAL: Verify the tensor is actually bfloat16 + int32_t actual_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16, &actual_dtype), Error::Ok); + EXPECT_EQ(actual_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Expected bfloat16 dtype (" + << static_cast(SupportedDTypes::BFLOAT16) << "), got " + << actual_dtype; + + // Verify element size (bfloat16 should be 2 bytes per element) + size_t element_size = tensor_bf16->element_size(); + EXPECT_EQ(element_size, 2) + << "Expected bfloat16 element size to be 2 bytes, got " << element_size; + + // Verify total number of elements and memory usage + int64_t expected_numel = 2 * 3 * 4; // 24 elements + EXPECT_EQ(tensor_bf16->numel(), expected_numel) + << "Expected " << expected_numel << " elements, got " + << tensor_bf16->numel(); + + // Verify total memory size (numel * element_size) + size_t expected_memory_size = expected_numel * 2; // 24 * 2 = 48 bytes + size_t actual_memory_size = + tensor_bf16->numel() * tensor_bf16->element_size(); + EXPECT_EQ(actual_memory_size, expected_memory_size) + << "Expected " << expected_memory_size << " bytes, got " + << actual_memory_size; + + // Check tensor properties + EXPECT_EQ(tensor_bf16->dim(), 3); + EXPECT_EQ(tensor_bf16->size(0), 2); + EXPECT_EQ(tensor_bf16->size(1), 3); + EXPECT_EQ(tensor_bf16->size(2), 4); + + // Verify we can get tensor metadata + int64_t* sizes_ptr; + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_sizes(tensor_bf16, &sizes_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor_bf16, &strides_ptr), Error::Ok); + + // Check sizes match + EXPECT_EQ(sizes_ptr[0], 2); + EXPECT_EQ(sizes_ptr[1], 3); + EXPECT_EQ(sizes_ptr[2], 4); + + // Check that strides are computed correctly (row-major order) + EXPECT_EQ(strides_ptr[0], 12); // 3 * 4 + EXPECT_EQ(strides_ptr[1], 4); // 4 + EXPECT_EQ(strides_ptr[2], 1); // 1 + + // Test bfloat16 tensor with custom strides + std::vector sizes_2d = {3, 2}; + std::vector strides_2d = {2, 1}; // Row-major strides + Tensor* tensor_bf16_custom; + error = aoti_torch_empty_strided( + sizes_2d.size(), + sizes_2d.data(), + strides_2d.data(), + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_bf16_custom); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_bf16_custom, nullptr); + + // Verify custom stride tensor is also bfloat16 + int32_t custom_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16_custom, &custom_dtype), Error::Ok); + EXPECT_EQ(custom_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Expected bfloat16 dtype (" + << static_cast(SupportedDTypes::BFLOAT16) << "), got " + << custom_dtype; + + // Verify element size for custom stride tensor + EXPECT_EQ(tensor_bf16_custom->element_size(), 2); + + // Check tensor properties + EXPECT_EQ(tensor_bf16_custom->dim(), 2); + EXPECT_EQ(tensor_bf16_custom->size(0), 3); + EXPECT_EQ(tensor_bf16_custom->size(1), 2); + + // Verify memory size for custom stride tensor + int64_t custom_expected_numel = 3 * 2; // 6 elements + size_t custom_expected_memory = custom_expected_numel * 2; // 6 * 2 = 12 bytes + EXPECT_EQ( + tensor_bf16_custom->numel() * tensor_bf16_custom->element_size(), + custom_expected_memory); + + // Check custom strides + int64_t* custom_strides_ptr; + EXPECT_EQ( + aoti_torch_get_strides(tensor_bf16_custom, &custom_strides_ptr), + Error::Ok); + EXPECT_EQ(custom_strides_ptr[0], 2); + EXPECT_EQ(custom_strides_ptr[1], 1); + + // Test bfloat16 scalar tensor (0D) + std::vector scalar_sizes = {}; + Tensor* tensor_bf16_scalar; + error = aoti_torch_empty_strided( + scalar_sizes.size(), + scalar_sizes.data(), + nullptr, + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor_bf16_scalar); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_bf16_scalar, nullptr); + EXPECT_EQ(tensor_bf16_scalar->dim(), 0); + + // Verify scalar tensor is also bfloat16 + int32_t scalar_dtype; + EXPECT_EQ(aoti_torch_get_dtype(tensor_bf16_scalar, &scalar_dtype), Error::Ok); + EXPECT_EQ(scalar_dtype, static_cast(SupportedDTypes::BFLOAT16)) + << "Expected bfloat16 dtype (" + << static_cast(SupportedDTypes::BFLOAT16) << "), got " + << scalar_dtype; + + // Verify scalar tensor properties + EXPECT_EQ(tensor_bf16_scalar->element_size(), 2); + EXPECT_EQ(tensor_bf16_scalar->numel(), 1); // Scalar tensor has 1 element + EXPECT_EQ( + tensor_bf16_scalar->numel() * tensor_bf16_scalar->element_size(), + 2); // 1 * 2 = 2 bytes +} + +// Test custom strides functionality +TEST_F(AOTITorchEmptyStridedTest, CustomStrides) { + // Create tensor with valid custom strides (contiguous layout) + std::vector sizes = {2, 3}; + std::vector strides = {3, 1}; // Standard row-major strides + + Tensor* tensor = create_tracked_tensor(sizes, strides); + EXPECT_NE(tensor, nullptr); + + // Verify the tensor was created correctly + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); + + // Check strides through AOTI interface + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok); + EXPECT_EQ(strides_ptr[0], 3); + EXPECT_EQ(strides_ptr[1], 1); + + // Test another valid stride pattern - transpose-like + std::vector sizes_2 = {3, 2}; + std::vector strides_2 = {1, 3}; // Column-major strides + + Tensor* tensor_2 = create_tracked_tensor(sizes_2, strides_2); + EXPECT_NE(tensor_2, nullptr); + + // Verify the tensor properties + EXPECT_EQ(tensor_2->dim(), 2); + EXPECT_EQ(tensor_2->size(0), 3); + EXPECT_EQ(tensor_2->size(1), 2); + + // Check strides + int64_t* strides_ptr_2; + EXPECT_EQ(aoti_torch_get_strides(tensor_2, &strides_ptr_2), Error::Ok); + EXPECT_EQ(strides_ptr_2[0], 1); + EXPECT_EQ(strides_ptr_2[1], 3); +} + +// Test edge case: zero-element tensor with non-zero dimensions +TEST_F(AOTITorchEmptyStridedTest, ZeroElementTensor) { + std::vector sizes = {2, 0, 3}; // Total elements = 0 + Tensor* tensor = create_tracked_tensor(sizes); + EXPECT_NE(tensor, nullptr); + + // Verify the tensor properties + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 0); + EXPECT_EQ(tensor->size(2), 3); + + // Should be able to get metadata + int64_t* sizes_ptr; + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok); + + EXPECT_EQ(sizes_ptr[0], 2); + EXPECT_EQ(sizes_ptr[1], 0); + EXPECT_EQ(sizes_ptr[2], 3); +} + +// Test different data types (currently we support bf16, fp32 and int32) +TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) { + std::vector sizes = {2, 3}; + + // Test float32 (dtype 6) - one of the supported types + Tensor* tensor_float32; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 6, // float32 + 1, // CUDA device + 0, // device index + &tensor_float32); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_float32, nullptr); + + // Test int32 (dtype 3) - one of the supported types + Tensor* tensor_int32; + error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 3, // int32 - unsupported + 1, // CUDA device + 0, // device index + &tensor_int32); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor_int32, nullptr); + + // Test another unsupported data type + Tensor* tensor_float64; + error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + 7, // float64 - unsupported + 1, // CUDA device + 0, // device index + &tensor_float64); + + EXPECT_EQ(error, Error::InvalidArgument); // Should fail for unsupported dtype +} + +// Test multi-dimensional tensors with various shapes +TEST_F(AOTITorchEmptyStridedTest, MultiDimensionalTensors) { + // Test 3D tensor + std::vector sizes_3d = {2, 3, 4}; + Tensor* tensor_3d = create_tracked_tensor(sizes_3d); + EXPECT_NE(tensor_3d, nullptr); + EXPECT_EQ(tensor_3d->dim(), 3); + EXPECT_EQ(tensor_3d->size(0), 2); + EXPECT_EQ(tensor_3d->size(1), 3); + EXPECT_EQ(tensor_3d->size(2), 4); + + // Test 4D tensor + std::vector sizes_4d = {2, 3, 4, 5}; + Tensor* tensor_4d = create_tracked_tensor(sizes_4d); + EXPECT_NE(tensor_4d, nullptr); + EXPECT_EQ(tensor_4d->dim(), 4); + EXPECT_EQ(tensor_4d->size(0), 2); + EXPECT_EQ(tensor_4d->size(1), 3); + EXPECT_EQ(tensor_4d->size(2), 4); + EXPECT_EQ(tensor_4d->size(3), 5); + + // Test 5D tensor + std::vector sizes_5d = {1, 2, 3, 4, 5}; + Tensor* tensor_5d = create_tracked_tensor(sizes_5d); + EXPECT_NE(tensor_5d, nullptr); + EXPECT_EQ(tensor_5d->dim(), 5); + EXPECT_EQ(tensor_5d->size(0), 1); + EXPECT_EQ(tensor_5d->size(1), 2); + EXPECT_EQ(tensor_5d->size(2), 3); + EXPECT_EQ(tensor_5d->size(3), 4); + EXPECT_EQ(tensor_5d->size(4), 5); +} + +// Test incontiguous tensor creation - transpose-like layout +TEST_F(AOTITorchEmptyStridedTest, IncontiguousTransposeLayout) { + // Create a tensor with transpose-like strides (column-major) + // For a 3x4 tensor in column-major order, strides should be [1, 3] + // This means each row step is 1, and each column step is 3 + std::vector sizes = {3, 4}; + std::vector strides = {1, 3}; // Column-major (incontiguous) + + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->size(1), 4); + + // Verify the strides are what we specified + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok); + EXPECT_EQ(strides_ptr[0], 1); // Column-major stride for dimension 0 + EXPECT_EQ(strides_ptr[1], 3); // Column-major stride for dimension 1 + + // Verify that memory was allocated correctly for incontiguous layout + // Storage size should be: stride[0] * (size[0] - 1) + stride[1] * (size[1] - + // 1) + 1 = 1 * (3 - 1) + 3 * (4 - 1) + 1 = 1 * 2 + 3 * 3 + 1 = 2 + 9 + 1 = 12 + // elements Total bytes = 12 * 4 = 48 bytes (for float32) + EXPECT_EQ(tensor->numel(), 12); // numel is still 3*4=12 for logical shape + + // The tensor should be accessible and writable + void* data_ptr = tensor->mutable_data_ptr(); + EXPECT_NE(data_ptr, nullptr); + + // Verify we can use CUDA to write to the memory + std::vector test_data(12, 1.0f); + cudaError_t cuda_err = cudaMemcpy( + data_ptr, test_data.data(), 12 * sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess); +} + +// Test incontiguous tensor creation - expanded/broadcasted stride pattern +TEST_F(AOTITorchEmptyStridedTest, IncontiguousExpandedStrides) { + // Create a tensor with expanded strides (simulating broadcasting) + // A 2x3x4 tensor where the first dimension has stride 0 (expanded) + // This creates a tensor where the first dimension is "broadcasted" + std::vector sizes = {2, 3, 4}; + std::vector strides = {0, 4, 1}; // First dimension has stride 0 + + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), + 0, // device index + &tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(tensor, nullptr); + + // Verify tensor properties + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); + EXPECT_EQ(tensor->size(2), 4); + + // Verify the strides are what we specified + int64_t* strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok); + EXPECT_EQ(strides_ptr[0], 0); // Expanded dimension stride + EXPECT_EQ(strides_ptr[1], 4); + EXPECT_EQ(strides_ptr[2], 1); + + // Verify that memory was allocated correctly for this incontiguous layout + // Storage size should be: stride[0] * (size[0] - 1) + stride[1] * (size[1] - + // 1) + stride[2] * (size[2] - 1) + 1 = 0 * (2 - 1) + 4 * (3 - 1) + 1 * (4 - + // 1) + 1 = 0 + 8 + 3 + 1 = 12 elements Note: numel() returns logical number + // of elements (2*3*4=24), not storage size + EXPECT_EQ(tensor->numel(), 24); // Logical numel is 2*3*4=24 + + // The tensor should be accessible and writable + void* data_ptr = tensor->mutable_data_ptr(); + EXPECT_NE(data_ptr, nullptr); + + // Verify we can use CUDA to write to the allocated memory + // We only need to allocate 12 elements (storage size), not 24 + std::vector test_data(12, 2.0f); + cudaError_t cuda_err = cudaMemcpy( + data_ptr, test_data.data(), 12 * sizeof(float), cudaMemcpyHostToDevice); + EXPECT_EQ(cuda_err, cudaSuccess); +} diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_new_tensor_handle.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_new_tensor_handle.cpp new file mode 100644 index 00000000000..d123443cbfa --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_new_tensor_handle.cpp @@ -0,0 +1,560 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_new_tensor_handle tests +class AOTITorchNewTensorHandleTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create test tensors + Tensor* create_test_tensor( + const std::vector& sizes, + const std::vector& strides = {}, + int32_t dtype = static_cast(SupportedDTypes::FLOAT32), + int32_t device_type = static_cast(SupportedDevices::CUDA), + int32_t device_index = 0) { + Tensor* tensor; + + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides_ptr, + dtype, + device_type, + device_index, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// Test basic functionality of creating a new tensor handle +TEST_F(AOTITorchNewTensorHandleTest, BasicFunctionality) { + // Create an original tensor + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + // Create a new handle from the original tensor + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(new_tensor, nullptr); + + // Verify the new tensor has the same properties + EXPECT_EQ(new_tensor->dim(), orig_tensor->dim()); + EXPECT_EQ(new_tensor->size(0), orig_tensor->size(0)); + EXPECT_EQ(new_tensor->size(1), orig_tensor->size(1)); + EXPECT_EQ(new_tensor->numel(), orig_tensor->numel()); + + // Verify they share the same memory + EXPECT_EQ(new_tensor->mutable_data_ptr(), orig_tensor->mutable_data_ptr()); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating new handle from null tensor +TEST_F(AOTITorchNewTensorHandleTest, NullOriginalTensor) { + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(nullptr, &new_tensor); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test passing null pointer for new handle +TEST_F(AOTITorchNewTensorHandleTest, NullNewHandle) { + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, nullptr); + + EXPECT_EQ(error, Error::InvalidArgument); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); +} + +// Test memory sharing between original and new tensor handle +TEST_F(AOTITorchNewTensorHandleTest, MemorySharing) { + // Create an original tensor + std::vector sizes = {3, 4}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + // Get original memory pointer + void* orig_ptr = orig_tensor->mutable_data_ptr(); + ASSERT_NE(orig_ptr, nullptr); + + // Create a new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify both tensors point to the same memory + void* new_ptr = new_tensor->mutable_data_ptr(); + EXPECT_EQ(orig_ptr, new_ptr); + + // Clean up - deleting one should not affect the other's validity + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + + // New tensor should still be valid and accessible + void* still_valid_ptr = new_tensor->mutable_data_ptr(); + EXPECT_EQ(still_valid_ptr, new_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating multiple handles from the same tensor +TEST_F(AOTITorchNewTensorHandleTest, MultipleHandles) { + // Create an original tensor + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + void* orig_ptr = orig_tensor->mutable_data_ptr(); + + // Create multiple handles + std::vector handles; + const int num_handles = 5; + + for (int i = 0; i < num_handles; i++) { + Tensor* new_tensor; + AOTITorchError error = + aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + EXPECT_EQ(new_tensor->mutable_data_ptr(), orig_ptr); + handles.push_back(new_tensor); + } + + // Delete original tensor + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + + // All handles should still be valid + for (Tensor* handle : handles) { + EXPECT_EQ(handle->mutable_data_ptr(), orig_ptr); + EXPECT_EQ(handle->dim(), 2); + EXPECT_EQ(handle->size(0), 2); + EXPECT_EQ(handle->size(1), 3); + } + + // Delete all handles + for (Tensor* handle : handles) { + EXPECT_EQ(aoti_torch_delete_tensor_object(handle), Error::Ok); + } +} + +// Test creating handle from tensor with custom strides +TEST_F(AOTITorchNewTensorHandleTest, CustomStrides) { + std::vector sizes = {3, 4}; + std::vector strides = {4, 1}; // Row-major strides + Tensor* orig_tensor = create_test_tensor(sizes, strides); + ASSERT_NE(orig_tensor, nullptr); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify strides are preserved + int64_t* orig_strides_ptr; + int64_t* new_strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(orig_tensor, &orig_strides_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(new_tensor, &new_strides_ptr), Error::Ok); + + EXPECT_EQ(orig_strides_ptr[0], new_strides_ptr[0]); + EXPECT_EQ(orig_strides_ptr[1], new_strides_ptr[1]); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle from bfloat16 tensor +TEST_F(AOTITorchNewTensorHandleTest, BFloat16Tensor) { + std::vector sizes = {2, 3, 4}; + Tensor* orig_tensor = create_test_tensor( + sizes, + {}, + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA)); + ASSERT_NE(orig_tensor, nullptr); + + // Verify original is bfloat16 + int32_t orig_dtype; + EXPECT_EQ(aoti_torch_get_dtype(orig_tensor, &orig_dtype), Error::Ok); + EXPECT_EQ(orig_dtype, static_cast(SupportedDTypes::BFLOAT16)); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify new tensor is also bfloat16 + int32_t new_dtype; + EXPECT_EQ(aoti_torch_get_dtype(new_tensor, &new_dtype), Error::Ok); + EXPECT_EQ(new_dtype, static_cast(SupportedDTypes::BFLOAT16)); + + // Verify element size (bfloat16 should be 2 bytes) + EXPECT_EQ(new_tensor->element_size(), 2); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle from scalar (0D) tensor +TEST_F(AOTITorchNewTensorHandleTest, ScalarTensor) { + std::vector sizes = {}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + EXPECT_EQ(orig_tensor->dim(), 0); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify scalar properties + EXPECT_EQ(new_tensor->dim(), 0); + EXPECT_EQ(new_tensor->numel(), 1); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle from zero-sized tensor +TEST_F(AOTITorchNewTensorHandleTest, ZeroSizedTensor) { + std::vector sizes = {0, 5}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + EXPECT_EQ(orig_tensor->numel(), 0); + + // Attempt to create new handle - should fail because zero-sized tensors have + // null data pointers + Tensor* new_tensor = nullptr; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + + // Zero-sized tensors are not currently supported + EXPECT_EQ(error, Error::InvalidArgument); + EXPECT_EQ(new_tensor, nullptr); + + // Clean up original tensor + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); +} + +// Test creating handle from large multi-dimensional tensor +TEST_F(AOTITorchNewTensorHandleTest, LargeMultiDimensionalTensor) { + std::vector sizes = {10, 20, 30}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify dimensions + EXPECT_EQ(new_tensor->dim(), 3); + EXPECT_EQ(new_tensor->size(0), 10); + EXPECT_EQ(new_tensor->size(1), 20); + EXPECT_EQ(new_tensor->size(2), 30); + EXPECT_EQ(new_tensor->numel(), 6000); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle preserves tensor metadata +TEST_F(AOTITorchNewTensorHandleTest, MetadataPreservation) { + std::vector sizes = {2, 3, 4}; + std::vector strides = {12, 4, 1}; + Tensor* orig_tensor = create_test_tensor( + sizes, + strides, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA)); + ASSERT_NE(orig_tensor, nullptr); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Get and compare all metadata + int64_t* orig_sizes_ptr; + int64_t* new_sizes_ptr; + int64_t* orig_strides_ptr; + int64_t* new_strides_ptr; + int32_t orig_dtype, new_dtype; + int32_t orig_device_type, new_device_type; + int32_t orig_device_index, new_device_index; + + EXPECT_EQ(aoti_torch_get_sizes(orig_tensor, &orig_sizes_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_sizes(new_tensor, &new_sizes_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(orig_tensor, &orig_strides_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(new_tensor, &new_strides_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_dtype(orig_tensor, &orig_dtype), Error::Ok); + EXPECT_EQ(aoti_torch_get_dtype(new_tensor, &new_dtype), Error::Ok); + EXPECT_EQ( + aoti_torch_get_device_type(orig_tensor, &orig_device_type), Error::Ok); + EXPECT_EQ( + aoti_torch_get_device_type(new_tensor, &new_device_type), Error::Ok); + EXPECT_EQ( + aoti_torch_get_device_index(orig_tensor, &orig_device_index), Error::Ok); + EXPECT_EQ( + aoti_torch_get_device_index(new_tensor, &new_device_index), Error::Ok); + + // Verify all metadata matches + for (int i = 0; i < 3; i++) { + EXPECT_EQ(orig_sizes_ptr[i], new_sizes_ptr[i]); + EXPECT_EQ(orig_strides_ptr[i], new_strides_ptr[i]); + } + EXPECT_EQ(orig_dtype, new_dtype); + EXPECT_EQ(orig_device_type, new_device_type); + EXPECT_EQ(orig_device_index, new_device_index); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle chain: orig -> handle1 -> handle2 +TEST_F(AOTITorchNewTensorHandleTest, HandleChain) { + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + void* orig_ptr = orig_tensor->mutable_data_ptr(); + + // Create first handle + Tensor* handle1; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &handle1); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(handle1, nullptr); + EXPECT_EQ(handle1->mutable_data_ptr(), orig_ptr); + + // Create second handle from the first handle + Tensor* handle2; + error = aoti_torch_new_tensor_handle(handle1, &handle2); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(handle2, nullptr); + EXPECT_EQ(handle2->mutable_data_ptr(), orig_ptr); + + // Delete in reverse order + EXPECT_EQ(aoti_torch_delete_tensor_object(handle2), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(handle1), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); +} + +// Test creating handle and verifying reference counting +TEST_F(AOTITorchNewTensorHandleTest, ReferenceCountingTest) { + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + void* orig_ptr = orig_tensor->mutable_data_ptr(); + + // Create multiple handles + Tensor* handle1; + Tensor* handle2; + Tensor* handle3; + + EXPECT_EQ(aoti_torch_new_tensor_handle(orig_tensor, &handle1), Error::Ok); + EXPECT_EQ(aoti_torch_new_tensor_handle(orig_tensor, &handle2), Error::Ok); + EXPECT_EQ(aoti_torch_new_tensor_handle(orig_tensor, &handle3), Error::Ok); + + // Delete original + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + + // All handles should still be valid + EXPECT_EQ(handle1->mutable_data_ptr(), orig_ptr); + EXPECT_EQ(handle2->mutable_data_ptr(), orig_ptr); + EXPECT_EQ(handle3->mutable_data_ptr(), orig_ptr); + + // Delete handles one by one + EXPECT_EQ(aoti_torch_delete_tensor_object(handle1), Error::Ok); + + // Remaining handles should still be valid + EXPECT_EQ(handle2->mutable_data_ptr(), orig_ptr); + EXPECT_EQ(handle3->mutable_data_ptr(), orig_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(handle2), Error::Ok); + + // Last handle should still be valid + EXPECT_EQ(handle3->mutable_data_ptr(), orig_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(handle3), Error::Ok); +} + +// Test creating handle from int32 tensor +TEST_F(AOTITorchNewTensorHandleTest, Int32Tensor) { + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor( + sizes, + {}, + 3, // int32 + static_cast(SupportedDevices::CUDA)); + ASSERT_NE(orig_tensor, nullptr); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify dtype + int32_t new_dtype; + EXPECT_EQ(aoti_torch_get_dtype(new_tensor, &new_dtype), Error::Ok); + EXPECT_EQ(new_dtype, 3); // int32 + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle with incontiguous tensor (transpose-like layout) +TEST_F(AOTITorchNewTensorHandleTest, IncontiguousTransposeLayout) { + std::vector sizes = {3, 4}; + std::vector strides = {1, 3}; // Column-major (incontiguous) + Tensor* orig_tensor = create_test_tensor(sizes, strides); + ASSERT_NE(orig_tensor, nullptr); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify strides are preserved + int64_t* new_strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(new_tensor, &new_strides_ptr), Error::Ok); + EXPECT_EQ(new_strides_ptr[0], 1); + EXPECT_EQ(new_strides_ptr[1], 3); + + // Verify both tensors share the same memory + EXPECT_EQ(new_tensor->mutable_data_ptr(), orig_tensor->mutable_data_ptr()); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle with expanded strides (broadcasted dimension) +TEST_F(AOTITorchNewTensorHandleTest, ExpandedStrides) { + std::vector sizes = {2, 3, 4}; + std::vector strides = {0, 4, 1}; // First dimension has stride 0 + Tensor* orig_tensor = create_test_tensor(sizes, strides); + ASSERT_NE(orig_tensor, nullptr); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify expanded strides are preserved + int64_t* new_strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(new_tensor, &new_strides_ptr), Error::Ok); + EXPECT_EQ(new_strides_ptr[0], 0); + EXPECT_EQ(new_strides_ptr[1], 4); + EXPECT_EQ(new_strides_ptr[2], 1); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Stress test: create many handles +TEST_F(AOTITorchNewTensorHandleTest, StressTestManyHandles) { + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + void* orig_ptr = orig_tensor->mutable_data_ptr(); + + // Create many handles + const int num_handles = 100; + std::vector handles; + + for (int i = 0; i < num_handles; i++) { + Tensor* new_tensor; + AOTITorchError error = + aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + EXPECT_EQ(new_tensor->mutable_data_ptr(), orig_ptr); + handles.push_back(new_tensor); + } + + // Delete original + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + + // All handles should still be valid + for (Tensor* handle : handles) { + EXPECT_EQ(handle->mutable_data_ptr(), orig_ptr); + } + + // Delete all handles + for (Tensor* handle : handles) { + EXPECT_EQ(aoti_torch_delete_tensor_object(handle), Error::Ok); + } +} diff --git a/backends/cuda/runtime/tensor/tensor_maker.cpp b/backends/cuda/runtime/tensor/tensor_maker.cpp new file mode 100644 index 00000000000..01252082bfc --- /dev/null +++ b/backends/cuda/runtime/tensor/tensor_maker.cpp @@ -0,0 +1,126 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +namespace executorch::backends::cuda { + +namespace { +#ifndef USE_ATEN_LIB +/** + * A structure that consolidates the metadata (sizes, dim_order, strides) and + * the data buffer associated with a Tensor. Since Tensor does not own + * the memory for these metadata arrays or the data itself, this structure + * ensures that they are managed together and have the same lifetime as the + * Tensor. When the Tensor is destroyed, the Storage structure ensures + * proper cleanup of the associated metadata and data if needed. + */ +struct Storage final { + executorch::aten::TensorImpl tensor_impl; + executorch::aten::Tensor tensor; + std::vector sizes; + std::vector dim_order; + std::vector strides; + std::function deleter; + + Storage( + executorch::aten::TensorImpl&& tensor_impl, + std::vector&& sizes, + std::vector&& dim_order, + std::vector&& strides, + std::function&& deleter) + : tensor_impl(std::move(tensor_impl)), + tensor(&this->tensor_impl), + sizes(std::move(sizes)), + dim_order(std::move(dim_order)), + strides(std::move(strides)), + deleter(std::move(deleter)) {} + + ~Storage() { + if (deleter) { + deleter(tensor_impl.mutable_data()); + } + } +}; +#endif // USE_ATEN_LIB +} // namespace + +TensorPtr make_tensor( + std::vector sizes, + void* data, + std::vector dim_order, + std::vector strides, + executorch::aten::ScalarType type, + executorch::aten::TensorShapeDynamism dynamism, + std::function deleter) { + const auto dim = sizes.size(); + ET_CHECK_MSG( + dim_order.empty() || dim_order.size() == dim, + "dim_order size must match sizes or be empty."); + ET_CHECK_MSG( + strides.empty() || strides.size() == dim, + "strides size must match sizes or be empty."); + + if (dim_order.empty()) { + dim_order.resize(dim); + std::iota(dim_order.begin(), dim_order.end(), 0); + if (!strides.empty()) { + std::sort(dim_order.begin(), dim_order.end(), [&](size_t a, size_t b) { + return strides[a] > strides[b]; + }); + } + } + + // AOTI backends (like AOTI-CUDA) handle both contiguous and incontiguous + // tensors, so we skip stride calculation and incontiguous tensor checks. + // Strides are passed through as-is without validation. + +#ifndef USE_ATEN_LIB + executorch::aten::TensorImpl tensor_impl( + type, + dim, + sizes.data(), + data, + dim_order.data(), + strides.data(), + dim > 0 ? dynamism : executorch::aten::TensorShapeDynamism::STATIC); + auto storage = std::make_shared( + std::move(tensor_impl), + std::move(sizes), + std::move(dim_order), + std::move(strides), + std::move(deleter)); + const auto tensor_ptr = &storage->tensor; + return std::shared_ptr( + std::move(storage), tensor_ptr); +#else + auto options = c10::TensorOptions() + .dtype(c10::scalarTypeToTypeMeta(type)) + .device(c10::kCPU); + auto storage = c10::Storage( + c10::Storage::use_byte_size_t(), + at::detail::computeStorageNbytes( + sizes, strides, options.dtype().itemsize()), + c10::InefficientStdFunctionContext::makeDataPtr( + data, std::move(deleter), options.device()), + nullptr, + false); + auto tensor_impl = c10::make_intrusive( + std::move(storage), + c10::DispatchKeySet(c10::DispatchKey::CPU), + options.dtype()); + tensor_impl->set_sizes_and_strides(sizes, strides); + return std::make_shared(std::move(tensor_impl)); +#endif // USE_ATEN_LIB +} + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/tensor/tensor_maker.h b/backends/cuda/runtime/tensor/tensor_maker.h new file mode 100644 index 00000000000..92cdec60bb4 --- /dev/null +++ b/backends/cuda/runtime/tensor/tensor_maker.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace executorch::backends::cuda { + +/** + * A smart pointer type for managing the lifecycle of a Tensor. + * This is compatible with executorch::extension::TensorPtr. + */ +using TensorPtr = std::shared_ptr; + +/** + * Creates a TensorPtr for AOTI backends that skips stride calculation and + * incontiguous tensor checks. This is specifically designed for AOTI-CUDA + * which handles both contiguous and incontiguous tensors. + * + * This function is similar to executorch::extension::make_tensor_ptr but + * bypasses the stride validation that assumes contiguous tensors, making it + * suitable for AOTI backends that support arbitrary strides. + * + * @param sizes A vector specifying the size of each dimension. + * @param data A pointer to the data buffer. + * @param dim_order A vector specifying the order of dimensions. + * @param strides A vector specifying the strides of the tensor. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies the mutability of the tensor's shape. + * @param deleter A custom deleter function for managing the lifetime of the + * data buffer. If provided, this deleter will be called when the managed Tensor + * object is destroyed. + * @return A TensorPtr that manages the newly created Tensor. + */ +TensorPtr make_tensor( + std::vector sizes, + void* data, + std::vector dim_order, + std::vector strides, + executorch::aten::ScalarType type = executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism dynamism = + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND, + std::function deleter = nullptr); + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/tests/TARGETS b/backends/cuda/runtime/tests/TARGETS new file mode 100644 index 00000000000..9ff3e83a8bd --- /dev/null +++ b/backends/cuda/runtime/tests/TARGETS @@ -0,0 +1,6 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/cuda/runtime/tests/targets.bzl b/backends/cuda/runtime/tests/targets.bzl new file mode 100644 index 00000000000..37e8d876526 --- /dev/null +++ b/backends/cuda/runtime/tests/targets.bzl @@ -0,0 +1,27 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") + +def cuda_runtime_cpp_unittest(name): + cpp_unittest( + name = "test_" + name, + srcs = [ + "test_" + name + ".cpp", + ], + deps = [ + "//executorch/backends/cuda/runtime:runtime_shims", + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/platform:platform", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + ) + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + cuda_runtime_cpp_unittest("cuda_guard") + cuda_runtime_cpp_unittest("cuda_stream_guard") diff --git a/backends/cuda/runtime/tests/test_cuda_guard.cpp b/backends/cuda/runtime/tests/test_cuda_guard.cpp new file mode 100644 index 00000000000..a364ae98484 --- /dev/null +++ b/backends/cuda/runtime/tests/test_cuda_guard.cpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::runtime; + +// TODO(gasoonjia): Multiple device tests were not included due to test +// environment limitations. These tests should be added in the future when +// multi-GPU test environments are available, + +class CUDAGuardTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t error = cudaGetDeviceCount(&device_count); + if (error != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available or no CUDA devices found"; + } + device_count_ = device_count; + + ASSERT_EQ(cudaGetDevice(&original_device_), cudaSuccess); + } + + void TearDown() override { + if (device_count_ > 0) { + ASSERT_EQ(cudaSetDevice(original_device_), cudaSuccess); + } + } + + int device_count_ = 0; + int original_device_ = 0; +}; + +TEST_F(CUDAGuardTest, BasicDeviceSwitching) { + int current_device; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + + { + auto guard_result = CUDAGuard::create(0); + ASSERT_TRUE(guard_result.ok()); + CUDAGuard guard = std::move(guard_result.get()); + + int device_after_guard; + ASSERT_EQ(cudaGetDevice(&device_after_guard), cudaSuccess); + EXPECT_EQ(device_after_guard, 0); + EXPECT_EQ(guard.current_device(), 0); + EXPECT_EQ(guard.original_device(), current_device); + } + + int device_after_destruction; + ASSERT_EQ(cudaGetDevice(&device_after_destruction), cudaSuccess); + EXPECT_EQ(device_after_destruction, current_device); +} + +TEST_F(CUDAGuardTest, SameDeviceNoSwitching) { + ASSERT_EQ(cudaSetDevice(0), cudaSuccess); + + { + auto guard_result = CUDAGuard::create(0); + ASSERT_TRUE(guard_result.ok()); + CUDAGuard guard = std::move(guard_result.get()); + + int current_device; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + EXPECT_EQ(current_device, 0); + EXPECT_EQ(guard.current_device(), 0); + EXPECT_EQ(guard.original_device(), 0); + } + + int final_device; + ASSERT_EQ(cudaGetDevice(&final_device), cudaSuccess); + EXPECT_EQ(final_device, 0); +} + +TEST_F(CUDAGuardTest, InvalidDeviceIndex) { + auto guard_result = CUDAGuard::create(999); + EXPECT_FALSE(guard_result.ok()); +} + +TEST_F(CUDAGuardTest, NegativeDeviceIndex) { + auto guard_result = CUDAGuard::create(-2); + EXPECT_FALSE(guard_result.ok()); +} + +TEST_F(CUDAGuardTest, CopyConstructorDeleted) { + static_assert( + !std::is_copy_constructible_v, + "CUDAGuard should not be copy constructible"); +} + +TEST_F(CUDAGuardTest, CopyAssignmentDeleted) { + static_assert( + !std::is_copy_assignable_v, + "CUDAGuard should not be copy assignable"); +} + +TEST_F(CUDAGuardTest, MoveAssignmentDeleted) { + static_assert( + !std::is_move_assignable_v, + "CUDAGuard should not be move assignable"); +} diff --git a/backends/cuda/runtime/tests/test_cuda_stream_guard.cpp b/backends/cuda/runtime/tests/test_cuda_stream_guard.cpp new file mode 100644 index 00000000000..68a050a69be --- /dev/null +++ b/backends/cuda/runtime/tests/test_cuda_stream_guard.cpp @@ -0,0 +1,264 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::runtime; + +// TODO(gasoonjia): Multiple device tests were not included due to test +// environment limitations. These tests should be added in the future when +// multi-GPU test environments are available, + +class CUDAStreamGuardTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t error = cudaGetDeviceCount(&device_count); + if (error != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available or no CUDA devices found"; + } + device_count_ = device_count; + + ASSERT_EQ(cudaGetDevice(&original_device_), cudaSuccess); + + ASSERT_EQ(cudaStreamCreate(&test_stream1_), cudaSuccess); + ASSERT_EQ(cudaStreamCreate(&test_stream2_), cudaSuccess); + } + + void TearDown() override { + if (test_stream1_) { + ASSERT_EQ(cudaStreamDestroy(test_stream1_), cudaSuccess); + } + if (test_stream2_) { + ASSERT_EQ(cudaStreamDestroy(test_stream2_), cudaSuccess); + } + + if (device_count_ > 0) { + ASSERT_EQ(cudaSetDevice(original_device_), cudaSuccess); + } + } + + int device_count_ = 0; + int original_device_ = 0; + cudaStream_t test_stream1_ = nullptr; + cudaStream_t test_stream2_ = nullptr; +}; + +TEST_F(CUDAStreamGuardTest, BasicStreamSwitching) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + EXPECT_EQ(guard.stream(), test_stream1_); + EXPECT_EQ(guard.device_index(), 0); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); + EXPECT_EQ(current_stream_result.get(), test_stream1_); + + int current_device; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + EXPECT_EQ(current_device, 0); +} + +TEST_F(CUDAStreamGuardTest, StreamSwitchingOnSameDevice) { + Error err = setCurrentCUDAStream(test_stream1_, 0); + ASSERT_EQ(err, Error::Ok); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); + EXPECT_EQ(current_stream_result.get(), test_stream1_); + + { + auto guard_result = CUDAStreamGuard::create(test_stream2_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + auto new_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(new_stream_result.ok()); + EXPECT_EQ(new_stream_result.get(), test_stream2_); + EXPECT_EQ(guard.stream(), test_stream2_); + } + + auto restored_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(restored_stream_result.ok()); + EXPECT_EQ(restored_stream_result.get(), test_stream1_); +} + +TEST_F(CUDAStreamGuardTest, NestedStreamGuards) { + cudaStream_t initial_stream; + ASSERT_EQ(cudaStreamCreate(&initial_stream), cudaSuccess); + + Error err = setCurrentCUDAStream(initial_stream, 0); + ASSERT_EQ(err, Error::Ok); + + { + auto guard1_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard1_result.ok()); + CUDAStreamGuard guard1 = std::move(guard1_result.get()); + + auto stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result.ok()); + EXPECT_EQ(stream_result.get(), test_stream1_); + + { + auto guard2_result = CUDAStreamGuard::create(test_stream2_, 0); + ASSERT_TRUE(guard2_result.ok()); + CUDAStreamGuard guard2 = std::move(guard2_result.get()); + + auto stream_result2 = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result2.ok()); + EXPECT_EQ(stream_result2.get(), test_stream2_); + } + + auto stream_result3 = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result3.ok()); + EXPECT_EQ(stream_result3.get(), test_stream1_); + } + + auto final_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(final_stream_result.ok()); + EXPECT_EQ(final_stream_result.get(), initial_stream); + + ASSERT_EQ(cudaStreamDestroy(initial_stream), cudaSuccess); +} + +TEST_F(CUDAStreamGuardTest, SameStreamNoChange) { + Error err = setCurrentCUDAStream(test_stream1_, 0); + ASSERT_EQ(err, Error::Ok); + + { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + auto stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result.ok()); + EXPECT_EQ(stream_result.get(), test_stream1_); + EXPECT_EQ(guard.stream(), test_stream1_); + } + + auto final_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(final_stream_result.ok()); + EXPECT_EQ(final_stream_result.get(), test_stream1_); +} + +TEST_F(CUDAStreamGuardTest, StreamAccessor) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + EXPECT_EQ(guard.stream(), test_stream1_); + EXPECT_EQ(guard.device_index(), 0); +} + +TEST_F(CUDAStreamGuardTest, SetStreamMethod) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + EXPECT_EQ(guard.stream(), test_stream1_); + + Error err = guard.set_stream(test_stream2_, 0); + EXPECT_EQ(err, Error::Ok); + + EXPECT_EQ(guard.stream(), test_stream2_); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); + EXPECT_EQ(current_stream_result.get(), test_stream2_); +} + +TEST_F(CUDAStreamGuardTest, MoveConstructor) { + auto guard1_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard1_result.ok()); + CUDAStreamGuard guard1 = std::move(guard1_result.get()); + + EXPECT_EQ(guard1.stream(), test_stream1_); + EXPECT_EQ(guard1.device_index(), 0); + + CUDAStreamGuard guard2 = std::move(guard1); + + EXPECT_EQ(guard2.stream(), test_stream1_); + EXPECT_EQ(guard2.device_index(), 0); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); + EXPECT_EQ(current_stream_result.get(), test_stream1_); +} + +TEST_F(CUDAStreamGuardTest, MoveConstructorRestoresOnlyOnce) { + cudaStream_t initial_stream; + ASSERT_EQ(cudaStreamCreate(&initial_stream), cudaSuccess); + + Error err = setCurrentCUDAStream(initial_stream, 0); + ASSERT_EQ(err, Error::Ok); + + { + auto guard1_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard1_result.ok()); + CUDAStreamGuard guard1 = std::move(guard1_result.get()); + + { CUDAStreamGuard guard2 = std::move(guard1); } + + auto stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result.ok()); + EXPECT_EQ(stream_result.get(), initial_stream); + } + + auto final_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(final_stream_result.ok()); + EXPECT_EQ(final_stream_result.get(), initial_stream); + + ASSERT_EQ(cudaStreamDestroy(initial_stream), cudaSuccess); +} + +TEST_F(CUDAStreamGuardTest, InvalidDeviceIndex) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 999); + EXPECT_FALSE(guard_result.ok()); +} + +TEST_F(CUDAStreamGuardTest, NegativeDeviceIndex) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, -2); + EXPECT_FALSE(guard_result.ok()); +} + +TEST_F(CUDAStreamGuardTest, CopyConstructorDeleted) { + static_assert( + !std::is_copy_constructible_v, + "CUDAStreamGuard should not be copy constructible"); +} + +TEST_F(CUDAStreamGuardTest, CopyAssignmentDeleted) { + static_assert( + !std::is_copy_assignable_v, + "CUDAStreamGuard should not be copy assignable"); +} + +TEST_F(CUDAStreamGuardTest, MoveAssignmentDeleted) { + static_assert( + !std::is_move_assignable_v, + "CUDAStreamGuard should not be move assignable"); +} + +TEST_F(CUDAStreamGuardTest, NullStreamPointer) { + auto guard_result = CUDAStreamGuard::create(nullptr, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + EXPECT_EQ(guard.stream(), nullptr); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); +} diff --git a/backends/cuda/runtime/utils.h b/backends/cuda/runtime/utils.h new file mode 100644 index 00000000000..4474f8cf57e --- /dev/null +++ b/backends/cuda/runtime/utils.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +// CUDA error checking macro (with return) +#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \ + do { \ + const cudaError_t err = EXPR; \ + if (err == cudaSuccess) { \ + break; \ + } \ + ET_LOG( \ + Error, \ + "%s:%d CUDA error: %s", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(err)); \ + return Error::Internal; \ + } while (0) + +// CUDA error checking macro (without return, for use in void functions) +#define ET_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t err = EXPR; \ + if (err == cudaSuccess) { \ + break; \ + } \ + ET_LOG( \ + Error, \ + "%s:%d CUDA error: %s", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(err)); \ + ET_CHECK_MSG(false, "CUDA error: %s", cudaGetErrorString(err)); \ + } while (0) + +// Kernel launch check macro (with return) +#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \ + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError()) + +// Kernel launch check macro (without return, for use in void functions) +#define ET_CUDA_KERNEL_LAUNCH_CHECK() ET_CUDA_CHECK(cudaGetLastError()) + +namespace executorch::backends::cuda { + +// Enum for supported data types in et-cuda backend +enum class SupportedDTypes : int32_t { + INT8 = 1, // PyTorch's int8 dtype code + INT16 = 2, // PyTorch's int16 dtype code + INT32 = 3, // PyTorch's int32 dtype code + INT64 = 4, // PyTorch's int64 dtype code + FLOAT32 = 6, // PyTorch's float32 dtype code + BOOL = 11, // PyTorch's bool dtype code + BFLOAT16 = 15, // PyTorch's bfloat16 dtype code +}; + +// Enum for supported device types in et-cuda backend +enum class SupportedDevices : int32_t { + CPU = 0, // CPU device + CUDA = 1, // CUDA device +}; + +extern "C" { +using executorch::runtime::Error; +// Common AOTI type aliases +using AOTITorchError = Error; + +// Helper function to check if a dtype is supported in ET CUDA backend +inline bool is_dtype_supported_in_et_cuda(int32_t dtype) { + switch (dtype) { + case static_cast(SupportedDTypes::INT8): + case static_cast(SupportedDTypes::INT16): + case static_cast(SupportedDTypes::INT32): + case static_cast(SupportedDTypes::INT64): + case static_cast(SupportedDTypes::FLOAT32): + case static_cast(SupportedDTypes::BOOL): + case static_cast(SupportedDTypes::BFLOAT16): + return true; + default: + return false; + } +} + +// Dtype validation utility function +inline AOTITorchError validate_dtype(int32_t dtype) { + ET_CHECK_OR_RETURN_ERROR( + is_dtype_supported_in_et_cuda(dtype), + InvalidArgument, + "Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bool), %d (bfloat16)", + dtype, + static_cast(SupportedDTypes::INT8), + static_cast(SupportedDTypes::INT16), + static_cast(SupportedDTypes::INT32), + static_cast(SupportedDTypes::INT64), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::BOOL), + static_cast(SupportedDTypes::BFLOAT16)); + + return Error::Ok; +} +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/tests/TARGETS b/backends/cuda/tests/TARGETS new file mode 100644 index 00000000000..974086cd4c5 --- /dev/null +++ b/backends/cuda/tests/TARGETS @@ -0,0 +1,42 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") +load("@fbcode_macros//build_defs:python_unittest_remote_gpu.bzl", "python_unittest_remote_gpu") + +oncall("executorch") + +python_unittest_remote_gpu( + name = "test_cuda_export", + srcs = [ + "test_cuda_export.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/cuda:cuda_backend", + "//executorch/backends/cuda:cuda_partitioner", + "//executorch/exir:lib", + "//executorch/exir/backend:backend_api", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/examples/models/toy_model:toy_model", + ], + keep_gpu_sections = True, +) + +python_unittest( + name = "test_cuda_partitioner", + srcs = [ + "test_cuda_partitioner.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/cuda:cuda_partitioner", + "//executorch/backends/cuda:cuda_backend", + "//executorch/exir:lib", + "//executorch/exir/backend:compile_spec_schema", + ], +) diff --git a/backends/cuda/tests/__init__.py b/backends/cuda/tests/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/cuda/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/cuda/tests/test_cuda_export.py b/backends/cuda/tests/test_cuda_export.py new file mode 100644 index 00000000000..ff4a9313545 --- /dev/null +++ b/backends/cuda/tests/test_cuda_export.py @@ -0,0 +1,327 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Tuple + +import torch +from executorch.backends.cuda.cuda_backend import CudaBackend +from executorch.backends.cuda.cuda_partitioner import CudaPartitioner +from executorch.examples.models.toy_model import SdpaModule +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch.export import export + + +class TestCudaExport(unittest.TestCase): + """Test CUDA export functionality for various operations using to_edge_transform_and_lower.""" + + def setUp(self): + """Set up test environment.""" + # Skip tests if CUDA is not available + if not torch.cuda.is_available(): + self.skipTest("CUDA is not available") + + def _export_to_cuda_with_lower( + self, + module: torch.nn.Module, + inputs: Tuple[torch.Tensor, ...], + compile_specs: list[CompileSpec] | None = None, + ) -> None: + """Helper method to export a module to CUDA backend using to_edge_transform_and_lower. + + Args: + module: The torch.nn.Module to export + inputs: The example inputs for the module + compile_specs: Optional list of compile specs. If not provided, defaults to + only the method name compile spec for "forward" + """ + # Export the model + exported_program = export(module, inputs, strict=True) + + # Create partitioner with compile specs + if compile_specs is None: + compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")] + + partitioner = CudaPartitioner(compile_specs) + + # Use to_edge_transform_and_lower for complete pipeline + edge_program_manager = to_edge_transform_and_lower( + exported_program, + partitioner=[partitioner], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + ), + ) + + # Verify that the pipeline succeeded + self.assertIsNotNone(edge_program_manager) + self.assertTrue(hasattr(edge_program_manager, "exported_program")) + + # Verify that the final exported program contains delegated calls + exported_program = edge_program_manager.exported_program() + has_delegate_call = False + for node in exported_program.graph.nodes: + if node.op == "call_function" and "executorch_call_delegate" in str( + node.target + ): + has_delegate_call = True + break + + self.assertTrue( + has_delegate_call, "No delegate calls found in final exported program" + ) + + return edge_program_manager + + def test_simple_add(self): + """Test CUDA export for simple element-wise addition.""" + + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + module = AddModule() + module.eval() + inputs = (torch.randn(3, 4), torch.randn(3, 4)) + + # Test export + edge_program_manager = self._export_to_cuda_with_lower(module, inputs) + self.assertIsNotNone(edge_program_manager, "Simple add operation export failed") + + def test_conv2d(self): + """Test CUDA export for 2D convolution.""" + + class Conv2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + module = Conv2dModule() + module.eval() + inputs = (torch.randn(1, 3, 32, 32),) + + # Test export + edge_program_manager = self._export_to_cuda_with_lower(module, inputs) + self.assertIsNotNone(edge_program_manager, "Conv2d operation export failed") + + def test_linear(self): + """Test CUDA export for linear layer.""" + + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 64) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + module = LinearModule() + module.eval() + inputs = (torch.randn(8, 128),) + + # Test export + edge_program_manager = self._export_to_cuda_with_lower(module, inputs) + self.assertIsNotNone(edge_program_manager, "Linear operation export failed") + + def test_resnet_block(self): + """Test CUDA export for a ResNet-style block.""" + + class ResNetBlock(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, stride: int = 1): + super().__init__() + self.conv1 = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + # Use eval mode to avoid batch norm mutations during export + self.bn1 = torch.nn.BatchNorm2d(out_channels) + self.relu = torch.nn.ReLU(inplace=True) + self.conv2 = torch.nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + self.bn2 = torch.nn.BatchNorm2d(out_channels) + + # Shortcut connection + self.shortcut = torch.nn.Sequential() + if stride != 1 or in_channels != out_channels: + self.shortcut = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + ), + torch.nn.BatchNorm2d(out_channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = self.shortcut(x) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += identity + out = self.relu(out) + + return out + + module = ResNetBlock(64, 64) + # Set module to eval mode to avoid batch norm running statistics mutations + module.eval() + inputs = (torch.randn(1, 64, 32, 32),) + + # Test export + edge_program_manager = self._export_to_cuda_with_lower(module, inputs) + self.assertIsNotNone(edge_program_manager, "ResNet block export failed") + + def test_multi_operation_module(self): + """Test CUDA export for a module with multiple operations.""" + + class MultiOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1) + self.relu = torch.nn.ReLU() + self.pool = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.linear = torch.nn.Linear(32, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = self.relu(x) + x = self.pool(x) + x = x.view(x.size(0), -1) + x = self.linear(x) + return x + + module = MultiOpModule() + module.eval() + inputs = (torch.randn(2, 3, 16, 16),) + + # Test export + edge_program_manager = self._export_to_cuda_with_lower(module, inputs) + self.assertIsNotNone( + edge_program_manager, "Multi-operation module export failed" + ) + + def test_activation_functions(self): + """Test CUDA export for various activation functions.""" + + class ActivationModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Test multiple activation functions + x1 = torch.relu(x) + x2 = torch.sigmoid(x) + x3 = torch.tanh(x) + return x1 + x2 + x3 + + module = ActivationModule() + module.eval() + inputs = (torch.randn(4, 8),) + + # Test export + edge_program_manager = self._export_to_cuda_with_lower(module, inputs) + self.assertIsNotNone(edge_program_manager, "Activation functions export failed") + + def test_mathematical_operations(self): + """Test CUDA export for mathematical operations.""" + + class MathOpsModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Test various mathematical operations + add_result = x + y + mul_result = x * y + sub_result = x - y + div_result = x / (y + 1e-8) # Add epsilon to avoid division by zero + return add_result + mul_result + sub_result + div_result + + module = MathOpsModule() + module.eval() + inputs = (torch.randn(4, 4), torch.randn(4, 4)) + + # Test export + edge_program_manager = self._export_to_cuda_with_lower(module, inputs) + self.assertIsNotNone( + edge_program_manager, "Mathematical operations export failed" + ) + + def test_conv1d(self): + """Test CUDA export for 1D convolution.""" + + class Conv1dModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 16, kernel_size=3, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + module = Conv1dModule() + module.eval() + inputs = (torch.randn(1, 3, 10),) + + # Test export + edge_program_manager = self._export_to_cuda_with_lower(module, inputs) + self.assertIsNotNone(edge_program_manager, "Conv1d operation export failed") + + def test_sdpa_single_kernel(self): + """ + Test CUDA export for model containing single SDPA kernel. + SDPA: Scaled Dot Product Attention + """ + + sdpa = SdpaModule() + + # Test export + edge_program_manager = self._export_to_cuda_with_lower( + sdpa.get_eager_model(), sdpa.get_example_inputs() + ) + self.assertIsNotNone( + edge_program_manager, + "SDPA single kernel operation export failed", + ) + + def test_triton_kernel_mode_off(self): + """ + Test CUDA export with triton_kernel_mode set to OFF for SDPA kernel. + This validates that the backend correctly processes the triton_kernel_mode + compile spec and can export SDPA operations without Triton kernel replacements. + When triton_kernel_mode is OFF, SDPA should be decomposed using the MATH backend. + """ + + sdpa = SdpaModule() + + # Create compile specs with triton_kernel_mode set to OFF + compile_specs = [ + CudaBackend.generate_method_name_compile_spec("forward"), + CompileSpec(key="triton_kernel_mode", value=b"OFF"), + ] + + # Test export with triton_kernel_mode=OFF + edge_program_manager = self._export_to_cuda_with_lower( + sdpa.get_eager_model(), sdpa.get_example_inputs(), compile_specs + ) + self.assertIsNotNone( + edge_program_manager, + "SDPA kernel export with triton_kernel_mode=OFF failed", + ) diff --git a/backends/cuda/tests/test_cuda_partitioner.py b/backends/cuda/tests/test_cuda_partitioner.py new file mode 100644 index 00000000000..c08c0e6ff56 --- /dev/null +++ b/backends/cuda/tests/test_cuda_partitioner.py @@ -0,0 +1,224 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Tuple + +import torch +from executorch.backends.cuda.cuda_partitioner import CudaPartitioner +from executorch.exir.backend.partitioner import PartitionResult +from torch.export import export + + +class TestCudaPartitioner(unittest.TestCase): + """ + Test CUDA partitioner functionality. + + After CUDA partitioning, there should be exactly one partitioned graph that contains + all operators from the input graph. This means all operators should be tagged with + the same delegation tag, indicating they will all be executed by the CUDA backend. + """ + + def _get_partition_result( + self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...] + ) -> PartitionResult: + """Helper method to get partition result for a given module.""" + # Export the model + exported_program = export(module, inputs, strict=True) + + # Create partitioner and compile specs + partitioner = CudaPartitioner([]) + + # Get partition result + partition_result = partitioner.partition(exported_program) + + # Verify partition result structure + self.assertIsNotNone(partition_result) + self.assertTrue(hasattr(partition_result, "tagged_exported_program")) + self.assertTrue(hasattr(partition_result, "partition_tags")) + + return partition_result + + def _check_fully_partitioned(self, partition_result: PartitionResult) -> bool: + """Check if the graph is fully partitioned (all operators have the same tag).""" + tagged_nodes = [] + untagged_ops = [] + + for node in partition_result.tagged_exported_program.graph.nodes: + if node.op == "call_function": + if hasattr(node, "meta") and "delegation_tag" in node.meta: + tagged_nodes.append(node) + else: + untagged_ops.append(node) + + # Check if we have any tagged nodes + if not tagged_nodes: + return False + + # Check if all tagged nodes have the same tag + first_tag = tagged_nodes[0].meta["delegation_tag"] + all_same_tag = all( + node.meta.get("delegation_tag") == first_tag for node in tagged_nodes + ) + + # Should have no untagged operations for full partitioning + fully_partitioned = len(untagged_ops) == 0 and all_same_tag + + return fully_partitioned + + def test_simple_add_partition(self): + """ + Test that CUDA partitioner creates exactly one partition containing all operators. + Simple element-wise addition should result in a single graph with all ops tagged identically. + """ + + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + module = AddModule() + inputs = (torch.randn(3, 4), torch.randn(3, 4)) + + partition_result = self._get_partition_result(module, inputs) + fully_partitioned = self._check_fully_partitioned(partition_result) + + self.assertTrue( + fully_partitioned, + "Graph should be fully partitioned with all operators having the same tag", + ) + + def test_conv2d_partition(self): + """ + Test that CUDA partitioner creates exactly one partition containing all operators. + Conv2D operation should result in a single graph with all ops tagged identically. + """ + + class Conv2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + module = Conv2dModule() + inputs = (torch.randn(1, 3, 32, 32),) + + partition_result = self._get_partition_result(module, inputs) + fully_partitioned = self._check_fully_partitioned(partition_result) + + self.assertTrue( + fully_partitioned, + "Graph should be fully partitioned with all operators having the same tag", + ) + + def test_linear_partition(self): + """ + Test that CUDA partitioner creates exactly one partition containing all operators. + Linear layer operation should result in a single graph with all ops tagged identically. + """ + + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 64) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + module = LinearModule() + inputs = (torch.randn(8, 128),) + + partition_result = self._get_partition_result(module, inputs) + fully_partitioned = self._check_fully_partitioned(partition_result) + + self.assertTrue( + fully_partitioned, + "Graph should be fully partitioned with all operators having the same tag", + ) + + def test_unused_constant_tagging(self): + """ + Test that constant nodes without users are properly tagged with delegation_tag. + + When a graph contains constants (parameters, buffers, or lifted tensor constants) + that are not used by any operations, the CUDA partitioner should still tag them + with the delegation_tag. This ensures all constant data is properly handled during + delegation, even if they have no users in the graph. + """ + + class ModuleWithUnusedConst(torch.nn.Module): + def __init__(self): + super().__init__() + # Register a buffer that won't be used in forward + self.register_buffer("unused_buffer", torch.randn(10, 10)) + # Also register a used parameter + self.weight = torch.nn.Parameter(torch.randn(5, 5)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Only use the weight parameter, not the unused_buffer + return x + self.weight + + module = ModuleWithUnusedConst() + inputs = (torch.randn(5, 5),) + + # Get partition result + partition_result = self._get_partition_result(module, inputs) + + # Find all placeholder nodes (these represent constants, parameters, buffers, and inputs) + constant_placeholders = [] + input_placeholders = [] + + for node in partition_result.tagged_exported_program.graph.nodes: + if node.op == "placeholder": + # Check if this is a constant (param, buffer, or lifted tensor constant) + from torch._export.utils import ( + is_buffer, + is_lifted_tensor_constant, + is_param, + ) + + is_constant = ( + is_param(partition_result.tagged_exported_program, node) + or is_buffer(partition_result.tagged_exported_program, node) + or is_lifted_tensor_constant( + partition_result.tagged_exported_program, node + ) + ) + + if is_constant: + constant_placeholders.append(node) + else: + input_placeholders.append(node) + + # Verify we have constant placeholders + self.assertGreater( + len(constant_placeholders), + 0, + "Expected to find constant placeholders in the graph", + ) + + # Check that all constant placeholders are tagged, including unused ones + untagged_constants = [] + for node in constant_placeholders: + if "delegation_tag" not in node.meta: + untagged_constants.append(node.name) + + self.assertEqual( + len(untagged_constants), + 0, + f"All constant placeholders should be tagged. Found untagged constants: {untagged_constants}", + ) + + # Verify all tagged constants have the expected tag + expected_tag = "tag0" + for node in constant_placeholders: + actual_tag = node.meta.get("delegation_tag") + self.assertEqual( + actual_tag, + expected_tag, + f"Constant placeholder {node.name} has tag '{actual_tag}' but expected '{expected_tag}'", + ) diff --git a/backends/cuda/triton/__init__.py b/backends/cuda/triton/__init__.py new file mode 100644 index 00000000000..4b9c36249ac --- /dev/null +++ b/backends/cuda/triton/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Import all kernels to ensure @triton_op decorators are executed +# and ops are registered to torch.ops.triton namespace +from executorch.backends.cuda.triton import kernels # noqa: F401 + +from executorch.backends.cuda.triton.replacement_pass import ( + ReplaceEdgeOpWithTritonOpPass, +) + +__all__ = [ + "ReplaceEdgeOpWithTritonOpPass", +] diff --git a/backends/cuda/triton/kernels/__init__.py b/backends/cuda/triton/kernels/__init__.py new file mode 100644 index 00000000000..5bd582679c4 --- /dev/null +++ b/backends/cuda/triton/kernels/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.cuda.triton.kernels.sdpa import sdpa + +__all__ = [ + "sdpa", +] diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py new file mode 100644 index 00000000000..e05dcdbbd28 --- /dev/null +++ b/backends/cuda/triton/kernels/sdpa.py @@ -0,0 +1,827 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Triton SDPA Kernel for ExecuTorch CUDA Backend. + +This module provides a Triton-optimized implementation of scaled dot-product attention +that can replace the default ATen/Edge SDPA operator during graph transformation to allow +us export the model without decomposing the SDPA operator under libtorch free environment +and have better performance. +""" + +import math +from typing import Optional + +import torch +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + + +def _is_power_of_2(n: int) -> bool: + """Check if n is a power of 2.""" + return n > 0 and (n & (n - 1)) == 0 + + +def _next_power_of_2(x: int) -> int: + """Get the next power of 2 >= x, clamped to [16, 256].""" + if x <= 16: + return 16 + if x <= 32: + return 32 + if x <= 64: + return 64 + if x <= 128: + return 128 + return 256 + + +def _validate_qkv_shapes( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, +) -> tuple[int, int, int, int, int, int]: + """ + Validate dimensions and return shape info. + Args: + query: Query tensor [B, H, L_q, D] + key: Key tensor [B, H, L_kv, D] + value: Value tensor [B, H, L_kv, D] + Returns: + Tuple of (B, H, L_q, L_kv, D_q, D_kv) + Raises: + RuntimeError: If dimensions are incompatible + """ + B_q, H_q, L_q, D_q = query.shape + B_k, H_k, L_kv_k, D_k = key.shape + B_v, H_v, L_kv_v, D_v = value.shape + # Validate batch and head dimensions + if not (B_q == B_k == B_v): + raise RuntimeError( + f"Batch dimension must match; got B_q={B_q}, B_k={B_k}, B_v={B_v}." + ) + + if not (H_q == H_k == H_v): + raise RuntimeError( + f"Head dimension must match; got H_q={H_q}, H_k={H_k}, H_v={H_v}." + ) + # Head dimension must match + if not (D_q == D_k == D_v): + raise RuntimeError( + f"Head dimension must match across Q, K, V; got D_q={D_q}, D_k={D_k}, D_v={D_v}." + ) + # Key and Value sequence lengths must match + if L_kv_k != L_kv_v: + raise RuntimeError( + f"Key and Value must have the same sequence length; got L_k={L_kv_k}, L_v={L_kv_v}." + ) + return B_q, H_q, L_q, L_kv_k, D_q, D_k + + +# ============================================================================== +# Non-power-of-2 HEAD_DIM kernel +# ============================================================================== +@triton.jit +def _sdpa_fwd_kernel_non_pow2( + q_ptr, + k_ptr, + v_ptr, + o_ptr, + mask_ptr, + B, + H, + LQ, + LK, + HEAD_DIM, + stride_qb, + stride_qh, + stride_ql, + stride_qd, + stride_kb, + stride_kh, + stride_kl, + stride_kd, + stride_vb, + stride_vh, + stride_vl, + stride_vd, + stride_ob, + stride_oh, + stride_ol, + stride_od, + stride_mb, + stride_mh, + stride_mlq, + stride_mlk, + scale, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, +): + """ + SDPA forward kernel for non-power-of-2 HEAD_DIM. + Uses dynamic masking to handle arbitrary head dimensions. + """ + pid_m = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + b = pid_bh // H + h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + d_mask = offs_d < HEAD_DIM + q_row_mask = offs_m < LQ + + q_base = q_ptr + b * stride_qb + h * stride_qh + k_base = k_ptr + b * stride_kb + h * stride_kh + v_base = v_ptr + b * stride_vb + h * stride_vh + o_base = o_ptr + b * stride_ob + h * stride_oh + + q_ptrs = q_base + (offs_m[:, None] * stride_ql + offs_d[None, :] * stride_qd) + q = tl.load(q_ptrs, mask=q_row_mask[:, None] & d_mask[None, :], other=0.0) + + acc = tl.zeros((BLOCK_M, BLOCK_D), dtype=tl.float32) + m_i = tl.full((BLOCK_M,), -float("inf"), dtype=tl.float32) + l_i = tl.full((BLOCK_M,), 1.0, dtype=tl.float32) + + qk_scale_log2 = scale * 1.4426950408889634 + + if HAS_MASK: + mask_b_base = mask_ptr + b * stride_mb + + for start_n in tl.range(0, LK, BLOCK_N, num_stages=2): + kn = start_n + offs_n + kv_col_mask = kn < LK + + k_ptrs = k_base + (kn[:, None] * stride_kl + offs_d[None, :] * stride_kd) + k = tl.load(k_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0) + + qk = tl.dot(q, tl.trans(k)) + qk = qk * qk_scale_log2 + + if IS_CAUSAL: + row_abs = offs_m[:, None] + col_abs = kn[None, :] + causal_mask = col_abs > row_abs + qk = tl.where(causal_mask, -float("inf"), qk) + + if HAS_MASK: + mask_ptrs = ( + mask_b_base + offs_m[:, None] * stride_mlq + kn[None, :] * stride_mlk + ) + tile_valid = q_row_mask[:, None] & kv_col_mask[None, :] + keep = tl.load(mask_ptrs, mask=tile_valid, other=True) + qk = tl.where(keep, qk, -float("inf")) + + qk = tl.where(kv_col_mask[None, :], qk, -float("inf")) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + + acc = acc * alpha[:, None] + + v_ptrs = v_base + (kn[:, None] * stride_vl + offs_d[None, :] * stride_vd) + v = tl.load(v_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0) + + acc = tl.dot(p.to(v.dtype), v, acc) + + l_i = l_i * alpha + l_ij + m_i = m_ij + + out = acc / l_i[:, None] + o_ptrs = o_base + (offs_m[:, None] * stride_ol + offs_d[None, :] * stride_od) + tl.store(o_ptrs, out.to(tl.bfloat16), mask=q_row_mask[:, None] & d_mask[None, :]) + + +# ============================================================================== +# Power-of-2 HEAD_DIM kernels +# ============================================================================== +@triton.jit +def _sdpa_fwd_kernel_body( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale: tl.float32, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + """ + Shared kernel body for SDPA forward pass. + """ + pid_m = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + b = pid_bh // H + h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n_init = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, HEAD_DIM) + + q_ptrs = Q_ptr + ( + b * stride_qb + + h * stride_qh + + (offs_m[:, None] * stride_qm) + + (offs_d[None, :] * stride_qd) + ) + q_mask = (offs_m[:, None] < Lq) & (offs_d[None, :] < HEAD_DIM) + q = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.bfloat16) + + m_i = tl.full([BLOCK_M], -float("inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + for start_n in tl.range(0, Lk, BLOCK_N): + offs_n = start_n + offs_n_init + + k_ptrs = K_ptr + ( + b * stride_kb + + h * stride_kh + + (offs_n[:, None] * stride_kn) + + (offs_d[None, :] * stride_kd) + ) + k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16) + + qk = tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale + + if HAS_MASK: + mask_ptrs = Mask_ptr + ( + b * stride_mb + + (offs_m[:, None] * stride_mq) + + (offs_n[None, :] * stride_mk) + ) + mn_mask = (offs_m[:, None] < Lq) & (offs_n[None, :] < Lk) + mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False) + qk = tl.where(mask_block, qk, -float("inf")) + + if IS_CAUSAL: + abs_m = offs_m[:, None] + abs_n = offs_n[None, :] + causal = abs_n > abs_m + qk = tl.where(causal, -float("inf"), qk) + + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p_f32 = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p_f32, axis=1) + alpha = tl.exp(m_i - m_ij) + + v_ptrs = V_ptr + ( + b * stride_vb + + h * stride_vh + + (offs_n[:, None] * stride_vn) + + (offs_d[None, :] * stride_vd) + ) + v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16) + + p_bf16 = p_f32.to(tl.bfloat16) + acc = acc * alpha[:, None] + tl.dot(p_bf16, v) + l_i = l_i * alpha + l_ij + m_i = m_ij + + inv_l_i = tl.where(l_i > 0, 1.0 / l_i, 0.0) + acc = acc * inv_l_i[:, None] + + o_ptrs = O_ptr + ( + b * stride_ob + + h * stride_oh + + (offs_m[:, None] * stride_om) + + (offs_d[None, :] * stride_od) + ) + o_mask = (offs_m[:, None] < Lq) & (offs_d[None, :] < HEAD_DIM) + tl.store(o_ptrs, acc.to(tl.bfloat16), mask=o_mask) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2), + ], + key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL"], +) +@triton.jit +def _sdpa_fwd_kernel_m64( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale: tl.float32, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + SDPA kernel with BLOCK_M=64 optimizations. + """ + _sdpa_fwd_kernel_body( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=IS_CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + HEAD_DIM=HEAD_DIM, + ) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), + ], + key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL"], +) +@triton.jit +def _sdpa_fwd_kernel_m32( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale: tl.float32, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + SDPA kernel with BLOCK_M=32 optimizations for small workloads. + """ + _sdpa_fwd_kernel_body( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=IS_CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + HEAD_DIM=HEAD_DIM, + ) + + +def _validate_sdpa_inputs( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float, + enable_gqa: bool, +) -> None: + """Validate SDPA input tensors and unsupported feature flags.""" + if not (query.is_cuda and key.is_cuda and value.is_cuda): + raise RuntimeError("Q, K, V must be CUDA tensors.") + if ( + query.dtype != torch.bfloat16 + or key.dtype != torch.bfloat16 + or value.dtype != torch.bfloat16 + ): + raise RuntimeError("Expected bfloat16 inputs") + if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: + raise RuntimeError( + f"Expected 4D tensors shaped [B, H, L, D]; got " + f"query.dim()={query.dim()}, key.dim()={key.dim()}, " + f"value.dim()={value.dim()}." + ) + if dropout_p != 0.0: + raise RuntimeError( + "dropout_p must be 0.0 (not supported in this implementation)." + ) + if enable_gqa is not False: + raise RuntimeError( + "enable_gqa must be False (not supported in this implementation)." + ) + + +def _prepare_mask_params( + attn_mask: Optional[torch.Tensor], + B: int, + L_q: int, + L_kv: int, +) -> tuple[bool, torch.Tensor, int, int, int]: + """Prepare attention mask parameters for kernel invocation.""" + if attn_mask is None: + return False, 0, 0, 0, 0 + + if attn_mask.dtype != torch.bool: + raise RuntimeError("attn_mask must have dtype torch.bool") + if not attn_mask.is_cuda: + raise RuntimeError("attn_mask must be a CUDA tensor") + if ( + attn_mask.shape[0] != B + or attn_mask.shape[2] != L_q + or attn_mask.shape[3] != L_kv + ): + raise RuntimeError( + f"attn_mask shape mismatch: expected [B={B}, H, L_q={L_q}, L_kv={L_kv}], " + f"got {attn_mask.shape}" + ) + return ( + True, + attn_mask, + attn_mask.stride(0), + attn_mask.stride(2), + attn_mask.stride(3), + ) + + +def _launch_pow2_kernel( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + B: int, + H: int, + L_q: int, + L_kv: int, + D: int, + sm_scale: float, + HAS_MASK: bool, + Mask_ptr: torch.Tensor, + stride_mb: int, + stride_mq: int, + stride_mk: int, + is_causal: bool, +) -> None: + """Launch power-of-2 optimized SDPA kernel.""" + stride_qb, stride_qh, stride_qm, stride_qd = query.stride() + stride_kb, stride_kh, stride_kn, stride_kd = key.stride() + stride_vb, stride_vh, stride_vn, stride_vd = value.stride() + stride_ob, stride_oh, stride_om, stride_od = out.stride() + + def grid(meta): + return (triton.cdiv(L_q, meta["BLOCK_M"]), B * H) + + total_ctas_m64 = ((L_q + 63) // 64) * (B * H) + threshold = 4 * 84 + kernel = ( + _sdpa_fwd_kernel_m32 if total_ctas_m64 < threshold else _sdpa_fwd_kernel_m64 + ) + + wrap_triton(kernel)[grid]( + query, + key, + value, + out, + Mask_ptr if HAS_MASK else 0, + B, + H, + L_q, + L_kv, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=is_causal, + HEAD_DIM=D, + ) + + +def _launch_non_pow2_kernel( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + attn_mask: Optional[torch.Tensor], + B: int, + H: int, + L_q: int, + L_kv: int, + D: int, + sm_scale: float, + HAS_MASK: bool, + is_causal: bool, +) -> None: + """Launch non-power-of-2 SDPA kernel with dynamic HEAD_DIM masking.""" + stride_qb, stride_qh, stride_qm, stride_qd = query.stride() + stride_kb, stride_kh, stride_kn, stride_kd = key.stride() + stride_vb, stride_vh, stride_vn, stride_vd = value.stride() + stride_ob, stride_oh, stride_om, stride_od = out.stride() + + BLOCK_D = _next_power_of_2(D) + BLOCK_N = 64 if BLOCK_D >= 256 else 128 + BLOCK_M = 32 + num_warps = 4 + num_stages = 2 + + if HAS_MASK: + mask_ptr = attn_mask + stride_mb_np2 = attn_mask.stride(0) + stride_mh_np2 = attn_mask.stride(1) + stride_mlq_np2 = attn_mask.stride(2) + stride_mlk_np2 = attn_mask.stride(3) + else: + mask_ptr = torch.empty((1,), device=query.device, dtype=torch.bool) + stride_mb_np2 = stride_mh_np2 = stride_mlq_np2 = stride_mlk_np2 = 0 + + def grid_non_pow2(meta): + return (triton.cdiv(L_q, meta["BLOCK_M"]), B * H) + + wrap_triton(_sdpa_fwd_kernel_non_pow2)[grid_non_pow2]( + query, + key, + value, + out, + mask_ptr, + B, + H, + L_q, + L_kv, + D, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb_np2, + stride_mh_np2, + stride_mlq_np2, + stride_mlk_np2, + sm_scale, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_D=BLOCK_D, + HAS_MASK=HAS_MASK, + IS_CAUSAL=is_causal, + num_warps=num_warps, + num_stages=num_stages, + ) + + +@triton_op("triton::sdpa", mutates_args={}) +def sdpa( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, +) -> torch.Tensor: + """ + Triton fused Scaled Dot-Product Attention with optimized dual-kernel approach. + + Args: + query: Query tensor with size [B, H, L_q, D] and dtype torch.bfloat16 + key: Key tensor [B, H, L_kv, D] and dtype torch.bfloat16 + value: Value tensor [B, H, L_kv, D] and dtype torch.bfloat16 + attn_mask: Optional attention mask [B, H, L_q, L_kv] with dtype torch.bool + dropout_p: must be 0.0 (others are not supported) + is_causal: whether to apply causal masking + scale: attention scale (default: 1/sqrt(D)) + enable_gqa: must be False (True is not supported) + Returns: + Output tensor [B, H, L_q, D] with dtype torch.bfloat16 + """ + _validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa) + + B, H, L_q, L_kv, D_q, _ = _validate_qkv_shapes(query, key, value) + D = D_q + + if is_causal and L_q != L_kv: + raise RuntimeError( + f"Causal masking requires L_q == L_kv; got L_q={L_q}, L_kv={L_kv}." + ) + + out = torch.empty((B, H, L_q, D), device=query.device, dtype=query.dtype) + sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale + HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk = _prepare_mask_params( + attn_mask, B, L_q, L_kv + ) + + if _is_power_of_2(D): + _launch_pow2_kernel( + query, + key, + value, + out, + B, + H, + L_q, + L_kv, + D, + sm_scale, + HAS_MASK, + Mask_ptr, + stride_mb, + stride_mq, + stride_mk, + is_causal, + ) + else: + _launch_non_pow2_kernel( + query, + key, + value, + out, + attn_mask, + B, + H, + L_q, + L_kv, + D, + sm_scale, + HAS_MASK, + is_causal, + ) + + return out + + +# Register the abstract/fake implementation for torch.export +# This is critical to avoid accessing real tensor data during export +@sdpa.register_fake +def _sdpa_abstract( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 0.0, + enable_gq: bool = False, +) -> torch.Tensor: + """ + Abstract/fake implementation for torch.export. + This just returns an empty tensor with the correct shape/dtype/device. + """ + # Validate dtypes match + assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype" + # Validate kqv's shape and get the output shape + B, H, L_q, _, D_q, _ = _validate_qkv_shapes(query, key, value) + + return torch.empty(B, H, L_q, D_q, dtype=query.dtype, device=query.device) diff --git a/backends/cuda/triton/replacement_pass.py b/backends/cuda/triton/replacement_pass.py new file mode 100644 index 00000000000..bfa3838296b --- /dev/null +++ b/backends/cuda/triton/replacement_pass.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Graph Transformation Pass for Triton Kernel Replacement. + +This pass replaces ATen operators with optimized Triton kernels in the graph. +""" + +import logging + +import torch +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +logger = logging.getLogger(__name__) +triton = torch.ops.triton + +# Global mapping from edge dialect operators to Triton kernel functions +EDGE_TO_TRITON_KERNELS = { + exir_ops.edge.aten.scaled_dot_product_attention.default: triton.sdpa, +} + + +class ReplaceEdgeOpWithTritonOpPass(PassBase): + """ + Pass to replace ATen operators with Triton kernels. + + This pass scans the graph for Edge operators that have registered Triton + replacements using EDGE_TO_TRITON_KERNELS and replaces them with the + optimized Triton implementations. + """ + + def __init__(self): + """Initialize the pass.""" + super().__init__() + self._replacement_count = 0 + + def call(self, graph_module: GraphModule) -> PassResult: + """ + Execute the pass on the graph module. + + Args: + graph_module: The graph module to transform + + Returns: + PassResult indicating success/failure and the modified graph module + """ + self._replacement_count = 0 + modified = False + + if not EDGE_TO_TRITON_KERNELS: + return PassResult(graph_module, False) + + # Iterate through all nodes in the graph + for node in graph_module.graph.nodes: + if self._should_replace_node(node): + try: + self._replace_node_with_triton(graph_module, node) + modified = True + self._replacement_count += 1 + except Exception as e: + logger.warning(f"Failed to replace node {node.name}: {e}") + # Continue with other replacements even if one fails + + if modified: + # Recompile the graph module after modifications + graph_module.recompile() + + # logger.info(f"Replaced {self._replacement_count} nodes with Triton kernels") + print(f"Replaced {self._replacement_count} nodes with Triton kernels") + + return PassResult(graph_module, modified) + + def _should_replace_node(self, node: Node) -> bool: + """ + Check if a node should be replaced with a Triton kernel. + + Args: + node: The node to check + + Returns: + True if the node should be replaced + """ + # Only consider call_function nodes + if node.op != "call_function": + return False + + return node.target in EDGE_TO_TRITON_KERNELS + + def _replace_node_with_triton(self, graph_module: GraphModule, node: Node) -> None: + """ + Replace an edge dialect node with a Triton kernel call. + + Args: + graph_module: The graph module containing the node + node: The node to replace + """ + # Get the target operator (should be an exir_ops edge dialect op) + target = node.target + + # Get the replacement kernel + if target not in EDGE_TO_TRITON_KERNELS: + raise ValueError(f"No replacement kernel found for {target}") + + triton_kernel_fn = EDGE_TO_TRITON_KERNELS[target] + + # Create a new node with the Triton kernel + with graph_module.graph.inserting_before(node): + # The triton_kernel_fn is already registered as a custom op via @triton_op + # We can call it directly + new_node = graph_module.graph.call_function( + triton_kernel_fn, + args=node.args, + kwargs=node.kwargs, + ) + + # Copy metadata from original node + new_node.meta = node.meta.copy() + + # Replace all uses of the old node with the new node + node.replace_all_uses_with(new_node) + + # Remove the old node + graph_module.graph.erase_node(node) diff --git a/backends/mediatek/CMakeLists.txt b/backends/mediatek/CMakeLists.txt index ed9b37e1998..10c28be0053 100644 --- a/backends/mediatek/CMakeLists.txt +++ b/backends/mediatek/CMakeLists.txt @@ -46,5 +46,5 @@ executorch_target_link_options_shared_lib(neuron_backend) install( TARGETS neuron_backend EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} ) diff --git a/backends/mediatek/README.md b/backends/mediatek/README.md index e8a535b3fde..6ff751f8408 100644 --- a/backends/mediatek/README.md +++ b/backends/mediatek/README.md @@ -28,7 +28,7 @@ To get started with MediaTek's ExecuTorch libraries, download the [NeuroPilot Ex - **`mtk_converter-8.13.0+public-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl`**: This library preprocesses the model into a MediaTek representation. -- **`mtk_neuron-8.2.19-py3-none-linux_x86_64.whl`**: This library converts the model to binaries. +- **`mtk_neuron-8.2.23-py3-none-linux_x86_64`**: This library converts the model to binaries. Additionally, make sure to copy `NeuronAdapter.h` to the following directory: `backends/mediatek/runtime/include/api/`. @@ -45,7 +45,7 @@ Follow the steps below to setup your build environment: ``` - Install the two .whl downloaded from NeuroPilot Portal ```bash - pip3 install mtk_neuron-8.2.19-py3-none-linux_x86_64.whl + pip3 install mtk_neuron-8.2.23-py3-none-linux_x86_64.whl pip3 install mtk_converter-8.13.0+public-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ``` diff --git a/backends/mediatek/preprocess.py b/backends/mediatek/preprocess.py index b2a79dafabe..0e2b68335e0 100644 --- a/backends/mediatek/preprocess.py +++ b/backends/mediatek/preprocess.py @@ -26,7 +26,7 @@ HEADER_SIZE = 13 HEADER_VERSION = 1 REQUIRED_COMPILE_SPEC_KEYS = {"platform-config"} -SUPPORTED_PLATFORM_CONFIGS = {"mt6989", "mt6991"} +SUPPORTED_PLATFORM_CONFIGS = {"mt6989", "mt6991", "mt6993"} def assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None: diff --git a/backends/mediatek/scripts/mtk_build.sh b/backends/mediatek/scripts/mtk_build.sh index 599f754d7bc..d42e5f7e10a 100755 --- a/backends/mediatek/scripts/mtk_build.sh +++ b/backends/mediatek/scripts/mtk_build.sh @@ -30,6 +30,7 @@ cmake -DCMAKE_INSTALL_PREFIX="${build_dir}" \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_NEURON=ON \ -B"${build_dir}" diff --git a/backends/nxp/CMakeLists.txt b/backends/nxp/CMakeLists.txt index 43fcaa24d19..bfc4c046be6 100644 --- a/backends/nxp/CMakeLists.txt +++ b/backends/nxp/CMakeLists.txt @@ -17,5 +17,5 @@ target_include_directories( install( TARGETS executorch_delegate_neutron EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} ) diff --git a/backends/nxp/README.md b/backends/nxp/README.md index 10eb1290a8b..8b76d1e276b 100644 --- a/backends/nxp/README.md +++ b/backends/nxp/README.md @@ -15,7 +15,6 @@ networks, as well as the ability to adapt and scale to new model architectures, to AI workloads. ML application development with the eIQ Neutron NPU is fully supported by the [eIQ machine learning software development environment](https://www.nxp.com/design/design-center/software/eiq-ml-development-environment/eiq-toolkit-for-end-to-end-model-development-and-deployment:EIQ-TOOLKIT). The eIQ AI SW Stack provides a streamlined development experience for developers and end-users of NXP products. -eIQ extensions connect broader AI ecosystems to the edge, such as the NVIDIA TAO extension, which enables developers to bring AI models trained and fine-tuned with TAO to NXP-powered edge devices. ## Supported NXP platforms @@ -35,37 +34,28 @@ improvements. NXP and the ExecuTorch community is actively developing this codeb ## Neutron Backend implementation and SW architecture Neutron Backend uses the eIQ Neutron Converter as ML compiler to compile the delegated subgraph to Neutron microcode. -The Neutron Converter accepts the ML model in LiteRT format, for the **eIQ Neutron N3** class therefore the Neutron Backend uses the LiteRT flatbuffers format as IR between the ExecuTorch and Neutron Converter ML compiler. - -The Neutron Backend in its early prototype phase, is based on existing NXP products, such as -onnx2tflite, known from the NXP's eIQ Toolkit. -The **onnx2tflite** is a converter from the ONNX format to LiteRT (formerly known as TFLite). -It consists of 3 stages: -* ONNX Model Parsing -* Tensor Format Inference, to identify tensors using channel-first layer -* ONNX to LiteRT Conversion -* Optimization Passes, which operate on top of the LiteRT format -* LiteRT Serialization - -Due to the similarities between ONNX to LiteRT and Edge to LiteRT conversion, the Neutron Backend's -currently leverages the Tensor format Inference and LiteRT Optimizer. -This shall be considered as temporary solution, intended to be replaced with: -* Dim Order (https://github.com/pytorch/executorch/issues/4873) -* Corresponding ExecuTorch/ATen passes - -before reaching higher maturity status by the end of 2025. +The Neutron Converter accepts the ML model in LiteRT format, for the **eIQ Neutron N3** class therefore the Neutron Backend +uses the LiteRT flatbuffers format as IR between the ExecuTorch and Neutron Converter ML compiler. ## Layout -The current code base is as follows: * `backend/ir/` - TFLite/LiteRT based IR to represent the Edge Subgraph, taken from onnx2tflite code base and extended to support Edge Dialect to LiteRT conversion. * `backend/ir/converter` - Neutron Backends conversion from Edge (ATen) Dialect to LiteRT, TFLite. The subfolder `node_conveters` is structured as single module for each Edge operator. - * `backend/ir/lib` - automatically generated handlers from LiteRT flatbuffers schema + * `backend/ir/lib` - automatically generated handlers from LiteRT flatbuffers schema. * `backend/ir/tflite_generator` and `backend/ir/tflite_optimizer` handle the serialization of the in-memory built subgraph for delegation into LiteRT/TFLite flatbuffers representation. Code taken from the onnx2tflite tool. -* `quantizer` - Neutron Backends quantizer implementation. +* `edge_passes` - Various passes operating on Edge dialect level. +* `quantizer` - Neutron Backend quantizer implementation. +* `runtime` - Neutron Backend runtime implementation. For running compiled on device. +* `tests/` - Unit tests for Neutron backend. + * `tests/converter/node_converter` - Operator level unit tests. + +* `examples/nxp/` - Example models and scripts for running them. + +## Examples +Please see this [README.md](https://github.com/pytorch/executorch/blob/main/examples/nxp/README.md). ## Help & Improvements If you have problems or questions or have suggestions for ways to make diff --git a/backends/nxp/TARGETS b/backends/nxp/TARGETS index d56ac60242c..a5a0508b33c 100644 --- a/backends/nxp/TARGETS +++ b/backends/nxp/TARGETS @@ -32,6 +32,18 @@ runtime.python_library( ], ) +runtime.python_library( + name = "_passes", + srcs = glob([ + "_passes/*.py", + ]), + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:pass_manager", + ], +) + runtime.python_library( name = "quantizer", srcs = [ @@ -50,7 +62,7 @@ runtime.python_library( name = "neutron_sdk", srcs = glob(["backend/**/*.py"]), deps = [ - "fbsource//third-party/pypi/neutron_converter:neutron_converter", + "fbsource//third-party/pypi/neutron_converter:neutron_converter", ], ) @@ -65,10 +77,10 @@ runtime.python_library( deps = [ ":neutron_sdk", ":aten_passes", + ":_passes", ":quantizer", "fbsource//third-party/pypi/flatbuffers:flatbuffers", "fbsource//third-party/pypi/ml-dtypes:ml-dtypes", - "fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/serializer:serializer", "//executorch/exir:lib", "//executorch/backends/transforms:remove_getitem_op", "//caffe2:torch", diff --git a/backends/nxp/_passes/remove_getitem_pass.py b/backends/nxp/_passes/remove_getitem_pass.py new file mode 100644 index 00000000000..316cc13f49c --- /dev/null +++ b/backends/nxp/_passes/remove_getitem_pass.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025 NXP +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RemoveGetItemPass(ExportPass): + """ + This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator, + that extracts the first output. More specifically, we are only getting the first output from aten::maxpool2d operator. + Before Pass: + MaxPool2d ---> GetItem[max_values, max_indexes] + After Pass: + MaxPool2d -> max_values + """ + + def call(self, graph_module: torch.fx.GraphModule): + module = graph_module + for node in module.graph.nodes: + if node.op == "call_function": + if ( + node.target.__name__ == "aten.max_pool2d_with_indices.default" + or node.target.__name__ == "aten.max.dim" + ): + users = list(node.users.keys()) + + if len(users) != 1: + if len(users) == 2 and node.target.__name__ == "aten.max.dim": + # Two users is allowed for max.dim. For that case, + # rather than removing the getitem node in this + # pass, we handle the getitem nodes in the op's + # visitor when serializing + continue + else: + raise AssertionError( + f"Invalid number of users for {node.target.__name__}: {len(users)}" + ) + + getitem_node = list(node.users.keys())[0] + + if getitem_node.target.__name__ != "getitem": + raise AssertionError( + f"Expected max node's user to be getitem, got {getitem_node.target.__name__}" + ) + + getitem_index = getitem_node.args[1] + + with module.graph.inserting_before(node): + if ( + node.target.__name__ + == "aten.max_pool2d_with_indices.default" + ): + if getitem_index != 0: + raise AssertionError( + f"Expected second argument of getitem node for {node.target.__name__} to be 0, got " + f"{getitem_index}. XNNPACK delegate currently only supports getting just the max " + "values from the op but not getting the corresponding indices." + ) + new_max_wd = module.graph.create_node( + "call_function", + exir_ops.edge.aten.max_pool2d.default, + args=node.args, + kwargs=node.kwargs, + ) + + else: + if getitem_index != 0: + raise AssertionError( + f"Expected second argument of getitem node for {node.target.__name__} to be 0, got " + f"{getitem_index}. XNNPACK delegate currently only supports getting just the max " + "values or getting both the max values and their corresponding indices from the " + "op, but not getting the indices alone." + ) + new_max_wd = module.graph.create_node( + "call_function", + exir_ops.edge.aten.amax.default, + args=node.args, + kwargs=node.kwargs, + ) + + # MODIFIED PART START + # Make sure to preserve the inferred node format. + new_max_wd.meta[NXP_NODE_FORMAT] = node.meta.get( + NXP_NODE_FORMAT, NodeFormat.NONE + ) + # MODIFIED PART END + + getitem_node.replace_all_uses_with(new_max_wd) + + module.graph.erase_node(getitem_node) + module.graph.erase_node(node) + + graph_module.recompile() + # Propagate metadata and retrace module + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/nxp/aten_passes/fuse_linear_and_add_pass.py b/backends/nxp/aten_passes/fuse_linear_and_add_pass.py new file mode 100644 index 00000000000..20a32c1bcac --- /dev/null +++ b/backends/nxp/aten_passes/fuse_linear_and_add_pass.py @@ -0,0 +1,204 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from executorch.backends.nxp.backend.edge_helper import ( + try_get_tensor_constant_from_node, +) +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix +from torch.export.unflatten import _assign_attr, _AttrKind +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +class FuseLinearAndAddPass(PassBase): + """Replace a sequence of `linear` and `add` nodes in the following pattern by a single `linear` node when possible. + │ + ┌──────▼──────┐ + │ aten.linear │ + └──────┬──────┘ │ + │ replace with ┌──────▼──────┐ + ┌─────▼────┐ ───────────► │ aten.linear │ + │ aten.add │ └──────┬──────┘ + └─────┬────┘ + ▼ + """ + + def _fuse_with_existing_bias( + self, + linear_node: Node, + other_add_input: Node, + graph_module: GraphModule, + alpha: float, + ) -> bool: + """Fuse the `linear` and `add` nodes provided the `linear` already has a bias. + The fusion can only be done if both the "biases" have static data, which can be added together to get a + single bias. + + :return: True, if the nodes were successfully merged. False, otherwise. + """ + + linear_bias = linear_node.args[2] + if other_add_input.meta["val"].shape != linear_bias.meta["val"].shape: + # The biases cannot be added together due to their different shapes. + # Shape broadcasting is not applicable, as the only allowed `linear` bias shape is 1D ([output_features]). + return False + + bias_data = [ + try_get_tensor_constant_from_node(graph_module, linear_bias), + try_get_tensor_constant_from_node(graph_module, other_add_input), + ] + if any(data is None for data in bias_data): + return ( + False # Fusion is not possible because at least 1 bias is not static. + ) + + # Add the bias data together, to obtain the combined bias. Take the `alpha` attribute into account. + combined_bias = bias_data[0] + bias_data[1] * alpha + + # Create a new node containing the combined bias data. + combined_bias_name = get_new_attr_name_with_prefix( + linear_bias.name + "combined" + )(graph_module) + _assign_attr( + torch.nn.Parameter(combined_bias), + graph_module, + combined_bias_name, + _AttrKind.PARAMETER, + ) + with graph_module.graph.inserting_before(linear_node): + new_bias_node = graph_module.graph.get_attr(combined_bias_name) + + # Use the combined bias as the new bias for the `Linear`. + linear_node.args = ( + linear_node.args[:2] + (new_bias_node,) + linear_node.args[3:] + ) + return True + + def _fuse_without_existing_bias( + self, + linear_node: Node, + other_add_input: Node, + graph_module: GraphModule, + alpha: float, + ) -> bool: + """Fuse the `linear` and `add` provided the `linear` does not already have a bias. + + :return: True, if the nodes were successfully merged. False, otherwise. + """ + + # The weights have shape (out_features, in_features). + output_features = linear_node.args[1].meta["val"].shape[0] + new_bias_shape = other_add_input.meta["val"].shape + if list(new_bias_shape) != [output_features]: + return False # The `Add` is adding a tensor with shape that is not supported for the `Linear` bias. + + bias_data = try_get_tensor_constant_from_node(graph_module, other_add_input) + + if bias_data is None: + return False # Neutron doesn't support a dynamic bias, so fusion would be counterproductive. + + # It is possible that the `linear` comes before the `other_add_input` in the graph, so it cannot use it as an + # input directly. If the nodes are ordered as [linear, ..., other_add_input, ... add] (which is valid), using + # `other_add_input` directly as an input to `Linear` would not follow topological order. + # Rearranging the nodes is not trivial, as the graph could be complex (ultimately, the + # `other_add_input` could even originate from the `Linear` node...). + # Since the `other_add_input` has static data, we can create a new node with the data just before the `Linear` + # to ensure topological order. + # Regardless of the node ordering, the `add.Tensor` attribute `alpha` multiplies the second `add` input. If + # `alpha != 1`, we would have to insert a `mul` operator if we wanted to keep the original parameter node. + # Therefore, it is better to create a new static parameter node for the multiplied data in this case as well. + nodes = list(graph_module.graph.nodes) + if nodes.index(linear_node) < nodes.index(other_add_input) or alpha != 1.0: + # Problematic order, or required multiplication. + + # Handle the `aten.add.Tensor` attribute `alpha`. + bias_data *= alpha + + # Create a unique name. + new_bias_name = get_new_attr_name_with_prefix(linear_node.name + "_bias")( + graph_module + ) + _assign_attr(bias_data, graph_module, new_bias_name, _AttrKind.PARAMETER) + with graph_module.graph.inserting_before(linear_node): + new_bias_node = graph_module.graph.get_attr(new_bias_name) + + # Use the added tensor as the new `Linear` bias. + linear_node.args = ( + linear_node.args[:2] + (new_bias_node,) + linear_node.args[2:] + ) + return True + + else: + # Use the `other_add_input` directly as the new bias. + linear_node.args = ( + linear_node.args[:2] + (other_add_input,) + linear_node.args[2:] + ) + return True + + def call(self, graph_module: GraphModule) -> Optional[PassResult]: + def _is_applicable_linear_node(node_: Node): + is_linear = ( + node_.op == "call_function" + and node_.target == torch.ops.aten.linear.default + ) + has_single_user = len(node.users) == 1 + + return is_linear and has_single_user + + def _is_add(node_: Node): + return ( + node_.op == "call_function" + and node_.target == torch.ops.aten.add.Tensor + ) + + made_changes = False + for node in graph_module.graph.nodes: + if not _is_applicable_linear_node( + linear_node := node + ): # Also ensures a single user. + continue + + if not _is_add(add_node := list(linear_node.users.keys())[0]): + continue # Not the `Linear` -> `Add` case. + + if len(add_node.args) != 2: + continue # Unexpected case. + + # The `aten.add.Tensor` carries out the expression `out = input[0] + alpha × input[1]`. + # https://docs.pytorch.org/docs/stable/generated/torch.add.html + alpha = add_node.kwargs.get("alpha", 1.0) + if add_node.args[0] == linear_node: + other_add_input = add_node.args[1] + + else: + # The fusion is not implemented. The `other_add_input` would have to be divided by `alpha` before the + # fusion, and a `mul` operator would have to be added after the `linear` to multiply its output by + # `alpha`. + continue + + if len(linear_node.args) > 2: + if not self._fuse_with_existing_bias( + linear_node, other_add_input, graph_module, alpha + ): + continue # The nodes could not be fused. + + else: + # The `Linear` doesn't have a bias yet. + if not self._fuse_without_existing_bias( + linear_node, other_add_input, graph_module, alpha + ): + continue # The nodes could not be fused. + + # Use the output of the `Linear` instead of the `Add`, and remove the now unused `Add` node. + add_node.replace_all_uses_with(linear_node) + graph_module.graph.erase_node(add_node) + + made_changes = True + + return PassResult(graph_module, made_changes) diff --git a/backends/nxp/aten_passes/move_activation_before_concat.py b/backends/nxp/aten_passes/move_activation_before_concat.py new file mode 100644 index 00000000000..8ba306d42e2 --- /dev/null +++ b/backends/nxp/aten_passes/move_activation_before_concat.py @@ -0,0 +1,102 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec + +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +class MoveActivationBeforeConcat(PassBase): + """Move some operators around in the following pattern. + This is a common pattern that emerges from the conversion of separable convolutions. + This optimization works together with joint quantization of compute nodes and activations. Without it, + it is not beneficial. + + │ │ │ │ + ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ + │ aten.conv2d │ ... │ aten.conv2d │ │ aten.conv2d │ ... │ aten.conv2d │ + └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ + └───────┐ ┌──────┘ │ │ + ┌──▼─────▼─┐ replace with ┌─────▼─────┐ ┌─────▼─────┐ + │ aten.cat │ ──────────────► │ aten.relu │ ... │ aten.relu │ + └────┬─────┘ └─────┬─────┘ └─────┬─────┘ + │ └───────┐ ┌───────┘ + ┌─────▼─────┐ ┌──▼─────▼─┐ + │ aten.relu │ │ aten.cat │ + └─────┬─────┘ └────┬─────┘ + │ │ + """ + + def __init__(self, neutron_target_spec: NeutronTargetSpec): + self.neutron_target_spec = neutron_target_spec + + def call(self, module: GraphModule) -> bool: + def _is_concat(node_: Node) -> bool: + return ( + node_.op == "call_function" + and node_.target == torch.ops.aten.cat.default + ) + + made_changes = False + + for node in module.graph.nodes: + if not _is_concat(node): + continue # Not cat node. + + cat_node = node + activation = next(iter(cat_node.users)) + + # Check if all cat inputs nodes are conv 2D or linear 2D type and their only user is cat. + if not all( + self.neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ): + continue + + # Check if following activation is supported on Neutron as fused activation. + if not ( + len(cat_node.users) == 1 + and self.neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + activation + ) + ): + continue + + # Loop all Cat input nodes and insert new activation after node. + for input_node in cat_node.all_input_nodes: + with module.graph.inserting_after(input_node): + new_activation = module.graph.call_function( + activation.target, + args=(), + kwargs=activation.kwargs, + ) + + new_activation.meta["source_fn_stack"] = [ + ( + new_activation.name, + activation.meta["source_fn_stack"][-1][-1], + ) + ] + new_activation.meta["val"] = input_node.meta["val"] + + # Replace the uses of the input node with the new activation node. + input_node.replace_all_uses_with(new_activation) + new_activation.args = (input_node, *activation.args[1:]) + + # Replace the uses of the activation node with the cat node. + activation.replace_all_uses_with(cat_node) + + module.graph.erase_node(activation) + + made_changes = True + + return PassResult(module, made_changes) diff --git a/backends/nxp/aten_passes/neutron_aten_pass_manager.py b/backends/nxp/aten_passes/neutron_aten_pass_manager.py index f6e3c374b19..35205c76c68 100644 --- a/backends/nxp/aten_passes/neutron_aten_pass_manager.py +++ b/backends/nxp/aten_passes/neutron_aten_pass_manager.py @@ -13,6 +13,12 @@ from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( FuseBatchNormWithLinearPass, ) +from executorch.backends.nxp.aten_passes.fuse_linear_and_add_pass import ( + FuseLinearAndAddPass, +) +from executorch.backends.nxp.aten_passes.move_activation_before_concat import ( + MoveActivationBeforeConcat, +) from executorch.backends.nxp.aten_passes.remove_nodes_with_known_outputs import ( RemoveNodesWithKnownOutputs, ) @@ -22,6 +28,7 @@ from executorch.backends.nxp.aten_passes.split_gru_based_on_num_layers import ( SplitGRUBasedOnNumLayers, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from executorch.exir.pass_manager import PassManager from torch import nn from torch.fx.passes.infra.pass_base import PassResult @@ -31,13 +38,17 @@ class NeutronAtenPassManager(PassManager): - def __init__(self, passes: list[PassType] = None): + def __init__( + self, neutron_target_spec: NeutronTargetSpec, passes: list[PassType] = None + ): passes: list[PassType] = passes or [ FuseBatchNormWithConvPass(), FuseBatchNormWithLinearPass(), SplitGroupConvolution(), SplitGRUBasedOnNumLayers(), RemoveNodesWithKnownOutputs(), + FuseLinearAndAddPass(), + MoveActivationBeforeConcat(neutron_target_spec), ] super().__init__(passes) diff --git a/backends/nxp/backend/edge_helper.py b/backends/nxp/backend/edge_helper.py index 061295ead79..d78997ea4a6 100644 --- a/backends/nxp/backend/edge_helper.py +++ b/backends/nxp/backend/edge_helper.py @@ -1,13 +1,27 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch + +from executorch.exir.dialects._ops import ops as exir_ops + from torch.fx import GraphModule, Node from torch.nn import Parameter +QUANTIZE_OPERATORS = [ + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, +] + +DEQUANTIZE_OPERATORS = [ + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, +] + + def input_tensor(node: Node, input_index: int) -> torch.Tensor: if len(node.all_input_nodes) <= input_index: raise IndexError @@ -62,12 +76,6 @@ def node_is_effectively_static_tensor( if node_is_static_tensor(node, parameters_mapping): return True - def _is_dequantize(node_: Node) -> bool: - return node_.target.__name__ in { - "quantized_decomposed.dequantize_per_tensor.default", - "quantized_decomposed.dequantize_per_channel.default", - } - return _is_dequantize(node) and node_is_static_tensor( node.args[0], parameters_mapping ) @@ -87,3 +95,99 @@ def try_get_tensor_constant_from_node( return None attr_itr = getattr(attr_itr, atom) return attr_itr + + +def _is_dequantize(node_: Node) -> bool: + return node_.op == "call_function" and node_.target in [ + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ] + + +def _is_quantize(node_: Node) -> bool: + return node_.op == "call_function" and node_.target in [ + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ] + + +def previous_non_qdq_node(node: Node, input_index: int = 0) -> Node | None: + """Return the first node which is not a `quantize` or `dequantize`, found by traversing the graph backwards + starting with the `node.args[input_index]`, + """ + current_node = node.args[input_index] + while True: + if _is_quantize(current_node) or _is_dequantize(current_node): + current_node = current_node.args[0] + else: + return current_node + + +Scale = list[float] | float +ZeroPoint = list[int] | int + + +def get_quantization_parameters_for(node: Node) -> tuple[Scale, ZeroPoint] | None: + if "quantize" not in node.target.__name__ or len(node.args) < 3: + return None + + return node.args[1], node.args[2] # Scale and zero_point + + +def get_non_qdq_users(node: Node) -> list[Node]: + """Return a list of nodes which consume the output of `node`, but Quantize/Dequantize nodes from QDQ clusters are + ignored. Meaning, the list of nodes [, ..., ] from the illustration below is returned. + + If the graph does not follow the QDQ pattern, an empty list is returned. + + │ + ┌───▼────┐ + │ `node` │ + └───┬────┘ + ┌────▼─────┐ + │ Quantize │ + └────┬─────┘ + ├─────── ... ──────┐ + ┌─────▼──────┐ ┌─────▼──────┐ + │ Dequantize │ ... │ Dequantize │ + └─────┬──────┘ └─────┬──────┘ + ┌────▼─────┐ ┌────▼─────┐ + │ │ ... │ │ + └────┬─────┘ └────┬─────┘ + + """ + + quant_nodes = list(node.users) + if len(quant_nodes) != 1 or quant_nodes[0].target not in [ + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + ]: + return [] + + dequant_nodes = list(quant_nodes[0].users) + if any( + dequant_node.target + not in [ + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + ] + for dequant_node in dequant_nodes + ): + return [] + + res = [] + for dequant_node in dequant_nodes: + res.extend(list(dequant_node.users)) + + return res + + +def is_channels_last_dim_order(dim_order: list[int]) -> bool: + if len(dim_order) < 3: + return False + + return list(dim_order) == [0] + list(range(2, len(dim_order))) + [1] diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index ddbbf5b2e3a..fdfa4a31bc6 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -18,10 +18,8 @@ from torch.fx import Node from torch.nn.parameter import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 -from executorch.backends.nxp.backend.node_format_inference import ( - NodeFormat, - NodeFormatInference, -) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT from executorch.exir.dialects._ops import ops as exir_ops # noinspection PyProtectedMember @@ -33,15 +31,19 @@ exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405 exir_ops.edge.aten.cat.default: CatConverter, # noqa F405 exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405 + exir_ops.edge.dim_order_ops._clone_dim_order.default: CloneConverter, # noqa F405 exir_ops.edge.aten.constant_pad_nd.default: ConstantPadNDConverter, # noqa F405 exir_ops.edge.aten.convolution.default: ConvolutionConverter, # noqa F405 exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405 exir_ops.edge.aten.max_pool2d.default: MaxPool2dConverter, # noqa F405 exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405 exir_ops.edge.aten.mm.default: MMConverter, # noqa F405 + exir_ops.edge.aten.mul.Tensor: MulTensorConverter, # noqa F405 exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405 exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405 + exir_ops.edge.aten.slice_copy.Tensor: SliceTensorConverter, # noqa F405 exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405 + exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405 exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405 exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405 exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405 @@ -54,28 +56,32 @@ class EdgeProgramToIRConverter: """ _default_conversion_config = ConversionConfig() + _default_target_spec = NeutronTargetSpec("imxrt700", "SDK_25_09") _default_delegation_options = CustomDelegationOptions() def convert_program( self, edge_program: ExportedProgram, - conversion_config=_default_conversion_config, + conversion_config: ConversionConfig = _default_conversion_config, + neutron_target_spec: NeutronTargetSpec = _default_target_spec, custom_delegation_options: CustomDelegationOptions = _default_delegation_options, - ) -> (bytes, dict): + ) -> (bytes, dict[str, NodeFormat]): """ Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes. :param edge_program: Converter ExportedProgram. :param conversion_config: ConversionConfig instance. + :param neutron_target_spec: Object for querying the target platform to retrieve its properties. :param custom_delegation_options: Custom user options which affect node delegation. :return: TFLite flatbuffers as bytes. """ - node_formats = NodeFormatInference(edge_program).identify_node_formats() parameters_mapping = self.map_inputs_to_parameters(edge_program) + dim_order_map = self.map_nodes_to_dim_order(edge_program) cc = self.build_conversion_context( parameters_mapping, - node_formats, + dim_order_map, + neutron_target_spec, conversion_config, custom_delegation_options, ) @@ -85,13 +91,16 @@ def convert_program( self._convert_qdq_cluster_q_dq_nodes(edge_program.graph.nodes, cc) self._process_nodes(edge_program.graph.nodes, cc) - # Assign output - io_formats = cc.tflite_builder.assign_model_io_to_subgraph_and_get_io_formats( - edge_program.graph_signature - ) + # Assign the model its inputs and outputs. + cc.tflite_builder.assign_model_io_to_subgraph(edge_program.graph_signature) - # TFLite model generation + # Apply optimizations and finalize the model. internal_tflite_model = cc.tflite_builder.finish() + + # Extract the formats of the model's inputs and outputs. + io_formats = cc.tflite_builder.get_io_formats(edge_program.graph_signature) + + # TFLite model generation flatbuffers_builder = flatbuffers.Builder() internal_tflite_model.gen_tflite(flatbuffers_builder) @@ -101,7 +110,7 @@ def convert_program( def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext): for node in nodes: if node.op == "placeholder": - node_format = context.node_formats[node] + node_format = node.meta[NXP_NODE_FORMAT] if node.name in context.parameters_mapping: # Node is placeholder and has data -> append as static tensor with data @@ -114,7 +123,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex context.tflite_builder.append_as_fake_tensor(node, node_format) elif node.op == "call_function": # Node is call function -> append only output as a tensor - node_format = context.node_formats[node] + node_format = node.meta[NXP_NODE_FORMAT] context.tflite_builder.append_as_fake_tensor(node, node_format) elif node.op == "output": # Nothing to do @@ -134,6 +143,7 @@ def _process_nodes(self, nodes: list[Node], conversion_context: ConversionContex qdq_related_functions = [ exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, ] @@ -168,15 +178,35 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet return result_map + @staticmethod + def map_nodes_to_dim_order(edge_program: ExportedProgram) -> dict[str, Parameter]: + """ + Create mapping between node names and their dim-orders. + + :param edge_program: EdgeProgram instance. + :return: Mapping from node name to dim-order. + """ + + return { + n.name: val.dim_order() + for n in edge_program.graph.nodes + if hasattr(val := n.meta.get("val", None), "dim_order") + } + @staticmethod def build_conversion_context( parameters_mapping: dict, - node_formats: dict[Node, NodeFormat], + dim_order_map: dict[str, ...], + neutron_target_spec: NeutronTargetSpec, conversion_config: ConversionConfig = _default_conversion_config, custom_delegation_options: CustomDelegationOptions = _default_delegation_options, ) -> ConversionContext: tflite_builder = AtenModelBuilderDirector( - 3, "TFLite from EdgeProgram", conversion_config + 3, + "TFLite from EdgeProgram", + neutron_target_spec, + dim_order_map, + conversion_config, ) # Add "sentinel" buffer (defined in schema.fbs) @@ -186,7 +216,6 @@ def build_conversion_context( tflite_builder, conversion_config, parameters_mapping, - node_formats, custom_delegation_options, ) @@ -203,7 +232,8 @@ def _convert_qdq_cluster_q_dq_nodes( :param conversion_context: ConversionContext instance. """ qdq_q_ops_converters = { - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQDequantizeConverter, # noqa F405 + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQPerTensorDequantizeConverter, # noqa F405 + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: QDQPerChannelDequantizeConverter, # noqa F405 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: QDQQuantizeConverter, # noqa F405 } diff --git a/backends/nxp/backend/ir/conversion_config.py b/backends/nxp/backend/ir/conversion_config.py index 4ac88eb467c..4ba66adc942 100644 --- a/backends/nxp/backend/ir/conversion_config.py +++ b/backends/nxp/backend/ir/conversion_config.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -13,8 +13,7 @@ def __init__(self, args: dict | None = None): :param args: Optional dictionary with conversion arguments. Unknown arguments are ignored. """ - self.keep_io_format: bool = False - self.skip_shape_inference: bool = False + self.use_neutron_for_format_conversion: bool = True self.allow_inputs_stripping: bool = True self.qdq_aware_conversion: bool = True self.symbolic_dimensions_mapping: dict[str, int] | None = None @@ -46,15 +45,6 @@ def __repr__(self): return "ConversionConfig[" + ", ".join(attrs) + "]" -class SkipShapeInferenceConfig(ConversionConfig): - - def __init__(self): - """ - Conversion config shortcut with disabled shape inference. - """ - super().__init__({"skip_shape_inference": True}) - - class QDQAwareConfig(ConversionConfig): def __init__(self): diff --git a/backends/nxp/backend/ir/conversion_context.py b/backends/nxp/backend/ir/conversion_context.py index 6fb7e98424e..d4746fbde01 100644 --- a/backends/nxp/backend/ir/conversion_context.py +++ b/backends/nxp/backend/ir/conversion_context.py @@ -10,8 +10,6 @@ from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import ( AtenModelBuilderDirector, ) -from executorch.backends.nxp.backend.node_format_inference import NodeFormat -from torch import Node from torch.nn import Parameter @@ -19,7 +17,6 @@ class ConversionContext: tflite_builder: AtenModelBuilderDirector conversion_config: ConversionConfig parameters_mapping: dict[str, Parameter] - node_formats: dict[Node, NodeFormat] custom_delegation_options: CustomDelegationOptions def __init__( @@ -27,7 +24,6 @@ def __init__( tflite_builder: AtenModelBuilderDirector, conversion_config: ConversionConfig, parameters_mapping: dict, - node_formats: dict[Node, NodeFormat], custom_delegation_options: CustomDelegationOptions, ): """ @@ -39,5 +35,4 @@ def __init__( self.tflite_builder = tflite_builder self.conversion_config = conversion_config self.parameters_mapping = parameters_mapping - self.node_formats = node_formats self.custom_delegation_options = custom_delegation_options diff --git a/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py b/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py index a420cea9aa7..658b4fc93f7 100644 --- a/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py +++ b/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py @@ -9,7 +9,7 @@ from executorch.backends.nxp.backend.ir.converter.conversion import translator from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model -from executorch.backends.nxp.backend.node_format_inference import NodeFormat +from executorch.backends.nxp.backend.node_format import NodeFormat from torch.fx import Node from torch.nn import Parameter @@ -88,19 +88,40 @@ def append_operators(self, ops_to_add: list[tflite_model.Operator]): self.check_and_append_operator(op) - def assign_model_io_to_subgraph_and_get_io_formats( - self, graph_signature - ) -> dict[str, dict]: - """ - Assign model's inputs/outputs to SubGraph. + def get_io_formats(self, graph_signature) -> dict[str, dict[str, TensorFormat]]: + """Get a mapping from tensor names to their formats. - :param graph_signature: Instance of GraphSignature. + :param graph_signature: Instance of GraphSignature. :returns: Mapping between IO tensors' names and their formats. """ io_formats = { "inputs": {}, "outputs": {}, } + for input_name in graph_signature.user_inputs: + tensor = self.tensor_for_name(input_name) + assert input_name == tensor.name, ( + "Program's input name doesn't match with tensor name in TFLite. " + "Input was probably redirected." + ) + io_formats["inputs"][tensor.name] = tensor.tensor_format + + for output_name in graph_signature.user_outputs: + tensor = self.tensor_for_name(output_name) + assert output_name == tensor.name, ( + "Program's output name doesn't match with tensor name in TFLite. " + "Output was probably redirected." + ) + io_formats["outputs"][tensor.name] = tensor.tensor_format + + return io_formats + + def assign_model_io_to_subgraph(self, graph_signature): + """ + Assign model's inputs/outputs to SubGraph. + + :param graph_signature: Instance of GraphSignature. + """ self.get_sub_graph().inputs = tflite_model.SubGraphInputs() for input_name in graph_signature.user_inputs: @@ -110,7 +131,6 @@ def assign_model_io_to_subgraph_and_get_io_formats( "Input was probably redirected." ) self.get_sub_graph().inputs.tmp_inputs.append(tensor) - io_formats["inputs"][tensor.name] = tensor.tensor_format self.get_sub_graph().outputs = tflite_model.SubGraphOutputs() for output_name in graph_signature.user_outputs: @@ -120,7 +140,3 @@ def assign_model_io_to_subgraph_and_get_io_formats( "Output was probably redirected." ) self.get_sub_graph().outputs.tmp_outputs.append(tensor) - - io_formats["outputs"][tensor.name] = tensor.tensor_format - - return io_formats diff --git a/backends/nxp/backend/ir/converter/builder/model_builder.py b/backends/nxp/backend/ir/converter/builder/model_builder.py index 1ca46237814..87b1e55bcf9 100755 --- a/backends/nxp/backend/ir/converter/builder/model_builder.py +++ b/backends/nxp/backend/ir/converter/builder/model_builder.py @@ -1,18 +1,20 @@ # # Copyright 2023 Martin Pavella -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # License: MIT # See the LICENSE_MIT for more details. # + from copy import deepcopy -from typing import Dict, List, Optional, Union +from itertools import chain +from typing import List, Optional, Union import executorch.backends.nxp.backend.ir.converter.conversion.translator as translator import executorch.backends.nxp.backend.ir.logger as logger import executorch.backends.nxp.backend.ir.tflite_generator.tflite_model as tflite_model - import numpy as np +from executorch.backends.nxp.backend.edge_helper import is_channels_last_dim_order from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder import ( quantization_verification, @@ -48,6 +50,10 @@ FlexTranspose, ) from executorch.backends.nxp.backend.ir.tflite_optimizer import optimizer +from executorch.backends.nxp.backend.neutron_operator_support import ( + transposition_is_supported_on_neutron, +) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec class ModelBuilder: @@ -59,33 +65,41 @@ class ModelBuilder: _tfl_model: tflite_model.Model - _tensor_name_map: Dict # Mapping 'str' to 'tflT.Tensor' + _tensor_name_map: dict # Mapping 'str' to 'tflT.Tensor' - # Maps BuiltinOperator to a Dict, mapping version to index. Operators of type 'BuiltinOperator.CUSTOM' + # Maps BuiltinOperator to a dict, mapping version to index. Operators of type 'BuiltinOperator.CUSTOM' # have their 'version' prepended with its name, for example "FlexErf_1". - op_code_type_index_map: Dict[BuiltinOperator, Dict[Union[str, int], int]] + op_code_type_index_map: dict[BuiltinOperator, dict[Union[str, int], int]] - _nchw_tensor_version: Dict # Mapping 'tflT.Tensor' to 'tflT.Tensor' which is + _nchw_tensor_version: dict # Mapping 'tflT.Tensor' to 'tflT.Tensor' which is # equal, but in NCHW format - _skipped_output_map: Dict # Mapping 'tflT.Tensor' objects that were outputs + _skipped_output_map: dict # Mapping 'tflT.Tensor' objects that were outputs # of skipped operators, to 'tflT.Tensor' outputs of # previous operators - _zeros_tensor_map: Dict # Mapping 'string' shapes to 'tflT.Tensor' objects + _zeros_tensor_map: dict # Mapping 'string' shapes to 'tflT.Tensor' objects - _default_conversion_config = ConversionConfig() + neutron_target_spec: NeutronTargetSpec + + dim_order_map: dict # Mapping tensor names to their ExecuTorch `dim_order`. conversion_config: ConversionConfig + _default_conversion_config = ConversionConfig() + def __init__( self, model_version: int, model_description: str, + neutron_target_spec: NeutronTargetSpec, + dim_order_map: dict[str, ...], conversion_config: ConversionConfig = _default_conversion_config, ) -> None: self._tfl_model = tflite_model.Model(model_version, model_description) + self.neutron_target_spec = neutron_target_spec self.conversion_config = conversion_config + self.dim_order_map = dim_order_map self.op_code_type_index_map = {} self._tensor_name_map = {} @@ -213,7 +227,7 @@ def channels_first_version_of(self, t_tensor: tflite_model.Tensor): new_tensor.shape = translator.channels_last_shape_to_channels_first( t_tensor.shape ) - new_tensor.tensor_format = new_tensor.tensor_format.to_node_format() + new_tensor.tensor_format = TensorFormat.CHANNELS_FIRST perm = translator.create_channels_last_to_channels_first_permutation( t_tensor.rank @@ -348,8 +362,31 @@ def _make_inputs_channels_first(self): for input_tensor in self.get_sub_graph().inputs.tmp_inputs: if input_tensor.tensor_format.is_channels_last(): + # The input must be permuted. + + if is_channels_last_dim_order( + self.dim_order_map.get(input_tensor.name, []) + ): + # Do NOT insert a Transpose, as the input will already be provided in the channels last format + # during runtime. + new_inputs.append(input_tensor) + continue + # Create a Transpose operator and replace the graph input + new_input_shape = translator.channels_last_shape_to_channels_first( + input_tensor.shape + ) + perm = translator.create_channels_first_to_channels_last_permutation( + input_tensor.rank + ) + + if not transposition_is_supported_on_neutron( + new_input_shape.vector, list(perm), self.neutron_target_spec + ): + new_inputs.append(input_tensor) + continue + if input_tensor.rank > 6: msg = ( f"Couldn't preserve the shape of input tensor '{input_tensor.name}', because it has " @@ -360,14 +397,9 @@ def _make_inputs_channels_first(self): new_input = self.duplicate_tensor( input_tensor, input_tensor.name + "_channels_first" ) - new_input.shape = translator.channels_last_shape_to_channels_first( - input_tensor.shape - ) - new_input.tensor_format = input_tensor.tensor_format.to_node_format() + new_input.shape = new_input_shape + new_input.tensor_format = TensorFormat.CHANNELS_FIRST - perm = translator.create_channels_first_to_channels_last_permutation( - input_tensor.rank - ) transpose = self._create_transpose_operator( new_input, input_tensor, perm ) @@ -390,8 +422,28 @@ def _make_outputs_channels_first(self): for output_tensor in self.get_sub_graph().outputs.tmp_outputs: if output_tensor.tensor_format.is_channels_last(): + # The output must be permuted. + + if is_channels_last_dim_order( + self.dim_order_map.get(output_tensor.name, []) + ): + # Do NOT insert a Transpose, as the output will be required to be in the channels last format + # during runtime. + new_outputs.append(output_tensor) + continue + # Add a Transpose operator, to make the output channels first + shape = output_tensor.shape.vector + perm = translator.create_channels_last_to_channels_first_permutation( + len(shape), True + ) + if not transposition_is_supported_on_neutron( + shape, perm, self.neutron_target_spec + ): + new_outputs.append(output_tensor) + continue + if output_tensor.rank > 6: logger.e( logger.Code.IO_PRESERVATION_ERROR, @@ -412,6 +464,34 @@ def _make_outputs_channels_first(self): self.get_sub_graph().outputs.tmp_outputs = new_outputs + def _keep_one_empty_buffer(self): + """Create a single empty `Buffer` object and assign it to all tensors in the model that don't have static data.""" + empty_buffer = self.get_first_empty_buffer() + + for t in self.get_tensors().vector: + if tensor_has_data(t): + # The buffer of `t` is not empty. + continue + + if t.tmp_buffer == empty_buffer: + # Already optimized. + continue + + if t.is_variable: + # The data of the tensor will change at runtime, so it shouldn't share the buffer with other tensors. + continue + + # It's safe to replace the buffer. + t.tmp_buffer = empty_buffer + + def replace_io_tensor_format_with_node_format(self): + for t in chain( + self.get_sub_graph().inputs.tmp_inputs, + self.get_sub_graph().outputs.tmp_outputs, + ): + if isinstance(t.tensor_format, TensorFormat): + t.tensor_format = t.tensor_format.to_equal_node_format() + def finish(self) -> tflite_model.Model: """Finalize and optimize the converted TFLite model. Then return it. @@ -419,17 +499,23 @@ def finish(self) -> tflite_model.Model: :return: The final TFLite model. """ - if self.conversion_config.keep_io_format: + if self.conversion_config.use_neutron_for_format_conversion: # If the input or output is channels last, add a Transpose operator, to make is channels first. self._make_inputs_channels_first() self._make_outputs_channels_first() # Apply optimizations to the internal TFLite model. - optimizer.Optimizer(self, self.conversion_config).optimize( + optimizer.Optimizer( + self, self.conversion_config, self.neutron_target_spec + ).optimize( self.conversion_config.optimization_whitelist, self.conversion_config.optimization_blacklist, ) + self._keep_one_empty_buffer() + + self.replace_io_tensor_format_with_node_format() + # Remove outputs, which are not produced by any node. Otherwise, there would be errors after inference. operator_outputs = [] for op in self.get_operators().vector: @@ -449,31 +535,7 @@ def finish(self) -> tflite_model.Model: return self._tfl_model - def _assign_tensor_and_buffer_indices( # noqa C901 - self, allow_inputs_stripping: bool - ): - """Correctly initialize all references via indices in all tensors and buffers.""" - - # Assign each buffer its index - for i, buffer in enumerate(self.get_buffers().vector): - buffer.tmp_index = i - - # Assign each tensor its index and its buffer index - for i, tensor in enumerate(self.get_tensors().vector): - if tensor.tmp_null_tensor: - # Using -1 as the index to the 'tensors' vector is way of telling the TFLite inference engine, that - # this tensor should not be used. - # https://github.com/tensorflow/tensorflow/blob/05404d959119d41a8ffb8a75c6f232cfd8540d45/tensorflow/lite/kernels/kernel_util.cc#L79-L98 - tensor.tmp_index = -1 - else: - tensor.tmp_index = i - - tensor.buffer = tensor.tmp_buffer.tmp_index - - # TODO Remove inputs and outputs that are not in the tensors collection - - # Assign 'Outputs' and 'Inputs' their tensor indices - outputs = self.get_sub_graph().outputs + def _assign_io_tensor_indices(self, inputs, outputs, allow_inputs_stripping: bool): for tensor in outputs.tmp_outputs: try: outputs.append(tensor.tmp_index) @@ -483,7 +545,6 @@ def _assign_tensor_and_buffer_indices( # noqa C901 f"The tensor '{tensor.name}' is among the model outputs, but does NOT appear in the graph!", ) - inputs = self.get_sub_graph().inputs for tensor in inputs.tmp_inputs: try: inputs.append(tensor.tmp_index) @@ -498,14 +559,46 @@ def _assign_tensor_and_buffer_indices( # noqa C901 f"The tensor '{tensor.name}' is among the model inputs, but does NOT appear in the graph!", ) - # Assign each operator its inputs and outputs indices - for operator in self.get_sub_graph().operators.vector: + def _assign_operators_io_tensor_indices(self, operators): + for operator in operators.vector: for inputTensor in operator.tmp_inputs: operator.inputs.append(inputTensor.tmp_index) for outputTensor in operator.tmp_outputs: operator.outputs.append(outputTensor.tmp_index) + def _assign_tensor_and_buffer_indices(self, allow_inputs_stripping: bool): + """Correctly initialize all references via indices in all tensors and buffers.""" + + # Assign each buffer its index + for i, buffer in enumerate(self.get_buffers().vector): + buffer.tmp_index = i + + # Assign each tensor its index and its buffer index + for i, tensor in enumerate(self.get_tensors().vector): + if tensor.tmp_null_tensor: + # Using -1 as the index to the 'tensors' vector is way of telling the TFLite inference engine, that + # this tensor should not be used. + # https://github.com/tensorflow/tensorflow/blob/05404d959119d41a8ffb8a75c6f232cfd8540d45/tensorflow/lite/kernels/kernel_util.cc#L79-L98 + tensor.tmp_index = -1 + else: + tensor.tmp_index = i + + tensor.buffer = tensor.tmp_buffer.tmp_index + + # TODO Remove inputs and outputs that are not in the tensors collection + + subgraph = self.get_sub_graph() + + # Assign 'Outputs' and 'Inputs' their tensor indices + self._assign_io_tensor_indices( + inputs=subgraph.inputs, + outputs=subgraph.outputs, + allow_inputs_stripping=allow_inputs_stripping, + ) + # Assign each operator its inputs and outputs indices + self._assign_operators_io_tensor_indices(operators=subgraph.operators) + def _build_operator_code( self, op_type: BuiltinOperator, version, custom_code: str = None ): @@ -773,29 +866,8 @@ def _remove_tensor_with_name(self, name): def append_new_tensor(self, t_tensor: tflite_model.Tensor, overwrite: bool = False): """Append the TFLite tensor 't_tensor' to the 'SubGraph.tensors' and register it.""" - - if t_tensor.name in self._tensor_name_map.keys(): - """Tensor has already been added. Sometimes however, ONNX models - will have tensors in their 'inputs' or 'outputs', which don't - belong there and are in fact static. I this case we need to - overwrite the existing tensors.""" - - if overwrite: - self._remove_tensor_with_name(t_tensor.name) - - # If the tenor previously appeared in ONNX 'inputs' or 'outputs', - # the old version MUST be removed from there. - self._remove_input_with_name(t_tensor.name) - self._remove_output_with_name(t_tensor.name) - - self.get_tensors().append(t_tensor) - self._tensor_name_map[t_tensor.name] = t_tensor - else: - logger.w(f"Tensor '{t_tensor.name}' is already in the tensors!") - - else: - self._tensor_name_map[t_tensor.name] = t_tensor - self.get_tensors().append(t_tensor) + self._tensor_name_map[t_tensor.name] = t_tensor + self.get_tensors().append(t_tensor) def append_new_buffer(self, buffer: tflite_model.Buffer): """Append the 'buffer' to the 'model.buffers'.""" @@ -1493,7 +1565,7 @@ def prepare_dynamic_tensor_for_correct_broadcasting_with_channels_first_tensors( # Prepend a partial identity, to keep leading dimensions unchanged. revert_perm = list(range(rank_diff)) + list(revert_perm) - # Now add a permutation to convert the extended ONNX shape to a TFLite shape + # Now add a permutation to convert the extended ExecuTorch shape to a TFLite shape to_tflite_perm = ( translator.create_channels_first_to_channels_last_permutation( output_rank @@ -1557,20 +1629,20 @@ def prepare_static_tensor_for_correct_broadcasting_with_channels_first_tensors( original_shape = translator.dims_to_channels_first( shape - ) # Same shape as in the ONNX model + ) # Same shape as in the ExecuTorch model # Prepend 1s to the shape - extended_onnx_shape = [1] * rank_diff + original_shape + extended_executorch_shape = [1] * rank_diff + original_shape # Convert the full shape to TFLite format - tflite_shape = translator.dims_to_channels_last(extended_onnx_shape) + tflite_shape = translator.dims_to_channels_last(extended_executorch_shape) tensor.shape = tflite_model.Shape(tflite_shape) # Statically transpose the data data = translator.convert_data_to_channels_first( data - ) # To the same shape as in the ONNX model - data = data.reshape(extended_onnx_shape) # Extend with leading 1s + ) # To the same shape as in the ExecuTorch model + data = data.reshape(extended_executorch_shape) # Extend with leading 1s tensor.tmp_buffer.data = translator.convert_data_to_channels_last( data ) # Convert to TFLite format @@ -1578,16 +1650,16 @@ def prepare_static_tensor_for_correct_broadcasting_with_channels_first_tensors( assert tflite_shape == list(tensor.tmp_buffer.data.shape) else: - # The tensor is the same as in the ONNX model. + # The tensor is the same as in the ExecuTorch model. - extended_onnx_shape = [1] * rank_diff + shape + extended_executorch_shape = [1] * rank_diff + shape # Convert the full shape to TFLite format - tflite_shape = translator.dims_to_channels_last(extended_onnx_shape) + tflite_shape = translator.dims_to_channels_last(extended_executorch_shape) tensor.shape = tflite_model.Shape(tflite_shape) # Statically transpose the data - data = data.reshape(extended_onnx_shape) # Extend with leading 1s + data = data.reshape(extended_executorch_shape) # Extend with leading 1s tensor.tmp_buffer.data = translator.convert_data_to_channels_last( data ) # Convert to TFLite format diff --git a/backends/nxp/backend/ir/converter/conversion/common.py b/backends/nxp/backend/ir/converter/conversion/common.py index 8230e39a7fa..318fe66dfbd 100755 --- a/backends/nxp/backend/ir/converter/conversion/common.py +++ b/backends/nxp/backend/ir/converter/conversion/common.py @@ -1,6 +1,6 @@ # # Copyright 2023 Martin Pavella -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # License: MIT # See the LICENSE_MIT for more details. @@ -12,7 +12,7 @@ 'conversion/builtin/' directory. """ -from typing import Any, List, MutableSequence, Optional +from typing import List, MutableSequence, Optional import executorch.backends.nxp.backend.ir.logger as logger from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model @@ -22,28 +22,8 @@ max_pool_2d_options, transpose_conv_options, ) -from torch.fx import Node - - -def exactly_one_is_none(obj1: Optional, obj2: Optional) -> bool: - """Determine if exactly 1 of the arguments is None, or not.""" - return (obj1 is None and obj2 is not None) or (obj1 is not None and obj2 is None) - - -def contains_duplicates(list_to_check: List[Any]) -> bool: - """Determine if given list has duplicate elements or not.""" - return len(list_to_check) != len(set(list_to_check)) - - -def clamp(val: int, start: int, end: int) -> int: - """Clamp an int value between start and end (inclusive) and return it.""" - if val < start: - return start - - elif val > end: - return end - return val +from torch.fx import Node def try_get_input(t_op: tflite_model.Operator, idx: int) -> tflite_model.Tensor | None: @@ -62,11 +42,6 @@ def try_get_input(t_op: tflite_model.Operator, idx: int) -> tflite_model.Tensor tensor = t_op.tmp_inputs[idx] - if tensor.name == "": - # ONNX allows the name "" for optional tensors. It indicates that the tensor should be ignored, and a default - # value should be used. Just like if the tensor was omitted altogether. - return None - return tensor @@ -101,7 +76,7 @@ def assign_2d_strides(options: StridedOptions, strides: Optional[List[int]]): If 'strides' is None, assign 1s. :param options: TFLite AveragePool2D, Conv2D, MaxPool2D or TransposeConv options object. - :param strides: An optional list of ONNX strides attribute. + :param strides: An optional list of ExecuTorch strides attribute. """ if strides is None: @@ -115,8 +90,8 @@ def assign_2d_strides(options: StridedOptions, strides: Optional[List[int]]): else: logger.e( - logger.Code.INVALID_ONNX_OPERATOR_ATTRIBUTE, - f"ONNX operator has invalid 'strides' attribute! ('{strides}')", + logger.Code.INVALID_OPERATOR_ATTRIBUTE, + f"ExecuTorch operator has invalid 'strides' attribute! ('{strides}')", ) @@ -188,32 +163,6 @@ def node_uses_shape_broadcasting(node: Node) -> bool: ) -def uses_multiple_input_types(t_op: tflite_model.Operator) -> bool: - """Determine if the input tensors of given TFLite operator use different data types or not. - - :param t_op: TFLite operator with 'tmp_inputs' initialized. - :return: True, if any two input tensors have a different data type. - False, if all input tensors use the same data type. - """ - - if t_op.tmp_inputs is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "common.uses_multiple_input_types(): 'tmp_inputs' are None!", - ) - - if len(t_op.tmp_inputs) == 0: - logger.e( - logger.Code.INTERNAL_ERROR, - "common.uses_multiple_input_types(): Operator has no inputs!", - ) - - first_input_type = t_op.tmp_inputs[0].type - return any( - input_tensor.type != first_input_type for input_tensor in t_op.tmp_inputs[1:] - ) - - class OpsList: """ Holder of TFLite operator (middle_op) that can be prefixed (pre_ops) of suffixed (post_ops) diff --git a/backends/nxp/backend/ir/converter/conversion/translator.py b/backends/nxp/backend/ir/converter/conversion/translator.py index 4f327c6ac80..1fe195843c0 100755 --- a/backends/nxp/backend/ir/converter/conversion/translator.py +++ b/backends/nxp/backend/ir/converter/conversion/translator.py @@ -1,6 +1,5 @@ -# # Copyright 2023 Martin Pavella -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # License: MIT # See the LICENSE_MIT for more details. @@ -9,10 +8,10 @@ translator Module contains functions for context-free conversion of various -things from ONNX to TFLite. +things from ExecuTorch to NeutronIR. """ -from typing import Any, Collection, List, Optional, Sequence, Tuple +from typing import Any, Collection, List, Optional, Sequence import executorch.backends.nxp.backend.ir.lib.tflite.Padding as tflPadding import executorch.backends.nxp.backend.ir.logger as logger @@ -21,16 +20,12 @@ import numpy as np import torch from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType -from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat -from executorch.backends.nxp.backend.ir.tflite_generator.meta.types import ( - TensorFlowDataType, -) def permute_static_tensor(tensor: tflite_model.Tensor, perm: list[int]): - """Take a static TFLite tensor and permute its shape and data according to the permutation in 'perm'. + """Take a static NeutronIR tensor and permute its shape and data according to the permutation in 'perm'. - :param tensor: Static TFLite tensor to permute. + :param tensor: Static NeutronIR tensor to permute. :param perm: Permutation to apply to the tensor. """ @@ -53,7 +48,7 @@ def permute_static_tensor(tensor: tflite_model.Tensor, perm: list[int]): def get_tflite_tensor_shape_with_explicit_padding( tflite_shape: List[int], explicit_padding: List[List[int]] ) -> List[int]: - """Get the resulting shape of a tensor with shape 'tflite_shape' (in TFLite format), after 'explicit_padding' is + """Get the resulting shape of a tensor with shape 'tflite_shape' (in NeutronIR format), after 'explicit_padding' is applied to it. """ @@ -62,7 +57,7 @@ def get_tflite_tensor_shape_with_explicit_padding( ): logger.e( logger.Code.INTERNAL_ERROR, - f"Cannot apply padding '{explicit_padding}' to TFLite shape '{tflite_shape}'!", + f"Cannot apply padding '{explicit_padding}' to NeutronIR shape '{tflite_shape}'!", ) total_padding = [ @@ -90,24 +85,9 @@ def get_tflite_tensor_shape_with_explicit_padding( return padded_shape -def convert_tensor_format_to_tflite(tensor_format: TensorFormat) -> TensorFormat: - """Convert the format of a tensor from ONNX to TFLite. - :return: The tensor_format converted to TFLite. - """ - if tensor_format is TensorFormat.CHANNELS_FIRST: - return TensorFormat.CHANNELS_LAST - - elif tensor_format not in (TensorFormat.FORMATLESS, TensorFormat.NONE): - logger.d( - f"translator.convert_tensor_format(): Got unexpected format '{tensor_format}'." - ) - - return tensor_format - - def dims_to_channels_first(channels_last_dimensions: List[int]) -> List[int]: - """Convert a list of ints which represent dimensions in the channels last (TFLite) format to the channels first - (ONNX) format. + """Convert a list of ints which represent dimensions in the channels last (NeutronIR) format to the channels first + (ExecuTorch) format. """ assert len(channels_last_dimensions) > 0, "Dimensions list is empty!" @@ -122,8 +102,8 @@ def dims_to_channels_first(channels_last_dimensions: List[int]) -> List[int]: def dims_to_channels_last(channels_first_dimensions: List[int]) -> List[int]: - """Convert a list of ints which represent dimensions in the channels first (ONNX) format to the channels last - (TFLite) format. + """Convert a list of ints which represent dimensions in the channels first (ExecuTorch) format to the channels last + (NeutronIR) format. """ assert len(channels_first_dimensions) > 0, "Dimensions list is empty!" @@ -171,7 +151,7 @@ def _same_upper_equals_same_lower( o_strides: Optional[List[int]] = None, o_dilations: Optional[List[int]] = None, ) -> bool: - """Determine if in a given particular setting, the values of the ONNX `auto_pads` attribute SAME_UPPER and + """Determine if in a given particular setting, the values of the ExecuTorch `auto_pads` attribute SAME_UPPER and SAME_LOWER represent the exact same padding. """ @@ -193,7 +173,7 @@ def _tflite_padding_compute_output_size( """ Calculates the output shape of the tensor with particular setting as tflite would. Implementation corresponds to tensorflow/lite/kernels/padding.h:ComputeOutSize() - :param padding: TFLite Padding value - 'Same' or 'Valid' + :param padding: NeutronIR Padding value - 'Same' or 'Valid' :param tflite_spatial_input_shape: input tensor shape :param tflite_kernel_shape: convolution kernel shape :param strides: strides (default is 1) @@ -229,7 +209,7 @@ def tflite_compute_padding_with_offset( dilations: Optional[List[int]] = None, ) -> (List[int], List[int]): """ - Calculate padding and offset for each dimension for particular convolution setting as TFLite. + Calculate padding and offset for each dimension for particular convolution setting as NeutronIR. Implementation corresponds to tensorflow/lite/kernels/padding.h:ComputePaddingWithOffset() :param tflite_input_shape: tensorflow lite input shape :param tflite_kernel_shape: tensorflow lite kernel shape @@ -272,14 +252,14 @@ def _is_same_padding( o_strides: Optional[List[int]] = None, o_dilations: Optional[List[int]] = None, ) -> bool: - """Determine if given ONNX 'pads' padding can be represented exactly with the TFLite 'SAME' padding type. - - :param o_pads: ONNX 'pads' attribute. - :param tflite_input_shape: The shape of the main input of the operator in TFLite format. - :param tflite_output_shape: The shape of the main output of the operator in TFLite format. - :param o_kernel_shape: ONNX 'kernel_shape' attribute. - :param o_strides: ONNX 'strides' attribute. Can be omitted. - :param o_dilations: ONNX 'dilations' attribute. Can be omitted. + """Determine if given ExecuTorch 'pads' padding can be represented exactly with the NeutronIR 'SAME' padding type. + + :param o_pads: ExecuTorch 'pads' attribute. + :param tflite_input_shape: The shape of the main input of the operator in NeutronIR format. + :param tflite_output_shape: The shape of the main output of the operator in NeutronIR format. + :param o_kernel_shape: ExecuTorch 'kernel_shape' attribute. + :param o_strides: ExecuTorch 'strides' attribute. Can be omitted. + :param o_dilations: ExecuTorch 'dilations' attribute. Can be omitted. """ if len(tflite_input_shape) == 0 or len(tflite_output_shape) == 0: @@ -289,7 +269,7 @@ def _is_same_padding( f"'{tflite_input_shape}' and output shape '{tflite_output_shape}'.", ) - # Calculate if the output shape corresponds to Same padding setting in TFLite + # Calculate if the output shape corresponds to Same padding setting in NeutronIR tflite_spatial_input_shape = tflite_input_shape[1:-1] tmp_spatial_output_shape = _tflite_padding_compute_output_size( tflPadding.Padding.SAME, @@ -302,10 +282,10 @@ def _is_same_padding( return False # For every dimension, the padding is added to the start and end of the dimension. - # TFLite padding 'SAME' tries to split it evenly, but in case of odd padding, 'SAME' adds the excess 1 at the end. - # TFLite represents this in the offset. The offset is added to the end of particular dimension, + # NeutronIR padding 'SAME' tries to split it evenly, but in case of odd padding, 'SAME' adds the excess 1 at the end. + # NeutronIR represents this in the offset. The offset is added to the end of particular dimension, # i.e. bottom for H dim, right for W dim and so on. - # ONNX represents this in 'pads' as [x1_begin, x2_begin,... , x1_end, x2_end,...]. + # ExecuTorch represents this in 'pads' as [x1_begin, x2_begin,... , x1_end, x2_end,...]. padding, offset = tflite_compute_padding_with_offset( tflite_input_shape, o_kernel_shape, tflite_output_shape, o_strides, o_dilations ) @@ -319,30 +299,6 @@ def _is_same_padding( return True -def permutations_are_inverse( - permutation1: Sequence[int], permutation2: Sequence[int] -) -> bool: - """Determine if given Transpose permutations are inverse of each other. - i.e. when applied back to back, there will be no effect. - - Example: - 0 3 1 2 - 0 2 3 1 - """ - - if len(permutation1) != len(permutation2): - logger.e( - logger.Code.INTERNAL_ERROR, - "translator.permutations_are_inverse(): permutations have different size!", - ) - - for i, perm2 in enumerate(permutation2): - if i != permutation1[perm2]: - return False - - return True - - def combine_permutations( permutation1: Sequence[int], permutation2: Sequence[int] ) -> List[int]: @@ -375,31 +331,35 @@ def shape_from_numpy(numpy_array): return tflite_model.Shape(dims) -def onnx_explicit_padding_to_tflite(onnx_pads: list[int]) -> list[list[int]]: - """Convert the attribute or input 'pads' of the ONNX 'Pad' operator to the 'paddings' input of the TFLite 'Pad' +def executorch_explicit_padding_to_tflite( + executorch_pads: list[int], +) -> list[list[int]]: + """Convert the attribute or input 'pads' of the ExecuTorch 'Pad' operator to the 'paddings' input of the NeutronIR 'Pad' class of operators. This function does NOT take tensor formats into consideration. """ - start_padding = onnx_pads[ - : len(onnx_pads) // 2 + start_padding = executorch_pads[ + : len(executorch_pads) // 2 ] # Padding at the start of each dimension - end_padding = onnx_pads[ - len(onnx_pads) // 2 : + end_padding = executorch_pads[ + len(executorch_pads) // 2 : ] # Padding at the end of each dimension return list(zip(start_padding, end_padding)) -def onnx_pads_to_tflite_explicit_padding(onnx_pads: List[int]) -> List[List[int]]: - """Convert an ONNX attribute 'pads' of operators such as Conv, MaxPool or AveragePool, to a list of ints which is - compatible with the TFLite 'Pad' operator. +def executorch_pads_to_tflite_explicit_padding( + executorch_pads: List[int], +) -> List[List[int]]: + """Convert an ExecuTorch attribute 'pads' of operators such as Conv, MaxPool or AveragePool, to a list of ints which is + compatible with the NeutronIR 'Pad' operator. """ - tflite_padding = onnx_explicit_padding_to_tflite(onnx_pads) + tflite_padding = executorch_explicit_padding_to_tflite(executorch_pads) - # TFLite also allows padding to the 'batch' and 'channels'. ONNX does not + # NeutronIR also allows padding to the 'batch' and 'channels'. ExecuTorch does not tflite_padding.insert(0, [0, 0]) tflite_padding.append([0, 0]) @@ -413,15 +373,15 @@ def _get_explicit_tflite_padding_for_same_lower( o_strides: Optional[List[int]] = None, o_dilations: Optional[List[int]] = None, ) -> List[List[int]]: - """Get the TFLite explicit padding required to represent ONNX 'SAME_LOWER' auto_pad for a particular setting. + """Get the NeutronIR explicit padding required to represent ExecuTorch 'SAME_LOWER' auto_pad for a particular setting. - :param tflite_input_shape: TFLite (NHWC) shape of the input tensor of the operator. - :param tflite_output_shape: TFLite (NHWC) shape of the output tensor of the operator. - :param o_kernel_shape: ONNX 'kernel_shape' attribute. - :param o_strides: Optional ONNX 'o_strides' attribute. - :param o_dilations: Optional ONNX 'o_dilations' attribute. + :param tflite_input_shape: NeutronIR (NHWC) shape of the input tensor of the operator. + :param tflite_output_shape: NeutronIR (NHWC) shape of the output tensor of the operator. + :param o_kernel_shape: ExecuTorch 'kernel_shape' attribute. + :param o_strides: Optional ExecuTorch 'o_strides' attribute. + :param o_dilations: Optional ExecuTorch 'o_dilations' attribute. - :return: A TFLite style explicit padding, compatible with the TFLite 'Pad' operator. + :return: A NeutronIR style explicit padding, compatible with the NeutronIR 'Pad' operator. """ padding, offset = tflite_compute_padding_with_offset( @@ -433,102 +393,15 @@ def _get_explicit_tflite_padding_for_same_lower( ] # In case of odd padding, the excess is added at the start end_padding = padding - onnx_explicit_padding = start_padding + end_padding - - # Return explicit ONNX padding converted to TFLite padding - return onnx_pads_to_tflite_explicit_padding(onnx_explicit_padding) - - -def convert_padding( - o_auto_pad: str, - o_pads: List[int], - tflite_input_shape: List[int], - tflite_output_shape: List[int], - o_kernel_shape: List[int], - o_strides: Optional[List[int]], - o_dilations: Optional[List[int]] = None, -) -> Tuple[tflPadding.Padding, Optional[List[List[int]]]]: - """Convert ONNX operator attributes 'pads' and 'auto_pad' to TFLite. - - :param o_auto_pad: ONNX operator attribute 'auto_pad' - :param o_pads: ONNX operator attribute 'pads' - :param tflite_input_shape: The shape of the main input tensor in the TFLite format. - :param tflite_output_shape: The shape of the main output tensor in the TFLite format. - :param o_kernel_shape: ONNX operator attribute 'kernel_shape' - :param o_strides: ONNX operator attribute 'strides' - :param o_dilations: ONNX operator attribute 'dilations' - - :return: A tuple. - The first element is the converted TFLite padding. - The second is None, if conversion is finished. Or it is a list of ints representing the explicit - padding in TFLite format (compatible with the 'Pad' operator), which needs to be provided by a - 'Pad' operator. Caller must add this operator using model_builder! - """ - - if o_auto_pad == "SAME_UPPER": - return tflPadding.Padding.SAME, None - - elif o_auto_pad == "SAME_LOWER": - if _same_upper_equals_same_lower( - tflite_input_shape, - tflite_output_shape, - o_kernel_shape, - o_strides, - o_dilations, - ): - return tflPadding.Padding.SAME, None - - else: - logger.d( - "'SAME_LOWER' auto_pad cannot be exactly represented in TFLite as padding 'SAME' or 'VALID'. " - "Inserting an extra 'Pad' operator." - ) - tflite_explicit_padding = _get_explicit_tflite_padding_for_same_lower( - tflite_input_shape, - tflite_output_shape, - o_kernel_shape, - o_strides, - o_dilations, - ) - return tflPadding.Padding.VALID, tflite_explicit_padding - - elif o_auto_pad == "VALID": - return tflPadding.Padding.VALID, None - - # auto_pad is NOTSET -> use explicit padding - elif o_pads is None or all(val == 0 for val in o_pads): - # No padding in any direction - return tflPadding.Padding.VALID, None - - elif _is_same_padding( - o_pads, - tflite_input_shape, - tflite_output_shape, - o_kernel_shape, - o_strides, - o_dilations, - ): - # Explicit padding can be represented with TFLite 'SAME' padding. - return tflPadding.Padding.SAME, None - - else: - # 'pads' cannot be converted directly. Return 'VALID' and the required explicit padding and caller must - # implement conversion by adding a 'Pad' operator. - - logger.d( - "Explicit ONNX 'pads' cannot be represented directly as 'SAME' or 'VALID'. " - "Inserting an extra 'Pad' operator." - ) - - # ONNX 'pads' uses different format than TFLite 'Pad' operator. Convert the explicit padding. - tflite_explicit_padding = onnx_pads_to_tflite_explicit_padding(o_pads) + executorch_explicit_padding = start_padding + end_padding - return tflPadding.Padding.VALID, tflite_explicit_padding + # Return explicit ExecuTorch padding converted to NeutronIR padding + return executorch_pads_to_tflite_explicit_padding(executorch_explicit_padding) def convert_data_to_channels_first(array: np.ndarray) -> np.ndarray: - """Convert a numpy array representing the data of a tensor from the channels last format (TFLite), to channels - first format (ONNX). + """Convert a numpy array representing the data of a tensor from the channels last format (NeutronIR), to channels + first format (ExecuTorch). :param array: Numpy array holding the tensor's data. :return: The transformed data. @@ -543,8 +416,8 @@ def convert_data_to_channels_first(array: np.ndarray) -> np.ndarray: def convert_data_to_channels_last(array: np.ndarray) -> np.ndarray: - """Convert a numpy array representing the data of a tensor from the channels first format (ONNX), to channels last - format (TFLite). + """Convert a numpy array representing the data of a tensor from the channels first format (ExecuTorch), to channels last + format (NeutronIR). :param array: Numpy array holding the tensor's data. :return: The transformed data. @@ -558,17 +431,6 @@ def convert_data_to_channels_last(array: np.ndarray) -> np.ndarray: return np.moveaxis(array, 1, -1) # Move the second axis (C), to the end -def channels_first_shape_to_channels_last( - channels_first_shape: tflite_model.Shape, -) -> tflite_model.Shape: - """Create a channels last version of a channels first 'tflite_model.Shape' object.""" - - dims = channels_first_shape.vector.copy() - dims = dims_to_channels_last(dims) - - return tflite_model.Shape(dims) - - def channels_last_shape_to_channels_first( nhwc_shape: tflite_model.Shape, ) -> tflite_model.Shape: @@ -580,23 +442,13 @@ def channels_last_shape_to_channels_first( return tflite_model.Shape(dims) -def convert_onnx_dimensions_to_tflite_shape(o_dims: List[int]) -> tflite_model.Shape: - """Convert list of ints representing the shape of an ONNX channels first Tensor to a TFLite 'Shape' object.""" - - dims = list(o_dims) # Copy just in case - - dims = dims_to_channels_last(dims) - - return tflite_model.Shape(dims) - - def create_channels_last_to_channels_first_permutation( rank: int, return_list: bool = False ) -> np.ndarray | list[int]: """Return a numpy array with data that describes the permutation, which would change a tensor from the channels - last (TFLite) format to the channels first (ONNX) format. + last (NeutronIR) format to the channels first (ExecuTorch) format. - This permutation is compatible with the TFLite `Transpose` operator. + This permutation is compatible with the NeutronIR `Transpose` operator. :param rank: The rank of the required permutation. :param return_list: If True, the function returns a list of ints. If False, a numpy array is returned. @@ -615,9 +467,9 @@ def create_channels_first_to_channels_last_permutation( rank: int, return_list: bool = False ) -> np.ndarray | list[int]: """Return a numpy array with data that describes the permutation, which would change a tensor from the channels - first (ONNX) format to the channels last (TFLite) format. + first (ExecuTorch) format to the channels last (NeutronIR) format. - This permutation is compatible with the TFLite `Transpose` operator. + This permutation is compatible with the NeutronIR `Transpose` operator. :param rank: The rank of the required permutation. :param return_list: If True, the function returns a list of ints. If False, a numpy array is returned. @@ -632,35 +484,8 @@ def create_channels_first_to_channels_last_permutation( return np.asarray(perm, np.int32) -def create_axis_to_last_perm(axis, num_dims): - """Create a numpy array representing the transpose permutations needed, to - make the 'axis' dimension, the last dimension. - """ - - dims = list(range(num_dims)) - - if axis == num_dims - 1: - return dims - elif axis >= num_dims or axis < 0: - logger.e( - logger.Code.INTERNAL_ERROR, - f"translator.create_axis_to_last_perm({axis},{num_dims}). Inputs don't make sense!", - ) - - # Remember axis dimension - axis_dim = dims[axis] - - # Move dimensions after 'axis' to the left - dims[axis:-1] = dims[axis + 1 : -1] - - # Add axis dimension to the end - dims.append(axis_dim) - - return np.asarray(dims, np.int32) - - def apply_permutation_to(target: List[Any], permutation: Collection[int]) -> List: - """Permute a list according to a permutation. Uses the same permutation format as the TFLite Transpose operator. + """Permute a list according to a permutation. Uses the same permutation format as the NeutronIR Transpose operator. :param target: A list of any types, to permute. Must be same size as the permutation. :param permutation: The permutation to apply to the target. @@ -678,7 +503,7 @@ def apply_permutation_to(target: List[Any], permutation: Collection[int]) -> Lis def create_inverse_permutation(permutation: List[int]) -> List[int]: """Create and return a permutation, that is the inverse of the given 'permutation' parameter. - Uses the same permutation format as the TFLite Transpose operator. + Uses the same permutation format as the NeutronIR Transpose operator. :param permutation: The permutation to create the inverse of. :return: Inverse permutation. @@ -694,38 +519,8 @@ def create_inverse_permutation(permutation: List[int]) -> List[int]: return [permutation.index(perm) for perm in range(len(permutation))] -def get_max_value_for_type(dtype: np.dtype) -> any: - """Return the maximum possible value for given numpy type.""" - if dtype.kind in ("i", "u"): - return np.iinfo(dtype).max - - elif dtype.kind == "f": - return np.finfo(dtype).max - - else: - logger.e( - logger.Code.INTERNAL_ERROR, - f"translator.get_max_value_for_type(): unexpected type {dtype.name}.", - ) - - -def get_min_value_for_type(dtype: np.dtype) -> any: - """Return the minimum possible value for given numpy type.""" - if dtype.kind in ("i", "u"): - return np.iinfo(dtype).min - - elif dtype.kind == "f": - return np.finfo(dtype).min - - else: - logger.e( - logger.Code.INTERNAL_ERROR, - f"translator.get_min_value_for_type(): unexpected type {dtype.name}.", - ) - - def convert_data_type(torch_type: torch.TensorType) -> TensorType: - """Convert Torch DataType to TFLite TensorType""" + """Convert Torch DataType to NeutronIR TensorType""" if torch_type == torch.float32: return TensorType.FLOAT32 @@ -753,7 +548,7 @@ def convert_data_type(torch_type: torch.TensorType) -> TensorType: def torch_type_to_numpy_type(torch_type: torch.TensorType) -> np.ScalarType: - """Convert Torch DataType to TFLite TensorType""" + """Convert Torch DataType to NeutronIR TensorType""" if torch_type == torch.float32: return np.dtype(np.float32) @@ -778,10 +573,10 @@ def torch_type_to_numpy_type(torch_type: torch.TensorType) -> np.ScalarType: def numpy_type_to_tf_lite(numpy_type: np.dtype) -> TensorType: # noqa C901 - """Convert the numpy data type to a corresponding TFLite 'TensorType'. + """Convert the numpy data type to a corresponding NeutronIR 'TensorType'. :param numpy_type: Numpy dtype to convert. - :return: Corresponding TFLite TensorType. + :return: Corresponding NeutronIR TensorType. """ numpy_type = numpy_type.type @@ -835,12 +630,12 @@ def numpy_type_to_tf_lite(numpy_type: np.dtype) -> TensorType: # noqa C901 else: logger.e( logger.Code.CONVERSION_IMPOSSIBLE, - f"Cannot convert numpy data type '{numpy_type}' to TFLite.", + f"Cannot convert numpy data type '{numpy_type}' to NeutronIR.", ) def tf_lite_type_to_numpy(tfl_type: TensorType) -> np.ScalarType: # noqa C901 - """Convert TFLite TensorType to numpy dtype""" + """Convert NeutronIR TensorType to numpy dtype""" if tfl_type == TensorType.FLOAT32: return np.dtype(np.float32) @@ -890,72 +685,5 @@ def tf_lite_type_to_numpy(tfl_type: TensorType) -> np.ScalarType: # noqa C901 else: logger.e( logger.Code.CONVERSION_IMPOSSIBLE, - f"Cannot convert TFLite type '{tfl_type}' to numpy dtype.", + f"Cannot convert NeutronIR type '{tfl_type}' to numpy dtype.", ) - - -def tflite_type_to_tensor_flow_data_type(tfl_type: TensorType) -> TensorFlowDataType: - """Convert TFLite TensorType to the internal type of TensorFlow.""" - match tfl_type: - case TensorType.FLOAT16: - # There seems to be no counterpart in the TF DataType. - logger.e( - logger.Code.INTERNAL_ERROR, - "tflite_type_to_tensor_flow_data_type(): float16.", - ) - case TensorType.FLOAT32: - return TensorFlowDataType.DT_FLOAT.value - case TensorType.FLOAT64: - return TensorFlowDataType.DT_DOUBLE.value - - case TensorType.INT4: - return TensorFlowDataType.DT_INT4.value - case TensorType.INT8: - return TensorFlowDataType.DT_INT8.value - case TensorType.INT16: - return TensorFlowDataType.DT_INT16.value - case TensorType.INT32: - return TensorFlowDataType.DT_INT32.value - case TensorType.INT64: - return TensorFlowDataType.DT_INT64.value - - case TensorType.UINT8: - return TensorFlowDataType.DT_UINT8.value - case TensorType.UINT16: - return TensorFlowDataType.DT_UINT16.value - case TensorType.UINT32: - return TensorFlowDataType.DT_UINT32.value - case TensorType.UINT64: - return TensorFlowDataType.DT_UINT64.value - - case TensorType.COMPLEX64: - return TensorFlowDataType.DT_COMPLEX64.value - case TensorType.COMPLEX128: - return TensorFlowDataType.DT_COMPLEX128.value - - case TensorType.STRING: - return TensorFlowDataType.DT_STRING.value - - case TensorType.BOOL: - return TensorFlowDataType.DT_BOOL.value - - case TensorType.RESOURCE: - return TensorFlowDataType.DT_RESOURCE.value - case TensorType.VARIANT: - return TensorFlowDataType.DT_VARIANT.value - - case _: - # All TFLite types are covered. Must be an invalid type. - logger.e( - logger.Code.INTERNAL_ERROR, - f"tflite_type_to_tensor_flow_data_type(): invalid TFLite type `{tfl_type}`.", - ) - - -def infer_kernel_shape(weight_tensor: tflite_model.Tensor) -> list[int]: - """Returns the kernel shape inferred from the weight tensor. - - Weight tensors shape expected in TFlite Format, where the 0th index is output channels count, last is input channels - count. - """ - return weight_tensor.shape.vector[1:-1] diff --git a/backends/nxp/backend/ir/converter/node_converter.py b/backends/nxp/backend/ir/converter/node_converter.py index ed624aaa411..b653718e643 100755 --- a/backends/nxp/backend/ir/converter/node_converter.py +++ b/backends/nxp/backend/ir/converter/node_converter.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from enum import Enum import torch @@ -16,8 +15,10 @@ AtenModelBuilderDirector, ) from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from executorch.exir.dialects._ops import ops as exir_ops from torch.fx import Node +from torch.fx.passes.infra.partitioner import Partition from torch.nn import Parameter @@ -37,15 +38,8 @@ def _is_dequant_node(node: torch.fx.Node) -> bool: ] -class Target(Enum): - IGNORE = "ignore" # No target platform. Any target specific restrictions will be ignored. - - RT700 = "imxrt700" - IMX95 = "imx95" - - @classmethod - def values(cls) -> list[str]: - return [elt.value for elt in cls] +def is_not_qdq_node(node: torch.fx.Node) -> bool: + return not (_is_quant_node(node) or _is_dequant_node(node)) class NodeConverter(ABC): @@ -89,7 +83,7 @@ def _is_supported_in_IR( @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: @@ -98,33 +92,55 @@ def _is_supported_on_target( can be used by operators with no target specific requirements. :param node: The node (edge operator) to check. - :param target: Value of the `Target` enum representing the target platform to check for. + :param neutron_target_spec: Object for querying the target platform to retrieve its properties. :param parameters_mapping: Dictionary mapping tensor names to their static data (if they have it). :param custom_delegation_options: Custom options which affect delegation. """ - return target == Target.RT700 + return True @classmethod def is_supported( cls, node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: """Check if the given `node` is supported in the IR and on the given `target` platform. :param node: torch.Node to check. - :param target: Value of the `Target` enum representing the target platform to check for. + :param neutron_target_spec: Object for querying the target platform to retrieve its properties. :param parameters_mapping: Dict mapping tensor names to their data. :param custom_delegation_options: Custom user options which affect node delegation. """ return cls._is_supported_in_IR( node, parameters_mapping, custom_delegation_options ) and cls._is_supported_on_target( - node, target, parameters_mapping, custom_delegation_options + node, neutron_target_spec, parameters_mapping, custom_delegation_options ) + @classmethod + def supports_partitioning_result( + cls, + node: Node, + partition_list: list[Partition], + custom_delegation_options: CustomDelegationOptions, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + ) -> bool: + """Check if the given `node` supports the assigned partitioning, which is stored the `partition_list`. Child + classes can overwrite this method in case they have delegation restrictions based on the context defined by + the partitioning result. + + :param node: torch.Node to check. + :param partition_list: List of proposed partitions. + :param custom_delegation_options: Custom user options which affect node delegation. + :param neutron_target_spec: NeutronTargetSpec instance. + :param parameters_mapping: Dictionary mapping tensor names to their static data. + :return: Boolean indicating whether the node supports the current partitioning. + """ + return True + @staticmethod def _has_shared_q_params_if_quantized(node: Node) -> bool: """Check if node has shared quantization parameters if it's quantized.""" @@ -174,6 +190,14 @@ def builder(self) -> AtenModelBuilderDirector: """ return self.context.tflite_builder + @property + def neutron_target_spec(self) -> NeutronTargetSpec: + """ + Get an instance of NeutronTargetSpec from the conversion context. + :return: NeutronTargetSpec instance. + """ + return self.builder.neutron_target_spec + def _create_tflite_op_with_io_tensors(self, node: Node) -> tflite_model.Operator: """ Create TFLite op wrapper with input/output tensors added into 'tmp_inputs' and 'tmp_outputs'. diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py index d1674e16a9f..3b8b9bf9b3f 100755 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py @@ -37,11 +37,15 @@ from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.mm_converter import ( MMConverter, ) +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.mul_tensor_converter import ( + MulTensorConverter, +) from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.permute_copy_converter import ( PermuteCopyConverter, ) from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.qdq_dequantize_converter import ( - QDQDequantizeConverter, + QDQPerChannelDequantizeConverter, + QDQPerTensorDequantizeConverter, ) from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.qdq_quantize_converter import ( QDQQuantizeConverter, @@ -52,9 +56,15 @@ from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.sigmoid_converter import ( SigmoidConverter, ) +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.slice_tensor_converter import ( + SliceTensorConverter, +) from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.softmax_converter import ( SoftmaxConverter, ) +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.sub_tensor_converter import ( + SubTensorConverter, +) from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.tanh_converter import ( TanhConverter, ) @@ -63,25 +73,29 @@ ) __all__ = [ + "AbsConverter", + "AdaptiveAvgPool2dConverter", "AddMMConverter", + "AddTensorConverter", + "AvgPool2dConverter", "CatConverter", + "CloneConverter", + "ConstantPadNDConverter", "ConvolutionConverter", + "HardTanhConverter", + "MaxPool2dConverter", + "MeanDimConverter", "MMConverter", + "MulTensorConverter", "PermuteCopyConverter", - "SoftmaxConverter", - "ViewCopyConverter", - "QDQDequantizeConverter", + "QDQPerChannelDequantizeConverter", + "QDQPerTensorDequantizeConverter", "QDQQuantizeConverter", - "ConstantPadNDConverter", "ReLUConverter", - "MeanDimConverter", - "MaxPool2dConverter", - "AvgPool2dConverter", - "AddTensorConverter", - "CloneConverter", - "AbsConverter", - "AdaptiveAvgPool2dConverter", - "HardTanhConverter", "SigmoidConverter", + "SliceTensorConverter", + "SoftmaxConverter", + "SubTensorConverter", "TanhConverter", + "ViewCopyConverter", ] diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py index c74baa61f67..cd5aa2ead81 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py @@ -9,11 +9,11 @@ from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, - Target, ) from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( add_options, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -22,20 +22,15 @@ class AddTensorConverter(NodeConverter): @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - match target: - case Target.RT700: - if node_uses_shape_broadcasting(node): - # Shape broadcasting may require the addition of `Transpose` ops during conversion. - return False - - return True + if node_uses_shape_broadcasting(node): + # Shape broadcasting may require the addition of `Transpose` ops during conversion. + return False - case _: - return False + return True @staticmethod def _is_supported_in_IR( diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py index 4f7f00fe5ba..bb8ab1048eb 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py @@ -8,17 +8,24 @@ from executorch.backends.nxp.backend.custom_delegation_options import ( CustomDelegationOptions, ) +from executorch.backends.nxp.backend.edge_helper import previous_non_qdq_node from executorch.backends.nxp.backend.ir.converter.conversion import translator +from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( + apply_permutation_to, + create_channels_first_to_channels_last_permutation, +) from executorch.backends.nxp.backend.ir.converter.node_converter import ( _is_dequant_node, _is_quant_node, NodeConverter, - Target, ) from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.concatenation_options import ( Concatenation, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node +from torch.fx.passes.infra.partitioner import Partition from torch.nn import Parameter @@ -72,51 +79,63 @@ def _all_io_shares_quantization_parameters(node: Node) -> bool: @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: if custom_delegation_options.force_delegate_cat: return True - match target: - case Target.RT700: - dim = CatConverter._get_normalized_dim(node) - - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1491 - if dim == 0: - return False + dim = CatConverter._get_normalized_dim(node) - # Neutron requires the channels to be a multiple of `8`. The channels could either be the second or the - # last dimension, depending on the formats of the node. The format, however, cannot be determined - # during conversion, as it depends on what other nodes are delegated. - input_channels = [ - # The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it - # will still be the channels in the IR. - _get_shape(input_)[1] - for input_ in node.all_input_nodes - ] + [ - # If the inputs/outputs are channels first, the last dimension will be the channels. - _get_shape(input_)[-1] - for input_ in node.all_input_nodes - ] - if any((input_channel % 8) != 0 for input_channel in input_channels): - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492 - return False + # Neutron requires the channels to be a multiple of `num_macs`. The channels could either be the second or the + # last dimension, depending on the formats of the node. + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # During conversion to IR, the shape will be permuted to channels last, and the dimension on index + # `1` will end up being the channels (last dim in NHWC). + channels_index = 1 + to_nhwc_perm = create_channels_first_to_channels_last_permutation( + len(node.meta["val"].shape), True + ) + dim = to_nhwc_perm.index( + dim + ) # Make sure the dim points to the NHWC dimension. + else: + # The shape will not be permuted during conversion, so the channels will remain the last dimension. + channels_index = -1 + + input_channels = [ + _get_shape(input_)[channels_index] for input_ in node.all_input_nodes + ] + output_channels = _get_shape(node)[channels_index] + + num_macs = neutron_target_spec.get_num_macs() + input_shapes = [_get_shape(input_) for input_ in node.all_input_nodes] + if any((input_channel % num_macs) != 0 for input_channel in input_channels): + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492 + + # If all input shapes are equal, the neutron is able to pad the last dimension of the inputs. + if not ( + input_shapes.count(input_shapes[0]) == len(input_shapes) + and dim == len(input_shapes[0]) - 1 + ): + return False - output_channels = [_get_shape(node)[1], _get_shape(node)[-1]] - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493 - if any((out_c % 8) != 0 for out_c in output_channels): - return False + if (output_channels % num_macs) != 0: + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493 - if len(node.all_input_nodes) < 2: # Not supported on Neutron - # TODO Try to skip the operator if this case is realistic. - return False + # If all input shapes are equal, the neutron is able to pad the last dimension of the output. + if not ( + input_shapes.count(input_shapes[0]) == len(input_shapes) + and dim == len(input_shapes[0]) - 1 + ): + return False - return True + if len(node.all_input_nodes) < 2: # Not supported on Neutron + # TODO Try to skip the operator if this case is realistic. + return False - case _: - return False + return True @staticmethod def _is_supported_in_IR( @@ -131,6 +150,48 @@ def _is_supported_in_IR( return True + @classmethod + def supports_partitioning_result( + cls, + node: Node, + partition_list: list[Partition], + custom_delegation_options: CustomDelegationOptions, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + ) -> bool: + # There is a bug in the NeutronConverter, where if none of the input dimensions before the one referenced by + # `dim` are `!= 1`, the `Concat` is not delegated. + # This only happens when the inputs to the `Concat` are model inputs, and not outputs of other + # operators. + cat_partition = [p for p in partition_list if node in p.nodes][0] + cat_inputs = map(previous_non_qdq_node, node.args[0]) + + if not all( + input_.op == "call_function" and input_ in cat_partition.nodes + for input_ in cat_inputs + ): + # Some inputs of the `cat` are NOT in the same partition as `cat`. + dim = CatConverter._get_normalized_dim(node) + input_shapes = [list(n.meta["val"].shape) for n in node.args[0]] + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # Transform the shapes to channels last. + to_nhwc_perm = create_channels_first_to_channels_last_permutation( + len(node.meta["val"].shape), True + ) + input_shapes = [ + apply_permutation_to(shape, to_nhwc_perm) for shape in input_shapes + ] + + # Transform the `dim` to refer to a channels last dimension. + dim = to_nhwc_perm.index(dim) + + for input_shape in input_shapes: + if not any(d != 1 for d in input_shape[:dim]): + # Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension. + return False + + return True + def convert(self, node: Node): """Convert the 'aten.cat' operator to TFLite 'Concatenation'.""" self.assert_convertible(node) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/clone_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/clone_converter.py index 1d370ab8c48..17b2cee9874 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/clone_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/clone_converter.py @@ -20,6 +20,11 @@ def _has_supported_memory_format(node: Node) -> bool: class CloneConverter(NodeConverter): + """ + This converter is responsible for converting both edge operators: + - aten.clone.default + - dim_order_ops._clone_dim_order.default + """ @staticmethod def _is_supported_in_IR( diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py index f58df1a88d9..29a8f7d51bb 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py @@ -17,7 +17,6 @@ from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, - Target, ) from executorch.backends.nxp.backend.ir.converter.quantization_utils import ( quantize_int8, @@ -27,6 +26,9 @@ pad_options, pad_v2_options, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec + +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter @@ -35,23 +37,24 @@ class ConstantPadNDConverter(NodeConverter): @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - match target: - case Target.RT700: - # TODO: Consider different tensor formats (dim-order) - paddings = node.args[1] - if len(paddings) > 4 and paddings[4:6] != [0, 0]: - # Attempt to Pad channels dimension, which is not supported on Neutron. - return False - - return True - - case _: + paddings = node.args[1] + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # Dim `1` will end up being the channels. It is padded by paddings[4:6]. + if len(paddings) > 4 and paddings[4:6] != [0, 0]: + # Attempt to Pad channels dimension -> currently not supported + return False + else: + # Dim `-1` will end up being the channels. It is padded by paddings[:2]. + if len(paddings) > 0 and paddings[:2] != [0, 0]: + # Attempt to Pad channels dimension -> currently not supported return False + return True + @staticmethod def _is_supported_in_IR( node: Node, @@ -71,10 +74,6 @@ def _is_supported_in_IR( if not NodeConverter._has_shared_q_params_if_quantized(node): return False - if len(paddings) > 4 and paddings[4:6] != [0, 0]: - # Attempt to Pad channels dimension -> currently not supported - return False - return True # noinspection PyMethodMayBeStatic diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py index 0f3a4b9bb5a..645274c7870 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py @@ -3,8 +3,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging - import numpy as np import torch @@ -25,7 +23,6 @@ from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, - Target, ) from executorch.backends.nxp.backend.ir.converter.node_converters.shared import ( conv_utils, @@ -33,18 +30,22 @@ from executorch.backends.nxp.backend.ir.converter.node_converters.shared.conv_utils import ( ConvConversionResult, ConvParameters, + get_node_tensor_params, ) from executorch.backends.nxp.backend.ir.converter.quantization_utils import ( set_quantization_parameters_to_tensor, ) from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType +from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( conv_2d_options, depthwise_conv_2d_options, reshape_options, + transpose_conv_options, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -53,45 +54,72 @@ class ConvolutionConverter(NodeConverter): @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - match target: - case Target.RT700: - activations = node.args[0] - weights = node.args[1] - groups = node.args[8] - - if activations.meta["val"].shape[0] != 1: - # Only batch size 1 is supported on neutron. - return False - - if groups == 1: # Regular convolution. - pass - elif conv_utils.group_conv_convertible_as_depthwise( - node, groups - ): # Depthwise convolution. - # Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted - # weights. In case the weights are dynamic, a Transpose operator would have to be added, which - # is not supported on Neutron. - if not node_is_effectively_static_tensor( - weights, parameters_mapping - ): - return False - elif conv_utils.group_conv_convertible_into_multiple_convolutions( - node, groups - ): # Separable conv. This should never be reached, as the node should have been decomposed into - # multiple parallel convolutions by the `SplitGroupConvolution` pre-processing pass. - logging.warning("Group convolution was not decomposed.") - return False - else: # Unexpected case (should never happen). - return False - - return True - - case _: + num_macs = neutron_target_spec.get_num_macs() + node_t_params = get_node_tensor_params(node) + weights = node.args[1] + conv_params = ConvParameters( + *ConvolutionConverter._get_convolution_arguments(node) + ) + + if node_t_params["batch_size"] != 1: + # Only batch size 1 is supported on neutron. + return False + + if conv_params.transposed: + # TransposeConv1d is not supported on Neutron + if len(conv_params.dilation) == 1: + return False + if not node_is_effectively_static_tensor(weights, parameters_mapping): + # Only supported if the weights are static, because TFLite `TransposeConv` uses permuted + # weights. In case the weights are dynamic, a Transpose operator would have to be added, which + # is not supported on Neutron. return False + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#876 TransposeConv2DKernelKind + if ( + conv_params.dilation != [1, 1] + or conv_params.padding[0] != 0 + or conv_params.padding[1] >= node_t_params["kernel_width"] + or ( + conv_params.padding[1] != 0 and node_t_params["inp_height"] != 1 + ) # Slice added by explicit padding + or conv_params.stride[0] != 1 + or ( + ( + conv_params.stride[1] != node_t_params["kernel_width"] / 2 + or node_t_params["out_height"] != 1 + ) + and conv_params.stride[1] != node_t_params["kernel_width"] + ) + or conv_params.stride[1] % 2 != 0 + or node_t_params["inp_channels"] % num_macs != 0 + or node_t_params["out_channels"] % num_macs != 0 + or node_t_params["kernel_width"] % 2 != 0 + or node_t_params["kernel_height"] != 1 + ): + return False + elif conv_params.groups == 1: # Regular convolution. + pass + elif conv_utils.group_conv_convertible_as_depthwise( + node, conv_params.groups + ): # Depthwise convolution. + # Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted + # weights. In case the weights are dynamic, a Transpose operator would have to be added, which + # is not supported on Neutron. + if not node_is_effectively_static_tensor(weights, parameters_mapping): + return False + elif conv_utils.group_conv_convertible_into_multiple_convolutions( + node, conv_params.groups + ): # Separable conv. + # Requires addition of `Split` and `Concatenation` operators, which are not supported on Neutron. + return False + else: # Unexpected case (should never happen). + return False + + return True @staticmethod def _is_supported_in_IR( @@ -103,11 +131,15 @@ def _is_supported_in_IR( dimensions = input_tensor_rank - 2 is_transposed = node.args[6] output_padding = node.args[7] + groups = node.args[8] - if is_transposed: + if is_transposed and conv_utils.group_conv_convertible_as_depthwise( + node, groups + ): + # TFLite does not support transposed depthwise convolution return False - if output_padding != [0] * dimensions: + if not is_transposed and output_padding != [0] * dimensions: return False if input_tensor_safe(node, 2) is None: @@ -122,6 +154,20 @@ def _is_supported_in_IR( Transposed = bool Groups = int + def _compute_slicing_params( + self, output_shape, explicit_padding + ) -> tuple[list[int], list[int]]: + begins = [] + sizes = [] + + for axis in range(len(output_shape)): + (start, end) = explicit_padding[axis] + + begins.append(start) + sizes.append(output_shape[axis] - start - end) + + return begins, sizes + @staticmethod def _get_convolution_arguments( conv_node: Node, @@ -137,7 +183,7 @@ def _get_convolution_arguments( list(padding), list(dilation), transposed, - out_padding, + list(out_padding), groups, ) @@ -238,7 +284,7 @@ def _convert_1d_conv( def _convert_unpadded_2D( self, t_op: tflite_model.Operator, conv_params: ConvParameters ) -> conv_utils.ConvConversionResult: - """Convert the `aten.convolution` into TFLite. The `padding` and `builtin_options` must be converter by the + """Convert the `aten.convolution` into TFLite. The `padding` and `builtin_options` must be converted by the caller. """ common.assign_2d_strides(t_op.builtin_options, conv_params.stride) @@ -266,15 +312,16 @@ def _convert_unpadded_2D( [output_channels], "zero_bias", bias_type, False ) - # Compute scale and zero point for bias tensor - input_scale = np.array(x.quantization.scale.vector) - weight_scale = np.array(w.quantization.scale.vector) - bias_scale = input_scale * weight_scale - bias_zero_point = np.zeros(weight_scale.shape, dtype=np.int64) + if w.type in [TensorType.INT8, TensorType.UINT8]: + # Compute scale and zero point for bias tensor + input_scale = np.array(x.quantization.scale.vector) + weight_scale = np.array(w.quantization.scale.vector) + bias_scale = input_scale * weight_scale + bias_zero_point = np.zeros(weight_scale.shape, dtype=np.int64) - set_quantization_parameters_to_tensor( - b, bias_scale, bias_zero_point, quantized_dimension=0 - ) + set_quantization_parameters_to_tensor( + b, bias_scale, bias_zero_point, quantized_dimension=0 + ) # Assign the operator its TFLite inputs and outputs t_op.tmp_inputs = [x, w, b] @@ -285,83 +332,195 @@ def _convert_unpadded_2D( return conversion_result - def _convert_2d_conv( + def _convert_transpose_conv( self, t_op: tflite_model.Operator, conv_params: ConvParameters - ) -> list[tflite_model.Operator]: - if conv_utils.group_conv_convertible_as_depthwise( - t_op, conv_params.groups - ): # Convert to `DepthwiseConv2D`. - t_op.builtin_options = depthwise_conv_2d_options.DepthwiseConv2D() - - conversion_result = self._convert_unpadded_2D(t_op, conv_params) - t_op.builtin_options.padding, explicit_padding = ( - aten_translator.convert_padding(conv_params.padding) - ) - if explicit_padding is not None: - # Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). - input_quantization = t_op.tmp_inputs[0].quantization - pad_value = ( - None - if input_quantization is None - else np.array(input_quantization.zero_point[0]).astype( - tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) - ) - ) - conversion_result.ops_list.add_pre( - self.builder.create_pad_operator_before( - t_op, 0, explicit_padding, constant_value=pad_value - ) + ) -> conv_utils.ConvConversionResult: + """Convert the `aten.convolution` into TFLite TransposeConv. The `builtin_options` must be + converted by the caller. + """ + common.assign_2d_strides(t_op.builtin_options, conv_params.stride) + + x: tflite_model.Tensor = t_op.tmp_inputs[0] + w: tflite_model.Tensor = t_op.tmp_inputs[1] + y: tflite_model.Tensor = t_op.tmp_outputs[0] + + if (b := try_get_input(t_op, 2)) is None: + # Operator has no bias. Convolution aten op can omit it, TFLite can't. + # Weight tensor format in TFLite: [C, kH, kW, O] + # (C = input channels, O = output channels, kW = kernel width, kH = kernel height) + output_channels = w.shape.vector[-1] + + if w.type == TensorType.FLOAT32: + bias_type = np.dtype(np.float32) + elif w.type in [TensorType.INT8, TensorType.UINT8]: + bias_type = np.dtype(np.int32) + else: + # Should never happen. + raise NotImplementedError( + f"Convolution node with unsupported weight type: {w.type}" ) - # DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels] - perm = [3, 1, 2, 0] - weight_tensor = conversion_result.conv_weight_tensor - if tensor_has_data(weight_tensor): - # Transpose cloned tensor statically - t_op.tmp_inputs[1] = self.builder.create_transposed_tensor( - weight_tensor, perm + b = self.builder.create_zeros_tensor( + [output_channels], "zero_bias", bias_type, True + ) + + if w.type in [TensorType.INT8, TensorType.UINT8]: + # Compute scale and zero point for bias tensor + input_scale = np.array(x.quantization.scale.vector) + weight_scale = np.array(w.quantization.scale.vector) + bias_scale = input_scale * weight_scale + bias_zero_point = np.zeros(weight_scale.shape, dtype=np.int64) + + set_quantization_parameters_to_tensor( + b, bias_scale, bias_zero_point, quantized_dimension=0 ) - else: - raise NotImplementedError("Dynamic Depthwise Conv weights.") - elif conv_utils.group_conv_convertible_into_multiple_convolutions( - t_op, conv_params.groups - ): - # This case should have been rejected in the `is_supported_on_target()` method. - raise RuntimeError("Group convolution was not decomposed.") + # TransposeConv weight tensor format in TFLite: [O, kH, kW, C] + # (C = input channels, O = output channels, kW = kernel width, kH = kernel height) + if tensor_has_data(w): + # Transpose cloned tensor statically + w = self.builder.create_transposed_tensor(w, [3, 1, 2, 0]) + if w.quantization is not None: + # Model is quantized + w.quantization.quantized_dimension = 0 else: - # Convert to regular `Conv2D`. - t_op.builtin_options = conv_2d_options.Conv2D() - conversion_result = self._convert_unpadded_2D(t_op, conv_params) - t_op.builtin_options.padding, explicit_padding = ( - aten_translator.convert_padding(conv_params.padding) + raise NotImplementedError("Dynamic Transpose Conv weights.") + w.tensor_format = TensorFormat.TRANSPOSE_CONV_2D_WEIGHT_FORMAT + + output_shape_tensor_data = np.asarray(y.shape.vector, dtype=np.int32) + o = self.builder.create_tensor_for_data( + output_shape_tensor_data, "output_shape" + ) + + # Assign the operator its TFLite inputs and outputs + t_op.tmp_inputs = [o, w, x, b] + t_op.tmp_outputs = [y] + conversion_result = ConvConversionResult(x, w, b, y, o) + t_op.builtin_options.padding, explicit_padding = ( + aten_translator.convert_padding(conv_params.padding) + ) + if explicit_padding is not None: + # Add padding to output shape to make sure we have computed all the data we need + for idx, padding in enumerate(explicit_padding): + output_shape_tensor_data[idx] += padding[0] + padding[1] + y.shape = tflite_model.Shape(output_shape_tensor_data.tolist()) + + # We need to "cut" produced tensor by size of explicit padding + begins, sizes = self._compute_slicing_params( + output_shape_tensor_data.tolist(), explicit_padding ) - if explicit_padding is not None: - # Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). - input_quantization = t_op.tmp_inputs[0].quantization - pad_value = ( - None - if input_quantization is None - else np.array(input_quantization.zero_point[0]).astype( - tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) - ) + slice_op = self.builder.create_slice_after(t_op, 0, begins, sizes) + conversion_result.ops_list.add_post(slice_op) + + conversion_result.ops_list.middle_op = t_op + + return conversion_result + + def _convert_2d_conv( + self, t_op: tflite_model.Operator, conv_params: ConvParameters + ) -> list[tflite_model.Operator]: + if conv_params.transposed: + t_op.builtin_options = transpose_conv_options.TransposeConv() + if conv_utils.group_conv_convertible_into_multiple_convolutions( + t_op, conv_params.groups + ): + # Convert to separated `TransposeConv`. + raise NotImplementedError("Separated TransposeConv not implemented.") + else: + # Convert to `TransposeConv`. + conversion_result = self._convert_transpose_conv(t_op, conv_params) + + else: + if conv_utils.group_conv_convertible_as_depthwise( + t_op, conv_params.groups + ): # Convert to `DepthwiseConv2D`. + t_op.builtin_options = depthwise_conv_2d_options.DepthwiseConv2D() + + conversion_result = self._convert_unpadded_2D(t_op, conv_params) + t_op.builtin_options.padding, explicit_padding = ( + aten_translator.convert_padding(conv_params.padding) ) - conversion_result.ops_list.add_pre( - self.builder.create_pad_operator_before( - t_op, 0, explicit_padding, constant_value=pad_value + if explicit_padding is not None: + # Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). + input_quantization = t_op.tmp_inputs[0].quantization + pad_value = ( + None + if input_quantization is None + else np.array(input_quantization.zero_point[0]).astype( + tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) + ) + ) + conversion_result.ops_list.add_pre( + self.builder.create_pad_operator_before( + t_op, 0, explicit_padding, constant_value=pad_value + ) ) + + # DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels] + perm = [3, 1, 2, 0] + weight_tensor = conversion_result.conv_weight_tensor + if tensor_has_data(weight_tensor): + # Transpose cloned tensor statically + t_op.tmp_inputs[1] = self.builder.create_transposed_tensor( + weight_tensor, perm + ) + + if t_op.tmp_inputs[1].quantization is not None: + # Model is quantized + t_op.tmp_inputs[1].quantization.quantized_dimension = 3 + else: + raise NotImplementedError("Dynamic Depthwise Conv weights.") + + elif conv_utils.group_conv_convertible_into_multiple_convolutions( + t_op, conv_params.groups + ): # Convert to separated `Conv2D`. + t_op.builtin_options = conv_2d_options.Conv2D() + + return conv_utils.create_separated_convolutions_based_on_group( + t_op, + conv_params, + self.builder, + self._convert_unpadded_2D, + conv_utils.conv_op_factory, + ) + + else: + # Convert to regular `Conv2D`. + t_op.builtin_options = conv_2d_options.Conv2D() + conversion_result = self._convert_unpadded_2D(t_op, conv_params) + t_op.builtin_options.padding, explicit_padding = ( + aten_translator.convert_padding(conv_params.padding) ) + if explicit_padding is not None: + # Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). + input_quantization = t_op.tmp_inputs[0].quantization + pad_value = ( + None + if input_quantization is None + else np.array(input_quantization.zero_point[0]).astype( + tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) + ) + ) + conversion_result.ops_list.add_pre( + self.builder.create_pad_operator_before( + t_op, 0, explicit_padding, constant_value=pad_value + ) + ) return conversion_result.ops_list.flatten() def convert(self, node: Node): self.assert_convertible(node) - stride, padding, dilation, _, _, groups = self._get_convolution_arguments(node) + stride, padding, dilation, transposed, out_padding, groups = ( + self._get_convolution_arguments(node) + ) t_op = self._create_tflite_op_with_io_tensors(node) - conv_params = ConvParameters(stride, padding, dilation, groups) + conv_params = ConvParameters( + stride, padding, dilation, transposed, out_padding, groups + ) rank = t_op.tmp_inputs[1].shape.len() if rank == 3: # Conv1D diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py index f03c403876f..ac09e564eb8 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py @@ -1,5 +1,4 @@ -# Copyright (c) 2025 NXP -# All rights reserved. +# Copyright 2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -12,7 +11,6 @@ from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, - Target, ) from executorch.backends.nxp.backend.ir.converter.node_converters.shared.reduce_utils import ( convert_axes_from_attribute, @@ -20,6 +18,8 @@ from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( mean_options, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter @@ -28,34 +28,38 @@ class MeanDimConverter(NodeConverter): @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - match target: - case Target.RT700: - # TODO: Consider different tensor formats (dim-order) - dim = node.args[1] - keepdim = node.args[2] if len(node.args) >= 3 else False - rank = len(node.args[0].meta["val"].shape) - dim = [MeanDimConverter._to_neg_dim(d, rank) for d in dim] + keepdim = node.args[2] if len(node.args) >= 3 else False + rank = len(node.args[0].meta["val"].shape) + dim = [MeanDimConverter._to_pos_dim(d, rank) for d in node.args[1]] - # Only last 2 dimensions (H, W) and keepdim=True with rank=4 are supported on Neutron. - if rank != 4 or dim not in [[-1, -2], [-2, -1]] or not keepdim: - return False + if rank != 4 or not keepdim: + # neutron-converter/src/OperatorC/GlobalAvgPoolPlugin.cpp#74-77 + return False - return True + # The `mean.dim` gets converted to AveragePool by the NeutronConverter, so the channels must be a + # multiple of `num_macs`. + # neutron-converter/src/OperatorC/GlobalAvgPoolPlugin.cpp#59-85 + num_macs = neutron_target_spec.get_num_macs() + channels_dim = 1 if node.meta[NXP_NODE_FORMAT].is_channels_first() else -1 + if (node.meta["val"].shape[channels_dim] % num_macs) != 0: + return False - case _: + # Neutron only supports reduction over the spatial dimensions H, W. + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # The input is NCHW. H and W are at indices 2 and 3. + if dim not in [[2, 3], [3, 2]]: + return False + else: + # The input is formatless. It can be considered as NHWC, as this is the way Neutron will look at + # the dimensions. So H and W are the middle dimensions. + if dim not in [[1, 2], [2, 1]]: return False - @staticmethod - def _to_pos_dim(d, rank): - return d + rank if d < 0 else d - - @staticmethod - def _to_neg_dim(d, rank): - return d - rank if d > 0 else d + return True @staticmethod def _is_supported_in_IR( @@ -75,6 +79,10 @@ def _is_supported_in_IR( return True + @staticmethod + def _to_pos_dim(d: int, rank: int): + return d + rank if d < 0 else d + @staticmethod def _normalize_and_to_channel_last_dim(dim: list[int], rank: int) -> list[int]: # convert negative index to positive diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mul_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mul_tensor_converter.py new file mode 100644 index 00000000000..d67b0aa4bcb --- /dev/null +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mul_tensor_converter.py @@ -0,0 +1,61 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.nxp.backend.ir.converter.conversion.common import ( + node_uses_shape_broadcasting, +) +from executorch.backends.nxp.backend.ir.converter.node_converter import ( + CustomDelegationOptions, + NodeConverter, +) +from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( + mul_options, +) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from torch.fx import Node +from torch.nn import Parameter + + +class MulTensorConverter(NodeConverter): + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + if node_uses_shape_broadcasting(node): + # Shape broadcasting may require the addition of `Transpose` ops during conversion. + return False + + node_shape = node.meta["val"].shape + + # Check that at least one dimension is divisible by number of MACS + # or all dimensions are equal to one + # Otherwise Neutron cannot convert it + dim_divisible = any(s % 8 == 0 for s in node_shape) or all( + s == 1 for s in node_shape + ) + return dim_divisible + + @staticmethod + def _is_supported_in_IR( + node: Node, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + if len(node.args) != 2: + return False + + return True + + # mul.Tensor Node format: (Tensor self, Tensor other, *) + def convert(self, node: Node): + """Convert 'mul_tensor' operator to NeutronIR 'Mul'.""" + self.assert_convertible(node) + t_op = self._create_tflite_op_with_io_tensors(node) + t_op.builtin_options = mul_options.Mul() + + self.builder.append_operators([t_op]) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py index f0150b4bc1f..35bef6c8035 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py @@ -4,28 +4,438 @@ # LICENSE file in the root directory of this source tree. import numpy as np +import torch +from executorch.backends.nxp.backend.edge_helper import ( + node_is_effectively_static_tensor, +) +from executorch.backends.nxp.backend.ir.conversion_context import ConversionContext from executorch.backends.nxp.backend.ir.converter import quantization_utils +from executorch.backends.nxp.backend.ir.converter.conversion import translator from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, + NeutronTargetSpec, NodeConverter, ) +from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat +from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( transpose_options, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + is_tensor_invariant_permutation, + transposition_is_supported_on_neutron, +) +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter +Permutation = list[int] +PermutationSupportDict = dict[str, dict[str, bool | Permutation]] + + +def _get_shape(node: torch.fx.Node) -> list[int]: + return list(node.meta["val"].shape) + + +def get_supported_transpositions( + node: Node, neutron_target_spec: NeutronTargetSpec +) -> PermutationSupportDict: + """Since ExecuTorch and NeutronIR use different tensor formats, we must consider the different possible cases + which may occur. The main permutation is always done on channels_first/formatless data, and the output is + channels_first/formatless as well. If this is not the case, a `Transpose` is inserted before and/or + after the main `Transpose`, to make the input/output channels_first. These additional `Transpose` + ops must be supported by Neutron as well. Alternatively, consecutive `Transpose` ops can be fused + together. It is possible for a pair of unsupported permutation to result in a supported one. + Therefore, the merged permutations must also be considered. + + This function identifies which of these permutations are supported on neutron, and returns a dictionary with the + support summary and the corresponding permutations. + + :param node: The `permute_copy` node to base the support analysis from/ + :param neutron_target_spec: NeutronTagetSpec instance. + :return: A dictionary containing the support status and permutation, for all the possible permutations which may be + used during the conversion of the `node`. + """ + + input_shape = node.args[0].meta["val"].shape + output_shape = node.meta["val"].shape + perm = list(node.args[1]) + + to_nchw_perm = translator.create_channels_last_to_channels_first_permutation( + len(input_shape), True + ) + to_nhwc_perm = translator.create_channels_first_to_channels_last_permutation( + len(input_shape), True + ) + channels_last_input_shape = translator.apply_permutation_to( + input_shape, to_nhwc_perm + ) + + main_perm_supported = transposition_is_supported_on_neutron( + input_shape, perm, neutron_target_spec + ) + + # "To NCHW" permutation, in case the input is channels last. + separate_pre_transpose_supported = transposition_is_supported_on_neutron( + channels_last_input_shape, to_nchw_perm, neutron_target_spec + ) + # The main permutation and the previous one merged. + merged_pre_transpose_supported = transposition_is_supported_on_neutron( + channels_last_input_shape, + merged_pre_transpose_permutation := translator.combine_permutations( + to_nchw_perm, perm + ), + neutron_target_spec, + ) + + # "To NHWC" permutation after the main `Transpose`. + separate_post_transpose_supported = transposition_is_supported_on_neutron( + output_shape, to_nhwc_perm, neutron_target_spec + ) + + # The main permutation and the previous one merged. + merged_post_transpose_supported = transposition_is_supported_on_neutron( + input_shape, + merged_post_transpose_permutation := translator.combine_permutations( + perm, to_nhwc_perm + ), + neutron_target_spec, + ) + + # "To NCHW", main permutation, and "to NHWC" all merged. + everything_merged_supported = transposition_is_supported_on_neutron( + input_shape, + everything_merged_permutation := translator.combine_permutations( + translator.combine_permutations(to_nchw_perm, perm), to_nhwc_perm + ), + neutron_target_spec, + ) + + return { + "main": {"supported": main_perm_supported, "perm": perm}, + "separate_pre": { + "supported": separate_pre_transpose_supported, + "perm": to_nchw_perm, + }, + "merged_pre": { + "supported": merged_pre_transpose_supported, + "perm": merged_pre_transpose_permutation, + }, + "separate_post": { + "supported": separate_post_transpose_supported, + "perm": to_nhwc_perm, + }, + "merged_post": { + "supported": merged_post_transpose_supported, + "perm": merged_post_transpose_permutation, + }, + "everything_merged": { + "supported": everything_merged_supported, + "perm": everything_merged_permutation, + }, + } + + +class PermuteCopyFormatHandler: + def __init__(self, context: ConversionContext): + self.context = context + + @property + def neutron_target_spec(self): + return self.context.tflite_builder.neutron_target_spec + + @property + def builder(self): + return self.context.tflite_builder + + def _handle_channels_first_input_and_formatless_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # The input must be permuted. + # Either combine the permutations, or prepend a `Transpose` operator. + + if node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + # The input is static, so the operator will be removed by an optimization. + perm = perm_dict["main"]["perm"] + + elif perm_dict["merged_pre"]["supported"]: + # Use the combined permutation. + perm = perm_dict["merged_pre"]["perm"] + + elif perm_dict["separate_pre"]["supported"] and perm_dict["main"]["supported"]: + # Prepend a `Transpose` operator to make the input channels first. + ops.add_pre( + self.builder.create_transpose_operator_before( + t_op, 0, perm_dict["separate_pre"]["perm"] + ) + ) + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + t_op.tmp_inputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + + return perm + + def _handle_formatless_input_and_channels_first_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # The output must be permuted. + # Either combine the permutations, or append a `Transpose` operator. + + if node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + # The input is static, so the operator will be removed by an optimization. + perm = perm_dict["main"]["perm"] + + elif perm_dict["merged_post"]["supported"]: + # Use the combined permutation. + perm = perm_dict["merged_post"]["perm"] + + elif perm_dict["main"]["supported"] and perm_dict["separate_post"]["supported"]: + # Append a `Transpose` operator to make the output channels first. + perm = perm_dict["main"]["perm"] + ops.add_post( + self.builder.create_transpose_operator_after( + t_op, 0, perm_dict["separate_post"]["perm"] + ) + ) + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + t_op.tmp_outputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + + return perm + + def _handle_channels_first_input_and_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # Both input and output must be permuted, or some merged permutations must be supported. + if perm_dict["everything_merged"]["supported"]: + # Combine all 3 permutations into 1. + perm = perm_dict["everything_merged"]["perm"] + + elif ( + perm_dict["merged_pre"]["supported"] + and perm_dict["separate_post"]["supported"] + ): + # Combine the input and main permutations, and append a `Transpose` to handle the output permutation. + perm = perm_dict["merged_pre"]["perm"] + ops.add_post( + self.builder.create_transpose_operator_after( + t_op, 0, perm_dict["separate_post"]["perm"] + ) + ) + + elif ( + perm_dict["separate_pre"]["supported"] + and perm_dict["merged_post"]["supported"] + ): + # Prepend a `Transpose` to handle the input permutation, and combine the main and output permutations. + ops.add_pre( + self.builder.create_transpose_operator_before( + t_op, 0, perm_dict["separate_pre"]["perm"] + ) + ) + perm = perm_dict["everything_merged"]["supported"] + + elif ( + perm_dict["separate_pre"]["supported"] + and perm_dict["main"]["supported"] + and perm_dict["separate_post"]["supported"] + ): + # Handle each permutation separately. + ops.add_pre( + self.builder.create_transpose_operator_before( + t_op, 0, perm_dict["separate_pre"]["perm"] + ) + ) + perm = perm_dict["main"]["perm"] + ops.add_post( + self.builder.create_transpose_operator_after( + t_op, 0, perm_dict["separate_post"]["perm"] + ) + ) + + elif node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + t_op.tmp_inputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + t_op.tmp_outputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + + return perm + + def _handle_formatless_input_and_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # Neither the input nor the output have to be permuted. + if perm_dict["main"]["supported"]: + perm = perm_dict["main"]["perm"] + + elif node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + return perm + + def handle_tensor_formats(self, t_op: tflite_model.Operator, node: Node) -> OpsList: + """Due to the different tensor formats used by ExecuTorch and NeutronIR, it may be necessary to modify the + permutation, or insert extra permutations to equalize the tensor formats. + This method identifies the four possible cases of input/output formats, and finds the conversion solution + which minimizes the number of necessary `Transpose` operators. + """ + perm_dict = get_supported_transpositions(node, self.neutron_target_spec) + + ops = OpsList(middle_op=t_op) + input_format, output_format = ( + node.args[0].meta[NXP_NODE_FORMAT], + node.meta[NXP_NODE_FORMAT], + ) + if input_format.is_channels_first() and (not output_format.is_channels_first()): + perm = self._handle_channels_first_input_and_formatless_output( + perm_dict, node, t_op, ops + ) + + elif ( + not input_format.is_channels_first() + ) and output_format.is_channels_first(): + perm = self._handle_formatless_input_and_channels_first_output( + perm_dict, node, t_op, ops + ) + + elif input_format.is_channels_first() and output_format.is_channels_first(): + perm = self._handle_channels_first_input_and_output( + perm_dict, node, t_op, ops + ) + + else: + perm = self._handle_formatless_input_and_output(perm_dict, node, t_op, ops) + + perm_tensor = self.builder.create_tensor_for_data( + np.array(perm, "int32"), "perm" + ) + + # Use the final permutation as the operator's second input. + t_op.tmp_inputs = [t_op.tmp_inputs[0], perm_tensor] + + return ops + class PermuteCopyConverter(NodeConverter): + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + if node_is_effectively_static_tensor(node.args[0], parameters_mapping): + return ( + True # The operator computes on static data. It will be removed later. + ) + + input_shape = _get_shape(node.args[0]) + perm = list(node.args[1]) + + to_nhwc_perm = translator.create_channels_first_to_channels_last_permutation( + len(input_shape), True + ) + channels_last_input_shape = translator.apply_permutation_to( + input_shape, to_nhwc_perm + ) + + if is_tensor_invariant_permutation( + input_shape, perm + ) and is_tensor_invariant_permutation(channels_last_input_shape, perm): + # The `permute_copy` can always be represented as a Reshape. + return True + + perm_dict = get_supported_transpositions(node, neutron_target_spec) + + input_format, output_format = ( + node.args[0].meta[NXP_NODE_FORMAT], + node.meta[NXP_NODE_FORMAT], + ) + if input_format.is_channels_first() and (not output_format.is_channels_first()): + # Just the input must be permuted. + return ( + perm_dict["separate_pre"]["supported"] + and perm_dict["main"]["supported"] + ) or perm_dict["merged_pre"]["supported"] + + elif ( + not input_format.is_channels_first() + ) and output_format.is_channels_first(): + # Just the output must be permuted. + return ( + perm_dict["separate_post"]["supported"] + and perm_dict["main"]["supported"] + ) or perm_dict["merged_post"]["supported"] + + elif input_format.is_channels_first() and output_format.is_channels_first(): + # Both input and output must be permuted. + return ( + # Separate IO transpositions. + ( + perm_dict["separate_pre"]["supported"] + and perm_dict["main"]["supported"] + and perm_dict["separate_post"]["supported"] + ) + # Separate input, merged output. + or ( + perm_dict["separate_pre"]["supported"] + and perm_dict["merged_post"]["supported"] + ) + # Merged input, separate output. + or ( + perm_dict["merged_pre"]["supported"] + and perm_dict["separate_post"]["supported"] + ) + # Merged input and output. + or perm_dict["everything_merged"]["supported"] + ) + else: + # Simplest case. No format changes required. + return perm_dict["main"]["supported"] + @staticmethod def _is_supported_in_IR( node: Node, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: + if not NodeConverter._has_shared_q_params_if_quantized(node): + return False + return True def convert(self, node: Node): @@ -53,13 +463,6 @@ def convert(self, node: Node): "match. This indicates error in quantizer." ) - perm = np.array(node.args[1], "int32") - perm_tensor = self.builder.create_tensor_for_data(perm, "perm") - - # Assign the operator its TFLite inputs and outputs - t_op.tmp_inputs = [x, perm_tensor] - t_op.tmp_outputs = [y] - - ops_to_add = OpsList(middle_op=t_op) + ops = PermuteCopyFormatHandler(self.context).handle_tensor_formats(t_op, node) - self.builder.append_operators(ops_to_add.flatten()) + self.builder.append_operators(ops.flatten()) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py index c6ea7f90042..3e20e504e8a 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py @@ -3,9 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from abc import ABC, abstractmethod + import numpy as np from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( + create_channels_last_to_channels_first_permutation, torch_type_to_numpy_type, ) from executorch.backends.nxp.backend.ir.converter.node_converter import ( @@ -15,11 +18,21 @@ from executorch.backends.nxp.backend.ir.converter.quantization_utils import ( set_quantization_parameters_to_tensor, ) +from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat +from executorch.backends.nxp.backend.ir.tflite_generator.tflite_model import Tensor from torch.fx import Node from torch.nn import Parameter -class QDQDequantizeConverter(NodeConverter): +class QDQDequantizeConverterBase(NodeConverter, ABC): + + @abstractmethod + def get_zero_point(self, node: Node) -> np.ndarray: + pass + + @abstractmethod + def get_scale(self, node: Node) -> np.ndarray: + pass @staticmethod def _is_supported_in_IR( @@ -27,7 +40,7 @@ def _is_supported_in_IR( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - zero_point_type = torch_type_to_numpy_type(node.args[5]) + zero_point_type = torch_type_to_numpy_type(node.args[-1]) if "cluster" not in node.meta or zero_point_type not in [np.int8, np.int32]: return False @@ -39,10 +52,11 @@ def convert(self, node: Node): from_tensor = self.builder.tensor_for_name(node.name) to_tensor = self.builder.tensor_for_name(node.args[0].name) - zero_point_type = torch_type_to_numpy_type(node.args[5]) - - scale = np.array(node.args[1], dtype=np.float32) - zero_point = np.array(node.args[2], dtype=zero_point_type) + scale = self.get_scale(node) + zero_point = self.get_zero_point(node) + quantized_dimension = 0 + if isinstance(self, QDQPerChannelDequantizeConverter): + quantized_dimension = self.get_quantization_dimension(from_tensor, node) if self.context.parameters_mapping.get(node.args[0].name, None) is None: # Convert dequantize as identity op (Transpose that will be removed) because @@ -50,16 +64,53 @@ def convert(self, node: Node): # here we will change input name of the model. t_op = self._create_tflite_op_with_io_tensors(node) - set_quantization_parameters_to_tensor(to_tensor, scale, zero_point, 0) - set_quantization_parameters_to_tensor(from_tensor, scale, zero_point, 0) + set_quantization_parameters_to_tensor( + to_tensor, scale, zero_point, quantized_dimension + ) + set_quantization_parameters_to_tensor( + from_tensor, scale, zero_point, quantized_dimension + ) from_tensor.type = to_tensor.type self.builder.turn_operator_to_identity(t_op) self.builder.append_operators([t_op]) else: # Dequantize consumes tensor with static data -> convert as a tensor - set_quantization_parameters_to_tensor(to_tensor, scale, zero_point, 0) + set_quantization_parameters_to_tensor( + to_tensor, scale, zero_point, quantized_dimension + ) # Change type so we pass check tensor similarity check when redirecting from_tensor.type = to_tensor.type self.builder.redirect_tensor(from_tensor, to_tensor) + + +class QDQPerTensorDequantizeConverter(QDQDequantizeConverterBase): + + def get_zero_point(self, node: Node) -> np.ndarray: + zero_point_type = torch_type_to_numpy_type(node.args[5]) + return np.array(node.args[2], dtype=zero_point_type) + + def get_scale(self, node: Node) -> np.ndarray: + return np.array(node.args[1], dtype=np.float32) + + +class QDQPerChannelDequantizeConverter(QDQDequantizeConverterBase): + + def get_zero_point(self, node: Node) -> np.ndarray: + return self.context.parameters_mapping[node.args[2].name].numpy() + + def get_scale(self, node: Node) -> np.ndarray: + return self.context.parameters_mapping[node.args[1].name].numpy() + + def get_quantization_dimension(self, from_tensor: Tensor, node: Node) -> int: + quantization_dimension = node.args[3] + + # Quantization dimension is affected by tensor format + if from_tensor.tensor_format == TensorFormat.CHANNELS_LAST: + tensor_rank = len(from_tensor.shape.vector) + perm = create_channels_last_to_channels_first_permutation( + tensor_rank, return_list=True + ) + quantization_dimension = perm[quantization_dimension] + return quantization_dimension diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/slice_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/slice_tensor_converter.py new file mode 100644 index 00000000000..fd2aec7b8a0 --- /dev/null +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/slice_tensor_converter.py @@ -0,0 +1,158 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from executorch.backends.nxp.backend.edge_helper import input_tensor +from executorch.backends.nxp.backend.ir.converter.conversion import translator +from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList +from executorch.backends.nxp.backend.ir.converter.node_converter import ( + CustomDelegationOptions, + NodeConverter, +) +from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( + slice_options, +) +from executorch.backends.nxp.backend.neutron_operator_support import ( + transposition_is_supported_on_neutron, +) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT +from torch.fx import Node +from torch.nn import Parameter + + +class SliceTensorConverter(NodeConverter): + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + # Provisional solution - slice conversion works for neutron software 2.2.1+ + neutron_flavor = neutron_target_spec.neutron_target.__module__.split(".")[0] + if neutron_flavor != "neutron_converter_SDK_25_12": + return False + + input_shape = input_tensor(node, 0).shape + dim = node.args[1] + if node.args[0].meta[NXP_NODE_FORMAT].is_channels_first(): + dim = translator.create_channels_last_to_channels_first_permutation( + len(input_shape) + )[dim] + input_shape = translator.apply_permutation_to( + input_shape, + translator.create_channels_first_to_channels_last_permutation( + len(input_shape) + ), + ) + input_rank = len(input_shape) + + # Slicing is only allowed along the channel dimension. + # Therefore, we must verify that Neutron supports swapping the channel dimension + # with the dimension intended for slicing. + if dim != -1 and dim != input_rank - 1: + perm = list(range(0, input_rank)) + perm[dim], perm[-1] = perm[-1], perm[dim] + + if not transposition_is_supported_on_neutron( + list(input_shape), perm, neutron_target_spec + ): + return False + + # The shape of dimension that we want to slice must be divisible by num_macs + num_macs = neutron_target_spec.get_num_macs() + return input_shape[dim] % num_macs == 0 + + @staticmethod + def _is_supported_in_IR( + node: Node, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + args = node.args + if len(args) != 4: + return False + + dim, start, end = SliceTensorConverter._get_clipped_slice_args(node) + input_rank = len(input_tensor(node, 0).shape) + + # Check "dim" out of bounds + if dim >= input_rank or abs(dim) > input_rank: + return False + + # Check invalid combination of "start" and "end" parameters + if start >= end: + return False + + return True + + def _convert_to_slice(self, t_op, main_input, input_rank, dim, start, end) -> None: + # Prepare the TFLite parameters 'begin' and 'size' tensors + begin = [0] * input_rank # By default, start the slice at 0 + size = ( + main_input.shape.vector.copy() + ) # By default, end the slice at the end of the dimension + + size[dim] = max(end - start, 0) + begin[dim] = start + + # We can slice only the channels dimension + # So we swap the sliced dimension with the channels dimension + begin[-1], begin[dim] = begin[dim], begin[-1] + size[-1], size[dim] = size[dim], size[-1] + + begin_tensor = self.builder.create_tensor_for_data( + np.asarray(begin, np.int32), "begin" + ) + size_tensor = self.builder.create_tensor_for_data( + np.asarray(size, np.int32), "size" + ) + + t_op.tmp_inputs = [main_input, begin_tensor, size_tensor] + t_op.builtin_options = slice_options.Slice() + ops = OpsList(middle_op=t_op) + + # If slicing along non-channels dimension, we need to swap it with channels dimension. + # Otherwise Neutron will not convert it. + if dim != -1 and dim != input_rank - 1: + # Create permutation for swapping + perm = list(range(0, input_rank)) + perm[dim], perm[-1] = perm[-1], perm[dim] + + # Insert forward and backward transpose + ops.add_pre(self.builder.create_transpose_operator_before(t_op, 0, perm)) + ops.add_post(self.builder.create_transpose_operator_after(t_op, 0, perm)) + + self.builder.append_operators(ops.flatten()) + + Dim = Start = End = int + + @staticmethod + def _get_clipped_slice_args(node: Node) -> tuple[Dim, Start, End]: + input_shape = input_tensor(node, 0).shape + _, dim, start, end = node.args + sliced_tensor_rank = input_shape[dim] + + end = int(np.clip(end, 0, sliced_tensor_rank)) + start = int(np.clip(start, 0, sliced_tensor_rank)) + + return dim, start, end + + def convert(self, node: Node): + """Convert 'slice_tensor' operator to NeutronIR 'Slice'.""" + self.assert_convertible(node) + t_op = self._create_tflite_op_with_io_tensors(node) + inputs = t_op.tmp_inputs[0] + rank = inputs.rank + + dim, start, end = self._get_clipped_slice_args(node) + + if t_op.tmp_inputs[0].tensor_format.is_channels_last(): + dim = translator.create_channels_last_to_channels_first_permutation( + t_op.tmp_inputs[0].rank + )[dim] + + self._convert_to_slice(t_op, inputs, rank, dim, start, end) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/softmax_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/softmax_converter.py index aa74c78ca24..5e4404d8476 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/softmax_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/softmax_converter.py @@ -7,13 +7,11 @@ CustomDelegationOptions, ) from executorch.backends.nxp.backend.edge_helper import input_rank -from executorch.backends.nxp.backend.ir.converter.node_converter import ( - NodeConverter, - Target, -) +from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( softmax_options, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node from torch.nn import Parameter @@ -22,18 +20,11 @@ class SoftmaxConverter(NodeConverter): @staticmethod def _is_supported_on_target( node: Node, - target: Target, + neutron_target_spec: NeutronTargetSpec, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - match target: - case Target.RT700: - # The eIQ Neutron NPU runtime software has a known issue with the SoftMax operation. - # As long as the issue is present, return False for the i.MX RT700 target also. - return False - - case _: - return False + return False @staticmethod def _is_supported_in_IR( diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py new file mode 100644 index 00000000000..e9522c87114 --- /dev/null +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py @@ -0,0 +1,59 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.nxp.backend.ir.converter.conversion.common import ( + node_uses_shape_broadcasting, +) +from executorch.backends.nxp.backend.ir.converter.node_converter import ( + CustomDelegationOptions, + NodeConverter, +) +from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( + sub_options, +) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from torch.fx import Node +from torch.nn import Parameter + + +class SubTensorConverter(NodeConverter): + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + if node_uses_shape_broadcasting(node): + # Shape broadcasting may require the addition of `Transpose` ops during conversion. + return False + + return True + + @staticmethod + def _is_supported_in_IR( + node: Node, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + if len(node.args) != 2: + return False + + # The `alpha` attribute can be represented by adding an extra `Mul` operator. + # However, this is not implemented as `alpha` is rarely used. + if hasattr(node.kwargs, "alpha"): + return False + + return True + + # sub.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1) + def convert(self, node: Node): + """Convert 'sub_tensor' operator to NeutronIR 'Sub'.""" + self.assert_convertible(node) + + t_op = self._create_tflite_op_with_io_tensors(node) + + t_op.builtin_options = sub_options.Sub() + self.builder.append_operators([t_op]) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/view_copy_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/view_copy_converter.py index 95a42d5d078..1c8a0086c72 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/view_copy_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/view_copy_converter.py @@ -6,14 +6,21 @@ import numpy as np from executorch.backends.nxp.backend.edge_helper import ( + get_non_qdq_users, input_tensor, output_tensor, tensor_rank, ) from executorch.backends.nxp.backend.ir.converter import quantization_utils from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList +from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( + apply_permutation_to, + create_channels_first_to_channels_last_permutation, + create_channels_last_to_channels_first_permutation, +) from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, + is_not_qdq_node, NodeConverter, ) from executorch.backends.nxp.backend.ir.converter.node_converters.shared.reshape_transposition import ( @@ -22,7 +29,14 @@ from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( reshape_options, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + transposition_is_supported_on_neutron, +) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT +from executorch.exir.dialects._ops import ops as exir_ops from torch.fx import Node +from torch.fx.passes.infra.partitioner import Partition from torch.nn import Parameter @@ -45,6 +59,99 @@ def _is_supported_in_IR( return True + @classmethod + def supports_partitioning_result( + cls, + node: Node, + partition_list: list[Partition], + custom_delegation_options: CustomDelegationOptions, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + ): + view_copy_partitions = [ + partition for partition in partition_list if node in partition.nodes + ] + assert len(view_copy_partitions) == 1 + non_q_dq_partition_nodes = list( + filter(is_not_qdq_node, view_copy_partitions[0].nodes) + ) + + if len(non_q_dq_partition_nodes) == 1: + # The `view_copy` cannot be the only node in a partition. + return False + + input_format = node.args[0].meta[NXP_NODE_FORMAT] + output_format = node.meta[NXP_NODE_FORMAT] + input_shape = list(node.args[0].meta["val"].shape) + output_shape = list(node.meta["val"].shape) + to_nchw_perm = create_channels_last_to_channels_first_permutation( + len(input_shape), True + ) + to_nhwc_perm = create_channels_first_to_channels_last_permutation( + len(output_shape), True + ) + channels_last_input_shape = apply_permutation_to( + input_shape, + create_channels_first_to_channels_last_permutation(len(input_shape), True), + ) + + if input_format.is_channels_first() and (not output_format.is_channels_first()): + # The `view_copy` removes node format. Conversion will require the addition of a `Transpose` operator. + # Make sure the `Transpose` will be supported. + + if not transposition_is_supported_on_neutron( + channels_last_input_shape, to_nchw_perm, neutron_target_spec + ): + # The `Transpose` would have to be removed by the `PermuteFullyConnectedWeightsAfterReshape` pass. + # Make sure it will be applied. + users = get_non_qdq_users(node) + if len(users) != 1 or (linear_node := users[0]).target not in [ + exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.mm.default, + ]: + return False + + if linear_node not in view_copy_partitions[0].nodes: + # The `mm` / `addmm` node will not be delegated within this partition. + return False + + # Make sure the specific requirements of the `PermuteFullyConnectedWeightsAfterReshape` are satisfied. + weights_index = ( + 2 if linear_node.target == exir_ops.edge.aten.addmm.default else 1 + ) + if not ( + input_shape[0] == output_shape[0] # Preserve batch. + and len(output_shape) == 2 + and output_shape[1] + == linear_node.args[weights_index].meta["val"].shape[0] + ): + return False + + elif ( + not input_format.is_channels_first() + ) and output_format.is_channels_first(): + # The `view_copy` introduces node format. Conversion will require the addition of a `Transpose` operator. + # Make sure the `Transpose` will be supported. + if not transposition_is_supported_on_neutron( + output_shape, to_nhwc_perm, neutron_target_spec + ): + return False + + elif input_format.is_channels_first() and output_format.is_channels_first(): + # The `view_copy` works with the channels first format, so both tensors will end up being transposed. + # Make sure these transpositions are supported. + if not ( + transposition_is_supported_on_neutron( + channels_last_input_shape, to_nchw_perm, neutron_target_spec + ) + and transposition_is_supported_on_neutron( + output_shape, to_nhwc_perm, neutron_target_spec + ) + ): + return False + + return True + @staticmethod def _safe_compute_flat_size(shape: list[int | str]) -> int: """Compute the flat size of a tensor with given shape. Strings and negative dimensions are treated as '1'. diff --git a/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py b/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py index 5817fd127b3..2012ecc8640 100755 --- a/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py +++ b/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py @@ -16,6 +16,8 @@ class ConvParameters: stride: list[int] padding: list[int] dilation: list[int] + transposed: bool + out_padding: list[int] groups: int @@ -35,6 +37,29 @@ def _get_IO_channels(node: Node | tflite_model.Operator) -> (int, int): return input_channels, output_channels +def get_node_tensor_params(node: Node) -> dict: + node_tensor_params = {} + + input_tensor = node.args[0] + assert len(input_tensor.meta["val"].shape) in [3, 4], "Supports only Conv 1D, 2D." + node_tensor_params["batch_size"] = input_tensor.meta["val"].shape[0] + node_tensor_params["inp_channels"] = input_tensor.meta["val"].shape[1] + node_tensor_params["inp_height"] = input_tensor.meta["val"].shape[2] + if len(input_tensor.meta["val"].shape) == 4: + node_tensor_params["inp_width"] = input_tensor.meta["val"].shape[3] + + weights = node.args[1] + node_tensor_params["out_channels"] = node.meta["val"].shape[1] + node_tensor_params["out_height"] = node.meta["val"].shape[2] + if len(node.meta["val"].shape) == 4: + node_tensor_params["out_width"] = node.meta["val"].shape[3] + node_tensor_params["kernel_height"] = weights.meta["val"].shape[2] + if len(weights.meta["val"].shape) == 4: + node_tensor_params["kernel_width"] = weights.meta["val"].shape[3] + + return node_tensor_params + + def group_conv_convertible_as_depthwise(node: Node | tflite_model.Operator, group: int): input_channels, output_channels = _get_IO_channels(node) @@ -70,9 +95,11 @@ def __init__( weight_tensor: tflite_model.Tensor, bias_tensor: tflite_model.Tensor, output_tensor: tflite_model.Tensor, + output_shape_tensor: tflite_model.Tensor | None = None, ): self.conv_input_tensor = input_tensor self.conv_weight_tensor = weight_tensor self.conv_bias_tensor = bias_tensor self.conv_output_tensor = output_tensor + self.output_shape_tensor = output_shape_tensor self.ops_list = OpsList() diff --git a/backends/nxp/backend/ir/converter/node_converters/shared/recurrent_utils.py b/backends/nxp/backend/ir/converter/node_converters/shared/recurrent_utils.py index 50b9aef6d18..52b895d60cd 100755 --- a/backends/nxp/backend/ir/converter/node_converters/shared/recurrent_utils.py +++ b/backends/nxp/backend/ir/converter/node_converters/shared/recurrent_utils.py @@ -1,19 +1,12 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from executorch.backends.nxp.backend.ir import logger from executorch.backends.nxp.backend.ir.converter.builder import model_builder from executorch.backends.nxp.backend.ir.converter.conversion import translator -from executorch.backends.nxp.backend.ir.converter.conversion.common import ( - OpsList, - try_get_input, -) +from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data -from executorch.backends.nxp.backend.ir.lib.tflite.ActivationFunctionType import ( - ActivationFunctionType, -) from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model @@ -25,12 +18,12 @@ def ensure_correct_tensor_formatting( or RNN operator. The LSTM/RNN may be using channels last tensors, because of the surrounding operators. LSTM/RNN requires its own - format, however I think the input tensors should be marked as 'FORMATLESS', because the main inputs of TFLite - and ONNX version of the operators have the same shape. + format, however I think the input tensors should be marked as 'FORMATLESS', because the main inputs of the + NeutronIR and the ExecuTorch version of the operators have the same shape. I believe that the cleanest and most robust way to solve this, is to mark LSTM/RNN as an operator which can change the formats of its tensors, and solve any format related issues in this module. - :param t_op: TFLite operator with inputs and outputs corresponding to the ONNX LSTM/RNN operator. + :param t_op: NeutronIR operator with inputs and outputs corresponding to the ExecuTorch LSTM/RNN operator. :param builder: ModelBuilder object. :param ops: OpsList object, with operators to add to the model. May already contain some operators. """ @@ -69,44 +62,3 @@ def ensure_correct_tensor_formatting( ops.post_ops.append(transpose) t_op.tmp_outputs[idx].tensor_format = TensorFormat.FORMATLESS - - -def get_activation_function_for_name( - name: str, op_type: str = "LSTM" -) -> ActivationFunctionType: - get_activation_function_for_name.map = { - "Tanh": ActivationFunctionType.TANH, - "Relu": ActivationFunctionType.RELU, - } - - if act_fun := get_activation_function_for_name.map.get(name, None): - return act_fun - - # Couldn't find a corresponding activation function - logger.e( - logger.Code.CONVERSION_IMPOSSIBLE, - f"Conversion of ONNX {op_type} with activation function '{name}' is not possible.", - ) - - -def check_sequence_lens( - t_op: tflite_model.Operator, seq_length: int, op_type: str = "LSTM" -): - """Check if the 'sequence_lens' operand of ONNX LSTM/RNN has an effect. If it does, exit with error. - - :param t_op: TFLite operator with inputs and outputs corresponding to the ONNX operator. - :param seq_length: The first dimension of the main LSTM input. - :param op_type: Operator type of 't_op'. Used only for printing a specific error message. - """ - if sequence_lens := try_get_input(t_op, 4): - # 'sequence_lens' allows each sequence to have a different length. As far as I can tell, TFLite doesn't support - # this. - if (not tensor_has_data(sequence_lens)) or any( - elt != seq_length for elt in sequence_lens.tmp_buffer.data - ): - # The 'sequence_lens' is either dynamic, or static with at least one value different from 'seq_length'. - # Conversion most likely impossible. - logger.e( - logger.Code.CONVERSION_IMPOSSIBLE, - f"Conversion of ONNX {op_type} with 'sequence_lens' input is not possible.", - ) diff --git a/backends/nxp/backend/ir/converter/node_converters/shared/reduce_utils.py b/backends/nxp/backend/ir/converter/node_converters/shared/reduce_utils.py index 1dca3acea74..da92e359f1e 100755 --- a/backends/nxp/backend/ir/converter/node_converters/shared/reduce_utils.py +++ b/backends/nxp/backend/ir/converter/node_converters/shared/reduce_utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import numpy as np + from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -16,7 +17,7 @@ def convert_axes_from_attribute( t_op: tflite_model.Operator, builder: ModelBuilder, axes: list[int] | None ): - """Create an `axes` tensor and assign it as an input to the `t_op`, which is expected to represent an ONNX + """Create an `axes` tensor and assign it as an input to the `t_op`, which is expected to represent an ExecuTorch reduction operator. """ x = t_op.tmp_inputs[0] @@ -52,15 +53,15 @@ def ensure_reduce_transposition(builder, ops: OpsList): output_format = output_tensor.tensor_format if input_format.is_channels_last() and output_format.is_channels_last(): - to_onnx_perm = translator.create_channels_last_to_channels_first_permutation( - input_rank + to_executorch_perm = ( + translator.create_channels_last_to_channels_first_permutation(input_rank) ) to_tflite_perm = translator.create_channels_first_to_channels_last_permutation( output_rank, return_list=True ) transpose_before = builder.create_transpose_operator_before( - t_op, 0, to_onnx_perm + t_op, 0, to_executorch_perm ) transpose_before.tmp_outputs[0].tensor_format = TensorFormat.CHANNELS_FIRST ops.add_pre(transpose_before) @@ -72,7 +73,7 @@ def ensure_reduce_transposition(builder, ops: OpsList): ops.post_ops.insert(0, transpose_after) elif input_format.is_channels_last() and not output_format.is_channels_last(): - # The dimensions of the tensor lose their meaning! Insert a transpose op, to change input to match ONNX. + # The dimensions of the tensor lose their meaning! Insert a transpose op, to change input to match ExecuTorch. permutation = list( translator.create_channels_last_to_channels_first_permutation(input_rank) @@ -83,9 +84,9 @@ def ensure_reduce_transposition(builder, ops: OpsList): ops.add_pre(transpose) elif not input_format.is_channels_last() and output_format.is_channels_last(): - # The ReduceX introduces format to the tensor - # The ONNX ReduceX outputs a 'channels first' tensor. This has to stay the same, and then a Transpose operator - # must be added, to change the tensor to 'channels last'. + # The reduction operator introduces format to the tensor. + # The ExecuTorch reduction operator outputs a 'channels first' tensor. This has to stay the same, and then a + # Transpose operator must be added, to change the tensor to 'channels last'. permutation = list( translator.create_channels_first_to_channels_last_permutation(output_rank) diff --git a/backends/nxp/backend/ir/converter/node_converters/shared/reshape_transposition.py b/backends/nxp/backend/ir/converter/node_converters/shared/reshape_transposition.py index 0e55c27684b..55056614684 100755 --- a/backends/nxp/backend/ir/converter/node_converters/shared/reshape_transposition.py +++ b/backends/nxp/backend/ir/converter/node_converters/shared/reshape_transposition.py @@ -1,4 +1,4 @@ -# Copyright 2023 NXP +# Copyright 2023-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -158,7 +158,7 @@ def ensure_reshape_transposition(builder, ops: OpsList) -> list[int]: new_shape = output_tensor.shape.vector if input_format.is_channels_last() and not output_format.is_channels_last(): - # The dimensions of the tensor lose their meaning! Insert a transpose op, to change input to match ONNX. + # The dimensions of the tensor lose their meaning! Insert a transpose op, to change input to match ExecuTorch. permutation = list( translator.create_channels_last_to_channels_first_permutation(input_rank) @@ -170,7 +170,7 @@ def ensure_reshape_transposition(builder, ops: OpsList) -> list[int]: elif not input_format.is_channels_last() and output_format.is_channels_last(): # The Reshape introduces format to the tensor (2D -> 4D for example) - # The ONNX Reshape outputs a 'channels first' tensor. This has to stay the same, and then a Transpose operator + # The `view_copy` outputs a 'channels first' tensor. This has to stay the same, and then a Transpose operator # must be added, to change the tensor to 'channels last'. permutation = list( diff --git a/backends/nxp/backend/ir/converter/quantization_utils.py b/backends/nxp/backend/ir/converter/quantization_utils.py index d9e7674d953..11de4eec13c 100755 --- a/backends/nxp/backend/ir/converter/quantization_utils.py +++ b/backends/nxp/backend/ir/converter/quantization_utils.py @@ -1,111 +1,19 @@ -# Copyright 2023 NXP +# Copyright 2023-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy -from typing import Iterable, List, Optional - -import executorch.backends.nxp.backend.ir.converter.builder.model_builder as model_builder +from typing import List import numpy as np + from executorch.backends.nxp.backend.ir import logger as logger -from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( - tf_lite_type_to_numpy, -) -from executorch.backends.nxp.backend.ir.lib.tflite import TensorType as tflTensorType -from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType from executorch.backends.nxp.backend.ir.tflite_generator import ( tflite_model as tflite_model, ) -def quantization_is_equal( - x_scale: np.ndarray, - x_zp: np.ndarray, - x_type: TensorType, - y_scale: np.ndarray, - y_zp: np.ndarray, - y_type: TensorType, -) -> bool: - """Determine if provided quantization parameters of tensors 'x' and 'y' are the same. - - :param x_scale: Scale of the 'x' tensor. - :param x_zp: Zero point of the 'x' tensor. - :param x_type: TFLite data type of the 'x' tensor. - :param y_scale: Scale of the 'y' tensor. - :param y_zp: Zero point of the 'y' tensor. - :param y_type: TFLite data type of the 'y' tensor. - :return: True, if the quantization parameters are equal. - """ - if x_type != y_type: - return False - - if not (x_scale.size == x_zp.size == y_scale.size == y_zp.size): - return False - - x_scale, x_zp = quantization_params_to_lists(x_scale, x_zp) - y_scale, y_zp = quantization_params_to_lists(y_scale, y_zp) - - return all( - x_s == y_s and x_z == y_z - for x_s, y_s, x_z, y_z in zip(x_scale, y_scale, x_zp, y_zp) - ) - - -def quantization_params_to_lists( - scale: np.ndarray, zero_point: np.ndarray -) -> (List[float], List[int]): - if (scale is None) or (zero_point is None): - logger.e( - logger.Code.INTERNAL_ERROR, - "Missing zero_point and/or scale quantization params when converting to list!", - ) - - if (scale.size == 1) and (zero_point.size == 1): - # Per tensor quantization - scale = [scale.item()] - zero_point = [zero_point.item()] - elif (scale.size != 1) and (zero_point.size != 1): - # Per channel quantization - scale = scale.tolist() - zero_point = zero_point.tolist() - else: - logger.e( - logger.Code.CONVERSION_IMPOSSIBLE, - "TFLite doesn't support combination of per-channel and per-tensor quantization params.", - ) - - return scale, zero_point - - -def is_quantization_valid(scale, zero_point): - return scale.size == zero_point.size - - -def is_per_tensor_quantized(scale, zero_point): - return (scale.size == 1) and (zero_point.size == 1) - - -def is_per_channel_quantized(scale, zero_point): - return is_quantization_valid(scale, zero_point) and not is_per_tensor_quantized( - scale, zero_point - ) - - -def get_symmetric_zero_point_for_type(tensor_type: TensorType): - match tensor_type: - case TensorType.INT8: - return 0 - case TensorType.UINT8: - return 128 - case _: - logger.e( - logger.Code.INTERNAL_ERROR, - f"Attempt to get zero point definition for type: {tensor_type}", - ) - - def _validate_or_set_quant_params( tensor: tflite_model.Tensor, quant: tflite_model.Quantization ) -> bool: @@ -130,7 +38,7 @@ def propagate_quantization( """ Propagates quantization parameters from from_tensor to to_tensor. If to_tensor already has the params set checks the consistency. - :raises: logger.Error - INVALID_ONNX_MODEL + :raises: logger.Error - INVALID_INPUT_MODEL """ if ( @@ -147,7 +55,7 @@ def propagate_quantization( # noinspection PyTypeChecker if not _validate_or_set_quant_params(to_tensor, from_tensor.quantization): logger.e( - logger.Code.INVALID_ONNX_MODEL, + logger.Code.INVALID_INPUT_MODEL, f'Mismatched quantization parameters between tensors "{from_tensor.name}" and "{to_tensor.name}"', ) @@ -161,16 +69,16 @@ def set_quantization_parameters_to_tensor( """Create a TFLite QuantizationParameters object, initialize it from given parameters and add it to the 'tflite_tensor'. :param tflite_tensor: The TFLite tensor in the model, to add the quantization to. - :param scale: The data of the tensor, which is an input of a quantized ONNX operator and represents the + :param scale: The data of the tensor, which is an input of a quantized ExecuTorch operator and represents the quantization scale. - :param zero_point: The data of the tensor, which is an input of a quantized ONNX operator and represents the + :param zero_point: The data of the tensor, which is an input of a quantized ExecuTorch operator and represents the quantization zero point. :param quantized_dimension: The quantized dimension attribute of TFLite QuantizationParameters. """ if (scale is None) or (zero_point is None): logger.e( logger.Code.NOT_IMPLEMENTED, - "Conversion of ONNX quantized operators is only supported when " + "Conversion of ExecuTorch quantized operators is only supported when " "the quantization parameters are static!", ) @@ -184,8 +92,8 @@ def set_quantization_parameters_to_tensor( if scale.size != zero_point.size: logger.e( - logger.Code.INVALID_ONNX_MODEL, - f"The per channel quantization parameters of ONNX tensor " + logger.Code.INVALID_INPUT_MODEL, + f"The per channel quantization parameters of ExecuTorch tensor " f"'{tflite_tensor.name}' are of different sizes! ('{scale.size}'" f" != '{zero_point.size}')", ) @@ -193,8 +101,8 @@ def set_quantization_parameters_to_tensor( quantized_dimension_size = tflite_tensor.shape.get(quantized_dimension) if scale.size != quantized_dimension_size: logger.e( - logger.Code.INVALID_ONNX_MODEL, - f"The ONNX per channel quantization parameter vectors do not " + logger.Code.INVALID_INPUT_MODEL, + f"The ExecuTorch per channel quantization parameter vectors do not " f"match the size of the quantized dimension! ('{scale.size}' != " f"'{quantized_dimension_size}')", ) @@ -205,8 +113,8 @@ def set_quantization_parameters_to_tensor( else: # Combination of per tensor and per channel quantization parameters logger.e( - logger.Code.INVALID_ONNX_MODEL, - f"ONNX tensor '{tflite_tensor.name}' uses a combination of per " + logger.Code.INVALID_INPUT_MODEL, + f"ExecuTorch node '{tflite_tensor.name}' uses a combination of per " f"tensor and per channel quantization parameters. Conversion to " f"TFLite is not possible!", ) @@ -218,33 +126,12 @@ def set_quantization_parameters_to_tensor( ) if not _validate_or_set_quant_params(tflite_tensor, quant): logger.e( - logger.Code.INVALID_ONNX_MODEL, + logger.Code.INVALID_INPUT_MODEL, f'Mismatched quantization parameters between tensors: "{tflite_tensor.name}" already ' f"has the quantization params set", ) -def calculate_uint_to_int_re_quantization_zero_point( - data_type_byte_size: int, old_zero_point: Iterable[int] -) -> np.ndarray: - """ - Calculate the new zero points, after a quantized tensor with an unsigned int data type is re-quantized to - a signed type. - :param data_type_byte_size: Size of the data type that is used, in Bytes. For example 1 for INT8. - :param old_zero_point: The zero point quantisation parameter, of the original data, before re-quantization. - :return: The new zero point quantisation parameter, after re-quantization. - """ - data_type_bit_size = 8 * data_type_byte_size - zero_point_shift = 2 ** (data_type_bit_size - 1) - return np.asarray(np.subtract(np.array(old_zero_point, np.int32), zero_point_shift)) - - -def _re_quantize_uint8_to_int8(tensor_data: np.ndarray) -> np.ndarray: - """Re-quantize static uint8 data to int8.""" - int16_data = np.asarray(tensor_data, np.int16) - return np.array(int16_data - 128, np.int8) - - def quantize_int8( data: np.ndarray, scale: List[float], zero_point: List[int] ) -> np.ndarray: @@ -252,20 +139,6 @@ def quantize_int8( return np.clip(new_data, -128, 127).astype(np.int8) -def quantize_uint8( - data: np.ndarray, scale: List[float], zero_point: List[int] -) -> np.ndarray: - new_data = np.add(np.round(np.divide(data, scale)), zero_point) - return np.clip(new_data, 0, 255).astype(np.uint8) - - -def quantize_int32( - data: np.ndarray, scale: List[float], zero_point: List[int] -) -> np.ndarray: - new_data = np.add(np.round(np.divide(data, scale)), zero_point) - return np.clip(new_data, -2_147_483_648, 2_147_483_648).astype(np.int32) - - def dequantize( data: np.ndarray, scale: List[float], zero_point: List[int] ) -> np.ndarray: @@ -274,211 +147,3 @@ def dequantize( scale, dtype=np.float32, ) - - -def re_quantize_static_tensor( - builder: "model_builder.ModelBuilder", - tflite_tensor: tflite_model.Tensor, - to_type: tflTensorType.TensorType, - new_scale: Optional[List[float]] = None, - new_zero_point: Optional[List[int]] = None, -) -> tflite_model.Tensor: - """Create a new TFLite Tensor with new quantization parameters, type and data. - - :param builder: A ModelBuilder instance. - :param tflite_tensor: TFLite tensor to re-quantize. - :param to_type: The TFLite TensorType, that the tensor will be re-quantized to. - :param new_scale: New scale quantization parameter. Used only when re-quantizing to the same type. - :param new_zero_point: New zero point quantization parameter. Used only when re-quantizing to the same type. - :return: A new re-quantized tensor. - """ - if tflite_tensor.quantization is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "translator.re_quantize_static_tensor(): Got tensor without quantization!", - ) - - if tflite_tensor.tmp_buffer.data is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "translator.re_quantize_static_tensor(): Got tensor without static data!", - ) - - new_dtype = tf_lite_type_to_numpy(to_type) - re_quantized_tensor = builder.duplicate_tensor(tflite_tensor) - tensor_data = re_quantized_tensor.tmp_buffer.data - - if tensor_data.dtype == np.uint8 and new_dtype == np.int8: # INT8 -> UINT8 - re_quantized_tensor.tmp_buffer.data = _re_quantize_uint8_to_int8(tensor_data) - re_quantized_tensor.type = tflTensorType.TensorType.INT8 - calculated_zero_point = calculate_uint_to_int_re_quantization_zero_point( - 1, re_quantized_tensor.quantization.zero_point.vector - ) - re_quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(calculated_zero_point) - ) - - elif tensor_data.dtype == np.int32 and new_dtype == np.int8: # INT32 -> INT8 - if new_zero_point is None or new_scale is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "Missing new zero_point or new scale when re-quantizing tensor.", - ) - - old_zp = re_quantized_tensor.quantization.zero_point.vector - old_scale = re_quantized_tensor.quantization.scale.vector - float_data = dequantize(tensor_data, old_scale, old_zp) - int8_data = quantize_int8(float_data, new_scale, new_zero_point) - - re_quantized_tensor.tmp_buffer.data = int8_data - re_quantized_tensor.type = tflTensorType.TensorType.INT8 - re_quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(new_zero_point) - ) - re_quantized_tensor.quantization.scale = tflite_model.Scale(list(new_scale)) - - elif tensor_data.dtype == np.int32 and new_dtype == np.uint8: # INT32 -> UINT8 - if new_zero_point is None or new_scale is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "Missing new zero_point or new scale when re-quantizing tensor.", - ) - - old_zp = re_quantized_tensor.quantization.zero_point.vector - old_scale = re_quantized_tensor.quantization.scale.vector - float_data = dequantize(tensor_data, old_scale, old_zp) - uint8_data = quantize_uint8(float_data, new_scale, new_zero_point) - - re_quantized_tensor.tmp_buffer.data = uint8_data - re_quantized_tensor.type = tflTensorType.TensorType.UINT8 - re_quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(new_zero_point) - ) - re_quantized_tensor.quantization.scale = tflite_model.Scale(list(new_scale)) - - elif tensor_data.dtype == np.int8 and new_dtype == np.int8: # INT8 -> INT8 - # Re-quantizing int8 tensor data with different quantization parameters - if new_zero_point is None or new_scale is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "Missing new zero_point or new scale when re-quantizing tensor.", - ) - - zero_point_data = re_quantized_tensor.quantization.zero_point.vector - scale_data = re_quantized_tensor.quantization.scale.vector - new_tensor_data = dequantize(tensor_data, scale_data, zero_point_data) - - re_quantized_tensor.tmp_buffer.data = quantize_int8( - new_tensor_data, new_scale, new_zero_point - ) - re_quantized_tensor.quantization.scale = tflite_model.Scale(new_scale) - re_quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - new_zero_point - ) - - elif tensor_data.dtype == np.int32 and new_dtype == np.int32: # INT32 -> INT32 - if new_zero_point is None or new_scale is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "Missing new zero_point or new scale when re-quantizing tensor.", - ) - - old_zp = re_quantized_tensor.quantization.zero_point.vector - old_scale = re_quantized_tensor.quantization.scale.vector - float_data = dequantize(tensor_data, old_scale, old_zp) - int32_data = quantize_int32(float_data, new_scale, new_zero_point) - - re_quantized_tensor.tmp_buffer.data = int32_data - re_quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(new_zero_point) - ) - re_quantized_tensor.quantization.scale = tflite_model.Scale(list(new_scale)) - - else: - logger.e( - logger.Code.NOT_IMPLEMENTED, - f"Re-quantization of static tensors from type '{tensor_data.dtype}' " - f"to type '{to_type}' is not yet implemented!", - ) - - return re_quantized_tensor - - -def quantize_static_float_tensor( - builder: "model_builder.ModelBuilder", - tflite_tensor: tflite_model.Tensor, - to_type: tflTensorType.TensorType, - scale: List[float], - zero_point: List[int], - quantized_dimension: int = 0, -) -> tflite_model.Tensor: - """Quantize tensor 'tflite_tensor' with passed quantization params. - - :param builder: A ModelBuilder instance. - :param tflite_tensor: TFLite tensor to quantize. - :param to_type: The TFLite TensorType, that the tensor will be quantized to. - :param scale: Scale quantization parameter. - :param zero_point: Zero point quantization parameter. - :param quantized_dimension: Quantized dimension. - """ - if tflite_tensor.quantization is not None: - logger.e(logger.Code.INTERNAL_ERROR, "Got tensor with quantization!") - - if tflite_tensor.tmp_buffer.data is None: - logger.e(logger.Code.INTERNAL_ERROR, "Got tensor without static data!") - - quantized_tensor = builder.duplicate_tensor(tflite_tensor) - tensor_data = quantized_tensor.tmp_buffer.data - - if zero_point is None or scale is None: - logger.e( - logger.Code.INTERNAL_ERROR, - "Missing new zero_point or new scale when quantizing tensor.", - ) - - new_dtype = tf_lite_type_to_numpy(to_type) - - if tensor_data.dtype == np.float32 and new_dtype == np.int8: - int8_data = quantize_int8(tensor_data, scale, zero_point) - - quantized_tensor.tmp_buffer.data = int8_data - quantized_tensor.type = tflTensorType.TensorType.INT8 - quantized_tensor.quantization = tflite_model.Quantization() - quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(zero_point) - ) - quantized_tensor.quantization.scale = tflite_model.Scale(list(scale)) - quantized_tensor.quantization.quantized_dimension = quantized_dimension - - elif tensor_data.dtype == np.float32 and new_dtype == np.uint8: - uint8_data = quantize_uint8(tensor_data, scale, zero_point) - - quantized_tensor.tmp_buffer.data = uint8_data - quantized_tensor.type = tflTensorType.TensorType.UINT8 - quantized_tensor.quantization = tflite_model.Quantization() - quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(zero_point) - ) - quantized_tensor.quantization.scale = tflite_model.Scale(list(scale)) - quantized_tensor.quantization.quantized_dimension = quantized_dimension - - elif tensor_data.dtype == np.float32 and new_dtype == np.int32: - int32_data = quantize_int32(tensor_data, scale, zero_point) - - quantized_tensor.tmp_buffer.data = int32_data - quantized_tensor.type = tflTensorType.TensorType.INT32 - quantized_tensor.quantization = tflite_model.Quantization() - quantized_tensor.quantization.zero_point = tflite_model.ZeroPoint( - list(zero_point) - ) - quantized_tensor.quantization.scale = tflite_model.Scale(list(scale)) - quantized_tensor.quantization.quantized_dimension = quantized_dimension - - else: - logger.e( - logger.Code.NOT_IMPLEMENTED, - f"Quantization of static tensors from type '{tensor_data.dtype}' " - f"to type '{to_type}' is not yet implemented!", - ) - - return quantized_tensor diff --git a/backends/nxp/backend/ir/logger.py b/backends/nxp/backend/ir/logger.py index ce8da2a31df..8019fb4d780 100644 --- a/backends/nxp/backend/ir/logger.py +++ b/backends/nxp/backend/ir/logger.py @@ -1,6 +1,6 @@ # # Copyright 2023 Martin Pavella -# Copyright 2023 NXP +# Copyright 2023-2025 NXP # # License: MIT # See the LICENSE_MIT for more details. @@ -85,18 +85,18 @@ class Code(Enum): PREPROCESSING_ERROR = 4 UNSUPPORTED_OPERATOR = 21 - UNSUPPORTED_ONNX_TYPE = 22 + # Code 22 was removed. UNSUPPORTED_OPERATOR_ATTRIBUTES = 23 NOT_IMPLEMENTED = 24 INVALID_TYPE = 31 INVALID_TENSOR_SHAPE = 32 - INVALID_ONNX_OPERATOR = 33 - INVALID_ONNX_OPERATOR_ATTRIBUTE = 34 - INVALID_ONNX_MODEL = 35 + # Code 33 was removed. + INVALID_OPERATOR_ATTRIBUTE = 34 + INVALID_INPUT_MODEL = 35 CONVERSION_IMPOSSIBLE = 41 - SHAPE_INFERENCE_ERROR = 42 + # Code 42 was removed. IO_PRESERVATION_ERROR = 43 INVALID_INPUT = 51 @@ -142,8 +142,6 @@ class BasicLoggingContext(LoggingContext): """ GLOBAL = LoggingContext("global") - SHAPE_INFERENCE = LoggingContext("shape_inference") - ONNX_PARSER = LoggingContext("onnx_parser") OPERATOR_CONVERSION = LoggingContext("operator_conversion") TFLITE_GENERATOR = LoggingContext("tflite_generator") QDQ_QUANTIZER = LoggingContext("qdq_quantizer") @@ -151,7 +149,7 @@ class BasicLoggingContext(LoggingContext): class NodeLoggingContext(LoggingContext): """ - ONNX node specific context. Logs reported within this context are related to node with index 'node_id'. + ExecuTorch node specific context. Logs reported within this context are related to node with index 'node_id'. """ def __init__(self, node_id): @@ -213,7 +211,7 @@ def _get_node_error(self, node_id: int, dict_item: str) -> Code | str | None: Return first error log item that belong to node with id 'node_id'. If no error is present None is returned instead. - :param node_id: ONNX node id. + :param node_id: ExecuTorch node id. :param dict_item: Dictionary item to return from `log` :return: Error code or None if there's no error related to node. """ @@ -230,7 +228,7 @@ def get_node_error_code(self, node_id: int) -> Code | None: Return first error code that belong to node with id 'node_id'. If no error is present None is returned instead. - :param node_id: ONNX node id. + :param node_id: ExecuTorch node id. :return: Error code or None if there's no error related to node. """ @@ -241,7 +239,7 @@ def get_node_error_message(self, node_id: int) -> str | None: Return first error message that belong to node with id 'node_id'. If no error is present None is returned instead. - :param node_id: ONNX node id + :param node_id: ExecuTorch node id :return: Error message or None if there is no error related to node. """ @@ -256,7 +254,7 @@ class loggingContext: Context manager used to nest logging contexts. Usage: with loggingContext(BasicLoggingContext.GLOBAL): - with loggingContext(BasicLoggingContext.ONNX_PARSER): + with loggingContext(BasicLoggingContext.OPERATOR_CONVERSION): logger.i("My log") # this log is automatically assigned to both parent contexts """ diff --git a/backends/nxp/backend/ir/tensor_formatting.py b/backends/nxp/backend/ir/tensor_formatting.py index aab22c3c368..71b697a0eba 100644 --- a/backends/nxp/backend/ir/tensor_formatting.py +++ b/backends/nxp/backend/ir/tensor_formatting.py @@ -1,13 +1,12 @@ -# # Copyright 2023 Martin Pavella -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # License: MIT # See the LICENSE_MIT for more details. # from enum import Enum -from executorch.backends.nxp.backend.node_format_inference import NodeFormat +from executorch.backends.nxp.backend.node_format import NodeFormat class TensorFormat(Enum): @@ -26,7 +25,7 @@ class TensorFormat(Enum): TRANSPOSE_CONV_2D_WEIGHT_FORMAT = 13 # No special format (matrices, vectors, shapes etc.). All tensors with the FORMATLESS format MUST have EXACTLY - # the same shape and data in the TFLite model and in the ONNX model. + # the same shape and data in the NeutronIR model and in the ExecuTorch model. FORMATLESS = 20 NONE = 30 # Format has not been identified @@ -39,8 +38,10 @@ def is_channels_last(self) -> bool: @staticmethod def from_node_format(node_format: NodeFormat): - if node_format.is_channels_first(): - return TensorFormat.CHANNELS_LAST + if node_format == NodeFormat.CHANNELS_FIRST: + return TensorFormat.CHANNELS_LAST # Format is swapped. + elif node_format == NodeFormat.CHANNELS_LAST: + return TensorFormat.CHANNELS_FIRST # Format is swapped. elif node_format == NodeFormat.FORMATLESS: return TensorFormat.FORMATLESS else: @@ -48,8 +49,21 @@ def from_node_format(node_format: NodeFormat): def to_node_format(self): if self == TensorFormat.CHANNELS_LAST: - return NodeFormat.CHANNELS_FIRST + return NodeFormat.CHANNELS_FIRST # Format is swapped. elif self == TensorFormat.FORMATLESS: return NodeFormat.FORMATLESS + elif self == TensorFormat.CHANNELS_FIRST: + return NodeFormat.CHANNELS_LAST # Format is swapped. else: return NodeFormat.NONE + + def to_equal_node_format(self): + match self: + case TensorFormat.CHANNELS_FIRST: + return NodeFormat.CHANNELS_FIRST + case TensorFormat.CHANNELS_LAST: + return NodeFormat.CHANNELS_LAST + case TensorFormat.FORMATLESS: + return NodeFormat.FORMATLESS + case _: + return NodeFormat.NONE diff --git a/backends/nxp/backend/ir/tflite_generator/tflite_model.py b/backends/nxp/backend/ir/tflite_generator/tflite_model.py index a9384861178..76a50a2e177 100755 --- a/backends/nxp/backend/ir/tflite_generator/tflite_model.py +++ b/backends/nxp/backend/ir/tflite_generator/tflite_model.py @@ -1,6 +1,5 @@ -# # Copyright 2023 Martin Pavella -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # License: MIT # See the LICENSE_MIT for more details. @@ -272,8 +271,7 @@ def is_per_tensor(self) -> bool: return False def gen_tflite(self, builder: fb.Builder): - # Sometimes 1D per-tensor quantized tensors can have quantized_dimension != 0 - # (residue from badly defined ONNX models). This would cause TFLite inference to crash. + # Sometimes 1D per-tensor quantized tensors can have quantized_dimension != 0. if not self.is_per_channel(): self.quantized_dimension = 0 @@ -513,7 +511,7 @@ class Operator(meta.TFLiteObject): tmp_outputs: List[Tensor] tmp_version: int # OperatorConverter uses this to assign the corresponding operator code with correct version. - # If `True`, this is an extra operator added during conversion. It was not present in the original ONNX model. + # If `True`, this is an extra operator added during conversion. It was not present in the original input model. tmp_added_extra: bool def __init__( diff --git a/backends/nxp/backend/ir/tflite_optimizer/operator_rules.py b/backends/nxp/backend/ir/tflite_optimizer/operator_rules.py index 253dc9c69a1..e861eff0d18 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/operator_rules.py +++ b/backends/nxp/backend/ir/tflite_optimizer/operator_rules.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -100,23 +100,3 @@ def __call__( operator_is_type(preceding_op, self.single_preceding_op_type, builder) for preceding_op in preceding_ops ) - - -@dataclass -class WasNotInTheOriginalONNXModel(OpRule): - """Assures that this operator wasn't created by converting an ONNX operator from the original model, but instead - was added extra in order to convert a different operator. - - This rule is currently only satisfied for operators added by ModelBuilder methods `create_..._before()` and - `create_..._after()`. - """ - - def __call__( - self, - op: tflite_model.Operator, - tensor_map: NameToTensorMap, - input_to_ops_map: InputTensorToOpsMap, - output_to_op_map: OutputTensorToOpMap, - builder: "model_builder.ModelBuilder", - ) -> bool: - return op.tmp_added_extra diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py index 6001ca961b8..18e397cc1bd 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py @@ -12,16 +12,21 @@ InputTensorToOpsMap, OutputTensorToOpMap, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec class BaseOptimization(ABC): _builder: "model_builder.ModelBuilder" def __init__( - self, builder: "model_builder.ModelBuilder", conversion_config: ConversionConfig + self, + builder: "model_builder.ModelBuilder", + conversion_config: ConversionConfig, + neutron_target_spec: NeutronTargetSpec, ): self._builder = builder self._conversion_config = conversion_config + self.neutron_target_spec = neutron_target_spec def _create_tensor_to_operator_dictionaries( self, diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/combine_hard_sigmoid_and_mul_to_hard_swish.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/combine_hard_sigmoid_and_mul_to_hard_swish.py deleted file mode 100755 index dddabfe87f1..00000000000 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/combine_hard_sigmoid_and_mul_to_hard_swish.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright 2024 NXP -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( - BuiltinOperator, -) -from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType -from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model -from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.hard_swish_options import ( - HardSwish, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.base_optimization import ( - BaseOptimization, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.pattern_matcher import ( - OneOf, - Op, - PatternMatcher, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.tensor_rules import ( - RuleOr, - TensorHasNConsumers, - TensorHasStaticValue, - TensorHasType, - TensorsAreQuantized, - TensorsHaveOneConsumer, - TensorsHaveType, -) - - -class CombineHardSigmoidAndMulIntoHardSwish(BaseOptimization): - - def __call__(self) -> bool: - made_changes = self._combine_float_variant() - made_changes |= self._combine_quantized_variant() - - return made_changes - - def _combine_float_variant(self) -> bool: - """Fuse some operators in the following pattern. The ops `Mul`, `Add` `Minimum` and `Relu` compute the - `HardSigmoid` operation, as there is no `HardSigmoid` operator in TFLite. - - ┌─────┴─────┐ `x` - ┌──▼──┐ │ - 1/6 ──► Mul │ │ - └──┬──┘ │ - ┌──▼──┐ │ - 1/2 ──► Add │ │ │ - └──┬──┘ │ ┌─────▼─────┐ - ┌────▼────┐ │ ─────► │ HardSwish │ - 1 ──► Minimum │ │ └─────┬─────┘ - └────┬────┘ │ - ┌──▼───┐ │ - │ Relu │ │ - └──┬───┘ │ - └───┐ ┌───┘ - ┌▼───▼┐ - │ Mul │ - └──┬──┘ - """ - - matcher = PatternMatcher( - self._builder, - [ - Op(["Mul"], ["x", "alpha"], ["mul_o"]), - OneOf( - [ - Op(["Add"], ["mul_o", "beta"], ["add_o"]), - Op(["Add"], ["beta", "mul_o"], ["add_o"]), - ] - ), - OneOf( - [ - Op(["Minimum"], ["add_o", "one"], ["min_o"]), - Op(["Minimum"], ["one", "add_o"], ["min_o"]), - ] - ), - Op(["Relu"], ["min_o"], ["relu_o"]), - OneOf( - [ - Op(["Mul"], ["x", "relu_o"], ["y"]), - Op(["Mul"], ["relu_o", "x"], ["y"]), - ] - ), - ], - [ - TensorHasNConsumers("x", 2), - TensorsHaveOneConsumer(["mul_o", "add_o", "min_o", "relu_o"]), - TensorHasStaticValue("alpha", 1 / 6), - TensorHasStaticValue("beta", 0.5), - TensorHasStaticValue("one", 1), - # `HardSwishConverter` and `HardSigmoidConverter` both only support float32. - TensorHasType("x", TensorType.FLOAT32), - ], - ) - - # The mapped operator (value) will be inserted into the model later, at the position of the `key` operator. - to_add: dict[tflite_model.Operator, tflite_model.Operator] = {} - to_remove = [] - for pattern_ops, tensor_map, _, _ in matcher.match_patterns(): - x, y = tensor_map["x"], tensor_map["y"] - hard_swish = tflite_model.Operator( - builtin_options=HardSwish(), - opcode_index=self._builder.op_code_index_for_op_type( - BuiltinOperator.HARD_SWISH - ), - ) - hard_swish.tmp_inputs = [x] - hard_swish.tmp_outputs = [y] - - to_add[pattern_ops[0]] = hard_swish - - to_remove.extend(pattern_ops) - - ops = self._builder.get_operators() - for k, v in to_add.items(): - idx = ops.index(k) - ops.insert(idx, v) - - for op in to_remove: - ops.remove(op) - - return len(to_remove) != 0 - - def _combine_quantized_variant(self) -> bool: - """Fuse some operators in the following pattern. The ops `Mul`, `Add` `Minimum` and `Relu` compute the - `HardSigmoid` operation, as there is no `HardSigmoid` operator in TFLite. - - The following pattern arises from using the `onnx2quant` on a model with `HardSwish`. The quantizer always - runs a pre-processing step which splits the ONNX `HardSwish` into `HardSigmoid` and `Mul`. It seems like it - cannot be turned off. Therefore, we cannot add QDQ quantization of `HardSwish`. But since `HardSigmoid` - gets converted to multiple TFLite operators, we also cannot really add QDQ quantization for that operator. - This means that `HardSwish` will never get fully quantized by the `onnx2quant`, and the following pattern - will be created. - We can, however, convert the entire pattern into a quantized `HardSwish` using this optimization. - - │ (u)int8 `x` - ┌─────▼──────┐ - │ Dequantize │ - └─────┬──────┘ - ┌─────┴─────┐ float32 - ┌──▼──┐ │ - 1/6 ──► Mul │ │ - └──┬──┘ │ - ┌──▼──┐ │ - 1/2 ──► Add │ │ - └──┬──┘ │ - ┌────▼────┐ │ - 1 ──► Minimum │ │ │ (u)int8 `x` - └────┬────┘ │ ┌─────▼─────┐ - ┌──▼───┐ │ ─────► │ HardSwish │ - │ Relu │ │ └─────┬─────┘ - └──┬───┘ │ │ (u)int8 `y` - ┌────▼─────┐ │ - │ Quantize │ │ - └────┬─────┘ │ - ┌─────▼──────┐ │ - │ Dequantize │ │ - └─────┬──────┘ │ - └───┐ ┌───┘ - ┌▼───▼┐ - │ Mul │ - └──┬──┘ - │ float32 - ┌────▼─────┐ - │ Quantize │ - └────┬─────┘ - │ (u)int8 `y` - """ - matcher = PatternMatcher( - self._builder, - [ - Op(["Dequantize"], ["x"], ["deq1_o"]), - OneOf( - [ - Op(["Mul"], ["deq1_o", "alpha"], ["mul1_o"]), - Op(["Mul"], ["alpha", "deq1_o"], ["mul1_o"]), - ] - ), - OneOf( - [ - Op(["Add"], ["mul1_o", "beta"], ["add_o"]), - Op(["Add"], ["beta", "mul1_o"], ["add_o"]), - ] - ), - OneOf( - [ - Op(["Minimum"], ["add_o", "one"], ["min_o"]), - Op(["Minimum"], ["one", "add_o"], ["min_o"]), - ] - ), - Op(["Relu"], ["min_o"], ["relu_o"]), - Op(["Quantize"], ["relu_o"], ["quant1_o"]), - Op(["Dequantize"], ["quant1_o"], ["deq2_o"]), - OneOf( - [ - Op(["Mul"], ["deq1_o", "deq2_o"], ["mul2_o"]), - Op(["Mul"], ["deq2_o", "deq1_o"], ["mul2_o"]), - ] - ), - Op(["Quantize"], ["mul2_o"], ["y"]), - ], - [ - TensorHasNConsumers("deq1_o", 2), - TensorsHaveOneConsumer( - [ - "mul1_o", - "add_o", - "min_o", - "relu_o", - "quant1_o", - "deq2_o", - "mul2_o", - ] - ), - TensorHasStaticValue("alpha", 1 / 6), - TensorHasStaticValue("beta", 0.5), - TensorHasStaticValue("one", 1), - TensorHasType("deq1_o", TensorType.FLOAT32), - TensorsAreQuantized(["x", "y"]), - RuleOr( - TensorsHaveType(["x", "y"], TensorType.INT8), - TensorsHaveType(["x", "y"], TensorType.UINT8), - ), - ], - ) - - # The mapped operator (value) will be inserted into the model later, at the position of the `key` operator. - to_add: dict[tflite_model.Operator, tflite_model.Operator] = {} - to_remove = [] - for pattern_ops, tensor_map, _, _ in matcher.match_patterns(): - x, y = tensor_map["x"], tensor_map["y"] - hard_swish = tflite_model.Operator( - builtin_options=HardSwish(), - opcode_index=self._builder.op_code_index_for_op_type( - BuiltinOperator.HARD_SWISH - ), - ) - hard_swish.tmp_inputs = [x] - hard_swish.tmp_outputs = [y] - - to_add[pattern_ops[0]] = hard_swish - - to_remove.extend(pattern_ops) - - ops = self._builder.get_operators() - for k, v in to_add.items(): - idx = ops.index(k) - ops.insert(idx, v) - - for op in to_remove: - ops.remove(op) - - return len(to_remove) != 0 diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/fuse_activation_functions.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/fuse_activation_functions.py deleted file mode 100755 index 6b657c4d5b1..00000000000 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/fuse_activation_functions.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright 2024 NXP -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from executorch.backends.nxp.backend.ir import logger -from executorch.backends.nxp.backend.ir.lib.tflite.ActivationFunctionType import ( - ActivationFunctionType, -) -from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( - BuiltinOperator, -) -from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model -from executorch.backends.nxp.backend.ir.tflite_optimizer.graph_utils import ( - operator_is_type, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.operator_rules import ( - NoFusedActivationFunction, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.base_optimization import ( - BaseOptimization, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.pattern_matcher import ( - Op, - PatternMatcher, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.tensor_rules import ( - TensorHasOneConsumer, -) - - -class FuseActivationFunctions(BaseOptimization): - ops_with_fused_activation_function = [ - "Conv2D", - "Conv3D", - "DepthwiseConv2D", - "TransposeConv", - "MaxPool2D", - "AveragePool2D", - "SVDF", - "FullyConnected", - "Add", - "Mul", - "Sub", - "Div", - # 'Concatenation', # currently disabled - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/concatenation.cc#L139 - # 'L2Norm', # currently disabled - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/l2norm.cc#L72 - # LSTM operators will always already have fused activation functions. They are assigned in `convert_lstm.py`. - # 'LSTM', 'UnidirectionalSequenceLSTM', 'BidirectionalSequenceLSTM' - # RNN operators will always already have fused activation functions. They are assigned in `convert_rnn.py`. - # 'RNN', 'SequenceRNN', 'BidirectionalSequenceRNN', - ] - - activation_functions = ["Relu", "ReluN1To1", "Relu6", "Tanh", "Sign"] - - supported_activations_for_op: dict[ - BuiltinOperator, list[ActivationFunctionType] - ] = { - BuiltinOperator.CONV_2D: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/conv.cc#L912 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/kernel_util.h#L285-L300 - BuiltinOperator.CONV_3D: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/conv3d.cc#L213 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/kernel_util.h#L285-L300 - BuiltinOperator.DEPTHWISE_CONV_2D: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/depthwise_conv.cc#L307 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/kernel_util.h#L285-L300 - BuiltinOperator.TRANSPOSE_CONV: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/transpose_conv.cc#L516 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/kernel_util.h#L285-L300 - BuiltinOperator.MAX_POOL_2D: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/pooling.cc#L247 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/kernel_util.h#L285-L300 - BuiltinOperator.AVERAGE_POOL_2D: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/pooling.cc#L124 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/kernel_util.h#L285-L300 - BuiltinOperator.FULLY_CONNECTED: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/fully_connected.cc#L627-L630 - BuiltinOperator.ADD: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/add.cc#L246 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/kernel_util.h#L285-L300 - BuiltinOperator.MUL: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/mul.cc#L159 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/kernel_util.h#L285-L300 - BuiltinOperator.SUB: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/sub.cc#L306 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/kernel_util.h#L285-L300 - BuiltinOperator.DIV: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/div.cc#L180 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/kernel_util.h#L285-L300 - BuiltinOperator.SVDF: [ActivationFunctionType.RELU], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/svdf.cc#L394 - BuiltinOperator.RNN: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ActivationFunctionType.TANH, - ActivationFunctionType.SIGN_BIT, - ], - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/basic_rnn.cc#L222 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/internal/kernel_utils.cc#L71 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/internal/tensor_utils.h#L58-L77 - BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_RNN: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ActivationFunctionType.TANH, - ActivationFunctionType.SIGN_BIT, - ], - # https://github.com/tensorflow/tensorflow/blob/6887368d6d46223f460358323c4b76d61d1558a8/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc#L239 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/internal/kernel_utils.cc#L71 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/internal/tensor_utils.h#L58-L77 - BuiltinOperator.BIDIRECTIONAL_SEQUENCE_RNN: [ - ActivationFunctionType.RELU, - ActivationFunctionType.RELU_N1_TO_1, - ActivationFunctionType.RELU6, - ActivationFunctionType.TANH, - ActivationFunctionType.SIGN_BIT, - ], - # https://github.com/tensorflow/tensorflow/blob/6887368d6d46223f460358323c4b76d61d1558a8/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc#L433 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/internal/kernel_utils.cc#L71 - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/internal/tensor_utils.h#L58-L77 - } - - ops_that_need_equal_io_quantization = [ - # Documented restrictions from https://www.tensorflow.org/lite/performance/quantization_spec - BuiltinOperator.AVERAGE_POOL_2D, - BuiltinOperator.MAX_POOL_2D, - BuiltinOperator.CONCATENATION, - ] - - def _act_fun_type_for_op(self, op: tflite_model.Operator) -> ActivationFunctionType: - if operator_is_type(op, "Relu", self._builder): - return ActivationFunctionType.RELU - elif operator_is_type(op, "ReluN1To1", self._builder): - return ActivationFunctionType.RELU_N1_TO_1 - elif operator_is_type(op, "Relu6", self._builder): - return ActivationFunctionType.RELU6 - elif operator_is_type(op, "Tanh", self._builder): - return ActivationFunctionType.TANH - elif operator_is_type(op, "Sign", self._builder): - return ActivationFunctionType.SIGN_BIT - - def __call__(self) -> bool: - matcher = PatternMatcher( - self._builder, - [ - Op( - self.ops_with_fused_activation_function, - ["x"], - ["x1"], - [NoFusedActivationFunction()], - ), - Op(self.activation_functions, ["x1"], ["y"]), - ], - [TensorHasOneConsumer("x1")], - ) - - to_remove = [] - for [leading_op, act_fun_op], tensor_map, _, _ in matcher.match_patterns(): - builtin_leading_op = leading_op.builtin_options.operator_type - logger.internal_assert( - builtin_leading_op in self.supported_activations_for_op.keys(), - f"FuseActivationFunctions: supported activations for operator `{builtin_leading_op}`" - "are not known.", - ) - - act_fun = self._act_fun_type_for_op(act_fun_op) - if act_fun not in self.supported_activations_for_op[builtin_leading_op]: - # The leading op doesn't support this activation function. - continue - - x, y = tensor_map["x"], tensor_map["y"] - if ( - x.quantization != y.quantization - and builtin_leading_op in self.ops_that_need_equal_io_quantization - ): - # The fusion would result in different input and output quantization of `leading_op`, which would cause - # runtime issues for that particular operator. - continue - - leading_op.builtin_options.fused_activation_function = act_fun - leading_op.tmp_outputs[0] = act_fun_op.tmp_outputs[0] - to_remove.append(act_fun_op) - - for op in to_remove: - self._builder.get_operators().remove(op) - - return len(to_remove) != 0 diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/fuse_fully_connected_and_add_operators.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/fuse_fully_connected_and_add_operators.py deleted file mode 100755 index b6fd5849551..00000000000 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/fuse_fully_connected_and_add_operators.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2024 NXP -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType -from executorch.backends.nxp.backend.ir.tflite_optimizer.operator_rules import ( - NoFusedActivationFunction, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.base_optimization import ( - BaseOptimization, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.pattern_matcher import ( - OneOf, - Op, - PatternMatcher, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.tensor_rules import ( - RuleAnd, - RuleIf, - RuleOr, - TensorDimensionsMatch, - TensorHasDimensionOfSize, - TensorHasOneConsumer, - TensorHasRank, - TensorHasType, - TensorIsQuantized, -) - - -class FuseFullyConnectedAndAddOperators(BaseOptimization): - - def __call__(self) -> bool: - """ - FullyConnected -> Add sequence can handle more complicated shapes than just FullyConnected with bias - (due to shape broadcasting). - The bias can have shape [N] or [1, N], where N is the first dimension of the FC weights tensor. - It could also have shape [1, ..., 1, N], but then the TFLite FullyConnected removes the leading ones, - even if 'keep_num_dims' is True. In ONNX, the output tensor has the leading ones, - In this case, a Reshape would have to be added, so we do not perform the fusion. - - # https://github.com/tensorflow/tensorflow/blob/v2.15.0/tensorflow/lite/kernels/fully_connected.cc#L398 - """ - matcher = PatternMatcher( - self._builder, - [ - # Require exactly 2 inputs. - Op( - ["FullyConnected"], ["x", "w"], ["y"], [NoFusedActivationFunction()] - ), - OneOf([Op(["Add"], ["y", "b"]), Op(["Add"], ["b", "y"])]), - ], - [ - TensorHasOneConsumer("y"), - TensorHasRank("w", 2), - RuleOr( - TensorHasRank("b", 1), - RuleAnd(TensorHasRank("b", 2), TensorHasDimensionOfSize("b", 0, 1)), - ), - TensorDimensionsMatch("w", 0, "b", -1), - RuleIf(TensorIsQuantized("x"), TensorHasType("b", TensorType.INT32)), - ], - ) - - to_remove = [] - for (fc, add), tensor_map, _, _ in matcher.match_patterns(): - b = tensor_map["b"] - fc.tmp_inputs.append(b) - - # Remove the 'Add' operator. - fc.tmp_outputs[0] = add.tmp_outputs[0] - fc.builtin_options.fused_activation_function = ( - add.builtin_options.fused_activation_function - ) - to_remove.append(add) - - for op in to_remove: - self._builder.get_operators().remove(op) - - return len(to_remove) != 0 diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/keep_one_empty_buffer.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/keep_one_empty_buffer.py deleted file mode 100755 index 9809719fad4..00000000000 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/keep_one_empty_buffer.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2024 NXP -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data -from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.base_optimization import ( - BaseOptimization, -) - - -class KeepOneEmptyBuffer(BaseOptimization): - - def __call__(self) -> bool: - """Create a single empty `Buffer` object and assign it to all tensors in the model that don't have static data. - :return: True, if any tensors had their buffer changed. Otherwise, False. - """ - - made_changes = False - empty_buffer = self._builder.get_first_empty_buffer() - - for t in self._builder.get_tensors().vector: - if tensor_has_data(t): - # The buffer of `t` is not empty. - continue - - if t.tmp_buffer == empty_buffer: - # Already optimized. - continue - - if t.is_variable: - # The data of the tensor will change at runtime, so it shouldn't share the buffer with other tensors. - continue - - # It's safe to replace the buffer. - t.tmp_buffer = empty_buffer - made_changes = True - - return made_changes diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/move_relu_before_concat.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/move_relu_before_concat.py deleted file mode 100755 index 4d10b7c80ae..00000000000 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/move_relu_before_concat.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2024 NXP -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from collections import defaultdict -from copy import deepcopy - -from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model -from executorch.backends.nxp.backend.ir.tflite_optimizer.operator_rules import ( - AllInputsComeFrom, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.base_optimization import ( - BaseOptimization, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.pattern_matcher import ( - Op, - PatternMatcher, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.tensor_rules import ( - TensorHasOneConsumer, - TensorsHaveSameQuantization, -) - - -class MoveActivationBeforeConcatenation(BaseOptimization): - """ - Move some operators around in the following pattern. - This is a common pattern that emerges from the conversion of separable convolutions. - - │ │ │ │ - ┌───▼────┐ ┌───▼────┐ ┌───▼────┐ ┌───▼────┐ - │ Conv2D │ ... │ Conv2D │ │ Conv2D │ ... │ Conv2D │ - └───┬────┘ └───┬────┘ └───┬────┘ └───┬────┘ - └──┐ ┌──┘ │ │ - ┌──▼──────────▼─┐ ┌──▼───┐ ┌──▼───┐ - │ Concatenation │ ─────► │ Relu │ ... │ Relu │ - └───────┬───────┘ └──┬───┘ └──┬───┘ - │ 'x' └──┐ ┌──┘ - ┌──▼───┐ ┌──▼──────────▼─┐ - │ Relu │ │ Concatenation │ - └──┬───┘ └───────┬───────┘ - │ 'y' │ - """ - - activations = ["Relu", "ReluN1To1", "Relu6", "Tanh", "Sign"] - - def __call__(self) -> bool: - matcher = PatternMatcher( - self._builder, - [ - Op(["Concatenation"], None, ["x"], [AllInputsComeFrom("Conv2D")]), - Op(self.activations, ["x"], ["y"]), - ], - [ - TensorHasOneConsumer("x"), - # If the activation function is not changing the quantization parameters, it can be moved without - # messing with the quantization elsewhere. - TensorsHaveSameQuantization(["x", "y"]), - ], - ) - - to_remove = [] - - # Mapping an operator to a list of operators. These operators (value) will later be added into the TFLite - # model's `operators` in front of the specified operator (key). - to_add: dict[tflite_model.Operator, list[tflite_model.Operator]] = defaultdict( - lambda: [] - ) - - for [concat, activation], _, _, _ in matcher.match_patterns(): - new_concat_inputs = [] - for concat_input in concat.tmp_inputs: - # Create a new operator for the activation function. - new_activation = deepcopy(activation) - new_activation.tmp_inputs = [concat_input] - new_activation_output = self._builder.duplicate_tensor(concat_input) - new_activation.tmp_outputs = [new_activation_output] - - to_add[concat].append( - new_activation - ) # Insert the new activation into the model later. - - new_concat_inputs.append( - new_activation_output - ) # Connect the activation with the `Concatenation`. - - concat.tmp_inputs = new_concat_inputs - - # Tensor rule ensures that only the activation functions is using the output of the `Concatenation`. - # It is safe to bypass. - concat.tmp_outputs[0] = activation.tmp_outputs[0] - to_remove.append(activation) - - operators = self._builder.get_operators() - - # Add the new activations into the model. - for concat, activations in to_add.items(): - idx = operators.index(concat) - for activation in activations: - operators.insert(idx, activation) - - # Remove the old activations. - for activation in to_remove: - operators.remove(activation) - - return len(to_remove) != 0 diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/permute_fully_connected_weights_after_reshape.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/permute_fully_connected_weights_after_reshape.py index 42eefc1ab56..ef76fad90de 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/permute_fully_connected_weights_after_reshape.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizations/permute_fully_connected_weights_after_reshape.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -50,7 +50,7 @@ def __call__(self) -> bool: How it works: - The original model doesn't have the `Transpose`. It just has `Reshape` into `MatMul` (or `Gemm`...). - The `Transpose` is added, because the `Reshape` has a channels last input, which was originally - channels first (in the ONNX model), and so the 2D output of the `Reshape` would have the same data. + channels first (in the ExecuTorch model), and so the 2D output of the `Reshape` would have the same data. but at different locations. The `Transpose` makes the input channels first, which ensures correct output of the `Reshape`. - In the scenario in the graph above, it is possible to omit the `Transpose`, which causes the `Reshape` @@ -85,12 +85,12 @@ def __call__(self) -> bool: for (transpose, reshape, fc), tensor_map, _, _ in matcher.match_patterns(): # Make sure the `Transpose` is applying the expected permutation. y = tensor_map["y"] - to_onnx_perm = ( + to_executorch_perm = ( translator.create_channels_last_to_channels_first_permutation( y.shape.len() ) ) - if not np.allclose(to_onnx_perm, tensor_map["perm"].tmp_buffer.data): + if not np.allclose(to_executorch_perm, tensor_map["perm"].tmp_buffer.data): continue # The `Transpose` has an unexpected permutation. w = tensor_map["w"] diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py index dc9ad9999b4..053e53d9df8 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -24,10 +24,14 @@ TensorIsNotModelOutput, TensorsHaveData, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + transposition_is_supported_on_neutron, +) class FuseTransposeOperators(BaseOptimization): - """Remove some `Transpose` operators in the following pattern. + """Remove some `Transpose` operators in the following pattern. This is only done if the resulting permutation is + supported on Neutron. │ 'x' ┌─────▼─────┐ @@ -61,12 +65,27 @@ def __call__(self) -> bool: ) in matcher.match_patterns(): x = tensor_map["x"] perm1 = tensor_map["perm1"].tmp_buffer.data + combined_perms = [] # Remove the leading transpose. for second_transpose in following_transposes: # Combine the permutations for a new permutation of the second `Transpose`. perm2 = second_transpose.tmp_inputs[1].tmp_buffer.data - combined_perm = np.array(combine_permutations(perm1, perm2), np.int32) + combined_perms.append( + np.array(combine_permutations(perm1, perm2), np.int32) + ) + + if not all( + transposition_is_supported_on_neutron( + x.shape.vector, list(perm), self.neutron_target_spec + ) + for perm in combined_perms + ): + continue # Avoid creating an unsupported permutation. + + for second_transpose, combined_perm in zip( + following_transposes, combined_perms + ): second_transpose.tmp_inputs[1] = self._builder.create_tensor_for_data( combined_perm, "perm" ) diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizer.py b/backends/nxp/backend/ir/tflite_optimizer/optimizer.py index eb4ce6a5992..1a96422e377 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizer.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizer.py @@ -11,21 +11,6 @@ from executorch.backends.nxp.backend.ir import logger from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig -from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.combine_hard_sigmoid_and_mul_to_hard_swish import ( - CombineHardSigmoidAndMulIntoHardSwish, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.fuse_activation_functions import ( - FuseActivationFunctions, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.fuse_fully_connected_and_add_operators import ( - FuseFullyConnectedAndAddOperators, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.keep_one_empty_buffer import ( - KeepOneEmptyBuffer, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.move_relu_before_concat import ( - MoveActivationBeforeConcatenation, -) from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.permute_fully_connected_weights_after_reshape import ( PermuteFullyConnectedWeightsAfterReshape, ) @@ -33,21 +18,15 @@ FuseTransposeOperators, RemoveIdentityTransposeOperators, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec class Optimization(Enum): - KEEP_ONE_EMPTY_BUFFER = 0 - FUSE_ACTIVATION_FUNCTIONS = 1 - FUSE_FULLY_CONNECTED_AND_ADD = 2 - FUSE_TRANSPOSE_OPERATORS = 5 REMOVE_IDENTITY_TRANSPOSE_OPERATORS = 6 PERMUTE_FULLY_CONNECTED_WEIGHTS_AFTER_RESHAPE = 12 - MOVE_ACTIVATION_BEFORE_CONCAT = 15 - COMBINE_HARD_SIGMOID_AND_MUL_INTO_HARD_SWISH = 16 - class Optimizer: """ @@ -72,33 +51,19 @@ def __init__( self, builder: "model_builder.ModelBuilder", # noqa F821 conversion_config: ConversionConfig, + neutron_target_spec: NeutronTargetSpec, ): self._builder = builder self.optimization_map = { - Optimization.KEEP_ONE_EMPTY_BUFFER: KeepOneEmptyBuffer( - builder, conversion_config - ), - Optimization.FUSE_ACTIVATION_FUNCTIONS: FuseActivationFunctions( - builder, conversion_config - ), - Optimization.FUSE_FULLY_CONNECTED_AND_ADD: FuseFullyConnectedAndAddOperators( - builder, conversion_config - ), Optimization.FUSE_TRANSPOSE_OPERATORS: FuseTransposeOperators( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), Optimization.REMOVE_IDENTITY_TRANSPOSE_OPERATORS: RemoveIdentityTransposeOperators( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), Optimization.PERMUTE_FULLY_CONNECTED_WEIGHTS_AFTER_RESHAPE: PermuteFullyConnectedWeightsAfterReshape( - builder, conversion_config - ), - Optimization.MOVE_ACTIVATION_BEFORE_CONCAT: MoveActivationBeforeConcatenation( - builder, conversion_config - ), - Optimization.COMBINE_HARD_SIGMOID_AND_MUL_INTO_HARD_SWISH: CombineHardSigmoidAndMulIntoHardSwish( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), } diff --git a/backends/nxp/backend/neutron_converter_manager.py b/backends/nxp/backend/neutron_converter_manager.py index 99d7715ebf0..0e110ee9b9f 100644 --- a/backends/nxp/backend/neutron_converter_manager.py +++ b/backends/nxp/backend/neutron_converter_manager.py @@ -1,11 +1,23 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import importlib +import logging +import multiprocessing import pkgutil -from executorch.backends.nxp.backend.ir.converter.node_converter import Target + +def convert_unsafe(neutron_converter, tflite_model, cctx, queue): + """ + Run neutron_converter on given tflite_model with compilation context cctx. + This routine is supposed to run in a separate process. + If properly finished, the output queue contains the converted model, + otherwise the neutron_converter exits and the output queue is empty. + """ + model_converted = neutron_converter.convertModel(list(tflite_model), cctx) + queue.put(model_converted) class NeutronConverterManager: @@ -14,44 +26,100 @@ class NeutronConverterManager: contains NeutronGraph nodes. """ - _supported_target_names = [Target.RT700.value] - - def convert( - self, tflite_model: bytes, target: str, neutron_converter_flavor: str - ) -> bytes: - # Neutron converter crashes if we provide invalid target -> verify. - if target not in self._supported_target_names: - raise RuntimeError( - f"Target '{target}' is not supported by NeutronConverterManager." - ) + def __init__( + self, + neutron_converter_flavor: str = "SDK_25_09", + ): neutron_converter_modules = [ module.name for module in pkgutil.iter_modules() if module.name.startswith("neutron_converter") + or module.name == "eiq_neutron_sdk" ] - requested_module_name = f"neutron_converter_{neutron_converter_flavor}" + if neutron_converter_flavor: + requested_module_name = f"neutron_converter_{neutron_converter_flavor}" + print( + "Warning: The use of converter flavors will be deprecated. Use empty string to select 'eiq_neutron_sdk' module." + ) + else: + requested_module_name = "eiq_neutron_sdk" + if requested_module_name not in neutron_converter_modules: if len(neutron_converter_modules) > 0: raise RuntimeError( - f"Neutron Converter module with flavor '{neutron_converter_flavor}' " + f"Neutron Converter module '{requested_module_name}' " f"not found. Available modules: {neutron_converter_modules}." ) else: raise RuntimeError( - f"Neutron Converter module with flavor '{neutron_converter_flavor}' " - f"not found. Install 'neutron_converter_[flavor]' Python package." + f"Neutron Converter module '{requested_module_name}' " + f"not found. Install 'eiq_neutron_sdk' or 'neutron_converter_[flavor]' Python package." ) - neutron_converter = importlib.import_module( + self.neutron_converter = importlib.import_module( f"{requested_module_name}.neutron_converter" ) + self.neutron_library_utils = importlib.import_module( + f"{requested_module_name}.neutron_library_utils" + ) + + def get_converter(self): + return self.neutron_converter + + def get_library_utils(self): + return self.neutron_library_utils + + def verify_target(self, target: str): + if not self.neutron_library_utils.isNeutronTarget(target): + valid_targets = [ + target.name for target in self.neutron_library_utils.getNeutronTargets() + ] + raise ValueError( + f"Target `{target}` is not a valid target. Must be one of `{valid_targets}`." + ) + + def convert(self, tflite_model: bytes, target: str) -> bytes: + # Neutron converter crashes if we provide invalid target -> verify. + self.verify_target(target) - cctx = neutron_converter.CompilationContext() - cctx.targetOpts = neutron_converter.getNeutronTarget(target) - # New switch since Neutron Converter SDK_25.06 + cctx = self.neutron_converter.CompilationContext() + cctx.targetOpts = self.neutron_converter.getNeutronTarget(target) cctx.compilationOpts.minNumOpsPerGraph = 1 - model_converted = neutron_converter.convertModel(list(tflite_model), cctx) + cctx.compilationOpts.excludeGraphPasses = ( + "HoistSliceAboveTranspose,MergeTranspose" + ) + + # Try to use multiprocessing for isolation, but fall back to direct execution + # if the environment doesn't support it (e.g., in sandcastle/build environments) + try: + logger = multiprocessing.log_to_stderr() + logger.setLevel(logging.WARNING) + queue = multiprocessing.Manager().Queue() + + process = multiprocessing.Process( + target=convert_unsafe, + args=(self.neutron_converter, tflite_model, cctx, queue), + ) + process.start() + process.join() # waits until the subprocess is complete + + if queue.empty(): # signals the unsafe task did not run till the end + raise RuntimeError( + f"Neutron converter module terminated unexpectedly with exit code {process.exitcode}" + ) + + model_converted = queue.get() + process.close() + except (EOFError, OSError) as e: + # Multiprocessing failed (likely due to environment restrictions) + # Fall back to direct execution + logging.warning( + f"Multiprocessing not available ({e}), running neutron converter directly" + ) + model_converted = self.neutron_converter.convertModel( + list(tflite_model), cctx + ) return bytes(model_converted) diff --git a/backends/nxp/backend/neutron_operator_support.py b/backends/nxp/backend/neutron_operator_support.py new file mode 100644 index 00000000000..cdb46870b2e --- /dev/null +++ b/backends/nxp/backend/neutron_operator_support.py @@ -0,0 +1,79 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec + + +def is_tensor_invariant_permutation( + input_shape: list[int], permutation: list[int] +) -> bool: + def input_dim_is_not_one(index): + return input_shape[index] != 1 + + new_permutation = list(filter(input_dim_is_not_one, permutation)) + + return new_permutation == sorted(new_permutation) + + +def transposition_is_supported_on_neutron( + input_shape: list[int], + permutation: list[int], + neutron_target_spec: NeutronTargetSpec, +) -> bool: + """This function determines if the current NeutronSoftware properly supports a `Transpose` operator with given + `input_shape` and `permutation`. + + :param input_shape: The shape of the main input tensor of the `Transpose` operator. + :param permutation: The permutation the `Transpose` operator is computing. + :param neutron_target_spec: Object for querying the target platform to retrieve its properties. + """ + num_macs = neutron_target_spec.get_num_macs() + + if is_tensor_invariant_permutation(input_shape, permutation): + # The `Transpose` will be turned into a `Reshape` by Neutron. The check includes the identity permutation. + return True + + if permutation == [0, 3, 1, 2]: + # NHWC -> NCHW + n, h, w, c = input_shape + + if h * w * c % num_macs != 0: # Official Neutron requirement. + return False + + if not ( + c % num_macs == 0 and h * w % num_macs == 0 + ): # Neutron would produce incorrect outputs. + return False + + if n != 1: + # Neutron only supports `Transpose` operators where the dimensions can be combined into 2 consecutive + # groups. These 2 groups are then transposed like a matrix, and the result is reshaped. Therefore, for the + # [0, 3, 1, 2] permutation, when h * w != 1 and c != 1, batch size must be 1. + return False + + return True + + elif permutation == [0, 2, 3, 1]: + # NCHW -> NHWC + + n, c, h, w = input_shape + + if w % num_macs != 0: # Official Neutron requirement. + return False + + if not ( + c % num_macs == 0 and h * w % num_macs == 0 + ): # Neutron would produce incorrect outputs. + return False + + if n != 1: + # Neutron only supports `Transpose` operators where the dimensions can be combined into 2 consecutive + # groups. These 2 groups are then transposed like a matrix, and the result is reshaped. Therefore, for the + # [0, 2, 3, 1] permutation, when h * w != 1 and c != 1, batch size must be 1. + return False + + return True + + return False diff --git a/backends/nxp/backend/neutron_target_spec.py b/backends/nxp/backend/neutron_target_spec.py new file mode 100644 index 00000000000..ce187be5982 --- /dev/null +++ b/backends/nxp/backend/neutron_target_spec.py @@ -0,0 +1,143 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Target Spec for the NXP Neutron NPU + +from enum import Enum + +import torch + +from executorch.backends.nxp.backend.neutron_converter_manager import ( + NeutronConverterManager, +) +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.fx import Node + + +class NeutronHWVersion(Enum): + N1 = 1 + N3 = 2 + + +class NeutronTargetNeutronC: + @staticmethod + def is_supported_fused_activation__aten(node_: Node) -> bool: + """Node operator is supported fused activation on Neutron for Linear and Conv2D.""" + return node_.op == "call_function" and ( + node_.target + in ( + torch.ops.aten.relu.default, # TODO Add torch.ops.aten.leaky_relu.default once it is supported + torch.ops.aten.relu_.default, + torch.ops.aten.sigmoid.default, + torch.ops.aten.sigmoid_.default, + torch.ops.aten.tanh.default, + torch.ops.aten.tanh_.default, + ) + or ( + ( + node_.target == torch.ops.aten.hardtanh.default + or node_.target == torch.ops.aten.hardtanh_.default + ) + and ( + node_.args[1:3] == (0.0, 6.0) # is converted to Relu6 + or node_.args[1:3] == (0.0, float("inf")) # is converted to Relu + ) + ) + ) + + @staticmethod + def is_supported_fused_activation__edge(node_: Node) -> bool: + """Node operator is supported fused activation on Neutron for Linear and Conv2D.""" + return node_.op == "call_function" and ( + node_.target + in ( + exir_ops.edge.aten.relu.default, # TODO Add torch.ops.aten.leaky_relu.default once it is supported + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.tanh.default, + ) + or ( + (node_.target == exir_ops.edge.aten.hardtanh.default) + and ( + node_.args[1:3] == (0.0, 6.0) # is converted to Relu6 + or node_.args[1:3] == (0.0, float("inf")) # is converted to Relu + ) + ) + ) + + @staticmethod + def is_fusable_conv_or_linear__aten(node_: Node) -> bool: + """Node operator is supported fusable Linear or Conv2D on Neutron.""" + return node_.op == "call_function" and ( + node_.target == torch.ops.aten.conv2d.default + or node_.target == torch.ops.aten.addmm.default + or node_.target == torch.ops.aten.mm.default + or ( + node_.target == torch.ops.aten.linear.default + and len(node_.meta["val"].shape) == 2 + ) + ) + + @staticmethod + def is_fusable_conv_or_linear__edge(node_: Node) -> bool: + """Node operator in edge dialect is supported fusable Linear or Conv2D on Neutron.""" + return node_.op == "call_function" and ( + node_.target == exir_ops.edge.aten.addmm.default + or node_.target == exir_ops.edge.aten.mm.default + or ( + node_.target == exir_ops.edge.aten.convolution.default + and len(node_.meta["val"].shape) == 4 + ) + ) + + +class NeutronTargetSpec: + """ + The functionality for probing the properties of Neutron Target. + """ + + def __init__(self, target: str, neutron_converter_flavor: str): + + converter_manager = NeutronConverterManager(neutron_converter_flavor) + converter_manager.verify_target(target) + neutron_converter = converter_manager.get_converter() + self.neutron_target = neutron_converter.getNeutronTarget(target) + + if self.is_subsystem(): + raise ValueError( + f"Target `{target}` is not a neutron-C target. Only MCU targets are supported at the moment." + ) + + if self.get_hw_version() != NeutronHWVersion.N3: + raise ValueError( + f"Target `{target}` contains unsupported HW version. Only N3/N3+ targets are supported at the moment." + ) + + # Now only Neutron-C is supported + self.neutron_target_info = NeutronTargetNeutronC() + + # Target name. + def get_name(self) -> str: + return self.neutron_target.name + + # Whether the target has subsystem (Neutron-S) or not (Neutron-C). + def is_subsystem(self) -> bool: + return self.neutron_target.subsystem + + # Number of compute units. + def get_num_units(self) -> int: + return self.neutron_target.numUnits + + # Number of compute pipelines. + def get_num_pipes(self) -> int: + return self.neutron_target.numPipes + + # Number of compute MACs. + def get_num_macs(self) -> int: + return self.neutron_target.numMacs + + # Neutron compute block hardware version. + def get_hw_version(self) -> NeutronHWVersion: + return NeutronHWVersion(self.neutron_target.hwVersion) diff --git a/backends/nxp/backend/node_format.py b/backends/nxp/backend/node_format.py new file mode 100644 index 00000000000..fd54e2365ed --- /dev/null +++ b/backends/nxp/backend/node_format.py @@ -0,0 +1,26 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + +# Key into the `meta` attribute of nodes, which is mapped to their inferred node format. +NXP_NODE_FORMAT = "nxp_node_format" + + +class NodeFormat(Enum): + # Node's output in NCHW format + CHANNELS_FIRST = 0 + + # Node's output format has no meaning + FORMATLESS = 1 + + # Format has not been identified + NONE = 2 + + # NHWC + CHANNELS_LAST = 3 + + def is_channels_first(self) -> bool: + return self == NodeFormat.CHANNELS_FIRST diff --git a/backends/nxp/backend/node_format_inference.py b/backends/nxp/backend/node_format_inference.py index 76b05d172a4..244fd76d588 100644 --- a/backends/nxp/backend/node_format_inference.py +++ b/backends/nxp/backend/node_format_inference.py @@ -1,33 +1,22 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import logging -from enum import Enum +import operator +from executorch.backends.nxp.backend.edge_program_converter import functions_converters +from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload -from torch import Node from torch.export import ExportedProgram +from torch.fx import Node logger = logging.getLogger(__name__) -class NodeFormat(Enum): - # Node's output in NCHW format - CHANNELS_FIRST = 0 - - # Node's output format has no meaning - FORMATLESS = 1 - - # Format has not been identified - NONE = 2 - - def is_channels_first(self) -> bool: - return self == NodeFormat.CHANNELS_FIRST - - class NodeFormatInference: # Dictionary with Edge Aten ops that always use channels first format. # The op in the dictionary is mapped to a dictionary, which holds indices to input nodes @@ -41,9 +30,10 @@ class NodeFormatInference: # A set of Edge Aten ops, which have the ability to change the format (for example - input nodes # are channels first but output is formatless). - ops_that_can_change_tensor_format = {exir_ops.edge.aten.view_copy.default} - - _node_format_mapping: dict[Node, NodeFormat] + ops_that_can_change_tensor_format = { + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + } _type_changed_during_last_run: bool @@ -53,11 +43,13 @@ class NodeFormatInference: # Mapping between Node and its children (outputs) _node_outputs: dict[Node, list[Node]] + # List of all edge operations, which are supported by the converter. + _known_targets: list[EdgeOpOverload] + def __init__(self, edge_program: ExportedProgram): self._edge_program = edge_program self._nodes = edge_program.graph.nodes - self._node_format_mapping = {} self._node_inputs = { node: node.all_input_nodes for node in edge_program.graph.nodes } @@ -67,7 +59,14 @@ def __init__(self, edge_program: ExportedProgram): self._type_changed_during_last_run = False - def identify_node_formats(self) -> dict[Node, NodeFormat]: + self._known_targets = list(functions_converters) + [ + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + operator.getitem, + ] + + def identify_node_formats(self): self._type_changed_during_last_run = True # Re-run format inference until there are no changes @@ -77,25 +76,55 @@ def identify_node_formats(self) -> dict[Node, NodeFormat]: for node in self._nodes: self._infer_format_of_nodes(node) - return self._node_format_mapping + for node in self._nodes: + if self._get_node_op_type(node) is None: + continue + if not hasattr(node, "meta"): + logging.warning(f"Node `{node}` does not have the `meta` attribute.") + node.meta = {} + if NXP_NODE_FORMAT not in node.meta: + logging.warning(f"Node `{node}` does not have inferred format.") + node.meta[NXP_NODE_FORMAT] = NodeFormat.NONE def _infer_format_of_nodes(self, node: Node): op_type = self._get_node_op_type(node) if op_type in self.ops_with_channels_first_nodes: self._handle_node_which_uses_channels_first_format(node) + elif op_type in self.ops_that_can_change_tensor_format: - if op_type == exir_ops.edge.aten.view_copy.default: # view_copy + if op_type in [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + ]: + # Try to assign the `formatless` format to the input and output. The converter will then handle the + # transition. + # Note: If the format for the input/output has already been assigned as channels first, it will NOT be + # overwritten. self._assign_format_to_node( self._node_outputs[node][0], NodeFormat.FORMATLESS ) + self._assign_format_to_node( + self._node_inputs[node][0], NodeFormat.FORMATLESS + ) + else: logger.error( f"Node format inference for node type: {op_type} not found!" ) - else: + elif node.op != "call_function" or ( + hasattr(node, "target") and node.target in self._known_targets + ): + # Generic node, or tensor. self._handle_node_which_can_use_any_node_format(node) + else: + # Don't infer the format for unknown nodes. These nodes will never be delegated, so they will divide + # delegated partitions. Propagating the format here could unnecessarily enforce the format in one of these + # partitions, which would require extra transpositions. + for processed_node in self._node_inputs[node] + [node]: + self._assign_format_to_node(processed_node, NodeFormat.NONE) + def _infer_format_based_on_io_ranks(self, node: Node): """Determine the format of the output tensor of given "reshape style operator" based on the ranks of its input and output. @@ -148,10 +177,14 @@ def _assign_format_to_node(self, node: Node, node_format: NodeFormat): # Once CHANNEL_FIRST was assigned, we don't want to reassign return + if node_format is NodeFormat.NONE and old_node_format is not NodeFormat.NONE: + # A format has already been assigned to the node before. Don't replace it with `NONE`. + return + if old_node_format != node_format: self._type_changed_during_last_run = True - self._node_format_mapping[node] = node_format + node.meta[NXP_NODE_FORMAT] = node_format def _get_node_op_type(self, node: Node) -> str | None: """ @@ -252,8 +285,10 @@ def _node_produces_or_consumes_channels_first_format(self, node) -> bool: for ancestor_node in input_nodes ) - def _get_node_format(self, node): - return self._node_format_mapping.get(node, NodeFormat.NONE) + def _get_node_format(self, node) -> NodeFormat: + if not hasattr(node, "meta"): + node.meta = {} + return node.meta.get(NXP_NODE_FORMAT, NodeFormat.NONE) - def _node_is_placeholder(self, node: Node): + def _node_is_placeholder(self, node: Node) -> bool: return node.op == "placeholder" diff --git a/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py b/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py index d88684b86f0..14c4890a202 100644 --- a/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py +++ b/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py @@ -15,6 +15,12 @@ AddMM = exir_ops.edge.aten.addmm.default ViewCopy = exir_ops.edge.aten.view_copy.default MM = exir_ops.edge.aten.mm.default +Conv = exir_ops.edge.aten.convolution.default +HardTanh = exir_ops.edge.aten.hardtanh.default +Relu = exir_ops.edge.aten.relu.default +Sigmoid = exir_ops.edge.aten.sigmoid.default +Tanh = exir_ops.edge.aten.tanh.default +CloneDimOrder = exir_ops.edge.dim_order_ops._clone_dim_order.default def insert_qdq_pair_after_node( @@ -97,6 +103,9 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): MM: [ ViewCopy, ], + ViewCopy: [ + CloneDimOrder, + ], } def run(self, graph_module: torch.fx.GraphModule) -> PassResult: @@ -175,9 +184,23 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): main_cluster_node_to_auxiliary_nodes = { AddMM: [ ViewCopy, + HardTanh, + Relu, + Sigmoid, + Tanh, ], MM: [ ViewCopy, + HardTanh, + Relu, + Sigmoid, + Tanh, + ], + Conv: [ + HardTanh, + Relu, + Sigmoid, + Tanh, ], } diff --git a/backends/nxp/edge_passes/neutron_edge_pass_manager.py b/backends/nxp/edge_passes/neutron_edge_pass_manager.py index ec46070ac31..a4953d74b78 100644 --- a/backends/nxp/edge_passes/neutron_edge_pass_manager.py +++ b/backends/nxp/edge_passes/neutron_edge_pass_manager.py @@ -3,22 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import copy - from executorch.backends.nxp.edge_passes.move_auxiliary_operator_into_separate_qdq_cluster_pass import ( MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass, MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass, ) from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass -from executorch.exir import EdgeProgramManager -from executorch.exir.program._program import ( - _get_updated_graph_signature, - _get_updated_range_constraints, -) - -from torch import nn -from torch.export import ExportedProgram -from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.infra.pass_manager import PassManager @@ -34,56 +23,3 @@ def __init__(self, passes: list[NeutronEdgePass] = None): passes, steps=10, # Empirical value. At most 10 cycles of passes will be run. ) - - def _transform_graph_module(self, module: nn.Module) -> PassResult: - """Apply the passes to a single graph module.""" - pass_result: PassResult = super().__call__(module) - - graph_module = pass_result.graph_module - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - - return pass_result - - def __call__(self, epm: EdgeProgramManager) -> EdgeProgramManager: - """Apply the passes to all graph modules in the edge program.""" - new_programs: dict[str, ExportedProgram] = {} - - for name, program in epm._edge_programs.items(): - pass_result = self._transform_graph_module(program.graph_module) - - if pass_result.modified: - # Create a new exported program. - new_program = ExportedProgram( - root=pass_result.graph_module, - graph=pass_result.graph_module.graph, - graph_signature=_get_updated_graph_signature( - program.graph_signature, pass_result.graph_module - ), - state_dict=program.state_dict, - range_constraints=_get_updated_range_constraints( - pass_result.graph_module - ), - module_call_graph=copy.deepcopy(program._module_call_graph), - example_inputs=program.example_inputs, - constants=program.constants, - verifiers=[program.verifier], - ) - new_program.graph_module.meta.update(program.graph_module.meta) - new_program.graph_module.meta.update(pass_result.graph_module.meta) - - else: - # Keep the old exported program. - new_program = program - - new_programs[name] = new_program - - if len(new_programs) == 0: - # No passes were run, return the old EdgeProgramManager. - return epm - - else: - # Return a new EdgeProgramManager with the updated programs. - return EdgeProgramManager( - new_programs, copy.deepcopy(epm._config_methods), epm.compile_config - ) diff --git a/backends/nxp/edge_passes/remove_additional_quantize_dequantize_nodes_pass.py b/backends/nxp/edge_passes/remove_additional_quantize_dequantize_nodes_pass.py new file mode 100644 index 00000000000..4edcc0b0e97 --- /dev/null +++ b/backends/nxp/edge_passes/remove_additional_quantize_dequantize_nodes_pass.py @@ -0,0 +1,111 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from executorch.backends.nxp.backend.edge_helper import get_quantization_parameters_for +from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass +from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer +from executorch.exir.dialects._ops import ops as exir_ops +from torch.fx.passes.infra.pass_base import PassResult + + +class RemoveAdditionalQDQClustersPass(NeutronEdgePass): + """ + After delegation of partitions, there may be additional dequantize quantize nodes for QDQ clusters that were + not delegated. If dequantize quantize nodes are quantized per tensor and quantization parameters of dequantize + and quantize nodes in a QDQ cluster are equal, the nodes can be removed and thus the inner nodes computed in int8. + + │ + ┌────────────▼──────────┐ + │ dequantize_per_tensor │ + └────────────┬──────────┘ + │ │ + ┌───▼──┐ replace with ┌───▼──┐ + │ node │ ──────────────► │ node │ + └───┬──┘ └───┬──┘ + │ ▼ + ┌───────────▼─────────┐ + │ quantize_per_tensor │ + └───────────┬─────────┘ + ▼ + + """ + + qdq_per_channel_nodes = ( + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + ) + + qdq_per_tensor_nodes = ( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + ) + + def run(self, graph_module: torch.fx.GraphModule) -> PassResult: + nodes = list(graph_module.graph.nodes) + qdq_clusterer = QDQClusterRecognizer() + qdq_clusterer.tag_qdq_clusters(nodes) + + for cluster in qdq_clusterer.cluster_map.values(): + # For now, enable only permute_copy and cat. + if cluster.compute_node.target not in [ + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.cat.default, + ]: + continue + + # Ensure cluster doesn't contain dequantize/quantize per channel nodes. + if any( + node + for node in cluster.ops + if node.target in self.qdq_per_channel_nodes + ): + continue + + qdq_nodes = [ + node for node in cluster.ops if node.target in self.qdq_per_tensor_nodes + ] + + qdq_nodes_quant_params = [ + get_quantization_parameters_for(node) for node in qdq_nodes + ] + + equal_quant_scales = [ + np.allclose( + qdq_nodes_quant_params[idx][0], qdq_nodes_quant_params[idx + 1][0] + ) + for idx in range(len(qdq_nodes_quant_params[:-1])) + ] + + equal_quant_zero_points = [ + np.allclose( + qdq_nodes_quant_params[idx][1], qdq_nodes_quant_params[idx + 1][1] + ) + for idx in range(len(qdq_nodes_quant_params[:-1])) + ] + + # Check if all quantization params are equal to ensure that QDQ cluster can be removed. + if not all(equal_quant_scales + equal_quant_zero_points): + continue + + # Replace the uses of each dequantize/quantize node with its arg node. + for qdq_node in qdq_nodes: + qdq_node.replace_all_uses_with(qdq_node.args[0]) + graph_module.graph.erase_node(qdq_node) + + # Remove compute node cluster info from node meta. + cluster.compute_node.meta.pop("cluster") + + graph_module = self.recompile_module(graph_module) + + # The graph has now changed, and we cannot keep iterating through it. Return the new graph and the parent + # class will call this pass again. + return PassResult(graph_module, True) + + return PassResult(graph_module, False) diff --git a/backends/nxp/backend/ir/edge_passes/remove_io_quant_ops_pass.py b/backends/nxp/edge_passes/remove_io_quant_ops_pass.py similarity index 100% rename from backends/nxp/backend/ir/edge_passes/remove_io_quant_ops_pass.py rename to backends/nxp/edge_passes/remove_io_quant_ops_pass.py diff --git a/backends/nxp/neutron_node_extraction.py b/backends/nxp/neutron_node_extraction.py index 9d2431d29ed..b1ff4ae7310 100644 --- a/backends/nxp/neutron_node_extraction.py +++ b/backends/nxp/neutron_node_extraction.py @@ -21,6 +21,7 @@ class NeutronNodeArtifacts: microcode: np.ndarray weights: np.ndarray kernels: np.ndarray + payload_version: int def extract_artifacts_from_neutron_node( @@ -123,7 +124,12 @@ def extract_artifacts_from_neutron_node( output_names = [] output_indices = [] graph_outputs = sub_graph.OutputsAsNumpy() + payload_version = 0 + # Ignore the extra outputs: scratch and eventually also profile and debug node_outputs = neutron_node.OutputsAsNumpy()[:-1] + if len(graph_outputs) == len(node_outputs) - 2: + payload_version = 1 + node_outputs = node_outputs[:-2] for tensor_idx in node_outputs: which_graph_output = np.where(graph_outputs == tensor_idx)[0] assert ( @@ -142,4 +148,5 @@ def extract_artifacts_from_neutron_node( microcode, weights, kernels, + payload_version, ) diff --git a/backends/nxp/neutron_partitioner.py b/backends/nxp/neutron_partitioner.py index 5bcdee0f8b6..e74a79b3e7b 100644 --- a/backends/nxp/neutron_partitioner.py +++ b/backends/nxp/neutron_partitioner.py @@ -8,22 +8,28 @@ import logging import operator from dataclasses import dataclass -from typing import Dict, final, List, Mapping +from typing import final, Mapping import torch from executorch.backends.nxp.backend.custom_delegation_options import ( CustomDelegationOptions, ) +from executorch.backends.nxp.backend.edge_helper import ( + DEQUANTIZE_OPERATORS, + QUANTIZE_OPERATORS, +) from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.backend.ir.converter.node_converter import Target from torch.export.exported_program import ExportedProgram -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx import Graph +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.passes.operator_support import OperatorSupportBase from torch.nn import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from executorch.backends.nxp.nxp_backend import NeutronBackend from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( @@ -34,6 +40,9 @@ from executorch.exir.backend.utils import tag_constant_data from executorch.exir.dialects._ops import ops as exir_ops +NXP_DO_NOT_DELEGATE = "NXP_DO_NOT_DELEGATE" +NXP_DELEGATION_TAG = "delegation_tag" + class QDQClusterRecognizer: """ @@ -60,22 +69,17 @@ class QDQCluster: """ compute_node: torch.fx.Node - ops: List[torch.fx.Node] - - QUANTIZE_OPERATORS = [ - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - ] - - DEQUANTIZE_OPERATORS = [ - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - ] + ops: list[torch.fx.Node] AUXILIARY_OPS = [ operator.getitem, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, ] def __init__(self): @@ -83,17 +87,17 @@ def __init__(self): @staticmethod def is_quant_node(node: torch.fx.Node) -> bool: - return node.target in QDQClusterRecognizer.QUANTIZE_OPERATORS + return node.target in QUANTIZE_OPERATORS @staticmethod def is_dequant_node(node: torch.fx.Node) -> bool: - return node.target in QDQClusterRecognizer.DEQUANTIZE_OPERATORS + return node.target in DEQUANTIZE_OPERATORS @staticmethod def is_auxiliary_node(node: torch.fx.Node) -> bool: return node.target in QDQClusterRecognizer.AUXILIARY_OPS - def get_qdq_cluster_input_part(self, node: torch.fx.Node) -> List[torch.fx.Node]: + def get_qdq_cluster_input_part(self, node: torch.fx.Node) -> list[torch.fx.Node]: """ Return the list of nodes representing the input part of the QDQ cluster of the node `node`. Those are various dequantization nodes (see DEQUANTIZE_OPERATORS) optionally followed by auxiliary @@ -121,7 +125,7 @@ def get_qdq_cluster_input_part(self, node: torch.fx.Node) -> List[torch.fx.Node] logging.debug(f"Dequant Cluster for {node} is: {qdq_cluster}") return qdq_cluster - def get_qdq_cluster_output_part(self, node: torch.fx.Node) -> List[torch.fx.Node]: + def get_qdq_cluster_output_part(self, node: torch.fx.Node) -> list[torch.fx.Node]: """ Returns the list of nodes representing the output part of the QDQ cluster of the `node`. Those are various quantize nodes (see QUANTIZE_OPERATORS) preceded by auxiliary nodes. @@ -151,7 +155,7 @@ def get_qdq_cluster_output_part(self, node: torch.fx.Node) -> List[torch.fx.Node logging.debug(f"Quant Cluster for {node} is {qdq_cluster}") return qdq_cluster - def get_qdq_cluster(self, node: torch.fx.Node) -> List[torch.fx.Node]: + def get_qdq_cluster(self, node: torch.fx.Node) -> list[torch.fx.Node]: """ Returns the QDQ cluster of the operator, if quantized. If operator is not quantized, returns empty list. """ @@ -163,7 +167,7 @@ def get_qdq_cluster(self, node: torch.fx.Node) -> List[torch.fx.Node]: else: return [] - def tag_nodes(self, nodes: List[torch.fx.Node], cluster_name: str) -> None: + def tag_nodes(self, nodes: list[torch.fx.Node], cluster_name: str) -> None: """ Tags a node and its related dequant and quant nodes with a specified cluster name """ @@ -171,7 +175,7 @@ def tag_nodes(self, nodes: List[torch.fx.Node], cluster_name: str) -> None: logging.info(f"Tagging node {node} as {cluster_name}") node.meta["cluster"] = cluster_name - def tag_qdq_clusters(self, nodes: List[torch.fx.Node]): + def tag_qdq_clusters(self, nodes: list[torch.fx.Node]): """ Identifies QDQ clusters and tag them based on compute operation inside. """ @@ -197,6 +201,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]): exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405 exir_ops.edge.aten.cat.default: CatConverter, # noqa F405 exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405 + exir_ops.edge.dim_order_ops._clone_dim_order.default: CloneConverter, # noqa F405 exir_ops.edge.aten.constant_pad_nd.default: ConstantPadNDConverter, # noqa F405 exir_ops.edge.aten.convolution.default: ConvolutionConverter, # noqa F405 exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405 @@ -204,11 +209,15 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]): exir_ops.edge.aten.max_pool2d_with_indices.default: MaxPool2dConverter, # noqa F405 exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405 exir_ops.edge.aten.mm.default: MMConverter, # noqa F405 + exir_ops.edge.aten.mul.Tensor: MulTensorConverter, # noqa F405 + exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405 exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405 + exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405 + exir_ops.edge.aten.slice_copy.Tensor: SliceTensorConverter, # noqa F405 exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405 + exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405 exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405 exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405 - exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405 } @@ -216,14 +225,14 @@ class NeutronSupportedOperators(OperatorSupportBase): def __init__( self, - qdq_clusters: Dict[str, QDQClusterRecognizer.QDQCluster], - target: Target, - operators_not_to_delegate: List[str], + qdq_clusters: dict[str, QDQClusterRecognizer.QDQCluster], + neutron_target_spec: NeutronTargetSpec, + operators_not_to_delegate: list[str], parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ): self.qdq_clusters = qdq_clusters - self.target = target + self.neutron_target_spec = neutron_target_spec self.operators_not_to_delegate = operators_not_to_delegate self.parameters_mapping = parameters_mapping self.custom_delegation_options = custom_delegation_options @@ -246,6 +255,11 @@ def _is_node_supported_compute(self, node: torch.fx.node.Node) -> bool: """ Operator checking function for compute nodes. """ + + if hasattr(node, "meta") and node.meta.get(NXP_DO_NOT_DELEGATE, False): + # The delegation of this node has been prohibited. + return False + if not self.is_node_delegatable(node): return False @@ -260,7 +274,7 @@ def _is_node_supported_compute(self, node: torch.fx.node.Node) -> bool: # TODO: `view_copy` node should be delegated only if it's not the only operator in the cluster. node_converter.is_supported( node, - self.target, + self.neutron_target_spec, self.parameters_mapping, self.custom_delegation_options, ) @@ -296,35 +310,62 @@ def is_node_supported( class NeutronPartitioner(Partitioner): def __init__( self, - compile_spec: List[CompileSpec], + compile_spec: list[CompileSpec], + neutron_target_spec: NeutronTargetSpec, custom_delegation_options: CustomDelegationOptions | None = None, ) -> None: self.delegation_spec = DelegationSpec(NeutronBackend.__name__, compile_spec) self.custom_delegation_options = ( custom_delegation_options or CustomDelegationOptions() ) + self.neutron_target_spec = neutron_target_spec + + def validate_partitioning_result( + self, + graph: Graph, + partition_list: list[Partition], + custom_delegation_options: CustomDelegationOptions, + parameters_mapping: dict[str, Parameter], + ) -> bool: + all_delegated_nodes = { + node for partition in partition_list for node in partition.nodes + } + partitioning_valid = True + for node in graph.nodes: + if ( + node in all_delegated_nodes + and hasattr(node, "target") + and node.target in supported_ops + ): + if not supported_ops[node.target].supports_partitioning_result( + node, + partition_list, + custom_delegation_options, + self.neutron_target_spec, + parameters_mapping, + ): + # This node is not supported within its partition. Exclude it from delegation in the future. + partitioning_valid = False + node.meta[NXP_DO_NOT_DELEGATE] = True + + return partitioning_valid def partition(self, exported_program: ExportedProgram) -> PartitionResult: # Run the CapabilityBasedPartitioner to return the largest possible # subgraphs containing the nodes with the tags logging.info("NeutronPartitioner::partition") partition_tags = {} + partition_list = [] graph_module = exported_program.graph_module nodes = list(graph_module.graph.nodes) qdq_cluster_recognizer = QDQClusterRecognizer() qdq_cluster_recognizer.tag_qdq_clusters(nodes) + graph_module.recompile() - target = None - operators_not_to_delegate = "" - for spec in self.delegation_spec.compile_specs: - if spec.key == "target": - target = Target(spec.value.decode()) - if spec.key == "operators_not_to_delegate": - operators_not_to_delegate = spec.value.decode().split(",") - assert target is not None + operators_not_to_delegate = self.delegation_spec[1][4].value.decode().split(",") logging.info(f"Operators not to delegate: {operators_not_to_delegate}") parameters_mapping = EdgeProgramToIRConverter.map_inputs_to_parameters( @@ -334,7 +375,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: exported_program.graph_module, NeutronSupportedOperators( qdq_cluster_recognizer.cluster_map, - target, + self.neutron_target_spec, operators_not_to_delegate, parameters_mapping, self.custom_delegation_options, @@ -342,11 +383,35 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: allows_single_node_partition=True, ) - partition_list = capability_partitioner.propose_partitions() + # Identify the format (NCHW/NHWC/...) for all nodes in the graph, and store it in the `node.meta`. + # This format will be used by the `CapabilityBasedPartitioner` to determine which nodes will be delegated. + NodeFormatInference(exported_program).identify_node_formats() + + parameters_mapping = EdgeProgramToIRConverter.map_inputs_to_parameters( + exported_program + ) + + iteration_limit = len(exported_program.graph.nodes) + for _ in range(iteration_limit): + # Run the partitioning. + partition_list = capability_partitioner.propose_partitions() + + # Check if the nodes support the partitioning result. Mark the problematic nodes with `NXP_DO_NOT_DELEGATE`. + partitioning_valid = self.validate_partitioning_result( + exported_program.graph, + partition_list, + self.custom_delegation_options, + parameters_mapping, + ) + if partitioning_valid: + # The result of the partitioning is fine + break + + # Mark the partitions in the node `meta` attribute. for partition in partition_list: for node in partition.nodes: delegation_tag = f"tag{partition.id}" - node.meta["delegation_tag"] = delegation_tag + node.meta[NXP_DELEGATION_TAG] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec tag_constant_data(exported_program) diff --git a/backends/nxp/nxp_backend.py b/backends/nxp/nxp_backend.py index c801eefec81..4419f05aa99 100644 --- a/backends/nxp/nxp_backend.py +++ b/backends/nxp/nxp_backend.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -15,20 +15,21 @@ import numpy as np import torch +from executorch.backends.nxp._passes.remove_getitem_pass import RemoveGetItemPass from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.backend.ir.converter.node_converter import Target -from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.neutron_converter_manager import ( NeutronConverterManager, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format import NodeFormat from executorch.backends.nxp.neutron_node_extraction import ( extract_artifacts_from_neutron_node, NeutronNodeArtifacts, ) from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager -from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.verification.verifier import EXIREdgeDialectVerifier @@ -36,14 +37,15 @@ class NeutronCompileSpecBuilder: + config: NeutronTargetSpec def __init__(self): - self.config: Target = None self.compile_spec: List[CompileSpec] = [] self.compiler_flags = [] self.output_format = None self.operators_not_to_delegate: List[str] = [] self.neutron_converter_flavor = None + self.use_neutron_for_format_conversion = True def _replace_colons(self, operator: str) -> str: """ @@ -57,6 +59,7 @@ def neutron_compile_spec( neutron_converter_flavor: str, extra_flags: Optional[str] = None, operators_not_to_delegate: Optional[List[str]] = None, + use_neutron_for_format_conversion: bool = True, ): """ Generate compile spec for Neutron NPU @@ -64,18 +67,16 @@ def neutron_compile_spec( Args: config: Neutron accelerator configuration, e.g. "imxrt700" neutron_converter_flavor: Flavor of the neutron-converter module to use. Neutron-converter module named " - "'neutron_converter_SDK_25_06' has flavor 'SDK_25_06'. + "'neutron_converter_SDK_25_09' has flavor 'SDK_25_09'. extra_flags: Extra flags for the Neutron compiler operators_not_to_delegate: List of operators that should not be delegated + use_neutron_for_format_conversion: If True, the EdgeProgramToIRConverter will insert `Transpose` ops to + ensure that the IO matches the executorch partition, which will be + delegated to Neutron. """ - try: - self.config = Target(config) - except ValueError: - raise ValueError( - f"Config `{config}` is not a valid target. Must be one of `{Target.values()}`." - ) self.neutron_converter_flavor = neutron_converter_flavor + self.config = NeutronTargetSpec(config, neutron_converter_flavor) assert ( self.output_format is None @@ -91,6 +92,8 @@ def neutron_compile_spec( self._replace_colons(op) for op in operators_not_to_delegate ] + self.use_neutron_for_format_conversion = use_neutron_for_format_conversion + return self def build(self): @@ -101,7 +104,7 @@ def build(self): self.compile_spec += [ CompileSpec("output_format", "tflite".encode()), CompileSpec("compile_flags", " ".join(self.compiler_flags).encode()), - CompileSpec("target", self.config.value.encode()), + CompileSpec("target", self.config.get_name().encode()), CompileSpec( "neutron_converter_flavor", self.neutron_converter_flavor.encode() ), @@ -109,6 +112,10 @@ def build(self): "operators_not_to_delegate", ",".join(self.operators_not_to_delegate).encode(), ), + CompileSpec( + "use_neutron_for_format_conversion", + f"{self.use_neutron_for_format_conversion}".encode(), + ), ] return self.compile_spec @@ -120,6 +127,7 @@ def generate_neutron_compile_spec( system_config: Optional[str] = None, extra_flags: Optional[str] = None, operators_not_to_delegate: Optional[List[str]] = None, + use_neutron_for_format_conversion: bool = True, ) -> List[CompileSpec]: return ( NeutronCompileSpecBuilder() @@ -128,6 +136,7 @@ def generate_neutron_compile_spec( neutron_converter_flavor, extra_flags=extra_flags, operators_not_to_delegate=operators_not_to_delegate, + use_neutron_for_format_conversion=use_neutron_for_format_conversion, ) .build() ) @@ -150,6 +159,7 @@ def preprocess( # noqa C901 binary = bytes() target = "" neutron_converter_flavor = "" + use_neutron_for_format_conversion = None for spec in compile_spec: if spec.key == "output_format": output_format = spec.value.decode() @@ -159,6 +169,8 @@ def preprocess( # noqa C901 compile_flags.append(spec.value.decode()) if spec.key == "neutron_converter_flavor": neutron_converter_flavor = spec.value.decode() + if spec.key == "use_neutron_for_format_conversion": + use_neutron_for_format_conversion = spec.value.decode() == "True" # Check that the output format is set in the compile spec if not output_format: @@ -185,12 +197,19 @@ def preprocess( # noqa C901 ).transform() # Convert the edge program to TFLite. + conversion_config = ConversionConfig( + {"use_neutron_for_format_conversion": use_neutron_for_format_conversion} + if use_neutron_for_format_conversion is not None + else {} + ) tflite_model, io_formats = EdgeProgramToIRConverter().convert_program( edge_program, + neutron_target_spec=NeutronTargetSpec(target, neutron_converter_flavor), + conversion_config=conversion_config, ) - neutron_model = NeutronConverterManager().convert( - tflite_model, target, neutron_converter_flavor + neutron_model = NeutronConverterManager(neutron_converter_flavor).convert( + tflite_model, target ) # Dump the tflite file if logging level is enabled @@ -245,7 +264,9 @@ def _format_string_for_array(self, array: np.ndarray) -> str: return f"{array.size}s{self._padding_format_string_for_array(array)}" - def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: + def _create_payload_header( + self, io_formats: dict[str, list[NodeFormat]], neutron_artifacts + ) -> np.ndarray: """ Create bytes header for returned payload. It contains information about input and output tensor formats. Tensors are ordered based on graph signature @@ -262,6 +283,8 @@ def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: +----------------------------------------+------------------------------------------+ | 1st output map (1B) | [nth* output map (1B)] | +----------------------------------------+------------------------------------------+ + | Payload version (1B) | + +-----------------------------------------------------------------------------------+ :param io_formats: IO tensors formats. :return: Bytes representation of payload header. @@ -283,9 +306,7 @@ def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: for input_name in neutron_artifacts.input_names: try: header_data.append( - 1 - if inputs[input_name.decode()] == TensorFormat.CHANNELS_LAST - else 0 + 1 if inputs[input_name.decode()] == NodeFormat.CHANNELS_LAST else 0 ) except KeyError: raise AssertionError( @@ -296,7 +317,7 @@ def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: try: header_data.append( 1 - if outputs[output_name.decode()] == TensorFormat.CHANNELS_LAST + if outputs[output_name.decode()] == NodeFormat.CHANNELS_LAST else 0 ) except KeyError: @@ -306,6 +327,7 @@ def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: header_data.extend(neutron_artifacts.input_indices) header_data.extend(neutron_artifacts.output_indices) + header_data.append(neutron_artifacts.payload_version) # noinspection PyTypeChecker return np.array(header_data, dtype=np.uint8) @@ -335,7 +357,9 @@ def _pack_with_alignment( neutron_artifacts.kernels.tobytes(), ) - def get_binary_payload(self, io_formats, neutron_model) -> bytes: + def get_binary_payload( + self, io_formats: dict[str, list[NodeFormat]], neutron_model + ) -> bytes: """ Get binary payload for provided input/output tensor formats and neutron_model. Returned data have following structure: @@ -355,7 +379,7 @@ def get_binary_payload(self, io_formats, neutron_model) -> bytes: Tensor format definition: '0x1' == CHANNELS_LAST, '0x0' == FORMATLESS (no format). :param io_formats: Dictionary with keys 'inputs' and 'outputs' that contains dictionaries - mapping tensor name to TensorFormat. + mapping tensor name to NodeFormat. :param neutron_model: Neutron model with single NeutronGraph node. :return: 16 bytes aligned binary payload. """ diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index d3f84144aa3..b9186884d5e 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -4,15 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional, Tuple, Union - import torch - from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( NeutronAtenPassManager, ) + +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from executorch.backends.nxp.quantizer.patterns import ( AbsPattern, + ActivationsConcatClusterPattern, AdaptiveAvgPoolPattern, AddmmPattern, AddTensorPattern, @@ -20,6 +20,7 @@ CatPattern, Conv1dPattern, Conv2dPattern, + ConvTranspose2dPattern, DropoutPattern, FlattenPattern, HardTanhInPlacePattern, @@ -27,6 +28,9 @@ LinearPattern, MaxPoolPattern, MeanDimPattern, + MmPattern, + MulTensorPattern, + NodeArgsIdx, PadPattern, PermutePattern, QuantizationPattern, @@ -35,9 +39,12 @@ ReshapePattern, SharedSpecPattern, SigmoidPattern, + SliceTensorPattern, SoftMaxPattern, + SubTensorPattern, TanhInPlacePattern, TanhPattern, + TransposeIntPattern, ViewPattern, ) from executorch.backends.nxp.quantizer.utils import ( @@ -47,7 +54,13 @@ ) from torch import fx from torch.ao.quantization.quantizer.utils import _annotate_output_qspec -from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver +from torchao.quantization.pt2e import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, +) from torchao.quantization.pt2e.quantizer import ( ComposableQuantizer, DerivedQuantizationSpec, @@ -106,13 +119,13 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: ) def annotate_inputs( - inputs: Union[ - List[Tuple[fx.Node, int]], - List[Tuple[fx.Node, int, DerivedQuantizationSpec],], - ], - spec: Optional[QuantizationSpec], + inputs: ( + list[tuple[fx.Node, NodeArgsIdx]] + | list[tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec]] + ), + spec: QuantizationSpec | None, ) -> None: - for node, idx, *custom_spec in inputs: + for node, args_idx, *custom_spec in inputs: # pyre-ignore[16]: no attribute annotation = node.meta.get( Q_ANNOTATION_KEY, @@ -120,10 +133,10 @@ def annotate_inputs( ) arg = ( # pyre-ignore[16]: no attribute - node.args[idx] - if isinstance(idx, int) + node.args[args_idx.idx] + if args_idx.inner_idx is None # pyre-ignore[16]: no attribute - else node.args[idx[0]][idx[1]] + else node.args[args_idx.idx][args_idx.inner_idx] ) annotation.input_qspec_map[arg] = ( custom_spec[0] if custom_spec else spec @@ -131,106 +144,136 @@ def annotate_inputs( # pyre-ignore[16]: no attribute node.meta[Q_ANNOTATION_KEY] = annotation - def annotate_weights_or_biases( - weights_or_biases: List[Tuple[fx.Node, int]], - spec: Optional[QuantizationSpec], - ) -> None: - for node, idx, *custom_spec in weights_or_biases: - annotation = node.meta.get( - Q_ANNOTATION_KEY, - QuantizationAnnotation(_annotated=True), - ) - annotation.input_qspec_map[node.args[idx]] = ( - custom_spec[0] if custom_spec else spec - ) - node.meta[Q_ANNOTATION_KEY] = annotation - # pyre-ignore[6]: incompatible parameter type annotate_inputs(anchors.inputs, input_act_qspec) - annotate_weights_or_biases(anchors.weights, weight_qspec) + annotate_inputs(anchors.weights, weight_qspec) # pyre-ignore[6]: incompatible parameter type - annotate_weights_or_biases(anchors.biases, bias_qspec) + annotate_inputs(anchors.biases, bias_qspec) return model def validate(self, model: fx.GraphModule) -> None: pass @classmethod - def get_supported_operators(cls) -> List[OperatorConfig]: + def get_supported_operators(cls) -> list[OperatorConfig]: return [] # Quantization Specification used by Neutron NPU -act_qspec = QuantizationSpec( - dtype=torch.int8, - quant_min=-128, - quant_max=127, - qscheme=torch.per_tensor_affine, - is_dynamic=False, - observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), -) - -wgt_qspec = QuantizationSpec( - dtype=torch.int8, - quant_min=-127, - quant_max=127, - qscheme=torch.per_tensor_symmetric, - is_dynamic=False, - observer_or_fake_quant_ctr=MinMaxObserver, - ch_axis=0, -) +def act_qspec(is_qat: bool): + eps = 2**-12 + observer_or_fake_quant_ctr = ( + FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, eps=eps + ) + if is_qat + else HistogramObserver.with_args(eps=eps) + ) + + return QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, + ) + + +def wgt_qspec(is_qat: bool): + observer_or_fake_quant_ctr = ( + FakeQuantize.with_args(observer=MovingAverageMinMaxObserver) + if is_qat + else MinMaxObserver + ) + + return QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, + ch_axis=0, + ) + + +def wgt_fc_qspec(is_qat: bool): + observer_or_fake_quant_ctr = ( + FakeQuantize.with_args(observer=MovingAverageMinMaxObserver) + if is_qat + else MinMaxObserver + ) + + return QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, + ) -wgt_fc_qspec = QuantizationSpec( - dtype=torch.int8, - quant_min=-127, - quant_max=127, - qscheme=torch.per_tensor_symmetric, - is_dynamic=False, - observer_or_fake_quant_ctr=MinMaxObserver, -) # Is set by the *PatternQuantizer directly. bias_qspec = None class NeutronQuantizer(ComposableQuantizer): - def __init__(self): + def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False): + self.neutron_target_spec = neutron_target_spec + self.is_qat = is_qat + static_qconfig = QuantizationConfig( - act_qspec, - act_qspec, - wgt_qspec, + act_qspec(is_qat=is_qat), + act_qspec(is_qat=is_qat), + wgt_qspec(is_qat=is_qat), + None, + ) + static_fc_qconfig = QuantizationConfig( + act_qspec(is_qat=is_qat), + act_qspec(is_qat=is_qat), + wgt_fc_qspec(is_qat=is_qat), None, ) - static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None) + + OpQuantizer = NeutronAtenQuantizer super().__init__( [ - NeutronAtenQuantizer(AbsPattern(), static_qconfig), - NeutronAtenQuantizer(AdaptiveAvgPoolPattern(), static_qconfig), - NeutronAtenQuantizer(AddTensorPattern(), static_qconfig), - NeutronAtenQuantizer(AddmmPattern(), static_fc_qconfig), - NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig), - NeutronAtenQuantizer(CatPattern(), static_qconfig), - NeutronAtenQuantizer(Conv1dPattern(), static_qconfig), - NeutronAtenQuantizer(Conv2dPattern(), static_qconfig), - NeutronAtenQuantizer(DropoutPattern(), static_qconfig), - NeutronAtenQuantizer(FlattenPattern(), static_qconfig), - NeutronAtenQuantizer(HardTanhPattern(), static_qconfig), - NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig), - NeutronAtenQuantizer(LinearPattern(), static_fc_qconfig), - NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig), - NeutronAtenQuantizer(MeanDimPattern(), static_qconfig), - NeutronAtenQuantizer(PadPattern(), static_qconfig), - NeutronAtenQuantizer(PermutePattern(), static_qconfig), - NeutronAtenQuantizer(ReluPattern(), static_qconfig), - NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig), - NeutronAtenQuantizer(ReshapePattern(), static_qconfig), - NeutronAtenQuantizer(SigmoidPattern(), static_qconfig), - NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig), - NeutronAtenQuantizer(TanhPattern(), static_qconfig), - NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig), - NeutronAtenQuantizer(ViewPattern(), static_qconfig), + OpQuantizer(AbsPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(AdaptiveAvgPoolPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig), + OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig), + OpQuantizer(ConvTranspose2dPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(FlattenPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(HardTanhPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(HardTanhInPlacePattern(is_qat=is_qat), static_qconfig), + OpQuantizer(LinearPattern(self, is_qat=is_qat), static_fc_qconfig), + OpQuantizer(MaxPoolPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(MeanDimPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(MmPattern(self, is_qat=is_qat), static_qconfig), + OpQuantizer(MulTensorPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(PadPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(PermutePattern(is_qat=is_qat), static_qconfig), + OpQuantizer(ReluPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(ReluInPlacePattern(is_qat=is_qat), static_qconfig), + OpQuantizer(ReshapePattern(is_qat=is_qat), static_qconfig), + OpQuantizer(SigmoidPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(SliceTensorPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(SoftMaxPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(SubTensorPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(TanhPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(TanhInPlacePattern(is_qat=is_qat), static_qconfig), + OpQuantizer(TransposeIntPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(ViewPattern(is_qat=is_qat), static_qconfig), ] ) + # Mapping ops defined in quantizer partition types to its quantizer self.op_to_quantizer = { pt: q for q in self.quantizers for pt in q.pattern.partition_types() @@ -239,13 +282,18 @@ def __init__(self): self.op_to_applied_quantizer = { pt: False for q in self.quantizers for pt in q.pattern.partition_types() } + self.cluster_quantizers = [ + NeutronAtenQuantizer( + ActivationsConcatClusterPattern(self, is_qat=is_qat), static_qconfig + ) + ] def transform_for_annotation( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: model.graph.eliminate_dead_code() # Remove dead code to simplify the graph for the passes. - model = NeutronAtenPassManager()(model).graph_module + model = NeutronAtenPassManager(self.neutron_target_spec)(model).graph_module model.graph.eliminate_dead_code() # Remove dead code again, in case it was created by the passes. @@ -254,6 +302,10 @@ def transform_for_annotation( def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: self._annotate_inputs(model) + # Annotate node clusters in model + for cluster_quantizer in self.cluster_quantizers: + cluster_quantizer.annotate(model) + nodes = list(model.graph.nodes) for node in nodes: if ( @@ -286,7 +338,7 @@ def _annotate_inputs(self, model: fx.GraphModule): continue if node.op == "placeholder" and len(node.users) > 0: - _annotate_output_qspec(node, act_qspec) + _annotate_output_qspec(node, act_qspec(self.is_qat)) self._mark_input_node_as_annotated(node) def validate(self, model: torch.fx.GraphModule) -> None: diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 651f995d570..e8f247d4bbc 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -7,26 +7,49 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Optional, Tuple, Type, Union import torch from executorch.backends.nxp.quantizer.utils import get_bias_qparams from torch import fx from torch._ops import OpOverload +from torch.fx import Node +from torchao.quantization.pt2e import ( + FakeQuantize, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, +) from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, FixedQParamsQuantizationSpec, + QuantizationSpec, SharedQuantizationSpec, ) + from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +@dataclass +class NodeArgsIdx: + """ + Specifies indexes to args paramater of Node in node input annotation. + + + Attributes: + idx (int): Index to Node's args paramater (list). Selects an input Node or a list of Nodes at the index. + inner_idx (int): If specified, index to a list pointed by 'idx' attribute. Selects an input Node at the index. + Default: None. + """ + + idx: int + inner_idx: int = None + + @dataclass class PartitionAnchors: """ - All fields except output are lists of (node, args_index) pair, where node is from - the given partition and node.args[args_index] is an input to the partition. Assumes + All fields except output are lists of (node, node_args_idx) or (node, node_args_idx, quantization_spec) tuples, + where node is from the given partition and node.args[node_args_idx] is an input to the partition. Assumes a single output. Quantizer uses inputs, weights and biases for quantization annotation. The others @@ -35,29 +58,34 @@ class PartitionAnchors: """ # Inputs can share quantization parameters - inputs: List[ - Union[ - Tuple[fx.Node, Union[int, Tuple[int, int]]], - Tuple[ - fx.Node, - Union[int, Tuple[int, int]], - SharedQuantizationSpec, - ], - ] + inputs: list[ + tuple[fx.Node, NodeArgsIdx] + | tuple[fx.Node, NodeArgsIdx, SharedQuantizationSpec], ] = field(default_factory=list) - weights: List[Tuple[fx.Node, int]] = field(default_factory=list) - biases: List[ - Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]] + weights: list[ + tuple[fx.Node, NodeArgsIdx] + | tuple[fx.Node, NodeArgsIdx, QuantizationSpec | FakeQuantize], + ] = field(default_factory=list) + biases: list[ + tuple[fx.Node, NodeArgsIdx] + | tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec], + ] = field(default_factory=list) + others: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list) + literals: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list) + output: list[ + tuple[fx.Node] + | tuple[ + fx.Node, + FixedQParamsQuantizationSpec | SharedQuantizationSpec, + ], ] = field(default_factory=list) - others: List[Tuple[fx.Node, int]] = field(default_factory=list) - literals: List[Tuple[fx.Node, int]] = field(default_factory=list) - output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field( - default_factory=list - ) empty: bool = False class QuantizationPattern(ABC): + def __init__(self, is_qat: bool = False): + self.is_qat = is_qat + @abstractmethod def partition_types(self) -> list[OpOverload]: """ @@ -67,8 +95,8 @@ def partition_types(self) -> list[OpOverload]: @abstractmethod def get_anchors( - self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> Optional[PartitionAnchors]: + self, gm: torch.fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors | None: pass @@ -80,11 +108,12 @@ class SharedSpecPattern(QuantizationPattern): quantization parameters (scale and zero-point). """ - def partition_types(self) -> List[Type[torch.nn.Module]]: + @abstractmethod + def partition_types(self) -> list[torch.nn.Module]: pass def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors | None: node = fused_partition[0].nodes[-1] assert len(fused_partition[0].input_nodes) == 1 @@ -97,7 +126,7 @@ def get_anchors( qspec = SharedQuantizationSpec(prev_node) return PartitionAnchors( - inputs=[(node, 0)], + inputs=[(node, NodeArgsIdx(0))], weights=[], biases=[], output=[ @@ -106,17 +135,36 @@ def get_anchors( ) +class SingleInputBasicPattern(QuantizationPattern): + @abstractmethod + def partition_types(self) -> list[OpOverload]: + pass + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors | None: + node = fused_partition[0].nodes[-1] + + return PartitionAnchors( + inputs=[(node, NodeArgsIdx(0))], + weights=[], + biases=[], + output=[(node,)], + ) + + def get_anchors_for_fixed_quant_specs( fused_partition: list[fx.GraphModule], scale: float, zero_point: int, quant_min: int = -128, quant_max: int = 127, + is_qat: bool = False, ) -> PartitionAnchors: node = fused_partition[0].nodes[-1] assert len(fused_partition[0].input_nodes) == 1 - qspec = FixedQParamsQuantizationSpec( + qspec_or_fake_quantize = FixedQParamsQuantizationSpec( dtype=torch.int8, scale=scale, zero_point=zero_point, @@ -126,11 +174,11 @@ def get_anchors_for_fixed_quant_specs( ) return PartitionAnchors( - inputs=[(node, 0)], + inputs=[(node, NodeArgsIdx(0))], weights=[], biases=[], output=[ - (node, qspec), + (node, qspec_or_fake_quantize), ], ) @@ -154,13 +202,20 @@ def partition_types(self): class AddmmPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: + def __init__(self, neutron_quantizer, is_qat: bool): + super().__init__(is_qat=is_qat) + + self.neutron_quantizer = neutron_quantizer + self.neutron_target_info = ( + self.neutron_quantizer.neutron_target_spec.neutron_target_info + ) + + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.addmm.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: - # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... addmm_node = fused_partition[0].nodes[-1] bias_qspec = DerivedQuantizationSpec( @@ -175,11 +230,25 @@ def get_anchors( qscheme=torch.per_tensor_affine, ) + # If the following node is a fusable activation, quantize together with activation + output = [(addmm_node,)] + if len( + addmm_node.users + ) == 1 and self.neutron_target_info.is_supported_fused_activation__aten( + activation := next(iter(addmm_node.users)) + ): + activation_quantizer = self.neutron_quantizer.op_to_quantizer[ + activation.target + ] + activation_quantizer.annotate(gm) + output = [] + activation.meta["quantization_annotation"].input_qspec_map = {} + return PartitionAnchors( - inputs=[(addmm_node, 1)], - weights=[(addmm_node, 2)], - biases=[(addmm_node, 0, bias_qspec)], - output=[(addmm_node,)], + inputs=[(addmm_node, NodeArgsIdx(1))], + weights=[(addmm_node, NodeArgsIdx(2))], + biases=[(addmm_node, NodeArgsIdx(0), bias_qspec)], + output=output, ) @@ -190,16 +259,42 @@ class AddTensorPattern(QuantizationPattern): Basic quantization for all inputs and output. """ - def partition_types(self) -> List[Type[torch.nn.Module]]: + def partition_types(self) -> list[torch.nn.Module]: return [torch.ops.aten.add.Tensor] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors | None: node = fused_partition[0].nodes[-1] - inputs = [(node, 0)] + inputs = [(node, NodeArgsIdx(0))] if len(fused_partition[0].input_nodes) == 2: - inputs = [(node, 0), (node, 1)] + inputs = [(node, NodeArgsIdx(0)), (node, NodeArgsIdx(1))] + + return PartitionAnchors( + inputs=inputs, + weights=[], + biases=[], + output=[(node,)], + ) + + +class SubTensorPattern(QuantizationPattern): + """ + Quantization pattern for Sub Tensor quantization. Accepts 1 or 2 input nodes. + + Basic quantization for all inputs and output. + """ + + def partition_types(self) -> list[torch.nn.Module]: + return [torch.ops.aten.sub.Tensor] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors | None: + node = fused_partition[0].nodes[-1] + inputs = [(node, NodeArgsIdx(0))] + if len(fused_partition[0].input_nodes) == 2: + inputs = [(node, NodeArgsIdx(0)), (node, NodeArgsIdx(1))] return PartitionAnchors( inputs=inputs, @@ -242,13 +337,15 @@ def get_anchors( if quantized_input is not None: inputs = [] for idx, _ in enumerate(node.args[0]): - inputs.append((node, (0, idx), SharedQuantizationSpec(quantized_input))) + inputs.append( + (node, NodeArgsIdx(0, idx), SharedQuantizationSpec(quantized_input)) + ) outputs = [(node, SharedQuantizationSpec(quantized_input))] else: # No previous node was quantized => we are not able to share q-params. The conversion to IR will have to # re-quantize the inputs if necessary. - inputs = [(node, (0, idx)) for idx in range(len(node.args[0]))] + inputs = [(node, NodeArgsIdx(0, idx)) for idx in range(len(node.args[0]))] outputs = [(node,)] return PartitionAnchors( @@ -259,75 +356,179 @@ def get_anchors( ) -class Conv1dPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: - return [torch.ops.aten.conv1d.default] +class ConvPattern(QuantizationPattern): + @abstractmethod + def partition_types(self) -> list[OpOverload]: + pass def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: - # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... - conv1d_node = fused_partition[0].nodes[-1] + conv_node = fused_partition[0].nodes[-1] - bias_qspec = DerivedQuantizationSpec( + bias_quantization_qspec = DerivedQuantizationSpec( derived_from=[ - (conv1d_node.args[0], conv1d_node), - (conv1d_node.args[1], conv1d_node), + (conv_node.args[0], conv_node), + (conv_node.args[1], conv_node), ], derive_qparams_fn=get_bias_qparams, dtype=torch.int32, - quant_min=-(2**31), + quant_min=-(2**31) + 1, quant_max=2**31 - 1, - qscheme=torch.per_tensor_affine, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + ) + + weight_observer_or_fake_quant_ctr = ( + FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver) + if self.is_qat + else PerChannelMinMaxObserver + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, ) # Keep bias empty if not supplied bias = [] - if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None: - bias = [(conv1d_node, 2, bias_qspec)] + if len(conv_node.args) > 2 and conv_node.args[2] is not None: + bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)] return PartitionAnchors( - inputs=[(conv1d_node, 0)], - weights=[(conv1d_node, 1)], - # pyre-fixme[6]: Incompatible parameter type + inputs=[(conv_node, NodeArgsIdx(0))], + weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)], biases=bias, - output=[(conv1d_node,)], + output=[(conv_node,)], ) -class Conv2dPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: +class Conv1dPattern(ConvPattern): + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.conv1d.default] + + +class ConvTranspose1dPattern(ConvPattern): + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.conv_transpose1d.default] + + +class Conv2dPattern(ConvPattern): + def __init__(self, neutron_quantizer, is_qat: bool = False): + super().__init__(is_qat=is_qat) + + self.neutron_quantizer = neutron_quantizer + self.neutron_target_info = ( + self.neutron_quantizer.neutron_target_spec.neutron_target_info + ) + + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.conv2d.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: - # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... - conv2d_node = fused_partition[0].nodes[-1] + conv_node = fused_partition[0].nodes[-1] - bias_qspec = DerivedQuantizationSpec( + bias_quantization_qspec = DerivedQuantizationSpec( derived_from=[ - (conv2d_node.args[0], conv2d_node), - (conv2d_node.args[1], conv2d_node), + (conv_node.args[0], conv_node), + (conv_node.args[1], conv_node), ], derive_qparams_fn=get_bias_qparams, dtype=torch.int32, - quant_min=-(2**31), + quant_min=-(2**31) + 1, quant_max=2**31 - 1, - qscheme=torch.per_tensor_affine, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + ) + + weight_observer_or_fake_quant_ctr = ( + FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver) + if self.is_qat + else PerChannelMinMaxObserver + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + ) + + # Keep bias empty if not supplied + bias = [] + if len(conv_node.args) > 2 and conv_node.args[2] is not None: + bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)] + + # If the following node is a fusable activation, quantize together with activation + output = [(conv_node,)] + if len( + conv_node.users + ) == 1 and self.neutron_target_info.is_supported_fused_activation__aten( + activation := next(iter(conv_node.users)) + ): + activation_quantizer = self.neutron_quantizer.op_to_quantizer[ + activation.target + ] + activation_quantizer.annotate(gm) + output = [] + activation.meta["quantization_annotation"].input_qspec_map = {} + + return PartitionAnchors( + inputs=[(conv_node, NodeArgsIdx(0))], + weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)], + biases=bias, + output=output, + ) + + +class ConvTranspose2dPattern(QuantizationPattern): + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.conv_transpose2d.input] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors: + conv_node = fused_partition[0].nodes[-1] + + bias_quantization_qspec = DerivedQuantizationSpec( + derived_from=[ + (conv_node.args[0], conv_node), + (conv_node.args[1], conv_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31) + 1, + quant_max=2**31 - 1, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + ) + + weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=1, ) # Keep bias empty if not supplied bias = [] - if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None: - bias = [(conv2d_node, 2, bias_qspec)] + if len(conv_node.args) > 2 and conv_node.args[2] is not None: + bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)] return PartitionAnchors( - inputs=[(conv2d_node, 0)], - weights=[(conv2d_node, 1)], - # pyre-fixme[6]: Incompatible parameter type + inputs=[(conv_node, NodeArgsIdx(0))], + weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)], biases=bias, - output=[(conv2d_node,)], + output=[(conv_node,)], ) @@ -349,47 +550,33 @@ def partition_types(self): return [torch.ops.aten.flatten.using_ints] -class HardTanhPattern(QuantizationPattern): +class HardTanhPattern(SingleInputBasicPattern): """ - Quantizer for HardTanh operator. Shared quantization spec is selected, as activation functions usually follows - computation layer. + Quantizer for HardTanh operator. """ def partition_types(self): return [torch.ops.aten.hardtanh.default] - def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors | None: - node = fused_partition[0].nodes[-1] - - return PartitionAnchors( - inputs=[(node, 0)], - weights=[], - biases=[], - output=[(node,)], - ) - def replacement_op(self): raise AssertionError() -class HardTanhInPlacePattern(QuantizationPattern): +class HardTanhInPlacePattern(SingleInputBasicPattern): """ - Quantizer for HardTanh operator with param inplace=True. Shared quantization spec is selected, as activation - functions usually follows computation layer. + Quantizer for HardTanh operator with param inplace=True. """ def partition_types(self): return [torch.ops.aten.hardtanh_.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors | None: node = fused_partition[0].nodes[-1] return PartitionAnchors( - inputs=[(node, 0)], + inputs=[(node, NodeArgsIdx(0))], weights=[], biases=[], output=[(node,)], @@ -400,13 +587,20 @@ def replacement_op(self): class LinearPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: + def __init__(self, neutron_quantizer, is_qat: bool = False): + super().__init__(is_qat=is_qat) + + self.neutron_quantizer = neutron_quantizer + self.neutron_target_info = ( + self.neutron_quantizer.neutron_target_spec.neutron_target_info + ) + + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.linear.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: - # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... linear_node = fused_partition[0].nodes[-1] bias_qspec = DerivedQuantizationSpec( @@ -424,14 +618,29 @@ def get_anchors( # Keep bias empty if not supplied bias = [] if len(linear_node.args) > 2: - bias = [(linear_node, 2, bias_qspec)] + bias = [(linear_node, NodeArgsIdx(2), bias_qspec)] + + # If the following node is a fusable activation, quantize together with activation + output = [(linear_node,)] + if ( + len(linear_node.users) == 1 + and len(linear_node.meta["val"].shape) <= 2 + and self.neutron_target_info.is_supported_fused_activation__aten( + activation := next(iter(linear_node.users)) + ) + ): + activation_quantizer = self.neutron_quantizer.op_to_quantizer[ + activation.target + ] + activation_quantizer.annotate(gm) + output = [] + activation.meta["quantization_annotation"].input_qspec_map = {} return PartitionAnchors( - inputs=[(linear_node, 0)], - weights=[(linear_node, 1)], - # pyre-fixme[6]: Incompatible parameter type + inputs=[(linear_node, NodeArgsIdx(0))], + weights=[(linear_node, NodeArgsIdx(1))], biases=bias, - output=[(linear_node,)], + output=output, ) @@ -453,6 +662,88 @@ def partition_types(self): return [torch.ops.aten.mean.dim] +class MmPattern(QuantizationPattern): + def __init__(self, neutron_quantizer, is_qat: bool = False): + super().__init__(is_qat=is_qat) + + self.neutron_quantizer = neutron_quantizer + self.neutron_target_info = ( + self.neutron_quantizer.neutron_target_spec.neutron_target_info + ) + + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.mm.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors: + mm_node = fused_partition[0].nodes[-1] + + # If the following node is a fusable activation, quantize together with activation + output = [(mm_node,)] + if len( + mm_node.users + ) == 1 and self.neutron_target_info.is_supported_fused_activation__aten( + activation := next(iter(mm_node.users)) + ): + activation_quantizer = self.neutron_quantizer.op_to_quantizer[ + activation.target + ] + activation_quantizer.annotate(gm) + output = [] + activation.meta["quantization_annotation"].input_qspec_map = {} + + return PartitionAnchors( + inputs=[(mm_node, NodeArgsIdx(0))], + weights=[(mm_node, NodeArgsIdx(1))], + biases=[], + output=output, + ) + + +class MulTensorPattern(QuantizationPattern): + """ + Quantization pattern for Mul Tensor quantization. Accepts 1 or 2 input nodes. + + Basic quantization for all inputs and output. + """ + + def partition_types(self) -> list[torch.nn.Module]: + return [torch.ops.aten.mul.Tensor] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors | None: + node = fused_partition[0].nodes[-1] + input_nodes = node.all_input_nodes + + qspec = FixedQParamsQuantizationSpec( + dtype=torch.int8, + scale=1.0 / 256.0, + zero_point=0, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + ) + + # The "Mul" operator in Neutron IR requires a specific scale and zero_point + # (defined above) for its inputs. + # Since these input nodes have already been annotated by their own patterns + # which didn't take the requirements of "Mul" into account, we need to overwrite + # the existing "quantization_annotation". + for input_node in input_nodes: + input_node.meta["quantization_annotation"].output_qspec = qspec + + return PartitionAnchors( + inputs=[(node, NodeArgsIdx(0), qspec), (node, NodeArgsIdx(1), qspec)], + weights=[], + biases=[], + output=[ + (node,), + ], + ) + + class PadPattern(SharedSpecPattern): """ Quantizer for Pad operator. @@ -471,19 +762,27 @@ def partition_types(self): return [torch.ops.aten.permute.default] -class ReluPattern(SharedSpecPattern): +class TransposeIntPattern(SharedSpecPattern): + """ + Quantizer for Transpose Int operator. + """ + + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.transpose.int] + + +class ReluPattern(SingleInputBasicPattern): """ - Quantizer for Relu operator. Shared quantization spec is selected, as ReLU usually follows computation layer. + Quantizer for Relu operator. """ def partition_types(self): return [torch.ops.aten.relu.default] -class ReluInPlacePattern(SharedSpecPattern): +class ReluInPlacePattern(SingleInputBasicPattern): """ - Quantizer for Relu operator with param inplace=True. Shared quantization spec is selected, as ReLU usually - follows computation layer. + Quantizer for Relu operator with param inplace=True. """ def partition_types(self): @@ -508,6 +807,15 @@ def partition_types(self): return [torch.ops.aten.view.default] +class SliceTensorPattern(SharedSpecPattern): + """ + Quantizer for Slice operator. + """ + + def partition_types(self): + return [torch.ops.aten.slice.Tensor] + + class SoftMaxPattern(QuantizationPattern): """ Quantizer for Softmax operator. @@ -515,14 +823,32 @@ class SoftMaxPattern(QuantizationPattern): The quantization of Softmax output is fixed to scale 1/256, zero point -128, dtype int8. """ - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.softmax.int] def get_anchors( self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: return get_anchors_for_fixed_quant_specs( - fused_partition, scale=1.0 / 256.0, zero_point=-128 + fused_partition, scale=1.0 / 256.0, zero_point=-128, is_qat=self.is_qat + ) + + +class SigmoidPattern(QuantizationPattern): + """ + Quantizer for Sigmoid operator. + + The quantization of Sigmoid output is fixed to scale 1/256, zero point -128, dtype int8. + """ + + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.sigmoid.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors: + return get_anchors_for_fixed_quant_specs( + fused_partition, scale=1.0 / 256.0, zero_point=-128, is_qat=self.is_qat ) @@ -540,7 +866,7 @@ def get_anchors( self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: return get_anchors_for_fixed_quant_specs( - fused_partition, scale=1.0 / 128.0, zero_point=0 + fused_partition, scale=1.0 / 128.0, zero_point=0, is_qat=self.is_qat ) @@ -558,23 +884,151 @@ def get_anchors( self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: return get_anchors_for_fixed_quant_specs( - fused_partition, scale=1.0 / 128.0, zero_point=0 + fused_partition, scale=1.0 / 128.0, zero_point=0, is_qat=self.is_qat ) -class SigmoidPattern(QuantizationPattern): +class ActivationsConcatClusterPattern(QuantizationPattern): """ - Quantizer for Sigmoid operator. - - The quantization of Sigmoid output is fixed to scale 1/256, zero point -128, dtype int8. + Quantizer for activations concat cluster pattern. + + The quantizer matches a pattern where concat node is preceded by activation nodes preceded by Conv 2D or Linear. + All activation nodes quantization parameters must be the same. Only activations, that have support for fusion + to preceding compute node on Neutron are allowed. This cluster is usually produced by MoveActivationBeforeConcat + pass. Cluster schema: + + │ │ + ┌──────▼──────┐ ┌──────▼──────┐ + │ aten.conv2d │ ... │ aten.conv2d │ + └──────┬──────┘ └──────┬──────┘ + │ │ + ┌─────▼─────┐ ┌─────▼─────┐ + │ aten.relu │ ... │ aten.relu │ + └─────┬─────┘ └─────┬─────┘ + └───────┐ ┌───────┘ + ┌──▼─────▼─┐ + │ aten.cat │ + └────┬─────┘ + │ """ - def partition_types(self) -> List[OpOverload]: - return [torch.ops.aten.sigmoid.default] + def __init__(self, neutron_quantizer, is_qat: bool = False): + super().__init__(is_qat=is_qat) + + self.neutron_quantizer = neutron_quantizer + self.neutron_target_info = ( + self.neutron_quantizer.neutron_target_spec.neutron_target_info + ) + + @staticmethod + def _all_activations_are_equal(activations: list[Node]) -> bool: + first_input_node = activations[0] + hardtanh_t = [ + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + ] + relu_t = [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ] + tanh_t = [ + torch.ops.aten.tanh.default, + torch.ops.aten.tanh_.default, + ] + + def _activations_are_equal(activation1: Node, activation2: Node) -> bool: + if ( # Targets are equal also with their inplace variants + (activation1.target in hardtanh_t and activation2.target in hardtanh_t) + or (activation1.target in relu_t and activation2.target in relu_t) + or (activation1.target in tanh_t and activation2.target in tanh_t) + or ( + activation1.target == torch.ops.aten.sigmoid.default + and activation2.target == torch.ops.aten.sigmoid.default + ) + ): + return True + elif ( # Hardtanh with min_val 0 and max_val 'inf' is equal to Relu + activation1.target in hardtanh_t + and activation1.args[1:] == (0.0, float("inf")) + and activation2.target in relu_t + ) or ( + activation1.target in relu_t + and activation2.target in hardtanh_t + and activation2.args[1:] == (0.0, float("inf")) + ): + return True + else: + return False + + return all( + _activations_are_equal(activation, first_input_node) + for activation in activations + ) + + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.cat.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> PartitionAnchors: - return get_anchors_for_fixed_quant_specs( - fused_partition, scale=1.0 / 256.0, zero_point=-128 + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors | None: + cat_node = fused_partition[0].nodes[-1] + + # Check all cat inputs are supported activations + if not all( + self.neutron_target_info.is_supported_fused_activation__aten(input_node) + for input_node in cat_node.all_input_nodes + ): + return None + + # Check all cat inputs are equal activations + if not self._all_activations_are_equal(cat_node.all_input_nodes): + return None + + # Check compute nodes are Conv 2D or Linear + if not all( + self.neutron_target_info.is_fusable_conv_or_linear__aten(compute_node) + for input_node in cat_node.all_input_nodes + for compute_node in input_node.all_input_nodes + ): + return None + + # Annotate compute nodes + for input_node in cat_node.all_input_nodes: + for compute_node in input_node.all_input_nodes: + if compute_node.target not in self.neutron_quantizer.op_to_quantizer: + return None + compute_node_quantizer = self.neutron_quantizer.op_to_quantizer[ + compute_node.target + ] + compute_node_quantizer.annotate(gm) + del compute_node.meta["quantization_annotation"].output_qspec + + # Annotate activations + for input_node in cat_node.all_input_nodes: + if input_node.target not in self.neutron_quantizer.op_to_quantizer: + return None + activation_quantizer = self.neutron_quantizer.op_to_quantizer[ + input_node.target + ] + activation_quantizer.annotate(gm) + input_node.meta["quantization_annotation"].input_qspec_map = {} + + # Annotate cat node + inputs = [] + first_input_node = cat_node.all_input_nodes[0] + for idx in range(len(cat_node.all_input_nodes)): + inputs.append( + ( + cat_node, + NodeArgsIdx(0, idx), + SharedQuantizationSpec(first_input_node), + ) + ) + outputs = [(cat_node, SharedQuantizationSpec(first_input_node))] + + return PartitionAnchors( + inputs=inputs, + weights=[], + biases=[], + output=outputs, ) diff --git a/backends/nxp/quantizer/utils.py b/backends/nxp/quantizer/utils.py index ed94183c2db..6dc58e8114a 100644 --- a/backends/nxp/quantizer/utils.py +++ b/backends/nxp/quantizer/utils.py @@ -9,17 +9,25 @@ import itertools from collections import OrderedDict +from collections.abc import Iterable from typing import Any, Dict, List, Tuple, Type import torch from torch import fx from torch._ops import OpOverload +from torch.ao.quantization import move_exported_model_to_eval +from torch.export import ExportedProgram from torch.fx.passes.utils.source_matcher_utils import ( check_subgraphs_connected, SourcePartition, ) from torchao.quantization.pt2e import ObserverOrFakeQuantize -from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY, Quantizer def is_annotated(nodes: List[fx.Node]) -> bool: @@ -49,7 +57,7 @@ def get_bias_qparams( act_scale, _ = obs_or_fqs[0].calculate_qparams() weight_scale, _ = obs_or_fqs[1].calculate_qparams() bias_scale = act_scale * weight_scale - bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32) + bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int64) return bias_scale, bias_zero_point @@ -149,3 +157,37 @@ def find_sequential_partitions_aten( if _partitions_sequential(candidate): fused_partitions.append(candidate) return fused_partitions + + +def calibrate_and_quantize( + model: ExportedProgram | fx.GraphModule, + calibration_inputs: Iterable[tuple[torch.Tensor, ...]], + quantizer: Quantizer, + is_qat: bool = False, +) -> fx.GraphModule: + """Quantize the provided model. + + :param model: Aten model (or it's GraphModule representation) to quantize. + :param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model + input. Or an iterator over such tuples. + :param quantizer: Quantizer to use. + :param is_qat: Whether quantization is done using Quantization Aware Training (QAT) or not. + Note: In QAT mode, training is not performed. Only calibration (in eval mode) is done. + + :return: Quantized GraphModule. + """ + + if isinstance(model, ExportedProgram): + model = model.module() + + if is_qat: + m = prepare_qat_pt2e(model, quantizer) + m = move_exported_model_to_eval(m) + else: + m = prepare_pt2e(model, quantizer) + + for data in calibration_inputs: + m(*data) + m = convert_pt2e(m) + + return m diff --git a/backends/nxp/requirements-tests-eiq.txt b/backends/nxp/requirements-tests-eiq.txt index 896d2b8c07e..1fccf010e86 100644 --- a/backends/nxp/requirements-tests-eiq.txt +++ b/backends/nxp/requirements-tests-eiq.txt @@ -1,2 +1,2 @@ --index-url https://eiq.nxp.com/repository -neutron_converter_SDK_25_06 +neutron_converter_SDK_25_09 diff --git a/backends/nxp/runtime/NeutronBackend.cpp b/backends/nxp/runtime/NeutronBackend.cpp index 3568ab72580..4bf23324ef5 100644 --- a/backends/nxp/runtime/NeutronBackend.cpp +++ b/backends/nxp/runtime/NeutronBackend.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2024 NXP + * Copyright 2024-2025 NXP * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -10,6 +10,7 @@ #include #include #include +#include #include "NeutronDriver.h" #include "NeutronErrors.h" @@ -19,7 +20,6 @@ using namespace std; namespace torch { namespace executor { namespace neutron { - // All the memory need to be aligned with 16 #define BUFFER_ALIGNMENT 16 #define ALIGN_SIZE(size) \ @@ -38,6 +38,8 @@ namespace neutron { +----------------------------------------+------------------------------------------+ | 1st output map (1B) | [nth* output map (1B)] | +----------------------------------------+------------------------------------------+ + | Payload version (1B) | + +-----------------------------------------------------------------------------------+ */ // clang-format on #define ITEM_SIZE 1 // 1 Byte @@ -53,10 +55,13 @@ namespace neutron { #define OUTPUT_TENSOR_MAP_ARRAY_ADDR(base) \ (base + 3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \ 1 * base[OUTPUT_TENSOR_FORMAT_LEN_POS]) +#define PAYLOAD_VERSION_ADDR(base) \ + (base + 3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \ + 2 * base[OUTPUT_TENSOR_FORMAT_LEN_POS]) #define PAYLOAD_ADDR(base) \ (base + \ ALIGN_SIZE( \ - 3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \ + 4 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \ 2 * base[OUTPUT_TENSOR_FORMAT_LEN_POS])) // Aggregate neutron model handle and data structures into one. @@ -65,6 +70,8 @@ typedef struct { int numOutputs = 0; int numInputArgs = 0; uint32_t scratchSize = 0; + uint32_t profileSize = 0; + uint32_t debugSize = 0; NeutronModelConfig mcfg; NeutronDataConfig dcfg; NeutronModelHandle nmh = NULL; @@ -269,6 +276,7 @@ class NeutronBackend final : public PyTorchBackendInterface { OUTPUT_TENSOR_FORMAT_ARRAY_ADDR(payloadFlags); cfg->inputMap = INPUT_TENSOR_MAP_ARRAY_ADDR(payloadFlags); cfg->outputMap = OUTPUT_TENSOR_MAP_ARRAY_ADDR(payloadFlags); + uint8_t payloadVersion = *PAYLOAD_VERSION_ADDR(payloadFlags); const uint32_t* buffer = static_cast( static_cast PAYLOAD_ADDR(payloadFlags)); @@ -282,9 +290,28 @@ class NeutronBackend final : public PyTorchBackendInterface { } uint32_t microcodeSize = buffer[6]; uint32_t weightsSize = buffer[7]; - cfg->scratchSize = buffer[9]; - cfg->numInputs = buffer[11]; - cfg->numOutputs = buffer[12]; + switch (payloadVersion) { + case 0: + cfg->scratchSize = buffer[9]; + cfg->profileSize = 0; + cfg->debugSize = 0; + cfg->numInputs = buffer[11]; + cfg->numOutputs = buffer[12]; + break; + case 1: + cfg->scratchSize = buffer[9]; + cfg->profileSize = buffer[10]; + cfg->debugSize = buffer[11]; + cfg->numInputs = buffer[13]; + cfg->numOutputs = buffer[14]; + break; + default: + ET_LOG( + Error, + "Unknown payload version %d. Please update the backend", + payloadVersion); + return Error::InvalidProgram; + } if (cfg->numInputs != numInputs) { ET_LOG( Error, @@ -336,27 +363,60 @@ class NeutronBackend final : public PyTorchBackendInterface { // Allocate place for input and output pointers. cfg->dcfg.inputs = static_cast( context.allocate(cfg->numInputs * sizeof(void*))); - cfg->dcfg.outputs = - static_cast(context.allocate(cfg->numOutputs * sizeof(void*))); + // There are 3 extra entries: scratch, profile and debug. The scratch + // pointer was allocated implicitly in the previous versions. + cfg->dcfg.outputs = static_cast( + context.allocate((cfg->numOutputs + 3) * sizeof(void*))); cfg->dcfg.outputs[cfg->numOutputs] = static_cast(context.allocate(cfg->scratchSize, 16)); + cfg->dcfg.outputs[cfg->numOutputs + 1] = + static_cast(context.allocate(cfg->profileSize, 16)); + cfg->dcfg.outputs[cfg->numOutputs + 2] = + static_cast(context.allocate(cfg->debugSize, 16)); // Set inputs from args. // Transpose inputs if needed. for (int i = 0; i < cfg->numInputs; i++) { auto arg = args[cfg->inputMap[i]]->toTensor(); + auto dim_order = arg.dim_order().data(); + if (cfg->inputTranspositionFlags[i] && multipleChannelsPresent(arg.sizes())) { + // The input must be transposed. if (arg.sizes().size() < 3) { ET_LOG(Error, "Unable to transpose 1D and 2D input to channel last"); return Error::InvalidProgram; } - // Allocate buffer, the allocator is reset after each PTE instruction. - void* buffer = context.allocate(arg.nbytes()); - transposeInput( - arg.const_data_ptr(), buffer, arg.sizes(), arg.element_size()); - cfg->dcfg.inputs[i] = buffer; + + if (is_channels_last_dim_order(dim_order, arg.dim())) { + // The tensor is already permuted. + ET_LOG(Info, "Using channels last dim order for input %d.\n", i); + cfg->dcfg.inputs[i] = arg.const_data_ptr(); + } else if (is_contiguous_dim_order(dim_order, arg.dim())) { + // Transpose the data to channels last. + + ET_LOG(Info, "Transposing input %d to channels last.\n", i); + + // Allocate buffer, the allocator is reset after each PTE instruction. + void* buffer = context.allocate(arg.nbytes(), 16); + transposeInput( + arg.const_data_ptr(), buffer, arg.sizes(), arg.element_size()); + cfg->dcfg.inputs[i] = buffer; + } else { + // Unexpected dim-order. + ET_LOG(Error, "Input %d uses unsupported dim-order.", i); + return Error::InvalidProgram; + } } else { + // The input matches the ExecuTorch format, so no transposition is + // needed. + + if (!is_contiguous_dim_order(dim_order, arg.dim())) { + // Unexpected dim-order. + ET_LOG(Error, "Input %d uses unsupported dim-order.", i); + return Error::InvalidProgram; + } + cfg->dcfg.inputs[i] = arg.const_data_ptr(); } } @@ -365,12 +425,35 @@ class NeutronBackend final : public PyTorchBackendInterface { // Redirect outputs if needed before transposition. for (int i = 0; i < cfg->numOutputs; i++) { auto arg = args[cfg->numInputArgs + cfg->outputMap[i]]->toTensor(); + auto dim_order = arg.dim_order().data(); + if (cfg->outputTranspositionFlags[i] && multipleChannelsPresent(arg.sizes())) { - // Allocate buffer, the allocator is reset after each PTE instruction. - void* buffer = context.allocate(arg.nbytes()); - cfg->dcfg.outputs[i] = buffer; + // The output will have to be transposed. + + if (is_channels_last_dim_order(dim_order, arg.dim())) { + // The tensor will already be correctly permuted. No transposition + // needed. + cfg->dcfg.outputs[i] = arg.mutable_data_ptr(); + } else if (is_contiguous_dim_order(dim_order, arg.dim())) { + // Allocate buffer, the allocator is reset after each PTE instruction. + void* buffer = context.allocate(arg.nbytes(), 16); + cfg->dcfg.outputs[i] = buffer; + } else { + // Unexpected dim-order. + ET_LOG(Error, "Output %d uses unsupported dim-order.", i); + return Error::InvalidProgram; + } } else { + // The tensor should match the ExecuTorch required format, so no + // transposition is needed. + + if (!is_contiguous_dim_order(dim_order, arg.dim())) { + // Unexpected dim-order. + ET_LOG(Error, "Output %d uses unsupported dim-order.", i); + return Error::InvalidProgram; + } + cfg->dcfg.outputs[i] = arg.mutable_data_ptr(); } } @@ -394,18 +477,35 @@ class NeutronBackend final : public PyTorchBackendInterface { // Transpose outputs. for (int i = 0; i < cfg->numOutputs; i++) { auto arg = args[cfg->numInputArgs + cfg->outputMap[i]]->toTensor(); + if (cfg->outputTranspositionFlags[i] && multipleChannelsPresent(arg.sizes())) { + // The output must be transposed. + if (arg.sizes().size() < 3) { ET_LOG( Error, "Unable to transpose 1D and 2D output to channel first"); return Error::InvalidProgram; } - transposeOutput( - cfg->dcfg.outputs[i], - arg.mutable_data_ptr(), - arg.sizes(), - arg.element_size()); + + auto dim_order = arg.dim_order().data(); + if (is_channels_last_dim_order(dim_order, arg.dim())) { + // The rest of the model expects the `channels_last` dim order, which + // the data already matches. + ET_LOG(Info, "Using channels last dim order for output %d.\n", i); + } else if (is_contiguous_dim_order(dim_order, arg.dim())) { + // Transpose the data to channels first. + ET_LOG(Info, "Transposing output %d to channels first.\n", i); + transposeOutput( + cfg->dcfg.outputs[i], + arg.mutable_data_ptr(), + arg.sizes(), + arg.element_size()); + } else { + // Unexpected dim-order. + ET_LOG(Error, "Output %d uses unsupported dim-order.", i); + return Error::InvalidProgram; + } } } @@ -434,7 +534,6 @@ auto backend = NeutronBackend(); Backend backend_id{"NeutronBackend", &backend}; static auto registered = register_backend(backend_id); } // namespace - } // namespace neutron } // namespace executor -} // namespace torch +} // namespace torch \ No newline at end of file diff --git a/backends/nxp/runtime/NeutronDriver.h b/backends/nxp/runtime/NeutronDriver.h index 5ae4c3a3ff9..5c47bd74eab 100644 --- a/backends/nxp/runtime/NeutronDriver.h +++ b/backends/nxp/runtime/NeutronDriver.h @@ -18,22 +18,6 @@ extern "C" { #include "NeutronErrors.h" -/* Neutron Driver error category codes */ -typedef enum ERROR_CATEGORY_DRIVER { - ERROR_CATEGORY_DRIVER_GENERIC, /* Generic error category */ - ERROR_CATEGORY_DRIVER_UNSUPPORTED, /* Unsupported function */ - ERROR_CATEGORY_DRIVER_UCODE, /* Microcode bad magic or version incompatible. - */ - ERROR_CATEGORY_DRIVER_INVALID, /* Invalid arguments */ - ERROR_CATEGORY_DRIVER_BAD_HANDLE, /* Bad inference handle */ - ERROR_CATEGORY_DRIVER_NO_MEMORY, /* Not enough memory */ - ERROR_CATEGORY_DRIVER_INTERNAL_FAULT, /* Internal error */ - ERROR_CATEGORY_DRIVER_UNKNOWN_ARCH, /* Unknown architecture */ - ERROR_CATEGORY_DRIVER_TRACE_NOT_RUN, /* Tracing did not run, but trace buffer - was requested. */ - ERROR_CATEGORY_DRIVER_TIMEOUT /* Timeout error. */ -} ERROR_CATEGORY_DRIVER; - /// Trace configuration to enable kernel level tracing. #define TRACE_CONFIG_KERNEL_LEVEL (1U << 0) @@ -169,6 +153,12 @@ NeutronError neutronCustomExec( NeutronModelHandle hdl, const NeutronDataConfig* neutron_dcfg); +/// - Setup the input and output data ptr to use Neutron memory area. +/// - The input and ouput data ptr is stored in neutron_dcfg. +NeutronError neutronDataSetup( + NeutronModelHandle hdl, + NeutronDataConfig* neutron_dcfg); + /// - Prepare Neutron execution for a model with the given configuration. /// - This function only prepares the execution by transferring the parameters /// to the firmware. @@ -245,6 +235,29 @@ void* neutronMemAlloc(size_t alignment, size_t size); /// - This function is only available for Neutron-S in the Linux environment. void neutronMemFree(void* ptr); +/// - Allocates size bytes large buffer in DDR to be used for specialized +/// kernels (e.g. batch matmul) +/// Uses Linux CMA allocator +NeutronError allocateBuffer(uint64_t size, void** pBuffer, bool userspace); + +/// - Frees buffer allocated via allocateBuffer function +NeutronError releaseBuffer(void* buffer); + +/// - Clean/flush cache for DDR allocated buffer +/// TODO: rename function as "cleanCache" to satisfy neutron-software naming +/// convention +NeutronError clean_cache(const void* addr, int size); + +/// - Function for calling firmware for specialized kernel (matmul) +NeutronError matmul( + const void* info, + int sizeInfo, + const void* in, + int sizeIn, + const void* out, + int sizeOut, + int idxSlot); + /// Other functions to control the state of driver/firmware. #ifdef __cplusplus } diff --git a/backends/nxp/runtime/NeutronErrors.h b/backends/nxp/runtime/NeutronErrors.h index 5141c4bb4c5..071db8b44be 100644 --- a/backends/nxp/runtime/NeutronErrors.h +++ b/backends/nxp/runtime/NeutronErrors.h @@ -39,6 +39,32 @@ typedef enum ERROR_COMPONENT_ID { ERROR_COMPONENT_DRIVER = 0x3 } ERROR_COMPONENT_ID; +/* Neutron Firmware error category codes */ +typedef enum ERROR_CATEGORY_FW { + ERROR_CATEGORY_FW_GENERIC, /* Generic error category */ + ERROR_CATEGORY_FW_UCODE, /* Microcode bad magic or version incompatible. */ + ERROR_CATEGORY_FW_BUFFER_OVERFLOW, /* Buffer overflow error category */ + ERROR_CATEGORY_FW_NULL_POINTER, /* Pointer is null */ + ERROR_CATEGORY_FW_INTR_ERROR, /* Interrupt triggering error */ + ERROR_CATEGORY_FW_DMAPI_ERROR, /* DM API parameter error */ +} ERROR_CATEGORY_FW; + +/* Neutron Driver error category codes */ +typedef enum ERROR_CATEGORY_DRIVER { + ERROR_CATEGORY_DRIVER_GENERIC, /* Generic error category */ + ERROR_CATEGORY_DRIVER_UNSUPPORTED, /* Unsupported function */ + ERROR_CATEGORY_DRIVER_UCODE, /* Microcode bad magic or version incompatible. + */ + ERROR_CATEGORY_DRIVER_INVALID, /* Invalid arguments */ + ERROR_CATEGORY_DRIVER_BAD_HANDLE, /* Bad inference handle */ + ERROR_CATEGORY_DRIVER_NO_MEMORY, /* Not enough memory */ + ERROR_CATEGORY_DRIVER_INTERNAL_FAULT, /* Internal error */ + ERROR_CATEGORY_DRIVER_UNKNOWN_ARCH, /* Unknown architecture */ + ERROR_CATEGORY_DRIVER_TRACE_NOT_RUN, /* Tracing did not run, but trace buffer + was requested. */ + ERROR_CATEGORY_DRIVER_TIMEOUT /* Timeout error. */ +} ERROR_CATEGORY_DRIVER; + /// Retrieve component name as string from NeutronError code. char* getNeutronErrorComponent(NeutronError ne); diff --git a/backends/nxp/runtime/targets.bzl b/backends/nxp/runtime/targets.bzl index 1eacbbe0a2b..3214761a9cb 100644 --- a/backends/nxp/runtime/targets.bzl +++ b/backends/nxp/runtime/targets.bzl @@ -1,20 +1,25 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbsource//tools/target_determinator/macros:ci.bzl", "ci") def define_common_targets(): runtime.cxx_library( - name = "nxp_backend", + name = "nxp_backend_base", srcs = ["NeutronBackend.cpp"], - headers = ["NeutronDriver.h", "NeutronErrors.h"], - compatible_with = ["ovr_config//cpu:arm32-embedded", "@fbsource//arvr/firmware/projects/smartglasses/config:embedded-mcu-rtos"], - # Neutron runtime needs to compile with executor as whole - # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + exported_headers = [ + "NeutronDriver.h", + "NeutronErrors.h", + ], link_whole = True, # Constructor needed for backend registration. compiler_flags = ["-Wno-global-constructors", "-fno-rtti", "-DNO_HEAP_USAGE"], - visibility = ["@EXECUTORCH_CLIENTS"], + labels = [ci.skip_target()], + visibility = [ + "//executorch/backends/nxp/runtime/fb:nxp_fb_backend", + "//executorch/backends/nxp/runtime/fb:nxp_hifi_fb_backend", + "@EXECUTORCH_CLIENTS", + ], deps = [ "//executorch/runtime/backend:interface", "//executorch/runtime/core:core", - "fbsource//arvr/third-party/toolchains/nxp-sdk/2.16.0/middleware/eiq/executorch/third-party/neutron/rt700:libNeutron", ], ) diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index 881bdeeec7b..61af7b5c67f 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass +from functools import partial from typing import Callable import torch @@ -12,24 +13,35 @@ from executorch.backends.nxp.backend.custom_delegation_options import ( CustomDelegationOptions, ) -from executorch.backends.nxp.backend.ir.edge_passes.remove_io_quant_ops_pass import ( - RemoveIOQuantOpsPass, -) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import ( NeutronEdgePassManager, ) +from executorch.backends.nxp.edge_passes.remove_additional_quantize_dequantize_nodes_pass import ( + RemoveAdditionalQDQClustersPass, +) +from executorch.backends.nxp.edge_passes.remove_io_quant_ops_pass import ( + RemoveIOQuantOpsPass, +) from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.quantizer.utils import calibrate_and_quantize from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, ExecutorchBackendConfig, ExecutorchProgramManager, + to_edge_transform_and_lower, ) -from executorch.extension.export_util.utils import export_to_edge from torch import nn -from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.export import export +from torchao.quantization.pt2e.quantizer import Quantizer + +neutron_converter_flavor = "SDK_25_09" +neutron_target_spec = NeutronTargetSpec( + target="imxrt700", neutron_converter_flavor=neutron_converter_flavor +) @dataclass @@ -38,17 +50,6 @@ class ModelInputSpec: dtype: torch.dtype = torch.float32 -def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor, ...]]): - quantizer = NeutronQuantizer() - - m = prepare_pt2e(model, quantizer) - for data in calibration_inputs: - m(*data) - m = convert_pt2e(m) - - return m - - def get_random_calibration_inputs( input_spec: tuple[ModelInputSpec, ...] ) -> list[tuple[torch.Tensor, ...]]: @@ -58,10 +59,13 @@ def get_random_calibration_inputs( ] +def _get_default_quantizer(target_spec: NeutronTargetSpec, use_qat: bool) -> Quantizer: + return NeutronQuantizer(target_spec, is_qat=use_qat) + + def to_model_input_spec( input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]] ) -> tuple[ModelInputSpec, ...]: - if isinstance(input_spec, tuple) and all( isinstance(spec, ModelInputSpec) for spec in input_spec ): @@ -88,12 +92,20 @@ def to_quantized_edge_program( [tuple[ModelInputSpec, ...]], list[tuple[torch.Tensor, ...]] ] = get_random_calibration_inputs, target="imxrt700", - neutron_converter_flavor="SDK_25_06", + neutron_converter_flavor=neutron_converter_flavor, + use_qat=False, remove_quant_io_ops=False, custom_delegation_options=CustomDelegationOptions(), # noqa B008 + get_quantizer_fn=None, + use_neutron_for_format_conversion=True, ) -> EdgeProgramManager: - calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec)) + _neutron_target_spec = NeutronTargetSpec(target, neutron_converter_flavor) + if get_quantizer_fn is None: + get_quantizer_fn = partial( + _get_default_quantizer, _neutron_target_spec, use_qat + ) + calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec)) example_input = calibration_inputs[0] # Make sure the model is in the evaluation mode. @@ -101,40 +113,56 @@ def to_quantized_edge_program( exir_program_aten = torch.export.export(model, example_input, strict=True) - exir_program_aten__module_quant = _quantize_model( - exir_program_aten.module(), calibration_inputs + exir_program_aten__module_quant = calibrate_and_quantize( + model=exir_program_aten, + calibration_inputs=calibration_inputs, + quantizer=get_quantizer_fn(), + is_qat=use_qat, ) - edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) - edge_program_manager = export_to_edge( - exir_program_aten__module_quant, - example_input, - edge_compile_config=edge_compile_config, - ) - - edge_program_manager = NeutronEdgePassManager()(edge_program_manager) - compile_spec = generate_neutron_compile_spec( target, operators_not_to_delegate=operators_not_to_delegate, neutron_converter_flavor=neutron_converter_flavor, + use_neutron_for_format_conversion=use_neutron_for_format_conversion, + ) + partitioners = [ + NeutronPartitioner( + compile_spec, _neutron_target_spec, custom_delegation_options + ) + ] + + edge_program_manager = to_edge_transform_and_lower( + export(exir_program_aten__module_quant, example_input, strict=True), + transform_passes=NeutronEdgePassManager(), + partitioner=partitioners, + compile_config=EdgeCompileConfig(_check_ir_validity=False), ) - partitioner = NeutronPartitioner(compile_spec, custom_delegation_options) - edge_program_manager = edge_program_manager.to_backend(partitioner) if remove_quant_io_ops: edge_program_manager = edge_program_manager.transform( [RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)] ) + edge_program_manager = edge_program_manager.transform( + NeutronEdgePassManager([RemoveAdditionalQDQClustersPass()]) + ) + return edge_program_manager def to_quantized_executorch_program( model: torch.nn.Module, input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]], + use_qat: bool = False, + use_neutron_for_format_conversion: bool = True, ) -> ExecutorchProgramManager: - edge_program_manager = to_quantized_edge_program(model, input_spec) + edge_program_manager = to_quantized_edge_program( + model, + input_spec, + use_qat=use_qat, + use_neutron_for_format_conversion=use_neutron_for_format_conversion, + ) return edge_program_manager.to_executorch( config=ExecutorchBackendConfig(extract_delegate_segments=False) diff --git a/backends/nxp/tests/executors.py b/backends/nxp/tests/executors.py index afdb15af106..fa99046ff33 100644 --- a/backends/nxp/tests/executors.py +++ b/backends/nxp/tests/executors.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,15 +18,14 @@ create_channels_first_to_channels_last_permutation, create_channels_last_to_channels_first_permutation, ) -from executorch.backends.nxp.backend.ir.converter.node_converter import ( - NodeConverter, - Target, -) +from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec + +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from torch.export import ExportedProgram from torch.fx import Node from torch.fx.graph import Graph - # If executed on i.MX platform, there is no tensorflow module. And typically the intention is to use the tflite python # interpreter available in tflite_runtime try: @@ -196,6 +195,11 @@ def compare_output_arrays( assert tfl_output.shape == edge_output.shape, "Output shapes don't match!" + if (max_diff := np.abs(np.max(tfl_output - edge_output))) > 0.0: + logger.w( + f"Maximum absolute difference of the tensor '{output_name}': '{max_diff}'" + ) + assert np.allclose( tfl_output, edge_output, rtol=rtol, atol=atol, equal_nan=True ), f"Output values of the `{output_name}` tensor don't match!" @@ -305,6 +309,7 @@ def convert_run_compare( ) -> (TFLiteExecutor, EdgeProgramExecutor): if tfl_model is None: + NodeFormatInference(edge_program).identify_node_formats() tfl_model, _ = EdgeProgramToIRConverter().convert_program( edge_program, conversion_config ) @@ -365,10 +370,16 @@ def convert_run_compare( def graph_contains_any_of_ops(graph: Graph, ops: list) -> bool: - return any(node.target in ops for node in graph.nodes) + return graph_contains_any( + graph, condition=lambda n: hasattr(n, "target") and n.target in ops + ) + + +def graph_contains_any(graph: Graph, condition: Callable[[Node], bool]) -> bool: + return any(map(condition, graph.nodes)) -target_support_check_function = Callable[[Node, Target], bool] +target_support_check_function = Callable[[Node, NeutronTargetSpec], bool] class OverrideTargetSupportCheck: diff --git a/backends/nxp/tests/exported_program_vizualize.py b/backends/nxp/tests/exported_program_vizualize.py deleted file mode 100644 index 0f4b8db697c..00000000000 --- a/backends/nxp/tests/exported_program_vizualize.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2024 NXP -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import random - -from gvgen import GvGen -from torch.export import ExportedProgram - - -def exported_program_to_dot( # noqa C901 - exported_program: ExportedProgram, dot_file_name="graph.dot", show_tags=True -): - """ - Generate dot file for tagged exported program. - - :param exported_program: Exported program with optional meta values: 'delegation_tag' and 'cluster'. - :param dot_file_name: Produced .dot file name. - :param show_tags: If True, nodes will be shown as a subcomponent of tag nodes. - """ - graph = GvGen() - - def name_color(string): # pseudo-randomization function - h = hash(string) # hash string and int together - if h < 0: # ensure positive number - h = h * -1 - random.seed(h) # set the seed to use for randomization - r = int(random.random() * 255) - g = int(random.random() * 255) - b = int(random.random() * 255) - return "#%02x%02x%02x" % (r, g, b) - - graph_items = {} - delegation_tags = {} - - # Find tags (parent objects) - for node in exported_program.graph.nodes: - if "delegation_tag" in node.meta and show_tags: - tag = node.meta["delegation_tag"] - if tag not in delegation_tags: - item = graph.newItem(tag) - delegation_tags[tag] = item - - for node in exported_program.graph.nodes: - if "delegation_tag" in node.meta and show_tags: - # Delegated node -> add color - tag = node.meta["delegation_tag"] - item = graph.newItem(node.name, delegation_tags[tag]) - - graph.propertyAppend(item, "fillcolor", name_color(tag)) - graph.propertyAppend(item, "style", "filled") - else: - item = graph.newItem(node.name) - - label = graph.propertyGet(item, "label") - if "cluster" in node.meta: - graph.propertyAppend( - item, "label", label + "\n QDQ Cluster: " + node.meta["cluster"] - ) - - # Change shape of node for (de)quantize and rest of nodes - if any(q in label for q in ["_quantize_per_tensor_", "_quantize_per_channel_"]): - graph.propertyAppend(item, "shape", "invhouse") - elif any( - dq in label - for dq in ["_dequantize_per_tensor_", "_dequantize_per_channel_"] - ): - graph.propertyAppend(item, "shape", "house") - else: - graph.propertyAppend(item, "shape", "box") - - graph_items[node.name] = item - - # Add connections between nodes - for node in exported_program.graph.nodes: - for user in node.users: - link = graph.newLink(graph_items[node.name], graph_items[user.name]) - - label = "" - if "val" in node.meta: - tensor = node.meta["val"] - if isinstance(tensor, tuple): - tensor = tensor[0] # Fake tensor - label = f" ({list(tensor.shape)} | {tensor.dtype})" - - graph.propertyAppend(link, "label", label) - - with open(dot_file_name, "w") as f: - graph.dot(f) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py index 315c76a7614..2e9a1b393ff 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py @@ -14,11 +14,13 @@ from executorch.backends.nxp.tests.executors import ( convert_run_compare, graph_contains_any_of_ops, - ToNCHWPreprocess, - ToNHWCPreprocess, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, ) + from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -62,12 +64,14 @@ def forward(self, x): return x.abs() -def test_conv_abs(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): +def test_conv_abs(mocker, use_qat, input_shape: tuple[int] = (1, 3, 112, 112)): model = ConvBlocksWithAbs(conv_in_channels=input_shape[1]) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -80,8 +84,8 @@ def test_conv_abs(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data, atol=1.0, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py index 9c8235f7eda..db5cbdcbb5e 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py @@ -16,6 +16,7 @@ AdaptiveAvgPool2dConvModule, ) from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -40,14 +41,16 @@ def reseed_model_per_test_run(): ], ) def test_adaptive_avg_pool_2d_delegated_quant_conversion( - mocker, input_shape, output_size + mocker, input_shape, output_size, use_qat ): model = AdaptiveAvgPool2dConvModule(output_size) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ).exported_program() nodes = [str(node) for node in edge_program.graph.nodes] # Input size is a multiple of output size, can be converted to AveragePool, node is delegated @@ -84,14 +87,16 @@ def test_adaptive_avg_pool_2d_delegated_quant_conversion( ], ) def test_adaptive_avg_pool_2d_non_delegated_quant_conversion( - mocker, input_shape, output_size + mocker, input_shape, output_size, use_qat ): model = AdaptiveAvgPool2dConvModule(output_size) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ).exported_program() nodes = list(edge_program.graph.nodes) # Input size is not a multiple of output size, cannot be converted to AveragePool, node is not delegated @@ -115,14 +120,16 @@ def test_adaptive_avg_pool_2d_non_delegated_quant_conversion( ) -def test_adaptive_avg_pool_2d_mean_dim_quant_conversion(mocker): +def test_adaptive_avg_pool_2d_mean_dim_quant_conversion(mocker, use_qat): input_shape = (1, 4, 16, 16) model = AdaptiveAvgPool2dConvMeanDimModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py index 567b593e05b..1aa58ab5d95 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py @@ -1,3 +1,7 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. import numpy as np import pytest import torch @@ -17,6 +21,7 @@ AddTensorOneInputModule, ) from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -34,13 +39,13 @@ def reseed_model_per_test_run(): pytest.param((1, 4, 8, 8), id="4D."), ], ) -def test_add_tensor_quant_conversion(mocker, input_shape): +def test_add_tensor_quant_conversion(mocker, input_shape, use_qat): model = AddTensorModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, [input_shape, input_shape]) + _ = to_quantized_edge_program(model, [input_shape, input_shape], use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -65,13 +70,13 @@ def test_add_tensor_quant_conversion(mocker, input_shape): pytest.param((1, 4, 8, 8), id="4D."), ], ) -def test_add_tensor_one_input_quant_conversion(mocker, input_shape): +def test_add_tensor_one_input_quant_conversion(mocker, input_shape, use_qat): model = AddTensorOneInputModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -93,13 +98,15 @@ def test_add_tensor_one_input_quant_conversion(mocker, input_shape): pytest.param((1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."), ], ) -def test_add_tensor_w_conv_quant_conversion(mocker, input_shape): +def test_add_tensor_w_conv_quant_conversion(mocker, input_shape, use_qat): model = AddTensorConvModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -131,13 +138,13 @@ def test_add_tensor_w_conv_quant_conversion(mocker, input_shape): ], ) def test_add_tensor_broadcasting_unsupported_quant_conversion( - x_input_shape, y_input_shape + x_input_shape, y_input_shape, use_qat ): model = AddTensorModule() # Run conversion edge_program = to_quantized_edge_program( - model, [x_input_shape, y_input_shape] + model, [x_input_shape, y_input_shape], use_qat=use_qat ).exported_program() nodes = list(edge_program.graph.nodes) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_addmm_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_addmm_converter.py new file mode 100644 index 00000000000..a8cdee41830 --- /dev/null +++ b/backends/nxp/tests/ir/converter/node_converter/test_addmm_converter.py @@ -0,0 +1,96 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import kgb +import numpy as np +import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + graph_contains_any_of_ops, +) +from executorch.backends.nxp.tests.models import AddmmModule, LinearModule +from executorch.exir.dialects._ops import ops as exir_ops +from parameterized import parameterized +from torch.export import ExportedProgram + + +class TestAddmmConversion(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) + + @parameterized.expand([("QAT", True), ("PTQ", False)]) + def test_addmm_conversion(self, _, use_qat: bool): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 32) + model = AddmmModule(input_shape[1]) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + ) + + @parameterized.expand([("QAT", True), ("PTQ", False)]) + def test_linear_conversion__with_bias(self, _, use_qat: bool): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (10, 32) + model = LinearModule(bias=True) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py index bcdbd955c71..b6083d1e816 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py @@ -6,10 +6,11 @@ import numpy as np import pytest import torch - from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) + +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -27,6 +28,7 @@ ) from executorch.backends.nxp.tests.models import AvgPool2dConvModule, AvgPool2dModule from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -91,6 +93,9 @@ def test_avg_pool_2d_conversion(input_shape, padding, count_include_pad): input_data, tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) @@ -139,13 +144,17 @@ def test_avg_pool_2d_conversion(input_shape, padding, count_include_pad): ), ], ) -def test_avg_pool_2d_quant_conversion(mocker, input_shape, padding, count_include_pad): +def test_avg_pool_2d_quant_conversion( + mocker, input_shape, padding, count_include_pad, use_qat +): model = AvgPool2dConvModule(padding=padding, count_include_pad=count_include_pad) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -164,7 +173,7 @@ def test_avg_pool_2d_quant_conversion(mocker, input_shape, padding, count_includ ) -def test_avg_pool_2d_quant_conversion__padded(mocker): +def test_avg_pool_2d_quant_conversion__padded(mocker, use_qat): input_shape = (1, 8, 8, 8) model = AvgPool2dModule(True, 1) @@ -172,7 +181,9 @@ def test_avg_pool_2d_quant_conversion__padded(mocker): ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ) # Capture the converter operators. ops = ops_spy.spy_return.sub_graphs[0].operators.vector diff --git a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py index 3df703f5bba..e3ee2fff90b 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py @@ -17,9 +17,12 @@ from executorch.backends.nxp.tests.executors import ( convert_run_compare, graph_contains_any_of_ops, + ToNCHWPreprocess, + ToNHWCPreprocess, ) from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 def _normalized_dim(dim, rank): @@ -42,6 +45,18 @@ def forward(self, *inputs: torch.Tensor): return torch.cat(list(inputs), self.dim) +class AddCatModule(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, *inputs: torch.Tensor): + inputs = [input_ + input_ for input_ in inputs] + + return torch.cat(list(inputs), self.dim) + + class CatConvModule(torch.nn.Module): def __init__(self, dim: int, channels: int = 4): @@ -70,13 +85,13 @@ def forward(self, *inputs: torch.Tensor): pytest.param(4, 5, -3, id="4D, 5 inputs, dim=-3"), ], ) -def test_cat__same_shapes(dim, num_inputs, rank, mocker): - input_shape = tuple([2, 8, 8, 8, 8][-rank:]) +def test_cat__same_shapes(dim, num_inputs, rank, mocker, use_qat): + input_shape = tuple([8, 8, 8, 8][:rank]) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") quantized_program = to_quantized_edge_program( - CatModule(dim), [input_shape] * num_inputs + CatModule(dim), [input_shape] * num_inputs, use_qat=use_qat ).exported_program() # Make sure the `Cat` was delegated. @@ -101,13 +116,13 @@ def test_cat__same_shapes(dim, num_inputs, rank, mocker): @pytest.mark.parametrize("dim", [3, -2, -3]) @pytest.mark.parametrize("num_inputs", [2, 5]) -def test_cat__channels_first__same_shapes(dim, num_inputs, mocker): +def test_cat__channels_first__same_shapes(dim, num_inputs, mocker, use_qat): input_shape = (2, 8, 6, 8) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") channels = input_shape[1] if dim not in {1, -3} else input_shape[1] * num_inputs quantized_program = to_quantized_edge_program( - CatConvModule(dim, channels), [input_shape] * num_inputs + CatConvModule(dim, channels), [input_shape] * num_inputs, use_qat=use_qat ).exported_program() # Make sure the `Cat` was delegated. @@ -126,17 +141,31 @@ def test_cat__channels_first__same_shapes(dim, num_inputs, mocker): exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), atol=1, ) -@pytest.mark.parametrize("dim", [0, -4]) -@pytest.mark.parametrize("num_inputs", [2]) -def test_cat__unsupported_dim__imxrt700(dim, num_inputs): - input_shape = (2, 8, 6, 8) - +@pytest.mark.parametrize( + "dim, input_shape", + [ + pytest.param(0, (1, 8, 8, 8), id="axis = 0"), + pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."), + pytest.param(-4, (1, 8, 8, 8), id="axis = -4"), + pytest.param(1, (1, 1, 8, 8), id="axis = 1"), + pytest.param(-3, (1, 1, 8, 8), id="axis = -3"), + pytest.param(2, (1, 1, 1, 8), id="axis = 2"), + pytest.param(-2, (1, 1, 1, 8), id="axis = -2"), + ], +) +def test_cat__unsupported__imxrt700(dim, input_shape, use_qat): + """This test is conjoined with the one below (`test_cat__context_dependent__imxrt700`). + In this case, the inputs of the `cat` are NOT compute ops, so the `cat` is NOT delegated. + """ + num_inputs = 2 quantized_program = to_quantized_edge_program( - CatModule(dim), [input_shape] * num_inputs, target="imxrt700" + CatModule(dim), [input_shape] * num_inputs, target="imxrt700", use_qat=use_qat ).exported_program() # Make sure the `Cat` was NOT delegated. @@ -148,6 +177,35 @@ def test_cat__unsupported_dim__imxrt700(dim, num_inputs): ) +@pytest.mark.parametrize( + "dim, input_shape", + [ + pytest.param(0, (1, 8, 8, 8), id="axis = 0"), + pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."), + pytest.param(-4, (1, 8, 8, 8), id="axis = -4"), + pytest.param(1, (1, 1, 8, 8), id="axis = 1"), + pytest.param(-3, (1, 1, 8, 8), id="axis = -3"), + pytest.param(2, (1, 1, 1, 8), id="axis = 2"), + pytest.param(-2, (1, 1, 1, 8), id="axis = -2"), + ], +) +def test_cat__context_dependent__imxrt700(dim, input_shape, use_qat): + """This test is conjoined with the one above (`test_cat__unsupported__imxrt700`). + In this case, the inputs of the `cat` are compute ops, so the `cat` is delegated. + """ + num_inputs = 2 + ep = to_quantized_edge_program( + AddCatModule(dim), + [input_shape] * num_inputs, + target="imxrt700", + use_qat=use_qat, + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.cat.default]) + assert any("lowered_module" in node.name for node in ep.graph.nodes) + + @pytest.mark.parametrize( "rank, num_inputs, dim", [ @@ -164,7 +222,7 @@ def test_cat__unsupported_dim__imxrt700(dim, num_inputs): pytest.param(4, 5, -3, id="4D, 5 inputs, dim=-3"), ], ) -def test_cat__different_shapes(dim, num_inputs, rank, mocker): +def test_cat__different_shapes(dim, num_inputs, rank, mocker, use_qat): input_shape = tuple([2, 8, 8, 8, 8][-rank:]) # The shape of every input will be different along the concatenated dimension. @@ -177,7 +235,7 @@ def test_cat__different_shapes(dim, num_inputs, rank, mocker): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") quantized_program = to_quantized_edge_program( - CatModule(dim), input_shapes + CatModule(dim), input_shapes, use_qat=use_qat ).exported_program() # Make sure the `Cat` was delegated. @@ -204,7 +262,7 @@ def test_cat__different_shapes(dim, num_inputs, rank, mocker): @pytest.mark.parametrize( "num_inputs", [2, 5], ids=lambda num_inputs: f"num_inputs = {num_inputs}" ) -def test_cat__channels_first__different_shapes(dim, num_inputs, mocker): +def test_cat__channels_first__different_shapes(dim, num_inputs, mocker, use_qat): input_shape = (2, 8, 6, 8) # The shape of every input will be different along the concatenated dimension. @@ -222,7 +280,7 @@ def test_cat__channels_first__different_shapes(dim, num_inputs, mocker): sum(shape[1] for shape in input_shapes) if dim in [1, -3] else input_shape[1] ) quantized_program = to_quantized_edge_program( - CatConvModule(dim, channels), input_shapes + CatConvModule(dim, channels), input_shapes, use_qat=use_qat ).exported_program() # Make sure the `Cat` was delegated. @@ -241,11 +299,13 @@ def test_cat__channels_first__different_shapes(dim, num_inputs, mocker): exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), atol=1, ) -def test_cat__different_shapes__unsupported_channels__imxrt700(): +def test_cat__different_shapes__unsupported_channels__imxrt700(use_qat): input_shape = (2, 4, 6, 7) # (channels % 8) != 0 num_inputs = 2 @@ -259,7 +319,7 @@ def test_cat__different_shapes__unsupported_channels__imxrt700(): input_shapes.append(tuple(tmp_shape)) quantized_program = to_quantized_edge_program( - CatModule(dim), input_shapes, target="imxrt700" + CatModule(dim), input_shapes, target="imxrt700", use_qat=use_qat ).exported_program() # Make sure the `Cat` was NOT delegated. @@ -271,7 +331,7 @@ def test_cat__different_shapes__unsupported_channels__imxrt700(): ) -def test_cat__force_delegate(): +def test_cat__force_delegate(use_qat): target = "imxrt700" # The Partitioner doesn't know if the `8` or the `1` will become the channels in the IR. Therefore, it would @@ -283,6 +343,51 @@ def test_cat__force_delegate(): [input_shape, input_shape], target=target, custom_delegation_options=CustomDelegationOptions(force_delegate_cat=True), + use_qat=use_qat, + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + +def test_cat__same_shapes_converter_padding_last_dimension(use_qat): + target = "imxrt700" + + # The Converter is capable of padding the last dimension of `cat` with the same input shapes. + input_shape = (3, 1, 3) + + quantized_program = to_quantized_edge_program( + CatModule(2), + [input_shape, input_shape], + target=target, + neutron_converter_flavor="SDK_25_09", + custom_delegation_options=CustomDelegationOptions(), + use_qat=use_qat, + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + +def test_cat__same_shapes__channels_first__padding_channels(use_qat): + target = "imxrt700" + + # The Converter is capable of padding the last dimension of `cat` with the same input shapes. + input_shape = (1, 2, 3, 4) + + quantized_program = to_quantized_edge_program( + CatConvModule(1), + [input_shape, input_shape], + target=target, + neutron_converter_flavor="SDK_25_09", + custom_delegation_options=CustomDelegationOptions(), + use_qat=use_qat, ).exported_program() # Make sure the `Cat` was delegated. @@ -290,3 +395,101 @@ def test_cat__force_delegate(): graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] ) assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + +def test_cat__same_shapes_converter_padding_middle_dimension(use_qat): + target = "imxrt700" + + # The Converter is not capable of padding the middle dimensions of `cat` with the same input shapes. + input_shape = (3, 1, 3) + + quantized_program = to_quantized_edge_program( + CatModule(1), + [input_shape, input_shape], + target=target, + custom_delegation_options=CustomDelegationOptions(), + use_qat=use_qat, + ).exported_program() + + # Make sure the `Cat` was NOT delegated. + assert graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert not any( + "lowered_module" in node.name for node in quantized_program.graph.nodes + ) + + +def test_cat__format_specific_support__formatless(mocker, use_qat): + # The last dim will end up being the channels, as the format is `formatless`. + # Only the last dim satisfies the Neutron requirements for the channels. + input_shape = (3, 3, 3, 8) + num_inputs = 2 + dim = 2 + + input_shapes = [input_shape] * num_inputs + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + quantized_program = to_quantized_edge_program( + CatModule(dim), input_shapes, use_qat=use_qat + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + input_data = { + i: (np.random.random(shape) * 50).astype(np.int8) + for i, shape in enumerate(input_shapes) + } + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + input_data=input_data, + atol=1, + ) + + +def test_cat__format_specific_support__channels_first(mocker, use_qat): + # The second dim will end up being the channels, as the format is `formatless`. + # Only the second dim satisfies the Neutron requirements for the channels. + input_shape = (3, 8, 3, 3) + num_inputs = 2 + dim = 2 + + input_shapes = [input_shape] * num_inputs + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + channels = ( + sum(shape[1] for shape in input_shapes) if dim in [1, -3] else input_shape[1] + ) + quantized_program = to_quantized_edge_program( + CatConvModule(dim, channels), input_shapes, use_qat=use_qat + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + input_data = { + i: (np.random.random(shape) * 50).astype(np.int8) + for i, shape in enumerate(input_shapes) + } + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), + atol=1, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py index f5945607f1b..250ddb88212 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py @@ -4,31 +4,33 @@ # LICENSE file in the root directory of this source tree. +import itertools +import unittest + +import kgb import numpy as np -import pytest import torch from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executorch_pipeline import ( + to_edge_program, + to_quantized_edge_program, +) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any, graph_contains_any_of_ops, - ToNCHWPreprocess, - ToNHWCPreprocess, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, ) from executorch.exir.dialects._ops import ops as exir_ops +from parameterized import parameterized from torch import nn from torch.export import ExportedProgram -@pytest.fixture(autouse=True) -def reseed_model_per_test_run(): - torch.manual_seed(23) - np.random.seed(23) - - class SingleConvBlockWithDropout(torch.nn.Module): def __init__( self, conv_in_channels: int = 3, perform_inplace_dropout: bool = False @@ -74,57 +76,124 @@ def forward(self, x): return self.block(x) -@pytest.mark.parametrize("inplace_dropout", [False, True]) -@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128), (1, 3, 256, 256)]) -def test_conv_dropout_quant(mocker, inplace_dropout: bool, input_shape: tuple[int]): - model = SingleConvBlockWithDropout( - conv_in_channels=input_shape[1], perform_inplace_dropout=inplace_dropout - ).eval() - - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") +class TestCloneConverter(unittest.TestCase): + __test__ = False # Prevent interfering with PyTest tests - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() - - tflite_flatbuffers_model, io_formats = converter_spy.spy_return - exported_program: ExportedProgram = converter_spy.call_args.args[1] - - assert not graph_contains_any_of_ops( - graph=quantized_program.graph, ops=[exir_ops.edge.aten.clone.default] - ) - - input_data = (np.random.random(input_shape) * 50).astype(np.int8) - convert_run_compare( - exported_program, - tfl_model=tflite_flatbuffers_model, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), - input_data=input_data, - atol=1.0, - ) + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(23) + @staticmethod + def _node_is_clone(node) -> bool: + clone_ops = [ + exir_ops.edge.aten.clone.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, + ] -@pytest.mark.parametrize("inplace_dropout", [False, True]) -def test_clone_pool_view_copy_quant( - mocker, inplace_dropout: bool, input_shape: tuple[int] = (1, 64, 25, 5) -): - model = KWSFinalBlock(input_shape).eval() + def target_can_be_clone(node): + if hasattr(node, "op") and node.op == "call_function": + return "clone" in node.target.__name__ - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + return False - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + return node in clone_ops or target_can_be_clone(node) - tflite_flatbuffers_model, io_formats = converter_spy.spy_return - exported_program: ExportedProgram = converter_spy.call_args.args[1] - - assert not graph_contains_any_of_ops( - graph=quantized_program.graph, ops=[exir_ops.edge.aten.clone.default] + @parameterized.expand( + list( + itertools.product( + [True, False], [(1, 3, 128, 128), (1, 3, 256, 256)], [True, False] + ) + ) ) - - input_data = (np.random.random(input_shape) * 50).astype(np.int8) - convert_run_compare( - exported_program, - tfl_model=tflite_flatbuffers_model, - tflite_input_preprocess=ToNHWCPreprocess(), - input_data=input_data, - atol=1.0, + def test_conv_dropout_quant( + self, inplace_dropout: bool, input_shape: tuple[int], use_qat: bool + ): + model = SingleConvBlockWithDropout( + conv_in_channels=input_shape[1], perform_inplace_dropout=inplace_dropout + ).eval() + + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + quantized_program = to_quantized_edge_program( + model, + input_shape, + use_qat=use_qat, + use_neutron_for_format_conversion=False, + ).exported_program() + + tflite_flatbuffers_model, _ = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + assert not graph_contains_any( + graph=quantized_program.graph, + condition=TestCloneConverter._node_is_clone, + ) + + input_data = (np.random.random(input_shape) * 50).astype(np.int8) + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), + input_data=input_data, + atol=1.0, + ) + + @parameterized.expand( + list(itertools.product([True, False], [(1, 3, 128, 128), (1, 3, 256, 256)])) ) + def test_conv_dropout_no_quant( + self, inplace_dropout: bool, input_shape: tuple[int] + ): + model = SingleConvBlockWithDropout( + conv_in_channels=input_shape[1], perform_inplace_dropout=inplace_dropout + ).eval() + + edge_program = to_edge_program(model, input_shape).exported_program() + + has_clone = graph_contains_any_of_ops( + graph=edge_program.graph, + ops=[ + exir_ops.edge.aten.clone.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, + ], + ) + + # Clone with inplace=True should not produce clone edge op and vice versa + assert inplace_dropout ^ has_clone + + @parameterized.expand([("QAT", True), ("PTQ", False)]) + def test_clone_pool_view_copy_quant( + self, _, use_qat: bool, input_shape: tuple[int] = (1, 64, 25, 5) + ): + model = KWSFinalBlock(input_shape).eval() + + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + quantized_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + tflite_flatbuffers_model, _ = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + assert not graph_contains_any( + graph=quantized_program.graph, + condition=TestCloneConverter._node_is_clone, + ) + + input_data = (np.random.random(input_shape) * 50).astype(np.int8) + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToChannelLastPreprocess(), + input_data=input_data, + atol=1.0, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py index 47cd54c4efb..a2c9526a508 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,12 +7,14 @@ import pytest import torch +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, to_quantized_edge_program, ) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToNCHWPreprocess, ToNHWCPreprocess, ) @@ -20,6 +22,8 @@ ConstantPadNDConvModule, ConstantPadNDModule, ) +from executorch.backends.nxp.tests.use_qat import * # noqa F403 +from executorch.exir.dialects._ops import ops as exir_ops @pytest.fixture(autouse=True) @@ -99,6 +103,9 @@ def test_constant_pad_nd_conversion__channels_first(input_shape, paddings): input_data, tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) @@ -114,10 +121,68 @@ def test_constant_pad_nd_conversion__channels_first(input_shape, paddings): pytest.param((1, 1, 6, 8), (1, 2, 3, 4, 2, 1), id="4D, padding C, H, W"), ], ) -def test_constant_pad_nd__unsupported_paddings(input_shape, paddings): +def test_constant_pad_nd__unsupported_paddings(input_shape, paddings, use_qat): model = ConstantPadNDModule(paddings) - exec_program = to_quantized_edge_program(model, input_shape).exported_program() + exec_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() nodes = list(exec_program.graph.nodes) # There is at least one non-delegated Pad node assert any(node.name == "aten_constant_pad_nd_default" for node in nodes) + + +def test_constant_pad_nd__delegation__formatless__supported_padding(use_qat): + input_shape = (2, 4, 6, 8) # Formatless -> the last dim (8) will be padded. + paddings = [0, 0, 1, 2, 3, 4] # The last dim is padded using the first 2 paddings. + model = ConstantPadNDModule(paddings) + exec_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `pad` was delegated. + assert not graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) + + +def test_constant_pad_nd__delegation__formatless__unsupported_padding(use_qat): + input_shape = (2, 4, 6, 8) # Formatless -> the last dim (8) will be padded. + paddings = [0, 1] # The last dim is padded using the first 2 paddings. + model = ConstantPadNDModule(paddings) + exec_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `pad` was NOT delegated. + assert graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) + + +def test_constant_pad_nd__delegation__channels_first__supported_padding(use_qat): + input_shape = (2, 4, 6, 8) # Channels first -> the second dim (4) will be padded. + paddings = [1, 2, 3, 4, 0, 0] # The second dim is padded using the paddings[4:6]. + model = ConstantPadNDConvModule(paddings) + exec_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `pad` was delegated. + assert not graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) + + +def test_constant_pad_nd__delegation__channels_first__unsupported_padding(use_qat): + input_shape = (2, 3, 6, 8) # Channels first -> the second dim (3) will be padded. + paddings = [0, 0, 0, 0, 1, 0] # The second dim is padded using the paddings[4:6]. + model = ConstantPadNDConvModule(paddings) + exec_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `pad` was NOT delegated. + assert graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py index 745b26ef8ff..56fdf1a2e0c 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py @@ -10,6 +10,7 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -22,11 +23,14 @@ ) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToChannelFirstPreprocess, ToChannelLastPreprocess, ) from executorch.backends.nxp.tests.models import Conv1dModule, Conv2dModule +from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -35,17 +39,20 @@ def reseed_model_per_test_run(): np.random.seed(23) +@pytest.mark.parametrize("bias", [False, True]) @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("dilation", [2, 1]) @pytest.mark.parametrize("kernel_size", [(1,), (3,)]) -def test_conv1d_quant_conversion(stride, dilation, kernel_size, mocker): +def test_conv1d_quant_conversion(bias, stride, dilation, kernel_size, mocker, use_qat): input_shape = (1, 4, 16) - model = Conv1dModule(stride=stride, dilation=dilation, kernel_size=kernel_size) + model = Conv1dModule( + bias=bias, stride=stride, dilation=dilation, kernel_size=kernel_size + ) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -76,10 +83,21 @@ def test_conv1d_quant_conversion(stride, dilation, kernel_size, mocker): @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("dilation", [2, 1]) -@pytest.mark.parametrize("kernel_size", [(1,), (3,)]) +@pytest.mark.parametrize( + "kernel_size", + [ + pytest.param( + (1,), + marks=pytest.mark.xfail( + reason="Regression in Neutron SW 2.1.x (AIR-13336)", strict=True + ), + ), + (3,), + ], +) @pytest.mark.parametrize("padding", [(1,), 2]) def test_conv1d_quant_conversion__padded( - stride, dilation, kernel_size, padding, mocker + stride, dilation, kernel_size, padding, mocker, use_qat ): input_shape = (1, 4, 16) model = Conv1dModule( @@ -89,7 +107,7 @@ def test_conv1d_quant_conversion__padded( ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -131,13 +149,17 @@ def test_conv1d_quant_conversion__padded( ) # `Conv` input zp. +@pytest.mark.parametrize("bias", [False, True]) @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("dilation", [2, 1]) @pytest.mark.parametrize("kernel_size", [(1,), (3,)]) -def test_conv1d_quant_conversion__depthwise(stride, dilation, kernel_size, mocker): +def test_conv1d_quant_conversion__depthwise( + bias, stride, dilation, kernel_size, mocker, use_qat +): input_shape = (1, 4, 16) group = input_shape[1] model = Conv1dModule( + bias=bias, group=group, in_channels=group, out_channels=group, @@ -149,7 +171,7 @@ def test_conv1d_quant_conversion__depthwise(stride, dilation, kernel_size, mocke ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -179,10 +201,21 @@ def test_conv1d_quant_conversion__depthwise(stride, dilation, kernel_size, mocke @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("dilation", [2, 1]) -@pytest.mark.parametrize("kernel_size", [(1,), (3,)]) +@pytest.mark.parametrize( + "kernel_size", + [ + pytest.param( + (1,), + marks=pytest.mark.xfail( + reason="Regression in Neutron SW 2.1.x (AIR-13336)", strict=True + ), + ), + (3,), + ], +) @pytest.mark.parametrize("padding", [(1,), 2]) def test_conv1d_quant_conversion__depthwise__padded( - stride, dilation, kernel_size, padding, mocker + stride, dilation, kernel_size, padding, mocker, use_qat ): input_shape = (1, 4, 16) group = input_shape[1] @@ -199,7 +232,7 @@ def test_conv1d_quant_conversion__depthwise__padded( ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -347,13 +380,35 @@ def test_conv1d_quant_conversion__depthwise__padded( (1, 32, 32, 32), id="In ch 32, out ch 32, kernel 4, padding (0, 2), dilation (1, 2)", ), + pytest.param( + Conv2dModule( + in_channels=8, out_channels=32, kernel_size=5, padding=3, bias=False + ), + (1, 8, 32, 32), + id="In ch 8, out ch 32, kernel 5, padding 3, no bias", + ), + pytest.param( + Conv2dModule( + in_channels=32, + out_channels=32, + kernel_size=3, + padding=(1, 0), + dilation=(3, 1), + bias=False, + ), + (1, 32, 35, 35), + id="In ch 32, out ch 32, kernel 3, padding (1, 0), dilation (3, 1)," + "no bias", + ), ], ) -def test_conv2d_quant_conversion(mocker, model: torch.nn.Module, input_shape): +def test_conv2d_quant_conversion(mocker, model: torch.nn.Module, input_shape, use_qat): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -373,47 +428,12 @@ def test_conv2d_quant_conversion(mocker, model: torch.nn.Module, input_shape): ) -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [1, 2]) -@pytest.mark.parametrize("kernel_shape", [[1, 2], [3, 3], [4, 1]]) -def test_conv2d_conversion__depthwise(stride, dilation, kernel_shape, mocker): - input_shape = (1, 3, 12, 16) - group = input_shape[1] - edge_program = to_edge_program( - Conv2dModule( - group=group, - in_channels=group, - out_channels=group, - stride=stride, - dilation=dilation, - kernel_size=kernel_shape, - ), - input_shape, - ).exported_program() - - input_data = np.random.random(input_shape).astype(np.float32) - - spy = mocker.spy(ModelBuilder, "finish") - - convert_run_compare( - edge_program, - input_data, - tflite_input_preprocess=ToChannelLastPreprocess(), - tflite_output_preprocess=ToChannelFirstPreprocess(), - atol=4e-7, - ) - conversion_result = spy.spy_return - ops = conversion_result.sub_graphs[0].operators.vector - - assert len(ops) == 1 - assert ops[0].builtin_options.operator_type == BuiltinOperator.DEPTHWISE_CONV_2D - - +@pytest.mark.parametrize("bias", [False, True]) @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("dilation", [1, 2]) @pytest.mark.parametrize("kernel_shape", [[1, 2], [3, 3], [4, 1]]) def test_conv2d_conversion__depthwise__quantized( - stride, dilation, kernel_shape, mocker + bias, stride, dilation, kernel_shape, mocker, use_qat ): input_shape = (1, 4, 12, 12) group = input_shape[1] @@ -421,6 +441,7 @@ def test_conv2d_conversion__depthwise__quantized( edge_program = to_quantized_edge_program( Conv2dModule( + bias=bias, group=group, in_channels=group, out_channels=group, @@ -429,6 +450,8 @@ def test_conv2d_conversion__depthwise__quantized( kernel_size=kernel_shape, ), tuple(input_shape), + use_qat=use_qat, + use_neutron_for_format_conversion=False, ).exported_program() ops = spy.spy_return.sub_graphs[0].operators.vector @@ -463,6 +486,9 @@ def test_conv2d_conversion__depthwise__padded(padding, mocker): tflite_input_preprocess=ToChannelLastPreprocess(), tflite_output_preprocess=ToChannelFirstPreprocess(), atol=4e-7, + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) conversion_result = spy.spy_return ops = conversion_result.sub_graphs[0].operators.vector @@ -473,7 +499,7 @@ def test_conv2d_conversion__depthwise__padded(padding, mocker): @pytest.mark.parametrize("padding", [1, 2]) -def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): +def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker, use_qat): input_shape = (1, 4, 12, 12) group = input_shape[1] spy = mocker.spy(ModelBuilder, "finish") @@ -483,6 +509,8 @@ def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): group=group, in_channels=group, out_channels=group, padding=padding ), tuple(input_shape), + use_qat=use_qat, + use_neutron_for_format_conversion=False, ).exported_program() ops = spy.spy_return.sub_graphs[0].operators.vector @@ -495,3 +523,158 @@ def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): len(nodes) == 7 ) # input, Quant, lowered_module, delegate_call, getitem, Deq, output assert nodes[2].target == "lowered_module_0" + + +@pytest.mark.parametrize( + "model, input_shape", + [ + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (1, 4), stride=(1, 2)), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 4), stride (1, 2)", + ), + pytest.param( + torch.nn.ConvTranspose2d(64, 64, (1, 2), stride=(1, 2)), + (1, 64, 3, 12), + id="In ch 64, out ch 64, kernel (1, 2), stride (1, 2)", + ), + pytest.param( + torch.nn.ConvTranspose2d( + 16, 24, (1, 6), stride=(1, 6), output_padding=(0, 3) + ), + (1, 16, 7, 15), + id="In ch 16, out ch 24, kernel (1, 6), stride (1, 6), output_padding (0, 3)", + ), + pytest.param( + torch.nn.ConvTranspose2d(16, 40, (1, 4), stride=(1, 4), padding=(0, 1)), + (1, 16, 1, 27), + id="In ch 16, out ch 40, kernel (1, 4), stride (1, 4), padding (0, 1)", + ), + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (1, 4), stride=(1, 2), padding=(0, 1)), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 4), stride (1, 2), padding (0, 1)", + ), + pytest.param( + torch.nn.ConvTranspose2d( + 8, 16, (1, 8), stride=(1, 4), output_padding=(0, 2) + ), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 8), stride (1, 4), output_padding (0, 2)", + ), + pytest.param( + torch.nn.ConvTranspose2d(16, 16, (1, 4), stride=(1, 2)), + (1, 16, 1, 16), + id="In ch 16, out ch 16, kernel (1, 4), stride (1, 2)", + ), + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (1, 4), stride=(1, 2), bias=False), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 4), stride (1, 2), no bias", + ), + pytest.param( + torch.nn.ConvTranspose2d( + 8, 16, (1, 4), stride=(1, 2), padding=(0, 1), bias=False + ), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 4), stride (1, 2)," + "padding (0, 1), no bias", + ), + ], +) +def test_conv_transpose2d_conversion__quantized( + mocker, model: torch.nn.Module, input_shape, use_qat +): + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ).exported_program() + + # Make sure the `TransposeConv` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.convolution.default] + ) + assert any("lowered_module" in node.name for node in edge_program.graph.nodes) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + + convert_run_compare( + exported_program, + tflite_input_preprocess=ToChannelLastPreprocess(), + tfl_model=tflite_flatbuffers_model, + tflite_output_preprocess=ToChannelFirstPreprocess(), + input_data=input_data, + atol=1.0, + ) + + +@pytest.mark.parametrize( + "model, input_shape", + [ + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (1, 4), stride=(1, 2), dilation=(1, 2)), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 4), stride (1, 2), " + "dilation (1, 2) - Dilation != (1, 1)", + ), + pytest.param( + torch.nn.ConvTranspose2d(6, 16, (1, 4), stride=(1, 2)), + (1, 6, 1, 16), + id="In ch 6, out ch 16, kernel (1, 4), stride (1, 2) - In channels % num_macs != 0", + ), + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (1, 4), stride=(1, 2)), + (1, 8, 4, 16), + id="In ch 8, out ch 16, kernel (1, 4), stride (1, 2) - Out height != 1, stride width" + " != kernel width", + ), + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (2, 4), stride=(1, 2), padding=(0, 1)), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (2, 4), stride (1, 2), padding " + "(0, 1) - Out height != 1, stride width != kernel width", + ), + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (1, 5), stride=(1, 4)), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 5), stride (1, 4) - Stride width != kernel width / 2" + ", stride width != kernel width", + ), + pytest.param( + torch.nn.ConvTranspose2d(16, 12, (1, 4), stride=(3, 3)), + (1, 16, 1, 16), + id="In ch 16, out ch 12, kernel (1, 4), stride (3, 3) - Out channels % num_macs != 0", + ), + pytest.param( + torch.nn.ConvTranspose2d(64, 64, (1, 4), stride=(1, 2)), + (1, 64, 3, 12), + id="In ch 64, out ch 64, kernel (1, 4), stride (1, 2) - Out height != 1, stride width" + " != kernel width", + ), + pytest.param( + torch.nn.ConvTranspose2d(16, 40, (1, 4), stride=(1, 4), padding=(0, 1)), + (1, 16, 4, 27), + id="In ch 16, out ch 40, kernel (1, 4), stride (1, 4), padding (0, 1) - Padding width " + "!= 1 and input height != 1", + ), + ], +) +def test_conv_transpose2d_non_delegated_conversion__quantized( + model: torch.nn.Module, input_shape, use_qat +): + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 15 + assert ( + nodes[11].target.__name__ == "aten.convolution.default" + ) # TransposeConv not delegated. diff --git a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py index e17868d16e2..fb272a2c650 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py @@ -23,6 +23,7 @@ from executorch.backends.nxp.tests.models import Conv2dWithActivation from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -33,7 +34,7 @@ def reseed_model_per_test_run(): @pytest.mark.parametrize("input_shape", [(1, 3, 128, 128)]) @pytest.mark.parametrize("inplace", [True, False]) -def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): +def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool, use_qat: bool): # The torch.nn.Relu6 inherits from torch.nn.Hardtanh, and hence represented as HardTanh in ATen. # Testing the hardtanh originated from torch.nn.Relu6 op. model = Conv2dWithActivation( @@ -42,7 +43,9 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -57,7 +60,7 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), input_data=input_data, - atol=1.0, + atol=2.0, ) @@ -67,7 +70,11 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): ) @pytest.mark.parametrize("inplace", [True, False]) def test_custom_hardtanh_quant( - mocker, input_shape: tuple[int], activation_range: tuple[int, int], inplace: bool + mocker, + input_shape: tuple[int], + activation_range: tuple[int, int], + inplace: bool, + use_qat: bool, ): # TODO(13063): This test suffers from non-ideal testing random quantization, because we always use range <0,1>. # We should update (decrease atol) when the Conv/Linear + Activation fuse at quantization is in place. @@ -79,7 +86,9 @@ def test_custom_hardtanh_quant( converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_linear_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_linear_converter.py deleted file mode 100644 index 858724522cd..00000000000 --- a/backends/nxp/tests/ir/converter/node_converter/test_linear_converter.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2024 NXP -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np -import pytest -import torch - -from executorch.backends.nxp.tests.executorch_pipeline import to_edge_program -from executorch.backends.nxp.tests.executors import convert_run_compare -from executorch.backends.nxp.tests.models import LinearModule -from executorch.exir.dialects._ops import ops as exir_ops - - -@pytest.fixture(autouse=True) -def reseed_model_per_test_run(): - torch.manual_seed(23) - np.random.seed(23) - - -def test_linear_conversion__with_bias(): - input_shape = (10, 32) - edge_program = to_edge_program( - LinearModule(bias=True), input_shape - ).exported_program() - - input_data = np.random.random(input_shape).astype(np.float32) - - nodes = list(edge_program.graph.nodes) - assert nodes[4].target == exir_ops.edge.aten.addmm.default - assert len(nodes[4].args) == 3 # Has bias. - - convert_run_compare(edge_program, input_data=input_data) - - -def test_linear_conversion__without_bias(): - input_shape = (10, 32) - edge_program = to_edge_program( - LinearModule(bias=False), input_shape - ).exported_program() - - input_data = np.random.random(input_shape).astype(np.float32) - - nodes = list(edge_program.graph.nodes) - assert nodes[3].target == exir_ops.edge.aten.mm.default - assert len(nodes[3].args) == 2 # No bias. - - convert_run_compare(edge_program, input_data=input_data) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py index 50bbf100980..569ad571dbc 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py @@ -6,10 +6,11 @@ import numpy as np import pytest import torch - from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) + +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, @@ -24,6 +25,7 @@ from executorch.backends.xnnpack._passes import RemoveGetItemPass from executorch.exir.verification.verifier import EXIREdgeDialectVerifier from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -76,6 +78,9 @@ def test_max_pool_2d_conversion(input_shape, padding): input_data, tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) @@ -99,11 +104,16 @@ def test_max_pool_2d_conversion(input_shape, padding): ), ], ) -def test_max_pool_2d_quant_conversion(mocker, input_shape, padding): +def test_max_pool_2d_quant_conversion(mocker, input_shape, padding, use_qat): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(MaxPool2dConvModule(padding=padding), input_shape) + _ = to_quantized_edge_program( + MaxPool2dConvModule(padding=padding), + input_shape, + use_qat=use_qat, + use_neutron_for_format_conversion=False, + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py index 0032eae5c1a..7c0a5e8ffcf 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py @@ -1,3 +1,8 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import numpy as np import pytest import torch @@ -8,10 +13,13 @@ from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToChannelFirstPreprocess, ToChannelLastPreprocess, ) from executorch.backends.nxp.tests.models import MeanDimConvModule, MeanDimLinearModule +from executorch.backends.nxp.tests.use_qat import * # noqa F403 +from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram @@ -21,19 +29,39 @@ def reseed_model_per_test_run(): np.random.seed(23) +class MeanDimModule(torch.nn.Module): + def __init__(self, dim, keepdim): + super().__init__() + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return torch.mean(x, dim=self.dim, keepdim=self.keepdim) + + @pytest.mark.parametrize( "input_shape, dim", [ pytest.param((1, 4, 8, 8), (-1, -2), id="Dim -1, -2."), + pytest.param((1, 4, 8, 8), (-2, -1), id="Dim -2, -1."), + pytest.param((1, 4, 8, 8), (2, 3), id="Dim 2, 3."), + pytest.param((1, 4, 8, 8), (3, 2), id="Dim 3, 2."), ], ) -def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True): - model = MeanDimConvModule(dim, keeepdim) +def test_mean_dim_conv_quant_conversion( + mocker, input_shape, dim, use_qat, keepdim=True +): + model = MeanDimConvModule(dim, keepdim) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ).exported_program() + # Make sure the `mean.dim` was delegated. + assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert any("lowered_module" in n.name for n in ep.graph.nodes) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -49,6 +77,7 @@ def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True) input_data=input_data, tflite_output_preprocess=ToChannelFirstPreprocess(), tfl_model=tflite_flatbuffers_model, + atol=1.0, ) @@ -60,21 +89,23 @@ def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True) ], ) @pytest.mark.parametrize( - "keeepdim", + "keepdim", [ pytest.param(False, id="Don't keep dim."), pytest.param(True, id="Keep dim."), ], ) def test_mean_dim_linear_unsupported_quant_conversion( - mocker, input_shape, dim, keeepdim + mocker, input_shape, dim, use_qat, keepdim ): - model = MeanDimLinearModule(dim, keeepdim) + model = MeanDimLinearModule(dim, keepdim) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() nodes = list(edge_program.graph.nodes) # Last 2 dimensions are not used or keepdim is False, cannot be converted to MeanDim, node is not delegated @@ -106,19 +137,23 @@ def test_mean_dim_linear_unsupported_quant_conversion( ], ) @pytest.mark.parametrize( - "keeepdim", + "keepdim", [ pytest.param(False, id="Don't keep dim."), pytest.param(True, id="Keep dim."), ], ) -def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, keeepdim): - model = MeanDimConvModule(dim, keeepdim) +def test_mean_dim_conv_unsupported_quant_conversion( + mocker, input_shape, dim, use_qat, keepdim +): + model = MeanDimConvModule(dim, keepdim) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ).exported_program() nodes = list(edge_program.graph.nodes) # Last 2 dimensions are not used or keepdim is False, cannot be converted to MeanDim, node is not delegated @@ -139,3 +174,107 @@ def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, ke tflite_output_preprocess=ToChannelFirstPreprocess(), tfl_model=tflite_flatbuffers_model, ) + + +@pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((1, 2, 3, 8), (1, 2), id="Dim 1, 2."), + pytest.param((1, 2, 3, 8), (2, 1), id="Dim 2, 1."), + pytest.param((1, 2, 3, 8), (-3, -2), id="Dim -3, -2."), + pytest.param((1, 2, 3, 8), (-2, -3), id="Dim -2, -3."), + ], +) +def test_mean_dim__formatless__supported( + mocker, input_shape, dim, use_qat, keepdim=True +): + model = MeanDimModule(dim, keepdim) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `mean.dim` was delegated. + assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert any("lowered_module" in n.name for n in ep.graph.nodes) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + + convert_run_compare( + exported_program, + input_data=input_data, + tfl_model=tflite_flatbuffers_model, + atol=1, + ) + + +@pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((1, 2, 3, 8), (2, 3), id="Dim 2, 3."), + ], +) +def test_mean_dim__formatless__unsupported(input_shape, dim, use_qat, keepdim=True): + model = MeanDimModule(dim, keepdim) + + ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `mean.dim` was NOT delegated. + assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert not any("lowered_module" in n.name for n in ep.graph.nodes) + + +@pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param( + (1, 8, 8, 4), (1, 2), id="Dim 1, 2 (supported), channels = 4 (unsupported)." + ), + ], +) +def test_mean_dim__formatless__unsupported_channels( + input_shape, dim, use_qat, keepdim=True +): + model = MeanDimModule(dim, keepdim) + + ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `mean.dim` was NOT delegated. + assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert not any("lowered_module" in n.name for n in ep.graph.nodes) + + +@pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param( + (1, 4, 8, 8), (2, 3), id="Dim 2, 3 (supported), channels = 5 (unsupported)." + ), + ], +) +def test_mean_dim__channels_first__unsupported_channels( + input_shape, dim, use_qat, keepdim=True +): + model = MeanDimConvModule( + dim, keepdim, out_channels=5 + ) # Only multiples of 8 (num_macs) are supported. + + # Run conversion + ep = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `mean.dim` was NOT delegated. + assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py new file mode 100644 index 00000000000..962a4f4b0c1 --- /dev/null +++ b/backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py @@ -0,0 +1,96 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import kgb +import numpy as np +import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + graph_contains_any_of_ops, +) +from executorch.backends.nxp.tests.models import LinearModule, MmModule +from executorch.exir.dialects._ops import ops as exir_ops +from parameterized import parameterized +from torch.export import ExportedProgram + + +class TestMmConversion(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) + + @parameterized.expand([("QAT", True), ("PTQ", False)]) + def test_mm_conversion(self, _, use_qat: bool): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 32) + model = MmModule(input_shape[1]) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.mm.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + ) + + @parameterized.expand([("QAT", True), ("PTQ", False)]) + def test_linear_conversion__without_bias(self, _, use_qat: bool): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (10, 32) + model = LinearModule(bias=False) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.mm.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py new file mode 100644 index 00000000000..053cd96944d --- /dev/null +++ b/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py @@ -0,0 +1,212 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import pytest +import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, +) +from executorch.backends.nxp.tests.models import ( + MulTensorConvModule, + MulTensorModule, + MulTensorOneInputModule, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import ExportedProgram + + +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(23) + np.random.seed(23) + + +@pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 8), id="2D."), + pytest.param((1, 4, 8), id="3D."), + pytest.param((1, 4, 8, 8), id="4D."), + ], +) +def test_mul_tensor_quant_conversion(mocker, x_input_shape): + model = MulTensorModule() + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + # Run conversion + edge_program = to_quantized_edge_program( + model, [x_input_shape, x_input_shape] + ).exported_program() + edge_nodes = list(edge_program.graph.nodes) + + # Check "Mul" was delegated + assert not any("mul" in n.name for n in edge_nodes) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data_1 = (np.random.random(x_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data_2 = (np.random.random(x_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data = {0: input_data_1, 1: input_data_2} + + exported_nodes = list(exported_program.graph.nodes) + assert exported_nodes[4].target == exir_ops.edge.aten.mul.Tensor + + convert_run_compare( + exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data + ) + + +@pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((11,), id="1D."), + pytest.param((4, 4), id="2D."), + pytest.param((1, 4, 7), id="3D."), + pytest.param((1, 4, 4, 20), id="4D."), + ], +) +def test_mul_tensor_shape_unsupported_quant_conversion(x_input_shape): + model = MulTensorOneInputModule() + + # Run conversion + edge_program = to_quantized_edge_program(model, x_input_shape).exported_program() + nodes = list(edge_program.graph.nodes) + + # Input tensor shape is not supported, node is not converted + assert ( + nodes[3].target == exir_ops.edge.aten.mul.Tensor + ) # Mul Tensor is not delegated. + + +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param((16,), id="1D."), + pytest.param((6, 8), id="2D."), + pytest.param((1, 4, 8), id="3D."), + pytest.param((1, 4, 8, 8), id="4D."), + ], +) +def test_mul_tensor_one_input_quant_conversion(mocker, input_shape): + model = MulTensorOneInputModule() + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + # Run conversion + edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_nodes = list(edge_program.graph.nodes) + + # Check "Mul" was delegated + assert not any("mul" in n.name for n in edge_nodes) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + + exported_nodes = list(exported_program.graph.nodes) + assert exported_nodes[2].target == exir_ops.edge.aten.mul.Tensor + + convert_run_compare( + exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data + ) + + +@pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1, 4, 16, 16), id="4D."), + pytest.param((1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."), + ], +) +def test_mul_tensor_w_conv_quant_conversion(mocker, x_input_shape): + model = MulTensorConvModule() + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + n, c, h, w = x_input_shape + y_input_shape = (n, 8, h, w) + + # Run conversion + edge_program = to_quantized_edge_program( + model, [x_input_shape, y_input_shape], use_neutron_for_format_conversion=False + ).exported_program() + edge_nodes = list(edge_program.graph.nodes) + + # Check "Mul" was delegated + assert not any("mul" in n.name for n in edge_nodes) + + # Check "Convolution" was delegated + assert not any("convolution" in n.name for n in edge_nodes) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data_1 = (np.random.random(x_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data_2 = (np.random.random(y_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data = {0: input_data_1, 1: input_data_2} + + exported_nodes = list(exported_program.graph.nodes) + assert exported_nodes[12].target == exir_ops.edge.aten.convolution.default + assert exported_nodes[15].target == exir_ops.edge.aten.mul.Tensor + + convert_run_compare( + exported_program, + input_data=input_data, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) + + +@pytest.mark.parametrize( + "x_input_shape, y_input_shape", + [ + pytest.param((4, 4, 8), (1, 4, 4, 8), id="3D -> 4D."), + pytest.param((1, 6), (6,), id="2D -> 1D."), + ], +) +def test_mul_tensor_broadcasting_unsupported_quant_conversion( + x_input_shape, y_input_shape +): + model = MulTensorModule() + + # Run conversion + edge_program = to_quantized_edge_program( + model, [x_input_shape, y_input_shape] + ).exported_program() + nodes = list(edge_program.graph.nodes) + + # Broadcast is not supported, node is not converted + assert ( + nodes[6].target == exir_ops.edge.aten.mul.Tensor + ) # Mul Tensor is not delegated. diff --git a/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py index d25e2759cc8..d32de7241e5 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py @@ -3,8 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import unittest + +import kgb import numpy as np -import pytest import torch from executorch.backends.nxp.backend.edge_program_converter import ( @@ -13,52 +15,412 @@ from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, - ToNCHWPreprocess, - ToNHWCPreprocess, + graph_contains_any_of_ops, ) from executorch.backends.nxp.tests.models import Conv2dModule +from executorch.exir.dialects._ops import ops as exir_ops +from parameterized import parameterized from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 + + +class Conv2dTransposeModule(torch.nn.Module): + def __init__(self, in_channels: int, dim0: int, dim1: int): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + self.conv = Conv2dModule( + in_channels=in_channels, out_channels=in_channels, kernel_size=(1, 1) + ) + + def forward(self, x): + x = self.conv(x) + return torch.transpose(x, self.dim0, self.dim1) + + +class Conv2dPermuteModule(torch.nn.Module): + def __init__(self, in_channels: int, perm: tuple[int, ...]): + super().__init__() + self.perm = perm + self.conv = Conv2dModule( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + kernel_size=3, + padding=1, + ) + + def forward(self, x): + x = self.conv(x) + return torch.permute(x, self.perm) -@pytest.fixture(autouse=True) -def reseed_model_per_test_run(): - torch.manual_seed(23) - np.random.seed(23) +class PermuteConv2dModule(torch.nn.Module): + def __init__(self, in_channels: int, perm: tuple[int, ...]): + super().__init__() + self.perm = perm + self.conv = Conv2dModule( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + kernel_size=3, + padding=1, + ) + + def forward(self, x): + x = torch.permute(x, self.perm) + return self.conv(x) -class Conv2dPermuteCopyModule(torch.nn.Module): - def __init__(self, new_dims: tuple[int, ...]): +class PermuteConv2dPermuteModule(torch.nn.Module): + def __init__( + self, in_channels: int, perm1: tuple[int, ...], perm2: tuple[int, ...] + ): super().__init__() - self.new_dims = new_dims - self.conv = Conv2dModule() + self.perm1 = perm1 + self.perm2 = perm2 + self.conv = Conv2dModule( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + kernel_size=3, + padding=1, + ) def forward(self, x): + x = torch.permute(x, self.perm1) x = self.conv(x) - return torch.permute(x, self.new_dims) + x = torch.permute(x, self.perm2) + return x -def test_permute_copy_quant_conversion__with_bias(mocker): - input_shape = (1, 4, 8, 8) - new_dims = (0, 2, 3, 1) +class LinearPermuteModule(torch.nn.Module): + def __init__(self, in_features: int, perm: tuple[int, ...]): + super().__init__() + self.perm = perm + self.fc = torch.nn.Linear(in_features, in_features) + + def forward(self, x): + x = self.fc(x) + return torch.permute(x, self.perm) - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - # Run conversion - _ = to_quantized_edge_program(Conv2dPermuteCopyModule(new_dims), input_shape) +class TestPermuteCopyConversion(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) - # Capture generated model - tflite_flatbuffers_model, io_formats = converter_spy.spy_return + @parameterized.expand( + [ + ["QAT; To channel first permutation", (1, 16, 8, 8), (0, 3, 1, 2), True], + ["PTQ; To channel first permutation", (1, 16, 8, 8), (0, 3, 1, 2), False], + ["QAT; To channel last permutation", (1, 16, 8, 8), (0, 2, 3, 1), True], + ["PTQ; To channel last permutation", (1, 16, 8, 8), (0, 2, 3, 1), False], + ] + ) + def test_permute_copy_conversion__from_permute_4D__quantized__channels_first_input( + self, _: str, input_shape, perm, use_qat + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = Conv2dPermuteModule(input_shape[1], perm) - # Capture converted program - edge_program: ExportedProgram = converter_spy.call_args.args[1] + # Run conversion + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) - convert_run_compare( - edge_program, - input_data, - tfl_model=tflite_flatbuffers_model, - atol=1.0, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["QAT; To channel first permutation", (1, 8, 8, 8), (0, 3, 1, 2), True], + ["PTQ; To channel first permutation", (1, 8, 8, 8), (0, 3, 1, 2), False], + ["QAT; To channel last permutation", (1, 8, 8, 8), (0, 2, 3, 1), True], + ["PTQ; To channel last permutation", (1, 8, 8, 8), (0, 2, 3, 1), False], + ] ) + def test_permute_copy_conversion__from_permute_4D__quantized__channels_first_output( + self, _: str, input_shape, perm, use_qat + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = PermuteConv2dModule(input_shape[1], perm) + + # Run conversion + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + [ + "QAT; nchw->nhwc ... nchw->nhwc", + (1, 8, 8, 8), + (0, 2, 3, 1), + (0, 2, 3, 1), + True, + ], + [ + "PTQ; nchw->nhwc ... nchw->nhwc", + (1, 8, 8, 8), + (0, 2, 3, 1), + (0, 2, 3, 1), + False, + ], + [ + "QAT; nchw->nhwc ... nhwc->nchw", + (1, 8, 8, 8), + (0, 2, 3, 1), + (0, 3, 1, 2), + True, + ], + [ + "PTQ; nchw->nhwc ... nhwc->nchw", + (1, 8, 8, 8), + (0, 2, 3, 1), + (0, 3, 1, 2), + False, + ], + [ + "QAT; nhwc->nchw ... nhwc->nchw", + (1, 8, 8, 8), + (0, 3, 1, 2), + (0, 3, 1, 2), + True, + ], + [ + "PTQ; nhwc->nchw ... nhwc->nchw", + (1, 8, 8, 8), + (0, 3, 1, 2), + (0, 3, 1, 2), + False, + ], + [ + "QAT; nhwc->nchw ... nchw->nhwc", + (1, 8, 8, 8), + (0, 3, 1, 2), + (0, 2, 3, 1), + True, + ], + [ + "PTQ; nhwc->nchw ... nchw->nhwc", + (1, 8, 8, 8), + (0, 3, 1, 2), + (0, 2, 3, 1), + False, + ], + ] + ) + def test_permute_copy_conversion__from_permute_4D__quantized__channels_first_io( + self, _: str, input_shape, perm1, perm2, use_qat + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = PermuteConv2dPermuteModule(input_shape[1], perm1, perm2) + + # Run conversion + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + [ + "QAT; Permutation can be replaced by reshapes", + (10, 1, 8), + (0, 2, 1), + True, + ], + [ + "PTQ; Permutation can be replaced by reshapes", + (10, 1, 8), + (0, 2, 1), + False, + ], + [ + "QAT; Permutation can be replaced by reshapes", + (10, 1, 1), + (2, 1, 0), + True, + ], + [ + "PTQ; Permutation can be replaced by reshapes", + (10, 1, 1), + (2, 1, 0), + False, + ], + [ + "QAT; Permutation is identical and can be removed", + (10, 1, 8), + (0, 1, 2), + True, + ], + [ + "PTQ; Permutation is identical and can be removed", + (10, 1, 8), + (0, 1, 2), + False, + ], + ] + ) + def test_permute_copy_conversion__from_permute_3D__quantized( + self, _: str, input_shape, perm, use_qat + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + # Run conversion + edge_program = to_quantized_edge_program( + LinearPermuteModule(input_shape[2], perm), input_shape, use_qat=use_qat + ).exported_program() + + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["QAT; Transpose dims 1 and 2", (1, 16, 8, 8), (0, 2, 1, 3), True], + ["PTQ; Transpose dims 1 and 2", (1, 16, 8, 8), (0, 2, 1, 3), False], + ["QAT; To (2, 0, 1, 3) permutation", (1, 16, 8, 8), (2, 0, 1, 3), True], + ["PTQ; To (2, 0, 1, 3) permutation", (1, 16, 8, 8), (2, 0, 1, 3), False], + ["QAT; To (3, 1, 2, 0) permutation", (1, 16, 8, 8), (3, 1, 2, 0), True], + ["PTQ; To (3, 1, 2, 0) permutation", (1, 16, 8, 8), (3, 1, 2, 0), False], + ["QAT; To (3, 1, 0, 2) permutation", (1, 16, 8, 8), (3, 1, 0, 2), True], + ["PTQ; To (3, 1, 0, 2) permutation", (1, 16, 8, 8), (3, 1, 0, 2), False], + ] + ) + def test_permute_copy_non_delegated_conversion__from_permute_4D__quantized( + self, _: str, input_shape, perm, use_qat + ): + model = Conv2dPermuteModule(input_shape[1], perm) + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 8 + assert ( + nodes[5].target == exir_ops.edge.aten.permute_copy.default + ) # PermuteCopy not delegated. + + @parameterized.expand( + [ + ["QAT; Transpose dims 1 and 2", (1, 16, 8, 8), 1, 2, True], + ["PTQ; Transpose dims 1 and 2", (1, 16, 8, 8), 1, 2, False], + ["QAT; Transpose dims 2 and 3", (1, 16, 8, 8), 2, 3, True], + ["PTQ; Transpose dims 2 and 3", (1, 16, 8, 8), 2, 3, False], + ] + ) + def test_permute_copy_non_delegated_conversion__from_transpose_4D__quantized( + self, _: str, input_shape, dim0, dim1, use_qat + ): + model = Conv2dTransposeModule(input_shape[1], dim0, dim1) + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 8 + assert ( + nodes[5].target == exir_ops.edge.aten.permute_copy.default + ) # PermuteCopy not delegated. diff --git a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py index 8d903e3e0b5..b91720324f2 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py @@ -21,6 +21,7 @@ ) from executorch.backends.nxp.tests.models import Conv2dModule, LinearModule, ReLUModule from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -62,12 +63,17 @@ def test_relu_conversion(): convert_run_compare(edge_program, input_data=input_data) -def test_relu_with_conv_quant_conversion(mocker): +def test_relu_with_conv_quant_conversion(mocker, use_qat): input_shape = (1, 4, 32, 32) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(ConvReLUModule(), input_shape) + _ = to_quantized_edge_program( + ConvReLUModule(), + input_shape, + use_qat=use_qat, + use_neutron_for_format_conversion=False, + ) # Capture generated model tflite_flatbuffers_model, _ = converter_spy.spy_return @@ -88,12 +94,12 @@ def test_relu_with_conv_quant_conversion(mocker): ) -def test_relu_with_linear_quant_conversion(mocker): +def test_relu_with_linear_quant_conversion(mocker, use_qat): input_shape = (256, 32) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(LinearReLUModule(), input_shape) + _ = to_quantized_edge_program(LinearReLUModule(), input_shape, use_qat=use_qat) # Capture generated model tflite_flatbuffers_model, _ = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py index c5d7d4d6a38..ad03aa18ded 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py @@ -20,6 +20,7 @@ from executorch.backends.nxp.tests.models import ConvWithSigmoid from torch import nn from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -28,12 +29,14 @@ def reseed_model_per_test_run(): np.random.seed(23) -def test_conv_sigmoid(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): +def test_conv_sigmoid(mocker, use_qat, input_shape: tuple[int] = (1, 3, 112, 112)): model = ConvWithSigmoid(conv_in_channels=input_shape[1]) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - to_quantized_edge_program(model, input_shape).exported_program() + to_quantized_edge_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -59,12 +62,12 @@ def test_conv_sigmoid(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): pytest.param((10, 3, 25, 25, 25), id="4D"), ], ) -def test_sigmoid_only(mocker, input_shape): +def test_sigmoid_only(mocker, use_qat, input_shape): model = nn.Sigmoid() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - to_quantized_edge_program(model, input_shape).exported_program() + to_quantized_edge_program(model, input_shape, use_qat=use_qat).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_slice_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_slice_tensor_converter.py new file mode 100644 index 00000000000..e5516e8a254 --- /dev/null +++ b/backends/nxp/tests/ir/converter/node_converter/test_slice_tensor_converter.py @@ -0,0 +1,244 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np +import pytest +import torch +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import ( + neutron_converter_flavor, + to_quantized_edge_program, +) +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, +) + +from executorch.backends.nxp.tests.models import ( + SliceTensorConvModule, + SliceTensorModule, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import ExportedProgram + + +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(23) + np.random.seed(23) + + +@pytest.mark.parametrize( + "x_input_shape, dims, starts, ends", + [ + pytest.param((24, 32), (0, 1), (0, 16), (24, 32), id="2D, no transpose"), + pytest.param( + (24, 32, 64), (0, 1, 2), (0, 0, 8), (24, 32, 64), id="3D, no transpose" + ), + pytest.param( + (24, 32, 64, 48), + (0, 1, 2, 3), + (0, 0, 0, 8), + (24, 32, 64, 48), + id="4D, no transpose", + ), + pytest.param((24, 32), (0, 1), (8, 0), (24, 32), id="2D, one transpose"), + pytest.param( + (24, 32, 64), (0, 1, 2), (0, 8, 0), (24, 32, 64), id="3D, one transpose" + ), + pytest.param( + (24, 32, 64, 48), + (0, 1, 2, 3), + (0, 0, 8, 0), + (24, 32, 64, 48), + id="4D, one transpose", + ), + pytest.param( + (24, 32, 64), (0, 1, 2), (8, 8, 0), (24, 32, 64), id="3D, two transposes" + ), + # bug in neutron-converter will not properly convert models in these test cases + # pytest.param((24, 32, 64, 48), (0, 1, 2, 3), (16, 0, 8, 0), (24, 32, 64, 48), id="4D, two transposes"), + # pytest.param((24, 32, 64, 48), (0, 1, 2, 3), (16, 0, 8, 0), (24, 24, 56, 48), id="4D, three transposes"), + pytest.param( + (24, 32), + (0, 1), + (0, 13), + (24, 32), + id="2D, start arg not divisible by num_macs", + ), + pytest.param( + (24, 32), + (0, 1), + (0, 0), + (24, 31), + id="2D, end arg not divisible by num_macs", + ), + pytest.param((24, 32), (1, 0), (16, 0), (32, 24), id="2D, mixed dim args"), + pytest.param((24, 32), (0, -1), (0, 16), (24, 32), id="2D, negative dim arg"), + ], +) +def test_slice_tensor_quant_conversion(mocker, x_input_shape, dims, starts, ends): + model = SliceTensorModule( + dims=dims, + starts=starts, + ends=ends, + ) + + if neutron_converter_flavor == "SDK_25_09": + pytest.skip("Neutron Software must be version 2.2.1 or higher.") + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + # Run conversion + edge_program = to_quantized_edge_program(model, x_input_shape).exported_program() + edge_nodes = list(edge_program.graph.nodes) + + # Check if slices were delegated + assert not any("slice" in n.name for n in edge_nodes) + + # Capture generated model + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data = (np.random.random(x_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data = {0: input_data} + + convert_run_compare( + exported_program, + input_data=input_data, + tfl_model=tflite_flatbuffers_model, + ) + + +@pytest.mark.parametrize( + "x_input_shape, dims, starts, ends", + [ + pytest.param( + (1, 4, 34, 50), + (0, 1, 2, 3), + (0, 0, 8, 0), + (1, 8, 32, 32), + id="4D, handle channel order swap", + ) + ], +) +def test_slice_tensor_w_conv_quant_conversion( + mocker, x_input_shape, dims, starts, ends +): + if neutron_converter_flavor == "SDK_25_09": + pytest.skip("Neutron Software must be version 2.2.1 or higher.") + + model = SliceTensorConvModule(dims=dims, starts=starts, ends=ends) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + # Run conversion + edge_program = to_quantized_edge_program( + model, x_input_shape, use_neutron_for_format_conversion=False + ).exported_program() + edge_nodes = list(edge_program.graph.nodes) + + # Check if slices were delegated + assert not any("slice" in n.name for n in edge_nodes) + + # Capture generated model + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data = (np.random.random(x_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data = {0: input_data} + + convert_run_compare( + exported_program, + input_data=input_data, + tflite_input_preprocess=ToChannelLastPreprocess(), + tfl_model=tflite_flatbuffers_model, + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) + + +@pytest.mark.parametrize( + "x_input_shape, dims, starts, ends", + [ + pytest.param( + (24, 32), (0, 1), (0, 16), (24, 8), id="2D, start is higher than end" + ), + pytest.param( + (24, 32), (0, 1), (0, 16), (24, 16), id="2D, start is equal to end" + ), + pytest.param( + (24, 32), (0, 1), (0, 32), (24, 32), id="2D, start is equal to size" + ), + pytest.param( + (24, 32), (0, 1), (0, 0), (24, -5), id="2D, clipped end equal to zero" + ), + pytest.param( + (24, 32), (0, 1), (64, 0), (24, 32), id="2D, clipped start equal to size" + ), + ], +) +def test_invalid_slice(mocker, x_input_shape, dims, starts, ends): + model = SliceTensorModule( + dims=dims, + starts=starts, + ends=ends, + ) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + # Run conversion + _ = to_quantized_edge_program(model, x_input_shape).exported_program() + + # Capture generated model, should be None because the model is invalid + assert converter_spy.spy_return is None + + +@pytest.mark.parametrize( + "x_input_shape, dims, starts, ends", + [ + pytest.param( + (24, 31), + (0, 1), + (0, 0), + (24, 16), + id="2D, input shape not divisible by num_macs", + ), + pytest.param( + (24, 26, 64), + (0, 1, 2), + (0, 4, 0), + (24, 26, 64), + id="3D, input shape not divisible by num_macs", + ), + ], +) +def test_slice_not_delegated(mocker, x_input_shape, dims, starts, ends): + model = SliceTensorModule( + dims=dims, + starts=starts, + ends=ends, + ) + + edge_program = to_quantized_edge_program(model, x_input_shape).exported_program() + nodes = list(edge_program.graph.nodes) + + num_slice_ops = 0 + for i in range(len(x_input_shape)): + if starts[i] != 0 or ends[i] != x_input_shape[i]: + num_slice_ops += 1 + + for i in range(0, num_slice_ops): + slice_idx = (i + 1) * 3 + assert nodes[slice_idx].target == exir_ops.edge.aten.slice_copy.Tensor diff --git a/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py index 92af90b923d..5953b9dcac3 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -11,6 +11,7 @@ EdgeProgramToIRConverter, ) from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from executorch.backends.nxp.tests.executorch_pipeline import to_edge_program from executorch.backends.nxp.tests.executors import convert_run_compare from executorch.backends.nxp.tests.models import SoftmaxConvModule, SoftmaxModule @@ -56,6 +57,7 @@ def test_softmax_conversion__unknown_input_format(input_shape, dim: int): model = SoftmaxModule(dim) edge_program = to_edge_program(model, input_shape).exported_program() + NodeFormatInference(edge_program).identify_node_formats() # Currently this test not pass because the convertibility checker doesn't use tensor formats. with pytest.raises( @@ -78,6 +80,7 @@ def test_softmax_conversion_channel_last(input_shape, dim: int): model = SoftmaxConvModule(dim) edge_program = to_edge_program(model, input_shape).exported_program() + NodeFormatInference(edge_program).identify_node_formats() # TODO (Robert Kalmar) Currently this test not pass because the convertibility checker doesn't use tensor formats. with pytest.raises( @@ -104,6 +107,7 @@ def test_softmax_conversion_unsupported_dims(input_shape, dim: int): model = SoftmaxModule(dim) edge_program = to_edge_program(model, input_shape).exported_program() + NodeFormatInference(edge_program).identify_node_formats() with pytest.raises( AssertionError, match="`aten__softmax_default` is not convertible" diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py new file mode 100644 index 00000000000..9ce3e93f39b --- /dev/null +++ b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py @@ -0,0 +1,181 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np +import pytest +import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, +) +from executorch.backends.nxp.tests.models import ( + SubTensorConvModule, + SubTensorModule, + SubTensorOneInputModule, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 + + +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(23) + np.random.seed(23) + + +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param((4,), id="1D."), + pytest.param((6, 6), id="2D."), + pytest.param((1, 4, 8), id="3D."), + pytest.param((1, 4, 8, 8), id="4D."), + ], +) +def test_sub_tensor_quant_conversion(mocker, input_shape, use_qat): + model = SubTensorModule() + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + # Run conversion + _ = to_quantized_edge_program(model, [input_shape, input_shape], use_qat=use_qat) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data_1 = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data_2 = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data = {0: input_data_1, 1: input_data_2} + + nodes = list(exported_program.graph.nodes) + assert nodes[4].target == exir_ops.edge.aten.sub.Tensor + + convert_run_compare( + exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data + ) + + +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param((4,), id="1D."), + pytest.param((6, 6), id="2D."), + pytest.param((1, 4, 8), id="3D."), + pytest.param((1, 4, 8, 8), id="4D."), + ], +) +def test_sub_tensor_one_input_quant_conversion(mocker, input_shape, use_qat): + model = SubTensorOneInputModule() + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + # Run conversion + _ = to_quantized_edge_program(model, input_shape, use_qat=use_qat) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + + nodes = list(exported_program.graph.nodes) + assert nodes[2].target == exir_ops.edge.aten.sub.Tensor + + convert_run_compare( + exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data + ) + + +@pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1, 4, 8, 8), id="4D."), + pytest.param((1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."), + ], +) +def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape, use_qat): + model = SubTensorConvModule() + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + n, c, h, w = x_input_shape + y_input_shape = (n, 8, h, w) + + # Run conversion + _ = to_quantized_edge_program( + model, + [x_input_shape, y_input_shape], + use_qat=use_qat, + use_neutron_for_format_conversion=False, + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data_1 = (np.random.random(x_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data_2 = (np.random.random(y_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data = {0: input_data_1, 1: input_data_2} + + nodes = list(exported_program.graph.nodes) + assert nodes[15].target == exir_ops.edge.aten.sub.Tensor + + convert_run_compare( + exported_program, + input_data=input_data, + tflite_input_preprocess=ToChannelLastPreprocess(), + tfl_model=tflite_flatbuffers_model, + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) + + +@pytest.mark.parametrize( + "x_input_shape, y_input_shape", + [ + pytest.param((1, 4, 7), (4, 7), id="3D -> 2D."), + pytest.param((1, 4, 8), (1, 4, 4, 8), id="3D -> 4D."), + pytest.param((1, 1, 4, 4, 8), (1, 4, 4, 8), id="5D -> 4D."), + pytest.param((4,), (4, 4), id="1D -> 2D."), + pytest.param((4,), (4, 4, 4), id="1D -> 3D."), + pytest.param((6, 6), (1, 8, 6, 6), id="2D -> 4D."), + pytest.param((6, 6), (6,), id="2D -> 1D."), + ], +) +def test_sub_tensor_broadcasting_unsupported_quant_conversion( + x_input_shape, y_input_shape, use_qat +): + model = SubTensorModule() + + # Run conversion + edge_program = to_quantized_edge_program( + model, [x_input_shape, y_input_shape], use_qat=use_qat + ).exported_program() + nodes = list(edge_program.graph.nodes) + + # Broadcast is not supported, node is not converted + assert ( + nodes[6].target == exir_ops.edge.aten.sub.Tensor + ) # Sub Tensor is not delegated. diff --git a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py index 40857d18eb8..10892d28e38 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py @@ -27,23 +27,30 @@ class TestTanhConverter(unittest.TestCase): __test__ = False # Prevent interfering with PyTest tests + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(23) + @parameterized.expand( input=[ - ( - "inplace", - True, - ), - ( - "not_inplace", - False, - ), + ("QAT inplace", True, True), + ("PTQ inplace", True, False), + ("QAT not-inplace", False, True), + ("PTQ not-inplace", False, False), ] ) def test_conv_tanh( - self, _: str, inplace: bool, input_shape: tuple[int] = (1, 3, 112, 112) + self, + _: str, + inplace: bool, + use_qat: bool, + input_shape: tuple[int] = (1, 3, 112, 112), ): with kgb.spy_on( - EdgeProgramToIRConverter.convert_program, call_original=True + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, ) as converter_spy: if inplace: model = Conv2dWithActivation( @@ -55,7 +62,10 @@ def test_conv_tanh( ) quantized_program = to_quantized_edge_program( - model, input_shape + model, + input_shape, + use_qat=use_qat, + use_neutron_for_format_conversion=False, ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value exported_program: ExportedProgram = converter_spy.calls[-1].args[0] @@ -76,10 +86,5 @@ def test_conv_tanh( tflite_input_preprocess=ToChannelLastPreprocess(), tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data, - atol=1.0, + atol=2.0, ) - - @classmethod - def setUpClass(cls): - torch.manual_seed(23) - np.random.seed(23) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py index 448a9753000..ce9fecb049b 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -12,6 +12,7 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -30,11 +31,14 @@ ) from executorch.backends.nxp.tests.executors import ( convert_run_compare, - ToNCHWPreprocess, - ToNHWCPreprocess, + graph_contains_any_of_ops, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, ) +from executorch.exir.dialects._ops import ops as exir_ops from torch import nn from torch.export import ExportedProgram +from executorch.backends.nxp.tests.use_qat import * # noqa F403 @pytest.fixture(autouse=True) @@ -107,7 +111,35 @@ def forward(self, x): return x -def test__channels_first_to_2d(mocker): +class ConvViewLinearModule(torch.nn.Module): + def __init__(self, view_new_shape: list[int], channels: int, bias: bool): + super().__init__() + self.view_new_shape = view_new_shape + self.conv = nn.Conv2d(channels, channels, 1, 1) + self.linear = nn.Linear(view_new_shape[1], 8, bias=bias) + + def forward(self, x): + x = self.conv(x) + x = x.view(self.view_new_shape) + x = self.linear(x) + return x + + +class ConvViewConvModule(torch.nn.Module): + def __init__(self, view_new_shape: list[int], channels: int): + super().__init__() + self.view_new_shape = view_new_shape + self.conv1 = nn.Conv2d(channels, channels, 1, 1) + self.conv2 = nn.Conv2d(channels, channels, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = x.view(self.view_new_shape) + x = self.conv2(x) + return x + + +def test__view_copy__channels_first_to_2d(mocker): input_shape = (1, 4, 7, 9) new_shape = (6, 32) # Mix up the dimensions for a thorough test. @@ -119,7 +151,7 @@ def test__channels_first_to_2d(mocker): converter_spy = mocker.spy(ModelBuilder, "finish") convert_run_compare( - edge_program, input_data, tflite_input_preprocess=ToNHWCPreprocess() + edge_program, input_data, tflite_input_preprocess=ToChannelLastPreprocess() ) tflite_model = converter_spy.spy_return @@ -130,7 +162,7 @@ def test__channels_first_to_2d(mocker): assert isinstance(ops[2].builtin_options, Reshape) -def test__channels_first_to_4d(mocker): +def test__view_copy__channels_first_to_4d(mocker): input_shape = (1, 8, 6, 8) new_shape = (7, 4, 2, 5) @@ -144,8 +176,11 @@ def test__channels_first_to_4d(mocker): convert_run_compare( edge_program, input_data, - tflite_input_preprocess=ToNHWCPreprocess(), + tflite_input_preprocess=ToChannelLastPreprocess(), atol=2.0e-7, + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) tflite_model = converter_spy.spy_return @@ -156,7 +191,7 @@ def test__channels_first_to_4d(mocker): assert isinstance(ops[2].builtin_options, Reshape) -def test__formatless_to_channels_first(mocker): +def test__view_copy__formatless_to_channels_first(mocker): input_shape = (12, 32) new_shape = (1, 4, 12, 8) # Mix up the dimensions for a thorough test. @@ -172,7 +207,7 @@ def test__formatless_to_channels_first(mocker): convert_run_compare( edge_program, input_data, - tflite_output_preprocess=ToNCHWPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), atol=2.0e-7, ) @@ -184,7 +219,7 @@ def test__formatless_to_channels_first(mocker): assert isinstance(ops[2].builtin_options, Conv2D) -def test__formatless_to_formatless(mocker): +def test__view_copy__formatless_to_formatless(mocker): input_shape = (12, 32) new_shape = (1, 4, 6, 16) @@ -209,11 +244,13 @@ def test__formatless_to_formatless(mocker): pytest.param((8, 64), (1, 16, 4, 4), id="2D"), ], ) -def test_view_copy_w_linear_quant_conversion(mocker, input_shape, new_shape): +def test_view_copy_w_linear_quant_conversion(mocker, input_shape, new_shape, use_qat): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(LinearReshapeModule(new_shape=new_shape), input_shape) + _ = to_quantized_edge_program( + LinearReshapeModule(new_shape=new_shape), input_shape, use_qat=use_qat + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -234,7 +271,9 @@ def test_view_copy_w_linear_quant_conversion(mocker, input_shape, new_shape): pytest.param((1, 4, 16, 16), 196, id="4D"), ], ) -def test_view_w_conv_linear_quant_conversion(mocker, input_shape, channels_view_out): +def test_view_w_conv_linear_quant_conversion( + mocker, input_shape, channels_view_out, use_qat +): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion @@ -243,6 +282,8 @@ def test_view_w_conv_linear_quant_conversion(mocker, input_shape, channels_view_ channels=input_shape[1], channels_view_out=channels_view_out ), input_shape, + use_qat=use_qat, + use_neutron_for_format_conversion=False, ) # Capture generated model @@ -256,7 +297,213 @@ def test_view_w_conv_linear_quant_conversion(mocker, input_shape, channels_view_ convert_run_compare( edge_program, input_data, - tflite_input_preprocess=ToNHWCPreprocess(), + tflite_input_preprocess=ToChannelLastPreprocess(), tfl_model=tflite_flatbuffers_model, atol=1.0, ) + + +@pytest.mark.parametrize( + "bias", + [True, False], +) +def test__view_copy__context_dependent__channels_first_to_formatless__transpose_fused( + bias, mocker +): + input_shape = (1, 2, 3, 4) + new_shape = [1, 2 * 3 * 4] + module = ConvViewLinearModule(new_shape, 2, bias) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + ep = to_quantized_edge_program( + module, + input_shape, + use_neutron_for_format_conversion=False, + ).exported_program() + + # Make sure all 3 nodes were delegated + assert any(n.name == "executorch_call_delegate" for n in ep.graph.nodes) + assert not graph_contains_any_of_ops( + ep.graph, + [ + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.view_copy.default, + ], + ) + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + + converted_edge_program = converter_spy.call_args.args[1] + neutron_ir_model = converter_spy.spy_return[0] + convert_run_compare( + converted_edge_program, + input_data, + tfl_model=neutron_ir_model, + tflite_input_preprocess=ToChannelLastPreprocess(), + ) + + +@pytest.mark.parametrize( + "bias", + [True, False], +) +def test__view_copy__context_dependent__channels_first_to_formatless__transpose_not_fusable( + bias, +): + input_shape = (1, 2, 3, 4) + new_shape = [ + 2, + 3 * 4, + ] # The batch size changes, which makes the optimization not applicable. + module = ConvViewLinearModule(new_shape, 2, bias) + + ep = to_quantized_edge_program( + module, + input_shape, + use_neutron_for_format_conversion=False, + ).exported_program() + + # Make sure the convolution and the linear were delegated, but not the view_copy. + assert any(n.name == "executorch_call_delegate" for n in ep.graph.nodes) + assert not graph_contains_any_of_ops( + ep.graph, + [ + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.addmm.default, + ], + ) + assert graph_contains_any_of_ops( + ep.graph, + [ + exir_ops.edge.aten.view_copy.default, + ], + ) + + +def test__view_copy__formatless_to_channels_first__transpose_supported(mocker): + input_shape = (1, 8 * 3 * 8) + new_shape = [1, 8, 3, 8] + module = FormatlessToChannelsFirstModule(8, new_shape) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + ep = to_quantized_edge_program( + module, + input_shape, + use_neutron_for_format_conversion=False, + ).exported_program() + + # Make sure both nodes were delegated + assert any(n.name == "executorch_call_delegate" for n in ep.graph.nodes) + assert not graph_contains_any_of_ops( + ep.graph, + [ + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.view_copy.default, + ], + ) + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + + converted_edge_program = converter_spy.call_args.args[1] + neutron_ir_model = converter_spy.spy_return[0] + convert_run_compare( + converted_edge_program, + input_data, + tfl_model=neutron_ir_model, + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) + + +def test__view_copy__formatless_to_channels_first__transpose_not_supported(): + input_shape = (1, 8 * 3 * 4) + new_shape = [1, 8, 3, 4] # The last dim is not a multiple of num_macs. + module = FormatlessToChannelsFirstModule(8, new_shape) + + ep = to_quantized_edge_program( + module, + input_shape, + use_neutron_for_format_conversion=False, + ).exported_program() + + # Make sure the view_copy was not delegated. + assert any(n.name == "executorch_call_delegate" for n in ep.graph.nodes) + assert not graph_contains_any_of_ops( + ep.graph, + [ + exir_ops.edge.aten.convolution.default, + ], + ) + assert graph_contains_any_of_ops( + ep.graph, + [ + exir_ops.edge.aten.view_copy.default, + ], + ) + + +def test__view_copy__channels_first_to_channels_first__transpose_supported(mocker): + input_shape = (1, 8, 3, 8) + new_shape = [1, 8, 1, 24] + module = ConvViewConvModule(new_shape, 8) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + ep = to_quantized_edge_program( + module, + input_shape, + use_neutron_for_format_conversion=False, + ).exported_program() + + # Make sure all nodes were delegated + assert any(n.name == "executorch_call_delegate" for n in ep.graph.nodes) + assert not graph_contains_any_of_ops( + ep.graph, + [ + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.view_copy.default, + ], + ) + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + + converted_edge_program = converter_spy.call_args.args[1] + neutron_ir_model = converter_spy.spy_return[0] + convert_run_compare( + converted_edge_program, + input_data, + tfl_model=neutron_ir_model, + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) + + +def test__view_copy__channels_first_to_channels_first__transpose_not_supported(): + input_shape = (1, 8, 3, 5) # The last dimension is not a multiple of num_macs. + new_shape = [1, 8, 1, 15] + module = ConvViewConvModule(new_shape, 8) + + ep = to_quantized_edge_program( + module, + input_shape, + use_neutron_for_format_conversion=False, + ).exported_program() + + # Make sure the view_copy was NOT delegated + assert any(n.name == "executorch_call_delegate" for n in ep.graph.nodes) + assert not graph_contains_any_of_ops( + ep.graph, + [ + exir_ops.edge.aten.convolution.default, + ], + ) + assert graph_contains_any_of_ops( + ep.graph, + [ + exir_ops.edge.aten.view_copy.default, + ], + ) diff --git a/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py b/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py index 17b040fbc3d..b5e701ab239 100644 --- a/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py +++ b/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py @@ -51,7 +51,10 @@ def test_remove_io_quant_ops_pass__cifarnet(): model = CifarNet().get_eager_model() input_shape = (1, 3, 32, 32) edge_program_manager = to_quantized_edge_program( - model, input_shape, remove_quant_io_ops=True + model, + input_shape, + remove_quant_io_ops=True, + use_neutron_for_format_conversion=False, ) exec_prog = edge_program_manager.to_executorch( diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index bdad9ddc4b4..e2b41aab8de 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -4,10 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math from typing import Callable, Collection, Union import torch +from torch import nn + class Conv1dModule(torch.nn.Module): def __init__( @@ -169,6 +172,64 @@ def forward(self, x): return self.linear(x) +class SliceTensorModule(torch.nn.Module): + def __init__(self, dims, starts, ends): + super().__init__() + self.dims = dims + self.starts = starts + self.ends = ends + + def do_slice(self, x): + slices = [slice(None)] * x.dim() + for i, dim in enumerate(self.dims): + slices[dim] = slice(self.starts[i], self.ends[i]) + return x[tuple(slices)] + + def forward(self, x): + x = self.do_slice(x) + + return x + + +class SliceTensorConvModule(torch.nn.Module): + def __init__(self, dims, starts, ends): + super().__init__() + self.conv = Conv2dModule(in_channels=4, out_channels=8, kernel_size=3, stride=1) + self.slice = SliceTensorModule(dims, starts, ends) + + def forward(self, x): + x = self.conv(x) + x = self.slice(x) + + return x + + +class AddmmModule(torch.nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(in_channels, in_channels)) + self.bias = torch.nn.Parameter(torch.empty(in_channels)) + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + self.eval() + + def forward(self, x): + return torch.addmm(self.bias, x, self.weight) + + +class MmModule(torch.nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(in_channels, in_channels)) + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.eval() + + def forward(self, x): + return torch.mm(x, self.weight) + + class LinearSoftmaxModule(torch.nn.Module): def __init__(self): super().__init__() @@ -396,6 +457,34 @@ def forward(self, x): return self.pool(x) +class MulTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @staticmethod + def forward(x, y): + return x * y + + +class MulTensorConvModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = Conv2dModule(padding=1, stride=1) + + def forward(self, x, y): + x = self.conv(x) + return x * y + + +class MulTensorOneInputModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @staticmethod + def forward(x): + return x * x + + class AddTensorModule(torch.nn.Module): def __init__(self): super().__init__() @@ -424,6 +513,34 @@ def forward(x): return x + x +class SubTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @staticmethod + def forward(x, y): + return x - y + + +class SubTensorConvModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = Conv2dModule(padding=1, stride=1) + + def forward(self, x, y): + x = self.conv(x) + return x - y + + +class SubTensorOneInputModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @staticmethod + def forward(x): + return x - x + + class MeanDimLinearModule(torch.nn.Module): def __init__(self, dim, keepdim): super().__init__() @@ -437,12 +554,119 @@ def forward(self, x): class MeanDimConvModule(torch.nn.Module): - def __init__(self, dim, keepdim): + def __init__(self, dim, keepdim, out_channels=8): super().__init__() - self.conv = Conv2dModule(stride=1, padding=1) + self.conv = Conv2dModule(stride=1, padding=1, out_channels=out_channels) self.dim = dim self.keepdim = keepdim def forward(self, x): x = self.conv(x) return torch.mean(x, dim=self.dim, keepdim=self.keepdim) + + +def get_activation(activation, inplace): + match activation: + case "relu": + return nn.ReLU(inplace=inplace) + case "relu_hardtanh": + return nn.Hardtanh(inplace=inplace, min_val=0.0, max_val=float("inf")) + case "relu6": + return nn.ReLU6(inplace=inplace) + case "tanh": + if inplace: + return torch.tanh + else: + return torch.tanh_ + case "sigmoid": + return nn.Sigmoid() + case _: + raise ValueError + + +class LinearActivationModule(torch.nn.Module): + def __init__( + self, activation: str, inplace: bool, in_channels: int, mode: str = "linear" + ): + super().__init__() + self.mode = mode.lower() + assert self.mode in [ + "linear", + "addmm", + "mm", + ], "Mode must be 'linear', 'addmm', or 'mm'" + + if self.mode == "linear": + self.linear = torch.nn.Linear(in_channels, in_channels) + else: + # Manual weight and bias for addmm/mm + self.weight = torch.nn.Parameter(torch.empty(in_channels, in_channels)) + self.bias = torch.nn.Parameter(torch.empty(in_channels)) + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + + self.activation = get_activation(activation, inplace) + self.eval() + + def forward(self, x): + if self.mode == "linear": + x = self.linear(x) + if self.mode == "addmm": + x = torch.addmm(self.bias, x, self.weight) + elif self.mode == "mm": + x = torch.mm(x, self.weight) + return self.activation(x) + + +class ConvActivationModule(torch.nn.Module): + def __init__(self, activation: str, inplace: bool, in_channels: int): + super().__init__() + + self.conv = Conv2dModule(in_channels=in_channels) + self.activation = get_activation(activation, inplace) + self.eval() + + def forward(self, x): + x = self.conv(x) + return self.activation(x) + + +class MiniConvNetWithRegressionHead(torch.nn.Module): + def __init__(self): + super().__init__() + + self.conv1 = Conv2dModule(in_channels=3, out_channels=16, stride=1, padding=1) + self.relu = torch.nn.ReLU() + self.pool = torch.nn.MaxPool2d(2, 2) + self.conv2 = Conv2dModule(in_channels=16, out_channels=32, stride=1, padding=1) + self.relu2 = torch.nn.ReLU() + self.pool = torch.nn.MaxPool2d(2, 2) + self.linear = torch.nn.Linear(32 * 8 * 8, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.pool(x) + x = self.conv2(x) + x = self.relu2(x) + x = self.pool(x) + x = x.flatten() + x = self.linear(x) + return x + + +class MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.sequential = torch.nn.Sequential( + torch.nn.Linear(1, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + + def forward(self, x): + return self.sequential(x) diff --git a/backends/nxp/tests/test_batch_norm_fusion.py b/backends/nxp/tests/test_batch_norm_fusion.py index 3f1106c6d24..eeb4b03d7a6 100644 --- a/backends/nxp/tests/test_batch_norm_fusion.py +++ b/backends/nxp/tests/test_batch_norm_fusion.py @@ -18,7 +18,10 @@ from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.view_copy_converter import ( ViewCopyConverter, ) -from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executorch_pipeline import ( + neutron_target_spec, + to_quantized_edge_program, +) from executorch.backends.nxp.tests.executors import OverrideTargetSupportCheck from torch import nn @@ -98,17 +101,18 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]): program = torch.export.export(module, example_input, strict=True) og_module = program.module() - pm = NeutronAtenPassManager() + pm = NeutronAtenPassManager(neutron_target_spec) graph_module_out = pm(deepcopy(program.module())).graph_module # Make sure the fusion worked. og_nodes = list(program.graph.nodes) transformed_nodes = list(graph_module_out.graph.nodes) - assert len(og_nodes) == (11 if bias else 10) - assert og_nodes[9 if bias else 8].target.__name__ == "batch_norm.default" + assert any( + node.op == "call_function" and node.target.__name__ == "batch_norm.default" + for node in og_nodes + ) - assert len(transformed_nodes) == 5 assert not any( node.op == "call_function" and "batch_norm" in node.target.__name__ for node in transformed_nodes @@ -118,7 +122,7 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]): input_data = torch.randn(input_shape, dtype=torch.float32) out1 = og_module(input_data).detach().numpy() out2 = graph_module_out(input_data).detach().numpy() - assert np.allclose(out1, out2, atol=3.0e-7) + torch.testing.assert_close(out1, out2) @pytest.mark.parametrize( @@ -132,17 +136,18 @@ def test_batch_norm_linear_fusing(bias: bool): program = torch.export.export(module, example_input, strict=True) og_module = program.module() - pm = NeutronAtenPassManager() + pm = NeutronAtenPassManager(neutron_target_spec) graph_module_out = pm(deepcopy(program.module())).graph_module # Make sure the fusion worked. og_nodes = list(og_module.graph.nodes) transformed_nodes = list(graph_module_out.graph.nodes) - assert len(og_nodes) == (11 if bias else 10) - assert og_nodes[8 if bias else 7].target.__name__ == "linear.default" + assert any( + node.op == "call_function" and node.target.__name__ == "linear.default" + for node in og_nodes + ) - assert len(transformed_nodes) == 5 assert not any( node.op == "call_function" and "batch_norm" in node.target.__name__ for node in transformed_nodes @@ -152,7 +157,7 @@ def test_batch_norm_linear_fusing(bias: bool): input_data = torch.randn(input_shape, dtype=torch.float32) out1 = og_module(input_data).detach().numpy() out2 = graph_module_out(input_data).detach().numpy() - assert np.allclose(out1, out2, atol=1.2e-7) + torch.testing.assert_close(out1, out2) @pytest.mark.parametrize( @@ -168,7 +173,7 @@ def test_batch_norm_conv_fusing__full_pipeline__1d(bias: bool): nodes = list(edge_program.graph.nodes) assert ( - len(nodes) == 13 + len(nodes) == 17 ) # 1D Conv currently isn't delegated, because it doesn't get quantized. assert not any( node.op == "call_function" and "batch_norm" in node.target.__name__ diff --git a/backends/nxp/tests/test_context_sensitive_delegation.py b/backends/nxp/tests/test_context_sensitive_delegation.py new file mode 100644 index 00000000000..1919bc63d82 --- /dev/null +++ b/backends/nxp/tests/test_context_sensitive_delegation.py @@ -0,0 +1,71 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import numpy as np +import torch + +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import ( + ViewCopyConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops +from executorch.exir.dialects._ops import ops as exir_ops + + +class SingleViewCopyModule(torch.nn.Module): + def __init__(self, new_shape: list[int]): + super().__init__() + self.new_shape = new_shape + + def forward(self, x): + return torch.reshape(x, self.new_shape) + + +class TestContextSensitiveDelegation(unittest.TestCase): + __test__ = False # Prevent interfering with PyTest tests. + + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) + + def test_single_view_copy_partition(self): + input_shape = (2, 10) + module = SingleViewCopyModule([1, 20]) + + ep = to_quantized_edge_program(module, input_shape).exported_program() + + # Make sure the `view_copy` was not delegated. + assert graph_contains_any_of_ops( + ep.graph, [exir_ops.edge.aten.view_copy.default] + ) + assert not any("delegate" in n.name for n in ep.graph.nodes) + + def test_single_view_copy_partition__forced_delegation(self): + input_shape = (2, 10) + module = SingleViewCopyModule([1, 20]) + + def _supported_partitioning(*_): + return True + + # Replace the partition support check function, to accept anything. + original_supports_partitioning_result = ( + ViewCopyConverter.supports_partitioning_result + ) + ViewCopyConverter.supports_partitioning_result = _supported_partitioning + + with self.assertRaises(RuntimeError) as e: + to_quantized_edge_program(module, input_shape).exported_program() + assert ( + str(e.exception) + == "Model converted with neutron-converter does not contain a NeutronGraph node." + ) + + # Return to the original partition support check function. + ViewCopyConverter.supports_partitioning_result = ( + original_supports_partitioning_result + ) diff --git a/backends/nxp/tests/test_edge_passes.py b/backends/nxp/tests/test_edge_passes.py index a189299be52..d93b1ae69ff 100644 --- a/backends/nxp/tests/test_edge_passes.py +++ b/backends/nxp/tests/test_edge_passes.py @@ -1,14 +1,58 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import kgb import numpy as np +import torch + +from executorch.backends.nxp.backend.custom_delegation_options import ( + CustomDelegationOptions, +) +from executorch.backends.nxp.backend.edge_helper import _is_dequantize, _is_quantize +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import ( ViewCopyConverter, ) -from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import ( + NeutronEdgePassManager, +) +from executorch.backends.nxp.edge_passes.remove_additional_quantize_dequantize_nodes_pass import ( + RemoveAdditionalQDQClustersPass, +) +from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner +from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.quantizer.utils import calibrate_and_quantize +from executorch.backends.nxp.tests.executorch_pipeline import ( + get_random_calibration_inputs, + neutron_target_spec, + to_model_input_spec, + to_quantized_edge_program, +) from executorch.backends.nxp.tests.executors import ( + compare_output_arrays, EdgeProgramExecutor, OverrideTargetSupportCheck, ) -from executorch.backends.nxp.tests.models import ConvFCFCSoftmaxModuleWithoutReshape +from executorch.backends.nxp.tests.ir.converter.node_converter.test_permute_copy_converter import ( + Conv2dPermuteModule, +) +from executorch.backends.nxp.tests.models import ( + ConvActivationModule, + ConvFCFCSoftmaxModuleWithoutReshape, + LinearActivationModule, +) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.extension.export_util.utils import export_to_edge +from parameterized import parameterized +from torch.export import ExportedProgram from torch.fx import Graph, Node @@ -19,21 +63,6 @@ def _is_view_copy(node_: Node) -> bool: ) -def _is_dequantize(node_: Node) -> bool: - return ( - node_.op == "call_function" - and node_.target.__name__ - == "quantized_decomposed.dequantize_per_tensor.default" - ) - - -def _is_quantize(node_: Node) -> bool: - return ( - node_.op == "call_function" - and node_.target.__name__ == "quantized_decomposed.quantize_per_tensor.default" - ) - - def _find_view_copy_node_indices(graph_nodes: list[Node]) -> list[int]: view_copy_nodes_indices = [] @@ -57,32 +86,297 @@ def _assert_nodes_form_a_view_copy_qdq_cluster(graph: Graph, node_indices: list[ assert quantize.args[0] == view_copy -def test_moving_view_copy_into_separate_qdq_clusters(): - model = ConvFCFCSoftmaxModuleWithoutReshape() - input_shape = (1, 4, 3, 33) +class TestEdgePasses(unittest.TestCase): + __test__ = False # Prevent interfering with PyTest tests + + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) + + def test_moving_view_copy_into_separate_qdq_clusters(self): + model = ConvFCFCSoftmaxModuleWithoutReshape() + input_shape = (1, 4, 3, 33) + + # Prohibit `view_copy` conversion for the testing purposes. + def unsupported_target(*_): + return False + + with OverrideTargetSupportCheck( + ViewCopyConverter, new_target_support_check=unsupported_target + ): + epm = to_quantized_edge_program(model, input_shape, target="imxrt700") + exported_program = epm.exported_program() + + nodes = list(exported_program.graph_module.graph.nodes) + assert len(nodes) == 28 - # Prohibit `view_copy` conversion for the testing purposes. - def unsupported_target(*_): - return False + view_copy_indices = _find_view_copy_node_indices(nodes) - with OverrideTargetSupportCheck( - ViewCopyConverter, new_target_support_check=unsupported_target + assert len(view_copy_indices) == 4 + for idx in view_copy_indices: + _assert_nodes_form_a_view_copy_qdq_cluster( + exported_program.graph, node_indices=[idx - 1, idx, idx + 1] + ) + + # Make sure the program is runnable. + input_data = np.random.random(input_shape).astype("float32") + program_executor = EdgeProgramExecutor(exported_program) + program_executor.inference(input_data) + + @parameterized.expand( + [ + ["relu"], + ["relu6"], + ["tanh"], + ["sigmoid"], + ] + ) + def test_moving_fusable_activations_into_separate_qdq_clusters__addmm( + self, activation ): - epm = to_quantized_edge_program(model, input_shape, target="imxrt700") - exported_program = epm.exported_program() + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 4) + model = LinearActivationModule( + activation=activation, + inplace=True, + in_channels=input_shape[1], + mode="addmm", + ) - nodes = list(exported_program.graph_module.graph.nodes) - assert len(nodes) == 28 + _ = to_quantized_edge_program(model, input_shape) + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] - view_copy_indices = _find_view_copy_node_indices(nodes) + # Check linear and activation are in separate QDQ clusters + nodes = list(exported_program.graph.nodes) + assert len(nodes) == 12 + assert _is_dequantize(nodes[5]) + assert ( + neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__edge( + nodes[6] + ) + ) + assert _is_quantize(nodes[7]) + assert _is_dequantize(nodes[8]) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__edge( + nodes[9] + ) + assert _is_quantize(nodes[10]) - assert len(view_copy_indices) == 4 - for idx in view_copy_indices: - _assert_nodes_form_a_view_copy_qdq_cluster( - exported_program.graph, node_indices=[idx - 1, idx, idx + 1] + @parameterized.expand( + [ + ["relu"], + ["relu6"], + ["tanh"], + ["sigmoid"], + ] + ) + def test_moving_fusable_activations_into_separate_qdq_clusters__mm( + self, activation + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 4) + model = LinearActivationModule( + activation=activation, + inplace=True, + in_channels=input_shape[1], + mode="mm", ) - # Make sure the program is runnable. - input_data = np.random.random(input_shape).astype("float32") - program_executor = EdgeProgramExecutor(exported_program) - program_executor.inference(input_data) + _ = to_quantized_edge_program(model, input_shape) + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + # Check linear and activation are in separate QDQ clusters + nodes = list(exported_program.graph.nodes) + assert len(nodes) == 10 + assert _is_dequantize(nodes[3]) + assert ( + neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__edge( + nodes[4] + ) + ) + assert _is_quantize(nodes[5]) + assert _is_dequantize(nodes[6]) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__edge( + nodes[7] + ) + assert _is_quantize(nodes[8]) + + @parameterized.expand( + [ + ["relu"], + ["relu6"], + ["tanh"], + ["sigmoid"], + ] + ) + def test_moving_fusable_activations_into_separate_qdq_clusters__linear( + self, activation + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 4) + model = LinearActivationModule( + activation=activation, + inplace=True, + in_channels=input_shape[1], + mode="linear", + ) + + _ = to_quantized_edge_program(model, input_shape) + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + # Check linear and activation are in separate QDQ clusters + nodes = list(exported_program.graph.nodes) + assert len(nodes) == 13 + assert _is_dequantize(nodes[5]) + assert ( + neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__edge( + nodes[7] + ) + ) + assert _is_quantize(nodes[8]) + assert _is_dequantize(nodes[9]) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__edge( + nodes[10] + ) + assert _is_quantize(nodes[11]) + + @parameterized.expand( + [ + ["relu"], + ["relu6"], + ["tanh"], + ["sigmoid"], + ] + ) + def test_moving_fusable_activations_into_separate_qdq_clusters__conv( + self, activation + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 4, 8, 8) + model = ConvActivationModule( + activation=activation, inplace=True, in_channels=input_shape[1] + ) + + _ = to_quantized_edge_program(model, input_shape) + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + # Check linear and activation are in separate QDQ clusters + nodes = list(exported_program.graph.nodes) + assert len(nodes) == 16 + assert _is_dequantize(nodes[9]) + assert ( + neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__edge( + nodes[10] + ) + ) + assert _is_quantize(nodes[11]) + assert _is_dequantize(nodes[12]) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__edge( + nodes[13] + ) + assert _is_quantize(nodes[14]) + + def test_remove_additional_quantize_dequantize_nodes_pass(self): + input_shape = (1, 3, 8, 16) + new_dims = (3, 2, 1, 0) + model = Conv2dPermuteModule(input_shape[1], new_dims) + target = "imxrt700" + custom_delegation_options = CustomDelegationOptions() + + calibration_inputs = get_random_calibration_inputs( + to_model_input_spec(input_shape) + ) + + example_input = calibration_inputs[0] + exir_program_aten = torch.export.export(model, example_input, strict=True) + + exir_program_aten_quant = calibrate_and_quantize( + exir_program_aten, + calibration_inputs, + NeutronQuantizer(neutron_target_spec), + ) + edge_program_manager = export_to_edge( + exir_program_aten_quant, + example_input, + ) + + edge_program_manager = edge_program_manager.transform(NeutronEdgePassManager()) + + compile_spec = generate_neutron_compile_spec(target, "SDK_25_09") + partitioner = NeutronPartitioner( + compile_spec, neutron_target_spec, custom_delegation_options + ) + + edge_program_manager = edge_program_manager.to_backend(partitioner) + + # Make sure QDQ cluster for permute_copy is present. + edge_program_with_qdq_cluster = copy.deepcopy( + edge_program_manager.exported_program() + ) + nodes = list(edge_program_with_qdq_cluster.graph.nodes) + assert len(nodes) == 10 + assert ( + nodes[5].target + == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ) + assert nodes[6].target == exir_ops.edge.aten.permute_copy.default + assert "cluster" in nodes[6].meta + assert ( + nodes[7].target + == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ) + + # Run pass for removal of additional QDQ nodes and compute in non-float types where possible + edge_program_manager = edge_program_manager.transform( + NeutronEdgePassManager([RemoveAdditionalQDQClustersPass()]) + ) + + # Make sure QDQ cluster for permute_copy is removed. + edge_program_without_qdq_cluster = edge_program_manager.exported_program() + nodes = list(edge_program_without_qdq_cluster.graph.nodes) + assert len(nodes) == 8 + assert nodes[4].name == "getitem" + assert nodes[5].target == exir_ops.edge.aten.permute_copy.default + assert "cluster" not in nodes[5].meta + assert ( + nodes[6].target + == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ) + + edge_program_executor_without_qdq_cluster = EdgeProgramExecutor( + edge_program_without_qdq_cluster + ) + edge_program_executor_with_qdq_cluster = EdgeProgramExecutor( + edge_program_with_qdq_cluster + ) + + input_data = np.random.random(input_shape).astype(np.float32) + edge_program_output_without_qdq_cluster = ( + edge_program_executor_without_qdq_cluster.inference(input_data) + ) + edge_program_output_with_qdq_cluster = ( + edge_program_executor_with_qdq_cluster.inference(input_data) + ) + + compare_output_arrays( + edge_program_output_without_qdq_cluster, + edge_program_output_with_qdq_cluster, + "main output", + ) diff --git a/backends/nxp/tests/test_gru_splitting.py b/backends/nxp/tests/test_gru_splitting.py index a2e9d324f69..297f9677fb2 100644 --- a/backends/nxp/tests/test_gru_splitting.py +++ b/backends/nxp/tests/test_gru_splitting.py @@ -13,6 +13,7 @@ from executorch.backends.nxp.aten_passes.split_gru_based_on_num_layers import ( SplitGRUBasedOnNumLayers, ) +from executorch.backends.nxp.tests.executorch_pipeline import neutron_target_spec @pytest.fixture(autouse=True) @@ -94,7 +95,9 @@ def test_gru_splitting__with_bias(num_layers): ) # Just 1 `GRU` in the model. # Run pre-processing passes of the float32 aten dialect program. - pytorch_pass_manager = NeutronAtenPassManager([SplitGRUBasedOnNumLayers()]) + pytorch_pass_manager = NeutronAtenPassManager( + neutron_target_spec, [SplitGRUBasedOnNumLayers()] + ) pytorch_pass_manager(exir_program_aten) post_pass_output = [t.detach() for t in exir_program_aten(*example_input)] @@ -143,7 +146,9 @@ def test_gru_splitting__no_bias(num_layers): ) # Just 1 `GRU` in the model. # Run pre-processing passes of the float32 aten dialect program. - pytorch_pass_manager = NeutronAtenPassManager([SplitGRUBasedOnNumLayers()]) + pytorch_pass_manager = NeutronAtenPassManager( + neutron_target_spec, [SplitGRUBasedOnNumLayers()] + ) pytorch_pass_manager(exir_program_aten) post_pass_output = [t.detach() for t in exir_program_aten(*example_input)] @@ -193,7 +198,9 @@ def test_gru_splitting__bidirectional__no_bias(num_layers): ) # Just 1 `GRU` in the model. # Run pre-processing passes of the float32 aten dialect program. - pytorch_pass_manager = NeutronAtenPassManager([SplitGRUBasedOnNumLayers()]) + pytorch_pass_manager = NeutronAtenPassManager( + neutron_target_spec, [SplitGRUBasedOnNumLayers()] + ) pytorch_pass_manager(exir_program_aten) nodes = list(exir_program_aten.graph.nodes) @@ -239,7 +246,9 @@ def test_gru_splitting__bidirectional__with_bias(num_layers): ) # Just 1 `GRU` in the model. # Run pre-processing passes of the float32 aten dialect program. - pytorch_pass_manager = NeutronAtenPassManager([SplitGRUBasedOnNumLayers()]) + pytorch_pass_manager = NeutronAtenPassManager( + neutron_target_spec, [SplitGRUBasedOnNumLayers()] + ) pytorch_pass_manager(exir_program_aten) nodes = list(exir_program_aten.graph.nodes) diff --git a/backends/nxp/tests/test_integration.py b/backends/nxp/tests/test_integration.py index d31b22c9ce9..fe157b44c48 100644 --- a/backends/nxp/tests/test_integration.py +++ b/backends/nxp/tests/test_integration.py @@ -5,6 +5,7 @@ import executorch.extension.pybindings.portable_lib import executorch.kernels.quantized # noqa F401 +from executorch.backends.nxp.tests.use_qat import * # noqa F401 from executorch.backends.nxp.tests.executorch_pipeline import ( to_quantized_executorch_program, @@ -14,11 +15,11 @@ from executorch.examples.nxp.experimental.cifar_net.cifar_net import CifarNet -def test_conv_fc_softmax__to_executorch_program(): +def test_conv_fc_softmax__to_executorch_program(use_qat): model = ConvFCSoftmaxModule() input_shape = (1, 4, 5, 5) - exec_prog = to_quantized_executorch_program(model, input_shape) + exec_prog = to_quantized_executorch_program(model, input_shape, use_qat) program = exec_prog.exported_program() assert ( @@ -36,10 +37,12 @@ def test_conv_fc_softmax__to_executorch_program(): assert "addmm" not in node.name -def test_cifarnet(): +def test_cifarnet(use_qat): model = CifarNet().get_eager_model().eval() input_shape = (1, 3, 32, 32) - exec_prog = to_quantized_executorch_program(model, input_shape) + exec_prog = to_quantized_executorch_program( + model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + ) delegation_info = get_delegation_info(exec_prog.exported_program().graph_module) assert delegation_info.num_delegated_subgraphs == 1 diff --git a/backends/nxp/tests/test_linear_and_add_fusion.py b/backends/nxp/tests/test_linear_and_add_fusion.py new file mode 100644 index 00000000000..222d748001c --- /dev/null +++ b/backends/nxp/tests/test_linear_and_add_fusion.py @@ -0,0 +1,653 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from copy import deepcopy + +import numpy as np +import torch + +from executorch.backends.nxp.aten_passes.fuse_linear_and_add_pass import ( + FuseLinearAndAddPass, +) +from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( + NeutronAtenPassManager, +) +from executorch.backends.nxp.aten_passes.remove_nodes_with_known_outputs import ( + RemoveNodesWithKnownOutputs, +) +from executorch.backends.nxp.tests.executorch_pipeline import neutron_target_spec +from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops +from parameterized import parameterized + + +class LinearAddModule(torch.nn.Module): + def __init__( + self, + fc_in_features: int, + fc_out_features: int, + bias: bool, + artificial_bias_shape: list[int], + alpha=1.0, + ): + super().__init__() + self.fc_in_features = fc_in_features + self.fc_out_features = fc_out_features + self.bias = bias + self.artificial_bias_shape = artificial_bias_shape + self.alpha = alpha + self.linear = torch.nn.Linear(fc_in_features, fc_out_features, bias=bias) + self.eval() + + def forward(self, x): + artificial_bias = torch.ones(self.artificial_bias_shape, dtype=torch.float32) + x = self.linear(x) + return torch.add(x, artificial_bias, alpha=self.alpha) + + +class LinearAddModuleReverseNodeOrder(torch.nn.Module): + """The `ones` added by the `add` are only generated after the `linear` node.""" + + def __init__( + self, + fc_in_features: int, + fc_out_features: int, + bias: bool, + artificial_bias_shape: list[int], + ): + super().__init__() + self.fc_in_features = fc_in_features + self.fc_out_features = fc_out_features + self.bias = bias + self.artificial_bias_shape = artificial_bias_shape + self.linear = torch.nn.Linear(fc_in_features, fc_out_features, bias=bias) + self.eval() + + def forward(self, x): + # The `ones` are generated after the `linear` call. + x = self.linear(x) + artificial_bias = torch.ones(self.artificial_bias_shape, dtype=torch.float32) + return torch.add(x, artificial_bias) + + +class LinearAddModuleReverseInputOrder(torch.nn.Module): + """The `add` has the output of the `linear` as its second input (which is the input multiplied by `alpha`).""" + + def __init__( + self, + fc_in_features: int, + fc_out_features: int, + bias: bool, + artificial_bias_shape: list[int], + alpha=1.0, + ): + super().__init__() + self.fc_in_features = fc_in_features + self.fc_out_features = fc_out_features + self.bias = bias + self.artificial_bias_shape = artificial_bias_shape + self.alpha = alpha + self.linear = torch.nn.Linear(fc_in_features, fc_out_features, bias=bias) + self.eval() + + def forward(self, x): + artificial_bias = torch.ones(self.artificial_bias_shape, dtype=torch.float32) + x = self.linear(x) + return torch.add(artificial_bias, x, alpha=self.alpha) # Reversed input order. + + +class TestLinearAndAddFusing(unittest.TestCase): + __test__ = False # Prevent interfering with PyTest tests. + + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) + + @parameterized.expand( + [ + ["2D", [4, 6]], + ["4D", [4, 6, 8, 10]], + ] + ) + def test_linear_add_fusing__static__no_bias__valid_shape( + self, _, input_shape: list[int] + ): + example_input = (torch.ones(input_shape),) + + module = LinearAddModule(input_shape[-1], 5, False, [5]) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, + [ + RemoveNodesWithKnownOutputs(), # Make the added tensor static. + FuseLinearAndAddPass(), + ], + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 6 + assert original_nodes[3].target == torch.ops.aten.linear.default + assert original_nodes[4].target == torch.ops.aten.add.Tensor + + # The `add` has been removed. + assert len(modified_nodes) == 5 + assert modified_nodes[3].target == torch.ops.aten.linear.default + assert len(modified_nodes[3].args) == 3 + assert "ones" in modified_nodes[3].args[2].name + assert not graph_contains_any_of_ops( + modified_module.graph, [torch.ops.aten.add.Tensor] + ) + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + @parameterized.expand( + [ + ["2D", [8, 10]], + ] + ) + def test_linear_add_fusing__static__no_bias__invalid_shape( + self, _, input_shape: list[int] + ): + example_input = (torch.ones(input_shape),) + + module = LinearAddModule( + input_shape[-1], 5, False, [8, 5] # Unsupported `linear` bias shape. + ) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, + [ + RemoveNodesWithKnownOutputs(), # Make the added tensor static. + FuseLinearAndAddPass(), + ], + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 6 + assert original_nodes[3].target == torch.ops.aten.linear.default + assert len(original_nodes[3].args) == 2 + assert original_nodes[4].target == torch.ops.aten.add.Tensor + + # Nothing changed. + assert len(modified_nodes) == 6 + assert modified_nodes[3].target == torch.ops.aten.linear.default + assert modified_nodes[4].target == torch.ops.aten.add.Tensor + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + @parameterized.expand( + [ + ["2D", [4, 6]], + ["4D", [2, 3, 4, 5]], + ] + ) + def test_linear_add_fusing__static__bias__valid_shape( + self, _, input_shape: list[int] + ): + example_input = (torch.ones(input_shape),) + + module = LinearAddModule(input_shape[-1], 5, True, [5]) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, + [ + RemoveNodesWithKnownOutputs(), # Make the added tensor static. + FuseLinearAndAddPass(), + ], + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 7 + assert original_nodes[3].target == torch.ops.aten.ones.default + assert original_nodes[4].target == torch.ops.aten.linear.default + assert len(original_nodes[4].args) == 3 + assert original_nodes[5].target == torch.ops.aten.add.Tensor + + # make sure the `add` and the `ones` were removed. + assert len(modified_nodes) == 5 + assert not graph_contains_any_of_ops( + modified_module.graph, [torch.ops.aten.ones.default] + ) + assert modified_nodes[3].target == torch.ops.aten.linear.default + assert len(modified_nodes[3].args) == 3 + assert "combined" in modified_nodes[3].args[2].name + assert not graph_contains_any_of_ops( + modified_module.graph, [torch.ops.aten.add.Tensor] + ) + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + def test_linear_add_fusing__static__no_bias__reverse_order(self): + input_shape = [4, 8] + example_input = (torch.ones(input_shape),) + + # Use a module where the `bias` is generated after the `linear` node, which prevents the change. + module = LinearAddModuleReverseNodeOrder(input_shape[-1], 5, False, [5]) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, + [ + RemoveNodesWithKnownOutputs(), # Make the added tensor static. + FuseLinearAndAddPass(), + ], + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 6 + assert original_nodes[2].target == torch.ops.aten.linear.default + assert len(original_nodes[2].args) == 2 + assert ( + original_nodes[3].target == torch.ops.aten.ones.default + ) # `ones` after `linear`. + assert original_nodes[4].target == torch.ops.aten.add.Tensor + + # The `add` has been removed. + assert len(modified_nodes) == 5 + assert modified_nodes[3].target == torch.ops.aten.linear.default + assert len(modified_nodes[3].args) == 3 + assert not graph_contains_any_of_ops( + modified_module.graph, [torch.ops.aten.add.Tensor] + ) + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + def test_linear_add_fusing__static__bias__reverse_order(self): + input_shape = [4, 8] + example_input = (torch.ones(input_shape),) + + # Use a module where the `bias` is generated after the `linear` node, which prevents the change. + module = LinearAddModuleReverseNodeOrder(input_shape[-1], 5, True, [5]) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, + [ + RemoveNodesWithKnownOutputs(), # Make the added tensor static. + FuseLinearAndAddPass(), + ], + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 7 + assert original_nodes[3].target == torch.ops.aten.linear.default + assert len(original_nodes[3].args) == 3 + assert ( + original_nodes[4].target == torch.ops.aten.ones.default + ) # `ones` after `linear`. + assert original_nodes[5].target == torch.ops.aten.add.Tensor + + # The `add` and `ones` have been removed. + assert len(modified_nodes) == 5 + assert not graph_contains_any_of_ops( + modified_module.graph, [torch.ops.aten.ones.default] + ) + assert modified_nodes[3].target == torch.ops.aten.linear.default + assert len(modified_nodes[3].args) == 3 + assert not graph_contains_any_of_ops( + modified_module.graph, [torch.ops.aten.add.Tensor] + ) + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + def test_linear_add_fusing__static__alpha__no_bias(self): + alpha = 2.34 + input_shape = [4, 8] + example_input = (torch.ones(input_shape),) + + module = LinearAddModule(input_shape[-1], 5, False, [5], alpha=alpha) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, + [ + RemoveNodesWithKnownOutputs(), # Make the added tensor static. + FuseLinearAndAddPass(), + ], + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 6 + assert original_nodes[2].target == torch.ops.aten.ones.default + assert original_nodes[3].target == torch.ops.aten.linear.default + assert len(original_nodes[3].args) == 2 + assert original_nodes[4].target == torch.ops.aten.add.Tensor + assert original_nodes[4].kwargs["alpha"] == alpha + + # The `add` has been removed. + assert len(modified_nodes) == 5 + assert modified_nodes[3].target == torch.ops.aten.linear.default + assert len(modified_nodes[3].args) == 3 + assert not graph_contains_any_of_ops( + modified_module.graph, [torch.ops.aten.add.Tensor] + ) + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + def test_linear_add_fusing__static__alpha__bias(self): + alpha = 2.34 + input_shape = [4, 8] + example_input = (torch.ones(input_shape),) + + module = LinearAddModule(input_shape[-1], 5, True, [5], alpha=alpha) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, + [ + RemoveNodesWithKnownOutputs(), # Make the added tensor static. + FuseLinearAndAddPass(), + ], + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 7 + assert original_nodes[3].target == torch.ops.aten.ones.default + assert original_nodes[4].target == torch.ops.aten.linear.default + assert len(original_nodes[4].args) == 3 + assert original_nodes[5].target == torch.ops.aten.add.Tensor + assert original_nodes[5].kwargs["alpha"] == alpha + + # The `add` has been removed. + assert len(modified_nodes) == 5 + assert modified_nodes[3].target == torch.ops.aten.linear.default + assert len(modified_nodes[3].args) == 3 + assert not graph_contains_any_of_ops( + modified_module.graph, [torch.ops.aten.add.Tensor] + ) + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + def test_linear_add_fusing__static__alpha__reversed_add_inputs(self): + alpha = 2.34 + input_shape = [4, 8] + example_input = (torch.ones(input_shape),) + + module = LinearAddModuleReverseInputOrder( + input_shape[-1], 5, True, [5], alpha=alpha + ) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, + [ + RemoveNodesWithKnownOutputs(), # Make the added tensor static. + FuseLinearAndAddPass(), + ], + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 7 + assert original_nodes[3].target == torch.ops.aten.ones.default + assert original_nodes[4].target == torch.ops.aten.linear.default + assert len(original_nodes[4].args) == 3 + assert original_nodes[5].target == torch.ops.aten.add.Tensor + assert ( + original_nodes[5].args[1] == original_nodes[4] + ) # `linear` is the second input. + assert original_nodes[5].kwargs["alpha"] == alpha + + # Nothing changed (except the `ones` was replaced by static data). + assert len(modified_nodes) == 7 + assert modified_nodes[4].target == torch.ops.aten.linear.default + assert len(modified_nodes[4].args) == 3 + assert modified_nodes[5].target == torch.ops.aten.add.Tensor + assert ( + modified_nodes[5].args[1] == modified_nodes[4] + ) # `linear` is the second input. + assert modified_nodes[5].kwargs["alpha"] == alpha + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + @parameterized.expand( + [ + ["2D", [4, 6]], + ] + ) + def test_linear_add_fusing__dynamic__no_bias__valid_shape( + self, _, input_shape: list[int] + ): + example_input = (torch.ones(input_shape),) + + module = LinearAddModule(input_shape[-1], 5, False, [5]) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, [FuseLinearAndAddPass()] + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 6 + assert original_nodes[3].target == torch.ops.aten.linear.default + assert original_nodes[4].target == torch.ops.aten.add.Tensor + + # Nothing changed. + assert len(modified_nodes) == 6 + assert modified_nodes[3].target == torch.ops.aten.linear.default + assert modified_nodes[4].target == torch.ops.aten.add.Tensor + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + @parameterized.expand( + [ + ["2D", [8, 10]], + ] + ) + def test_linear_add_fusing__dynamic__no_bias__invalid_shape( + self, _, input_shape: list[int] + ): + example_input = (torch.ones(input_shape),) + + module = LinearAddModule( + input_shape[-1], 5, False, [8, 5] # Unsupported `linear` bias shape. + ) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, [FuseLinearAndAddPass()] + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 6 + assert original_nodes[3].target == torch.ops.aten.linear.default + assert original_nodes[4].target == torch.ops.aten.add.Tensor + + # Nothing changed. + assert len(modified_nodes) == 6 + assert modified_nodes[3].target == torch.ops.aten.linear.default + assert modified_nodes[4].target == torch.ops.aten.add.Tensor + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + @parameterized.expand( + [ + ["2D", [4, 6]], + ] + ) + def test_linear_add_fusing__dynamic__bias__valid_shape( + self, _, input_shape: list[int] + ): + example_input = (torch.ones(input_shape),) + + module = LinearAddModule(input_shape[-1], 5, True, [5]) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, [FuseLinearAndAddPass()] + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 7 + assert original_nodes[3].target == torch.ops.aten.ones.default + assert original_nodes[4].target == torch.ops.aten.linear.default + assert original_nodes[5].target == torch.ops.aten.add.Tensor + + # Nothing has changed, as the second bias is dynamic, so it cannot be added together with the first bias. + assert len(modified_nodes) == 7 + assert modified_nodes[3].target == torch.ops.aten.ones.default + assert modified_nodes[4].target == torch.ops.aten.linear.default + assert modified_nodes[5].target == torch.ops.aten.add.Tensor + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + def test_linear_add_fusing__dynamic__reverse_order(self): + input_shape = [4, 8] + example_input = (torch.ones(input_shape),) + + # Use a module where the `bias` is generated after the `linear` node, which prevents the change. + module = LinearAddModuleReverseNodeOrder(input_shape[-1], 5, False, [5]) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, [FuseLinearAndAddPass()] + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 6 + assert original_nodes[2].target == torch.ops.aten.linear.default + assert original_nodes[3].target == torch.ops.aten.ones.default + assert original_nodes[4].target == torch.ops.aten.add.Tensor + + # Nothing has changed. + assert len(modified_nodes) == 6 + assert modified_nodes[2].target == torch.ops.aten.linear.default + assert modified_nodes[3].target == torch.ops.aten.ones.default + assert modified_nodes[4].target == torch.ops.aten.add.Tensor + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + def test_linear_add_fusing__dynamic__alpha(self): + alpha = 2.34 + input_shape = [4, 8] + example_input = (torch.ones(input_shape),) + + module = LinearAddModule(input_shape[-1], 5, False, [5], alpha=alpha) + program = torch.export.export(module, example_input, strict=True) + original_module = program.module() + + modified_module = NeutronAtenPassManager( + neutron_target_spec, [FuseLinearAndAddPass()] + )(deepcopy(program.module())).graph_module + + # Make sure the module wasn't broken. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 6 + assert original_nodes[2].target == torch.ops.aten.ones.default + assert original_nodes[3].target == torch.ops.aten.linear.default + assert original_nodes[4].target == torch.ops.aten.add.Tensor + + # Nothing has changed. + assert len(modified_nodes) == 6 + assert modified_nodes[2].target == torch.ops.aten.ones.default + assert modified_nodes[3].target == torch.ops.aten.linear.default + assert modified_nodes[4].target == torch.ops.aten.add.Tensor + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) diff --git a/backends/nxp/tests/test_move_activation_before_concatenation.py b/backends/nxp/tests/test_move_activation_before_concatenation.py new file mode 100644 index 00000000000..27bd675a487 --- /dev/null +++ b/backends/nxp/tests/test_move_activation_before_concatenation.py @@ -0,0 +1,905 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import math +import unittest + +import kgb +import numpy as np +import torch +from executorch.backends.nxp.aten_passes.move_activation_before_concat import ( + MoveActivationBeforeConcat, +) +from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( + NeutronAtenPassManager, +) +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.quantizer.utils import calibrate_and_quantize +from executorch.backends.nxp.tests.executorch_pipeline import ( + get_random_calibration_inputs, + neutron_target_spec, + to_model_input_spec, + to_quantized_edge_program, +) +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + graph_contains_any_of_ops, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, +) +from executorch.backends.nxp.tests.models import get_activation +from executorch.exir.dialects._ops import ops as exir_ops +from parameterized import parameterized +from torch import nn +from torch.export import ExportedProgram +from torch.fx import GraphModule + +concat_cluster_ops = [ + exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.cat.default, +] + + +# Permutation of all supported combinations of: +# , , +all_activation_cases = list( + itertools.product( + ["relu", "relu6", "tanh"], + [True, False], + [True, False], + ) +) + [ + ("sigmoid", False, True), + ("sigmoid", False, False), +] + + +# , , , , +all_concat_cluster_cases = [ + ("relu", "relu", True, False, True), + ("relu", "relu", True, False, False), + ("relu6", "relu6", False, True, True), + ("relu6", "relu6", False, True, False), + ("tanh", "tanh", True, False, True), + ("tanh", "tanh", True, False, False), + ("sigmoid", "sigmoid", False, True, True), + ("sigmoid", "sigmoid", False, True, False), + ("relu", "relu_hardtanh", True, True, True), + ("relu", "relu_hardtanh", True, True, False), +] + + +class ConvConcatActivationModule(torch.nn.Module): + def __init__(self, activation: str, inplace: bool, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + in_channels, + (3, 3), + padding=1, + ) + + self.activation = get_activation(activation, inplace) + self.eval() + + def forward(self, x): + x1 = self.conv(x) + x2 = self.conv(x) + x = torch.cat((x1, x2), dim=1) + return self.activation(x) + + +class LinearConcatActivationModule(nn.Module): + def __init__( + self, activation: str, inplace: bool, in_channels: int, mode: str = "linear" + ): + super().__init__() + self.mode = mode.lower() + assert self.mode in [ + "linear", + "addmm", + "mm", + ], "Mode must be 'linear', 'addmm', or 'mm'" + + if self.mode == "linear": + self.linear = nn.Linear(in_channels, in_channels) + else: + # Manual weight and bias for addmm/mm. + self.weight = nn.Parameter(torch.empty(in_channels, in_channels)) + self.bias = nn.Parameter(torch.empty(in_channels)) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + self.activation = get_activation(activation, inplace) + self.eval() + + def forward(self, x): + x1, x2 = None, None + + if self.mode == "linear": + x1 = self.linear(x) + x2 = self.linear(x) + if self.mode == "addmm": + x1 = torch.addmm(self.bias, x, self.weight) + x2 = torch.addmm(self.bias, x, self.weight) + elif self.mode == "mm": + x1 = torch.mm(x, self.weight) + x2 = torch.mm(x, self.weight) + + x = torch.cat((x1, x2), dim=1) + return self.activation(x) + + +class ConvActivationConcatModule(torch.nn.Module): + def __init__( + self, + activation1: str, + activation2: str, + act1_inplace: bool, + act2_inplace: bool, + in_channels: int, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + in_channels, + (3, 3), + padding=1, + ) + + self.activation1 = get_activation(activation1, act1_inplace) + self.activation2 = get_activation(activation2, act2_inplace) + self.eval() + + def forward(self, x): + x1 = self.conv(x) + x1 = self.activation1(x1) + x2 = self.conv(x) + x2 = self.activation2(x2) + return torch.cat((x1, x2), dim=1) + + +class LinearActivationConcatModule(torch.nn.Module): + def __init__( + self, + activation1: str, + activation2: str, + act1_inplace: bool, + act2_inplace: bool, + in_channels: int, + ): + super().__init__() + self.linear = nn.Linear(in_channels, in_channels) + + self.activation1 = get_activation(activation1, act1_inplace) + self.activation2 = get_activation(activation2, act2_inplace) + self.eval() + + def forward(self, x): + x1 = self.linear(x) + x1 = self.activation1(x1) + x2 = self.linear(x) + x2 = self.activation2(x2) + return torch.cat((x1, x2), dim=1) + + +class TestMoveActivationBeforeConcat(unittest.TestCase): + __test__ = False # Prevent interfering with PyTest tests. + + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) + + @parameterized.expand(all_activation_cases) + def test_move_activation_before_concat__conv(self, activation, inplace, is_qat): + input_shape = (1, 3, 8, 8) + model = ConvConcatActivationModule( + activation=activation, inplace=inplace, in_channels=3 + ) + + calibration_inputs = get_random_calibration_inputs( + to_model_input_spec(input_shape) + ) + example_input = calibration_inputs[0] + + exir_program_aten = torch.export.export( + model, example_input, strict=True + ).module() + + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + nodes = list(exir_program_aten.graph.nodes) + assert len(nodes) == 8 + cat_node = nodes[5] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[6] + ) + ) + + # Apply the optimization. + NeutronAtenPassManager( + neutron_target_spec, + [MoveActivationBeforeConcat(neutron_target_spec)], + )(exir_program_aten) + + nodes = list(exir_program_aten.graph.nodes) + + # Make sure the optimization was applied. + assert len(nodes) == 9 + cat_node = nodes[7] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert nodes[8].target == "output" + + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert np.allclose(outputs_before[0], outputs_after[0]) + + # Run pre-processing passes of the float32 aten dialect program. + neutron_aten_pass_manager = NeutronAtenPassManager(neutron_target_spec) + neutron_aten_pass_manager(exir_program_aten) # All passes by default. + + exir_program_aten_quant = calibrate_and_quantize( + exir_program_aten, + calibration_inputs, + NeutronQuantizer(neutron_target_spec), + is_qat=is_qat, + ) + + # Check convolution and activation are in same QDQ cluster. + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 26 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[14] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[15] + ) + ) + assert ( + nodes[16].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[18] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[19] + ) + ) + assert ( + nodes[20].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + @parameterized.expand(all_activation_cases) + def test_move_activation_before_concat__linear(self, activation, inplace, is_qat): + input_shape = (1, 8) + model = LinearConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8, mode="linear" + ) + + calibration_inputs = get_random_calibration_inputs( + to_model_input_spec(input_shape) + ) + example_input = calibration_inputs[0] + + exir_program_aten = torch.export.export( + model, example_input, strict=True + ).module() + + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + nodes = list(exir_program_aten.graph.nodes) + assert len(nodes) == 8 + cat_node = nodes[5] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[6] + ) + ) + + # Apply the optimization. + NeutronAtenPassManager( + neutron_target_spec, + [MoveActivationBeforeConcat(neutron_target_spec)], + )(exir_program_aten) + + nodes = list(exir_program_aten.graph.nodes) + + # Make sure the optimization was applied. + assert len(nodes) == 9 + cat_node = nodes[7] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert nodes[8].target == "output" + + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert np.allclose(outputs_before[0], outputs_after[0]) + + # Run pre-processing passes of the float32 aten dialect program. + neutron_aten_pass_manager = NeutronAtenPassManager(neutron_target_spec) + neutron_aten_pass_manager(exir_program_aten) # All passes by default. + + exir_program_aten_quant = calibrate_and_quantize( + exir_program_aten, + calibration_inputs, + NeutronQuantizer(neutron_target_spec), + is_qat=is_qat, + ) + + # Check linear and activation are in same QDQ cluster. + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 22 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[10] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[11] + ) + ) + assert ( + nodes[12].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[14] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[15] + ) + ) + assert ( + nodes[16].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + @parameterized.expand(all_activation_cases) + def test_move_activation_before_concat__addmm(self, activation, inplace, is_qat): + input_shape = (1, 8) + model = LinearConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8, mode="addmm" + ) + + calibration_inputs = get_random_calibration_inputs( + to_model_input_spec(input_shape) + ) + example_input = calibration_inputs[0] + + exir_program_aten = torch.export.export( + model, example_input, strict=True + ).module() + + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + nodes = list(exir_program_aten.graph.nodes) + assert len(nodes) == 8 + cat_node = nodes[5] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[6] + ) + ) + + # Apply the optimization. + NeutronAtenPassManager( + neutron_target_spec, + [MoveActivationBeforeConcat(neutron_target_spec)], + )(exir_program_aten) + + nodes = list(exir_program_aten.graph.nodes) + + # Make sure the optimization was applied. + assert len(nodes) == 9 + cat_node = nodes[7] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert nodes[8].target == "output" + + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert np.allclose(outputs_before[0], outputs_after[0]) + + # Run pre-processing passes of the float32 aten dialect program. + neutron_aten_pass_manager = NeutronAtenPassManager(neutron_target_spec) + neutron_aten_pass_manager(exir_program_aten) # All passes by default. + + exir_program_aten_quant = calibrate_and_quantize( + exir_program_aten, + calibration_inputs, + NeutronQuantizer(neutron_target_spec), + is_qat=is_qat, + ) + + # Check addmm and activation are in same QDQ cluster. + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 22 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[10] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[11] + ) + ) + assert ( + nodes[12].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[14] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[15] + ) + ) + assert ( + nodes[16].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + @parameterized.expand(all_activation_cases) + def test_move_activation_before_concat__mm(self, activation, inplace, is_qat): + input_shape = (1, 8) + model = LinearConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8, mode="mm" + ) + + calibration_inputs = get_random_calibration_inputs( + to_model_input_spec(input_shape) + ) + example_input = calibration_inputs[0] + + exir_program_aten = torch.export.export( + model, example_input, strict=True + ).module() + + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + nodes = list(exir_program_aten.graph.nodes) + assert len(nodes) == 7 + cat_node = nodes[4] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[5] + ) + ) + + # Apply the optimization. + NeutronAtenPassManager( + neutron_target_spec, + [MoveActivationBeforeConcat(neutron_target_spec)], + )(exir_program_aten) + + nodes = list(exir_program_aten.graph.nodes) + + # Make sure the optimization was applied. + assert len(nodes) == 8 + cat_node = nodes[6] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert nodes[7].target == "output" + + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert np.allclose(outputs_before[0], outputs_after[0]) + + # Run pre-processing passes of the float32 aten dialect program. + neutron_aten_pass_manager = NeutronAtenPassManager(neutron_target_spec) + neutron_aten_pass_manager(exir_program_aten) # All passes by default. + + exir_program_aten_quant = calibrate_and_quantize( + exir_program_aten, + calibration_inputs, + NeutronQuantizer(neutron_target_spec), + is_qat=is_qat, + ) + + # Check mm and activation are in same QDQ cluster. + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 19 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[7] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[8] + ) + ) + assert ( + nodes[9].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[11] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[12] + ) + ) + assert ( + nodes[13].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + @parameterized.expand(all_activation_cases) + def test_move_activation_before_concat_quantization__conv( + self, activation, inplace, use_qat + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 8, 8, 8) + model = ConvConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8 + ) + + edge_program = to_quantized_edge_program( + model, + input_shape, + use_qat=use_qat, + use_neutron_for_format_conversion=False, + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=concat_cluster_ops + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) + + @parameterized.expand(all_activation_cases) + def test_move_activation_before_concat_quantization__linear( + self, activation, inplace, use_qat + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 8) + model = LinearConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8, mode="linear" + ) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=concat_cluster_ops + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + ) + + @parameterized.expand(all_activation_cases) + def test_move_activation_before_concat_quantization__addmm( + self, activation, inplace, use_qat + ): + torch.manual_seed(23) + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 8) + model = LinearConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8, mode="addmm" + ) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=concat_cluster_ops + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand(all_activation_cases) + def test_move_activation_before_concat_quantization__mm( + self, activation, inplace, use_qat + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 8) + model = LinearConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8, mode="mm" + ) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=concat_cluster_ops + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + ) + + @parameterized.expand(all_concat_cluster_cases) + def test_concat_cluster_quantization__conv( + self, activation1, activation2, act1_inplace, act2_inplace, use_qat + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + with kgb.spy_on( + calibrate_and_quantize, call_original=True + ) as quantizer_spy: + input_shape = (1, 8, 8, 8) + model = ConvActivationConcatModule( + activation1, activation2, act1_inplace, act2_inplace, in_channels=8 + ) + + edge_program = to_quantized_edge_program( + model, + input_shape, + use_qat=use_qat, + use_neutron_for_format_conversion=False, + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, + ops=concat_cluster_ops, + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[ + -1 + ].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + exir_program_aten_quant: GraphModule = quantizer_spy.calls[ + -1 + ].return_value + + # Check convolution and activation are in same QDQ cluster. + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 26 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[14] + ) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[15] + ) + assert ( + nodes[16].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[18] + ) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[19] + ) + assert ( + nodes[20].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + input_data = ( + np.random.random(input_shape).astype(np.float32) * 50 + ).astype(np.int8) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) + + @parameterized.expand(all_concat_cluster_cases) + def test_concat_cluster_quantization__linear( + self, activation1, activation2, act1_inplace, act2_inplace, use_qat + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + with kgb.spy_on( + calibrate_and_quantize, call_original=True + ) as quantizer_spy: + input_shape = (1, 8) + model = LinearActivationConcatModule( + activation1, activation2, act1_inplace, act2_inplace, in_channels=8 + ) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, + ops=concat_cluster_ops, + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[ + -1 + ].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + exir_program_aten_quant: GraphModule = quantizer_spy.calls[ + -1 + ].return_value + + # Check linear and activation are in same QDQ cluster. + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 22 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[10] + ) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[11] + ) + assert ( + nodes[12].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[14] + ) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[15] + ) + assert ( + nodes[16].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + input_data = ( + np.random.random(input_shape).astype(np.float32) * 50 + ).astype(np.int8) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) diff --git a/backends/nxp/tests/test_neutron_backend.py b/backends/nxp/tests/test_neutron_backend.py index 53e54ec2f56..867b585ef64 100644 --- a/backends/nxp/tests/test_neutron_backend.py +++ b/backends/nxp/tests/test_neutron_backend.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -21,7 +21,9 @@ def test_neutron_backend__single_conv_model(): def test_neutron_backend__single_conv_model__payload_header_channels_last(): edge_program_manager = to_quantized_edge_program( - Conv2dModule(bias=False), (1, 4, 32, 32) + Conv2dModule(bias=False), + (1, 4, 32, 32), + use_neutron_for_format_conversion=False, ) payload = ( edge_program_manager.exported_program().graph_module.lowered_module_0.processed_bytes @@ -34,7 +36,10 @@ def test_neutron_backend__single_conv_model__payload_header_channels_last(): assert payload[4] == 0x1 # Channels last 0-th Neutron output assert payload[5] == 0x0 # Map 0-th Neutron input to 0-th model input assert payload[6] == 0x0 # Map 0-th Neutron output to 0-th model output - assert all(byte == 0x0 for byte in payload[7:16]) # Aligned to 16 bytes + assert ( + payload[7] == 0x0 or payload[7] == 0x1 + ) # Payload version is 0 or 1 depending on the Neutron Software + assert all(byte == 0x0 for byte in payload[8:16]) # Aligned to 16 bytes assert payload[17] != 0x0 # Followed by non-zero content @@ -51,5 +56,8 @@ def test_neutron_backend__linear_softmax_model__payload_header_formatless(): assert payload[4] == 0x0 # Formatless 0-th Neutron output assert payload[5] == 0x0 # Map 0-th Neutron input to 0-th model input assert payload[6] == 0x0 # Map 0-th Neutron output to 0-th model output - assert all(byte == 0x0 for byte in payload[7:16]) # Aligned to 16 bytes + assert ( + payload[7] == 0x0 or payload[7] == 0x1 + ) # Payload version is 0 or 1 depending on the Neutron Software + assert all(byte == 0x0 for byte in payload[8:16]) # Aligned to 16 bytes assert payload[17] != 0x0 # Followed by non-zero content diff --git a/backends/nxp/tests/test_neutron_backend_executor.py b/backends/nxp/tests/test_neutron_backend_executor.py index 3503403311f..ac87a569bb1 100644 --- a/backends/nxp/tests/test_neutron_backend_executor.py +++ b/backends/nxp/tests/test_neutron_backend_executor.py @@ -11,10 +11,13 @@ ) from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOptions import BuiltinOptions from executorch.backends.nxp.backend.ir.lib.tflite.Model import Model +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.nxp_backend import PayloadComposer from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, EdgeProgramExecutor, + graph_contains_any_of_ops, TFLiteExecutor, ToNHWCPreprocess, ) @@ -108,3 +111,217 @@ def test_conv_fc__lowered_program_and_tflite_output_match(mocker): input_data=input_data, tflite_input_preprocess=ToNHWCPreprocess(), ) + + +def test_delegating_format_related_transpose_operators__unsupported_shapes(mocker): + # This test focuses on the case when Neutron would not support the inserted Transpose operators, so they are not + # inserted, so the runtime will permute the data. + + # Make sure none of the dimensions are multiples of `num_macs` (8), for proper testing. + model = Conv2dModule(in_channels=3, out_channels=3, padding=1, stride=1) + input_shape = (1, 3, 3, 3) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure the `Transpose` ops are NOT in the IR model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 2 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() == BuiltinOptions.PadV2Options + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 8 + # the 4th and 5th bytes indicate the format. `1` means `channels_last`, which means the runtime will transpose the data. + assert all(payload_header[3:5] == [1, 1]) # [, ] + + +def test_delegating_format_related_transpose_operators__supported_case(mocker): + # Make sure the output channels (channels for the trailing Transpose), and the last input dimension (channels for + # the leading Transpose) are multiples of `num_macs``. + + num_macs = NeutronTargetSpec("imxrt700", "SDK_25_09").get_num_macs() + model = Conv2dModule( + in_channels=num_macs, out_channels=num_macs, padding=1, stride=1 + ) + input_shape = (1, num_macs, num_macs, num_macs) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure the `Transpose` ops ARE in the IR model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 4 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() == BuiltinOptions.PadV2Options + ) + assert ( + tflite_subgraph.Operators(2).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + assert ( + tflite_subgraph.Operators(3).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 8 + # the 4th and 5th bytes indicate the format. `0` means `channels_last`, which means the runtime will NOT transpose the data. + assert all(payload_header[3:5] == [0, 0]) # [, ] + + +def test_delegating_format_related_transpose_operators__supported_output__unsupported_input( + mocker, +): + num_macs = NeutronTargetSpec("imxrt700", "SDK_25_09").get_num_macs() + model = Conv2dModule( + in_channels=num_macs, + out_channels=num_macs, # The output `Transpose` will be supported. + padding=1, + stride=1, + ) + input_shape = (1, num_macs, num_macs, 3) # The input `Transpose` is not supported. + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure there is just the 1 `Transpose` in the model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 3 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() == BuiltinOptions.PadV2Options + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + assert ( + tflite_subgraph.Operators(2).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 8 + # the 4th and 5th bytes indicate the format. `1` means `channels_last`, which means the runtime will transpose the data. + assert all(payload_header[3:5] == [1, 0]) # [, ] + + +def test_delegating_format_related_transpose_operators__supported_input__unsupported_output( + mocker, +): + num_macs = NeutronTargetSpec("imxrt700", "SDK_25_09").get_num_macs() + model = Conv2dModule( + in_channels=num_macs, + out_channels=3, # The output `Transpose` will NOT be supported. + stride=1, + ) + input_shape = (1, num_macs, 3, num_macs) # The input `Transpose` is supported. + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure there is just the 1 `Transpose` in the model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 2 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 8 + # the 4th and 5th bytes indicate the format. `1` means `channels_last`, which means the runtime will transpose the data. + assert all(payload_header[3:5] == [0, 1]) # [, ] diff --git a/backends/nxp/tests/test_neutron_converter_manager.py b/backends/nxp/tests/test_neutron_converter_manager.py index af723ec9c7a..410c58620d7 100644 --- a/backends/nxp/tests/test_neutron_converter_manager.py +++ b/backends/nxp/tests/test_neutron_converter_manager.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -13,6 +13,7 @@ from executorch.backends.nxp.backend.neutron_converter_manager import ( NeutronConverterManager, ) +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from executorch.backends.nxp.tests.models import Conv2dModule @@ -23,15 +24,14 @@ def test_conv2d_neutron_conversion__default_flavor(): exir_program = torch.export.export(model, example_input) edge_program_manager = exir.to_edge(exir_program) + NodeFormatInference(edge_program_manager.exported_program()).identify_node_formats() edge_program_converter = EdgeProgramToIRConverter() tflite_model, _ = edge_program_converter.convert_program( edge_program_manager.exported_program() ) neutron_converter_manager = NeutronConverterManager() - neutron_model = neutron_converter_manager.convert( - tflite_model, "imxrt700", "SDK_25_06" - ) + neutron_model = neutron_converter_manager.convert(tflite_model, "imxrt700") assert len( neutron_model @@ -45,15 +45,15 @@ def test__conv2d_neutron_conversion__invalid_flavor(): exir_program = torch.export.export(model, example_input) edge_program_manager = exir.to_edge(exir_program) + NodeFormatInference(edge_program_manager.exported_program()).identify_node_formats() edge_program_converter = EdgeProgramToIRConverter() tflite_model, _ = edge_program_converter.convert_program( edge_program_manager.exported_program() ) - neutron_converter_manager = NeutronConverterManager() with pytest.raises(RuntimeError) as excinfo: - _ = neutron_converter_manager.convert(tflite_model, "imxrt700", "bad_flavor") + _ = NeutronConverterManager("bad_flavor").convert(tflite_model, "imxrt700") - assert "Neutron Converter module with flavor 'bad_flavor' not found." in str( + assert "Neutron Converter module 'neutron_converter_bad_flavor' not found." in str( excinfo ) diff --git a/backends/nxp/tests/test_node_format_inference.py b/backends/nxp/tests/test_node_format_inference.py index e2796187ce8..412c422dc6d 100644 --- a/backends/nxp/tests/test_node_format_inference.py +++ b/backends/nxp/tests/test_node_format_inference.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -9,6 +9,7 @@ from executorch.backends.nxp.backend.node_format_inference import ( NodeFormat, NodeFormatInference, + NXP_NODE_FORMAT, ) from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager from executorch.backends.nxp.tests.models import ( @@ -27,7 +28,7 @@ def test_convolution(): exir_program = torch.export.export(model, example_input) edge_program = exir.to_edge(exir_program).exported_program() - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() expected_mapping = { "p_conv_weight": NodeFormat.CHANNELS_FIRST, @@ -37,8 +38,8 @@ def test_convolution(): "output": NodeFormat.CHANNELS_FIRST, } - for node, node_format in node_formats.items(): - assert expected_mapping[node.name] == node_format + for node in edge_program.graph.nodes: + assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT] def test_softmax(): @@ -48,7 +49,7 @@ def test_softmax(): exir_program = torch.export.export(model, example_input) edge_program = exir.to_edge(exir_program).exported_program() - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() expected_mapping = { "x": NodeFormat.FORMATLESS, @@ -56,8 +57,8 @@ def test_softmax(): "output": NodeFormat.FORMATLESS, } - for node, node_format in node_formats.items(): - assert expected_mapping[node.name] == node_format + for node in edge_program.graph.nodes: + assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT] def test_maxpool2d(): @@ -78,7 +79,7 @@ def test_maxpool2d(): # Remove MaxPool-related "getitem" nodes from graph edge_program = NeutronPassManager(edge_program, [RemoveGetItemPass]).transform() - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() expected_mapping = { "x": NodeFormat.CHANNELS_FIRST, @@ -86,5 +87,5 @@ def test_maxpool2d(): "output": NodeFormat.CHANNELS_FIRST, } - for node, node_format in node_formats.items(): - assert expected_mapping[node.name] == node_format + for node in edge_program.graph.nodes: + assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT] diff --git a/backends/nxp/tests/test_per_channel_conversion.py b/backends/nxp/tests/test_per_channel_conversion.py new file mode 100644 index 00000000000..b3034ff17ed --- /dev/null +++ b/backends/nxp/tests/test_per_channel_conversion.py @@ -0,0 +1,182 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import kgb +import numpy as np +import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.quantizer.neutron_quantizer import ( + act_qspec, + NeutronAtenQuantizer, + wgt_qspec, +) +from executorch.backends.nxp.quantizer.patterns import ( + NodeArgsIdx, + PartitionAnchors, + QuantizationPattern, +) +from executorch.backends.nxp.quantizer.utils import get_bias_qparams +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, +) +from executorch.backends.nxp.tests.models import Conv2dModule +from executorch.exir.dialects._ops import ops as exir_ops +from parameterized import parameterized + +from torch import fx +from torch._ops import OpOverload +from torch.export import ExportedProgram +from torchao.quantization.pt2e import ( + FusedMovingAvgObsFakeQuantize, + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, +) +from torchao.quantization.pt2e.quantizer import ( + DerivedQuantizationSpec, + QuantizationConfig, + QuantizationSpec, +) + + +class Conv2dPatternPerChannel(QuantizationPattern): + + def __init__(self, is_per_channel: bool, is_qat: bool): + super().__init__(is_qat=is_qat) + self.is_per_channel = is_per_channel + + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.conv2d.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors: + conv2d_node = fused_partition[0].nodes[-1] + + bias_qscheme = ( + torch.per_channel_symmetric + if self.is_per_channel + else torch.per_tensor_symmetric + ) + bias_quantization_qspec = DerivedQuantizationSpec( + derived_from=[ + (conv2d_node.args[0], conv2d_node), + (conv2d_node.args[1], conv2d_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31) + 1, + quant_max=2**31 - 1, + qscheme=bias_qscheme, + ch_axis=0, + ) + + weight_qscheme = ( + torch.per_channel_symmetric + if self.is_per_channel + else torch.per_tensor_symmetric + ) + if self.is_qat: + observer = ( + MovingAveragePerChannelMinMaxObserver + if self.is_per_channel + else MovingAverageMinMaxObserver + ) + weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + observer=observer + ) + else: + weight_observer_or_fake_quant_ctr = ( + PerChannelMinMaxObserver if self.is_per_channel else MinMaxObserver + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, + quant_min=-127, + quant_max=127, + qscheme=weight_qscheme, + ch_axis=0, + ) + + return PartitionAnchors( + inputs=[(conv2d_node, NodeArgsIdx(0))], + weights=[(conv2d_node, NodeArgsIdx(1), weight_quantization_spec)], + biases=[(conv2d_node, NodeArgsIdx(2), bias_quantization_qspec)], + output=[(conv2d_node,)], + ) + + +class TestPerChannelConversion(unittest.TestCase): + __test__ = False # Prevent interfering with PyTest tests + + @classmethod + def setUpClass(cls): + torch.manual_seed(25) + np.random.seed(25) + + @parameterized.expand([("QAT", True), ("PTQ", False)]) + def test_per_channel_convolution(self, _, use_qat: bool): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + model = Conv2dModule( + in_channels=8, out_channels=32, kernel_size=5, padding=3 + ) + input_shape = (1, 8, 32, 32) + + activation_qspec = act_qspec(is_qat=use_qat) + static_qconfig = QuantizationConfig( + activation_qspec, activation_qspec, wgt_qspec, None + ) + _ = to_quantized_edge_program( + model, + input_shape, + get_quantizer_fn=lambda: NeutronAtenQuantizer( + Conv2dPatternPerChannel(is_per_channel=True, is_qat=use_qat), + static_qconfig, + ), + use_qat=use_qat, + use_neutron_for_format_conversion=False, + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + tflite_input_preprocess=ToChannelLastPreprocess(), + tfl_model=tflite_flatbuffers_model, + tflite_output_preprocess=ToChannelFirstPreprocess(), + input_data=input_data, + atol=1.0, + ) + + nodes = list(exported_program.graph.nodes) + + assert ( + nodes[8].target + == exir_ops.edge.quantized_decomposed.dequantize_per_channel.default + ) + assert ( + nodes[9].target + == exir_ops.edge.quantized_decomposed.dequantize_per_channel.default + ) + assert nodes[10].target == exir_ops.edge.aten.convolution.default diff --git a/backends/nxp/tests/test_qdq_clustering_conv.py b/backends/nxp/tests/test_qdq_clustering_conv.py index 1713aace1fe..ffae931dbb4 100644 --- a/backends/nxp/tests/test_qdq_clustering_conv.py +++ b/backends/nxp/tests/test_qdq_clustering_conv.py @@ -16,13 +16,13 @@ def test_conv2d_partitioner(): lowered_module = edge_program.exported_program().graph_module.lowered_module_0 nodes = list(lowered_module.original_module.graph.nodes) - assert len(nodes) == 7 + assert len(nodes) == 9 - q_x_node = nodes[1] - dq_w_node = nodes[2] - dq_x_node = nodes[3] - conv_node = nodes[4] - q_y_node = nodes[5] + q_x_node = nodes[3] + dq_w_node = nodes[4] + dq_x_node = nodes[5] + conv_node = nodes[6] + q_y_node = nodes[7] assert "cluster" not in q_x_node.meta assert dq_w_node.meta["cluster"] == "aten_convolution_default_cluster" diff --git a/backends/nxp/tests/test_quantizer.py b/backends/nxp/tests/test_quantizer.py index ef5fbb0cbca..27422f9ce1e 100644 --- a/backends/nxp/tests/test_quantizer.py +++ b/backends/nxp/tests/test_quantizer.py @@ -1,20 +1,86 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # Tests for NeutronQuantizer. +import itertools from copy import deepcopy +import executorch.backends.nxp.tests.executorch_pipeline as executorch_pipeline import executorch.backends.nxp.tests.models as models +import numpy as np +import pytest import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) + from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer -from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from executorch.backends.nxp.tests.executorch_pipeline import ( + neutron_target_spec, + to_quantized_edge_program, +) +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + graph_contains_any_of_ops, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import export, ExportedProgram +from torch.fx import GraphModule +from torchao.quantization.pt2e import ( + move_exported_model_to_eval, + move_exported_model_to_train, +) +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) + +fuse_activation_ops = [ + exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.tanh.default, +] + + +# Permutation of all supported combinations of: +# , , +all_activation_cases = list( + itertools.product( + ["relu", "relu6", "tanh"], + [True, False], + [True, False], + ) +) + [ + ("sigmoid", False, True), + ("sigmoid", False, False), +] + +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(23) -def _get_target_name(node): - return node._pretty_print_target(node.target) + +def _prepare_for_quantization(exported_model, is_qat: bool = False): + if is_qat: + return prepare_qat_pt2e( + exported_model.module(), NeutronQuantizer(neutron_target_spec, is_qat=True) + ) + else: + return prepare_pt2e( + exported_model.module(), NeutronQuantizer(neutron_target_spec) + ) def test_quantizer_conv2d(): @@ -22,11 +88,10 @@ def test_quantizer_conv2d(): model.eval() example_input = (torch.ones(1, 4, 32, 32),) - quantizer = NeutronQuantizer() - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -34,26 +99,25 @@ def test_quantizer_conv2d(): m(*example_input) nodes = list(m.graph.nodes) - assert len(nodes) == 11 - assert nodes[7].name == "conv2d" + assert len(nodes) == 15 + assert nodes[11].name == "conv2d" # [0]: Input, [1] : weights, [2]: bias assert ( - _get_target_name(nodes[7].args[0]) - == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + nodes[11].args[0].target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default ) assert ( - _get_target_name(nodes[7].args[1]) - == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + nodes[11].args[1].target + == torch.ops.quantized_decomposed.dequantize_per_channel.default ) assert ( - _get_target_name(nodes[7].args[2]) - == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + nodes[11].args[2].target + == torch.ops.quantized_decomposed.dequantize_per_channel.default ) assert ( - _get_target_name(nodes[8]) - == "torch.ops.quantized_decomposed.quantize_per_tensor.default" + nodes[12].target == torch.ops.quantized_decomposed.quantize_per_tensor.default ) - assert nodes[8].args[0].name == "conv2d" + assert nodes[12].args[0].target == torch.ops.aten.conv2d.default def test_quantizer_linear(): @@ -61,11 +125,10 @@ def test_quantizer_linear(): model.eval() example_input = (torch.ones(10, 32),) - quantizer = NeutronQuantizer() - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -77,22 +140,19 @@ def test_quantizer_linear(): assert nodes[7].name == "linear" # [0]: Input, [1] : weights, [2]: bias assert ( - _get_target_name(nodes[7].args[0]) - == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" - ) - assert ( - _get_target_name(nodes[7].args[1]) - == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + nodes[7].args[0].target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default ) assert ( - _get_target_name(nodes[7].args[2]) - == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + nodes[7].args[1].target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default ) assert ( - _get_target_name(nodes[8]) - == "torch.ops.quantized_decomposed.quantize_per_tensor.default" + nodes[7].args[2].target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default ) - assert nodes[8].args[0].name == "linear" + assert nodes[8].target == torch.ops.quantized_decomposed.quantize_per_tensor.default + assert nodes[8].args[0].target == torch.ops.aten.linear.default def test_quantizer_maxpool2d(): @@ -100,11 +160,10 @@ def test_quantizer_maxpool2d(): model.eval() example_input = (torch.ones(1, 8, 32, 32),) - quantizer = NeutronQuantizer() - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -112,22 +171,21 @@ def test_quantizer_maxpool2d(): m(*example_input) nodes = list(m.graph.nodes) - assert len(nodes) == 14 + assert len(nodes) == 18 # Check if QDQ pattern: - assert nodes[10].name == "max_pool2d" + assert nodes[14].target == torch.ops.aten.max_pool2d.default assert ( - _get_target_name(nodes[10].args[0]) - == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + nodes[14].args[0].target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default ) assert ( - _get_target_name(nodes[11]) - == "torch.ops.quantized_decomposed.quantize_per_tensor.default" + nodes[15].target == torch.ops.quantized_decomposed.quantize_per_tensor.default ) - assert nodes[11].args[0].name == "max_pool2d" + assert nodes[15].args[0].target == torch.ops.aten.max_pool2d.default # Check if input and output quantization is same - input_quant = nodes[10].args[0].args[1:] - output_quant = nodes[11].args[1:] + input_quant = nodes[14].args[0].args[1:] + output_quant = nodes[15].args[1:] assert input_quant == output_quant @@ -136,11 +194,10 @@ def test_quantizer_softmax(): model.eval() example_input = (torch.ones(1, 10),) - quantizer = NeutronQuantizer() - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -150,16 +207,13 @@ def test_quantizer_softmax(): nodes = list(m.graph.nodes) assert len(nodes) == 7 # Check if QDQ pattern: - assert nodes[3].name == "softmax" + assert nodes[3].target == torch.ops.aten.softmax.int assert ( - _get_target_name(nodes[3].args[0]) - == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + nodes[3].args[0].target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default ) - assert ( - _get_target_name(nodes[4]) - == "torch.ops.quantized_decomposed.quantize_per_tensor.default" - ) - assert nodes[4].args[0].name == "softmax" + assert nodes[4].target == torch.ops.quantized_decomposed.quantize_per_tensor.default + assert nodes[4].args[0].target == torch.ops.aten.softmax.int # Check output quantization scale, zp, _, _, dtype = nodes[4].args[1:] @@ -173,11 +227,10 @@ def test_quantizer_single_maxpool2d(): model.eval() example_input = (torch.ones(1, 4, 32, 32),) - quantizer = NeutronQuantizer() - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -186,7 +239,7 @@ def test_quantizer_single_maxpool2d(): nodes = list(m.graph.nodes) assert len(nodes) == 7 - assert nodes[3].name == "max_pool2d" + assert nodes[3].target == torch.ops.aten.max_pool2d.default assert "quantization_annotation" not in nodes[1].meta @@ -195,11 +248,10 @@ def test_quantizer_conv2d_relu(): model.eval() example_input = (torch.ones(1, 4, 32, 32),) - quantizer = NeutronQuantizer() - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -207,10 +259,14 @@ def test_quantizer_conv2d_relu(): m(*example_input) nodes = list(m.graph.nodes) + assert len(nodes) == 12 - assert nodes[7].name == "dequantize_per_tensor_default_2" - assert nodes[8].name == "relu" - assert nodes[9].name == "quantize_per_tensor_default_3" + assert ( + nodes[6].target == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + assert nodes[7].target == torch.ops.aten.conv2d.default + assert nodes[8].target == torch.ops.aten.relu.default + assert nodes[9].target == torch.ops.quantized_decomposed.quantize_per_tensor.default def test_quantizer_conv2d_avg_pool2d(): @@ -218,11 +274,10 @@ def test_quantizer_conv2d_avg_pool2d(): model.eval() example_input = (torch.ones(1, 4, 16, 16),) - quantizer = NeutronQuantizer() - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -230,10 +285,15 @@ def test_quantizer_conv2d_avg_pool2d(): m(*example_input) nodes = list(m.graph.nodes) - assert len(nodes) == 14 - assert nodes[9].name == "dequantize_per_tensor_default_3" - assert nodes[10].name == "avg_pool2d" - assert nodes[11].name == "quantize_per_tensor_default_4" + + assert len(nodes) == 18 + assert ( + nodes[13].target == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + assert nodes[14].target == torch.ops.aten.avg_pool2d.default + assert ( + nodes[15].target == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) def test_quantizer_conv2d_permute(): @@ -241,11 +301,10 @@ def test_quantizer_conv2d_permute(): model.eval() example_input = (torch.ones(1, 4, 16, 16),) - quantizer = NeutronQuantizer() - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -253,10 +312,15 @@ def test_quantizer_conv2d_permute(): m(*example_input) nodes = list(m.graph.nodes) - assert len(nodes) == 12 - assert nodes[7].name == "dequantize_per_tensor_default_2" - assert nodes[8].name == "permute" - assert nodes[9].name == "quantize_per_tensor_default_3" + + assert len(nodes) == 14 + assert ( + nodes[9].target == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + assert nodes[10].target == torch.ops.aten.permute.default + assert ( + nodes[11].target == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) def test_multiple_shared_spec_ops_in_row(): @@ -268,11 +332,10 @@ def test_multiple_shared_spec_ops_in_row(): model.eval() example_input = (torch.ones(1, 3, 64, 64),) - quantizer = NeutronQuantizer() - graph_module = torch.export.export(model, example_input, strict=True).module() + exported_model = torch.export.export(model, example_input, strict=True) # noinspection PyTypeChecker - m = prepare_pt2e(graph_module, quantizer) + m = _prepare_for_quantization(exported_model) m(*example_input) m = convert_pt2e(m) @@ -282,14 +345,18 @@ def test_multiple_shared_spec_ops_in_row(): nodes = list(m.graph.nodes) assert len(nodes) == 15 - assert nodes[-5].name == "dequantize_per_tensor_default_3" - assert nodes[-4].name == "max_pool2d" - assert nodes[-3].name == "quantize_per_tensor_default_4" + assert ( + nodes[-5].target == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + assert nodes[-4].target == torch.ops.aten.max_pool2d.default + assert ( + nodes[-3].target == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) # Assert that post-ReLU quantize and pre-MaxPool dequantize has same specs assert nodes[-6].args[1:] == nodes[-5].args[1:] # Assert that post-Conv quantize and pre-ReLU dequantize has same specs - assert nodes[6].args[1:] == nodes[7].args[1:] + assert nodes[5].args[1:] == nodes[6].args[1:] def test_quantizers_order_invariance(): @@ -301,7 +368,7 @@ def test_quantizers_order_invariance(): model.eval() example_input = (torch.ones(1, 4, 64, 64),) - quantizer = NeutronQuantizer() + quantizer = NeutronQuantizer(neutron_target_spec) graph_module = torch.export.export(model, example_input, strict=True).module() @@ -323,3 +390,249 @@ def test_quantizers_order_invariance(): assert len(nodes) == len(nodes_reversed) assert all(n == n_reversed for n, n_reversed in zip(nodes, nodes_reversed)) + + +@pytest.mark.parametrize("activation, inplace, use_qat", all_activation_cases) +def test_quantizer__linear_w_activation(mocker, activation, inplace, use_qat): + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + quantizer_spy = mocker.spy(executorch_pipeline, "calibrate_and_quantize") + + input_shape = (1, 4) + model = models.LinearActivationModule( + activation=activation, + inplace=inplace, + in_channels=input_shape[1], + mode="linear", + ) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, + ops=fuse_activation_ops, + ) + assert any("lowered_module" in node.name for node in edge_program.graph.nodes) + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + exir_program_aten_quant: GraphModule = quantizer_spy.spy_return + + # Check linear and activation are in the same QDQ cluster + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 12 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[7] + ) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[8] + ) + assert nodes[9].target == torch.ops.quantized_decomposed.quantize_per_tensor.default + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + +@pytest.mark.parametrize("activation, inplace, use_qat", all_activation_cases) +def test_quantizer__addmm_w_activation(mocker, activation, inplace, use_qat): + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + quantizer_spy = mocker.spy(executorch_pipeline, "calibrate_and_quantize") + + input_shape = (1, 4) + model = models.LinearActivationModule( + activation=activation, inplace=inplace, in_channels=input_shape[1], mode="addmm" + ) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, + ops=fuse_activation_ops, + ) + assert any("lowered_module" in node.name for node in edge_program.graph.nodes) + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + exir_program_aten_quant: GraphModule = quantizer_spy.spy_return + + # Check linear and activation are in the same QDQ cluster + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 12 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[7] + ) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[8] + ) + assert nodes[9].target == torch.ops.quantized_decomposed.quantize_per_tensor.default + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + +@pytest.mark.parametrize("activation, inplace, use_qat", all_activation_cases) +def test_quantizer__mm_w_activation(mocker, activation, inplace, use_qat): + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + quantizer_spy = mocker.spy(executorch_pipeline, "calibrate_and_quantize") + + input_shape = (1, 4) + model = models.LinearActivationModule( + activation=activation, inplace=inplace, in_channels=input_shape[1], mode="mm" + ) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, + ops=fuse_activation_ops, + ) + assert any("lowered_module" in node.name for node in edge_program.graph.nodes) + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + exir_program_aten_quant: GraphModule = quantizer_spy.spy_return + + # Check linear and activation are in the same QDQ cluster + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 10 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[5] + ) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[6] + ) + assert nodes[7].target == torch.ops.quantized_decomposed.quantize_per_tensor.default + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + +@pytest.mark.parametrize("activation, inplace, use_qat", all_activation_cases) +def test_quantizer__conv_w_activation(mocker, activation, inplace, use_qat): + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + quantizer_spy = mocker.spy(executorch_pipeline, "calibrate_and_quantize") + + input_shape = (1, 4, 8, 8) + model = models.ConvActivationModule( + activation=activation, inplace=inplace, in_channels=input_shape[1] + ) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=use_qat + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, + ops=fuse_activation_ops, + ) + assert any("lowered_module" in node.name for node in edge_program.graph.nodes) + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + exir_program_aten_quant: GraphModule = quantizer_spy.spy_return + + # Check linear and activation are in the same QDQ cluster + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 16 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[11] + ) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[12] + ) + assert ( + nodes[13].target == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), + atol=1.0, + ) + + +def test_qat_train(loss_tolerance: float = 0.02): + def evaluate(model, inputs, gts): + with torch.no_grad(): + test_outputs = model(inputs) + loss = torch.nn.functional.mse_loss(test_outputs, gts) + return loss + + def train_step(model, optimizer): + optimizer.zero_grad() + batch = torch.randn(100, 1).clamp(-1, 1) + outputs = model(batch) + loss = torch.nn.functional.mse_loss(outputs, torch.sin(batch)) + loss.backward() + optimizer.step() + + model = models.MLP() + model.train() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + for _ in range(100): + train_step(model, optimizer) + + test_inputs = torch.randn(20, 1).clamp(-1, 1) + + model.eval() + eval_loss = evaluate(model, test_inputs, torch.sin(test_inputs)) + + exported_model = export(model, (torch.randn(1, 1),), strict=True) + prepared_model = _prepare_for_quantization(exported_model, is_qat=True) + + prepared_model = move_exported_model_to_train(prepared_model) + for _ in range(30): + train_step(prepared_model, optimizer) + prepared_model = move_exported_model_to_eval(prepared_model) + + quantized_model = convert_pt2e(prepared_model) + + test_inputs = torch.randn(100, 1).clamp(-1, 1) + + quant_eval_loss = evaluate(quantized_model, test_inputs, torch.sin(test_inputs)) + + assert (quant_eval_loss - eval_loss) < loss_tolerance + + +def test_qat_produces_same_graph_as_ptq(): + model = models.MiniConvNetWithRegressionHead() + model.eval() + exported_model = export(model, ((torch.randn(1, 3, 32, 32),)), strict=True) + + qat_prepared_model = _prepare_for_quantization(exported_model, is_qat=True) + qat_quantized_model = convert_pt2e(qat_prepared_model) + + ptq_prepared_model = _prepare_for_quantization(exported_model, is_qat=False) + ptq_quantized_model = convert_pt2e(ptq_prepared_model) + + assert all( + ptqn.target == qatn.target + for qatn, ptqn in zip( + qat_quantized_model.graph.nodes, ptq_quantized_model.graph.nodes + ) + ) diff --git a/backends/nxp/tests/test_removing_dead_code.py b/backends/nxp/tests/test_removing_dead_code.py index 7b8641fb247..8b3a979f412 100644 --- a/backends/nxp/tests/test_removing_dead_code.py +++ b/backends/nxp/tests/test_removing_dead_code.py @@ -9,8 +9,11 @@ import pytest import torch -from executorch.backends.nxp.tests.executorch_pipeline import _quantize_model +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.quantizer.utils import calibrate_and_quantize +from executorch.backends.nxp.tests.executorch_pipeline import neutron_target_spec from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops +from parameterized import parameterized @pytest.fixture(autouse=True) @@ -32,7 +35,13 @@ def forward(self, x): class TestRemovingDeadCode(unittest.TestCase): __test__ = False # Prevent interfering with PyTest tests - def test_removing_dead_code(self): + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(23) + + @parameterized.expand([("QAT", True), ("PTQ", False)]) + def test_removing_dead_code(self, _, is_qat: bool): input_shape = (42,) example_inputs = (torch.ones(input_shape),) model = DeadCodeModule() @@ -45,16 +54,12 @@ def test_removing_dead_code(self): ) # The `NeutronQuantizer` should remove the dead code in the `transform_for_annotation()` method. - exir_program_aten_quant = _quantize_model( - exir_program_aten.module(), [example_inputs] + quantizer = NeutronQuantizer(neutron_target_spec) + exir_program_aten_quant = calibrate_and_quantize( + exir_program_aten, [example_inputs], quantizer, is_qat=is_qat ) # Make sure the is no `add` operation in the graph anymore. assert not any( "add" in str(node.target) for node in exir_program_aten_quant.graph.nodes ) - - @classmethod - def setUpClass(cls): - torch.manual_seed(23) - np.random.seed(23) diff --git a/backends/nxp/tests/test_removing_nodes_with_known_outputs.py b/backends/nxp/tests/test_removing_nodes_with_known_outputs.py index 8f5549c8526..0c496356791 100644 --- a/backends/nxp/tests/test_removing_nodes_with_known_outputs.py +++ b/backends/nxp/tests/test_removing_nodes_with_known_outputs.py @@ -17,6 +17,7 @@ from executorch.backends.nxp.aten_passes.split_gru_based_on_num_layers import ( SplitGRUBasedOnNumLayers, ) +from executorch.backends.nxp.tests.executorch_pipeline import neutron_target_spec from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops from parameterized import parameterized from torch import nn @@ -57,7 +58,9 @@ def test_removing_nodes__zeros(self): outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] # Apply the optimization. - NeutronAtenPassManager([RemoveNodesWithKnownOutputs()])(exir_program_aten) + NeutronAtenPassManager(neutron_target_spec, [RemoveNodesWithKnownOutputs()])( + exir_program_aten + ) # Make sure the `aten.zeros` is no longer in the model. assert not graph_contains_any_of_ops( @@ -81,7 +84,9 @@ def test_removing_nodes__split(self, num_layers): exir_program_aten = torch.export.export(model, example_input).module() # Apply the pass to split the `aten.gru.input` into multiple instances, and add a `split` node. - NeutronAtenPassManager([SplitGRUBasedOnNumLayers()])(exir_program_aten) + NeutronAtenPassManager(neutron_target_spec, [SplitGRUBasedOnNumLayers()])( + exir_program_aten + ) # Make sure the `aten.zeros` and `torch.split` are in the model. assert graph_contains_any_of_ops( @@ -93,7 +98,9 @@ def test_removing_nodes__split(self, num_layers): outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] # Apply the optimization. - NeutronAtenPassManager([RemoveNodesWithKnownOutputs()])(exir_program_aten) + NeutronAtenPassManager(neutron_target_spec, [RemoveNodesWithKnownOutputs()])( + exir_program_aten + ) # Make sure the `aten.zeros` and `torch.split` are no longer in the model. assert not graph_contains_any_of_ops( diff --git a/backends/nxp/tests/test_split_group_convolution.py b/backends/nxp/tests/test_split_group_convolution.py index 1da53af794d..e8d807963ee 100644 --- a/backends/nxp/tests/test_split_group_convolution.py +++ b/backends/nxp/tests/test_split_group_convolution.py @@ -17,9 +17,11 @@ ) from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.quantizer.utils import calibrate_and_quantize from executorch.backends.nxp.tests.executorch_pipeline import ( - _quantize_model, get_random_calibration_inputs, + neutron_target_spec, to_model_input_spec, ) from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops @@ -36,11 +38,16 @@ def _quantize_and_lower_module( - module: GraphModule, input_shape: tuple[int, ...], target="imxrt700" + module: GraphModule, input_shape: tuple[int, ...], is_qat: bool, target="imxrt700" ) -> EdgeProgramManager: calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_shape)) - exir_program_aten__module_quant = _quantize_model(module, calibration_inputs) + exir_program_aten__module_quant = calibrate_and_quantize( + module, + calibration_inputs, + NeutronQuantizer(neutron_target_spec), + is_qat=is_qat, + ) edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) edge_program_manager = export_to_edge( @@ -49,8 +56,8 @@ def _quantize_and_lower_module( edge_compile_config=edge_compile_config, ) - compile_spec = generate_neutron_compile_spec(target, "SDK_25_06") - partitioner = NeutronPartitioner(compile_spec) + compile_spec = generate_neutron_compile_spec(target, "SDK_25_09") + partitioner = NeutronPartitioner(compile_spec, neutron_target_spec) return edge_program_manager.to_backend(partitioner) @@ -64,12 +71,17 @@ def setUp(cls): @parameterized.expand( [ - ["group = 2", [1, 16, 10, 10], 2], - ["group = 3", [1, 24, 10, 10], 3], - ["group = 8", [1, 8, 10, 10], 8], + ["QAT; group = 2", [1, 16, 10, 10], 2, True], + ["PTQ; group = 2", [1, 16, 10, 10], 2, False], + ["QAT; group = 3", [1, 24, 10, 10], 3, True], + ["PTQ; group = 3", [1, 24, 10, 10], 3, False], + ["QAT; group = 8", [1, 8, 10, 10], 8, True], + ["PTQ; group = 8", [1, 8, 10, 10], 8, False], ] ) - def test_split_group_convolution__2d(self, _, input_shape: list[int], group: int): + def test_split_group_convolution__2d( + self, _, input_shape: list[int], group: int, is_qat: bool + ): example_input = (torch.ones(input_shape),) module = Conv2dModule( @@ -83,9 +95,9 @@ def test_split_group_convolution__2d(self, _, input_shape: list[int], group: int graph_module = torch.export.export(module, example_input, strict=True).module() original_module = deepcopy(graph_module) - modified_module = NeutronAtenPassManager([SplitGroupConvolution()])( - graph_module - ).graph_module + modified_module = NeutronAtenPassManager( + neutron_target_spec, [SplitGroupConvolution()] + )(graph_module).graph_module # Make sure the fusion worked. original_nodes = list(original_module.graph.nodes) @@ -106,11 +118,11 @@ def test_split_group_convolution__2d(self, _, input_shape: list[int], group: int input_data = torch.randn(input_shape, dtype=torch.float32) out1 = original_module(input_data).detach().numpy() out2 = modified_module(input_data).detach().numpy() - assert np.allclose(out1, out2, atol=2.0e-7) + assert np.allclose(out1, out2, atol=2.0e-7, rtol=1.9e-4) # Make sure the graph can be correctly quantized and lowered to edge. ep = _quantize_and_lower_module( - modified_module, tuple(input_shape) + modified_module, tuple(input_shape), is_qat=is_qat ).exported_program() nodes = list(ep.graph.nodes) assert nodes[-5].name == "lowered_module_0" @@ -121,12 +133,17 @@ def test_split_group_convolution__2d(self, _, input_shape: list[int], group: int @parameterized.expand( [ - ["group = 2", [1, 16, 10], 2], - ["group = 3", [1, 24, 10], 3], - ["group = 6", [1, 24, 10], 6], + ["QAT; group = 2", [1, 16, 10], 2, True], + ["PTQ; group = 2", [1, 16, 10], 2, False], + ["QAT; group = 3", [1, 24, 10], 3, True], + ["PTQ; group = 3", [1, 24, 10], 3, False], + ["QAT; group = 6", [1, 24, 10], 6, True], + ["PTQ; group = 6", [1, 24, 10], 6, False], ] ) - def test_split_group_convolution__1d(self, _, input_shape: list[int], group: int): + def test_split_group_convolution__1d( + self, _, input_shape: list[int], group: int, is_qat: bool + ): example_input = (torch.ones(input_shape),) module = Conv1dModule( @@ -140,9 +157,9 @@ def test_split_group_convolution__1d(self, _, input_shape: list[int], group: int graph_module = torch.export.export(module, example_input).module() original_module = deepcopy(graph_module) - modified_module = NeutronAtenPassManager([SplitGroupConvolution()])( - graph_module - ).graph_module + modified_module = NeutronAtenPassManager( + neutron_target_spec, [SplitGroupConvolution()] + )(graph_module).graph_module # Make sure the fusion worked. original_nodes = list(original_module.graph.nodes) @@ -167,7 +184,7 @@ def test_split_group_convolution__1d(self, _, input_shape: list[int], group: int # Make sure the graph can be correctly quantized and lowered to edge. ep = _quantize_and_lower_module( - modified_module, tuple(input_shape) + modified_module, tuple(input_shape), is_qat=is_qat ).exported_program() nodes = list(ep.graph.nodes) assert nodes[-5].name == "lowered_module_0" @@ -194,9 +211,9 @@ def test_split_group_convolution__3d(self, _, input_shape: list[int], group: int graph_module = torch.export.export(module, example_input).module() original_module = deepcopy(graph_module) - modified_module = NeutronAtenPassManager([SplitGroupConvolution()])( - graph_module - ).graph_module + modified_module = NeutronAtenPassManager( + neutron_target_spec, [SplitGroupConvolution()] + )(graph_module).graph_module # Verify that the pass has NOT made any changes, as it is disabled for 3D convolution. original_nodes = list(original_module.graph.nodes) @@ -213,7 +230,8 @@ def test_split_group_convolution__3d(self, _, input_shape: list[int], group: int out2 = modified_module(input_data).detach().numpy() assert np.allclose(out1, out2) - def test_split_group_convolution__applied_by_default(self): + @parameterized.expand([("QAT", True), ("PTQ", False)]) + def test_split_group_convolution__applied_by_default(self, _, is_qat: bool): input_shape = [1, 16, 10, 10] group = 2 example_input = (torch.ones(input_shape),) @@ -228,7 +246,7 @@ def test_split_group_convolution__applied_by_default(self): graph_module = torch.export.export(module, example_input).module() original_module = deepcopy(graph_module) - modified_module = NeutronAtenPassManager()( + modified_module = NeutronAtenPassManager(neutron_target_spec)( graph_module ).graph_module # Default passes. @@ -255,7 +273,7 @@ def test_split_group_convolution__applied_by_default(self): # Make sure the graph can be correctly quantized and lowered to edge. ep = _quantize_and_lower_module( - modified_module, tuple(input_shape) + modified_module, tuple(input_shape), is_qat=is_qat ).exported_program() nodes = list(ep.graph.nodes) assert nodes[-5].name == "lowered_module_0" diff --git a/backends/nxp/tests/use_qat.py b/backends/nxp/tests/use_qat.py new file mode 100644 index 00000000000..5994d5aa193 --- /dev/null +++ b/backends/nxp/tests/use_qat.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.fixture +def use_qat(request): + return request.param + + +def pytest_generate_tests(metafunc): + if "use_qat" in metafunc.fixturenames: + metafunc.parametrize("use_qat", [True, False], indirect=True) diff --git a/backends/openvino/CMakeLists.txt b/backends/openvino/CMakeLists.txt index 4d32d8932c2..736ed6d8603 100644 --- a/backends/openvino/CMakeLists.txt +++ b/backends/openvino/CMakeLists.txt @@ -53,35 +53,11 @@ target_sources( executorch_target_link_options_shared_lib(openvino_backend) -if(EXECUTORCH_BUILD_OPENVINO_EXECUTOR_RUNNER) - # Build executor runner binary for openvino backend - list(APPEND openvino_executor_runner_libs openvino_backend executorch) - - set(_openvino_executor_runner__srcs - ${EXECUTORCH_ROOT}/examples/portable/executor_runner/executor_runner.cpp - ${EXECUTORCH_ROOT}/extension/data_loader/file_data_loader.cpp - ${EXECUTORCH_ROOT}/extension/evalue_util/print_evalue.cpp - ${EXECUTORCH_ROOT}/extension/runner_util/inputs.cpp - ${EXECUTORCH_ROOT}/extension/runner_util/inputs_portable.cpp - ) - add_executable(openvino_executor_runner ${_openvino_executor_runner__srcs}) - - list(APPEND openvino_executor_runner_libs) - - target_link_libraries( - openvino_executor_runner gflags portable_ops_lib - ${openvino_executor_runner_libs} - ) - target_compile_options( - openvino_executor_runner PUBLIC ${_common_compile_options} - ) -endif() - # Install OpenVINO backend library to the lib directory install( TARGETS openvino_backend EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) diff --git a/backends/openvino/README.md b/backends/openvino/README.md index a67cf12eca2..5ce38ade56f 100644 --- a/backends/openvino/README.md +++ b/backends/openvino/README.md @@ -18,6 +18,11 @@ For more information on the supported hardware, please refer to [OpenVINO System executorch ├── backends │ └── openvino +│ ├── quantizer +│ ├── observers +│ └── nncf_observers.py +│ ├── __init__.py +│ └── quantizer.py │ ├── runtime │ ├── OpenvinoBackend.cpp │ └── OpenvinoBackend.h @@ -42,11 +47,23 @@ executorch Before you begin, ensure you have openvino installed and configured on your system. -### Build OpenVINO from Source +### Use OpenVINO from Release Packages + +1. Download the OpenVINO release package from [here](https://docs.openvino.ai/2025/get-started/install-openvino.html). Make sure to select your configuration and click on **OpenVINO Archives** under the distribution section to download the appropriate archive for your platform. + +2. Extract the release package from the archive and set the environment variables. + + ```bash + tar -zxf openvino_toolkit_.tgz + cd openvino_toolkit_ + source setupvars.sh + ``` + +### (Optional) Build OpenVINO from Source ```bash git clone https://github.com/openvinotoolkit/openvino.git -cd openvino && git checkout b16b776ac119dafda51f69a80f1e6b7376d02c3b +cd openvino git submodule update --init --recursive sudo ./install_build_dependencies.sh mkdir build && cd build @@ -59,44 +76,45 @@ cd source setupvars.sh ``` -### Use OpenVINO from Release Packages - -1. Download the OpenVINO release package from [here](https://docs.openvino.ai/2025/get-started/install-openvino.html). Make sure to select your configuration and click on **OpenVINO Archives** under the distribution section to download the appropriate archive for your platform. - -2. Extract the release package from the archive and set the environment variables. - - ```bash - tar -zxf openvino_toolkit_.tgz - cd openvino_toolkit_ - source setupvars.sh - ``` - For more information about OpenVINO build, refer to the [OpenVINO Build Instructions](https://github.com/openvinotoolkit/openvino/blob/master/docs/dev/build_linux.md). ### Setup Follow the steps below to setup your build environment: -1. **Setup ExecuTorch Environment**: Refer to the [Environment Setup](https://pytorch.org/executorch/main/getting-started-setup#environment-setup) guide for detailed instructions on setting up the ExecuTorch environment. -2. **Setup OpenVINO Backend Environment** -- Install the dependent libs. Ensure that you are inside `executorch/backends/openvino/` directory +1. **Create a Virtual Environment** +- Create a virtual environment and activate it by executing the commands below. ```bash - pip install -r requirements.txt + python -m venv env + source env/bin/activate ``` - Note: To achieve optimal performance with NNCF quantization, you should install the latest development version of NNCF (version 2.16.0.dev0+191b53d9 or higher). -3. Navigate to `scripts/` directory. - -4. **Build OpenVINO Backend C++ Libraries and Executor Runner**: Once the prerequisites are in place, run the `openvino_build.sh` script to start the build process. By default, OpenVINO backend will be built under `cmake-out/backends/openvino/` as `libopenvino_backend.a` - +2. **Clone ExecuTorch Repository from Github** +- Clone Executorch repository by executing the command below. ```bash - ./openvino_build.sh + git clone --recurse-submodules https://github.com/pytorch/executorch.git ``` - **Build OpenVINO Backend Python Package with Pybindings**: To build and install the OpenVINO backend Python package with Python bindings, run the `openvino_build.sh` script with the `--enable_python` argument. This will compile and install the ExecuTorch Python package with the OpenVINO backend into your Python environment. This option will also enable python bindings required to execute OpenVINO backend tests and `aot_optimize_and_infer.py` script inside `executorch/examples/openvino` folder. - +3. **Build ExecuTorch with OpenVINO Backend** +- Ensure that you are inside `executorch/backends/openvino/scripts` directory. The following command builds and installs ExecuTorch with the OpenVINO backend, also compiles the C++ runtime libraries and binaries into `/cmake-out` for quick inference testing. ```bash + openvino_build.sh + ``` +- Optionally, `openvino_build.sh` script can be used to build python package or C++ libraries/binaries seperately. + + **Build OpenVINO Backend Python Package with Pybindings**: To build and install the OpenVINO backend Python package with Python bindings, run the `openvino_build.sh` script with the `--enable_python` argument as shown in the below command. This will compile and install the ExecuTorch Python package with the OpenVINO backend into your Python environment. This option will also enable python bindings required to execute OpenVINO backend tests and `aot_optimize_and_infer.py` script inside `executorch/examples/openvino` folder. + ```bash ./openvino_build.sh --enable_python ``` + **Build C++ Runtime Libraries for OpenVINO Backend**: Run the `openvino_build.sh` script with the `--cpp_runtime` flag to build the C++ runtime libraries as shown in the below command. The compiled libraries files and binaries can be found in the `/cmake-out` directory. The binary located at `/cmake-out/executor_runner` can be used to run inference with vision models. + ```bash + ./openvino_build.sh --cpp_runtime + ``` + **Build C++ Llama Runner**: First, ensure the C++ runtime libraries are built by following the earlier instructions. Then, run the `openvino_build.sh` script with the `--llama_runner flag` to compile the LlaMA runner as shown the below command, which enables executing inference with models exported using export_llama. The resulting binary is located at: `/cmake-out/examples/models/llama/llama_main` + ```bash + ./openvino_build.sh --llama_runner + ``` + +For more information about ExecuTorch environment setup, refer to the [Environment Setup](https://pytorch.org/executorch/main/getting-started-setup#environment-setup) guide. ### Run diff --git a/backends/openvino/partitioner.py b/backends/openvino/partitioner.py index bc3fde573e2..0d407e33f6e 100644 --- a/backends/openvino/partitioner.py +++ b/backends/openvino/partitioner.py @@ -26,12 +26,24 @@ from torch.fx.passes.operator_support import OperatorSupportBase +class PatternNode: + op_types: dict[str, Optional[list]] = {} + + def __init__(self): + self.op_types = {} + + class OpenvinoOperatorsSupport(OperatorSupportBase): + extended_support_dict = { + "torch.ops.dim_order_ops._clone_dim_order.default": None, + "torch.ops.dim_order_ops._to_dim_order_copy.default": None, + } def __init__( self, op_types_to_skip: Optional[set] = None, op_names_to_skip: Optional[set] = None, + enabled_ops_by_name: Optional[set] = None, ) -> None: """ Initializes the OpenvinoOperatorsSupport class. @@ -43,9 +55,12 @@ def __init__( op_types_to_skip = set() if op_names_to_skip is None: op_names_to_skip = set() + if enabled_ops_by_name is None: + enabled_ops_by_name = set() self._op_types_to_skip = op_types_to_skip self._op_names_to_skip = op_names_to_skip + self._enabled_ops_by_name = enabled_ops_by_name def is_node_supported(self, _, node: torch.fx.Node) -> bool: """ @@ -62,7 +77,13 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: op_type = node.target.__name__ else: op_type = str(node.target) - supported_ops = OperatorSupport(options)._support_dict + + if node.name in self._enabled_ops_by_name: + return True + + supported_ops = ( + OperatorSupport(options)._support_dict | self.extended_support_dict + ) if op_type == "getitem": return True @@ -99,6 +120,7 @@ def __init__( self.delegation_spec = DelegationSpec(OpenvinoBackend.__name__, compile_spec) self._op_types_to_skip = op_types_to_skip self._op_names_to_skip = op_names_to_skip + self._enabled_ops_by_name: set = set() def ops_to_not_decompose( self, @@ -117,9 +139,72 @@ def ops_to_not_decompose( torch.ops.aten.upsample_bilinear2d.vec, torch.ops.aten.upsample_nearest2d.default, torch.ops.aten.upsample_nearest2d.vec, + torch.ops.aten.stack.default, ] return (ops_not_decompose, None) + def check_pattern( + self, node: torch.fx.Node, pattern: type[PatternNode], enabled_ops: list + ) -> bool: + if node.op == "call_function": + if ("call_function" + ":" + str(node.target.__name__)) in pattern.op_types: # type: ignore[union-attr] + pt_input_nodes = node.all_input_nodes + pattern_input_ops = pattern.op_types[ + "call_function" + ":" + str(node.target.__name__) # type: ignore[union-attr] + ] + if pattern_input_ops is None: + enabled_ops.append(node) + return True + if len(pt_input_nodes) != len(pattern_input_ops): + return False + for i in range(len(pt_input_nodes)): + if not self.check_pattern( + pt_input_nodes[i], pattern_input_ops[i], enabled_ops + ): + return False + enabled_ops.append(node) + return True + elif node.op == "get_attr": + if "get_attr" in pattern.op_types: + return True + else: + return False + elif node.op == "placeholder": + if "placeholder" in pattern.op_types: + return True + else: + return False + return False + + def capture_nncf_patterns(self, graph_module: torch.fx.GraphModule): + const_node = PatternNode + const_node.op_types["get_attr"] = None + const_node.op_types["placeholder"] = None + bitwise_right_shift_node = PatternNode + bitwise_right_shift_node.op_types[ + "call_function:aten.bitwise_right_shift.Tensor_Scalar" + ] = [const_node] + bitwise_and_node = PatternNode + bitwise_and_node.op_types["call_function:aten.bitwise_and.Scalar"] = [ + const_node + ] + stack_node = PatternNode + stack_node.op_types["call_function:aten.stack.default"] = [ + bitwise_and_node, + bitwise_right_shift_node, + ] + + for node in graph_module.graph.nodes: + if ( + str(node.op) == "call_function" + and str(node.target.__name__) == "aten.stack.default" + ): + enabled_ops: list = [] + pattern_match = self.check_pattern(node, stack_node, enabled_ops) + if pattern_match: + for pattern_op in enabled_ops: + self._enabled_ops_by_name.add(pattern_op.name) + def partition(self, exported_program: ExportedProgram) -> PartitionResult: """ Partitions an exported program into supported and unsupported segments. @@ -127,9 +212,14 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: :param exported_program: The exported program. :return: A PartitionResult containing the partitioned graph and delegation tags. """ + self.capture_nncf_patterns(exported_program.graph_module) partitioner = CapabilityBasedPartitioner( exported_program.graph_module, - OpenvinoOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip), + OpenvinoOperatorsSupport( + self._op_types_to_skip, + self._op_names_to_skip, + self._enabled_ops_by_name, + ), allows_single_node_partition=True, ) partition_list = partitioner.propose_partitions() diff --git a/backends/openvino/preprocess.py b/backends/openvino/preprocess.py index c343f44a8b5..691115f6579 100644 --- a/backends/openvino/preprocess.py +++ b/backends/openvino/preprocess.py @@ -14,6 +14,8 @@ PreprocessResult, ) from executorch.exir.backend.compile_spec_schema import CompileSpec + +from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass from openvino.frontend.pytorch.torchdynamo.compile import ( # type: ignore[import-untyped] openvino_compile, ) @@ -36,6 +38,12 @@ def preprocess( Returns: PreprocessResult: The result of preprocessing, including the compiled model bytes. """ + transformed_ep = DimOrderOpsRevertPass()(edge_program.graph_module) + + # Update the edge_program with the transformed graph + if transformed_ep and transformed_ep.graph_module: + edge_program._graph_module = transformed_ep.graph_module + input_names = edge_program.graph_signature.user_inputs args = [] for node in edge_program.graph.nodes: diff --git a/backends/openvino/quantizer/__init__.py b/backends/openvino/quantizer/__init__.py index df038483f2f..5aae52ef3e8 100644 --- a/backends/openvino/quantizer/__init__.py +++ b/backends/openvino/quantizer/__init__.py @@ -1,3 +1,3 @@ -from .quantizer import OpenVINOQuantizer, quantize_model +from .quantizer import OpenVINOQuantizer, QuantizationMode, quantize_model -__all__ = ["OpenVINOQuantizer", "quantize_model"] +__all__ = ["OpenVINOQuantizer", "quantize_model", "QuantizationMode"] diff --git a/backends/openvino/quantizer/observers.py b/backends/openvino/quantizer/observers.py new file mode 100644 index 00000000000..6cda4561604 --- /dev/null +++ b/backends/openvino/quantizer/observers.py @@ -0,0 +1,186 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file found in the +# LICENSE file in the root directory of this source tree. + +# mypy: disable-error-code=import-not-found + +from abc import ABC, abstractmethod +from typing import Optional, Tuple + +import torch + +from nncf.experimental.torch.fx.node_utils import ( # type: ignore[import-untyped] + get_tensor_constant_from_node, +) +from nncf.experimental.torch.fx.transformations import ( # type: ignore[import-untyped] + constant_update, + module_insertion, + node_removal, +) +from nncf.quantization.algorithms.weight_compression.config import ( # type: ignore[import-untyped] + WeightCompressionParameters, +) +from nncf.quantization.algorithms.weight_compression.weight_lowering import ( # type: ignore[import-untyped] + do_integer_quantization, +) +from nncf.tensor.tensor import Tensor as NNCFTensor # type: ignore[import-untyped] +from nncf.torch.graph.transformations.commands import ( # type: ignore[import-untyped] + PTTargetPoint, + TargetType, +) +from nncf.torch.quantization.layers import ( # type: ignore[import-untyped] + BaseWeightsDecompressor, + INT4AsymmetricWeightsDecompressor, + INT4SymmetricWeightsDecompressor, + INT8AsymmetricWeightsDecompressor, + INT8SymmetricWeightsDecompressor, +) +from torchao.quantization.pt2e import ObserverBase + + +class WeightObserverBase(ObserverBase, ABC): + """ + Base implementation of an NNCF observer that defines the rules for compressing layer weights into the OpenVINO representation. + """ + + def __init__( + self, + wc_param: WeightCompressionParameters, + dtype: torch.dtype, + **kwargs, + ) -> None: + """ + :param wc_param: Weight compression parameters container. + :param dtype: target dtype for the quantization. + """ + super().__init__(dtype=dtype, is_dynamic=False) + self._wc_param = wc_param + + def calculate_qparams( # type: ignore[override] + self, + weight: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Calculates quantization parameters: quantized weight, quantization scale and quantization zero point. + + :param weight: FP weight to be used for calculating qparams. + :return: A tuple containing the quantized weight, quantization scale and quantization zero point. + """ + wc_param = self._wc_param + wc_config = wc_param.compression_config + reduction_axes = wc_param.reduction_axes + q_weight, scale, zp = do_integer_quantization( + NNCFTensor(weight), wc_config, reduction_axes=reduction_axes + ) + zp = zp.data if zp is not None else None + return q_weight.data, scale.data, zp + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + def convert( + self, model: torch.fx.GraphModule, observer_node: torch.fx.Node + ) -> None: + """ + Replaces the given observer node from the given model with a quantized + weight and a OpenVINO specific decompression module. + + :param model: A `torch.fx.GraphModule` representing the statically traced model + with observer nodes attached and calibrated. + :param observer_node: The `torch.fx.Node` corresponding to the observer module for + the weight that is being transformed into a compressed representation. + """ + weight_node = observer_node.args[0] + original_weight = get_tensor_constant_from_node(weight_node, model) + q_weight, scale, zero_point = self.calculate_qparams(original_weight) + + decompressor = self._create_decompressor( + scale, zero_point, q_weight, original_weight + ) + packed_q_weight = decompressor.pack_weight(q_weight) + + # Weight port id is 0 since observer is inserted for a single weight only. + constant_update(model, observer_node, packed_q_weight, input_port_id=0) + + compressed_weight_name = observer_node.all_input_nodes[0].name + decompressor_suffix = "_".join( + compressed_weight_name.replace(".", "_").split("_")[:-2] + ) + decompressor_name = f"{decompressor.quantization_mode}_weights_decompressor_{decompressor_suffix}" + + module_insertion( + model, + decompressor, + [ + PTTargetPoint( + TargetType.OPERATOR_POST_HOOK, + target_node_name=compressed_weight_name, + ) + ], + decompressor_name, + ) + node_removal(model, observer_node, 0) + + @abstractmethod + def _create_decompressor( + self, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + q_weight: torch.Tensor, + original_weight: torch.Tensor, + ) -> BaseWeightsDecompressor: + """ + Returns a respective NNCF decompressor for different types of quantization. + + :param scale: Calculated scale quantization parameter. + :param zero_point: Calculated zero_point quantization parameter. + :param q_weight: Calculated quantized weight. + :param original_weight: FP weight. + :return: NNCF observer according to the qmode which creates the decompression subgraph supported by OpenVINO. + """ + + +class INT4WeightObserver(WeightObserverBase): + """ + OpenVINO INT4 Weight Compression observer. + """ + + def _create_decompressor( + self, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + q_weight: torch.Tensor, + original_weight: torch.Tensor, + ) -> BaseWeightsDecompressor: + if zero_point is None: + return INT4SymmetricWeightsDecompressor( + scale, q_weight.shape, original_weight.shape, original_weight.dtype + ) + return INT4AsymmetricWeightsDecompressor( + scale, + zero_point, + q_weight.shape, + original_weight.shape, + original_weight.dtype, + ) + + +class INT8WeightObserver(WeightObserverBase): + """ + OpenVINO INT8 Weight Compression per channel observer. + """ + + def _create_decompressor( + self, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + q_weight: torch.Tensor, + original_weight: torch.Tensor, + ) -> BaseWeightsDecompressor: + if zero_point is None: + return INT8SymmetricWeightsDecompressor(scale, original_weight.dtype) + return INT8AsymmetricWeightsDecompressor( + scale, zero_point, original_weight.dtype + ) diff --git a/backends/openvino/quantizer/quantizer.py b/backends/openvino/quantizer/quantizer.py index edce272ff9b..5766013689b 100644 --- a/backends/openvino/quantizer/quantizer.py +++ b/backends/openvino/quantizer/quantizer.py @@ -15,8 +15,24 @@ import nncf.experimental.torch.fx as nncf_fx # type: ignore[import-untyped] import torch.fx - +from executorch.backends.openvino.quantizer.observers import ( + INT4WeightObserver, + INT8WeightObserver, +) from nncf.common.graph.graph import NNCFGraph # type: ignore[import-untyped] +from nncf.common.logging import nncf_logger # type: ignore[import-untyped] +from nncf.quantization.algorithms.min_max.algorithm import ( # type: ignore[import-untyped] + MinMaxQuantization, +) +from nncf.quantization.algorithms.weight_compression.config import ( # type: ignore[import-untyped] + WeightCompressionParameters, +) +from nncf.quantization.quantize_model import ( # type: ignore[import-untyped] + get_weight_compression_configuration, +) +from nncf.torch.model_graph_manager import ( # type: ignore[import-untyped] + get_weight_tensor_port_ids, +) from torchao.quantization.pt2e import ( HistogramObserver, PerChannelMinMaxObserver, @@ -30,7 +46,8 @@ Quantizer, SharedQuantizationSpec, ) -from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY + +QUANT_ANNOTATION_KEY = "quantization_annotation" class QuantizationMode(Enum): @@ -40,11 +57,19 @@ class QuantizationMode(Enum): - INT8_SYM: INT8 symmetric quantization for both activations and weights. - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights. - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models + - INT8WO_SYM: INT8 symmetric quantization for weights only. + - INT8WO_ASYM: INT8 asymmetric quantization for weights only. + - INT4WO_SYM: INT4 symmetric quantization for weights only. + - INT4WO_ASYM: INT4 asymmetric quantization for weights only """ INT8_SYM = "int8_sym" INT8_MIXED = "int8_mixed" INT8_TRANSFORMER = "int8_transformer" + INT8WO_SYM = "int8wo_sym" + INT8WO_ASYM = "int8wo_asym" + INT4WO_SYM = "int4wo_sym" + INT4WO_ASYM = "int4wo_asym" class OpenVINOQuantizer(Quantizer): @@ -53,10 +78,17 @@ class OpenVINOQuantizer(Quantizer): optimally for the inference via OpenVINO. """ + WEIGHTS_ONLY_COMPRESSION_MODES = ( + QuantizationMode.INT4WO_SYM, + QuantizationMode.INT4WO_ASYM, + QuantizationMode.INT8WO_SYM, + QuantizationMode.INT8WO_ASYM, + ) + def __init__( self, *, - mode: Optional[QuantizationMode] = QuantizationMode.INT8_SYM, + mode: QuantizationMode = QuantizationMode.INT8_SYM, **kwargs, ): """ @@ -65,22 +97,36 @@ def __init__( - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights. - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models Default value is INT8_SYM. + - INT4_SYM: Symmetric INT4 Weights-Only Compression + - INT4_ASYM: Asymmetric INT4 Weights-Only Compression :param kwargs: Arguments to pass to the NNCF MinMaxQuantization algorithm. """ - if mode == QuantizationMode.INT8_SYM: - preset = quantization.structs.QuantizationPreset.PERFORMANCE - model_type = None - elif mode == QuantizationMode.INT8_MIXED: - preset = quantization.structs.QuantizationPreset.MIXED - model_type = None - else: - preset = None - model_type = nncf.parameters.ModelType.TRANSFORMER - self._min_max_algo = ( - nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization( + self.mode = mode + if self.mode not in OpenVINOQuantizer.WEIGHTS_ONLY_COMPRESSION_MODES: + if mode == QuantizationMode.INT8_SYM: + preset = quantization.structs.QuantizationPreset.PERFORMANCE + model_type = None + elif mode == QuantizationMode.INT8_MIXED: + preset = quantization.structs.QuantizationPreset.MIXED + model_type = None + else: + preset = None + model_type = nncf.parameters.ModelType.TRANSFORMER + self._algo = MinMaxQuantization( preset=preset, model_type=model_type, **kwargs ) - ) + else: + compression_mode = mode.value.replace( + "wo", "" + ) # Mode value has to match NNCF CompressWeightsMode + weight_compression_configuration = get_weight_compression_configuration( + nncf.CompressWeightsMode(compression_mode), + **kwargs, + ) + subset_size = 1 # Doesn't really matter in this case since it is data-free. Should just be +ve + self._algo = nncf.quantization.algorithms.weight_compression.algorithm.WeightCompression( + subset_size=subset_size, **weight_compression_configuration + ) def set_ignored_scope( self, @@ -101,7 +147,7 @@ def set_ignored_scope( :param validate: If set to True, then a RuntimeError will be raised if any ignored scope does not match in the model graph. """ - self._min_max_algo.set_ignored_scope( + self._algo.set_ignored_scope( nncf.IgnoredScope( names=names or [], patterns=patterns or [], @@ -114,27 +160,73 @@ def set_ignored_scope( def get_nncf_quantization_setup( self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph ) -> quantization.quantizer_setup.SingleConfigQuantizerSetup: - self._min_max_algo._set_backend_entity(model) - return self._min_max_algo.find_quantization_setup(model, nncf_graph) + self._algo._set_backend_entity(model) + return self._algo.find_quantization_setup(model, nncf_graph) - def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - nncf_graph = nncf_fx.nncf_graph_builder.GraphConverter.create_nncf_graph(model) - quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph) + def _annotate_weight_compression( + self, + model: torch.fx.GraphModule, + graph: torch.fx.Graph, + nncf_graph: NNCFGraph, + node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation], + ) -> DefaultDict[torch.fx.Node, QuantizationAnnotation]: + """ + Annotates the model graph with weight-only quantization specs. - graph = model.graph - node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation] = ( - defaultdict(QuantizationAnnotation) + Identifies compressible nodes in the NNCF graph and attaches the corresponding + TorchAO quantization specifications to their weight edges for later transformation. + + :param model: The FX GraphModule to annotate. + :param graph: The underlying FX graph. + :param nncf_graph: The corresponding NNCF graph. + :param node_vs_torch_annotation: A mapping of FX nodes to quantization annotations. + :return: Updated mapping of FX nodes with weight compression annotations. + """ + self._algo.set_backend_entity(model) + all_wc_params, _ = self._algo.get_weight_compression_parameters( + model, nncf_graph ) + for wc_param in all_wc_params: + node_with_weight = wc_param.node_with_weight + target_node = nncf_fx.node_utils.get_graph_node_by_name( + graph, node_with_weight.node_name + ) + annotation = node_vs_torch_annotation[target_node] + edge_or_node = self._get_weight_edge(target_node, nncf_graph) + qspec = self._get_torch_ao_qspec_from_nncf_config_for_wc(wc_param=wc_param) + self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) + + return node_vs_torch_annotation + + def _annotate_post_training_quantization( + self, + model: torch.fx.GraphModule, + graph: torch.fx.Graph, + nncf_graph: NNCFGraph, + node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation], + ) -> DefaultDict[torch.fx.Node, QuantizationAnnotation]: + """ + Annotates the model graph with post-training quantization configurations. + + :param model: The FX GraphModule to annotate. + :param graph: The underlying FX graph. + :param nncf_graph: The corresponding NNCF graph. + :param node_vs_torch_annotation: A mapping of FX nodes to quantization annotations. + :return: Updated mapping of FX nodes with post-training quantization annotations. + """ + quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph) + for qp in quantization_setup.quantization_points.values(): edge_or_node, annotation = self._get_edge_or_node_and_annotation( graph, nncf_graph, qp, node_vs_torch_annotation ) - qspec: QuantizationSpecBase = self._get_torch_ao_qspec_from_qp(qp) + qspec: QuantizationSpecBase = ( + self._get_torch_ao_qspec_from_nncf_config_for_ptq(qp) + ) self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) for quantizer_ids in quantization_setup.unified_scale_groups.values(): - root_quantizer_id = self._get_unified_scales_root_quantizer_id( nncf_graph, quantizer_ids, quantization_setup ) @@ -145,14 +237,12 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: for q_id in quantizer_ids ): qps = [ - quantization_setup.quantization_points[q_id] - for q_id in quantizer_ids + quantization_setup.quantization_points[qid] for qid in quantizer_ids ] - msg = ( + raise nncf.InternalError( "Different quantization configs are set to one unified scale group:" f"{[(qp.insertion_point.__dict__, str(qp.qconfig)) for qp in qps]}" ) - raise nncf.InternalError(msg) root_target_node = nncf_fx.node_utils.get_graph_node_by_name( graph, root_qp.insertion_point.target_node_name @@ -165,16 +255,35 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: if quantizer_id == root_quantizer_id: continue - qspec = SharedQuantizationSpec(root_edge_or_node) + qspec = SharedQuantizationSpec(root_edge_or_node) # type: ignore[assignment] qp = quantization_setup.quantization_points[quantizer_id] edge_or_node, annotation = self._get_edge_or_node_and_annotation( graph, nncf_graph, qp, node_vs_torch_annotation ) self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) + return node_vs_torch_annotation + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + nncf_graph = nncf_fx.nncf_graph_builder.GraphConverter.create_nncf_graph(model) + graph = model.graph + node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation] = ( + defaultdict(QuantizationAnnotation) + ) + + if self.mode in OpenVINOQuantizer.WEIGHTS_ONLY_COMPRESSION_MODES: + node_vs_torch_annotation = self._annotate_weight_compression( + model, graph, nncf_graph, node_vs_torch_annotation + ) + else: + node_vs_torch_annotation = self._annotate_post_training_quantization( + model, graph, nncf_graph, node_vs_torch_annotation + ) + for node, annotation in node_vs_torch_annotation.items(): - assert Q_ANNOTATION_KEY not in node.meta - node.meta[Q_ANNOTATION_KEY] = annotation + assert QUANT_ANNOTATION_KEY not in node.meta + node.meta[QUANT_ANNOTATION_KEY] = annotation + return model @staticmethod @@ -236,6 +345,33 @@ def _get_edge_or_node_and_annotation( edge_or_node = OpenVINOQuantizer._get_edge_or_node(target_node, qp, nncf_graph) return edge_or_node, annotation + @staticmethod + def _get_weight_edge( + target_node: torch.fx.Node, + nncf_graph: NNCFGraph, + ) -> tuple[torch.fx.Node, torch.fx.Node]: + """ + Returns the FX node corresponding to the weight tensor input of a given operator node. + Uses the NNCF graph to identify which input port of the target node holds the weight. + If multiple weight ports are present, a warning is issued and only the first one is used. + + :param target_node: FX node representing a weighted operation (e.g., Linear, Conv). + :param nncf_graph: NNCFGraph used to determine weight port indices. + :return: Edge represented by a Tuple of (weight_node, target_node), where weight_node is the FX node supplying the weight. + """ + nncf_node = nncf_graph.get_node_by_name(target_node.name) + weights_ports_ids = get_weight_tensor_port_ids(nncf_node, nncf_graph) + if len(weights_ports_ids) > 1: + # TODO(dlyakhov): support quantization for nodes with several weights + nncf_logger.warning( + f"Quantization of the weighted node {target_node.name}" + " is not yet supported by the OpenVINOQuantizer." + f" Only the weight on port ID {weights_ports_ids[0]} will be quantized." + f" Quantizable weights are located on ports: {weights_ports_ids}." + ) + weight_node = target_node.all_input_nodes[weights_ports_ids[0]] + return (weight_node, target_node) + @staticmethod def _get_edge_or_node( target_node: torch.fx.Node, @@ -252,22 +388,7 @@ def _get_edge_or_node( """ ip = qp.insertion_point if qp.is_weight_quantization_point(): - nncf_node = nncf_graph.get_node_by_name(target_node.name) - weights_ports_ids = ( - nncf.torch.model_graph_manager.get_weight_tensor_port_ids( - nncf_node, nncf_graph - ) - ) - if len(weights_ports_ids) > 1: - # TODO(dlyakhov): support quantization for nodes with several weights - nncf.common.logging.nncf_logger.warning( - f"Quantization of the weighted node {target_node.name}" - " is not yet supported by the OpenVINOQuantizer." - f" Only the weight on port ID {weights_ports_ids[0]} will be quantized." - f" Quantizable weights are located on ports: {weights_ports_ids}." - ) - weight_node = target_node.all_input_nodes[weights_ports_ids[0]] - return (weight_node, target_node) + return OpenVINOQuantizer._get_weight_edge(target_node, nncf_graph) if ip.input_port_id is None: return target_node @@ -294,22 +415,78 @@ def _fill_torch_ao_annotation( annotation_to_update.input_qspec_map[edge_or_node[0]] = qspec @staticmethod - def _get_torch_ao_qspec_from_qp( + def _get_torch_ao_qspec_from_nncf_config_for_wc( + wc_param: WeightCompressionParameters, + ) -> QuantizationSpec: + """ + Returns a TorchAO QuantizationSpec based on NNCF weight compression parameter. + + :param wc_param: NNCF Weight compression parameters for the node. + :return: A TorchAO QuantizationSpec. + """ + observer: Type[UniformQuantizationObserverBase] + + extra_args: Dict[str, Any] = {} + + qmode = wc_param.compression_config.mode + extra_args["wc_param"] = wc_param + is_asym_mode = wc_param.compression_config.is_asym_mode + if qmode in [ + nncf.CompressWeightsMode.INT4_ASYM, + nncf.CompressWeightsMode.INT4_SYM, + ]: + observer = INT4WeightObserver # type: ignore[type-abstract] + quant_min = -8 if not is_asym_mode else 0 + quant_max = 7 if not is_asym_mode else 15 + dtype = torch.int8 + channel_axis = 0 + torch_qscheme = torch_qscheme = ( + torch.per_channel_symmetric + if not is_asym_mode + else torch.per_channel_affine + ) + else: + observer = INT8WeightObserver # type: ignore[type-abstract] + quant_min = -128 if not is_asym_mode else 0 + quant_max = 127 if not is_asym_mode else 255 + dtype = torch.int8 + channel_axis = 0 + torch_qscheme = ( + torch.per_channel_symmetric + if not is_asym_mode + else torch.per_channel_affine + ) + return QuantizationSpec( + dtype=dtype, + observer_or_fake_quant_ctr=observer.with_args(**extra_args), + quant_min=quant_min, + quant_max=quant_max, + qscheme=torch_qscheme, + ch_axis=channel_axis, + is_dynamic=False, + ) + + @staticmethod + def _get_torch_ao_qspec_from_nncf_config_for_ptq( qp: quantization.quantizer_setup.QuantizationPointBase, ) -> QuantizationSpec: """ - Retrieves the quantization configuration from the given quantization point and - converts it into a QuantizationSpec. + Returns a TorchAO QuantizationSpec based on NNCF quantization point. - :param qp: An instance of QuantizationPointBase. - :return: A QuantizationSpec retrieved and converted from the quantization point. + :param qp: Quantization point from NNCF. + :return: A TorchAO QuantizationSpec. """ + observer: Type[UniformQuantizationObserverBase] + # Eps value is copied from nncf/torch/quantization/layers.py - extra_args = {"eps": 1e-16} - qconfig = qp.qconfig - is_weight = qp.is_weight_quantization_point() + extra_args: Dict[str, Any] = {"eps": 1e-16} - observer: Type[UniformQuantizationObserverBase] + is_weight = qp.is_weight_quantization_point() + qconfig = qp.qconfig + dtype = torch.int8 + quant_min = None + quant_max = None + channel_axis = None if qconfig.per_channel: torch_qscheme = ( @@ -329,6 +506,11 @@ def _get_torch_ao_qspec_from_qp( quant_max = 127 dtype = torch.int8 channel_axis = 0 + torch_qscheme = ( + torch.per_channel_symmetric + if qconfig.mode is quantization.structs.QuantizationScheme.SYMMETRIC + else torch.per_channel_affine + ) else: observer = ( HistogramObserver diff --git a/backends/openvino/requirements.txt b/backends/openvino/requirements.txt index 316633e9004..519818d0aac 100644 --- a/backends/openvino/requirements.txt +++ b/backends/openvino/requirements.txt @@ -1,2 +1,2 @@ transformers -git+https://github.com/openvinotoolkit/nncf@6b0fc1c#egg=nncf +git+https://github.com/openvinotoolkit/nncf@3d753ac#egg=nncf diff --git a/backends/openvino/runtime/OpenvinoBackend.cpp b/backends/openvino/runtime/OpenvinoBackend.cpp index 8ec40d7f7c6..bac006ce916 100644 --- a/backends/openvino/runtime/OpenvinoBackend.cpp +++ b/backends/openvino/runtime/OpenvinoBackend.cpp @@ -114,6 +114,26 @@ exr::Error OpenvinoBackend::execute( ov_type, input_shape, input_tensor.mutable_data_ptr()); infer_request->set_input_tensor(i, ov_input_tensor); + + if (args[i]->isInt()) { + int64_t* val = &(args[i]->payload.copyable_union.as_int); + + // Create OpenVINO tensor from integer input + ov::Tensor ov_input_tensor(ov::element::i64, ov::Shape{1}, val); + infer_request->set_input_tensor(i, ov_input_tensor); + } else { + auto input_tensor = args[i]->toTensor(); + ov::Shape input_shape( + input_tensor.sizes().begin(), input_tensor.sizes().end()); + + // Convert input tensor to OpenVINO tensor + ov::element::Type ov_type = + convert_to_openvino_type(input_tensor.scalar_type()); + ov::Tensor ov_input_tensor( + ov_type, input_shape, input_tensor.mutable_data_ptr()); + + infer_request->set_input_tensor(i, ov_input_tensor); + } } // Set outputs @@ -165,10 +185,14 @@ ov::element::Type OpenvinoBackend::convert_to_openvino_type( switch (scalar_type) { case exa::ScalarType::Float: return ov::element::f32; + case exa::ScalarType::Half: + return ov::element::f16; case exa::ScalarType::Int: return ov::element::i32; case exa::ScalarType::Char: return ov::element::i8; + case exa::ScalarType::Byte: + return ov::element::u8; case exa::ScalarType::Long: return ov::element::i64; case exa::ScalarType::Bool: diff --git a/backends/openvino/scripts/openvino_build.sh b/backends/openvino/scripts/openvino_build.sh index 5a26f0b6dae..6762229081f 100755 --- a/backends/openvino/scripts/openvino_build.sh +++ b/backends/openvino/scripts/openvino_build.sh @@ -7,55 +7,108 @@ set -e EXECUTORCH_ROOT=$(realpath "$(dirname "$0")/../../..") echo EXECUTORCH_ROOT=${EXECUTORCH_ROOT} -main() { - build_type=${1:-"--cpp_runtime"} +install_requirements() { + echo "Installing Requirements For OpenVINO Backend" + cd "$EXECUTORCH_ROOT" + pip install -r backends/openvino/requirements.txt +} - # If the first arguments is --cpp_runtime (default), build libraries for C++ runtime - if [[ -z "$build_type" || "$build_type" == "--cpp_runtime" ]]; then - echo "Building C++ Runtime Libraries" +build_cpp_runtime() { + echo "Building C++ Runtime Libraries" + + # Set build directory + local build_dir="cmake-out" + + # Enter the Executorch root directory + cd "$EXECUTORCH_ROOT" + rm -rf "${build_dir}" + + # Configure the project with CMake + # Note: Add any additional configuration options you need here + cmake -DCMAKE_INSTALL_PREFIX="${build_dir}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_OPENVINO=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ + -DEXECUTORCH_BUILD_EXTENSION_LLM=ON \ + -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ + -B"${build_dir}" + + + # Build the project + cmake --build ${build_dir} --target install --config Release -j$(nproc) +} - # Set build directory - local build_dir="cmake-out" +build_llama_runner() { + echo "Building Export Llama Runner" - # Create and enter the build directory - cd "$EXECUTORCH_ROOT" - rm -rf "${build_dir}" + # Set build directory + local build_dir="cmake-out" - # Configure the project with CMake - # Note: Add any additional configuration options you need here - cmake -DCMAKE_INSTALL_PREFIX="${build_dir}" \ - -DCMAKE_BUILD_TYPE=Release \ - -DEXECUTORCH_BUILD_OPENVINO=ON \ - -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ - -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ - -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ - -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ - -DEXECUTORCH_BUILD_OPENVINO_EXECUTOR_RUNNER=ON \ - -B"${build_dir}" + # Enter the Executorch root directory + cd "$EXECUTORCH_ROOT" + # Configure the project with CMake + # Note: Add any additional configuration options you need here + cmake -DCMAKE_INSTALL_PREFIX="${build_dir}" \ + -DCMAKE_BUILD_TYPE=Release \ + -B"${build_dir}"/examples/models/llama \ + examples/models/llama + # Build the export llama runner + cmake --build cmake-out/examples/models/llama -j$(nproc) --config Release +} - # Build the project - cmake --build ${build_dir} --target install --config Release -j$(nproc) +build_python_enabled() { + echo "Building Python Package with Pybinding" - # If the first arguments is --enable_python, build python package with python bindings - elif [[ "$build_type" == "--enable_python" ]]; then - echo "Building Python Package with Pybinding" + # Enter the Executorch root directory + cd "$EXECUTORCH_ROOT" + ./install_executorch.sh --clean + + # Set parameters to configure the project with CMake + # Note: Add any additional configuration options you need here + export CMAKE_ARGS="-DEXECUTORCH_BUILD_OPENVINO=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON" + export CMAKE_BUILD_ARGS="--target openvino_backend" - # Create and enter the build directory - cd "$EXECUTORCH_ROOT" - ./install_executorch.sh --clean + # Build the package + ./install_executorch.sh --minimal - # Set parameters to configure the project with CMake - # Note: Add any additional configuration options you need here - export CMAKE_ARGS="-DEXECUTORCH_BUILD_OPENVINO=ON \ - -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON" - export CMAKE_BUILD_ARGS="--target openvino_backend" + # Install torchao + # Note: --no-build-isolation is required because torchao's setup.py imports torch + # See comment in torchao's pyproject.toml for more details + pip install third-party/ao --no-build-isolation +} + +main() { + build_type=${1:-"--build_all"} + + # If the first arguments is --build_all (default), build python package, C++ runtime, and llama runner binary + if [[ -z "$build_type" || "$build_type" == "--build_all" ]]; then + install_requirements + build_python_enabled + build_cpp_runtime + build_llama_runner - # Build the package - ./install_executorch.sh --minimal + # If the first arguments is --cpp_runtime, build libraries for C++ runtime + elif [[ "$build_type" == "--cpp_runtime" ]]; then + build_cpp_runtime - # Install torchao - pip install third-party/ao + # If the first arguments is --llama_runner, build export llama runner binary + # Note: c++ runtime with openvino backend should be built before building export llama runner + elif [[ "$build_type" == "--llama_runner" ]]; then + build_llama_runner + + # If the first arguments is --enable_python, build python package with python bindings + elif [[ "$build_type" == "--enable_python" ]]; then + install_requirements + build_python_enabled else echo "Error: Argument is not valid: $build_type" diff --git a/backends/qualcomm/CMakeLists.txt b/backends/qualcomm/CMakeLists.txt index 32105597260..8ce1ce1bdbf 100644 --- a/backends/qualcomm/CMakeLists.txt +++ b/backends/qualcomm/CMakeLists.txt @@ -23,6 +23,47 @@ get_filename_component( _common_include_directories "${EXECUTORCH_SOURCE_DIR}/.." ABSOLUTE ) +# We only download QNN SDK when we build pip wheel for ExecuTorch. Please don't +# change this code unless you know what you are doing. +if(EXECUTORCH_BUILD_WHEEL_DO_NOT_USE) + set(_qnn_default_sdk_dir "${CMAKE_CURRENT_BINARY_DIR}/sdk/qnn") + + if(EXISTS "${_qnn_default_sdk_dir}" AND EXISTS "${_qnn_default_sdk_dir}/lib") + message(STATUS "Found cached Qualcomm SDK at ${_qnn_default_sdk_dir}") + set(QNN_SDK_ROOT + ${_qnn_default_sdk_dir} + CACHE PATH "Qualcomm SDK root directory" FORCE + ) + else() + message(STATUS "Downloading Qualcomm SDK") + execute_process( + COMMAND + ${PYTHON_EXECUTABLE} + ${EXECUTORCH_SOURCE_DIR}/backends/qualcomm/scripts/download_qnn_sdk.py + --dst-folder ${_qnn_default_sdk_dir} --print-sdk-path + WORKING_DIRECTORY ${EXECUTORCH_SOURCE_DIR} + RESULT_VARIABLE _qnn_sdk_download_result + OUTPUT_VARIABLE _qnn_sdk_download_output + ERROR_VARIABLE _qnn_sdk_download_error + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if(NOT _qnn_sdk_download_result EQUAL 0 OR _qnn_sdk_download_output + STREQUAL "" + ) + message( + FATAL_ERROR + "Failed to download Qualcomm SDK. stdout: ${_qnn_sdk_download_output}\n" + "stderr: ${_qnn_sdk_download_error}" + ) + endif() + set(QNN_SDK_ROOT + ${_qnn_sdk_download_output} + CACHE PATH "Qualcomm SDK root directory" FORCE + ) + endif() + set(ENV{QNN_SDK_ROOT} ${QNN_SDK_ROOT}) +endif() + if(NOT DEFINED QNN_SDK_ROOT) message( FATAL_ERROR @@ -109,6 +150,7 @@ add_library(qnn_executorch_backend SHARED) add_library(qnn_executorch_header INTERFACE) add_library(qnn_executorch_logging STATIC) add_library(qnn_factory STATIC) +add_library(qnn_backend_unified_registry STATIC) add_library(qnn_function_interface INTERFACE) add_library(qnn_graph STATIC) add_library(qnn_implementation STATIC) @@ -172,13 +214,30 @@ target_link_libraries( ) target_link_libraries( - qnn_dlc_manager PRIVATE qnn_factory qnn_backend qnn_device qnn_context - qnn_graph qnn_mem_manager + qnn_backend_unified_registry PRIVATE qnn_schema qnn_backend qnn_device + qnn_implementation +) + +target_link_libraries( + qnn_dlc_manager + PRIVATE qnn_factory + qnn_backend_unified_registry + qnn_backend + qnn_device + qnn_context + qnn_graph + qnn_mem_manager ) target_link_libraries( - qnn_manager PRIVATE qnn_factory wrappers qnn_schema utils shared_buffer - qnn_dlc_manager + qnn_manager + PRIVATE qnn_factory + qnn_backend_unified_registry + wrappers + qnn_schema + utils + shared_buffer + qnn_dlc_manager ) target_link_libraries( qnn_executorch_backend @@ -214,7 +273,9 @@ add_subdirectory( install( TARGETS qnn_executorch_backend EXPORT ExecuTorchTargets - DESTINATION lib + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm + RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm ) # QNN pybind @@ -224,15 +285,11 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") ${CMAKE_CURRENT_BINARY_DIR}/pybind11 ) add_library(PyQnnManagerAdaptor MODULE) - add_library(PyQnnWrapperAdaptor MODULE) # PyQnnManager containing a pybind type triggers the warning because pybind11 # code internally forces hidden visibility. set_target_properties( PyQnnManagerAdaptor PROPERTIES CXX_VISIBILITY_PRESET hidden ) - set_target_properties( - PyQnnWrapperAdaptor PROPERTIES CXX_VISIBILITY_PRESET hidden - ) target_link_libraries( PyQnnManagerAdaptor @@ -244,18 +301,14 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") executorch extension_tensor qnn_backend_options - ) - target_link_libraries( - PyQnnWrapperAdaptor PRIVATE pybind11::module pybind11::lto wrappers - qnn_executorch_logging qnn_executorch_header + wrappers + qnn_executorch_logging ) pybind11_extension(PyQnnManagerAdaptor) - pybind11_extension(PyQnnWrapperAdaptor) if(NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES RelWithDebInfo) # Strip unnecessary sections of the binary pybind11_strip(PyQnnManagerAdaptor) - pybind11_strip(PyQnnWrapperAdaptor) endif() if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") @@ -266,13 +319,18 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") target_compile_options( PyQnnManagerAdaptor PUBLIC ${_pybind_compile_options} ) - target_compile_options( - PyQnnWrapperAdaptor PUBLIC ${_pybind_compile_options} - ) endif() add_subdirectory( ${QNN_EXECUTORCH_ROOT_DIR}/aot/python ${CMAKE_CURRENT_BINARY_DIR}/qnn_executorch/python ) + + install( + TARGETS PyQnnManagerAdaptor + LIBRARY + DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm/python + RUNTIME + DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm/python + ) endif() diff --git a/backends/qualcomm/README.md b/backends/qualcomm/README.md index 7c5853b3a6f..fa03bbd0860 100644 --- a/backends/qualcomm/README.md +++ b/backends/qualcomm/README.md @@ -21,12 +21,17 @@ Please check `generate_qnn_executorch_compiler_spec()` in - Snapdragon 8 Gen 2 - Snapdragon 8 Gen 3 - Snapdragon 8 Elite +- Snapdragon 8 Elite Gen 5 - SA8295 +- SA8255 - SSG2115P - SSG2125P - SXR1230P - SXR2230P - SXR2330P +- QCS9100 +- SAR2230P +- SW6100 ### Adding more supported Chipset Currently, users cannot add additional chipset models because the chipset ID is not accessible to community users. If you have specific chipset models you wish to add, please contact one of the authors in the `Code Reviews` section at the bottom of this page. @@ -51,7 +56,7 @@ backends/qualcomm | | # Meanwhile, this is also the runtime responsbile for executing compiled | | # models on a device. | └── backends # Backends supported by QNN. -| └── htpbackend +| └── gpu / htp | ├── aarch64 # Configuration required to run on device. (Device Part). | └── x86_64 # Configuration required to compile graph on host. (AoT Part). ├── scripts # Misc supporting scripts, not related to core functionality. @@ -128,7 +133,7 @@ PRs are always welcome to help improve the codebase in a comprehensive manner. B - **Code Reviews**:
Please ping authors in Qualcomm AI Engine Direct related PRs for reviewing, possible candidates are listed below: - [shewu-quic](https://github.com/shewu-quic) - - [chunit-quic](https://github.com/chunit-quic) + - [chenweng-quic](https://github.com/chenweng-quic) - [winskuo-quic](https://github.com/winskuo-quic) - [DannyYuyang-quic](https://github.com/DannyYuyang-quic) - [haowhsu-quic](https://github.com/haowhsu-quic) diff --git a/backends/qualcomm/__init__.py b/backends/qualcomm/__init__.py new file mode 100644 index 00000000000..5770dfb0fcd --- /dev/null +++ b/backends/qualcomm/__init__.py @@ -0,0 +1,13 @@ +import os + +from .scripts.download_qnn_sdk import install_qnn_sdk, is_linux_x86 + + +env_flag = os.getenv("EXECUTORCH_BUILDING_WHEEL", "0").lower() +# If users have preinstalled QNN_SDK_ROOT, we will use it. +qnn_sdk_root_flag = os.getenv("QNN_SDK_ROOT", None) + +if env_flag not in ("1", "true", "yes") and not qnn_sdk_root_flag and is_linux_x86(): + ok = install_qnn_sdk() + if not ok: + raise RuntimeError("Failed to install QNN SDK. Please check the logs above.") diff --git a/backends/qualcomm/_passes/TARGETS b/backends/qualcomm/_passes/TARGETS index 62a0fc43a78..876b51d3863 100644 --- a/backends/qualcomm/_passes/TARGETS +++ b/backends/qualcomm/_passes/TARGETS @@ -15,5 +15,6 @@ runtime.python_library( "//executorch/backends/transforms:decompose_sdpa", "//executorch/exir/backend:backend_details", "//executorch/exir/backend:compile_spec_schema", + "//executorch/backends/qualcomm/quantizer:quantizer", ], ) diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 15fce79ea12..49449fe2190 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -11,16 +11,23 @@ from .canonicalize_conv import CanonicalizeConv from .convert_bmm_to_matmul import ConvertBmmToMatmul from .convert_linear_to_conv2d import ConvertLinearToConv2d +from .convert_mha_to_sha import ConvertMhaToSha from .convert_square_to_pow import ConvertSquareToPow from .decompose_any import DecomposeAny +from .decompose_binary_alpha import DecomposeBinaryAlpha from .decompose_cdist import DecomposeCDist from .decompose_col_im import DecomposeColIm from .decompose_einsum import DecomposeEinsum from .decompose_expm1 import DecomposeExpM1 +from .decompose_floor_divide import DecomposeFloorDivide +from .decompose_glu import DecomposeGlu from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm +from .decompose_maxpool3d import DecomposeMaxPool3d from .decompose_minmaxdim import DecomposeMinMaxDim from .decompose_roll import DecomposeRoll from .decompose_silu import DecomposeSilu +from .decompose_threshold import DecomposeThreshold +from .decompose_triu import DecomposeTriu from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape from .fixed_linear_keep_dim import FixedLinearKeepDim @@ -30,6 +37,7 @@ from .i64_to_i32 import I64toI32 from .insert_io_qdq import InsertIOQDQ from .insert_requantize import InsertRequantize +from .insert_reshape_for_reduce_ops import InsertReshapeForReduceOps from .layout_transform import LayoutTransform from .lift_constant_scalar_operands import LiftConstantScalarOperands from .recompose_pixel_unshuffle import RecomposePixelUnshuffle @@ -42,7 +50,6 @@ from .seq_mse import SeqMSE from .tag_quant_io import TagQuantIO - __all__ = [ AnnotateAdaptiveAvgPool1D, AnnotateQuantAttrs, @@ -51,16 +58,23 @@ CanonicalizeConv, ConvertBmmToMatmul, ConvertLinearToConv2d, + ConvertMhaToSha, ConvertSquareToPow, DecomposeAny, + DecomposeBinaryAlpha, DecomposeCDist, DecomposeColIm, DecomposeEinsum, DecomposeExpM1, + DecomposeFloorDivide, + DecomposeGlu, DecomposeLinalgVectorNorm, + DecomposeMaxPool3d, DecomposeMinMaxDim, DecomposeRoll, DecomposeSilu, + DecomposeThreshold, + DecomposeTriu, DecomposeWrapWithAutocast, ExpandBroadcastTensorShape, FixedLinearKeepDim, @@ -69,6 +83,7 @@ FuseConsecutiveTranspose, I64toI32, InsertIOQDQ, + InsertReshapeForReduceOps, InsertRequantize, LayoutTransform, LiftConstantScalarOperands, diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index 610e88e6d3b..6077d51b099 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -19,6 +19,7 @@ QCOM_SCALE, QCOM_ZERO_POINT, ) +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from .utils import get_quant_attrs @@ -38,6 +39,9 @@ def __init__( super(AnnotateQuantAttrs, self).__init__() self.edge_program = edge_program self.skip_advanced_requant = skip_advanced_requant + self.skip_requant_allowlist = { + exir_ops.edge.aten.sigmoid.default, + } def _annotate_source_nodes( self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any] @@ -80,6 +84,10 @@ def _annotate_requant(self, n): # node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> .... # We store {node2: quant_attr in dq_int32} in node1.meta if n.target in q_ops and n.args[0].target not in dq_ops: + # for some fixed scale op, there is no need to requantize it + if n.args[0].target in self.skip_requant_allowlist: + return + dq_nodes = self._find_last_dq_nodes(n) q_attrs = get_quant_attrs(self.edge_program, n) for dq_node in dq_nodes: diff --git a/backends/qualcomm/_passes/canonicalize_conv.py b/backends/qualcomm/_passes/canonicalize_conv.py index 3804fb05da0..8836ed44328 100644 --- a/backends/qualcomm/_passes/canonicalize_conv.py +++ b/backends/qualcomm/_passes/canonicalize_conv.py @@ -9,7 +9,6 @@ import torch from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter -from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE from executorch.exir.pass_base import ExportPass, PassResult from torch._guards import detect_fake_mode @@ -34,6 +33,7 @@ def __init__(self, edge_program: torch.export.ExportedProgram): self.transpose_conv_set = { torch.ops.aten.conv_transpose1d.default, torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d.input, } def dilate(self, tensor, dilation): @@ -196,14 +196,6 @@ def call(self, graph_module: torch.fx.GraphModule): ) squeeze_node.meta = copy_meta(node.meta) - if QCOM_REQUANTIZE in input_node.meta: - input_node.meta.pop(QCOM_REQUANTIZE) - if QCOM_REQUANTIZE in node.meta: - squeeze_node.meta[QCOM_REQUANTIZE] = node.meta[ - QCOM_REQUANTIZE - ] - conv2d_node.meta.pop(QCOM_REQUANTIZE, None) - for user in node.users.copy(): user.replace_input_with(node, squeeze_node) diff --git a/backends/qualcomm/_passes/convert_bmm_to_matmul.py b/backends/qualcomm/_passes/convert_bmm_to_matmul.py index 3d4e44dfa42..262a3b9ef0f 100644 --- a/backends/qualcomm/_passes/convert_bmm_to_matmul.py +++ b/backends/qualcomm/_passes/convert_bmm_to_matmul.py @@ -47,7 +47,13 @@ def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph partitions = get_source_partitions( graph, - [operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default], + [ + "matmul", + operator.matmul, + torch.matmul, + torch.bmm, + torch.ops.aten.matmul.default, + ], ) for _, src_partitions in partitions.items(): for src_partition in src_partitions: diff --git a/backends/qualcomm/_passes/convert_mha_to_sha.py b/backends/qualcomm/_passes/convert_mha_to_sha.py new file mode 100644 index 00000000000..b225fe6d149 --- /dev/null +++ b/backends/qualcomm/_passes/convert_mha_to_sha.py @@ -0,0 +1,626 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch + +from executorch.backends.qualcomm._passes.utils import find_pattern +from executorch.backends.qualcomm.utils.constants import ( + QCOM_BLOCK_SIZE, + QCOM_QUANT_ATTRS, + QCOM_REQUANTIZE, + QCOM_SCALE, + QCOM_SCALES, + QCOM_ZERO_POINT, + QCOM_ZERO_POINTS, +) + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +from executorch.exir.passes.constant_prop_pass import constant_prop_pass + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def _is_node(node): + return isinstance(node, torch.fx.Node) + + +def _is_output(node): + return _is_node(node) and node.op == "output" + + +def _is_call(node): + return _is_node(node) and node.op == "call_function" + + +def _is_unsqueeze(node): + return _is_call(node) and node.target == exir_ops.edge.aten.unsqueeze_copy.default + + +def _is_view(node): + return _is_call(node) and node.target == exir_ops.edge.aten.view_copy.default + + +def _is_permute(node): + return _is_call(node) and node.target == exir_ops.edge.aten.permute_copy.default + + +def _is_matmul(node): + return _is_call(node) and node.target == exir_ops.edge.aten.matmul.default + + +def _is_bmm(node): + return _is_call(node) and node.target == exir_ops.edge.aten.bmm.default + + +def _is_expand(node): + return _is_call(node) and node.target == exir_ops.edge.aten.expand_copy.default + + +def _is_conv(node): + return _is_call(node) and node.target == exir_ops.edge.aten.convolution.default + + +def _is_softmax(node): + return _is_call(node) and node.target in [ + exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten._safe_softmax.default, + ] + + +def _shape(node): + assert "val" in node.meta + return list(node.meta["val"].shape) + + +@dataclass +class Sha: + axis: int + heads: int + + def __repr__(self): + return f"Sha(axis={self.axis}, heads={self.heads})" + + +class ConvertMhaToSha(ExportPass): + """ + b=batch, e=emb=h*d, h=heads, d=head_size, s=seq_len, p=past, c=s+p + + i[bse] ─┬─ q[bse] ─ [bhsd] ─ RoPE ─ [bhsd] ───────────────────── qk[bhsc] ─ mask ─ softmax ─ qkv[bhsd] ─ [bse] ─ o[bse] + ├─ k[bse] ─ [bhsd] ─ RoPE ─ [bhds] ─ k_cat[bhdc] ─(k_exp)─┘ │ + │ past_k[bhdp] ──┘ │ + └─ v[bse] ─ [bhsd] ───────────────── v_cat[bhcd] ─(v_exp)-────────────────────────────┘ + past_v[bhpd] ──┘ + """ + + def __init__( + self, + edge_program: torch.export.ExportedProgram, + verbose=False, + ): + super().__init__() + self.edge_program = edge_program + self.verbose = verbose + + def _nodes(self, graph_module, wanted_sources, node_checker=None): + nodes = [] + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in wanted_sources: + if node_checker is None or node_checker(node): + nodes.append(node) + return nodes + + def _get_attention_output(self, softmax): + """Output of MHA block or input of output projection""" + + pattern_qk = [_is_softmax, "*", lambda x: _is_matmul(x) or _is_bmm(x)] + qk = find_pattern(softmax, pattern_qk) + if not qk: + return None, None, None + + patterns_qkv = [ + _is_softmax, + "*", + lambda x: _is_matmul(x) or _is_bmm(x), + "*", + _is_permute, + _is_view, + ] + + qkv = find_pattern(softmax, patterns_qkv, from_args=False) + if qkv is None: + return None, None, None + + permute, reshape = qkv[0][-2:] + matmul = qkv[0][2] + attn_output = matmul + sha_axis = 1 + remove_nodes = [permute] + # the shape of attn_output should be [bhsd] + shape = _shape(attn_output.args[0]) + heads = shape[sha_axis] + sha = Sha(axis=sha_axis, heads=heads) + + return attn_output, sha, remove_nodes + + def _update_requantize_user(self, node): + if QCOM_REQUANTIZE in node.meta: + user_node_list = [user.name for user in node.users.keys()] + + new_dict = {} + for original_key in node.meta[QCOM_REQUANTIZE]: + for new_key in user_node_list: + # new_keys are the name of the split nodes whose naming pattern follows: _h_xxx + if original_key in new_key: + new_dict.update( + {new_key: node.meta[QCOM_REQUANTIZE][original_key]} + ) + node.meta[QCOM_REQUANTIZE] = new_dict + + def _split( # noqa: C901 + self, + graph_module: torch.fx.GraphModule, + attn_output: torch.fx.Node, + sha: Sha, + remove_nodes: List, + ): + """ + Main MHA to SHAs + - Start from the attention output or the input of the output projection node, assuming the head axis is 2. + - Recursively visit parent nodes until reaching the static Linear/Conv2D nodes, which must be the Q/K/V projection nodes. + - Splitting begins from the end of the recursion, which must be the Q/K/V projection nodes. + - The visit call will return the split nodes, which will be used by subsequent child visitors. + + Known issue + - Packed Q/K/V projection is not supported yet + """ + + def _visit_reshape(node, sha): + """Reshape: handle GQA pattern""" + in_shape, out_shape = _shape(node.args[0]), _shape(node) + if out_shape[sha.axis] % sha.heads == 1: + return _no_split(node, sha) + + assert ( + out_shape[sha.axis] % sha.heads == 0 + ), f"mismatching num_heads, {out_shape[sha.axis]} % {sha.heads} != 0" + + pattern_simple_gqa = [ + _is_view, + lambda x: _is_expand(x) and len(_shape(x)) == 5, + _is_unsqueeze, + ] + + if gqa := find_pattern(node, pattern_simple_gqa): + # GQA pattern: skip these and adjust sha.heads + if self.verbose: + logging.info(f"{__name__}:_visit_reshape: {node} is for GQA!") + _, expand, unsqueeze = gqa[0] + expand_shape = expand.args[1] + unsqueeze_dim = unsqueeze.args[1] + repeat_count = expand_shape[unsqueeze_dim] + kv_sha = Sha(sha.axis, in_shape[sha.axis]) + new_arg0s = _visit(unsqueeze.args[0], kv_sha) + new_arg0s = [arg for arg in new_arg0s for _ in range(repeat_count)] + else: + new_arg0s = _visit(node.args[0], sha) + + out_shape[sha.axis] //= sha.heads + new_args = [(arg0, out_shape) for arg0 in new_arg0s] + return _split_call(node, sha, new_args, out_shape) + + def _visit_permute(node, sha): + """Transpose: permute sha axis as well""" + out_shape = _shape(node) + assert ( + out_shape[sha.axis] % sha.heads == 0 + ), f"mismatching num_heads, {out_shape[sha.axis]} % {sha.heads} != 0" + out_shape[sha.axis] //= sha.heads + permute = node.args[1] + sha_permuted = Sha(axis=permute[sha.axis], heads=sha.heads) + new_arg0s = _visit(node.args[0], sha_permuted) + new_args = [(arg0, node.args[1]) for arg0 in new_arg0s] + return _split_call(node, sha, new_args, out_shape) + + def _visit_expand(node, sha): + out_shape = _shape(node) + if out_shape[sha.axis] != 1: + assert ( + out_shape[sha.axis] % sha.heads == 0 + ), f"mismatching num_heads, {out_shape[sha.axis]} % {sha.heads} != 0" + out_shape[sha.axis] //= sha.heads + + exp_shape = node.args[1] + if exp_shape[sha.axis] == 1: + return _visit_default(node, sha) + + assert ( + exp_shape[sha.axis] % sha.heads == 0 + ), f"mismatching expand shape, {exp_shape[sha.axis]} % {sha.heads} != 0" + new_exp_shape = type(exp_shape)( + [ + dim // sha.heads if axis == sha.axis else dim + for axis, dim in enumerate(exp_shape) + ] + ) + new_args = [(node.args[0], new_exp_shape)] * sha.heads + new_nodes = _split_call(node, sha, new_args, out_shape) + return new_nodes + + def _visit_cat(node, sha): + out_shape = _shape(node) + if out_shape[sha.axis] != 1: + assert ( + out_shape[sha.axis] % sha.heads == 0 + ), f"mismatching num_heads, {out_shape[sha.axis]} % {sha.heads} != 0" + out_shape[sha.axis] //= sha.heads + + assert isinstance(node.args[0], (tuple, list)) # concat + split_arg0s = [_visit(arg, sha) for arg in node.args[0]] + new_arg0s = list(zip(*split_arg0s)) + split_arg1s = [_visit(arg, sha) for arg in node.args[1:]] + new_arg1s = list(zip(*split_arg1s)) + new_args = [(arg0, *arg1) for arg0, arg1 in zip(new_arg0s, new_arg1s)] + + new_nodes = _split_call(node, sha, new_args, out_shape) + return new_nodes + + def _visit_default(node, sha): + out_shape = _shape(node) + + if out_shape[sha.axis] != 1: + assert ( + out_shape[sha.axis] % sha.heads == 0 + ), f"mismatching num_heads, {out_shape[sha.axis]} % {sha.heads} != 0" + out_shape[sha.axis] //= sha.heads + + assert not isinstance( + node.args[0], (tuple, list) + ), f"Unexpected cat node:{node}" + split_args = [_visit(arg, sha) for arg in node.args] + new_args = list(zip(*split_args)) + new_nodes = _split_call(node, sha, new_args, out_shape) + return new_nodes + + def _is_mha(node, sha): + if not _is_node(node): + return False + out_shape = _shape(node) + return len(out_shape) > sha.axis and out_shape[sha.axis] == sha.heads + + def _visit_binary(node, sha): + """elementwise binary operator visit mha inputs only""" + out_shape = _shape(node) + if out_shape[sha.axis] != 1: + assert ( + out_shape[sha.axis] % sha.heads == 0 + ), f"mismatching num_heads, {out_shape[sha.axis]} % {sha.heads} != 0" + out_shape[sha.axis] //= sha.heads + + split_args = [ + (_visit(arg, sha) if _is_mha(arg, sha) else [arg] * sha.heads) + for arg in node.args + ] + new_args = list(zip(*split_args)) + new_nodes = _split_call(node, sha, new_args, out_shape) + return new_nodes + + def _visit_placeholder(node, sha): + in_shape = _shape(node) + if ( + in_shape + and len(in_shape) > sha.axis + and in_shape[sha.axis] == sha.heads + ): # split past_kv by heads + new_nodes = _split_placeholder( + node, axis=sha.axis, size=1, count=sha.heads + ) + else: + # position embedding, attention mask, and R3 weights + new_nodes = _no_split(node, sha) + return new_nodes + + def _get_slicers(count, axis, size): + return [ + tuple( + [ + ( + slice(size * idx, size * (idx + 1)) + if ax == axis + else slice(None) + ) + for ax in range(axis + 1) + ] + ) + for idx in range(count) + ] + + def _split_call(node, sha, new_args, out_shape): + with graph_module.graph.inserting_after(node): + new_nodes = [] + slicers = _get_slicers(sha.heads, sha.axis, out_shape[sha.axis]) + for head, (args, slicer) in enumerate(zip(new_args, slicers)): + name = f"{node.name}_h_{head}" + new_nodes.append( + _duplicate_call(node, args, None, slicer, name=name) + ) + return new_nodes + + def _create_call( + op_target, args: Tuple, kwargs: Optional[dict] = None, name: str = None + ): + return graph_module.graph.create_node( + "call_function", + op_target, + args=args, + kwargs=kwargs or {}, + name=name, + ) + + def _no_split(node, sha): + return [node] * sha.heads + + def _copy_meta(dst_node, src_node, slicer): + dst_node.meta = src_node.meta.copy() + dst, src = dst_node.meta, src_node.meta + if "val" in src: + dst["val"] = src["val"].clone()[slicer] + if src_tensor_meta := src.get("tensor_meta", None) is not None: + tensor_meta = dict(zip(src_tensor_meta._fields, [*src_tensor_meta])) + tensor_meta["shape"] = dst["val"].shape + tensor_meta["stride"] = dst["val"].stride() + dst["tensor_meta"] = type(src_tensor_meta)(**tensor_meta) + # PCQ + if QCOM_QUANT_ATTRS in src and QCOM_SCALES in src[QCOM_QUANT_ATTRS]: + dst[QCOM_QUANT_ATTRS] = src[QCOM_QUANT_ATTRS].copy() + # slice for per channel quantize + dst[QCOM_QUANT_ATTRS][QCOM_SCALES] = src[QCOM_QUANT_ATTRS][ + QCOM_SCALES + ].clone()[slicer] + dst[QCOM_QUANT_ATTRS][QCOM_ZERO_POINTS] = src[QCOM_QUANT_ATTRS][ + QCOM_ZERO_POINTS + ].clone()[slicer] + + # LPBQ + if QCOM_QUANT_ATTRS in src and QCOM_BLOCK_SIZE in src[QCOM_QUANT_ATTRS]: + dst[QCOM_QUANT_ATTRS] = src[QCOM_QUANT_ATTRS].copy() + dst[QCOM_QUANT_ATTRS][QCOM_SCALE] = src[QCOM_QUANT_ATTRS][ + QCOM_SCALE + ].clone()[slicer] + dst[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT] = src[QCOM_QUANT_ATTRS][ + QCOM_ZERO_POINT + ].clone()[slicer] + + if "example_value" in src: + dst["example_value"] = src["example_value"].clone()[slicer] + + if QCOM_REQUANTIZE in src: + # We assume there is no requantize happens on the per-channel quantization weights, only per-tensor quantization + dst[QCOM_REQUANTIZE] = src[QCOM_REQUANTIZE].copy() + + def _duplicate_call( + node, args: Tuple, kwargs: Optional[dict] = None, slicer=None, name=None + ): + """Create SHA nodes by duplicating""" + assert ( + node.op == "call_function" + ), f"Unexpected node:{node.name}:{node.target}" + new_node = _create_call(node.target, args, kwargs, name=name) + _copy_meta(new_node, node, slicer) + return new_node + + def _split_placeholder(node, axis, size, count): + slice_op = exir_ops.edge.aten.slice_copy.Tensor + with graph_module.graph.inserting_after(node): + sliced_nodes = [] + for head, slicer in zip(range(count), _get_slicers(count, axis, size)): + sliced = _create_call( + slice_op, + (node, axis, slicer[axis].start, slicer[axis].stop), + name=f"{node.name}_h_{head}", + ) + _copy_meta(sliced, node, slicer) + sliced_nodes.append(sliced) + return sliced_nodes + + def _visit_linear_conv(node, sha): + """ + 0. Reshape of making multi-heads of MHA + - embedding = head * head_dim + - [batch, sequence, embedding] -> [batch, sequence, head, head_dim], + - [batch, sequence, embedding, 1] -> [batch, sequence, head, head_dim], embedding=head * head_dim + + 1. **q/k/v projections => stop recursion** + - 3D input and output + - Split output features + - ConvInplaceLinear + - [3d-unsqueeze-4d-permute-conv2d-permute-squeeze-3d] + - input: permute_copy(input): 4D[batch, in_feature, 1, num_input] => re-use + - weight[out_feature = heads * head_dim, in_feature, 1, 1] => heads * [head_dim, in_feature, 1, 1] + - So, split_axis=0 for Conv2D + + 2. **R3 of SpinQuant => continue recursion** + - 4D input and output + - ConvInplaceLinear + - [4d-permute-conv2d-permute-4d], **same as 3D case but no squeeze/unsqueeze** + - input: 4D [batch, head_dim, heads, num_input] => heads * [batch, head_dim, 1, num_input] + - weight: 2D [head_dim, head_dim, 1, 1] => re-use + """ + + def _is_making_mha(cur): + cur_sha = sha + pattern_conv_mha = ([_is_conv, "*", _is_permute, "*", _is_view], False) + if mha := find_pattern(cur, *pattern_conv_mha): + permute, reshape = mha[0][-3], mha[0][-1] + permutation = permute.args[1] + cur_sha = Sha( + permutation.index(sha.axis), sha.heads + ) # to reverse permute + else: + return False + + # Check whether this reshape is to make multi-heads or not + if len(reshape.args[1]) == 4: + # got MHA reshape + in_shape, out_shape = _shape(reshape.args[0]), _shape(reshape) + if ( + len(out_shape) > cur_sha.axis + 1 + and in_shape[cur_sha.axis] + == out_shape[cur_sha.axis] * out_shape[cur_sha.axis + 1] + ): + return True + return False + + if _is_making_mha(node): + if self.verbose: + logging.info( + f"{__name__}:_visit_linear_conv: {node} is making MHA!" + ) + out_feature, *_ = _shape(node.args[1]) + assert out_feature % sha.heads == 0 + out_feature_per_head = out_feature // sha.heads + + split_axis = 0 + new_weights = _split_placeholder( + node.args[1], + axis=split_axis, + size=out_feature_per_head, + count=sha.heads, + ) + if node.args[2] is not None: + new_bias = _split_placeholder( + node.args[2], + axis=split_axis, + size=out_feature_per_head, + count=sha.heads, + ) + + with graph_module.graph.inserting_after(node): + new_nodes = [] + slicers = _get_slicers(sha.heads, 1, out_feature_per_head) + if node.args[2] is not None: + for head, (weight, bias, slicer) in enumerate( + zip(new_weights, new_bias, slicers) + ): + name = f"{node.name}_h_{head}" + sliced = _duplicate_call( + node, + (node.args[0], weight, bias) + node.args[3:], + None, + slicer, + name=name, + ) + new_nodes.append(sliced) + else: + for head, (weight, slicer) in enumerate( + zip(new_weights, slicers) + ): + name = f"{node.name}_h_{head}" + sliced = _duplicate_call( + node, + (node.args[0], weight) + node.args[2:], + None, + slicer, + name=name, + ) + new_nodes.append(sliced) + + return new_nodes + else: + return _visit_default(node, sha) + + def _concat_sha_nodes(node, sha): + """Concat sha nodes and replace old node""" + sha_nodes = visited[node] + with graph_module.graph.inserting_after(sha_nodes[0]): + cat = exir_ops.edge.aten.cat.default + name = f"{node.name}_sha_concat" + new_node = _create_call(cat, (sha_nodes, sha.axis), name=name) + new_node.meta = node.meta.copy() + fake_tensors = [n.meta["val"] for n in sha_nodes] + result_fake_tensor = torch.cat(fake_tensors, sha.axis) + new_node.meta["val"] = result_fake_tensor + node.replace_all_uses_with(new_node) + + def _visit(node, sha): + if not _is_node(node): + return [node for _ in range(sha.heads)] + + if node in visited: + return visited[node] + + visitors = { + "placeholder": _visit_placeholder, + exir_ops.edge.aten.expand_copy.default: _visit_expand, + exir_ops.edge.aten.view_copy.default: _visit_reshape, + exir_ops.edge.aten.permute_copy.default: _visit_permute, + exir_ops.edge.aten.convolution.default: _visit_linear_conv, + exir_ops.edge.aten.mm.default: _visit_linear_conv, + exir_ops.edge.aten.cat.default: _visit_cat, + exir_ops.edge.aten.add.Tensor: _visit_binary, + exir_ops.edge.aten.mul.Tensor: _visit_binary, + exir_ops.edge.aten.eq.Tensor: _no_split, + } + + target = node.target if _is_call(node) else node.op + visited[node] = visitors.get(target, _visit_default)(node, sha) + + if [user for user in node.users.keys() if _is_output(user)]: + _concat_sha_nodes(node, sha) + return visited[node] + + if self.verbose: + logging.info(f"{__name__}:_split: attn_output:{attn_output}, sha:{sha}!") + visited = {} + _visit(attn_output, sha) + opt_sha = Sha(axis=3, heads=sha.heads) + _concat_sha_nodes(attn_output, opt_sha) + for remove_node in remove_nodes: + assert _is_permute(remove_node) or _is_view( + remove_node + ), "The removed nodes must be either transpose or reshape" + rnode_input = remove_node.args[0] + for user in list(remove_node.users): + new_args = tuple( + rnode_input if arg is remove_node else arg for arg in user.args + ) + user.args = new_args + for remove_node in remove_nodes: + graph_module.graph.erase_node(remove_node) + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + softmaxes = self._nodes( + graph_module, + [ + exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten._safe_softmax.default, + ], + ) + for softmax in softmaxes: + attn_output, sha, remove_nodes = self._get_attention_output(softmax) + if not attn_output: + continue + + self._split(graph_module, attn_output, sha, remove_nodes) + modified = True + + if modified: + for node in graph_module.graph.nodes: + self._update_requantize_user(node) + graph_module.graph.eliminate_dead_code() + constant_prop_pass(self.edge_program) # need to fuse sha weights + graph_module.recompile() + graph_module.graph.lint() + + return PassResult(graph_module, modified=modified) diff --git a/backends/qualcomm/_passes/decompose_any.py b/backends/qualcomm/_passes/decompose_any.py index e92bf11dd18..0cb959ff77f 100644 --- a/backends/qualcomm/_passes/decompose_any.py +++ b/backends/qualcomm/_passes/decompose_any.py @@ -8,6 +8,8 @@ from executorch.exir import to_edge from executorch.exir.pass_base import ExportPass, PassResult +from .utils import merge_decomposed_graph + class Any(torch.nn.Module): def __init__(self, dim, keepdim): @@ -49,26 +51,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # remap is used to map original node values to new node values, # which ensures that reference to nodes are correctly updated in the new graph remap = {"x": node.args[0]} - - for decomposed_node in decomposed_module.graph.nodes: - # no need to copy existent 'output' - if decomposed_node.op == "output": - for user in node.users.copy(): - # remap - user.replace_input_with( - node, - remap[decomposed_node.args[0][0]], - ) - # no need to copy existent placeholders - elif decomposed_node.op == "placeholder": - # replace node map from string to graph node - remap[decomposed_node] = remap.pop(decomposed_node.name) - else: - remap[decomposed_node] = graph.node_copy( - decomposed_node, - arg_transform=lambda x, remap=remap: remap[x], - ) - + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) graph.erase_node(node) graph.eliminate_dead_code() diff --git a/backends/qualcomm/_passes/decompose_binary_alpha.py b/backends/qualcomm/_passes/decompose_binary_alpha.py new file mode 100644 index 00000000000..df767f10ca9 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_binary_alpha.py @@ -0,0 +1,61 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import copy_meta + +decomp_set = {torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor} + + +class DecomposeBinaryAlpha(ExportPass): + """ + QNN does not support alpha parameter for add/sub. + Decompose to mul + add / mul + sub + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if ( + node.target in decomp_set + and "alpha" in node.kwargs + and node.kwargs["alpha"] != 1 + ): + alpha = node.kwargs["alpha"] + # Remove alpha from immutable dict + node.kwargs = {k: v for k, v in node.kwargs.items() if k != "alpha"} + input2_node = node.args[1] + # If input2 is constant, we can just multiply the value for optimization + if isinstance(input2_node, (int, float)): + arg_list = list(node.args) + arg_list[1] = input2_node * alpha + node.args = tuple(arg_list) + continue + with graph.inserting_before(node): + mul_op = torch.ops.aten.mul.Scalar + mul_node = graph.create_node( + "call_function", + mul_op, + ( + input2_node, + alpha, + ), + ) + mul_node.meta = copy_meta(node.meta) + node.replace_input_with(input2_node, mul_node) + node.args = ( + node.args[0], + mul_node, + ) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_cdist.py b/backends/qualcomm/_passes/decompose_cdist.py index d18a0295ffb..a3c812bdc37 100644 --- a/backends/qualcomm/_passes/decompose_cdist.py +++ b/backends/qualcomm/_passes/decompose_cdist.py @@ -7,6 +7,8 @@ import torch from executorch.exir.pass_base import ExportPass, PassResult +from .utils import merge_decomposed_graph + class CDist(torch.nn.Module): def __init__(self): @@ -54,26 +56,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # remap is used to map original node values to new node values, # which ensures that reference to nodes are correctly updated in the new graph remap = {"x": node.args[0], "y": node.args[1]} - - for decomposed_node in decomposed_module.graph.nodes: - # no need to copy existent 'output' - if decomposed_node.op == "output": - for user in node.users.copy(): - # remap - user.replace_input_with( - node, - remap[decomposed_node.args[0][0]], - ) - # no need to copy existent placeholders - elif decomposed_node.op == "placeholder": - # replace node map from string to graph node - remap[decomposed_node] = remap.pop(decomposed_node.name) - else: - remap[decomposed_node] = graph.node_copy( - decomposed_node, - arg_transform=lambda x, remap=remap: remap[x], - ) - + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) graph.erase_node(node) graph.eliminate_dead_code() diff --git a/backends/qualcomm/_passes/decompose_einsum.py b/backends/qualcomm/_passes/decompose_einsum.py index 046c1598311..464d989333f 100644 --- a/backends/qualcomm/_passes/decompose_einsum.py +++ b/backends/qualcomm/_passes/decompose_einsum.py @@ -8,7 +8,7 @@ from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.experimental.proxy_tensor import make_fx -from .utils import copy_nn_module_stack +from .utils import merge_decomposed_graph class DecomposeEinsum(ExportPass): @@ -37,30 +37,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for i, arg in enumerate(node.args[1]): remap[f"arg1_{i+1}"] = arg - for decomposed_node in decomposed_module.graph.nodes: - copy_nn_module_stack(node, decomposed_node) - # This is the arg[0] equation string, which is not required anymore after decomposition - if "arg0" in decomposed_node.name: - continue - - # no need to copy existent 'output' - if decomposed_node.op == "output": - for user in node.users.copy(): - # remap - user.replace_input_with( - node, - remap[decomposed_node.args[0][0]], - ) - # no need to copy existent placeholders - elif decomposed_node.op == "placeholder": - # replace node map from string to graph node - remap[decomposed_node] = remap.pop(decomposed_node.name) - else: - remap[decomposed_node] = graph.node_copy( - decomposed_node, - arg_transform=lambda x, remap=remap: remap[x], - ) - + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + predicate=lambda decomp_node: "arg0" not in decomp_node.name, + ) graph.erase_node(node) graph.eliminate_dead_code() diff --git a/backends/qualcomm/_passes/decompose_floor_divide.py b/backends/qualcomm/_passes/decompose_floor_divide.py new file mode 100644 index 00000000000..f7de074259e --- /dev/null +++ b/backends/qualcomm/_passes/decompose_floor_divide.py @@ -0,0 +1,62 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import merge_decomposed_graph + + +class FloorDivide(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + dtype = x.dtype + result = torch.div(x, y) + result = torch.floor(result) + return result.to(dtype) + + +class DecomposeFloorDivide(ExportPass): + """ + Decompose for math equivalent op. + Since QNN does not support floor_divide operations for int32 or int64 inputs, + it is necessary to decompose the operation into a division using floating-point precision, + followed by applying the floor function. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + model = FloorDivide() + if ( + torch.ops.aten.floor_divide.default == node.target + and not torch.is_floating_point(node.meta["val"]) + ): + decomposed_module = torch.export.export( + model, + (node.args[0].meta["val"], node.args[1].meta["val"]), + strict=True, + ).module() + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {"x": node.args[0], "y": node.args[1]} + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_glu.py b/backends/qualcomm/_passes/decompose_glu.py new file mode 100644 index 00000000000..de363468799 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_glu.py @@ -0,0 +1,55 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import merge_decomposed_graph + + +# this wrapper is required for IO name mapping with decomposed graph +class Glu(torch.nn.Module): + def __init__(self, dim=-1): + super().__init__() + self.glu = torch.nn.GLU(dim=dim) + + def forward(self, x): + return self.glu(x) + + +class DecomposeGlu(ExportPass): + """ + Decompose glu for quantization annotation to work properly. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if node.target == torch.ops.aten.glu.default: + ep = torch.export.export( + Glu(dim=-1 if len(node.args) < 2 else node.args[1]), + (node.args[0].meta["val"],), + ) + decomposed_module = ep.run_decompositions().graph_module + + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {"x": node.args[0]} + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_linalg_vector_norm.py b/backends/qualcomm/_passes/decompose_linalg_vector_norm.py index 993f088da12..94a5b10ba3f 100644 --- a/backends/qualcomm/_passes/decompose_linalg_vector_norm.py +++ b/backends/qualcomm/_passes/decompose_linalg_vector_norm.py @@ -8,7 +8,7 @@ from executorch.exir import to_edge from executorch.exir.pass_base import ExportPass, PassResult -from .utils import copy_nn_module_stack +from .utils import merge_decomposed_graph class LinalgVectorNorm(torch.nn.Module): @@ -62,27 +62,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # remap is used to map original node values to new node values, # which ensures that reference to nodes are correctly updated in the new graph remap = {"x": node.args[0]} - - for decomposed_node in decomposed_module.graph.nodes: - copy_nn_module_stack(node, decomposed_node) - # no need to copy existent 'output' - if decomposed_node.op == "output": - for user in node.users.copy(): - # remap - user.replace_input_with( - node, - remap[decomposed_node.args[0][0]], - ) - # no need to copy existent placeholders - elif decomposed_node.op == "placeholder": - # replace node map from string to graph node - remap[decomposed_node] = remap.pop(decomposed_node.name) - else: - remap[decomposed_node] = graph.node_copy( - decomposed_node, - arg_transform=lambda x, remap=remap: remap[x], - ) - + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) graph.erase_node(node) graph.eliminate_dead_code() diff --git a/backends/qualcomm/_passes/decompose_maxpool3d.py b/backends/qualcomm/_passes/decompose_maxpool3d.py new file mode 100644 index 00000000000..f8e750ebc64 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_maxpool3d.py @@ -0,0 +1,133 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from typing import cast, List + +import torch +import torch.nn as nn +from executorch.exir import to_edge +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import merge_decomposed_graph + + +class ModelMaxPool3D(torch.nn.Module): + def __init__( + self, filter_size, stride, padding, dilation, return_indices, ceil_mode + ): + super().__init__() + + self.pool2d_hw = nn.MaxPool2d( + kernel_size=[1, filter_size[2]], # (H, W) part + stride=[1, stride[2]], + padding=[0, padding[2]], + return_indices=return_indices, + ceil_mode=ceil_mode, + ) + self.pool2d_dh = nn.MaxPool2d( + kernel_size=filter_size[:2], # (D, H) part + stride=stride[:2], + padding=padding[:2], + return_indices=return_indices, + ceil_mode=ceil_mode, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + N, C, D, H, W = x.shape + x_ = x.permute(0, 1, 4, 2, 3) + x1_1d = x_.reshape(N * C, W, D, H) + # first pool over (D, H) + out_pool1d_0 = self.pool2d_dh(x1_1d) + D_out = out_pool1d_0.shape[2] + # NC, W, D, H-> NC, D, H, W + x1b = out_pool1d_0.permute(0, 2, 3, 1) + # second pool over (H, W) + out4d = self.pool2d_hw(x1b) + H_out2 = out4d.shape[2] + W_out = out4d.shape[3] + out = out4d.reshape(N, C, D_out, H_out2, W_out) + return out + + +class DecomposeMaxPool3d(ExportPass): + # The max_pool3d is not supported yet by QNN. + # Decompose: input -> permute -> reshape -> max_pool2d -> permute -> max_pool2d -> reshape -> output + + def __init__(self, quantization_capture=False) -> None: + super().__init__() + self.quantization_capture = quantization_capture + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if node.op == "call_function" and "max_pool3d" in str(node.target): + # kernel info + filter_size = cast(List[int], node.args[1]) + if len(filter_size) == 1: + filter_size *= 3 + + num_args = len(node.args) + + # stride info + stride = filter_size + if num_args > 2: + stride = cast(List[int], node.args[2]) + if len(stride) == 1: + stride *= 3 + + # padding info + padding = [0, 0, 0] + if num_args > 3: + padding = cast(List[int], node.args[3]) + if len(padding) == 1: + padding *= 3 + + # dilation info + dilation = [1, 1, 1] + if num_args > 4: + dilation = cast(List[int], node.args[4]) + if len(padding) == 1: + dilation *= 3 + + ceil_mode = node.args[5] if num_args > 5 else False + return_indices = node.args[6] if num_args > 6 else False + if return_indices: + warnings.warn( + "[QNN Delegate Op Builder]: The case return_indices=True is not be support, fallback", + stacklevel=1, + ) + return + + model = ModelMaxPool3D( + filter_size, stride, padding, dilation, return_indices, ceil_mode + ) + if self.quantization_capture: + decomposed_module = torch.export.export( + model, (node.args[0].meta["val"],), strict=True + ).module() + else: + edge_mgr = to_edge( + torch.export.export( + model, (node.args[0].meta["val"],), strict=True + ) + ) + decomposed_module = edge_mgr.exported_program() + + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {"x": node.args[0]} + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_roll.py b/backends/qualcomm/_passes/decompose_roll.py index e13433508f5..e6f60d55464 100644 --- a/backends/qualcomm/_passes/decompose_roll.py +++ b/backends/qualcomm/_passes/decompose_roll.py @@ -7,7 +7,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -from .utils import copy_nn_module_stack +from .utils import merge_decomposed_graph class SliceCopy(torch.nn.Module): @@ -65,27 +65,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # remap is used to map original node values to new node values, # which ensures that reference to nodes are correctly updated in the new graph remap = {"x": input_node} - - for decomposed_node in decomposed_module.graph.nodes: - copy_nn_module_stack(node, decomposed_node) - # no need to copy existent 'output' - if decomposed_node.op == "output": - for user in node.users.copy(): - # remap - user.replace_input_with( - node, - remap[decomposed_node.args[0][0]], - ) - # no need to copy existent placeholders - elif decomposed_node.op == "placeholder": - # replace node map from string to graph node - remap[decomposed_node] = remap.pop(decomposed_node.name) - else: - remap[decomposed_node] = graph.node_copy( - decomposed_node, - arg_transform=lambda x, remap=remap: remap[x], - ) - + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) graph.erase_node(node) graph.eliminate_dead_code() diff --git a/backends/qualcomm/_passes/decompose_silu.py b/backends/qualcomm/_passes/decompose_silu.py index c3ac45a8d9d..4336b6e95a3 100644 --- a/backends/qualcomm/_passes/decompose_silu.py +++ b/backends/qualcomm/_passes/decompose_silu.py @@ -17,10 +17,10 @@ def __init__(self): def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph for node in graph.nodes: - if ( - node.op == "call_function" - and node.target == torch.ops.aten.silu.default - ): + if node.op == "call_function" and node.target in { + torch.ops.aten.silu.default, + torch.ops.aten.silu_.default, + }: silu_node = node silu_node_input = node.args[0] with graph_module.graph.inserting_after(silu_node_input): diff --git a/backends/qualcomm/_passes/decompose_threshold.py b/backends/qualcomm/_passes/decompose_threshold.py new file mode 100644 index 00000000000..0f0a1bc4ea8 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_threshold.py @@ -0,0 +1,61 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import merge_decomposed_graph + + +class DecomposeModule(torch.nn.Module): + def __init__(self, threshold, value): + super().__init__() + self.threshold = threshold + self.value = value + + def forward(self, x): + return torch.where(x <= self.threshold, self.value, x) + + +class DecomposeThreshold(ExportPass): + """ + Decompose threshold to less_equal and where. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if node.target in { + torch.ops.aten.threshold_.default, + torch.ops.aten.threshold.default, + }: + input_node = node.args[0] + threshold = node.args[1] + value = node.args[2] + + model = DecomposeModule(threshold, value) + decomposed_module = torch.export.export( + model, (input_node.meta["val"],), strict=True + ).module() + + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {"x": input_node} + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_triu.py b/backends/qualcomm/_passes/decompose_triu.py new file mode 100644 index 00000000000..cb0450a499d --- /dev/null +++ b/backends/qualcomm/_passes/decompose_triu.py @@ -0,0 +1,71 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.exir.pass_base import ExportPass, PassResult +from torch._decomp import get_decompositions +from torch.fx.experimental.proxy_tensor import make_fx + +from .utils import merge_decomposed_graph + + +class DecomposeTriu(ExportPass): + """ + Decompose triu during quantization or export stage + This allows LiftConstantScalarOperands to lift the scalar into a scalar_tensor. + Otherwise, after to_edge, the triu operation will be decomposed into several operations that include aten.ge.Scalar. + """ + + def __init__(self) -> None: + super().__init__() + + def _replace_output( + self, node: torch.fx.Node, output_node: torch.fx.Node, remap: Dict + ): + for user in node.users.copy(): + # remap + user.replace_input_with( + node, + remap[output_node.args[0]], + ) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + decom_mappings = get_decompositions([torch.ops.aten.triu.default]) + + for node in graph.nodes: + if node.target == torch.ops.aten.triu.default: + input_args = [node.args[0].meta["val"]] + # The args[1], diagonal, is optional + if len(node.args) > 1: + input_args.append(node.args[1]) + decomposed_module = make_fx( + node.target, + decomposition_table=decom_mappings, + tracing_mode="fake", + )(*input_args) + + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {} + remap["arg0_1"] = node.args[0] + + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + predicate=lambda decomp_node: "arg1_1" not in decomp_node.name, + output_processor=self._replace_output, + ) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_wrap_with_autocast.py b/backends/qualcomm/_passes/decompose_wrap_with_autocast.py index 6c073bd309c..1b60b740ed3 100644 --- a/backends/qualcomm/_passes/decompose_wrap_with_autocast.py +++ b/backends/qualcomm/_passes/decompose_wrap_with_autocast.py @@ -10,7 +10,7 @@ import torch from executorch.exir.pass_base import ExportPass, PassResult -from .utils import copy_nn_module_stack +from .utils import merge_decomposed_graph class DecomposeWrapWithAutocast(ExportPass): @@ -52,7 +52,7 @@ def _replace(self, gm: torch.fx.GraphModule) -> None: graph = gm.graph for node in graph.nodes: if isinstance(node.target, torch._higher_order_ops.wrap.WrapWithAutocast): - submod, submod_name = self._get_submod(gm, node) + submod, _ = self._get_submod(gm, node) n_args = node.args input_submod = n_args[4] decomposed_module = submod @@ -61,22 +61,13 @@ def _replace(self, gm: torch.fx.GraphModule) -> None: # which ensures that reference to nodes are correctly updated in the new graph # remap = {"expand_1": node.args[5], "to_4": node.args[6]} remap = {n_args[i].name: n_args[i] for i in range(5, len(n_args))} - - for decomposed_node in decomposed_module.graph.nodes: - copy_nn_module_stack(node, decomposed_node) - # no need to copy existent 'output' - if decomposed_node.op == "output": - self._replace_output(node, decomposed_node, remap) - # no need to copy existent placeholders - elif decomposed_node.op == "placeholder": - # replace node map from string to graph node - remap[decomposed_node] = remap.pop(decomposed_node.name) - else: - remap[decomposed_node] = graph.node_copy( - decomposed_node, - arg_transform=lambda x, remap=remap: remap[x], - ) - + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + output_processor=self._replace_output, + ) graph.erase_node(node) graph.erase_node(input_submod) diff --git a/backends/qualcomm/_passes/fixed_linear_keep_dim.py b/backends/qualcomm/_passes/fixed_linear_keep_dim.py index 19f5c631921..04c0f92cebf 100644 --- a/backends/qualcomm/_passes/fixed_linear_keep_dim.py +++ b/backends/qualcomm/_passes/fixed_linear_keep_dim.py @@ -5,10 +5,14 @@ # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.builders.node_visitor import dq_ops +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.passes import dead_code_elimination_pass +from .utils import copy_meta, get_quant_attrs + class FixedLinearKeepDim(ExportPass): """ @@ -18,8 +22,12 @@ class FixedLinearKeepDim(ExportPass): view_copy = exir_ops.edge.aten.view_copy.default linear = exir_ops.edge.aten.linear.default - def __init__(self): + def __init__( + self, + edge_program: torch.export.ExportedProgram, + ): super(FixedLinearKeepDim, self).__init__() + self.edge_program = edge_program def _fixed_keep_dim(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: @@ -46,9 +54,15 @@ def _fixed_keep_dim(self, graph_module: torch.fx.GraphModule): ) # meta needs to be copied elementwisely for fake-tensor # to be updated correctly and not affect meta of input_node - for k, v in input_node.meta.items(): - squeeze_node.meta[k] = v + squeeze_node.meta = copy_meta(input_node.meta) squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim) + # if input_node is dequantize, we need to fetch encodings manually + # TODO: remove this when constant fold mechanism is introduced + if input_node.target in dq_ops: + squeeze_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs( + self.edge_program, input_node + ) + for user in input_users: if user == linear_node: user.replace_input_with(input_node, squeeze_node) @@ -66,8 +80,7 @@ def _fixed_keep_dim(self, graph_module: torch.fx.GraphModule): ) # meta needs to be copied elementwisely for fake-tensor # to be updated correctly and not affect meta of unsqueeze_node - for k, v in linear_node.meta.items(): - unsqueeze_node.meta[k] = v + unsqueeze_node.meta = copy_meta(linear_node.meta) # update linear node's shape linear_node.meta["val"] = linear_output.reshape( (squeeze_node.meta["val"].shape[0], linear_output.shape[-1]) diff --git a/backends/qualcomm/_passes/insert_reshape_for_reduce_ops.py b/backends/qualcomm/_passes/insert_reshape_for_reduce_ops.py new file mode 100644 index 00000000000..52f9546c28e --- /dev/null +++ b/backends/qualcomm/_passes/insert_reshape_for_reduce_ops.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + + +class InsertReshapeForReduceOps(ExportPass): + """ + Rewrite `aten.argmax.default` with `dim=None` into + a reshape-to-1D followed by argmax(dim=0). + + PyTorch semantics: + torch.argmax(x, dim=None) -> flatten(x) then argmax along axis=0 + + QNN requires an explicit axis, so we insert the reshape. + """ + + def __init__(self): + super().__init__() + self.op_map = {torch.ops.aten.argmax.default, torch.ops.aten.argmin.default} + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + modified = False + + for n in graph.nodes: + if n.target in self.op_map: + dim_arg = None if len(n.args) == 1 else n.args[1] + + if dim_arg is None: + inp = n.args[0] + + # Insert reshape before argmax + with graph.inserting_before(n): + reshape_node = graph.create_node( + "call_function", + torch.ops.aten.reshape.default, + (inp, [-1]), + {}, + ) + reshape_node.meta = dict(inp.meta) + if "val" in inp.meta: + reshape_node.meta["val"] = inp.meta["val"].reshape(-1) + + # Rewrite argmax: take reshape_node as input, set dim=0 + n.args = (reshape_node, 0, *n.args[2:]) + + modified = True + + if modified: + graph_module.recompile() + dead_code_elimination_pass(graph_module) + + return PassResult(graph_module, modified) diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index f285eb79bfb..691ba1607ff 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -42,8 +42,13 @@ class LayoutTransform(ExportPass): layout_sensitive_ops = { exir_ops.edge.aten.adaptive_avg_pool2d.default, + exir_ops.edge.aten._adaptive_avg_pool3d.default, + exir_ops.edge.aten.adaptive_max_pool2d.default, exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.avg_pool3d.default, exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.grid_sampler_2d.default, + exir_ops.edge.aten.grid_sampler_3d.default, exir_ops.edge.aten.instance_norm.default, exir_ops.edge.aten.max_pool2d_with_indices.default, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, @@ -93,6 +98,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.linear.default, exir_ops.edge.aten.log.default, + exir_ops.edge.aten.logical_and.default, exir_ops.edge.aten.logical_not.default, exir_ops.edge.aten.lt.Scalar, exir_ops.edge.aten.lt.Tensor, diff --git a/backends/qualcomm/_passes/lift_constant_scalar_operands.py b/backends/qualcomm/_passes/lift_constant_scalar_operands.py index f5c5915cab2..e5d9371709d 100644 --- a/backends/qualcomm/_passes/lift_constant_scalar_operands.py +++ b/backends/qualcomm/_passes/lift_constant_scalar_operands.py @@ -51,9 +51,11 @@ class TensorOpInfo: # The scalar number arg[1] is missing when using default. Result in a corner case to deal aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True, False), aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False), + aten.where.ScalarSelf: TensorOpInfo(aten.where.self, False, True), aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True), aten.where.Scalar: TensorOpInfo(aten.where.self, False, True), aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False), + aten.masked_fill_.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False), aten.bitwise_xor.Scalar: TensorOpInfo(aten.bitwise_xor.Tensor, False, False), } diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index ffb9f3221df..46a1dfb0970 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -16,16 +16,23 @@ CanonicalizeConv, ConvertBmmToMatmul, ConvertLinearToConv2d, + ConvertMhaToSha, ConvertSquareToPow, DecomposeAny, + DecomposeBinaryAlpha, DecomposeCDist, DecomposeColIm, DecomposeEinsum, DecomposeExpM1, + DecomposeFloorDivide, + DecomposeGlu, DecomposeLinalgVectorNorm, + DecomposeMaxPool3d, DecomposeMinMaxDim, DecomposeRoll, DecomposeSilu, + DecomposeThreshold, + DecomposeTriu, DecomposeWrapWithAutocast, ExpandBroadcastTensorShape, FixedLinearKeepDim, @@ -35,6 +42,7 @@ I64toI32, InsertIOQDQ, InsertRequantize, + InsertReshapeForReduceOps, LayoutTransform, LiftConstantScalarOperands, RecomposePixelUnshuffle, @@ -82,16 +90,16 @@ def get_capture_program_passes(): (AnnotateQuantAttrs, True), (AnnotateStack, True), (AnnotateUnbind, True), - (CanonicalizeConv, True), (ConvertBmmToMatmul, False), (DecomposeAny, True), (DecomposeColIm, True), (DecomposeMinMaxDim, True), - (ExpandBroadcastTensorShape, False), + (ExpandBroadcastTensorShape, True), (FixedLinearKeepDim, True), (FoldQDQ, True), (I64toI32, True), (LayoutTransform, True), + (DecomposeMaxPool3d, True), (RecomposePixelUnshuffle, True), (RecomposeRmsNorm, True), (Remove0DTensor, True), @@ -193,26 +201,40 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(RecomposePixelUnshuffle(quantization_capture=True)) self.add_pass(RecomposeRmsNorm(quantization_capture=True)) self.add_pass(ReplaceArangeArgs()) + self.add_pass(DecomposeBinaryAlpha()) self.add_pass(DecomposeCDist()) + self.add_pass(DecomposeMaxPool3d(quantization_capture=True)) self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) self.add_pass(DecomposeSilu()) + self.add_pass(DecomposeThreshold()) + self.add_pass(DecomposeTriu()) self.add_pass(DecomposeWrapWithAutocast()) self.add_pass(DecomposeEinsum()) self.add_pass(DecomposeExpM1()) + self.add_pass(DecomposeGlu()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(ReplaceInfValues()) self.add_pass(LiftConstantScalarOperands()) + self.add_pass(InsertReshapeForReduceOps()) return self._transform(graph_module) def transform_for_export_pipeline( self, exported_program: ExportedProgram, convert_linear_to_conv2d: bool = False ): + self.add_pass(DecomposeBinaryAlpha()) self.add_pass(DecomposeCDist()) self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) + self.add_pass(DecomposeThreshold()) + self.add_pass(DecomposeTriu()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(DecomposeExpM1()) + # DecomposeFloorDivide does not apply to the annotation pipeline, + # since the CPU QDQ model would reduce accuracy. + # We keep div and floor operations in floating-point to maintain precision. + # This pass is needed before to_edge pipeline to avoid mixed type for div operator with RemoveMixedTypeOperators pass. + self.add_pass(DecomposeFloorDivide()) self.add_pass(DecomposeWrapWithAutocast()) # this pass will rewrite state_dict, it needs to be accomplished before # to_edge_transform_and_lower @@ -221,12 +243,17 @@ def transform_for_export_pipeline( self.add_pass(ConvertLinearToConv2d(exported_program)) self.add_pass(ConvertSquareToPow()) self.add_pass(LiftConstantScalarOperands()) + self.add_pass(InsertReshapeForReduceOps()) self._transform(exported_program.graph_module) ep = lift_constant_tensor_pass(exported_program) return ep - def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram): + def transform_for_preprocess_pipeline( + self, exported_program: ExportedProgram, use_mha2sha=False + ): self.add_pass(FoldQDQ(exported_program, force_fold=True)) + if use_mha2sha: + self.add_pass(ConvertMhaToSha(exported_program)) self.add_pass(InsertRequantize()) self.add_pass(InsertIOQDQ(exported_program)) self.add_pass(LayoutTransform(exported_program, insert_permute=True)) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 6d908707892..e395511e438 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -69,6 +69,7 @@ def get_passes_dependency_for_capture_program(): DecomposeAny, DecomposeColIm, DecomposeLinalgVectorNorm, + DecomposeMaxPool3d, ExpandBroadcastTensorShape, FixedLinearKeepDim, FoldQDQ, @@ -93,6 +94,7 @@ def get_passes_dependency_for_capture_program(): DecomposeAny: [RemoveRedundancy], DecomposeColIm: [FoldQDQ], DecomposeLinalgVectorNorm: [RemoveRedundancy], + DecomposeMaxPool3d: [RemoveRedundancy], ExpandBroadcastTensorShape: [FoldQDQ], FixedLinearKeepDim: [FoldQDQ], FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind], @@ -117,6 +119,45 @@ def copy_nn_module_stack(src, target): target.meta["nn_module_stack"] = value +def merge_decomposed_graph( + remap: Dict[str, torch.fx.Node], + target_node: torch.fx.Node, + target_graph: torch.fx.GraphModule, + decomposed_graph_module: torch.fx.GraphModule, + predicate: Callable[[torch.fx.Node], None] = None, + # target_node, decomposed_output_node, remap + output_processor: Callable[ + [torch.fx.Node, torch.fx.Node, Dict[str, torch.fx.Node]], None + ] = None, +) -> None: + def default_output_process(node): + for user in node.users.copy(): + # remap + user.replace_input_with( + node, + remap[decomposed_node.args[0][0]], + ) + + for decomposed_node in decomposed_graph_module.graph.nodes: + copy_nn_module_stack(target_node, decomposed_node) + if predicate is None or predicate(decomposed_node): + # no need to copy existent 'output' + if decomposed_node.op == "output": + if output_processor is None: + default_output_process(target_node) + else: + output_processor(target_node, decomposed_node, remap) + # no need to copy existent placeholders + elif decomposed_node.op == "placeholder": + # replace node map from string to graph node + remap[decomposed_node] = remap.pop(decomposed_node.name) + else: + remap[decomposed_node] = target_graph.node_copy( + decomposed_node, + arg_transform=lambda x, remap=remap: remap[x], + ) + + def is_float_tensor(node: torch.fx.Node) -> bool: if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): return False @@ -138,7 +179,7 @@ def _next(node, from_args=True): yield from list(node.users) -def _find_pattern( +def find_pattern( node: torch.fx.Node, pattern: List[Callable[[torch.fx.Node], bool] | str], from_args: bool = True, @@ -151,6 +192,7 @@ def _find_pattern( - pattern: predicate list, can contain followings Callable(fx.node): predicate '*': wildcard + '?': any single node - from_args: if True find from node.args, otherwise from node.users - max_wildcard_life: max number of skips for wildcard @@ -158,7 +200,7 @@ def _find_pattern( Otherwise, return list of matched node list, which is the same length as pattern """ - asterisk = "*" + asterisk, question = "*", "?" def _probe( cur, hist, pat_idx, asterisk_life_count=max_wildcard_life, verbose=verbose @@ -173,7 +215,7 @@ def _probe( print( f"cur:{cur}, idx:{pat_idx}, life={asterisk_life_count}, pattern:{pattern[pat_idx]} hist={hist}" ) - if _pred(cur, pattern[pat_idx]): + if pattern[pat_idx] == question or _pred(cur, pattern[pat_idx]): hist.append(cur) for child in _next(cur, from_args): _probe(child, hist, pat_idx + 1) @@ -197,7 +239,8 @@ def _probe( # Check if pattern is valid assert all( - isinstance(i, Callable) or (isinstance(i, str) and i == "*") for i in pattern + isinstance(i, Callable) or (isinstance(i, str) and (i == "*" or i == "?")) + for i in pattern ), f"Invalid pattern: {pattern}" # Start probing @@ -210,7 +253,7 @@ def find_patterns(node, patterns, **kwargs): assert isinstance(patterns, list) and isinstance(patterns[0], list) results = [] for pattern in patterns: - result = _find_pattern(node, pattern, **kwargs) + result = find_pattern(node, pattern, **kwargs) results.append(result) return results diff --git a/backends/qualcomm/aot/python/CMakeLists.txt b/backends/qualcomm/aot/python/CMakeLists.txt index 337cfae1776..f84d7f01d86 100644 --- a/backends/qualcomm/aot/python/CMakeLists.txt +++ b/backends/qualcomm/aot/python/CMakeLists.txt @@ -9,9 +9,3 @@ target_sources( PyQnnManagerAdaptor PUBLIC ${CMAKE_CURRENT_LIST_DIR}/PyQnnManagerAdaptor.cpp ${CMAKE_CURRENT_LIST_DIR}/PyQnnManagerAdaptor.h ) - -# PyQnnWrapperAdaptor -target_sources( - PyQnnWrapperAdaptor PUBLIC ${CMAKE_CURRENT_LIST_DIR}/PyQnnWrapperAdaptor.cpp - ${CMAKE_CURRENT_LIST_DIR}/PyQnnWrapperAdaptor.h -) diff --git a/backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp b/backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp index 2511cd96636..1f45f062cfb 100644 --- a/backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp +++ b/backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp @@ -6,7 +6,11 @@ * LICENSE file in the root directory of this source tree. */ #include +#include +#include +#include #include +#include #include "QnnSdkBuildId.h" namespace py = pybind11; @@ -15,6 +19,133 @@ namespace backends { namespace qnn { using executorch::runtime::Error; +std::unique_ptr CreateQuantizationParamWrapper( + const Qnn_QuantizationEncoding_t& encoding, + py::dict& quant_info) { + std::unique_ptr quantize_param_wrapper; + if (encoding == QNN_QUANTIZATION_ENCODING_UNDEFINED) { + quantize_param_wrapper = std::make_unique(); + } else if (encoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + int32_t axis = quant_info["axis"].cast(); + std::vector scale_offset = + quant_info["scale_offset"].cast>(); + + quantize_param_wrapper = + std::make_unique( + axis, scale_offset); + } else if (encoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { + uint32_t bitwidth = quant_info["bitwidth"].cast(); + int32_t axis = quant_info["axis"].cast(); + std::vector scale_offset = + quant_info["scale_offset"].cast>(); + uint32_t num_elements = scale_offset.size(); + std::vector scales; + std::vector offsets; + for (const auto& scale_offset : scale_offset) { + scales.push_back(scale_offset.scale); + offsets.push_back(scale_offset.offset); + } + quantize_param_wrapper = + std::make_unique( + bitwidth, axis, num_elements, scales, offsets); + } else if (encoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) { + uint32_t bitwidth = quant_info["bitwidth"].cast(); + float scale = quant_info["scale"].cast(); + int32_t offset = quant_info["offset"].cast(); + quantize_param_wrapper = + std::make_unique( + bitwidth, scale, offset); + } else if (encoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + float scale = quant_info["scale"].cast(); + int32_t offset = quant_info["offset"].cast(); + quantize_param_wrapper = + std::make_unique(scale, offset); + } else if (encoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION) { + int32_t axis = quant_info["axis"].cast(); + std::vector scale_offset = + quant_info["block_scale_offset"].cast>(); + uint32_t num_blocks_per_axis = + quant_info["num_blocks_per_axis"].cast(); + uint32_t block_scale_bitwidth = + quant_info["block_scale_bitwidth"].cast(); + Qnn_BlockwiseExpansionBlockScaleStorageType_t block_storage_type = + quant_info["block_storage_type"] + .cast(); + std::vector buf = + quant_info["block_scales"].cast>(); + quantize_param_wrapper = + std::make_unique( + axis, + scale_offset, + num_blocks_per_axis, + block_scale_bitwidth, + block_storage_type, + buf.data(), + buf.size()); + } else { + QNN_EXECUTORCH_LOG_ERROR( + "Unknown the encoding of quantization: %d", encoding); + } + return quantize_param_wrapper; +} + +std::string GetScalarValue(const Qnn_Scalar_t& scalar) { + switch (scalar.dataType) { + case QNN_DATATYPE_FLOAT_32: + return std::to_string(scalar.floatValue); + case QNN_DATATYPE_FLOAT_64: + return std::to_string(scalar.doubleValue); + case QNN_DATATYPE_UINT_64: + return std::to_string(scalar.uint64Value); + case QNN_DATATYPE_INT_64: + return std::to_string(scalar.int64Value); + case QNN_DATATYPE_UINT_32: + return std::to_string(scalar.uint32Value); + case QNN_DATATYPE_INT_32: + return std::to_string(scalar.int32Value); + case QNN_DATATYPE_UINT_16: + return std::to_string(scalar.uint16Value); + case QNN_DATATYPE_INT_16: + return std::to_string(scalar.int16Value); + case QNN_DATATYPE_UINT_8: + return std::to_string(scalar.uint8Value); + case QNN_DATATYPE_INT_8: + return std::to_string(scalar.int8Value); + case QNN_DATATYPE_BOOL_8: + return std::to_string(static_cast(scalar.bool8Value)); + case QNN_DATATYPE_STRING: + return std::string(scalar.stringValue); + default: + return "QNN_DATATYPE_UNDEFINED"; + } +} + +std::shared_ptr CreateTensorWrapper( + const std::string& tensor_name, + Qnn_TensorType_t tensor_type, + Qnn_DataType_t data_type, + const Qnn_QuantizationEncoding_t& encoding, + py::dict& quant_info, + std::uint32_t rank, + const std::vector& dims, + const std::vector& dynamic_dims, + py::array& data, + bool copy_data) { + std::unique_ptr quantize_param_wrapper = + CreateQuantizationParamWrapper(encoding, quant_info); + + return CreateTensorWrapper( + tensor_name, + tensor_type, + data_type, + std::move(quantize_param_wrapper), + rank, + dims.data(), + dynamic_dims.data(), + 0, + data.size() == 0 ? nullptr : data.data(), + copy_data); +} std::string GetQnnSdkBuildId(std::string library_path) { QnnImplementation qnn_loaded_backend = QnnImplementation(library_path); @@ -28,15 +159,40 @@ std::string GetQnnSdkBuildId(std::string library_path) { if (err != QNN_SUCCESS || id == nullptr) { throw std::runtime_error("Failed to get QNN backend build ID"); } - qnn_loaded_backend.TerminateAllBackends(); + qnn_loaded_backend.Unload(); return std::string(id); } +py::array_t StripProtocol(const py::bytes& preprocessed_binary) { + py::buffer_info info(py::buffer(preprocessed_binary).request()); + + void* buf_ptr = nullptr; + size_t buf_size = 0; + // check if it's a qnn context binary + auto [status, signature, ctx_size, ctx_bin] = + QnnContextCustomProtocol().DeserializeContextCustomBuffer(info.ptr); + + if (status == Error::Ok) { + buf_size = ctx_size; + buf_ptr = ctx_bin; + } else { + // the format should be DLC, return nothing here + return py::array_t(0); + } + + auto result = py::array_t(buf_size); + auto result_buffer = result.request(); + std::memcpy(result_buffer.ptr, buf_ptr, buf_size); + return result; +} + PYBIND11_MODULE(PyQnnManagerAdaptor, m) { // TODO: Add related documents for configurations listed below using namespace qnn_delegate; + PYBIND11_NUMPY_DTYPE(PyQnnTensorWrapper::EncodingData, scale, offset); m.def("GetQnnSdkBuildId", &GetQnnSdkBuildId); + m.def("StripProtocol", &StripProtocol); py::class_(m, "QnnExecuTorchContextBinary") .def(py::init<>()); @@ -49,6 +205,8 @@ PYBIND11_MODULE(PyQnnManagerAdaptor, m) { .def(py::init()) .def(py::init()) .def("Init", &PyQnnManager::Init) + .def("InitBackend", &PyQnnManager::InitBackend) + .def("InitContext", &PyQnnManager::InitContext) .def("IsNodeSupportedByBackend", &PyQnnManager::IsNodeSupportedByBackend) .def( "Compile", @@ -57,6 +215,7 @@ PYBIND11_MODULE(PyQnnManagerAdaptor, m) { std::vector>>&>( &PyQnnManager::Compile)) .def("Destroy", &PyQnnManager::Destroy) + .def("DestroyContext", &PyQnnManager::DestroyContext) .def("IsAvailable", &PyQnnManager::IsAvailable) .def("IsTensorDump", &PyQnnManager::IsTensorDump) .def("AllocateTensor", &PyQnnManager::AllocateTensor) @@ -66,8 +225,377 @@ PYBIND11_MODULE(PyQnnManagerAdaptor, m) { .def("GetSpillFillBufferSize", &PyQnnManager::GetSpillFillBufferSize) .def( "MakeBinaryInfo", - py::overload_cast(&PyQnnManager::MakeBinaryInfo)) - .def("StripProtocol", &PyQnnManager::StripProtocol); + py::overload_cast(&PyQnnManager::MakeBinaryInfo)); + + py::enum_(m, "Qnn_TensorType_t") + .value( + "QNN_TENSOR_TYPE_APP_WRITE", + Qnn_TensorType_t::QNN_TENSOR_TYPE_APP_WRITE) + .value( + "QNN_TENSOR_TYPE_APP_READ", + Qnn_TensorType_t::QNN_TENSOR_TYPE_APP_READ) + .value( + "QNN_TENSOR_TYPE_APP_READWRITE", + Qnn_TensorType_t::QNN_TENSOR_TYPE_APP_READWRITE) + .value("QNN_TENSOR_TYPE_NATIVE", Qnn_TensorType_t::QNN_TENSOR_TYPE_NATIVE) + .value("QNN_TENSOR_TYPE_STATIC", Qnn_TensorType_t::QNN_TENSOR_TYPE_STATIC) + .value("QNN_TENSOR_TYPE_NULL", Qnn_TensorType_t::QNN_TENSOR_TYPE_NULL) + .value( + "QNN_TENSOR_TYPE_UNDEFINED", + Qnn_TensorType_t::QNN_TENSOR_TYPE_UNDEFINED) + .export_values(); + + py::enum_(m, "Qnn_DataType_t") + .value("QNN_DATATYPE_INT_8", Qnn_DataType_t::QNN_DATATYPE_INT_8) + .value("QNN_DATATYPE_INT_16", Qnn_DataType_t::QNN_DATATYPE_INT_16) + .value("QNN_DATATYPE_INT_32", Qnn_DataType_t::QNN_DATATYPE_INT_32) + .value("QNN_DATATYPE_INT_64", Qnn_DataType_t::QNN_DATATYPE_INT_64) + .value("QNN_DATATYPE_UINT_8", Qnn_DataType_t::QNN_DATATYPE_UINT_8) + .value("QNN_DATATYPE_UINT_16", Qnn_DataType_t::QNN_DATATYPE_UINT_16) + .value("QNN_DATATYPE_UINT_32", Qnn_DataType_t::QNN_DATATYPE_UINT_32) + .value("QNN_DATATYPE_UINT_64", Qnn_DataType_t::QNN_DATATYPE_UINT_64) + .value("QNN_DATATYPE_FLOAT_16", Qnn_DataType_t::QNN_DATATYPE_FLOAT_16) + .value("QNN_DATATYPE_FLOAT_32", Qnn_DataType_t::QNN_DATATYPE_FLOAT_32) + .value( + "QNN_DATATYPE_SFIXED_POINT_8", + Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_8) + .value( + "QNN_DATATYPE_SFIXED_POINT_16", + Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_16) + .value( + "QNN_DATATYPE_SFIXED_POINT_32", + Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_32) + .value( + "QNN_DATATYPE_UFIXED_POINT_8", + Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8) + .value( + "QNN_DATATYPE_UFIXED_POINT_16", + Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16) + .value( + "QNN_DATATYPE_UFIXED_POINT_32", + Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_32) + .value("QNN_DATATYPE_BOOL_8", Qnn_DataType_t::QNN_DATATYPE_BOOL_8) + .value("QNN_DATATYPE_UNDEFINED", Qnn_DataType_t::QNN_DATATYPE_UNDEFINED) + .export_values(); + + py::enum_(m, "Qnn_QuantizationEncoding_t") + .value( + "QNN_QUANTIZATION_ENCODING_SCALE_OFFSET", + Qnn_QuantizationEncoding_t::QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) + .value( + "QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET", + Qnn_QuantizationEncoding_t:: + QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) + .value( + "QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET", + Qnn_QuantizationEncoding_t::QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) + .value( + "QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET", + Qnn_QuantizationEncoding_t:: + QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) + .value( + "QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION", + Qnn_QuantizationEncoding_t:: + QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION) + .value( + "QNN_QUANTIZATION_ENCODING_UNDEFINED", + Qnn_QuantizationEncoding_t::QNN_QUANTIZATION_ENCODING_UNDEFINED) + .export_values(); + + py::enum_( + m, "Qnn_BlockwiseExpansionBlockScaleStorageType_t") + .value( + "QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8", + Qnn_BlockwiseExpansionBlockScaleStorageType_t:: + QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8) + .value( + "QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_16", + Qnn_BlockwiseExpansionBlockScaleStorageType_t:: + QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_16) + .value( + "QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_UNDEFINED", + Qnn_BlockwiseExpansionBlockScaleStorageType_t:: + QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_UNDEFINED) + .export_values(); + + py::class_>(m, "OpWrapper") + .def(py::init< + const std::string&, + const std::string&, + const std::string&>()) + .def( + "GetInputTensors", + &OpWrapper::GetInputTensors, + "A function which gets input tensors") + .def( + "GetOutputTensors", + &OpWrapper::GetOutputTensors, + "A function which gets output tensors") + .def("GetOpType", &OpWrapper::GetOpType, "A function which gets op type") + .def("GetName", &OpWrapper::GetName, "A function which gets name") + .def( + "GetPackageName", + &OpWrapper::GetPackageName, + "A function which gets package name") + .def( + "GetParams", &OpWrapper::GetRawParams, "A function which gets params") + // lambda function + // python: op_wrapper.GetOpConfig() + .def( + "GetOpConfig", + [](OpWrapper& self) { + auto op_config = self.GetOpConfig(); + py::dict result; + py::list params_list; + py::list input_tensors_list; + py::list output_tensors_list; + result["version"] = op_config.version; + result["name"] = QNN_OP_VER_PTR(op_config)->name; + result["packageName"] = QNN_OP_VER_PTR(op_config)->packageName; + result["typeName"] = QNN_OP_VER_PTR(op_config)->typeName; + result["numOfParams"] = QNN_OP_VER_PTR(op_config)->numOfParams; + for (size_t i = 0; i < QNN_OP_VER_PTR(op_config)->numOfParams; + ++i) { + params_list.append(QNN_OP_VER_PTR(op_config)->params[i]); + } + result["params"] = params_list; + result["numOfInputs"] = QNN_OP_VER_PTR(op_config)->numOfInputs; + for (size_t i = 0; i < QNN_OP_VER_PTR(op_config)->numOfInputs; + ++i) { + input_tensors_list.append( + QNN_OP_VER_PTR(op_config)->inputTensors[i]); + } + result["inputTensors"] = input_tensors_list; + result["numOfOutputs"] = QNN_OP_VER_PTR(op_config)->numOfOutputs; + for (size_t i = 0; i < QNN_OP_VER_PTR(op_config)->numOfOutputs; + ++i) { + output_tensors_list.append( + QNN_OP_VER_PTR(op_config)->outputTensors[i]); + } + result["outputTensors"] = output_tensors_list; + return result; + }, + "Get operator configuration"); + + py::class_>(m, "TensorWrapper") + .def(py::init(py::overload_cast< + const std::string&, + Qnn_TensorType_t, + Qnn_DataType_t, + const Qnn_QuantizationEncoding_t&, + py::dict&, + std::uint32_t, + const std::vector&, + const std::vector&, + py::array&, + bool>(&CreateTensorWrapper))); + + py::class_(m, "QuantizeParamsWrapper"); + + py::class_(m, "Qnn_ScaleOffset_t") + .def(py::init()) + .def_readonly("scale", &Qnn_ScaleOffset_t::scale) + .def_readonly("offset", &Qnn_ScaleOffset_t::offset); + + py::class_>( + m, "PyQnnOpWrapper") + .def(py::init< + const std::string&, + const std::string&, + const std::string&>()) + .def( + "AddInputTensors", + &PyQnnOpWrapper::AddInputTensors, + "A function which add input tensor wrapper into op wrapper", + py::arg("tensors")) + .def( + "AddOutputTensors", + &PyQnnOpWrapper::AddOutputTensors, + "A function which add output tensor wrapper into op wrapper", + py::arg("tensors")) + .def( + "AddTensorParam", + &PyQnnOpWrapper::AddTensorParam, + "A function which add tensor parameter into op wrapper", + py::arg("name"), + py::arg("data_type"), + py::arg("rank"), + py::arg("dims"), + py::arg("data"), + py::arg("copy_data")) + .def( + "AddScalarParam", + &PyQnnOpWrapper::AddScalarParam, + "A function which add scalar parameter into op wrapper", + py::arg("name"), + py::arg("data_type"), + py::arg("attrData")) + .def( + "GetOpWrapper", + &PyQnnOpWrapper::GetOpWrapper, + "A function which get op wrapper"); + + py::class_(m, "Encoding") + .def_readonly("data", &PyQnnTensorWrapper::Encoding::data) + .def_readonly("axis", &PyQnnTensorWrapper::Encoding::axis); + + py::class_>( + m, "PyQnnTensorWrapper") + .def(py::init&>()) + .def("GetDims", &PyQnnTensorWrapper::GetDims) + .def("GetDataType", &PyQnnTensorWrapper::GetDataType) + .def("GetName", &PyQnnTensorWrapper::GetName) + .def("GetEncodings", &PyQnnTensorWrapper::GetEncodings); + + py::class_(m, "Qnn_OpConfig") + .def_readonly("version", &Qnn_OpConfig_t::version) + // getter + // python: op_wrapper.GetOpConfig().v1 + .def_property_readonly( + "v1", [](const Qnn_OpConfig_t& config) -> const Qnn_OpConfigV1_t& { + return config.v1; + }); + + py::enum_(m, "Qnn_OpConfigVersion") + .value("QNN_OPCONFIG_VERSION_1", QNN_OPCONFIG_VERSION_1) + .value("QNN_OPCONFIG_VERSION_UNDEFINED", QNN_OPCONFIG_VERSION_UNDEFINED) + .export_values(); + + py::class_(m, "Qnn_OpConfigV1") + .def_readonly("name", &Qnn_OpConfigV1_t::name) + .def_readonly("packageName", &Qnn_OpConfigV1_t::packageName) + .def_readonly("typeName", &Qnn_OpConfigV1_t::typeName) + .def_readonly("numOfParams", &Qnn_OpConfigV1_t::numOfParams) + .def_readonly("params", &Qnn_OpConfigV1_t::params) + .def_readonly("numOfInputs", &Qnn_OpConfigV1_t::numOfInputs) + .def_readonly("inputTensors", &Qnn_OpConfigV1_t::inputTensors) + .def_readonly("numOfOutputs", &Qnn_OpConfigV1_t::numOfOutputs) + .def_readonly("outputTensors", &Qnn_OpConfigV1_t::outputTensors); + + py::class_(m, "Qnn_Param") + .def_readonly("paramType", &Qnn_Param_t::paramType) + .def_readonly("name", &Qnn_Param_t::name) + .def_property_readonly( + "scalarParam", + [](const Qnn_Param_t& param) -> const Qnn_Scalar_t& { + if (param.paramType == Qnn_ParamType_t::QNN_PARAMTYPE_SCALAR) { + return param.scalarParam; + } + throw std::runtime_error("ParamType is not scalar."); + }) + .def_property_readonly( + "tensorParam", [](const Qnn_Param_t& param) -> const Qnn_Tensor_t& { + if (param.paramType == Qnn_ParamType_t::QNN_PARAMTYPE_TENSOR) { + return param.tensorParam; + } + throw std::runtime_error("ParamType is not tensor."); + }); + + py::enum_(m, "Qnn_ParamType_t") + .value("QNN_PARAMTYPE_SCALAR", Qnn_ParamType_t::QNN_PARAMTYPE_SCALAR) + .value("QNN_PARAMTYPE_TENSOR", Qnn_ParamType_t::QNN_PARAMTYPE_TENSOR) + .value( + "QNN_PARAMTYPE_UNDEFINED", Qnn_ParamType_t::QNN_PARAMTYPE_UNDEFINED) + .export_values(); + + py::class_(m, "Qnn_Scalar_t") + .def_readonly("dataType", &Qnn_Scalar_t::dataType) + .def("value", &GetScalarValue, "Get the value of the scalar as a string"); + + py::class_(m, "Qnn_Tensor_t") + .def_readonly("version", &Qnn_Tensor_t::version) + .def_property_readonly("v2", [](Qnn_Tensor_t& t) -> Qnn_TensorV2_t& { + if (t.version == QNN_TENSOR_VERSION_2) { + return t.v2; + } + throw std::runtime_error("Tensor version is not V2."); + }); + + py::enum_(m, "Qnn_TensorVersion_t") + .value("QNN_TENSOR_VERSION_1", Qnn_TensorVersion_t::QNN_TENSOR_VERSION_1) + .value("QNN_TENSOR_VERSION_2", Qnn_TensorVersion_t::QNN_TENSOR_VERSION_2) + .value( + "QNN_TENSOR_VERSION_UNDEFINED", + Qnn_TensorVersion_t::QNN_TENSOR_VERSION_UNDEFINED) + .export_values(); + + py::class_(m, "Qnn_TensorV2_t") + .def_readonly("id", &Qnn_TensorV2_t::id) + .def_readonly("name", &Qnn_TensorV2_t::name) + .def_readonly("type", &Qnn_TensorV2_t::type) + .def_readonly("dataFormat", &Qnn_TensorV2_t::dataFormat) + .def_readonly("dataType", &Qnn_TensorV2_t::dataType) + .def_readonly("quantizeParams", &Qnn_TensorV2_t::quantizeParams) + .def_readonly("rank", &Qnn_TensorV2_t::rank) + // change dimensions pointer to vector(begin to rank) + .def_property_readonly( + "dimensions", + [](const Qnn_TensorV2_t& t) { + return std::vector(t.dimensions, t.dimensions + t.rank); + }) + .def_property_readonly( + "isDynamicDimensions", + [](const Qnn_TensorV2_t& t) { + return t.dimensions == nullptr + ? std::vector() + : std::vector(t.dimensions, t.dimensions + t.rank); + }) + .def_readonly("memType", &Qnn_TensorV2_t::memType); + + py::enum_(m, "Qnn_TensorMemType_t") + .value( + "QNN_TENSORMEMTYPE_RAW", Qnn_TensorMemType_t::QNN_TENSORMEMTYPE_RAW) + .value( + "QNN_TENSORMEMTYPE_MEMHANDLE", + Qnn_TensorMemType_t::QNN_TENSORMEMTYPE_MEMHANDLE) + .value( + "QNN_TENSORMEMTYPE_UNDEFINED", + Qnn_TensorMemType_t::QNN_TENSORMEMTYPE_UNDEFINED) + .export_values(); + + py::class_(m, "QnnQuantizeParams") + .def_readonly( + "encodingDefinition", &Qnn_QuantizeParams_t::encodingDefinition) + .def_readonly( + "quantizationEncoding", &Qnn_QuantizeParams_t::quantizationEncoding) + .def_property_readonly( + "scaleOffsetEncoding", + [](const Qnn_QuantizeParams_t& qp) { + if (qp.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + return qp.scaleOffsetEncoding; + } + throw std::runtime_error( + "Invalid quantization encoding type for scaleOffsetEncoding."); + }) + .def_property_readonly( + "axisScaleOffsetEncoding", [](const Qnn_QuantizeParams_t& qp) { + if (qp.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + return qp.axisScaleOffsetEncoding; + } + throw std::runtime_error( + "Invalid quantization encoding type for axisScaleOffsetEncoding."); + }); + + py::enum_(m, "QnnDefinition") + .value( + "QNN_DEFINITION_IMPL_GENERATED", + Qnn_Definition_t::QNN_DEFINITION_IMPL_GENERATED) + .value("QNN_DEFINITION_DEFINED", Qnn_Definition_t::QNN_DEFINITION_DEFINED) + .value( + "QNN_DEFINITION_UNDEFINED", + Qnn_Definition_t::QNN_DEFINITION_UNDEFINED) + .export_values(); + + py::class_(m, "QnnAxisScaleOffset") + .def_readonly("axis", &Qnn_AxisScaleOffset_t::axis) + .def_readonly("numScaleOffsets", &Qnn_AxisScaleOffset_t::numScaleOffsets) + .def_property_readonly( + "scaleOffset", [](const Qnn_AxisScaleOffset_t& aso) { + return std::vector( + aso.scaleOffset, aso.scaleOffset + aso.numScaleOffsets); + }); } } // namespace qnn } // namespace backends diff --git a/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h b/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h index c8044e5db0e..c1434db5573 100644 --- a/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h +++ b/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h @@ -6,7 +6,8 @@ * LICENSE file in the root directory of this source tree. */ #pragma once -#include +#include +#include #include #include #include @@ -22,6 +23,171 @@ namespace py = pybind11; namespace executorch { namespace backends { namespace qnn { +class PyQnnOpWrapper { + public: + explicit PyQnnOpWrapper( + const std::string& name, + const std::string& package_name, + const std::string& op_type) { + op_wrapper_ = std::make_shared(name, package_name, op_type); + } + void AddInputTensors( + const std::vector>& tensors) { + op_wrapper_->AddInputTensors(tensors); + } + + void AddOutputTensors( + const std::vector>& tensors) { + op_wrapper_->AddOutputTensors(tensors); + } + + void AddTensorParam( + const std::string& name, + Qnn_DataType_t data_type, + std::uint32_t rank, + const std::vector& dims, + py::array& data, + bool copy_data) { + op_wrapper_->AddTensorParam( + name, data_type, rank, dims.data(), data.data(), copy_data); + } + + void AddScalarParam( + const std::string& name, + Qnn_DataType_t data_type, + py::dict& attrData) { + switch (data_type) { + case Qnn_DataType_t::QNN_DATATYPE_INT_32: + op_wrapper_->AddScalarParam( + name, data_type, attrData["data"].cast()); + break; + case Qnn_DataType_t::QNN_DATATYPE_INT_16: + op_wrapper_->AddScalarParam( + name, data_type, attrData["data"].cast()); + break; + case Qnn_DataType_t::QNN_DATATYPE_INT_8: + op_wrapper_->AddScalarParam( + name, data_type, attrData["data"].cast()); + break; + case Qnn_DataType_t::QNN_DATATYPE_UINT_32: + op_wrapper_->AddScalarParam( + name, data_type, attrData["data"].cast()); + break; + case Qnn_DataType_t::QNN_DATATYPE_UINT_16: + op_wrapper_->AddScalarParam( + name, data_type, attrData["data"].cast()); + break; + case Qnn_DataType_t::QNN_DATATYPE_UINT_8: + op_wrapper_->AddScalarParam( + name, data_type, attrData["data"].cast()); + break; + case Qnn_DataType_t::QNN_DATATYPE_FLOAT_32: + case Qnn_DataType_t::QNN_DATATYPE_FLOAT_16: + op_wrapper_->AddScalarParam( + name, data_type, attrData["data"].cast()); + break; + case Qnn_DataType_t::QNN_DATATYPE_BOOL_8: + op_wrapper_->AddScalarParam( + name, data_type, attrData["data"].cast()); + break; + default: + QNN_EXECUTORCH_LOG_ERROR( + "%s has invalid data type: %d", name.c_str(), data_type); + break; + } + } + std::shared_ptr& GetOpWrapper() { + return op_wrapper_; + } + + private: + std::shared_ptr op_wrapper_; +}; + +class PyQnnTensorWrapper { + public: + explicit PyQnnTensorWrapper(const std::shared_ptr& wrapper) { + tensor_wrapper_ = wrapper; + } + struct EncodingData { + float scale; + int32_t offset; + }; + struct Encoding { + py::array_t data; + int32_t axis; + }; + + py::array_t GetDims() { + std::uint32_t* dim = tensor_wrapper_->GetDims(); + size_t shape[1]{tensor_wrapper_->GetRank()}; + size_t stride[1]{sizeof(std::uint32_t)}; + auto ret = py::array_t(shape, stride); + auto view = ret.mutable_unchecked<1>(); + for (int i = 0; i < ret.shape(0); ++i) { + view(i) = dim[i]; + } + return ret; + } + std::string GetName() { + return tensor_wrapper_->GetName(); + } + Qnn_DataType_t GetDataType() { + return tensor_wrapper_->GetDataType(); + } + Encoding GetEncodings() { + auto q_param = tensor_wrapper_->GetQuantizeParams(); + size_t stride[1]{sizeof(EncodingData)}; + + switch (q_param.quantizationEncoding) { + case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET: { + Qnn_ScaleOffset_t data = q_param.scaleOffsetEncoding; + size_t shape[1]{1}; + auto enc_data = py::array_t(shape, stride); + auto view = enc_data.mutable_unchecked<1>(); + view(0) = {data.scale, data.offset}; + return {enc_data, -1}; + } + case QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: { + Qnn_AxisScaleOffset_t data = q_param.axisScaleOffsetEncoding; + size_t shape[1]{data.numScaleOffsets}; + auto enc_data = py::array_t(shape, stride); + auto view = enc_data.mutable_unchecked<1>(); + for (int i = 0; i < enc_data.shape(0); ++i) { + view(i) = {data.scaleOffset[i].scale, data.scaleOffset[i].offset}; + } + return {enc_data, data.axis}; + } + case QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET: { + Qnn_BwScaleOffset_t data = q_param.bwScaleOffsetEncoding; + size_t shape[1]{1}; + auto enc_data = py::array_t(shape, stride); + auto view = enc_data.mutable_unchecked<1>(); + view(0) = {data.scale, data.offset}; + return {enc_data, -1}; + } + case QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: { + Qnn_BwAxisScaleOffset_t data = q_param.bwAxisScaleOffsetEncoding; + size_t shape[1]{data.numElements}; + auto enc_data = py::array_t(shape, stride); + auto view = enc_data.mutable_unchecked<1>(); + for (int i = 0; i < enc_data.shape(0); ++i) { + view(i) = {data.scales[i], data.offsets[i]}; + } + return {enc_data, data.axis}; + } + default: + QNN_EXECUTORCH_LOG_WARN( + "%s QNN_QUANTIZATION_ENCODING_UNDEFINED detected", + GetName().c_str()); + break; + } + return {}; + } + + private: + std::shared_ptr tensor_wrapper_; +}; class PyQnnManager { public: // used for AoT compilation @@ -50,7 +216,24 @@ class PyQnnManager { } executorch::runtime::Error Init() { - return qnn_manager_->Init(); + ET_CHECK_OR_RETURN_ERROR( + qnn_manager_->InitBackend() == Error::Ok, + Internal, + "Fail to initailize backend"); + ET_CHECK_OR_RETURN_ERROR( + qnn_manager_->InitContext() == Error::Ok, + Internal, + "Fail to initailize context"); + return Error::Ok; + } + + executorch::runtime::Error InitBackend() { + return qnn_manager_->InitBackend(); + } + + executorch::runtime::Error InitContext( + const std::vector& graph_names) { + return qnn_manager_->InitContext(std::optional{graph_names}); } bool IsNodeSupportedByBackend( @@ -90,6 +273,10 @@ class PyQnnManager { return qnn_manager_->Destroy(); } + void DestroyContext() { + return qnn_manager_->DestroyContext(); + } + bool IsAvailable() { return qnn_manager_->IsAvailable(); } @@ -148,29 +335,6 @@ class PyQnnManager { return result; } - py::array_t StripProtocol(const py::bytes& preprocessed_binary) { - py::buffer_info info(py::buffer(preprocessed_binary).request()); - - void* buf_ptr = nullptr; - size_t buf_size = 0; - // check if it's a qnn context binary - auto [status, signature, ctx_size, ctx_bin] = - QnnContextCustomProtocol().DeserializeContextCustomBuffer(info.ptr); - - if (status == Error::Ok) { - buf_size = ctx_size; - buf_ptr = ctx_bin; - } else { - // the format should be DLC, return nothing here - return py::array_t(0); - } - - auto result = py::array_t(buf_size); - auto result_buffer = result.request(); - std::memcpy(result_buffer.ptr, buf_ptr, buf_size); - return result; - } - private: // Store the bytes object instead of a raw pointer so that this module will // keep the bytes alive. @@ -178,8 +342,8 @@ class PyQnnManager { QnnExecuTorchContextBinary qnn_executorch_context_binary_; std::shared_ptr qnn_manager_; QnnContextCustomProtocol custom_context_custom_buffer_; - flatbuffers::FlatBufferBuilder builder_; }; + } // namespace qnn } // namespace backends } // namespace executorch diff --git a/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp b/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp deleted file mode 100644 index 39f1f3ee48f..00000000000 --- a/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp +++ /dev/null @@ -1,524 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -#include -#include -#include -#include -#include - -#include - -namespace py = pybind11; -namespace executorch { -namespace backends { -namespace qnn { -std::unique_ptr CreateQuantizationParamWrapper( - const Qnn_QuantizationEncoding_t& encoding, - py::dict& quant_info) { - std::unique_ptr quantize_param_wrapper; - if (encoding == QNN_QUANTIZATION_ENCODING_UNDEFINED) { - quantize_param_wrapper = std::make_unique(); - } else if (encoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { - int32_t axis = quant_info["axis"].cast(); - std::vector scale_offset = - quant_info["scale_offset"].cast>(); - - quantize_param_wrapper = - std::make_unique( - axis, scale_offset); - } else if (encoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { - uint32_t bitwidth = quant_info["bitwidth"].cast(); - int32_t axis = quant_info["axis"].cast(); - std::vector scale_offset = - quant_info["scale_offset"].cast>(); - uint32_t num_elements = scale_offset.size(); - std::vector scales; - std::vector offsets; - for (const auto& scale_offset : scale_offset) { - scales.push_back(scale_offset.scale); - offsets.push_back(scale_offset.offset); - } - quantize_param_wrapper = - std::make_unique( - bitwidth, axis, num_elements, scales, offsets); - } else if (encoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) { - uint32_t bitwidth = quant_info["bitwidth"].cast(); - float scale = quant_info["scale"].cast(); - int32_t offset = quant_info["offset"].cast(); - quantize_param_wrapper = - std::make_unique( - bitwidth, scale, offset); - } else if (encoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { - float scale = quant_info["scale"].cast(); - int32_t offset = quant_info["offset"].cast(); - quantize_param_wrapper = - std::make_unique(scale, offset); - } else if (encoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION) { - int32_t axis = quant_info["axis"].cast(); - std::vector scale_offset = - quant_info["block_scale_offset"].cast>(); - uint32_t num_blocks_per_axis = - quant_info["num_blocks_per_axis"].cast(); - uint32_t block_scale_bitwidth = - quant_info["block_scale_bitwidth"].cast(); - Qnn_BlockwiseExpansionBlockScaleStorageType_t block_storage_type = - quant_info["block_storage_type"] - .cast(); - std::vector buf = - quant_info["block_scales"].cast>(); - quantize_param_wrapper = - std::make_unique( - axis, - scale_offset, - num_blocks_per_axis, - block_scale_bitwidth, - block_storage_type, - buf.data(), - buf.size()); - } else { - QNN_EXECUTORCH_LOG_ERROR( - "Unknown the encoding of quantization: %d", encoding); - } - return quantize_param_wrapper; -} - -std::string GetScalarValue(const Qnn_Scalar_t& scalar) { - switch (scalar.dataType) { - case QNN_DATATYPE_FLOAT_32: - return std::to_string(scalar.floatValue); - case QNN_DATATYPE_FLOAT_64: - return std::to_string(scalar.doubleValue); - case QNN_DATATYPE_UINT_64: - return std::to_string(scalar.uint64Value); - case QNN_DATATYPE_INT_64: - return std::to_string(scalar.int64Value); - case QNN_DATATYPE_UINT_32: - return std::to_string(scalar.uint32Value); - case QNN_DATATYPE_INT_32: - return std::to_string(scalar.int32Value); - case QNN_DATATYPE_UINT_16: - return std::to_string(scalar.uint16Value); - case QNN_DATATYPE_INT_16: - return std::to_string(scalar.int16Value); - case QNN_DATATYPE_UINT_8: - return std::to_string(scalar.uint8Value); - case QNN_DATATYPE_INT_8: - return std::to_string(scalar.int8Value); - case QNN_DATATYPE_BOOL_8: - return std::to_string(static_cast(scalar.bool8Value)); - case QNN_DATATYPE_STRING: - return std::string(scalar.stringValue); - default: - return "QNN_DATATYPE_UNDEFINED"; - } -} - -std::shared_ptr CreateTensorWrapper( - const std::string& tensor_name, - Qnn_TensorType_t tensor_type, - Qnn_DataType_t data_type, - const Qnn_QuantizationEncoding_t& encoding, - py::dict& quant_info, - std::uint32_t rank, - const std::vector& dims, - const std::vector& dynamic_dims, - py::array& data, - bool copy_data) { - std::unique_ptr quantize_param_wrapper = - CreateQuantizationParamWrapper(encoding, quant_info); - - return CreateTensorWrapper( - tensor_name, - tensor_type, - data_type, - std::move(quantize_param_wrapper), - rank, - dims.data(), - dynamic_dims.data(), - 0, - data.size() == 0 ? nullptr : data.data(), - copy_data); -} - -PYBIND11_MODULE(PyQnnWrapperAdaptor, m) { - PYBIND11_NUMPY_DTYPE(PyQnnTensorWrapper::EncodingData, scale, offset); - - py::enum_(m, "Qnn_TensorType_t") - .value( - "QNN_TENSOR_TYPE_APP_WRITE", - Qnn_TensorType_t::QNN_TENSOR_TYPE_APP_WRITE) - .value( - "QNN_TENSOR_TYPE_APP_READ", - Qnn_TensorType_t::QNN_TENSOR_TYPE_APP_READ) - .value( - "QNN_TENSOR_TYPE_APP_READWRITE", - Qnn_TensorType_t::QNN_TENSOR_TYPE_APP_READWRITE) - .value("QNN_TENSOR_TYPE_NATIVE", Qnn_TensorType_t::QNN_TENSOR_TYPE_NATIVE) - .value("QNN_TENSOR_TYPE_STATIC", Qnn_TensorType_t::QNN_TENSOR_TYPE_STATIC) - .value("QNN_TENSOR_TYPE_NULL", Qnn_TensorType_t::QNN_TENSOR_TYPE_NULL) - .value( - "QNN_TENSOR_TYPE_UNDEFINED", - Qnn_TensorType_t::QNN_TENSOR_TYPE_UNDEFINED) - .export_values(); - - py::enum_(m, "Qnn_DataType_t") - .value("QNN_DATATYPE_INT_8", Qnn_DataType_t::QNN_DATATYPE_INT_8) - .value("QNN_DATATYPE_INT_16", Qnn_DataType_t::QNN_DATATYPE_INT_16) - .value("QNN_DATATYPE_INT_32", Qnn_DataType_t::QNN_DATATYPE_INT_32) - .value("QNN_DATATYPE_INT_64", Qnn_DataType_t::QNN_DATATYPE_INT_64) - .value("QNN_DATATYPE_UINT_8", Qnn_DataType_t::QNN_DATATYPE_UINT_8) - .value("QNN_DATATYPE_UINT_16", Qnn_DataType_t::QNN_DATATYPE_UINT_16) - .value("QNN_DATATYPE_UINT_32", Qnn_DataType_t::QNN_DATATYPE_UINT_32) - .value("QNN_DATATYPE_UINT_64", Qnn_DataType_t::QNN_DATATYPE_UINT_64) - .value("QNN_DATATYPE_FLOAT_16", Qnn_DataType_t::QNN_DATATYPE_FLOAT_16) - .value("QNN_DATATYPE_FLOAT_32", Qnn_DataType_t::QNN_DATATYPE_FLOAT_32) - .value( - "QNN_DATATYPE_SFIXED_POINT_8", - Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_8) - .value( - "QNN_DATATYPE_SFIXED_POINT_16", - Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_16) - .value( - "QNN_DATATYPE_SFIXED_POINT_32", - Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_32) - .value( - "QNN_DATATYPE_UFIXED_POINT_8", - Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8) - .value( - "QNN_DATATYPE_UFIXED_POINT_16", - Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16) - .value( - "QNN_DATATYPE_UFIXED_POINT_32", - Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_32) - .value("QNN_DATATYPE_BOOL_8", Qnn_DataType_t::QNN_DATATYPE_BOOL_8) - .value("QNN_DATATYPE_UNDEFINED", Qnn_DataType_t::QNN_DATATYPE_UNDEFINED) - .export_values(); - - py::enum_(m, "Qnn_QuantizationEncoding_t") - .value( - "QNN_QUANTIZATION_ENCODING_SCALE_OFFSET", - Qnn_QuantizationEncoding_t::QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) - .value( - "QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET", - Qnn_QuantizationEncoding_t:: - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) - .value( - "QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET", - Qnn_QuantizationEncoding_t::QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) - .value( - "QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET", - Qnn_QuantizationEncoding_t:: - QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) - .value( - "QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION", - Qnn_QuantizationEncoding_t:: - QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION) - .value( - "QNN_QUANTIZATION_ENCODING_UNDEFINED", - Qnn_QuantizationEncoding_t::QNN_QUANTIZATION_ENCODING_UNDEFINED) - .export_values(); - - py::enum_( - m, "Qnn_BlockwiseExpansionBlockScaleStorageType_t") - .value( - "QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8", - Qnn_BlockwiseExpansionBlockScaleStorageType_t:: - QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8) - .value( - "QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_16", - Qnn_BlockwiseExpansionBlockScaleStorageType_t:: - QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_16) - .value( - "QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_UNDEFINED", - Qnn_BlockwiseExpansionBlockScaleStorageType_t:: - QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_UNDEFINED) - .export_values(); - - py::class_>(m, "OpWrapper") - .def(py::init< - const std::string&, - const std::string&, - const std::string&>()) - .def( - "GetInputTensors", - &OpWrapper::GetInputTensors, - "A function which gets input tensors") - .def( - "GetOutputTensors", - &OpWrapper::GetOutputTensors, - "A function which gets output tensors") - .def("GetOpType", &OpWrapper::GetOpType, "A function which gets op type") - .def("GetName", &OpWrapper::GetName, "A function which gets name") - .def( - "GetPackageName", - &OpWrapper::GetPackageName, - "A function which gets package name") - .def( - "GetParams", &OpWrapper::GetRawParams, "A function which gets params") - // lambda function - // python: op_wrapper.GetOpConfig() - .def( - "GetOpConfig", - [](OpWrapper& self) { - auto op_config = self.GetOpConfig(); - py::dict result; - py::list params_list; - py::list input_tensors_list; - py::list output_tensors_list; - result["version"] = op_config.version; - result["name"] = QNN_OP_VER_PTR(op_config)->name; - result["packageName"] = QNN_OP_VER_PTR(op_config)->packageName; - result["typeName"] = QNN_OP_VER_PTR(op_config)->typeName; - result["numOfParams"] = QNN_OP_VER_PTR(op_config)->numOfParams; - for (size_t i = 0; i < QNN_OP_VER_PTR(op_config)->numOfParams; - ++i) { - params_list.append(QNN_OP_VER_PTR(op_config)->params[i]); - } - result["params"] = params_list; - result["numOfInputs"] = QNN_OP_VER_PTR(op_config)->numOfInputs; - for (size_t i = 0; i < QNN_OP_VER_PTR(op_config)->numOfInputs; - ++i) { - input_tensors_list.append( - QNN_OP_VER_PTR(op_config)->inputTensors[i]); - } - result["inputTensors"] = input_tensors_list; - result["numOfOutputs"] = QNN_OP_VER_PTR(op_config)->numOfOutputs; - for (size_t i = 0; i < QNN_OP_VER_PTR(op_config)->numOfOutputs; - ++i) { - output_tensors_list.append( - QNN_OP_VER_PTR(op_config)->outputTensors[i]); - } - result["outputTensors"] = output_tensors_list; - return result; - }, - "Get operator configuration"); - - py::class_>(m, "TensorWrapper") - .def(py::init(py::overload_cast< - const std::string&, - Qnn_TensorType_t, - Qnn_DataType_t, - const Qnn_QuantizationEncoding_t&, - py::dict&, - std::uint32_t, - const std::vector&, - const std::vector&, - py::array&, - bool>(&CreateTensorWrapper))); - - py::class_(m, "QuantizeParamsWrapper"); - - py::class_(m, "Qnn_ScaleOffset_t") - .def(py::init()) - .def_readonly("scale", &Qnn_ScaleOffset_t::scale) - .def_readonly("offset", &Qnn_ScaleOffset_t::offset); - - py::class_>( - m, "PyQnnOpWrapper") - .def(py::init< - const std::string&, - const std::string&, - const std::string&>()) - .def( - "AddInputTensors", - &PyQnnOpWrapper::AddInputTensors, - "A function which add input tensor wrapper into op wrapper", - py::arg("tensors")) - .def( - "AddOutputTensors", - &PyQnnOpWrapper::AddOutputTensors, - "A function which add output tensor wrapper into op wrapper", - py::arg("tensors")) - .def( - "AddTensorParam", - &PyQnnOpWrapper::AddTensorParam, - "A function which add tensor parameter into op wrapper", - py::arg("name"), - py::arg("data_type"), - py::arg("rank"), - py::arg("dims"), - py::arg("data"), - py::arg("copy_data")) - .def( - "AddScalarParam", - &PyQnnOpWrapper::AddScalarParam, - "A function which add scalar parameter into op wrapper", - py::arg("name"), - py::arg("data_type"), - py::arg("attrData")) - .def( - "GetOpWrapper", - &PyQnnOpWrapper::GetOpWrapper, - "A function which get op wrapper"); - - py::class_(m, "Encoding") - .def_readonly("data", &PyQnnTensorWrapper::Encoding::data) - .def_readonly("axis", &PyQnnTensorWrapper::Encoding::axis); - - py::class_>( - m, "PyQnnTensorWrapper") - .def(py::init&>()) - .def("GetDims", &PyQnnTensorWrapper::GetDims) - .def("GetDataType", &PyQnnTensorWrapper::GetDataType) - .def("GetName", &PyQnnTensorWrapper::GetName) - .def("GetEncodings", &PyQnnTensorWrapper::GetEncodings); - - py::class_(m, "Qnn_OpConfig") - .def_readonly("version", &Qnn_OpConfig_t::version) - // getter - // python: op_wrapper.GetOpConfig().v1 - .def_property_readonly( - "v1", [](const Qnn_OpConfig_t& config) -> const Qnn_OpConfigV1_t& { - return config.v1; - }); - - py::enum_(m, "Qnn_OpConfigVersion") - .value("QNN_OPCONFIG_VERSION_1", QNN_OPCONFIG_VERSION_1) - .value("QNN_OPCONFIG_VERSION_UNDEFINED", QNN_OPCONFIG_VERSION_UNDEFINED) - .export_values(); - - py::class_(m, "Qnn_OpConfigV1") - .def_readonly("name", &Qnn_OpConfigV1_t::name) - .def_readonly("packageName", &Qnn_OpConfigV1_t::packageName) - .def_readonly("typeName", &Qnn_OpConfigV1_t::typeName) - .def_readonly("numOfParams", &Qnn_OpConfigV1_t::numOfParams) - .def_readonly("params", &Qnn_OpConfigV1_t::params) - .def_readonly("numOfInputs", &Qnn_OpConfigV1_t::numOfInputs) - .def_readonly("inputTensors", &Qnn_OpConfigV1_t::inputTensors) - .def_readonly("numOfOutputs", &Qnn_OpConfigV1_t::numOfOutputs) - .def_readonly("outputTensors", &Qnn_OpConfigV1_t::outputTensors); - - py::class_(m, "Qnn_Param") - .def_readonly("paramType", &Qnn_Param_t::paramType) - .def_readonly("name", &Qnn_Param_t::name) - .def_property_readonly( - "scalarParam", - [](const Qnn_Param_t& param) -> const Qnn_Scalar_t& { - if (param.paramType == Qnn_ParamType_t::QNN_PARAMTYPE_SCALAR) { - return param.scalarParam; - } - throw std::runtime_error("ParamType is not scalar."); - }) - .def_property_readonly( - "tensorParam", [](const Qnn_Param_t& param) -> const Qnn_Tensor_t& { - if (param.paramType == Qnn_ParamType_t::QNN_PARAMTYPE_TENSOR) { - return param.tensorParam; - } - throw std::runtime_error("ParamType is not tensor."); - }); - - py::enum_(m, "Qnn_ParamType_t") - .value("QNN_PARAMTYPE_SCALAR", Qnn_ParamType_t::QNN_PARAMTYPE_SCALAR) - .value("QNN_PARAMTYPE_TENSOR", Qnn_ParamType_t::QNN_PARAMTYPE_TENSOR) - .value( - "QNN_PARAMTYPE_UNDEFINED", Qnn_ParamType_t::QNN_PARAMTYPE_UNDEFINED) - .export_values(); - - py::class_(m, "Qnn_Scalar_t") - .def_readonly("dataType", &Qnn_Scalar_t::dataType) - .def("value", &GetScalarValue, "Get the value of the scalar as a string"); - - py::class_(m, "Qnn_Tensor_t") - .def_readonly("version", &Qnn_Tensor_t::version) - .def_property_readonly("v2", [](Qnn_Tensor_t& t) -> Qnn_TensorV2_t& { - if (t.version == QNN_TENSOR_VERSION_2) { - return t.v2; - } - throw std::runtime_error("Tensor version is not V2."); - }); - - py::enum_(m, "Qnn_TensorVersion_t") - .value("QNN_TENSOR_VERSION_1", Qnn_TensorVersion_t::QNN_TENSOR_VERSION_1) - .value("QNN_TENSOR_VERSION_2", Qnn_TensorVersion_t::QNN_TENSOR_VERSION_2) - .value( - "QNN_TENSOR_VERSION_UNDEFINED", - Qnn_TensorVersion_t::QNN_TENSOR_VERSION_UNDEFINED) - .export_values(); - - py::class_(m, "Qnn_TensorV2_t") - .def_readonly("id", &Qnn_TensorV2_t::id) - .def_readonly("name", &Qnn_TensorV2_t::name) - .def_readonly("type", &Qnn_TensorV2_t::type) - .def_readonly("dataFormat", &Qnn_TensorV2_t::dataFormat) - .def_readonly("dataType", &Qnn_TensorV2_t::dataType) - .def_readonly("quantizeParams", &Qnn_TensorV2_t::quantizeParams) - .def_readonly("rank", &Qnn_TensorV2_t::rank) - // change dimensions pointer to vector(begin to rank) - .def_property_readonly( - "dimensions", - [](const Qnn_TensorV2_t& t) { - return std::vector(t.dimensions, t.dimensions + t.rank); - }) - .def_property_readonly( - "isDynamicDimensions", - [](const Qnn_TensorV2_t& t) { - return t.dimensions == nullptr - ? std::vector() - : std::vector(t.dimensions, t.dimensions + t.rank); - }) - .def_readonly("memType", &Qnn_TensorV2_t::memType); - - py::enum_(m, "Qnn_TensorMemType_t") - .value( - "QNN_TENSORMEMTYPE_RAW", Qnn_TensorMemType_t::QNN_TENSORMEMTYPE_RAW) - .value( - "QNN_TENSORMEMTYPE_MEMHANDLE", - Qnn_TensorMemType_t::QNN_TENSORMEMTYPE_MEMHANDLE) - .value( - "QNN_TENSORMEMTYPE_UNDEFINED", - Qnn_TensorMemType_t::QNN_TENSORMEMTYPE_UNDEFINED) - .export_values(); - - py::class_(m, "QnnQuantizeParams") - .def_readonly( - "encodingDefinition", &Qnn_QuantizeParams_t::encodingDefinition) - .def_readonly( - "quantizationEncoding", &Qnn_QuantizeParams_t::quantizationEncoding) - .def_property_readonly( - "scaleOffsetEncoding", - [](const Qnn_QuantizeParams_t& qp) { - if (qp.quantizationEncoding == - QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { - return qp.scaleOffsetEncoding; - } - throw std::runtime_error( - "Invalid quantization encoding type for scaleOffsetEncoding."); - }) - .def_property_readonly( - "axisScaleOffsetEncoding", [](const Qnn_QuantizeParams_t& qp) { - if (qp.quantizationEncoding == - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { - return qp.axisScaleOffsetEncoding; - } - throw std::runtime_error( - "Invalid quantization encoding type for axisScaleOffsetEncoding."); - }); - - py::enum_(m, "QnnDefinition") - .value( - "QNN_DEFINITION_IMPL_GENERATED", - Qnn_Definition_t::QNN_DEFINITION_IMPL_GENERATED) - .value("QNN_DEFINITION_DEFINED", Qnn_Definition_t::QNN_DEFINITION_DEFINED) - .value( - "QNN_DEFINITION_UNDEFINED", - Qnn_Definition_t::QNN_DEFINITION_UNDEFINED) - .export_values(); - - py::class_(m, "QnnAxisScaleOffset") - .def_readonly("axis", &Qnn_AxisScaleOffset_t::axis) - .def_readonly("numScaleOffsets", &Qnn_AxisScaleOffset_t::numScaleOffsets) - .def_property_readonly( - "scaleOffset", [](const Qnn_AxisScaleOffset_t& aso) { - return std::vector( - aso.scaleOffset, aso.scaleOffset + aso.numScaleOffsets); - }); -} -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.h b/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.h deleted file mode 100644 index 33c0bd63cac..00000000000 --- a/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.h +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -namespace py = pybind11; -namespace executorch { -namespace backends { -namespace qnn { -class PyQnnOpWrapper { - public: - explicit PyQnnOpWrapper( - const std::string& name, - const std::string& package_name, - const std::string& op_type) { - op_wrapper_ = std::make_shared(name, package_name, op_type); - } - void AddInputTensors( - const std::vector>& tensors) { - op_wrapper_->AddInputTensors(tensors); - } - - void AddOutputTensors( - const std::vector>& tensors) { - op_wrapper_->AddOutputTensors(tensors); - } - - void AddTensorParam( - const std::string& name, - Qnn_DataType_t data_type, - std::uint32_t rank, - const std::vector& dims, - py::array& data, - bool copy_data) { - op_wrapper_->AddTensorParam( - name, data_type, rank, dims.data(), data.data(), copy_data); - } - - void AddScalarParam( - const std::string& name, - Qnn_DataType_t data_type, - py::dict& attrData) { - switch (data_type) { - case Qnn_DataType_t::QNN_DATATYPE_INT_32: - op_wrapper_->AddScalarParam( - name, data_type, attrData["data"].cast()); - break; - case Qnn_DataType_t::QNN_DATATYPE_INT_16: - op_wrapper_->AddScalarParam( - name, data_type, attrData["data"].cast()); - break; - case Qnn_DataType_t::QNN_DATATYPE_INT_8: - op_wrapper_->AddScalarParam( - name, data_type, attrData["data"].cast()); - break; - case Qnn_DataType_t::QNN_DATATYPE_UINT_32: - op_wrapper_->AddScalarParam( - name, data_type, attrData["data"].cast()); - break; - case Qnn_DataType_t::QNN_DATATYPE_UINT_16: - op_wrapper_->AddScalarParam( - name, data_type, attrData["data"].cast()); - break; - case Qnn_DataType_t::QNN_DATATYPE_UINT_8: - op_wrapper_->AddScalarParam( - name, data_type, attrData["data"].cast()); - break; - case Qnn_DataType_t::QNN_DATATYPE_FLOAT_32: - case Qnn_DataType_t::QNN_DATATYPE_FLOAT_16: - op_wrapper_->AddScalarParam( - name, data_type, attrData["data"].cast()); - break; - case Qnn_DataType_t::QNN_DATATYPE_BOOL_8: - op_wrapper_->AddScalarParam( - name, data_type, attrData["data"].cast()); - break; - default: - QNN_EXECUTORCH_LOG_ERROR( - "%s has invalid data type: %d", name.c_str(), data_type); - break; - } - } - std::shared_ptr& GetOpWrapper() { - return op_wrapper_; - } - - private: - std::shared_ptr op_wrapper_; -}; - -class PyQnnTensorWrapper { - public: - explicit PyQnnTensorWrapper(const std::shared_ptr& wrapper) { - tensor_wrapper_ = wrapper; - } - struct EncodingData { - float scale; - int32_t offset; - }; - struct Encoding { - py::array_t data; - int32_t axis; - }; - - py::array_t GetDims() { - std::uint32_t* dim = tensor_wrapper_->GetDims(); - size_t shape[1]{tensor_wrapper_->GetRank()}; - size_t stride[1]{sizeof(std::uint32_t)}; - auto ret = py::array_t(shape, stride); - auto view = ret.mutable_unchecked<1>(); - for (int i = 0; i < ret.shape(0); ++i) { - view(i) = dim[i]; - } - return ret; - } - std::string GetName() { - return tensor_wrapper_->GetName(); - } - Qnn_DataType_t GetDataType() { - return tensor_wrapper_->GetDataType(); - } - Encoding GetEncodings() { - auto q_param = tensor_wrapper_->GetQuantizeParams(); - size_t stride[1]{sizeof(EncodingData)}; - - switch (q_param.quantizationEncoding) { - case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET: { - Qnn_ScaleOffset_t data = q_param.scaleOffsetEncoding; - size_t shape[1]{1}; - auto enc_data = py::array_t(shape, stride); - auto view = enc_data.mutable_unchecked<1>(); - view(0) = {data.scale, data.offset}; - return {enc_data, -1}; - } - case QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: { - Qnn_AxisScaleOffset_t data = q_param.axisScaleOffsetEncoding; - size_t shape[1]{data.numScaleOffsets}; - auto enc_data = py::array_t(shape, stride); - auto view = enc_data.mutable_unchecked<1>(); - for (int i = 0; i < enc_data.shape(0); ++i) { - view(i) = {data.scaleOffset[i].scale, data.scaleOffset[i].offset}; - } - return {enc_data, data.axis}; - } - case QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET: { - Qnn_BwScaleOffset_t data = q_param.bwScaleOffsetEncoding; - size_t shape[1]{1}; - auto enc_data = py::array_t(shape, stride); - auto view = enc_data.mutable_unchecked<1>(); - view(0) = {data.scale, data.offset}; - return {enc_data, -1}; - } - case QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: { - Qnn_BwAxisScaleOffset_t data = q_param.bwAxisScaleOffsetEncoding; - size_t shape[1]{data.numElements}; - auto enc_data = py::array_t(shape, stride); - auto view = enc_data.mutable_unchecked<1>(); - for (int i = 0; i < enc_data.shape(0); ++i) { - view(i) = {data.scales[i], data.offsets[i]}; - } - return {enc_data, data.axis}; - } - default: - QNN_EXECUTORCH_LOG_WARN( - "%s QNN_QUANTIZATION_ENCODING_UNDEFINED detected", - GetName().c_str()); - break; - } - return {}; - } - - private: - std::shared_ptr tensor_wrapper_; -}; -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/aot/python/targets.bzl b/backends/qualcomm/aot/python/targets.bzl index 74fbd1da511..0133aa73b93 100644 --- a/backends/qualcomm/aot/python/targets.bzl +++ b/backends/qualcomm/aot/python/targets.bzl @@ -46,38 +46,6 @@ def define_common_targets(): ) - runtime.cxx_python_extension( - name = "PyQnnWrapperAdaptor", - srcs = [ - "PyQnnWrapperAdaptor.cpp", - ], - headers = [ - "PyQnnWrapperAdaptor.h", - ], - base_module = "executorch.backends.qualcomm.python", - preprocessor_flags = [ - "-DEXECUTORCH_PYTHON_MODULE_NAME={}".format(PYTHON_MODULE_NAME), - ], - deps = [ - "//executorch/runtime/core:core", - "//executorch/backends/qualcomm/aot/python:python_lib", - "//executorch/backends/qualcomm/aot/wrappers:wrappers", - "//executorch/backends/qualcomm/runtime:logging", - "//executorch/backends/qualcomm:schema", - "//executorch/backends/qualcomm/runtime:runtime", - "fbsource//third-party/pybind11:pybind11", - "fbsource//third-party/qualcomm/qnn/qnn-{0}:api".format(get_qnn_library_version()), - "fbsource//third-party/qualcomm/qnn/qnn-{0}:app_sources".format(get_qnn_library_version()), - ], - external_deps = [ - "libtorch_python", - ], - use_static_deps = True, - visibility = [ - "//executorch/backends/qualcomm/...", - ], - ) - runtime.cxx_library( name = "python_lib", srcs = glob([ diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 6ba4eafb01f..a0643bd4f1d 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -2,20 +2,24 @@ Thank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. Reading and following these guidelines will help you quickly get the essentials of implementing operator builder to unblock yourself and land pull requests more efficiently. ## Sections -* [References](#references) -* [Getting Started](#getting-started) - * [Identify Unsupported Operator](#identify-unsupported-operator) - * [Check Operator Spec](#check-operator-spec) - * [Implementation](#implementation) - * [Quantizer Annotation](#quantizer-annotation) -* [Operator Support Status](#operator-support-status) -* [Issues](#issues) -* [Pull Requests](#pull-requests) +- [Contribution for More Operators](#contribution-for-more-operators) + - [Sections](#sections) + - [References](#references) + - [Qualcomm AI Engine Direct](#qualcomm-ai-engine-direct) + - [PyTorch](#pytorch) + - [Getting Started](#getting-started) + - [Identify Unsupported Operator](#identify-unsupported-operator) + - [Check Operator Spec](#check-operator-spec) + - [Implementation](#implementation) + - [Quantizer Annotation](#quantizer-annotation) + - [Operator Support Status](#operator-support-status) + - [Issues](#issues) + - [Pull Requests](#pull-requests) ## References ### Qualcomm AI Engine Direct -- [Operator Definitions](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/MasterOpDef.html) -- [Supported Operators in Backends](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/operations.html#backend-supplements) +- [Operator Definitions](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/MasterOpDef.html) +- [Supported Operators in Backends](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/operations.html#backend-supplements) ### PyTorch - [torch.nn Operator Definitions](https://pytorch.org/docs/stable/nn.html) @@ -37,7 +41,7 @@ class MyModel(torch.nn.Module): ``` At the time we try to lower it with Qualcomm backend: ```python -from excutorch.examples.qualcomm.utils import build_executorch_binary +from executorch.examples.qualcomm.utils import build_executorch_binary build_executorch_binary( model=MyModel(), @@ -120,9 +124,9 @@ It will provide more hint to the source PyTorch layer where the missing operator }; } Qnn_Param_t; ``` - The name value equals to the parameter name described in [Operator Definitions](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/MasterOpDef.html), there are `epsilon`, `axes` for `LayerNorm` case.
+ The name value equals to the parameter name described in [Operator Definitions](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/MasterOpDef.html), there are `epsilon`, `axes` for `LayerNorm` case.
- If you find it hard to correlate missing operator with documentation, this [table](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/SupportedOps.html) might be helpful for searching. In some cases, an exact match may not exist. Consider seeking for a math equivalent approach or notify maintainer for further analysis. + If you find it hard to correlate missing operator with documentation, this [table](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/SupportedOps.html) might be helpful for searching. In some cases, an exact match may not exist. Consider seeking for a math equivalent approach or notify maintainer for further analysis. - **PyTorch**:
We could also read the IO spec from [function declaration](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/layer_norm.cpp) mentioned in [PyTorch Documentation](#pytorch): @@ -168,7 +172,7 @@ The content should have exact match with literal values mentioned in [Qualcomm A Next, create a new file with name in snake case format (e.g. `op_layer_norm.py`) and import required modules (please check comments for getting the ideas of usage): ```python # pybind interface for invoking QNN APIs -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager # tensors or other numerics will be shipped in numpy format import numpy as np import torch @@ -195,8 +199,8 @@ class LayerNormVisitor(NodeVisitor): def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: ``` It's mandatory to have `target` member in list form, since there would have multiple targets map to the same implementation. e.g. `aten.leaky_relu.default`, `aten.prelu.default` have similar equations but only differ in negative slope.
The `nodes_to_wrappers` is a dictionary maintaining relationship between graph node and its output tensor. `nodes_to_wrappers` acts as an memo for not creating tensor objects to nodes that have already been traversed.
@@ -210,7 +214,7 @@ Now, we can start to fill in function body step by step: input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) ``` @@ -221,7 +225,7 @@ Now, we can start to fill in function body step by step: - **tensor_source_node**: current graph source node of the tensor - **target_build_node**: current node to build, which is important for fixed point mixed-precision to work properly - **tensor**: torch tensor emitted by node - - **tensor_type**: type compatible with QNN SDK, oftenly use `QNN_TENSOR_TYPE_NATIVE` for intermediate outputs and `QNN_TENSOR_TYPE_STATIC` for constant parameters + - **tensor_type**: type compatible with QNN SDK, often use `QNN_TENSOR_TYPE_NATIVE` for intermediate outputs and `QNN_TENSOR_TYPE_STATIC` for constant parameters - **nodes_to_wrappers**: dictionary of graph node and its output tensor (note: the tensor here is not a torch tensor but a wrapped object for QNN) - **node_name**: (optional) tensor name for user to specify - **wrapper_idx**: (optional) defaults to zero if node is not a tuple, otherwise it acts as an indexer to output tensors. e.g. when slicing input tensor into multiple outputs, `wrapper_idx` is necessary for getting correct wrapped tensor object @@ -234,7 +238,7 @@ Now, we can start to fill in function body step by step: weight_node, node, weight_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) @@ -244,7 +248,7 @@ Now, we can start to fill in function body step by step: bias_node, node, bias_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) ``` @@ -272,15 +276,15 @@ Now, we can start to fill in function body step by step: node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) ``` - Althought the input / output activations might map to the graph IOs (a.k.a. user inputs / outputs) with corresponding type `QNN_TENSOR_TYPE_APP_READ` / `QNN_TENSOR_TYPE_APP_WRITE`. Users are still expected to have `QNN_TENSOR_TYPE_NATIVE` for all nodes' IOs and leave the detection logic handled inside `define_tensor` method. + Although the input / output activations might map to the graph IOs (a.k.a. user inputs / outputs) with corresponding type `QNN_TENSOR_TYPE_APP_READ` / `QNN_TENSOR_TYPE_APP_WRITE`. Users are still expected to have `QNN_TENSOR_TYPE_NATIVE` for all nodes' IOs and leave the detection logic handled inside `define_tensor` method. 5. Generate operator object in QNN graph: ```python - layer_norm_op = PyQnnWrapper.PyQnnOpWrapper( + layer_norm_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpLayerNorm.op_name, @@ -300,12 +304,12 @@ Now, we can start to fill in function body step by step: ```python layer_norm_op.AddScalarParam( OpLayerNorm.param_epsilon, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, {QCOM_DATA: np.float32(epsilon)}, ) layer_norm_op.AddTensorParam( OpLayerNorm.param_axes, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(axis_shape), axis_shape, np.array(axis, dtype=np.uint32), @@ -326,7 +330,7 @@ Now, we can start to fill in function body step by step: - **data_type**: type compatible with QNN SDK, e.g. `QNN_DATATYPE_FLOAT_32`, `QNN_DATATYPE_UINT_32`, etc. - **rank**: dimensions of tensor - **dims**: shape of tensor - - **data**: tesnor data + - **data**: tensor data - **copy_data**: user should specify to True for constant parameters 8. Last, return operator object for partitioner to conduct validation: @@ -365,7 +369,7 @@ Please help update following table if you are contributing new operators: + 🚫 = Deprecated, supported with other QNN Ops -| Operators | HTP - 90/116 Enabled | +| Operators | HTP - 94/116 Enabled | |-----------|---------| | Argmax | ✓ | | Argmin | ✓ | @@ -375,7 +379,7 @@ Please help update following table if you are contributing new operators: | ChannelShuffle | ✗ | | Concat | ✓ | | Conv2d | ✓ | -| Conv3d | ✗ | +| Conv3d | ✓ | | Convert | ✓ | | CreateSparse | ✗ | | CumulativeSum | ✓ | @@ -431,7 +435,7 @@ Please help update following table if you are contributing new operators: | Gelu | ✓ | | GetSparseIndices | ✗ | | GetSparseValues | ✗ | -| GridSample | ✗ | +| GridSample | ✓ | | GroupNorm | ✓ | | HardSwish | ✓ | | InstanceNorm | ✓ | @@ -448,7 +452,7 @@ Please help update following table if you are contributing new operators: | Pack | ✓ | | Pad | ✓ | | PoolAvg2d | ✓ | -| PoolAvg3d | ✗ | +| PoolAvg3d | ✓ | | PoolMax2d | ✓ | | Prelu | ✓ | | Quantize | ✓ | @@ -481,7 +485,7 @@ Please help update following table if you are contributing new operators: | TopK | ✓ | | TransPose | ✓ | | TransPoseConv2d | ✓ | -| TransPoseConv3d | ✗ | +| TransPoseConv3d | ✓ | | Unpack | ✓ | ## Issues diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 9800fb7bdab..e982985477d 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -8,6 +8,7 @@ node_visitor, op_abs, op_adaptive_avg_pool2d, + op_adaptive_max_pool2d, op_add, op_amax, op_amin, @@ -18,13 +19,14 @@ op_asin, op_atan, op_avg_pool2d, + op_avg_pool3d, op_batch_norm, op_binary, op_bmm, op_cat, op_ceil, op_clamp, - op_conv2d, + op_conv, op_copy, op_cos, op_cum_sum, @@ -43,6 +45,7 @@ op_gather, op_ge, op_gelu, + op_grid_sampler_2d, op_group_norm, op_gt, op_hardsigmoid, @@ -113,6 +116,7 @@ node_visitor, op_abs, op_adaptive_avg_pool2d, + op_adaptive_max_pool2d, op_add, op_amax, op_amin, @@ -123,13 +127,14 @@ op_asin, op_atan, op_avg_pool2d, + op_avg_pool3d, op_batch_norm, op_binary, op_bmm, op_cat, op_ceil, op_clamp, - op_conv2d, + op_conv, op_copy, op_cos, op_cum_sum, @@ -148,6 +153,7 @@ op_gather, op_ge, op_gelu, + op_grid_sampler_2d, op_group_norm, op_gt, op_hardswish, diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index bc2b62c8c0b..f2fcf65c896 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -7,7 +7,7 @@ import copy from typing import Any, Dict, Tuple -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -30,6 +30,7 @@ QCOM_SCALE, QCOM_SCALE_OFFSET, QCOM_SCALES, + QCOM_TENSOR_NAME, QCOM_ZERO_POINT, QCOM_ZERO_POINTS, ) @@ -48,28 +49,28 @@ QNN_QUANT_TYPE_MAP = { - torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8, - torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16, - torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32, + torch.int8: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8, + torch.int16: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16, + torch.int32: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32, # Note that there is no int64 tensor data type in Qnn. - torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED, - torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, - torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, + torch.int64: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED, + torch.uint8: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, + torch.uint16: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, } QNN_TENSOR_TYPE_MAP = { - torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, - torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + torch.bool: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + torch.float32: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, # Note that there is no float64 tensor data type in Qnn. - torch.float64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, - torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8, - torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16, - torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, - torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64, - torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8, - torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16, - torch.uint32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, - int: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + torch.float64: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + torch.int8: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_8, + torch.int16: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_16, + torch.int32: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_32, + torch.int64: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_64, + torch.uint8: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_8, + torch.uint16: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_16, + torch.uint32: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + float: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + int: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, } PER_CHANNEL_ENCODING = { @@ -153,32 +154,40 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict): scales, scale_offset, quantized_scales = quant_attrs[QCOM_SCALE], [], [] # channel in observers defaults to zero num_channels = node.meta["val"].shape[0] + user_0 = self.get_first_user(node) + + ch_axis = 0 + # args[6] to check if it is transpose conv + if user_0.target == exir_ops.edge.aten.convolution.default and user_0.args[6]: + num_channels = node.meta["val"].shape[1] + ch_axis = 1 # TODO: expand this when QNN starts to support more configurations bitwidth_of_scale = 4 quant_scales_dtype = torch.uint8 num_steps = 2**bitwidth_of_scale scale_storage_type = ( - PyQnnWrapper.Qnn_BlockwiseExpansionBlockScaleStorageType_t.QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8 + PyQnnManager.Qnn_BlockwiseExpansionBlockScaleStorageType_t.QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8 ) for ch in range(num_channels): - max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps + candidates = scales[ch] if ch_axis == 0 else scales[:, ch, ...] + max_scale = candidates.reshape(1, -1).amax(dim=-1) / num_steps q_scales = torch.clamp( - input=torch.round(input=scales[ch] / max_scale), + input=torch.round(input=candidates / max_scale), min=1, max=2**bitwidth_of_scale, ).to(quant_scales_dtype) quantized_scales.append(q_scales) # symmetric quantization is required - scale_offset.append(PyQnnWrapper.Qnn_ScaleOffset_t(max_scale, 0)) + scale_offset.append(PyQnnManager.Qnn_ScaleOffset_t(max_scale, 0)) # skip dequantize op, e.g. frozen_param -> dq -> conv2d user_0 = self.get_first_user(node) - if "convolution" in user_0.target.__name__: + if user_0.target == exir_ops.edge.aten.convolution.default: # OIHW (pytorch) -> HWIO (QNN) - quant_config[QCOM_AXIS] = 3 + quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1 quant_config[QCOM_AXIS_ORDER] = (2, 3, 1, 0) - elif "linear" in user_0.target.__name__: + elif user_0.target == exir_ops.edge.aten.linear.default: # OI (pytorch) -> OI (QNN) quant_config[QCOM_AXIS] = 0 quant_config[QCOM_AXIS_ORDER] = (0, 1) @@ -194,7 +203,7 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict): ) quant_config[QCOM_BLOCK_STORAGE_TYPE] = scale_storage_type return ( - PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION, + PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION, quant_config, ) @@ -211,14 +220,14 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): for i in range(len(scales)): # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h scale_offset.append( - PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i]) + PyQnnManager.Qnn_ScaleOffset_t(scales[i], -zero_points[i]) ) # skip dequantize op, e.g. frozen_param -> dq -> conv2d user_0 = self.get_first_user(node) # Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO - if "convolution" in user_0.target.__name__: - quant_config[QCOM_AXIS] = 3 + if user_0.target == exir_ops.edge.aten.convolution.default: + quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1 else: quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS] @@ -230,11 +239,11 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): ): quant_config[QCOM_BITWIDTH] = 4 return ( - PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET, + PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET, quant_config, ) return ( - PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET, + PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET, quant_config, ) @@ -249,11 +258,11 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict): ): quant_config[QCOM_BITWIDTH] = 4 return ( - PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET, + PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET, quant_config, ) return ( - PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, + PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, quant_config, ) @@ -262,7 +271,7 @@ def get_quant_encoding_conf( ) -> Tuple[Any, Dict]: if not node.meta.get(QCOM_QUANT_ATTRS, None): return ( - PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, {}, ) is_input_tensor = node != target_node @@ -313,8 +322,8 @@ def get_quant_tensor_value( def get_tensor_type( self, node: torch.fx.Node, - tensor_type: PyQnnWrapper.Qnn_TensorType_t, - ) -> PyQnnWrapper.Qnn_TensorType_t: + tensor_type: PyQnnManager.Qnn_TensorType_t, + ) -> PyQnnManager.Qnn_TensorType_t: is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input( node, self.edge_program ) @@ -325,25 +334,25 @@ def get_tensor_type( node in self.external_ids ), f"Node {node}, is_input: {is_input}, is_output: {is_output}, ext_ids: {self.external_ids.keys()}" if is_input: - return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_WRITE + return PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_WRITE if is_output: - return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ + return PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ if is_parameter(node, self.edge_program): - return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC + return PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC # dump all tensor, set to app read, and we only dump native tensors if ( self.enable_tensor_dump - and tensor_type == PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + and tensor_type == PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE ): - return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ + return PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ return tensor_type def get_data_type( self, tensor: torch.Tensor, quant_config: Dict, - ) -> PyQnnWrapper.Qnn_TensorType_t: + ) -> PyQnnManager.Qnn_TensorType_t: if quant_config: quant_config[QCOM_DTYPE] = deduce_dtype(tensor, quant_config) return QNN_QUANT_TYPE_MAP[quant_config[QCOM_DTYPE]] @@ -387,26 +396,33 @@ def get_tensor_name( tensor_name = f"output_mutbuf_{position_index}_{tensor_name}" elif is_graph_output(node): tensor_name = f"output_{tensor_name}" + + # Save this for intermediate debugger + # Needs idx since node like topk has 2 outputs + if QCOM_TENSOR_NAME in node.meta: + node.meta[QCOM_TENSOR_NAME][wrapper_idx] = tensor_name + else: + node.meta[QCOM_TENSOR_NAME] = {wrapper_idx: tensor_name} return tensor_name def define_custom_tensor_wrapper( self, node_name: str, - tensor_type: PyQnnWrapper.Qnn_TensorType_t, - dtype: PyQnnWrapper.Qnn_DataType_t, - quant_encoding: PyQnnWrapper.Qnn_QuantizationEncoding_t, + tensor_type: PyQnnManager.Qnn_TensorType_t, + dtype: PyQnnManager.Qnn_DataType_t, + quant_encoding: PyQnnManager.Qnn_QuantizationEncoding_t, quant_configs: dict, dims: torch.Size, tensor: torch.Tensor, is_fake_tensor: bool, - nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], + nodes_to_wrappers: Dict[str, Dict[int, PyQnnManager.TensorWrapper]], wrapper_idx: int = 0, - ) -> PyQnnWrapper.TensorWrapper: + ) -> PyQnnManager.TensorWrapper: if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): return cached if is_fake_tensor: dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims) - tensor_wrapper = PyQnnWrapper.TensorWrapper( + tensor_wrapper = PyQnnManager.TensorWrapper( node_name, tensor_type, dtype, @@ -429,11 +445,11 @@ def define_tensor( tensor_source_node: torch.fx.Node, target_build_node: torch.fx.Node, tensor: torch.Tensor, - tensor_type: PyQnnWrapper.Qnn_TensorType_t, - nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], + tensor_type: PyQnnManager.Qnn_TensorType_t, + nodes_to_wrappers: Dict[str, Dict[int, PyQnnManager.TensorWrapper]], node_name: str = None, wrapper_idx: int = 0, - ) -> PyQnnWrapper.TensorWrapper: + ) -> PyQnnManager.TensorWrapper: """ Covert torch.Tensor to TensorWrapper @@ -459,7 +475,7 @@ def define_tensor( ) dtype = self.get_data_type(tensor, quant_configs) if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): - tensor_wrapper = PyQnnWrapper.TensorWrapper( + tensor_wrapper = PyQnnManager.TensorWrapper( tensor_name, tensor_type, dtype, @@ -478,7 +494,7 @@ def define_tensor( tensor_source_node.meta[QCOM_QUANT_ATTRS], quant_configs, ) - tensor_wrapper = PyQnnWrapper.TensorWrapper( + tensor_wrapper = PyQnnManager.TensorWrapper( tensor_name, tensor_type, dtype, @@ -496,7 +512,7 @@ def define_tensor( def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[str, Dict[int, PyQnnManager.TensorWrapper]], + ) -> PyQnnManager.PyQnnOpWrapper: """Convert torch.fx.Node to OpWrapper""" raise NotImplementedError("NodeVisitor must be extended!") diff --git a/backends/qualcomm/builders/op_abs.py b/backends/qualcomm/builders/op_abs.py index 1df49b88912..741ba3fa66d 100644 --- a/backends/qualcomm/builders/op_abs.py +++ b/backends/qualcomm/builders/op_abs.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) abs_output_tensors = [output_tensor_wrapper] @@ -41,12 +41,12 @@ def define_node( input_node, node, self.get_tensor(input_node, node), - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) abs_input_tensors = [input_tensor_wrapper] - abs_op = PyQnnWrapper.PyQnnOpWrapper( + abs_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseAbs.op_name, diff --git a/backends/qualcomm/builders/op_adaptive_avg_pool2d.py b/backends/qualcomm/builders/op_adaptive_avg_pool2d.py index 1b0d58482ec..5081fb150f2 100644 --- a/backends/qualcomm/builders/op_adaptive_avg_pool2d.py +++ b/backends/qualcomm/builders/op_adaptive_avg_pool2d.py @@ -6,7 +6,7 @@ import warnings from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -26,8 +26,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -35,7 +35,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -74,11 +74,11 @@ def define_node( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - adaptive_avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper( + adaptive_avg_pool2d_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpPoolAvg2d.op_name, @@ -89,7 +89,7 @@ def define_node( adaptive_avg_pool2d_op.AddTensorParam( OpPoolAvg2d.param_filter_size, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(filter_shape), filter_shape, np.array( @@ -101,7 +101,7 @@ def define_node( adaptive_avg_pool2d_op.AddTensorParam( OpPoolAvg2d.param_stride, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(stride_shape), stride_shape, np.array( @@ -113,7 +113,7 @@ def define_node( adaptive_avg_pool2d_op.AddTensorParam( OpPoolAvg2d.param_pad_amount, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(padding_shape), padding_shape, np.array( diff --git a/backends/qualcomm/builders/op_adaptive_max_pool2d.py b/backends/qualcomm/builders/op_adaptive_max_pool2d.py new file mode 100644 index 00000000000..7513a0cb6b3 --- /dev/null +++ b/backends/qualcomm/builders/op_adaptive_max_pool2d.py @@ -0,0 +1,151 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager +import numpy as np + +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpPoolMax2d, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class AdaptiveMaxPool2D(NodeVisitor): + target = ["aten.adaptive_max_pool2d.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + users = list(node.users.keys()) + for user in users: + if user.target.__name__ == "getitem": + getitem_index = user.args[1] + if getitem_index != 0: + warnings.warn( + f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}", + stacklevel=1, + ) + return + + if len(node.args) > 2: + warnings.warn( + "[QNN Delegate Op Builder]: The return_indices is not supported, fallback op", + stacklevel=1, + ) + return + + input_height = input_tensor.shape[1] + input_width = input_tensor.shape[2] + # output cases + out_wh = cast(List[int], node.args[1]) + if len(out_wh) == 1: + output_height = node.args[1][0] + output_width = node.args[1][0] + else: + output_height = node.args[1][0] + output_width = node.args[1][1] + if output_height is None: + output_height = input_height + if output_width is None: + output_width = input_width + # NOTE: Here we need not to emphasize on mode, cuz the output shape is decided by user. + mode = OpPoolMax2d.RoundingMode.FLOOR + + # floor division + stride_height = input_height // output_height + filter_height = input_height - (output_height - 1) * stride_height + stride_width = input_width // output_width + filter_width = input_width - (output_width - 1) * stride_width + + filter = [filter_height, filter_width] + filter_shape = [len(filter)] + + stride = [stride_height, stride_width] + stride_shape = [len(stride)] + + padding = [0, 0] + padding_shape = [len(padding), len(padding)] + + out_tensor = self.get_tensor(node, node, 0) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + adaptive_max_pool2d_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpPoolMax2d.op_name, + ) + + adaptive_max_pool2d_op.AddInputTensors([input_tensor_wrapper]) + adaptive_max_pool2d_op.AddOutputTensors([output_tensor_wrapper]) + + adaptive_max_pool2d_op.AddTensorParam( + OpPoolMax2d.param_filter_size, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(filter_shape), + filter_shape, + np.array( + filter, + dtype=np.uint32, + ), + True, + ) + + adaptive_max_pool2d_op.AddTensorParam( + OpPoolMax2d.param_stride, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(stride_shape), + stride_shape, + np.array( + stride, + dtype=np.uint32, + ), + True, + ) + + adaptive_max_pool2d_op.AddTensorParam( + OpPoolMax2d.param_pad_amount, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(padding_shape), + padding_shape, + np.array( + [[padding[0], padding[0]], [padding[1], padding[1]]], + dtype=np.uint32, + ), + True, + ) + + adaptive_max_pool2d_op.AddScalarParam( + OpPoolMax2d.param_rounding_mode, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(mode)}, + ) + + return adaptive_max_pool2d_op diff --git a/backends/qualcomm/builders/op_add.py b/backends/qualcomm/builders/op_add.py index d2f4a39fc3d..3c1c9beb79e 100644 --- a/backends/qualcomm/builders/op_add.py +++ b/backends/qualcomm/builders/op_add.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) add_output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -51,7 +51,7 @@ def define_node( ) add_input_tensors.append(input_tensor_wrapper) - add_op = PyQnnWrapper.PyQnnOpWrapper( + add_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseAdd.op_name, diff --git a/backends/qualcomm/builders/op_amax.py b/backends/qualcomm/builders/op_amax.py index d0335f95463..5305528996b 100644 --- a/backends/qualcomm/builders/op_amax.py +++ b/backends/qualcomm/builders/op_amax.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -60,11 +60,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - reduce_max_op = PyQnnWrapper.PyQnnOpWrapper( + reduce_max_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReduceMax.op_name, @@ -73,7 +73,7 @@ def define_node( reduce_max_op.AddOutputTensors([output_tensor_wrapper]) reduce_max_op.AddTensorParam( OpReduceMax.param_axes, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(mean_dims_shape), mean_dims_shape, np.array(mean_dims, dtype=np.uint32), @@ -83,7 +83,7 @@ def define_node( keep_dims = cast(bool, node.args[2]) reduce_max_op.AddScalarParam( OpReduceMax.param_keep_dims, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: keep_dims}, ) diff --git a/backends/qualcomm/builders/op_amin.py b/backends/qualcomm/builders/op_amin.py index 142340dbae0..c8589591a1a 100644 --- a/backends/qualcomm/builders/op_amin.py +++ b/backends/qualcomm/builders/op_amin.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -60,11 +60,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - reduce_min_op = PyQnnWrapper.PyQnnOpWrapper( + reduce_min_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReduceMin.op_name, @@ -73,7 +73,7 @@ def define_node( reduce_min_op.AddOutputTensors([output_tensor_wrapper]) reduce_min_op.AddTensorParam( OpReduceMin.param_axes, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(mean_dims_shape), mean_dims_shape, np.array(mean_dims, dtype=np.uint32), @@ -83,7 +83,7 @@ def define_node( keep_dims = cast(bool, node.args[2]) reduce_min_op.AddScalarParam( OpReduceMin.param_keep_dims, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: keep_dims}, ) diff --git a/backends/qualcomm/builders/op_and.py b/backends/qualcomm/builders/op_and.py index 9e43b4df5b2..b3736025414 100644 --- a/backends/qualcomm/builders/op_and.py +++ b/backends/qualcomm/builders/op_and.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -16,7 +16,7 @@ @register_node_visitor class OpAnd(NodeVisitor): - target = ["aten.bitwise_and.Tensor"] + target = ["aten.bitwise_and.Tensor", "aten.logical_and.default"] def __init__(self, *args) -> None: super().__init__(*args) @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) and_output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -50,7 +50,7 @@ def define_node( nodes_to_wrappers, ) and_input_tensors.append(input_tensor_wrapper) - and_op = PyQnnWrapper.PyQnnOpWrapper( + and_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseAnd.op_name, diff --git a/backends/qualcomm/builders/op_arange.py b/backends/qualcomm/builders/op_arange.py index e8c4c7d5267..0a95d55dca3 100644 --- a/backends/qualcomm/builders/op_arange.py +++ b/backends/qualcomm/builders/op_arange.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -23,8 +23,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: start, end = node.args[0:2] step = node.args[2] if len(node.args) > 2 else 1 out_tensor = torch.arange(start, end, step) @@ -36,6 +36,6 @@ def define_node( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) diff --git a/backends/qualcomm/builders/op_argmax.py b/backends/qualcomm/builders/op_argmax.py index e81b0dd1d95..60293e0c104 100644 --- a/backends/qualcomm/builders/op_argmax.py +++ b/backends/qualcomm/builders/op_argmax.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA @@ -25,8 +25,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) output_tensor = self.get_tensor(node, node) @@ -34,7 +34,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) argmax_input_tensors = [argmax_inp_tensor_wrapper] @@ -42,7 +42,7 @@ def define_node( node, node, output_tensor.to(torch.int32), - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) argmax_output_tensors = [argmax_out_tensor_wrapper] @@ -53,7 +53,7 @@ def define_node( if QCOM_AXIS_ORDER in node.meta: dim = node.meta[QCOM_AXIS_ORDER].index(dim) - argmax_op = PyQnnWrapper.PyQnnOpWrapper( + argmax_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpArgmax.op_name, @@ -63,7 +63,7 @@ def define_node( argmax_op.AddScalarParam( OpArgmax.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(dim)}, ) @@ -71,7 +71,7 @@ def define_node( keep_dims = cast(bool, node.args[2]) argmax_op.AddScalarParam( OpArgmax.param_keep_dims, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: keep_dims}, ) diff --git a/backends/qualcomm/builders/op_argmin.py b/backends/qualcomm/builders/op_argmin.py index a9fa2021bb0..d66a4300b96 100644 --- a/backends/qualcomm/builders/op_argmin.py +++ b/backends/qualcomm/builders/op_argmin.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA @@ -25,8 +25,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) output_tensor = self.get_tensor(node, node) @@ -34,7 +34,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) argmin_input_tensors = [argmin_inp_tensor_wrapper] @@ -42,7 +42,7 @@ def define_node( node, node, output_tensor.to(torch.int32), - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) argmin_output_tensors = [argmin_out_tensor_wrapper] @@ -53,7 +53,7 @@ def define_node( if QCOM_AXIS_ORDER in node.meta: dim = node.meta[QCOM_AXIS_ORDER].index(dim) - argmin_op = PyQnnWrapper.PyQnnOpWrapper( + argmin_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpArgmin.op_name, @@ -63,7 +63,7 @@ def define_node( argmin_op.AddScalarParam( OpArgmin.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(dim)}, ) @@ -71,7 +71,7 @@ def define_node( keep_dims = cast(bool, node.args[2]) argmin_op.AddScalarParam( OpArgmin.param_keep_dims, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: keep_dims}, ) diff --git a/backends/qualcomm/builders/op_asin.py b/backends/qualcomm/builders/op_asin.py index ff50380e62c..1907890c69a 100644 --- a/backends/qualcomm/builders/op_asin.py +++ b/backends/qualcomm/builders/op_asin.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree.from typing import cast, Dict from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from .node_visitor import NodeVisitor @@ -24,15 +24,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -41,11 +41,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - asin_op = PyQnnWrapper.PyQnnOpWrapper( + asin_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseAsin.op_name, diff --git a/backends/qualcomm/builders/op_atan.py b/backends/qualcomm/builders/op_atan.py index 83c47b9103d..f208f1bedb5 100644 --- a/backends/qualcomm/builders/op_atan.py +++ b/backends/qualcomm/builders/op_atan.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from .node_visitor import NodeVisitor @@ -23,15 +23,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -40,11 +40,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - atan_op = PyQnnWrapper.PyQnnOpWrapper( + atan_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseAtan.op_name, diff --git a/backends/qualcomm/builders/op_avg_pool2d.py b/backends/qualcomm/builders/op_avg_pool2d.py index 6e0f70474ea..4e44c333f6e 100644 --- a/backends/qualcomm/builders/op_avg_pool2d.py +++ b/backends/qualcomm/builders/op_avg_pool2d.py @@ -6,7 +6,7 @@ import warnings from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -33,15 +33,15 @@ def _get_filter_size(self, node): def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -50,7 +50,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -112,7 +112,7 @@ def define_node( ) return - avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper( + avg_pool2d_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpPoolAvg2d.op_name, @@ -122,7 +122,7 @@ def define_node( avg_pool2d_op.AddTensorParam( OpPoolAvg2d.param_filter_size, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(filter_size_shape), filter_size_shape, np.array( @@ -133,7 +133,7 @@ def define_node( ) avg_pool2d_op.AddTensorParam( OpPoolAvg2d.param_stride, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(stride_shape), stride_shape, np.array( @@ -144,7 +144,7 @@ def define_node( ) avg_pool2d_op.AddTensorParam( OpPoolAvg2d.param_pad_amount, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(padding_shape), padding_shape, np.array( @@ -156,12 +156,12 @@ def define_node( avg_pool2d_op.AddScalarParam( OpPoolAvg2d.param_rounding_mode, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(mode)}, ) avg_pool2d_op.AddScalarParam( OpPoolAvg2d.param_count_pad_for_edges, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: count_include_pad}, ) diff --git a/backends/qualcomm/builders/op_avg_pool3d.py b/backends/qualcomm/builders/op_avg_pool3d.py new file mode 100644 index 00000000000..5e27ce2b4c1 --- /dev/null +++ b/backends/qualcomm/builders/op_avg_pool3d.py @@ -0,0 +1,299 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager +import numpy as np + +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpPoolAvg3d, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class AvgPool3d(NodeVisitor): + target = ["aten.avg_pool3d.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: + + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + # kernel info + filter_size = cast(List[int], node.args[1]) + if len(filter_size) == 1: + filter_size *= 3 + filter_size_shape = [len(filter_size)] + + # stride info + stride = cast(List[int], node.args[2]) + if len(stride) == 1: + stride *= 3 + stride_shape = [len(stride)] + + # padding info + padding = [0, 0, 0] + if len(node.args) > 3: + padding = cast(List[int], node.args[3]) + if len(padding) == 1: + padding *= 3 + + # if ceil mode is True, use ceil instead of floor to compute the output shape + mode = OpPoolAvg3d.RoundingMode.FLOOR + if len(node.args) > 4: + ceil_mode = cast(bool, node.args[4]) + if ceil_mode: + mode = OpPoolAvg3d.RoundingMode.CEIL + + count_pad_for_edges = node.args[5] if len(node.args) > 5 else False + + # pad left, pad right + depth_pad_l = padding[0] + depth_pad_r = padding[0] + height_pad_l = padding[1] + height_pad_r = padding[1] + width_pad_l = padding[2] + width_pad_r = padding[2] + + shape_pad = [ + [depth_pad_l, depth_pad_r], + [height_pad_l, height_pad_r], + [width_pad_l, width_pad_r], + ] + padding_shape = [len(shape_pad), len(shape_pad[0])] + + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + avg_pool3d_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpPoolAvg3d.op_name, + ) + + avg_pool3d_op.AddInputTensors([input_tensor_wrapper]) + avg_pool3d_op.AddOutputTensors([output_tensor_wrapper]) + + avg_pool3d_op.AddTensorParam( + OpPoolAvg3d.param_filter_size, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(filter_size_shape), + filter_size_shape, + np.array( + filter_size, + dtype=np.uint32, + ), + True, + ) + + avg_pool3d_op.AddTensorParam( + OpPoolAvg3d.param_stride, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(stride_shape), + stride_shape, + np.array( + stride, + dtype=np.uint32, + ), + True, + ) + + avg_pool3d_op.AddTensorParam( + OpPoolAvg3d.param_pad_amount, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(padding_shape), + padding_shape, + np.array( + shape_pad, + dtype=np.uint32, + ), + True, + ) + + avg_pool3d_op.AddScalarParam( + OpPoolAvg3d.param_count_pad_for_edges, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: count_pad_for_edges}, + ) + + avg_pool3d_op.AddScalarParam( + OpPoolAvg3d.param_rounding_mode, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(mode)}, + ) + + return avg_pool3d_op + + +@register_node_visitor +class AdaptiveAvgPool3d(NodeVisitor): + target = ["aten._adaptive_avg_pool3d.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + # NOTE: This operator is layout sensitive, so the input tensor shape is always N,D,H,W,C. + input_depth = input_tensor.shape[1] + input_height = input_tensor.shape[2] + input_width = input_tensor.shape[3] + output_depth = node.args[1][0] + output_height = node.args[1][1] + output_width = node.args[1][2] + if output_depth is None: + output_depth = input_depth + if output_height is None: + output_height = input_height + if output_width is None: + output_width = input_width + + # kernel info & stride info + stride_height = input_height // output_height + filter_height = input_height - (output_height - 1) * stride_height + stride_width = input_width // output_width + filter_width = input_width - (output_width - 1) * stride_width + stride_depth = input_depth // output_depth + filter_depth = input_depth - (output_depth - 1) * stride_depth + + filter_size = [filter_depth, filter_height, filter_width] + filter_shape = [len(filter_size)] + stride = [stride_depth, stride_height, stride_width] + stride_shape = [len(stride)] + + depth = (output_depth - 1) * stride_depth + filter_depth - input_depth + height = (output_height - 1) * stride_height + filter_height - input_height + width = (output_width - 1) * stride_width + filter_width - input_width + + if any(x != 0 for x in (depth, height, width)): + warnings.warn( + "[QNN Delegate Op Builder]: Depth or Height or Width is not suitable, fallback op", + stacklevel=1, + ) + return + + count_pad_for_edges = False + # This operator use the default rounding mode of avg_pool3d, floor. + mode = OpPoolAvg3d.RoundingMode.FLOOR + + # pad left, pad right, use default 0 + depth_pad_b = 0 + depth_pad_a = 0 + height_pad_b = 0 + height_pad_a = 0 + width_pad_b = 0 + width_pad_a = 0 + + shape_pad = [ + [depth_pad_b, depth_pad_a], + [height_pad_b, height_pad_a], + [width_pad_b, width_pad_a], + ] + padding_shape = [len(shape_pad), len(shape_pad[0])] + + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + adaptive_avg_pool3d_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpPoolAvg3d.op_name, + ) + + adaptive_avg_pool3d_op.AddInputTensors([input_tensor_wrapper]) + adaptive_avg_pool3d_op.AddOutputTensors([output_tensor_wrapper]) + + adaptive_avg_pool3d_op.AddTensorParam( + OpPoolAvg3d.param_filter_size, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(filter_shape), + filter_shape, + np.array( + filter_size, + dtype=np.uint32, + ), + True, + ) + + adaptive_avg_pool3d_op.AddTensorParam( + OpPoolAvg3d.param_stride, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(stride_shape), + stride_shape, + np.array( + stride, + dtype=np.uint32, + ), + True, + ) + + adaptive_avg_pool3d_op.AddTensorParam( + OpPoolAvg3d.param_pad_amount, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(padding_shape), + padding_shape, + np.array( + shape_pad, + dtype=np.uint32, + ), + True, + ) + + adaptive_avg_pool3d_op.AddScalarParam( + OpPoolAvg3d.param_count_pad_for_edges, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: count_pad_for_edges}, + ) + + adaptive_avg_pool3d_op.AddScalarParam( + OpPoolAvg3d.param_rounding_mode, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(mode)}, + ) + + return adaptive_avg_pool3d_op diff --git a/backends/qualcomm/builders/op_batch_norm.py b/backends/qualcomm/builders/op_batch_norm.py index 25a9c2b123e..eecc32113b1 100644 --- a/backends/qualcomm/builders/op_batch_norm.py +++ b/backends/qualcomm/builders/op_batch_norm.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from executorch.backends.qualcomm.utils.constants import ( @@ -54,8 +54,8 @@ def try_dequantize(self, node: torch.fx.Node, tensor: torch.Tensor): def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -71,7 +71,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) batch_norm_input_tensors = [input_tensor_wrapper] @@ -81,7 +81,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) batch_norm_output_tensors = [output_tensor_wrapper] @@ -117,7 +117,7 @@ def define_node( filter_node, node, filter_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) batch_norm_input_tensors.append(filter_tensor_wrapper) @@ -135,12 +135,12 @@ def define_node( bias_node, node, bias_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) batch_norm_input_tensors.append(bias_tensor_wrapper) - batch_norm_op = PyQnnWrapper.PyQnnOpWrapper( + batch_norm_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpBatchnorm.op_name, diff --git a/backends/qualcomm/builders/op_binary.py b/backends/qualcomm/builders/op_binary.py index 4f4d8b9b560..2421fcb9af5 100644 --- a/backends/qualcomm/builders/op_binary.py +++ b/backends/qualcomm/builders/op_binary.py @@ -6,7 +6,7 @@ import warnings from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch from executorch.backends.qualcomm.utils.constants import QCOM_DATA @@ -33,14 +33,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) binary_output_tensors = [output_tensor_wrapper] @@ -49,7 +49,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -60,7 +60,7 @@ def define_node( ) binary_input_tensors.append(input_tensor_wrapper) - binary_op = PyQnnWrapper.PyQnnOpWrapper( + binary_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseBinary.op_name, @@ -77,7 +77,7 @@ def define_node( binary_op.AddScalarParam( OpElementWiseBinary.param_operation, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(QNN_BINARY_OPERATOR[node.target])}, ) diff --git a/backends/qualcomm/builders/op_bmm.py b/backends/qualcomm/builders/op_bmm.py index 92c8f1dde3e..191c4497b2c 100644 --- a/backends/qualcomm/builders/op_bmm.py +++ b/backends/qualcomm/builders/op_bmm.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,8 +24,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: bmm_input_tensors = [] for index in range(2): input_node = self.get_node(node.args[index]) @@ -35,7 +35,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) bmm_input_tensors.append(input_tensor_wrapper) @@ -45,12 +45,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) bmm_output_tensors = [output_tensor_wrapper] - bmm_op = PyQnnWrapper.PyQnnOpWrapper( + bmm_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpMatMul.op_name ) bmm_op.AddInputTensors(bmm_input_tensors) diff --git a/backends/qualcomm/builders/op_cat.py b/backends/qualcomm/builders/op_cat.py index 9f6eb6676cf..31045426959 100644 --- a/backends/qualcomm/builders/op_cat.py +++ b/backends/qualcomm/builders/op_cat.py @@ -6,7 +6,7 @@ import warnings from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,24 +27,25 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: - list_of_tensors = cast(List[torch.fx.Node], node.args[0]) - list_of_tensor_wrappers = [] - - for tensor_input in list_of_tensors: - input_tensor = self.get_tensor(self.get_node(tensor_input), node) - list_of_tensor_wrappers.append( + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: + input_nodes = cast(List[torch.fx.Node], node.args[0]) + input_tensor_wrappers = [] + + for input_node in input_nodes: + source_input_node = self.get_node(input_node) + input_tensor = self.get_tensor(source_input_node, node) + input_tensor_wrappers.append( self.define_tensor( - tensor_input, + source_input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) ) - if len(list_of_tensors) != len(list_of_tensor_wrappers): + if len(input_nodes) != len(input_tensor_wrappers): warnings.warn( "[QNN Delegate Op Builder]: The number or input tensors is not equal to the number of input tensor wrappers.", stacklevel=1, @@ -56,7 +57,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -71,17 +72,17 @@ def define_node( if QCOM_AXIS_ORDER in node.meta: axis = node.meta[QCOM_AXIS_ORDER].index(axis) - concat_op = PyQnnWrapper.PyQnnOpWrapper( + concat_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpConcat.op_name, ) - concat_op.AddInputTensors(list_of_tensor_wrappers) + concat_op.AddInputTensors(input_tensor_wrappers) concat_op.AddOutputTensors([output_tensor_wrapper]) concat_op.AddScalarParam( OpConcat.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(axis)}, ) diff --git a/backends/qualcomm/builders/op_ceil.py b/backends/qualcomm/builders/op_ceil.py index 6b85592165c..58397b4382c 100644 --- a/backends/qualcomm/builders/op_ceil.py +++ b/backends/qualcomm/builders/op_ceil.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,15 +24,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -41,11 +41,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - ceil_op = PyQnnWrapper.PyQnnOpWrapper( + ceil_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseCeil.op_name, diff --git a/backends/qualcomm/builders/op_clamp.py b/backends/qualcomm/builders/op_clamp.py index 1e13b70f78e..b19790e7735 100644 --- a/backends/qualcomm/builders/op_clamp.py +++ b/backends/qualcomm/builders/op_clamp.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -26,15 +26,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -55,11 +55,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - clamp_op = PyQnnWrapper.PyQnnOpWrapper( + clamp_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReluMinMax.op_name, @@ -68,12 +68,12 @@ def define_node( clamp_op.AddOutputTensors([output_tensor_wrapper]) clamp_op.AddScalarParam( OpReluMinMax.param_max_value, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, {QCOM_DATA: np.float32(output_max)}, ) clamp_op.AddScalarParam( OpReluMinMax.param_min_value, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, {QCOM_DATA: np.float32(output_min)}, ) diff --git a/backends/qualcomm/builders/op_conv.py b/backends/qualcomm/builders/op_conv.py new file mode 100644 index 00000000000..aadb893c8ff --- /dev/null +++ b/backends/qualcomm/builders/op_conv.py @@ -0,0 +1,263 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS + +from .node_visitor import NodeVisitor, PER_CHANNEL_ENCODING +from .node_visitor_manager import register_node_visitor +from .qnn_constants import ( + OpConv2d, + OpConv3d, + OpDepthWiseConv2d, + OpTransposeConv2d, + OpTransposeConv3d, + QNN_OP_PACKAGE_NAME_QTI_AISW, +) +from .utils import get_parameter + + +@register_node_visitor +class Conv2d(NodeVisitor): + target = ["aten.convolution.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def _add_conv_op_parameter( + self, + OP, + conv_op, + conv_input_tensors, + conv_output_tensors, + stride, + stride_shape, + padding, + padding_shape, + dilation, + dilation_shape, + output_padding=None, + output_padding_shape=None, + transpose_conv=False, + groups=None, + ) -> PyQnnManager.PyQnnOpWrapper: + """ + This function is shared among Conv1D, Conv2D, and DepthWise Conv2D as most of the required parameters overlaps. + """ + conv_op.AddInputTensors(conv_input_tensors) + conv_op.AddOutputTensors(conv_output_tensors) + conv_op.AddTensorParam( + OP.param_stride, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(stride_shape), + stride_shape, + np.array(stride, dtype=np.uint32), + True, + ) + conv_op.AddTensorParam( + OP.param_pad_amount, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(padding_shape), + padding_shape, + np.array( + padding, + dtype=np.uint32, + ), + True, + ) + + if transpose_conv: + conv_op.AddTensorParam( + OP.param_output_padding, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(output_padding_shape), + output_padding_shape, + np.array(output_padding, dtype=np.uint32), + True, + ) + else: + conv_op.AddTensorParam( + OP.param_dilation, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(dilation_shape), + dilation_shape, + np.array(dilation, dtype=np.uint32), + True, + ) + + if groups is not None: + conv_op.AddScalarParam( + OP.param_group, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(groups)}, + ) + + return conv_op + + def _reduce_bias_scales( + self, + node: torch.fx.Node, + filter_node: torch.fx.Node, + bias_node: torch.fx.Node, + groups: int, + ): + """_summary_ + If transpose_conv has groups, need special handle for bias_node's per channel quant. + Check _derived_bias_quant_spec under backends/qualcomm/quantizer/qconfig.py for more info. + """ + + filter_scales = filter_node.meta[QCOM_QUANT_ATTRS]["scales"] + bias_scales = bias_node.meta[QCOM_QUANT_ATTRS]["scales"] + bias_zero_points = bias_node.meta[QCOM_QUANT_ATTRS]["zero_points"] + + # Adding this condition to prevent reduce twice: op_validation and qnn_preprocess + if filter_scales.numel() != bias_scales.numel(): + bias_scales = bias_scales.view(-1, groups)[:, 0] + bias_zero_points = bias_zero_points.view(-1, groups)[:, 0] + bias_node.meta[QCOM_QUANT_ATTRS]["scales"] = bias_scales + bias_node.meta[QCOM_QUANT_ATTRS]["zero_points"] = bias_zero_points + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[str, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + assert ( + input_tensor.dim() != 3 + ), "All Conv1D should be converted to Conv2D in CanonicalizeConv," + assert input_tensor.dim() in { + 4, + 5, + }, "Only Conv2d and Conv3d is supported in conv builder," + + is_conv2d = input_tensor.dim() == 4 + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + filter_node = self.get_node(node.args[1]) + filter_tensor = get_parameter(filter_node, self.edge_program) + + stride = cast(List[int], node.args[3]) + padding = cast(List[int], node.args[4]) + dilation = cast(List[int], node.args[5]) + output_padding = cast(List[int], node.args[7]) + groups = cast(int, node.args[8]) + + # weight of pytorch OIHW(conv2d) / OIDHW(conv3d) or IOHW(conv_transpose2d) / IODHW(conv_transpose3d), + # yet QNN is HWIO or DHWIO for both conv and conv_transpose. + is_transpose_conv = cast(bool, node.args[6]) + if is_conv2d: + filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0) + else: + filter_axis_order = ( + (2, 3, 4, 0, 1) if is_transpose_conv else (2, 3, 4, 1, 0) + ) + filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous() + filter_tensor_wrapper = self.define_tensor( + filter_node, + node, + filter_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) + conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper] + if node.args[2] is not None: + bias_node = self.get_node(node.args[2]) + # TODO: Double check on condition below once QNN supports transpose_conv with block_quant. + # By checking node.args[1].target, only allow per_channel_quant to go through and bypass block_quant. + if ( + is_transpose_conv + and groups != 1 + and bias_node.meta.get(QCOM_QUANT_ATTRS) is not None + and node.args[1].target in PER_CHANNEL_ENCODING + ): + self._reduce_bias_scales(node, filter_node, bias_node, groups) + + bias_tensor = get_parameter(bias_node, self.edge_program) + bias_tensor_wrapper = self.define_tensor( + bias_node, + node, + bias_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) + conv_input_tensors.append(bias_tensor_wrapper) + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + conv_output_tensors = [output_tensor_wrapper] + + # Qnn filter tensor is (H, W, Cin, Cout) or (D, H, W, Cin, Cout) + group_input_channels = filter_tensor.shape[-2] + group_output_channels = int(filter_tensor.shape[-1] / groups) + # 1) groups = input_channels (i.e. group_input_channels = 1) + # 2) output_channels is a positive integer multiple of input channels + # TODO: Currently, negative results will be zero with Depthwise conv2d when input_channel == groups == 1 + # and test on QNN 2.14 rc1. Need to carefully investigate. + is_depthwise_conv = ( + (group_input_channels == 1) + and (group_output_channels % group_input_channels == 0) + and (groups > 2) + ) + if len(padding) == 1: + padding = padding + padding + padding = [[x, x] for x in padding] + + stride_shape = [len(stride)] + padding_shape = [len(padding), len(padding[0])] + dilation_shape = [len(dilation)] + output_padding_shape = [len(output_padding)] + + if is_transpose_conv: + assert all( + val == 1 for val in dilation + ), "CanonicalizeConv pass should perform dilate for transpose_conv." + op_class = OpTransposeConv2d if is_conv2d else OpTransposeConv3d + elif is_depthwise_conv: + assert is_conv2d, "DepthWise only supports Conv2d" + op_class = OpDepthWiseConv2d + else: + op_class = OpConv2d if is_conv2d else OpConv3d + + conv_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + op_class.op_name, + ) + conv_op = self._add_conv_op_parameter( + op_class, + conv_op, + conv_input_tensors, + conv_output_tensors, + stride, + stride_shape, + padding, + padding_shape, + dilation, + dilation_shape, + output_padding, + output_padding_shape, + is_transpose_conv, + None if is_depthwise_conv else groups, + ) + + return conv_op diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py deleted file mode 100644 index 1cfc1e45c9b..00000000000 --- a/backends/qualcomm/builders/op_conv2d.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import cast, Dict, List - -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper - -import numpy as np -import torch -from executorch.backends.qualcomm.utils.constants import QCOM_DATA - -from .node_visitor import NodeVisitor -from .node_visitor_manager import register_node_visitor -from .qnn_constants import ( - OpConv2d, - OpDepthWiseConv2d, - OpTransposeConv2d, - QNN_OP_PACKAGE_NAME_QTI_AISW, -) -from .utils import get_parameter - - -@register_node_visitor -class Conv2d(NodeVisitor): - target = ["aten.convolution.default"] - - def __init__(self, *args) -> None: - super().__init__(*args) - - def _add_conv_op_parameter( - self, - OP, - conv_op, - conv_input_tensors, - conv_output_tensors, - stride, - stride_shape, - padding, - padding_shape, - dilation, - dilation_shape, - output_padding=None, - output_padding_shape=None, - transpose_conv=False, - groups=None, - ) -> PyQnnWrapper.PyQnnOpWrapper: - """ - This function is shared among Conv1D, Conv2D, and DepthWise Conv2D as most of the required parameters overlaps. - """ - conv_op.AddInputTensors(conv_input_tensors) - conv_op.AddOutputTensors(conv_output_tensors) - conv_op.AddTensorParam( - OP.param_stride, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - len(stride_shape), - stride_shape, - np.array(stride, dtype=np.uint32), - True, - ) - conv_op.AddTensorParam( - OP.param_pad_amount, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - len(padding_shape), - padding_shape, - np.array( - [[padding[0], padding[0]], [padding[1], padding[1]]], - dtype=np.uint32, - ), - True, - ) - - if transpose_conv: - conv_op.AddTensorParam( - OP.param_output_padding, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - len(output_padding_shape), - output_padding_shape, - np.array(output_padding, dtype=np.uint32), - True, - ) - else: - conv_op.AddTensorParam( - OP.param_dilation, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - len(dilation_shape), - dilation_shape, - np.array(dilation, dtype=np.uint32), - True, - ) - - if groups is not None: - conv_op.AddScalarParam( - OP.param_group, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {QCOM_DATA: np.uint32(groups)}, - ) - - return conv_op - - def define_node( - self, - node: torch.fx.Node, - nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = self.get_node(node.args[0]) - input_tensor = self.get_tensor(input_node, node) - assert ( - input_tensor.dim() == 4 - ), "All Conv1D should be converted to Conv2D in CanonicalizeConv," - input_tensor_wrapper = self.define_tensor( - input_node, - node, - input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, - ) - - filter_node = self.get_node(node.args[1]) - filter_tensor = get_parameter(filter_node, self.edge_program) - # weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO - is_transpose_conv = cast(bool, node.args[6]) - filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0) - filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous() - filter_tensor_wrapper = self.define_tensor( - filter_node, - node, - filter_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) - conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper] - - if node.args[2] is not None: - bias_node = self.get_node(node.args[2]) - bias_tensor = get_parameter(bias_node, self.edge_program) - bias_tensor_wrapper = self.define_tensor( - bias_node, - node, - bias_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) - conv_input_tensors.append(bias_tensor_wrapper) - - output_tensor = self.get_tensor(node, node) - output_tensor_wrapper = self.define_tensor( - node, - node, - output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, - ) - conv_output_tensors = [output_tensor_wrapper] - - stride = cast(List[int], node.args[3]) - padding = cast(List[int], node.args[4]) - dilation = cast(List[int], node.args[5]) - output_padding = cast(List[int], node.args[7]) - - groups = cast(int, node.args[8]) - # Qnn filter tensor is (H, W, Cin, Cout) - group_input_channels = filter_tensor.shape[2] - group_output_channels = int(filter_tensor.shape[3] / groups) - # 1) groups = input_channels (i.e. group_input_channels = 1) - # 2) output_channels is a positive integer multiple of input channels - # TODO: Currently, negative results will be zero with Depthwise conv2d when input_channel == groups == 1 - # and test on QNN 2.14 rc1. Need to carefully investigate. - is_depthwise_conv = ( - (group_input_channels == 1) - and (group_output_channels % group_input_channels == 0) - and (groups > 2) - ) - if len(padding) == 1: - padding = padding + padding - - stride_shape = [len(stride)] - padding_shape = [2, 2] - dilation_shape = [len(dilation)] - output_padding_shape = [len(output_padding)] - - if is_depthwise_conv: - op_class = OpDepthWiseConv2d - elif is_transpose_conv: - op_class = OpTransposeConv2d - else: - op_class = OpConv2d - - conv_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - op_class.op_name, - ) - conv_op = self._add_conv_op_parameter( - op_class, - conv_op, - conv_input_tensors, - conv_output_tensors, - stride, - stride_shape, - padding, - padding_shape, - dilation, - dilation_shape, - output_padding, - output_padding_shape, - is_transpose_conv, - None if is_depthwise_conv else groups, - ) - - return conv_op diff --git a/backends/qualcomm/builders/op_copy.py b/backends/qualcomm/builders/op_copy.py index 164c910835e..a1caa1c98a2 100644 --- a/backends/qualcomm/builders/op_copy.py +++ b/backends/qualcomm/builders/op_copy.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS @@ -25,15 +25,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[1]) input_tensor = self.get_tensor(input_node, node) copy_inp_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -48,12 +48,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) copy_output_tensors = [output_tensor_wrapper] - copy_op = PyQnnWrapper.PyQnnOpWrapper( + copy_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReshape.op_name, diff --git a/backends/qualcomm/builders/op_cos.py b/backends/qualcomm/builders/op_cos.py index 9ff11d86dda..9701f9a211b 100644 --- a/backends/qualcomm/builders/op_cos.py +++ b/backends/qualcomm/builders/op_cos.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,15 +24,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -41,11 +41,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - cos_op = PyQnnWrapper.PyQnnOpWrapper( + cos_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseCos.op_name, diff --git a/backends/qualcomm/builders/op_cum_sum.py b/backends/qualcomm/builders/op_cum_sum.py index da2b025fe9f..5668fb2ab70 100644 --- a/backends/qualcomm/builders/op_cum_sum.py +++ b/backends/qualcomm/builders/op_cum_sum.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -36,15 +36,15 @@ def get_param(self, node, input_tensor): def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -57,11 +57,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - cumsum_op = PyQnnWrapper.PyQnnOpWrapper( + cumsum_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpCumulativeSum.op_name, @@ -70,17 +70,17 @@ def define_node( cumsum_op.AddOutputTensors([output_tensor_wrapper]) cumsum_op.AddScalarParam( OpCumulativeSum.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: dim}, ) cumsum_op.AddScalarParam( OpCumulativeSum.param_exclusive, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: False}, ) cumsum_op.AddScalarParam( OpCumulativeSum.param_reverse, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: False}, ) diff --git a/backends/qualcomm/builders/op_custom_op.py b/backends/qualcomm/builders/op_custom_op.py index 52a15ef95f2..203e3a2ac64 100644 --- a/backends/qualcomm/builders/op_custom_op.py +++ b/backends/qualcomm/builders/op_custom_op.py @@ -6,7 +6,7 @@ import warnings from typing import Dict, Iterable -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -32,9 +32,9 @@ def __init__(self, op_package_info: QnnExecuTorchOpPackageInfo, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: - custom_op = PyQnnWrapper.PyQnnOpWrapper( + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: + custom_op = PyQnnManager.PyQnnOpWrapper( node.name, self.op_package_info.op_package_name, self.op_package_info.qnn_op_type_name, @@ -57,7 +57,7 @@ def define_node( arg, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) custom_input_tensors.append(input_tensor_wrapper) @@ -83,7 +83,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) custom_output_tensors = [output_tensor_wrapper] diff --git a/backends/qualcomm/builders/op_depth_to_space.py b/backends/qualcomm/builders/op_depth_to_space.py index 908e0949162..5e1dbe639c5 100644 --- a/backends/qualcomm/builders/op_depth_to_space.py +++ b/backends/qualcomm/builders/op_depth_to_space.py @@ -6,7 +6,7 @@ from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -44,7 +44,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -54,7 +54,7 @@ def define_node( block_size = np.array(block_size, dtype=np.uint32) block_size_shape = [2] - depth_to_space_op = PyQnnWrapper.PyQnnOpWrapper( + depth_to_space_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpDepthToSpace.op_name, @@ -63,7 +63,7 @@ def define_node( depth_to_space_op.AddOutputTensors([output_tensor_wrapper]) depth_to_space_op.AddTensorParam( OpDepthToSpace.param_block_size, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(block_size.shape), block_size_shape, block_size, @@ -71,7 +71,7 @@ def define_node( ) depth_to_space_op.AddScalarParam( OpDepthToSpace.param_mode, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(OpDepthToSpace.Mode.CRD)}, ) diff --git a/backends/qualcomm/builders/op_dequantize.py b/backends/qualcomm/builders/op_dequantize.py index c4d9b8c29a4..31801c19aaf 100644 --- a/backends/qualcomm/builders/op_dequantize.py +++ b/backends/qualcomm/builders/op_dequantize.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -21,8 +21,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: dequant_input_tensors = [] input_node = node.args[0] input_tensor = self.get_tensor(input_node, node) @@ -30,7 +30,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) dequant_input_tensors.append(inp_tensor_wrapper) @@ -40,12 +40,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) dequant_output_tensors = [output_tensor_wrapper] - dequant_op = PyQnnWrapper.PyQnnOpWrapper( + dequant_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpDequantize.op_name, diff --git a/backends/qualcomm/builders/op_div.py b/backends/qualcomm/builders/op_div.py index 9fc4a9302b0..1467147d390 100644 --- a/backends/qualcomm/builders/op_div.py +++ b/backends/qualcomm/builders/op_div.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) div_output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -51,7 +51,7 @@ def define_node( ) div_input_tensors.append(input_tensor_wrapper) - div_op = PyQnnWrapper.PyQnnOpWrapper( + div_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseDivide.op_name, diff --git a/backends/qualcomm/builders/op_elu.py b/backends/qualcomm/builders/op_elu.py index 65e8d93f414..cfbb011f2f5 100644 --- a/backends/qualcomm/builders/op_elu.py +++ b/backends/qualcomm/builders/op_elu.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -26,8 +26,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: # tensor input input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -36,7 +36,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) elu_input_tensors = [input_tensor_wrapper] @@ -46,24 +46,23 @@ def define_node( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) elu_output_tensors = [output_tensor_wrapper] - elu_op = PyQnnWrapper.PyQnnOpWrapper( + elu_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElu.op_name, ) elu_op.AddInputTensors(elu_input_tensors) elu_op.AddOutputTensors(elu_output_tensors) - - if len(node.args) == 2: + if len(node.args) > 1: elu_op.AddScalarParam( OpElu.param_alpha, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, - {QCOM_DATA: np.uint32(node.args[1])}, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + {QCOM_DATA: np.float32(node.args[1])}, ) return elu_op diff --git a/backends/qualcomm/builders/op_embedding.py b/backends/qualcomm/builders/op_embedding.py index 45adc20fa79..03a33f14bec 100644 --- a/backends/qualcomm/builders/op_embedding.py +++ b/backends/qualcomm/builders/op_embedding.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: weight_node = self.get_node(node.args[0]) weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, node, weight_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) @@ -45,7 +45,7 @@ def define_node( indices_node, node, indices_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -56,12 +56,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) gather_output_tensors = [output_tensor_wrapper] - gather_op = PyQnnWrapper.PyQnnOpWrapper( + gather_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpGather.op_name, @@ -72,7 +72,7 @@ def define_node( # For now, default axis is zero. gather_op.AddScalarParam( OpGather.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_32, {QCOM_DATA: np.int32(0)}, ) diff --git a/backends/qualcomm/builders/op_eq.py b/backends/qualcomm/builders/op_eq.py index fcf3213d3a9..98fb34834e5 100644 --- a/backends/qualcomm/builders/op_eq.py +++ b/backends/qualcomm/builders/op_eq.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -51,7 +51,7 @@ def define_node( ) input_tensors.append(input_tensor_wrapper) - eq_op = PyQnnWrapper.PyQnnOpWrapper( + eq_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseEqual.op_name, diff --git a/backends/qualcomm/builders/op_exp.py b/backends/qualcomm/builders/op_exp.py index 9a80e7fb4f4..333c57bcf63 100644 --- a/backends/qualcomm/builders/op_exp.py +++ b/backends/qualcomm/builders/op_exp.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,8 +24,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: # tensor input input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -34,7 +34,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) exp_input_tensors = [input_tensor_wrapper] @@ -44,12 +44,12 @@ def define_node( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) exp_output_tensors = [output_tensor_wrapper] - exp_op = PyQnnWrapper.PyQnnOpWrapper( + exp_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseExp.op_name, diff --git a/backends/qualcomm/builders/op_expand.py b/backends/qualcomm/builders/op_expand.py index 01a8da42752..1aadb17e513 100644 --- a/backends/qualcomm/builders/op_expand.py +++ b/backends/qualcomm/builders/op_expand.py @@ -6,7 +6,7 @@ import warnings from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -26,15 +26,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -43,7 +43,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -66,7 +66,7 @@ def define_node( if sizes[i] != -1 and shape[i] == 1: multiples[i] = sizes[i] - tile_op = PyQnnWrapper.PyQnnOpWrapper( + tile_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpTile.op_name, @@ -75,7 +75,7 @@ def define_node( tile_op.AddOutputTensors([output_tensor_wrapper]) tile_op.AddTensorParam( OpTile.param_multiples, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(multiples_shape), multiples_shape, np.array(multiples, dtype=np.uint32), diff --git a/backends/qualcomm/builders/op_flip.py b/backends/qualcomm/builders/op_flip.py index 16d68942d31..3eea5867ac8 100644 --- a/backends/qualcomm/builders/op_flip.py +++ b/backends/qualcomm/builders/op_flip.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,11 +27,11 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -46,7 +46,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) ranges = [] @@ -62,7 +62,7 @@ def define_node( ranges.extend([0, size, 1]) range_shape = [input_tensor.dim(), 3] - stride_slice_op = PyQnnWrapper.PyQnnOpWrapper( + stride_slice_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpStridedSlice.op_name, @@ -71,7 +71,7 @@ def define_node( stride_slice_op.AddOutputTensors([output_tensor_wrapper]) stride_slice_op.AddTensorParam( OpStridedSlice.param_ranges, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_32, len(range_shape), range_shape, np.array(ranges, dtype=np.int32), diff --git a/backends/qualcomm/builders/op_floor.py b/backends/qualcomm/builders/op_floor.py index 3d69389686e..25780301c3d 100644 --- a/backends/qualcomm/builders/op_floor.py +++ b/backends/qualcomm/builders/op_floor.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from .node_visitor import NodeVisitor @@ -23,15 +23,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) floor_inp_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) floor_input_tensors = [floor_inp_tensor_wrapper] @@ -41,12 +41,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) floor_output_tensors = [output_tensor_wrapper] - floor_op = PyQnnWrapper.PyQnnOpWrapper( + floor_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseFloor.op_name, diff --git a/backends/qualcomm/builders/op_full.py b/backends/qualcomm/builders/op_full.py index d58efd77791..5ac2e95c57b 100644 --- a/backends/qualcomm/builders/op_full.py +++ b/backends/qualcomm/builders/op_full.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -23,8 +23,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = torch.full( node.args[0], node.args[1], dtype=node.meta["val"].dtype ) @@ -36,6 +36,6 @@ def define_node( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) diff --git a/backends/qualcomm/builders/op_full_like.py b/backends/qualcomm/builders/op_full_like.py index 69609d887aa..66f80ecc80a 100644 --- a/backends/qualcomm/builders/op_full_like.py +++ b/backends/qualcomm/builders/op_full_like.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -23,8 +23,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: in_tensor = node.args[0].meta["val"] ref_tensor = torch.zeros(in_tensor.shape, dtype=in_tensor.dtype) out_tensor = torch.full_like(ref_tensor, node.args[1]) @@ -36,6 +36,6 @@ def define_node( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) diff --git a/backends/qualcomm/builders/op_gather.py b/backends/qualcomm/builders/op_gather.py index 140d2a79caf..ed6f3f0a56b 100644 --- a/backends/qualcomm/builders/op_gather.py +++ b/backends/qualcomm/builders/op_gather.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -26,15 +26,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = node.args[0] input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -46,7 +46,7 @@ def define_node( indices_node, node, indices_tensor.to(torch.int32), - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) gather_input_tensors = [input_tensor_wrapper, indices_tensor_wrapper] @@ -55,12 +55,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) gather_output_tensors = [output_tensor_wrapper] - gather_op = PyQnnWrapper.PyQnnOpWrapper( + gather_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpGatherElements.op_name, @@ -69,7 +69,7 @@ def define_node( gather_op.AddOutputTensors(gather_output_tensors) gather_op.AddScalarParam( OpGatherElements.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(dim)}, ) diff --git a/backends/qualcomm/builders/op_ge.py b/backends/qualcomm/builders/op_ge.py index 6c5671ff5f2..2aeb5a87d71 100644 --- a/backends/qualcomm/builders/op_ge.py +++ b/backends/qualcomm/builders/op_ge.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -51,7 +51,7 @@ def define_node( ) input_tensors.append(input_tensor_wrapper) - ge_op = PyQnnWrapper.PyQnnOpWrapper( + ge_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseGreaterEqual.op_name, diff --git a/backends/qualcomm/builders/op_gelu.py b/backends/qualcomm/builders/op_gelu.py index 3d111f0cf98..8d74eaaef16 100644 --- a/backends/qualcomm/builders/op_gelu.py +++ b/backends/qualcomm/builders/op_gelu.py @@ -6,7 +6,7 @@ from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -25,15 +25,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -42,11 +42,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - gelu_op = PyQnnWrapper.PyQnnOpWrapper( + gelu_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpGelu.op_name, diff --git a/backends/qualcomm/builders/op_grid_sampler_2d.py b/backends/qualcomm/builders/op_grid_sampler_2d.py new file mode 100644 index 00000000000..4327aa81fe8 --- /dev/null +++ b/backends/qualcomm/builders/op_grid_sampler_2d.py @@ -0,0 +1,162 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager +import numpy as np + +import torch + +from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_DTYPE + +from .node_visitor import NodeVisitor, QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpGridSample, OpTranspose, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class GridSample(NodeVisitor): + target = ["aten.grid_sampler_2d.default", "aten.grid_sampler_3d.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: + grid_sample_op_list = [] + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + grid_node = self.get_node(node.args[1]) + grid_tensor = self.get_tensor(grid_node, node) + grid_tensor_wrapper = self.define_tensor( + grid_node, + node, + grid_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + input_shape = input_node.meta["val"].shape + input_rank = len(input_shape) + if input_rank not in [4, 5]: + warnings.warn( + "[QNN Delegate Op Builder]: The shape is not supported, fallback op", + stacklevel=1, + ) + return + + # About this operator, in ATen, the layout of input_tensor and of grid_tensor are not identical. + # But in HW they are all NHWC or NDHWC. So, we make shape transformation again. + if input_rank == 4: + dims_shape_back = (0, 3, 1, 2) + elif input_rank == 5: + dims_shape_back = (0, 4, 1, 2, 3) + else: + warnings.warn( + f"[QNN Delegate Op Builder]: Not support rank {input_rank}, fallback op", + stacklevel=1, + ) + return + + grid_quant_encoding, grid_quant_configs = self.get_quant_encoding_conf( + grid_node, node + ) + grid_dtype = ( + QNN_TENSOR_TYPE_MAP[grid_tensor.dtype] + if grid_quant_encoding + == PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED + else QNN_QUANT_TYPE_MAP[ + ( + torch.uint16 + if grid_quant_configs[QCOM_DTYPE] == torch.int32 + else grid_quant_configs[QCOM_DTYPE] + ) + ] + ) + # transpose + permute_output_tensor = grid_tensor.permute(dims=dims_shape_back) + transpose_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_transpose", + tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=grid_dtype, + quant_encoding=grid_quant_encoding, + quant_configs=grid_quant_configs, + dims=permute_output_tensor.size(), + tensor=permute_output_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + + permute_order = cast(List[int], dims_shape_back) + permute_order_shape = [len(permute_order)] + transpose_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTranspose.op_name, + ) + transpose_op.AddInputTensors([grid_tensor_wrapper]) + transpose_op.AddOutputTensors([transpose_output_tensor_wrapper]) + transpose_op.AddTensorParam( + OpTranspose.param_perm, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(permute_order_shape), + permute_order_shape, + np.array(permute_order, dtype=np.uint32), + True, + ) + grid_sample_op_list.append(transpose_op) + + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + align_corners = node.args[4] if len(node.args) > 4 else False + padding_mode = node.args[3] if len(node.args) > 3 else 0 + interpo_mode = node.args[2] if len(node.args) > 2 else 0 + + grid_sample_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpGridSample.op_name, + ) + grid_sample_op.AddInputTensors( + [input_tensor_wrapper, transpose_output_tensor_wrapper] + ) + grid_sample_op.AddOutputTensors([output_tensor_wrapper]) + grid_sample_op.AddScalarParam( + OpGridSample.param_align_corners, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: align_corners}, + ) + grid_sample_op.AddScalarParam( + OpGridSample.param_mode, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(interpo_mode)}, + ) + grid_sample_op.AddScalarParam( + OpGridSample.param_padding_mode, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(padding_mode)}, + ) + grid_sample_op_list.append(grid_sample_op) + return grid_sample_op_list diff --git a/backends/qualcomm/builders/op_group_norm.py b/backends/qualcomm/builders/op_group_norm.py index c492616d999..7a18373aa6c 100644 --- a/backends/qualcomm/builders/op_group_norm.py +++ b/backends/qualcomm/builders/op_group_norm.py @@ -6,7 +6,7 @@ from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -28,15 +28,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -46,7 +46,7 @@ def define_node( weight_node, node, weight_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) @@ -56,7 +56,7 @@ def define_node( bias_node, node, bias_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) group = node.args[6] @@ -67,11 +67,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - group_norm_op = PyQnnWrapper.PyQnnOpWrapper( + group_norm_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpGroupNorm.op_name, @@ -82,12 +82,12 @@ def define_node( group_norm_op.AddOutputTensors([output_tensor_wrapper]) group_norm_op.AddScalarParam( OpGroupNorm.param_epsilon, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, {QCOM_DATA: np.float32(epsilon)}, ) group_norm_op.AddScalarParam( OpGroupNorm.param_group, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(group)}, ) diff --git a/backends/qualcomm/builders/op_gt.py b/backends/qualcomm/builders/op_gt.py index e296589af5a..66b94fa3b75 100644 --- a/backends/qualcomm/builders/op_gt.py +++ b/backends/qualcomm/builders/op_gt.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -51,7 +51,7 @@ def define_node( ) input_tensors.append(input_tensor_wrapper) - gt_op = PyQnnWrapper.PyQnnOpWrapper( + gt_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseGreater.op_name, diff --git a/backends/qualcomm/builders/op_hardsigmoid.py b/backends/qualcomm/builders/op_hardsigmoid.py index 70ac35828d8..ca834358bf8 100644 --- a/backends/qualcomm/builders/op_hardsigmoid.py +++ b/backends/qualcomm/builders/op_hardsigmoid.py @@ -6,7 +6,7 @@ from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -44,11 +44,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - hardsigmoid_op = PyQnnWrapper.PyQnnOpWrapper( + hardsigmoid_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseNeuron.op_name, @@ -59,19 +59,19 @@ def define_node( # The operation enum of hardsigmoid in QNN hardsigmoid_op.AddScalarParam( OpElementWiseNeuron.param_operation, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(2)}, ) # The parameter used in Pytorch definition for hardsigmoid hardsigmoid_op.AddScalarParam( OpElementWiseNeuron.param_alpha, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, {QCOM_DATA: np.float32(1 / 6)}, ) hardsigmoid_op.AddScalarParam( OpElementWiseNeuron.param_beta, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, {QCOM_DATA: np.float32(1 / 2)}, ) diff --git a/backends/qualcomm/builders/op_hardswish.py b/backends/qualcomm/builders/op_hardswish.py index 8a8fa25847d..2fa8a6276c2 100644 --- a/backends/qualcomm/builders/op_hardswish.py +++ b/backends/qualcomm/builders/op_hardswish.py @@ -6,7 +6,7 @@ from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -25,15 +25,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -42,11 +42,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - hardswish_op = PyQnnWrapper.PyQnnOpWrapper( + hardswish_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpHardSwish.op_name, diff --git a/backends/qualcomm/builders/op_hardtanh.py b/backends/qualcomm/builders/op_hardtanh.py index 755e45f0e3b..3e321b4e028 100644 --- a/backends/qualcomm/builders/op_hardtanh.py +++ b/backends/qualcomm/builders/op_hardtanh.py @@ -6,7 +6,7 @@ from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -54,11 +54,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - hardtanh_op = PyQnnWrapper.PyQnnOpWrapper( + hardtanh_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReluMinMax.op_name, @@ -67,12 +67,12 @@ def define_node( hardtanh_op.AddOutputTensors([output_tensor_wrapper]) hardtanh_op.AddScalarParam( OpReluMinMax.param_max_value, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, {QCOM_DATA: np.float32(output_max)}, ) hardtanh_op.AddScalarParam( OpReluMinMax.param_min_value, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, {QCOM_DATA: np.float32(output_min)}, ) diff --git a/backends/qualcomm/builders/op_index.py b/backends/qualcomm/builders/op_index.py index 2a7da815265..997dad4cc6f 100644 --- a/backends/qualcomm/builders/op_index.py +++ b/backends/qualcomm/builders/op_index.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -51,7 +51,7 @@ def define_node( indices_node, node, indices_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -62,12 +62,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) gather_output_tensors = [output_tensor_wrapper] - gather_op = PyQnnWrapper.PyQnnOpWrapper( + gather_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpGather.op_name, @@ -78,7 +78,7 @@ def define_node( # If support tuple of tensor, need to refine it based on len gather_op.AddScalarParam( OpGather.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_32, {QCOM_DATA: np.int32(axis)}, ) diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py index c3c42ed483a..84eb2368967 100644 --- a/backends/qualcomm/builders/op_index_put.py +++ b/backends/qualcomm/builders/op_index_put.py @@ -1,14 +1,19 @@ import warnings +from collections import OrderedDict from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch -from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS +from executorch.backends.qualcomm.utils.constants import ( + QCOM_DATA, + QCOM_DTYPE, + QCOM_QUANT_ATTRS, +) from executorch.exir.dialects._ops import ops as exir_ops -from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP +from .node_visitor import NodeVisitor, QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP from .node_visitor_manager import register_node_visitor from .qnn_constants import ( OpConcat, @@ -26,72 +31,131 @@ class IndexPutVisitor(NodeVisitor): def __init__(self, *args) -> None: super().__init__(*args) - def define_node( + def define_node( # noqa: C901 self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: op_wrapper_list = [] input_node = self.get_node(node.args[0]) # Because the args[0] of index_put op doesn't annotate, need to fill in the quant_attr with the node here. if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): quant_attrs = quant_attrs.copy() input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - indicies_node = node.args[1] - index_node_dim = None - index_nodes = [] - index_tensors = [] + indices_nodes = ( + node.args[1] if isinstance(node.args[1], list) else [node.args[1]] + ) target_index = [] + all_range_index = OrderedDict() + index_dtype = [ + node.meta["val"].dtype for node in indices_nodes if node is not None + ][0] + + # preprocess: + # - broadcast dimension for multiple specified index + # - broadcast specified index if dimensions are not matched + max_indices_in_specified_index = 0 + for index, idx_node in enumerate(indices_nodes): + if isinstance(idx_node, torch.fx.Node): + last_specified_index_node = index + if max_indices_in_specified_index < idx_node.meta["val"].nelement(): + max_indices_in_specified_index = idx_node.meta["val"].nelement() # If there is None in a list, it means all range at that dimension - # E.g., indicies_node: [None, None, aten__to_copy_default_1] - if isinstance(indicies_node, list): - for index, idx_node in enumerate(indicies_node): - # First, collect the indice_node and index of None to construct the shape of index node - # E.g., shape of input: [1, 1024, 12, 64] - # For "None" axis (assume indicies_node: [None, None, aten__to_copy_default_1]), - # target_index: [1, 1024, x], x is the shape of index_tensor, index_node_dim: 2 - if isinstance(idx_node, torch.fx.Node): - index_nodes.append(idx_node) - index_tensors.append(self.get_tensor(idx_node, idx_node)) - target_index.extend(index_tensors[-1].size()) - index_node_dim = index - elif idx_node is None and index_node_dim is None: - # E.g., indicies_node: [None, aten__to_copy_default_1, None] - # Don't need to consider "None" after index_node. - target_index.append(input_tensor.size(index)) - else: - warnings.warn( - f"[QNN Delegate Op Builder]: Get the index {idx_node} that is neither a node nor None", - stacklevel=1, + for index, idx_node in enumerate(indices_nodes): + # First, collect the index_node and index of None to construct the shape of index node + # E.g., shape of input: [1, 1024, 12, 64] + # For "None" axis (assume indices_node: [None, None, aten__to_copy_default_1]), + # target_index: [1, 1024, x], x is the shape of index_tensor, index_node_dim: 2 + if isinstance(idx_node, torch.fx.Node): + # e.g. for case [index_node_0, None, index_node_1], nodes will have the same number of indices + target_index.append( + self.get_tensor(idx_node, idx_node).nelement() + if last_specified_index_node == index + else 1 + ) + elif idx_node is None: + # E.g., indices_node: [None, None, aten__to_copy_default_1] + all_range_index[index] = torch.arange( + input_tensor.size(index), dtype=index_dtype + ) + target_index.append(input_tensor.size(index)) + else: + warnings.warn( + f"[QNN Delegate Op Builder]: Get the index {idx_node} that is neither a node nor None", + stacklevel=1, + ) + return + + # preprocess all range indices if any + if None in indices_nodes: + all_range_tensor = torch.cartesian_prod(*all_range_index.values()) + # repeat all_range_tensor interleavely for future concatenation + # e.g. input_node = [5, 4, 3, 2], indices = [index_0_node, None, index_2_node] + # index_0.shape == index_2.shape == 2 (will guarantee this condition) + # where user specified (3, 4) for index_0, (0, 1) for index_2 + # --- + # we should have all_range_tensor: [0, 1, 2, 3] + # repeat interleavely with 2 to match future tiled index_0_node & index_2_node + # we'll have 1(index_0 -> same as index_2)*4(index_1)*2(index_2) indices in total: + # | index_0_node | None | index_2_node | + # | 3 | 0 | 0 | + # | 4 | 0 | 1 | + # | 3 | 1 | 0 | + # | 4 | 1 | 1 | + # | 3 | 2 | 0 | + # | 4 | 2 | 1 | + # | 3 | 3 | 0 | + # | 4 | 3 | 1 | + all_range_tensor_aug = all_range_tensor.repeat_interleave( + max_indices_in_specified_index, dim=0 + ) + for index in all_range_index.keys(): + # Repeat index for "None" axis in indices_nodes + range_index_node = torch.fx.Node( + node.graph, + node.name + f"_all_range_index_{index}", + "call_function", + exir_ops.edge.aten.tensor.default, + (), # args + {}, # kwargs + ) + range_indices = ( + ( + all_range_tensor_aug[:, index] + if all_range_tensor_aug.dim() > 1 + else + # if there is only one None + all_range_tensor_aug ) - return - # Assume that there is only one node in list - assert len(index_nodes) == 1, "Not support multiple indices tensor" - indice_node = index_nodes[0] - indice_tensor = index_tensors[0] - indices_tensor_wrapper = self.define_tensor( - indice_node, - node, - indice_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, - ) + .reshape(-1, 1) + .contiguous() + ) + target_index_tensor_wrapper = self.define_tensor( + range_index_node, + node, + range_indices, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) + # store it for future concatenation + all_range_index[index] = (range_indices, target_index_tensor_wrapper) # Need to reconstruct the index tensor. # E.g., based on ScatterND Op Def in QNN Docs. # Torch: # Given that # shape of input: [1, 12, 1024, 64] - # indicies_node: [None, None, aten__to_copy_default_1] + # indices_node: [None, None, aten__to_copy_default_1] # shape of aten__to_copy_default_1: [1] # QNN: # Index tensor: @@ -104,140 +168,159 @@ def define_node( # update_indices = indices.shape[:-1] # for idx in np.ndindex(update_indices): # output[indices[idx]] = updates[idx] + specified_index = OrderedDict() + for i, indices_node in enumerate(indices_nodes): + if indices_node is None: + continue - # Append one dimension to specify x-tuple - index_shape = target_index + [1] - # Reshape the index_node for tile op - reshape_shape = [ - shape if id == index_node_dim else 1 for id, shape in enumerate(index_shape) - ] - reshape_output_tensor = indice_tensor.reshape(reshape_shape) - reshape_output_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + "_reshape", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=QNN_TENSOR_TYPE_MAP[reshape_output_tensor.dtype], - quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, - quant_configs={}, - dims=reshape_output_tensor.size(), - tensor=reshape_output_tensor, - is_fake_tensor=True, - nodes_to_wrappers=nodes_to_wrappers, - ) - reshape_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpReshape.op_name, - ) - reshape_op.AddInputTensors([indices_tensor_wrapper]) - reshape_op.AddOutputTensors([reshape_output_tensor_wrapper]) - op_wrapper_list.append(reshape_op) - index_put_index_input_tensor_wrapper = reshape_output_tensor_wrapper + indices_tensor = self.get_tensor(indices_node, indices_node) + indices_tensor_wrapper = self.define_tensor( + indices_node, + node, + indices_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + if indices_tensor.nelement() < max_indices_in_specified_index: + # broadcast the specified index + indices_tensor = indices_tensor.repeat(max_indices_in_specified_index) + indices_multiples = [max_indices_in_specified_index] + indices_multiples_shape = [len(indices_multiples)] + indices_tile_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + f"_indices_tile_{i}", + tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[indices_tensor.dtype], + quant_encoding=PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + quant_configs={}, + dims=indices_tensor.size(), + tensor=indices_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + tile_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTile.op_name, + ) + tile_op.AddInputTensors([indices_tensor_wrapper]) + tile_op.AddOutputTensors([indices_tile_tensor_wrapper]) + tile_op.AddTensorParam( + OpTile.param_multiples, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(indices_multiples_shape), + indices_multiples_shape, + np.array(indices_multiples, dtype=np.uint32), + True, + ) + op_wrapper_list.append(tile_op) + indices_tensor_wrapper = indices_tile_tensor_wrapper - # Tile the index_node and concat the target index - if None in indicies_node: - tile_output_tensor = reshape_output_tensor.expand(index_shape) - # Tile the index_node to align with the shape of target_index - # Only need to tile the dim of None axis - # E.g., indicies_node: [None, None, aten__to_copy_default_1] - # Should tile the first two dimension. - multiples = [ - shape if id != index_node_dim else 1 - for id, shape in enumerate(index_shape) - ] - multiples_shape = [len(index_shape)] - tile_output_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + "_tile", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=QNN_TENSOR_TYPE_MAP[tile_output_tensor.dtype], - quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + # Append one dimension to specify x-tuple + # Reshape the index_node for tile op + reshape_shape = list(indices_tensor.shape) + [1] + reshape_output_tensor = indices_tensor.reshape(reshape_shape) + reshape_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + f"_reshape_{i}", + tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[reshape_output_tensor.dtype], + quant_encoding=PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, quant_configs={}, - dims=tile_output_tensor.size(), - tensor=tile_output_tensor, + dims=reshape_output_tensor.size(), + tensor=reshape_output_tensor, is_fake_tensor=True, nodes_to_wrappers=nodes_to_wrappers, ) - tile_op = PyQnnWrapper.PyQnnOpWrapper( + reshape_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, - OpTile.op_name, - ) - tile_op.AddInputTensors([reshape_output_tensor_wrapper]) - tile_op.AddOutputTensors([tile_output_tensor_wrapper]) - tile_op.AddTensorParam( - OpTile.param_multiples, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - len(multiples_shape), - multiples_shape, - np.array(multiples, dtype=np.uint32), - True, + OpReshape.op_name, ) - op_wrapper_list.append(tile_op) + reshape_op.AddInputTensors([indices_tensor_wrapper]) + reshape_op.AddOutputTensors([reshape_output_tensor_wrapper]) + op_wrapper_list.append(reshape_op) + index_tensor_wrapper = reshape_output_tensor_wrapper + index_tensor = reshape_output_tensor - # Repeat index for "None" axis in indicies_node - ranges = [ - torch.arange(dim, dtype=indice_tensor.dtype) - for dim in target_index[:-1] - ] - target_index_shape = target_index + [len(ranges)] - target_index_tensor = torch.cartesian_prod(*ranges) - reshape_target_index_shape = [ - shape if id != index_node_dim else 1 - for id, shape in enumerate(target_index_shape) - ] - target_index_tensor = target_index_tensor.reshape( - reshape_target_index_shape - ) - target_index_tensor = target_index_tensor.expand( - target_index_shape - ).contiguous() - target_index_node = torch.fx.Node( - node.graph, - node.name + "_target_index", - "call_function", - exir_ops.edge.aten.tensor.default, - (), # args - {}, # kwargs - ) - target_index_tensor_wrapper = self.define_tensor( - target_index_node, - node, - target_index_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) + # Tile the index_node and concat the target index + if None in indices_nodes: + tile_output_tensor = reshape_output_tensor.repeat( + all_range_tensor.size(0), 1 + ) + # Tile the index_node to align with the shape of target_index + # Only need to tile the dim of None axis + # E.g., indices_node: [None, None, aten__to_copy_default_1] + # Should tile the number of indices combination of first two dimension + # times number of indices specified by aten__to_copy_default_1 + multiples = [all_range_tensor.size(0), 1] + multiples_shape = [len(multiples)] + tile_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + f"_tile_{i}", + tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[tile_output_tensor.dtype], + quant_encoding=PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + quant_configs={}, + dims=tile_output_tensor.size(), + tensor=tile_output_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + tile_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTile.op_name, + ) + tile_op.AddInputTensors([reshape_output_tensor_wrapper]) + tile_op.AddOutputTensors([tile_output_tensor_wrapper]) + tile_op.AddTensorParam( + OpTile.param_multiples, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(multiples_shape), + multiples_shape, + np.array(multiples, dtype=np.uint32), + True, + ) + op_wrapper_list.append(tile_op) + index_tensor_wrapper = tile_output_tensor_wrapper + index_tensor = tile_output_tensor + + specified_index[i] = (index_tensor, index_tensor_wrapper) - # Concat target_index and tile output to reconstruct index_node - # Cannot use QNN Pack (stack) since QNN Pack is not support int32 dtype - concat_output_tensor = torch.concat( - (target_index_tensor, tile_output_tensor), dim=-1 + # Concat target_index and tile output to reconstruct index_node + # Cannot use QNN Pack (stack) since QNN Pack is not support int32 dtype + index_tensors, index_tensor_wrappers = [], [] + for i, arg in enumerate(indices_nodes): + tensor, tensor_wrapper = ( + all_range_index[i] if arg is None else specified_index[i] ) + index_tensors.append(tensor) + index_tensor_wrappers.append(tensor_wrapper) + + if len(index_tensor_wrappers) > 1: + concat_output_tensor = torch.concat(index_tensors, dim=-1) concat_output_tensor_wrapper = self.define_custom_tensor_wrapper( node_name=node.name + "_concat", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, dtype=QNN_TENSOR_TYPE_MAP[concat_output_tensor.dtype], - quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + quant_encoding=PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, quant_configs={}, dims=concat_output_tensor.size(), tensor=concat_output_tensor, is_fake_tensor=True, nodes_to_wrappers=nodes_to_wrappers, ) - concat_op = PyQnnWrapper.PyQnnOpWrapper( + concat_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpConcat.op_name, ) - concat_op.AddInputTensors( - [target_index_tensor_wrapper, tile_output_tensor_wrapper] - ) + concat_op.AddInputTensors(index_tensor_wrappers) concat_op.AddOutputTensors([concat_output_tensor_wrapper]) concat_op.AddScalarParam( OpConcat.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(concat_output_tensor.dim() - 1)}, ) op_wrapper_list.append(concat_op) - index_put_index_input_tensor_wrapper = concat_output_tensor_wrapper value_node = self.get_node(node.args[2]) value_tensor = self.get_tensor(value_node, node) @@ -245,29 +328,152 @@ def define_node( value_node, node, value_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) + # handle broadcast scenario + # e.g. input_tensor: (1, 12, 1024, 64), value_tensor: (1, 64) + # => value_reshape_tensor: (1, 1, 1, 64) + new_value_shape = ( + *([1] * (input_tensor.dim() - value_tensor.dim())), + *value_tensor.shape, + ) + # reshape the value_node for tile op + value_quant_encoding, value_quant_configs = self.get_quant_encoding_conf( + value_node, node + ) + value_dtype = ( + QNN_TENSOR_TYPE_MAP[value_tensor.dtype] + if value_quant_encoding + == PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED + else QNN_QUANT_TYPE_MAP[ + ( + torch.uint16 + if value_quant_configs[QCOM_DTYPE] == torch.int32 + else value_quant_configs[QCOM_DTYPE] + ) + ] + ) + value_reshape_tensor = value_tensor.reshape(new_value_shape) + value_reshape_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_value_reshape", + tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=value_dtype, + quant_encoding=value_quant_encoding, + quant_configs=value_quant_configs, + dims=value_reshape_tensor.size(), + tensor=value_reshape_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + value_reshape_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReshape.op_name, + ) + value_reshape_op.AddInputTensors([value_tensor_wrapper]) + value_reshape_op.AddOutputTensors([value_reshape_tensor_wrapper]) + op_wrapper_list.append(value_reshape_op) + + # e.g. input_tensor: (1, 12, 1024, 64), index_tensor: (None, None, 2), value_tensor: (1, 64) + # => multiples: [1, 12, 2, 1] + value_multiples = [] + for i in range(input_tensor.dim() - 1, -1, -1): + if i in specified_index: + # all user specified index node wil have the same dimension + multiplier = ( + indices_nodes[i].meta["val"].nelement() // new_value_shape[i] + if i == last_specified_index_node + else 1 + ) + else: + multiplier = input_tensor.shape[i] // new_value_shape[i] + value_multiples.insert(0, multiplier) + + value_tile_tensor = value_reshape_tensor.repeat(value_multiples) + value_multiples_shape = [len(value_multiples)] + value_tile_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_value_tile", + tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=value_dtype, + quant_encoding=value_quant_encoding, + quant_configs=value_quant_configs, + dims=value_tile_tensor.size(), + tensor=value_tile_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + value_tile_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTile.op_name, + ) + value_tile_op.AddInputTensors([value_reshape_tensor_wrapper]) + value_tile_op.AddOutputTensors([value_tile_tensor_wrapper]) + value_tile_op.AddTensorParam( + OpTile.param_multiples, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(value_multiples_shape), + value_multiples_shape, + np.array(value_multiples, dtype=np.uint32), + True, + ) + op_wrapper_list.append(value_tile_op) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - index_put_op = PyQnnWrapper.PyQnnOpWrapper( + index_put_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpScatterNd.op_name, ) + # accumulation + if len(node.args) > 3 and node.args[3]: + index_put_op.AddScalarParam( + OpScatterNd.param_reduction, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: 1}, + ) + + # check final index_input tensor + index_input_tensor, index_input_tensor_wrapper = ( + (concat_output_tensor, concat_output_tensor_wrapper) + if len(index_tensor_wrappers) > 1 + else specified_index[last_specified_index_node] + ) + target_index_reshape_tensor = index_input_tensor.reshape((*target_index, -1)) + target_index_reshape_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_target_index_reshape", + tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[target_index_reshape_tensor.dtype], + quant_encoding=PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + quant_configs={}, + dims=target_index_reshape_tensor.size(), + tensor=target_index_reshape_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + target_index_reshape_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReshape.op_name, + ) + target_index_reshape_op.AddInputTensors([index_input_tensor_wrapper]) + target_index_reshape_op.AddOutputTensors([target_index_reshape_tensor_wrapper]) + op_wrapper_list.append(target_index_reshape_op) + index_put_op.AddInputTensors( [ input_tensor_wrapper, - index_put_index_input_tensor_wrapper, - value_tensor_wrapper, + target_index_reshape_tensor_wrapper, + value_tile_tensor_wrapper, ] ) index_put_op.AddOutputTensors([output_tensor_wrapper]) diff --git a/backends/qualcomm/builders/op_index_select.py b/backends/qualcomm/builders/op_index_select.py index 22733e45397..56d5963e5e9 100644 --- a/backends/qualcomm/builders/op_index_select.py +++ b/backends/qualcomm/builders/op_index_select.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -26,15 +26,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -47,7 +47,7 @@ def define_node( indices_node, node, indices_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -58,12 +58,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) gather_output_tensors = [output_tensor_wrapper] - gather_op = PyQnnWrapper.PyQnnOpWrapper( + gather_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpGather.op_name, @@ -74,7 +74,7 @@ def define_node( # If support tuple of tensor, need to refine it based on len gather_op.AddScalarParam( OpGather.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_32, {QCOM_DATA: np.int32(axis)}, ) diff --git a/backends/qualcomm/builders/op_instance_norm.py b/backends/qualcomm/builders/op_instance_norm.py index 08c4730ce1d..45a863d140a 100644 --- a/backends/qualcomm/builders/op_instance_norm.py +++ b/backends/qualcomm/builders/op_instance_norm.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from executorch.backends.qualcomm.utils.constants import ( @@ -34,8 +34,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) weight_node = self.get_node(node.args[1]) bias_node = self.get_node(node.args[2]) @@ -44,7 +44,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) input_tensor_wrappers = [input_tensor_wrapper] @@ -54,7 +54,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) output_tensor_wrappers = [output_tensor_wrapper] @@ -85,7 +85,7 @@ def define_node( weight_node, node, weight_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) input_tensor_wrappers.append(weight_tensor_wrapper) @@ -96,12 +96,12 @@ def define_node( bias_node, node, bias_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) input_tensor_wrappers.append(bias_tensor_wrapper) - instance_norm_op = PyQnnWrapper.PyQnnOpWrapper( + instance_norm_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpInstanceNorm.op_name, diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index 7c17980a82e..a51056eb7bb 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -7,7 +7,7 @@ import warnings from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -29,15 +29,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -60,7 +60,7 @@ def define_node( weight_node, node, weight_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) @@ -73,7 +73,7 @@ def define_node( bias_node, node, bias_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) layer_norm_input_tensors.append(bias_tensor_wrapper) @@ -85,11 +85,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - layer_norm_op = PyQnnWrapper.PyQnnOpWrapper( + layer_norm_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpLayerNorm.op_name, @@ -98,12 +98,12 @@ def define_node( layer_norm_op.AddOutputTensors([output_tensor_wrapper]) layer_norm_op.AddScalarParam( OpLayerNorm.param_epsilon, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, {QCOM_DATA: np.float32(epsilon)}, ) layer_norm_op.AddTensorParam( OpLayerNorm.param_axes, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(axis_shape), axis_shape, np.array(axis, dtype=np.uint32), diff --git a/backends/qualcomm/builders/op_le.py b/backends/qualcomm/builders/op_le.py index ad6a78b3da8..c58edc8e5c3 100644 --- a/backends/qualcomm/builders/op_le.py +++ b/backends/qualcomm/builders/op_le.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -51,7 +51,7 @@ def define_node( ) input_tensors.append(input_tensor_wrapper) - le_op = PyQnnWrapper.PyQnnOpWrapper( + le_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseLessEqual.op_name, diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index d5ac153b8d1..0a1d8f5cbdf 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -6,7 +6,7 @@ from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from executorch.backends.qualcomm.utils.constants import ( @@ -18,7 +18,6 @@ from .node_visitor import NodeVisitor from .node_visitor_manager import register_node_visitor from .qnn_constants import OpFullyConnected, QNN_OP_PACKAGE_NAME_QTI_AISW -from .utils import get_parameter @register_node_visitor @@ -31,8 +30,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: linear_input_tensors = [] input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -40,7 +39,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) linear_input_tensors.append(input_tensor_wrapper) @@ -55,32 +54,26 @@ def define_node( quant_attrs[QCOM_ZERO_POINTS] = quant_attrs[QCOM_ZERO_POINTS].reshape( [-1, 1] ) - - weight_tensor = get_parameter(weight_node, self.edge_program) + weight_tensor = self.get_tensor(weight_node, node) weight_tensor_wrapper = self.define_tensor( weight_node, node, weight_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + # It will determine correct QNN tensor type in define_tensor. + # This param seems unnecessary, which we could possibly remove this in the future. + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) linear_input_tensors.append(weight_tensor_wrapper) if len(node.args) >= 3: bias_node = self.get_node(node.args[2]) - - bias_tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC - bias_tensor = get_parameter(bias_node, self.edge_program) - # if bias_node is getitem - if bias_tensor is None: - bias_tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE - bias_tensor = bias_node.meta["val"] - + bias_tensor = self.get_tensor(bias_node, node) bias_tensor_wrapper = self.define_tensor( bias_node, node, bias_tensor, - bias_tensor_type, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) linear_input_tensors.append(bias_tensor_wrapper) @@ -90,11 +83,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - linear_op = PyQnnWrapper.PyQnnOpWrapper( + linear_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpFullyConnected.op_name, diff --git a/backends/qualcomm/builders/op_log.py b/backends/qualcomm/builders/op_log.py index 397e2072489..8d8d4e64a98 100644 --- a/backends/qualcomm/builders/op_log.py +++ b/backends/qualcomm/builders/op_log.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,15 +24,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) log_inp_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) log_input_tensors = [log_inp_tensor_wrapper] @@ -42,12 +42,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) log_output_tensors = [output_tensor_wrapper] - log_op = PyQnnWrapper.PyQnnOpWrapper( + log_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseLog.op_name, diff --git a/backends/qualcomm/builders/op_log_softmax.py b/backends/qualcomm/builders/op_log_softmax.py index 947140006a3..73dddb684e0 100644 --- a/backends/qualcomm/builders/op_log_softmax.py +++ b/backends/qualcomm/builders/op_log_softmax.py @@ -6,7 +6,7 @@ import warnings from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA @@ -26,8 +26,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -35,7 +35,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) log_softmax_input_tensors = [log_softmax_inp_tensor_wrapper] @@ -45,7 +45,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) log_softmax_output_tensors = [log_softmax_output_tensor_wrapper] @@ -65,7 +65,7 @@ def define_node( ) return None - log_softmax_op = PyQnnWrapper.PyQnnOpWrapper( + log_softmax_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpLogSoftmax.op_name, @@ -75,7 +75,7 @@ def define_node( log_softmax_op.AddScalarParam( OpLogSoftmax.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(dim)}, ) return log_softmax_op diff --git a/backends/qualcomm/builders/op_logical_not.py b/backends/qualcomm/builders/op_logical_not.py index 4e8fc8543a7..fc639b21ef5 100644 --- a/backends/qualcomm/builders/op_logical_not.py +++ b/backends/qualcomm/builders/op_logical_not.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,15 +24,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -41,11 +41,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - logical_not_op = PyQnnWrapper.PyQnnOpWrapper( + logical_not_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseNot.op_name, diff --git a/backends/qualcomm/builders/op_lt.py b/backends/qualcomm/builders/op_lt.py index 2558a97dfab..aeb9b8a1eaa 100644 --- a/backends/qualcomm/builders/op_lt.py +++ b/backends/qualcomm/builders/op_lt.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -51,7 +51,7 @@ def define_node( ) input_tensors.append(input_tensor_wrapper) - lt_op = PyQnnWrapper.PyQnnOpWrapper( + lt_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseLess.op_name, diff --git a/backends/qualcomm/builders/op_matmul.py b/backends/qualcomm/builders/op_matmul.py index 5a1e366f384..bc3bddfcd69 100644 --- a/backends/qualcomm/builders/op_matmul.py +++ b/backends/qualcomm/builders/op_matmul.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,8 +24,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: matmul_input_tensors = [] for index in range(2): input_node = self.get_node(node.args[index]) @@ -35,7 +35,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) matmul_input_tensors.append(input_tensor_wrapper) @@ -45,12 +45,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) matmul_output_tensors = [output_tensor_wrapper] - matmul_op = PyQnnWrapper.PyQnnOpWrapper( + matmul_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpMatMul.op_name ) matmul_op.AddInputTensors(matmul_input_tensors) diff --git a/backends/qualcomm/builders/op_max.py b/backends/qualcomm/builders/op_max.py index 8406973ab5a..92d58b7f2b8 100644 --- a/backends/qualcomm/builders/op_max.py +++ b/backends/qualcomm/builders/op_max.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) min_output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -51,7 +51,7 @@ def define_node( ) min_input_tensors.append(input_tensor_wrapper) - max_op = PyQnnWrapper.PyQnnOpWrapper( + max_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseMaximum.op_name, diff --git a/backends/qualcomm/builders/op_max_dim.py b/backends/qualcomm/builders/op_max_dim.py index 354444da550..a00f04298af 100644 --- a/backends/qualcomm/builders/op_max_dim.py +++ b/backends/qualcomm/builders/op_max_dim.py @@ -6,7 +6,7 @@ from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> List[PyQnnWrapper.PyQnnOpWrapper]: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> List[PyQnnManager.PyQnnOpWrapper]: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -52,7 +52,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -62,7 +62,7 @@ def define_node( dims = [node.meta[QCOM_AXIS_ORDER].index(max_dim) for max_dim in dims] dims_shape = [len(dims)] - reduce_max_op = PyQnnWrapper.PyQnnOpWrapper( + reduce_max_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReduceMax.op_name, @@ -72,7 +72,7 @@ def define_node( reduce_max_op.AddTensorParam( OpReduceMax.param_axes, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(dims_shape), dims_shape, np.array(dims, dtype=np.uint32), @@ -82,7 +82,7 @@ def define_node( keep_dims = cast(bool, node.args[2]) reduce_max_op.AddScalarParam( OpReduceMax.param_keep_dims, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: keep_dims}, ) diff --git a/backends/qualcomm/builders/op_max_pool2d.py b/backends/qualcomm/builders/op_max_pool2d.py index 5da1bd1ac0f..6091fb8d053 100644 --- a/backends/qualcomm/builders/op_max_pool2d.py +++ b/backends/qualcomm/builders/op_max_pool2d.py @@ -6,7 +6,7 @@ import warnings from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -55,7 +55,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) # kernel info @@ -94,7 +94,7 @@ def define_node( if ceil_mode: mode = OpPoolMax2d.RoundingMode.CEIL - max_pool2d_op = PyQnnWrapper.PyQnnOpWrapper( + max_pool2d_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpPoolMax2d.op_name, @@ -104,7 +104,7 @@ def define_node( max_pool2d_op.AddTensorParam( OpPoolMax2d.param_filter_size, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(filter_size_shape), filter_size_shape, np.array( @@ -115,7 +115,7 @@ def define_node( ) max_pool2d_op.AddTensorParam( OpPoolMax2d.param_stride, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(stride_shape), stride_shape, np.array( @@ -126,7 +126,7 @@ def define_node( ) max_pool2d_op.AddTensorParam( OpPoolMax2d.param_pad_amount, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(padding_shape), padding_shape, np.array( @@ -138,7 +138,7 @@ def define_node( max_pool2d_op.AddScalarParam( OpPoolMax2d.param_rounding_mode, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(mode)}, ) diff --git a/backends/qualcomm/builders/op_mean_dim.py b/backends/qualcomm/builders/op_mean_dim.py index 630b1b0b8de..45b91e529a0 100644 --- a/backends/qualcomm/builders/op_mean_dim.py +++ b/backends/qualcomm/builders/op_mean_dim.py @@ -4,9 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import cast, Dict, List +from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,20 +27,35 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) # mean dims and keep dims - mean_dims = cast(List[int], node.args[1]) + rank = len(input_node.meta["val"].shape) + + if rank == 0: + raise RuntimeError( + "Mean doesn't support 0d input, please report a bug in https://github.com/pytorch/executorch/issues" + ) + + dim_arg = node.args[1] + + if dim_arg is None or len(dim_arg) == 0: + mean_dims = list(range(rank)) # reduce over all dims + elif isinstance(dim_arg, int): + mean_dims = [dim_arg] + else: + mean_dims = list(dim_arg) + mean_dims = [ mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims ] @@ -55,11 +70,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - reduce_mean_op = PyQnnWrapper.PyQnnOpWrapper( + reduce_mean_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReduceMean.op_name, @@ -68,7 +83,7 @@ def define_node( reduce_mean_op.AddOutputTensors([output_tensor_wrapper]) reduce_mean_op.AddTensorParam( OpReduceMean.param_axes, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(mean_dims_shape), mean_dims_shape, np.array(mean_dims, dtype=np.uint32), @@ -78,7 +93,7 @@ def define_node( keep_dims = cast(bool, node.args[2]) reduce_mean_op.AddScalarParam( OpReduceMean.param_keep_dims, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: keep_dims}, ) diff --git a/backends/qualcomm/builders/op_min.py b/backends/qualcomm/builders/op_min.py index 28c766cffb5..359ac0e6580 100644 --- a/backends/qualcomm/builders/op_min.py +++ b/backends/qualcomm/builders/op_min.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) min_output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -51,7 +51,7 @@ def define_node( ) min_input_tensors.append(input_tensor_wrapper) - min_op = PyQnnWrapper.PyQnnOpWrapper( + min_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseMinimum.op_name, diff --git a/backends/qualcomm/builders/op_min_dim.py b/backends/qualcomm/builders/op_min_dim.py index 6425a9aa755..63775847cdf 100644 --- a/backends/qualcomm/builders/op_min_dim.py +++ b/backends/qualcomm/builders/op_min_dim.py @@ -6,7 +6,7 @@ from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> List[PyQnnWrapper.PyQnnOpWrapper]: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> List[PyQnnManager.PyQnnOpWrapper]: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -52,7 +52,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -62,7 +62,7 @@ def define_node( dims = [node.meta[QCOM_AXIS_ORDER].index(min_dim) for min_dim in dims] dims_shape = [len(dims)] - reduce_min_op = PyQnnWrapper.PyQnnOpWrapper( + reduce_min_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReduceMin.op_name, @@ -72,7 +72,7 @@ def define_node( reduce_min_op.AddTensorParam( OpReduceMin.param_axes, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(dims_shape), dims_shape, np.array(dims, dtype=np.uint32), @@ -82,7 +82,7 @@ def define_node( keep_dims = cast(bool, node.args[2]) reduce_min_op.AddScalarParam( OpReduceMin.param_keep_dims, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: keep_dims}, ) diff --git a/backends/qualcomm/builders/op_mul.py b/backends/qualcomm/builders/op_mul.py index f003007e0df..655f0b8fecd 100644 --- a/backends/qualcomm/builders/op_mul.py +++ b/backends/qualcomm/builders/op_mul.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) mul_output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -51,7 +51,7 @@ def define_node( ) mul_input_tensors.append(input_tensor_wrapper) - mul_op = PyQnnWrapper.PyQnnOpWrapper( + mul_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseMultiply.op_name, diff --git a/backends/qualcomm/builders/op_ne.py b/backends/qualcomm/builders/op_ne.py index 660c78e3e14..9a8bfac83f1 100644 --- a/backends/qualcomm/builders/op_ne.py +++ b/backends/qualcomm/builders/op_ne.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -51,7 +51,7 @@ def define_node( ) input_tensors.append(input_tensor_wrapper) - ne_op = PyQnnWrapper.PyQnnOpWrapper( + ne_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseNotEqual.op_name, diff --git a/backends/qualcomm/builders/op_neg.py b/backends/qualcomm/builders/op_neg.py index 911fbe742c8..a1e7f904c20 100644 --- a/backends/qualcomm/builders/op_neg.py +++ b/backends/qualcomm/builders/op_neg.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from .node_visitor import NodeVisitor @@ -23,15 +23,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) neg_inp_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) neg_input_tensors = [neg_inp_tensor_wrapper] @@ -40,11 +40,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) neg_output_tensors = [output_tensor_wrapper] - neg_op = PyQnnWrapper.PyQnnOpWrapper( + neg_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseNeg.op_name, diff --git a/backends/qualcomm/builders/op_or.py b/backends/qualcomm/builders/op_or.py index c0a995d3631..b397cb795c1 100644 --- a/backends/qualcomm/builders/op_or.py +++ b/backends/qualcomm/builders/op_or.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) or_output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -50,7 +50,7 @@ def define_node( nodes_to_wrappers, ) or_input_tensors.append(input_tensor_wrapper) - or_op = PyQnnWrapper.PyQnnOpWrapper( + or_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseOr.op_name, diff --git a/backends/qualcomm/builders/op_pad.py b/backends/qualcomm/builders/op_pad.py index 7832e180ebb..2302d6d55a0 100644 --- a/backends/qualcomm/builders/op_pad.py +++ b/backends/qualcomm/builders/op_pad.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -26,15 +26,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) pad_inp_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) pad_input_tensors = [pad_inp_tensor_wrapper] @@ -44,7 +44,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) pad_output_tensors = [output_tensor_wrapper] @@ -64,7 +64,7 @@ def define_node( pad_amount = pad_amount[list(node.meta[QCOM_AXIS_ORDER])] pad_amount_val = node.args[2] - pad_op = PyQnnWrapper.PyQnnOpWrapper( + pad_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpPad.op_name, @@ -75,7 +75,7 @@ def define_node( # For now, we only support constant (0) padding due to torch implementation pad_op.AddScalarParam( OpPad.param_scheme, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(OpPad.Scheme.CONSTANT)}, ) @@ -87,7 +87,7 @@ def define_node( pad_op.AddTensorParam( OpPad.param_pad_amount, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(pad_amount_shape), pad_amount_shape, pad_amount, diff --git a/backends/qualcomm/builders/op_pow.py b/backends/qualcomm/builders/op_pow.py index 50568bfbcc1..395f732c46f 100644 --- a/backends/qualcomm/builders/op_pow.py +++ b/backends/qualcomm/builders/op_pow.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -25,14 +25,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) pow_output_tensors = [output_tensor_wrapper] @@ -41,7 +41,7 @@ def define_node( input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -58,13 +58,13 @@ def define_node( exp_node, node, exp_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) pow_input_tensors = [input_tensor_wrapper, exp_tensor_wrapper] - pow_op = PyQnnWrapper.PyQnnOpWrapper( + pow_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWisePower.op_name, diff --git a/backends/qualcomm/builders/op_prelu.py b/backends/qualcomm/builders/op_prelu.py index 5291acfbc8c..217fc77935d 100644 --- a/backends/qualcomm/builders/op_prelu.py +++ b/backends/qualcomm/builders/op_prelu.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER @@ -25,15 +25,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) prelu_inp_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -55,7 +55,7 @@ def define_node( coeff_node, node, coeff_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) prelu_input_tensors = [prelu_inp_tensor_wrapper, coeff_tensor_wrapper] @@ -65,12 +65,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) prelu_output_tensors = [output_tensor_wrapper] - prelu_op = PyQnnWrapper.PyQnnOpWrapper( + prelu_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpPRelu.op_name, diff --git a/backends/qualcomm/builders/op_quantize.py b/backends/qualcomm/builders/op_quantize.py index 7d7bd3ec9ec..b3125390e8e 100644 --- a/backends/qualcomm/builders/op_quantize.py +++ b/backends/qualcomm/builders/op_quantize.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING, QCOM_QUANT_ATTRS @@ -22,8 +22,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: quant_input_tensors = [] input_node = node.args[0] input_tensor = self.get_tensor(input_node, node) @@ -31,7 +31,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) quant_input_tensors.append(inp_tensor_wrapper) @@ -47,12 +47,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) quant_output_tensors = [output_tensor_wrapper] - quant_op = PyQnnWrapper.PyQnnOpWrapper( + quant_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpQuantize.op_name, diff --git a/backends/qualcomm/builders/op_relu.py b/backends/qualcomm/builders/op_relu.py index 94afce56113..140dc5f4df0 100644 --- a/backends/qualcomm/builders/op_relu.py +++ b/backends/qualcomm/builders/op_relu.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,15 +24,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) relu_inp_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) relu_input_tensors = [relu_inp_tensor_wrapper] @@ -42,12 +42,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) relu_output_tensors = [output_tensor_wrapper] - relu_op = PyQnnWrapper.PyQnnOpWrapper( + relu_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpRelu.op_name, diff --git a/backends/qualcomm/builders/op_repeat.py b/backends/qualcomm/builders/op_repeat.py index abd0cff73e8..0ebbc57bb5a 100644 --- a/backends/qualcomm/builders/op_repeat.py +++ b/backends/qualcomm/builders/op_repeat.py @@ -6,7 +6,7 @@ from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -26,15 +26,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -43,14 +43,14 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) multiples = cast(List[int], node.args[1]) multiples_shape = [len(multiples)] - tile_op = PyQnnWrapper.PyQnnOpWrapper( + tile_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpTile.op_name, @@ -59,7 +59,7 @@ def define_node( tile_op.AddOutputTensors([output_tensor_wrapper]) tile_op.AddTensorParam( OpTile.param_multiples, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(multiples_shape), multiples_shape, np.array(multiples, dtype=np.uint32), diff --git a/backends/qualcomm/builders/op_reshape.py b/backends/qualcomm/builders/op_reshape.py index 6cc7d81af33..a4f8e1eb554 100644 --- a/backends/qualcomm/builders/op_reshape.py +++ b/backends/qualcomm/builders/op_reshape.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,15 +24,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -40,11 +40,11 @@ def define_node( node, node, node.meta["val"], - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - reshape_op = PyQnnWrapper.PyQnnOpWrapper( + reshape_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReshape.op_name, diff --git a/backends/qualcomm/builders/op_resize.py b/backends/qualcomm/builders/op_resize.py index 04216ce9d2c..13a08ea98e0 100644 --- a/backends/qualcomm/builders/op_resize.py +++ b/backends/qualcomm/builders/op_resize.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) align_corners = cast(bool, node.args[2]) @@ -49,10 +49,10 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - resize_op = PyQnnWrapper.PyQnnOpWrapper( + resize_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpResize.op_name, @@ -62,23 +62,23 @@ def define_node( resize_op.AddScalarParam( OpResize.param_exclude_outside, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: False}, ) resize_op.AddScalarParam( OpResize.param_transformation_mode, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: transformation_mode}, ) resize_op.AddScalarParam( OpResize.param_interpolation_mode, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: interpolation_mode}, ) resize_op.AddScalarParam( OpResize.param_cubic_coeff, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, {QCOM_DATA: cubic_coeff}, ) diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py index 6d5060f730b..aca32071de9 100644 --- a/backends/qualcomm/builders/op_rms_norm.py +++ b/backends/qualcomm/builders/op_rms_norm.py @@ -7,7 +7,7 @@ import warnings from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -34,8 +34,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: # args of node : ['input', 'normalized_shape', 'weight', 'eps'] input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -43,7 +43,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -67,7 +67,7 @@ def define_node( weight_node, node, weight_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) @@ -89,21 +89,23 @@ def define_node( bias_node, node, bias_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) - epsilon = node.args[3] + epsilon = torch.finfo(torch.float32).eps + if len(node.args) > 3: + epsilon = node.args[3] output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - rms_nrom_op = PyQnnWrapper.PyQnnOpWrapper( + rms_nrom_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpRmsNorm.op_name, @@ -115,12 +117,12 @@ def define_node( rms_nrom_op.AddOutputTensors([output_tensor_wrapper]) rms_nrom_op.AddScalarParam( OpRmsNorm.param_epsilon, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, {QCOM_DATA: np.float32(epsilon)}, ) rms_nrom_op.AddTensorParam( OpRmsNorm.param_axes, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(axes_shape), axes_shape, np.array(axes, dtype=np.uint32), diff --git a/backends/qualcomm/builders/op_round.py b/backends/qualcomm/builders/op_round.py index 08aa83b5811..c458375588a 100644 --- a/backends/qualcomm/builders/op_round.py +++ b/backends/qualcomm/builders/op_round.py @@ -1,7 +1,7 @@ import warnings from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from .node_visitor import NodeVisitor @@ -20,15 +20,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -44,11 +44,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - round_op = PyQnnWrapper.PyQnnOpWrapper( + round_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseRound.op_name, diff --git a/backends/qualcomm/builders/op_rsqrt.py b/backends/qualcomm/builders/op_rsqrt.py index 0f0a069441d..28cd191c8d5 100644 --- a/backends/qualcomm/builders/op_rsqrt.py +++ b/backends/qualcomm/builders/op_rsqrt.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,15 +24,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) rsqrt_inp_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) rsqrt_input_tensors = [rsqrt_inp_tensor_wrapper] @@ -42,12 +42,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) rsqrt_output_tensors = [output_tensor_wrapper] - rsqrt_op = PyQnnWrapper.PyQnnOpWrapper( + rsqrt_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseRsqrt.op_name, diff --git a/backends/qualcomm/builders/op_scalar_tensor.py b/backends/qualcomm/builders/op_scalar_tensor.py index bb6b5825803..69daf23f38f 100644 --- a/backends/qualcomm/builders/op_scalar_tensor.py +++ b/backends/qualcomm/builders/op_scalar_tensor.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -23,8 +23,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: val = node.args[0] out_tensor = torch.tensor([val], dtype=node.meta["val"].dtype) @@ -46,6 +46,6 @@ def define_node( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) diff --git a/backends/qualcomm/builders/op_select_copy.py b/backends/qualcomm/builders/op_select_copy.py index 69d237c282d..f004fed8a67 100644 --- a/backends/qualcomm/builders/op_select_copy.py +++ b/backends/qualcomm/builders/op_select_copy.py @@ -6,7 +6,7 @@ import math from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -44,7 +44,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -63,7 +63,7 @@ def define_node( range_shape = [input_tensor_rank, 3] - stride_slice_op = PyQnnWrapper.PyQnnOpWrapper( + stride_slice_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpStridedSlice.op_name, @@ -73,7 +73,7 @@ def define_node( stride_slice_op.AddTensorParam( OpStridedSlice.param_ranges, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_32, len(range_shape), range_shape, np.array(ranges, dtype=np.int32), @@ -82,7 +82,7 @@ def define_node( stride_slice_op.AddScalarParam( OpStridedSlice.param_shrink_axes, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(math.pow(2, dim))}, ) diff --git a/backends/qualcomm/builders/op_sigmoid.py b/backends/qualcomm/builders/op_sigmoid.py index 20f933ed128..f374ebbc996 100644 --- a/backends/qualcomm/builders/op_sigmoid.py +++ b/backends/qualcomm/builders/op_sigmoid.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,15 +24,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) sigmoid_inp_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) sigmoid_input_tensors = [sigmoid_inp_tensor_wrapper] @@ -42,12 +42,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) sigmoid_output_tensors = [output_tensor_wrapper] - sigmoid_op = PyQnnWrapper.PyQnnOpWrapper( + sigmoid_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpSigmoid.op_name, diff --git a/backends/qualcomm/builders/op_sign.py b/backends/qualcomm/builders/op_sign.py index faf2f2e0066..2a77a3e9d82 100644 --- a/backends/qualcomm/builders/op_sign.py +++ b/backends/qualcomm/builders/op_sign.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,15 +24,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -41,11 +41,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - sign_op = PyQnnWrapper.PyQnnOpWrapper( + sign_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseSign.op_name, diff --git a/backends/qualcomm/builders/op_sin.py b/backends/qualcomm/builders/op_sin.py index 5c389ca3b20..5099cabd23c 100644 --- a/backends/qualcomm/builders/op_sin.py +++ b/backends/qualcomm/builders/op_sin.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,15 +24,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -41,11 +41,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - sin_op = PyQnnWrapper.PyQnnOpWrapper( + sin_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseSin.op_name, diff --git a/backends/qualcomm/builders/op_skip_ops.py b/backends/qualcomm/builders/op_skip_ops.py index f52f69d6019..2a60b320927 100644 --- a/backends/qualcomm/builders/op_skip_ops.py +++ b/backends/qualcomm/builders/op_skip_ops.py @@ -6,7 +6,7 @@ from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -25,7 +25,7 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], ) -> None: return @@ -41,14 +41,14 @@ class OpGetItem(OpSkipOps): def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], ) -> None: if isinstance(node.args[1], tuple) or isinstance(node.args[1], list): raise AssertionError( f"Invalid number of index for {node.name }: {len(node.args[1])}" ) idx = node.args[1] - # to fit the format of nodes_to_wrappers, Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], + # to fit the format of nodes_to_wrappers, Dict[str, Dict[int, PyQnnManager.TensorWrapper]], nodes_to_wrappers[node.name] = { 0: nodes_to_wrappers.get(node.args[0].name).get(idx) } diff --git a/backends/qualcomm/builders/op_slice_copy.py b/backends/qualcomm/builders/op_slice_copy.py index 5923d438252..8dcf7365361 100644 --- a/backends/qualcomm/builders/op_slice_copy.py +++ b/backends/qualcomm/builders/op_slice_copy.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER @@ -25,11 +25,11 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -44,7 +44,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) dim = cast(int, node.args[1]) @@ -75,7 +75,7 @@ def define_node( range_shape = [input_tensor_rank, 3] - stride_slice_op = PyQnnWrapper.PyQnnOpWrapper( + stride_slice_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpStridedSlice.op_name, @@ -85,7 +85,7 @@ def define_node( stride_slice_op.AddTensorParam( OpStridedSlice.param_ranges, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_INT_32, len(range_shape), range_shape, np.array(ranges, dtype=np.int32), diff --git a/backends/qualcomm/builders/op_slice_scatter.py b/backends/qualcomm/builders/op_slice_scatter.py index 9fa162d6653..da9e810d064 100644 --- a/backends/qualcomm/builders/op_slice_scatter.py +++ b/backends/qualcomm/builders/op_slice_scatter.py @@ -1,6 +1,6 @@ from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -20,15 +20,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -38,7 +38,7 @@ def define_node( value_node, node, value_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -47,7 +47,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) dim = cast(int, node.args[2]) if len(node.args) > 2 else 0 @@ -101,11 +101,11 @@ def define_node( target_index_node, node, target_index_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) - slice_scatter_op = PyQnnWrapper.PyQnnOpWrapper( + slice_scatter_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpScatterNd.op_name, diff --git a/backends/qualcomm/builders/op_softmax.py b/backends/qualcomm/builders/op_softmax.py index 556f5701f54..356c240fc72 100644 --- a/backends/qualcomm/builders/op_softmax.py +++ b/backends/qualcomm/builders/op_softmax.py @@ -6,7 +6,7 @@ import warnings from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA @@ -26,15 +26,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) softmax_inp_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) softmax_input_tensors = [softmax_inp_tensor_wrapper] @@ -44,7 +44,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) softmax_output_tensors = [output_tensor_wrapper] @@ -62,7 +62,7 @@ def define_node( ) return None - softmax_op = PyQnnWrapper.PyQnnOpWrapper( + softmax_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpSoftmax.op_name, @@ -72,7 +72,7 @@ def define_node( softmax_op.AddScalarParam( OpSoftmax.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(dim)}, ) diff --git a/backends/qualcomm/builders/op_space_to_depth.py b/backends/qualcomm/builders/op_space_to_depth.py index 74e31df475f..b16fa85bbac 100644 --- a/backends/qualcomm/builders/op_space_to_depth.py +++ b/backends/qualcomm/builders/op_space_to_depth.py @@ -6,7 +6,7 @@ from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,15 +27,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -44,7 +44,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -54,7 +54,7 @@ def define_node( block_size = np.array(block_size, dtype=np.uint32) block_size_shape = [2] - space_to_depth_op = PyQnnWrapper.PyQnnOpWrapper( + space_to_depth_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpSpaceToDepth.op_name, @@ -63,7 +63,7 @@ def define_node( space_to_depth_op.AddOutputTensors([output_tensor_wrapper]) space_to_depth_op.AddTensorParam( OpSpaceToDepth.param_block_size, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(block_size.shape), block_size_shape, block_size, @@ -71,7 +71,7 @@ def define_node( ) space_to_depth_op.AddScalarParam( OpSpaceToDepth.param_mode, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(OpSpaceToDepth.Mode.CRD)}, ) diff --git a/backends/qualcomm/builders/op_split_with_sizes.py b/backends/qualcomm/builders/op_split_with_sizes.py index fc5ba0f11fb..bbac10862c0 100644 --- a/backends/qualcomm/builders/op_split_with_sizes.py +++ b/backends/qualcomm/builders/op_split_with_sizes.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -26,8 +26,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -36,7 +36,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) input_tensor_wrappers = [input_tensor_wrapper] @@ -49,7 +49,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, wrapper_idx=index, ) @@ -75,7 +75,7 @@ def define_node( if QCOM_AXIS_ORDER in node.meta: dim = node.meta[QCOM_AXIS_ORDER].index(dim) - split_op = PyQnnWrapper.PyQnnOpWrapper( + split_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpSplit.op_name, @@ -84,7 +84,7 @@ def define_node( split_op.AddOutputTensors(output_tensor_wrappers) split_op.AddTensorParam( OpSplit.param_split_index, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(split_indices_shape), split_indices_shape, np.array(split_indices, dtype=np.uint32), @@ -93,7 +93,7 @@ def define_node( split_op.AddScalarParam( OpSplit.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(dim)}, ) return split_op diff --git a/backends/qualcomm/builders/op_sqrt.py b/backends/qualcomm/builders/op_sqrt.py index b71d4d68c30..d4bfe724853 100644 --- a/backends/qualcomm/builders/op_sqrt.py +++ b/backends/qualcomm/builders/op_sqrt.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,8 +24,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: # tensor input input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -34,7 +34,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) sqrt_input_tensors = [input_tensor_wrapper] @@ -44,12 +44,12 @@ def define_node( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) sqrt_output_tensors = [output_tensor_wrapper] - sqrt_op = PyQnnWrapper.PyQnnOpWrapper( + sqrt_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseSquareRoot.op_name, diff --git a/backends/qualcomm/builders/op_squeeze.py b/backends/qualcomm/builders/op_squeeze.py index 0cb7bf142b9..24b2d194dd7 100644 --- a/backends/qualcomm/builders/op_squeeze.py +++ b/backends/qualcomm/builders/op_squeeze.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,8 +24,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -33,7 +33,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -42,11 +42,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - squeeze_op = PyQnnWrapper.PyQnnOpWrapper( + squeeze_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReshape.op_name, diff --git a/backends/qualcomm/builders/op_stack.py b/backends/qualcomm/builders/op_stack.py index 2d8587d51cd..6b27d7f5e4c 100644 --- a/backends/qualcomm/builders/op_stack.py +++ b/backends/qualcomm/builders/op_stack.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -26,8 +26,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node_list = node.args[0] stack_input_tensors = [] for input_node in input_node_list: @@ -36,7 +36,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) stack_input_tensors.append(stack_inp_tensor_wrapper) @@ -45,7 +45,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) stack_output_tensors = [output_tensor_wrapper] @@ -55,7 +55,7 @@ def define_node( dim = dim % len(output_tensor.shape) if QCOM_AXIS_ORDER in node.meta: dim = node.meta[QCOM_AXIS_ORDER].index(dim) - stack_op = PyQnnWrapper.PyQnnOpWrapper( + stack_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpPack.op_name, @@ -65,7 +65,7 @@ def define_node( stack_op.AddScalarParam( OpPack.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(dim)}, ) diff --git a/backends/qualcomm/builders/op_sub.py b/backends/qualcomm/builders/op_sub.py index 064d9b3cd42..9d70be6cc59 100644 --- a/backends/qualcomm/builders/op_sub.py +++ b/backends/qualcomm/builders/op_sub.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) sub_output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -51,7 +51,7 @@ def define_node( ) sub_input_tensors.append(input_tensor_wrapper) - sub_op = PyQnnWrapper.PyQnnOpWrapper( + sub_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseSubtract.op_name, diff --git a/backends/qualcomm/builders/op_sum_int_list.py b/backends/qualcomm/builders/op_sum_int_list.py index af5fd1cecba..8b686261b97 100644 --- a/backends/qualcomm/builders/op_sum_int_list.py +++ b/backends/qualcomm/builders/op_sum_int_list.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -26,8 +26,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -35,7 +35,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) sum_input_tensors = [input_tensor_wrapper] @@ -54,11 +54,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) sum_output_tensors = [output_tensor_wrapper] - sum_op = PyQnnWrapper.PyQnnOpWrapper( + sum_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReduceSum.op_name, @@ -67,7 +67,7 @@ def define_node( sum_op.AddOutputTensors(sum_output_tensors) sum_op.AddTensorParam( OpReduceSum.param_axes, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(sum_dims_shape), sum_dims_shape, np.array(sum_dims, dtype=np.uint32), @@ -78,7 +78,7 @@ def define_node( keep_dims = cast(bool, node.args[2]) sum_op.AddScalarParam( OpReduceSum.param_keep_dims, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: keep_dims}, ) return sum_op diff --git a/backends/qualcomm/builders/op_tanh.py b/backends/qualcomm/builders/op_tanh.py index c61439398e3..fc95aec8335 100644 --- a/backends/qualcomm/builders/op_tanh.py +++ b/backends/qualcomm/builders/op_tanh.py @@ -6,7 +6,7 @@ from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -25,15 +25,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -42,11 +42,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - tanh_op = PyQnnWrapper.PyQnnOpWrapper( + tanh_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpTanh.op_name, diff --git a/backends/qualcomm/builders/op_to.py b/backends/qualcomm/builders/op_to.py index 6774b0e3af6..9b024307934 100644 --- a/backends/qualcomm/builders/op_to.py +++ b/backends/qualcomm/builders/op_to.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS @@ -22,12 +22,12 @@ class To(NodeVisitor): sufixed_16_offset_diff = 32768 epsilon = 1e-6 sufixed_8 = { - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, } sufixed_16 = { - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, } def __init__(self, *args) -> None: @@ -79,8 +79,8 @@ def is_cast_node(self, node): def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -88,7 +88,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) node_input_tensors = [input_tensor_wrapper] @@ -110,7 +110,7 @@ def define_node( ) cast_intermediate_tensor_wrapper = self.define_custom_tensor_wrapper( node_name=node.name + "_cast", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, dtype=QNN_TENSOR_TYPE_MAP[torch.int32], quant_encoding=input_quant_encoding, quant_configs=input_quant_configs, @@ -119,7 +119,7 @@ def define_node( is_fake_tensor=True, nodes_to_wrappers=nodes_to_wrappers, ) - cast_op = PyQnnWrapper.PyQnnOpWrapper( + cast_op = PyQnnManager.PyQnnOpWrapper( f"{node.name}_cast", QNN_OP_PACKAGE_NAME_QTI_AISW, OpCast.op_name, @@ -134,12 +134,12 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) qnn_op = OpCast if self.is_cast_node(node) else OpConvert - op = PyQnnWrapper.PyQnnOpWrapper( + op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name ) op.AddInputTensors(node_input_tensors) diff --git a/backends/qualcomm/builders/op_topk.py b/backends/qualcomm/builders/op_topk.py index f310752c8f6..d027702ddf0 100644 --- a/backends/qualcomm/builders/op_topk.py +++ b/backends/qualcomm/builders/op_topk.py @@ -6,7 +6,7 @@ import warnings from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -31,8 +31,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -40,7 +40,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) @@ -70,7 +70,7 @@ def define_node( node, node, output_val_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -80,13 +80,13 @@ def define_node( node, node, output_idx_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, wrapper_idx=1, ) topk_output_tensors = [output_val_tensor_wrapper, output_index_tensor_wrapper] - topk_op = PyQnnWrapper.PyQnnOpWrapper( + topk_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpTopK.op_name, @@ -96,7 +96,7 @@ def define_node( topk_op.AddScalarParam( OpTopK.param_k, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(k)}, ) @@ -105,7 +105,7 @@ def define_node( largest = cast(bool, node.args[3]) topk_op.AddScalarParam( OpTopK.param_largest, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: largest}, ) diff --git a/backends/qualcomm/builders/op_transpose.py b/backends/qualcomm/builders/op_transpose.py index dbed10ced46..e3682c8bcde 100644 --- a/backends/qualcomm/builders/op_transpose.py +++ b/backends/qualcomm/builders/op_transpose.py @@ -6,7 +6,7 @@ from typing import cast, Dict, List -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -27,8 +27,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) permute_node = input_node if QCOM_INSERTED_PERMUTE in node.meta else node input_tensor = self.get_tensor(input_node, permute_node) @@ -36,12 +36,14 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) # permutation permute_order = cast(List[int], node.args[1]) + # to prevent negative values + permute_order = [x % len(permute_order) for x in permute_order] permute_order_shape = [len(permute_order)] output_tensor = input_tensor.permute(permute_order) @@ -49,11 +51,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - transpose_op = PyQnnWrapper.PyQnnOpWrapper( + transpose_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpTranspose.op_name, @@ -65,7 +67,7 @@ def define_node( transpose_op.AddTensorParam( OpTranspose.param_perm, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, len(permute_order_shape), permute_order_shape, np.array(permute_order, dtype=np.uint32), diff --git a/backends/qualcomm/builders/op_unbind.py b/backends/qualcomm/builders/op_unbind.py index 7db8bf07596..66d05fcc636 100644 --- a/backends/qualcomm/builders/op_unbind.py +++ b/backends/qualcomm/builders/op_unbind.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import numpy as np import torch @@ -26,15 +26,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, ) unbind_input_tensors = [input_tensor_wrapper] @@ -46,7 +46,7 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, wrapper_idx=i, ) @@ -57,7 +57,7 @@ def define_node( dim = dim % len(input_tensor.shape) if QCOM_AXIS_ORDER in node.meta: dim = node.meta[QCOM_AXIS_ORDER].index(dim) - unbind_op = PyQnnWrapper.PyQnnOpWrapper( + unbind_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpUnpack.op_name, @@ -67,7 +67,7 @@ def define_node( unbind_op.AddScalarParam( OpUnpack.param_axis, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {QCOM_DATA: np.uint32(dim)}, ) diff --git a/backends/qualcomm/builders/op_unsqueeze.py b/backends/qualcomm/builders/op_unsqueeze.py index 3408f3ec14f..46b7a904137 100644 --- a/backends/qualcomm/builders/op_unsqueeze.py +++ b/backends/qualcomm/builders/op_unsqueeze.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,8 +24,8 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) @@ -33,7 +33,7 @@ def define_node( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -42,11 +42,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - unsqueeze_op = PyQnnWrapper.PyQnnOpWrapper( + unsqueeze_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReshape.op_name, diff --git a/backends/qualcomm/builders/op_upsample_bilinear2d.py b/backends/qualcomm/builders/op_upsample_bilinear2d.py index 7394823899e..fb9b4b078fd 100644 --- a/backends/qualcomm/builders/op_upsample_bilinear2d.py +++ b/backends/qualcomm/builders/op_upsample_bilinear2d.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from executorch.backends.qualcomm.utils.constants import QCOM_DATA @@ -25,15 +25,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -42,11 +42,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - resize_bilinear_op = PyQnnWrapper.PyQnnOpWrapper( + resize_bilinear_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpResizeBilinear.op_name, @@ -56,12 +56,12 @@ def define_node( resize_bilinear_op.AddScalarParam( OpResizeBilinear.param_align_corners, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: node.args[2]}, ) resize_bilinear_op.AddScalarParam( OpResizeBilinear.param_half_pixel_centers, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: not node.args[2]}, ) diff --git a/backends/qualcomm/builders/op_upsample_nearest2d.py b/backends/qualcomm/builders/op_upsample_nearest2d.py index a338f54b91f..8d85ba3576a 100644 --- a/backends/qualcomm/builders/op_upsample_nearest2d.py +++ b/backends/qualcomm/builders/op_upsample_nearest2d.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from executorch.backends.qualcomm.utils.constants import QCOM_DATA @@ -25,15 +25,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -42,11 +42,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - resize_nearest_op = PyQnnWrapper.PyQnnOpWrapper( + resize_nearest_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpResizeNearestNeighbor.op_name, @@ -56,12 +56,12 @@ def define_node( # align_corners is guaranteed to be false resize_nearest_op.AddScalarParam( OpResizeNearestNeighbor.param_align_corners, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: False}, ) resize_nearest_op.AddScalarParam( OpResizeNearestNeighbor.param_half_pixel_centers, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: True}, ) diff --git a/backends/qualcomm/builders/op_where.py b/backends/qualcomm/builders/op_where.py index 460431a4814..6a9c750ee5e 100644 --- a/backends/qualcomm/builders/op_where.py +++ b/backends/qualcomm/builders/op_where.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,15 +24,15 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: conditional_input_node = self.get_node(node.args[0]) conditional_input_tensor = self.get_tensor(conditional_input_node, node) conditional_input_tensor_wrapper = self.define_tensor( conditional_input_node, node, conditional_input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -42,7 +42,7 @@ def define_node( true_input_node, node, true_input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -52,7 +52,7 @@ def define_node( false_input_node, node, false_input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) @@ -61,11 +61,11 @@ def define_node( node, node, output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - where_op = PyQnnWrapper.PyQnnOpWrapper( + where_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseSelect.op_name, diff --git a/backends/qualcomm/builders/op_xor.py b/backends/qualcomm/builders/op_xor.py index d4462d9c707..eff5d1c009b 100644 --- a/backends/qualcomm/builders/op_xor.py +++ b/backends/qualcomm/builders/op_xor.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Dict -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,14 +24,14 @@ def __init__(self, *args) -> None: def define_node( self, node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, node, out_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) xor_output_tensors = [output_tensor_wrapper] @@ -40,7 +40,7 @@ def define_node( for index in range(2): input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE + tensor_type = PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, @@ -50,7 +50,7 @@ def define_node( nodes_to_wrappers, ) xor_input_tensors.append(input_tensor_wrapper) - xor_op = PyQnnWrapper.PyQnnOpWrapper( + xor_op = PyQnnManager.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpElementWiseXor.op_name, diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index b0c44dcae80..ecc221885dc 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -59,6 +59,15 @@ class OpConv2d: param_dilation: str = "dilation" +@dataclass(init=False, frozen=True) +class OpConv3d: + op_name: str = "Conv3d" + param_stride: str = "stride" + param_pad_amount: str = "pad_amount" + param_group: str = "group" + param_dilation: str = "dilation" + + @dataclass(init=False, frozen=True) class OpConvert: op_name: str = "Convert" @@ -295,6 +304,24 @@ class OpGather: param_axis: str = "axis" +class OpGridSample: + op_name: str = "GridSample" + param_align_corners: str = "align_corners" + param_mode: str = "mode" + param_padding_mode: str = "padding_mode" + + @unique + class Mode(IntEnum): + BILINAR = 0 + NEAREST = 1 + + @unique + class PaddingMode(IntEnum): + ZEROS = 0 + BORDER = 1 + REFLECTION = 2 + + @dataclass(init=False, frozen=True) class OpGatherElements: op_name: str = "GatherElements" @@ -389,6 +416,21 @@ class RoundingMode(IntEnum): CEIL = 1 +@dataclass(init=False, frozen=True) +class OpPoolAvg3d: + op_name: str = "PoolAvg3d" + param_filter_size: str = "filter_size" + param_stride: str = "stride" + param_pad_amount: str = "pad_amount" + param_count_pad_for_edges: str = "count_pad_for_edges" + param_rounding_mode: str = "rounding_mode" + + @unique + class RoundingMode(IntEnum): + FLOOR = 0 + CEIL = 1 + + @dataclass(init=False, frozen=True) class OpPoolMax2d: op_name: str = "PoolMax2d" @@ -573,6 +615,15 @@ class OpTransposeConv2d: param_output_padding: str = "output_padding" +@dataclass(init=False, frozen=True) +class OpTransposeConv3d: + op_name: str = "TransposeConv3d" + param_stride: str = "stride" + param_pad_amount: str = "pad_amount" + param_group: str = "group" + param_output_padding: str = "output_padding" + + @dataclass(init=False, frozen=True) class OpUnpack: op_name: str = "UnPack" diff --git a/backends/qualcomm/builders/targets.bzl b/backends/qualcomm/builders/targets.bzl index 39159e56cd8..fc7d88781cc 100644 --- a/backends/qualcomm/builders/targets.bzl +++ b/backends/qualcomm/builders/targets.bzl @@ -17,7 +17,6 @@ def define_common_targets(): deps = [ "//executorch/exir/backend:backend_details", "//executorch/exir/backend:compile_spec_schema", - "//executorch/backends/qualcomm/aot/python:PyQnnWrapperAdaptor", "//executorch/backends/qualcomm/aot/python:PyQnnManagerAdaptor", "//executorch/backends/qualcomm/utils:utils", "//executorch/exir:lib", diff --git a/backends/qualcomm/debugger/README.md b/backends/qualcomm/debugger/README.md index 60ecb3d71b3..1c91382131f 100644 --- a/backends/qualcomm/debugger/README.md +++ b/backends/qualcomm/debugger/README.md @@ -18,7 +18,7 @@ python -m examples.qualcomm.util_scripts.qairt_visualizer_demo -H ${host} -s {de - If online prepare mode is `enabled`, the following artifacts will be generated: - `model`.dlc - `optrace`.json - - `QHAS` + - `QHAS`.json - If online prepare mode is `disabled`, the following artifacts will be generated: - `model`.bin - `optrace`.json @@ -54,6 +54,7 @@ adb = SimpleADB( device_id=args.device, host_id=args.host, soc_model=args.model, + target=args.target, ) binaries_trace = generate_optrace( args, adb, f"{args.artifact}/{pte_filename}.pte", example_input @@ -91,3 +92,160 @@ Note: Files ending with `.bin ` do not support graph visualization in qairt_visu For more details, visit the [QAIRT Visualizer](https://pypi.org/project/qairt-visualizer/). + + +# ExecuTorch QNN Intermediate Output Debugger + +ExecuTorch QNN Intermediate Output Debugger is a tool that helps users debug intermediate output accuracy by comparing CPU outputs with QNN outputs. This tool offers a variety of output formats and flexibility for users to define their own metrics when debugging. + +Below, we will go through the details step by step on how to customize your own debugger. By the end of this tutorial, users should understand the mechanism behind the ExecuTorch QNN Debugger and how to apply the debugger to the desired model. In the rest of the tutorial, we will use the term `intermediate output` and `per-layer dump` interchangeably. + +To make the implementation process smooth, we have also provided an example script, [qnn_intermediate_debugger_demo.py](../../../examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py), which is an end-to-end example that goes through the steps for implementation. Refer to [Example Script](#example-script) section for more information. + +## Introduction + +1. Why do we need ExecuTorch QNN Intermediate Output Debugger? + During inference, there might be gaps between QNN and CPU final outputs. This leaves developers unsure about the root cause of accuracy drop. By using this debugger, users can gain better insight into which operation is causing the accuracy drop. Please note that the accuracy drop here refers to comparing QNN with CPU outputs, not the ground truth. + +2. Who is this tool for? + This tool is mainly for developers aiming to align QNN with CPU accuracy. Users will be able to identify which layer in the model is causing the accuracy drop, helping them either circumvent the issue by replacing the layer with other operations or contact authors in Qualcomm AI Engine Direct to resolve the accuracy issue. Please refer to the last section under [README.md](../README.md) for authors to contact when encountering any issues. + + +## Design Flow +```mermaid +flowchart TB; + nn.Module; + nn.Module --> edge_program["Edge Program"]; + edge_program --> qnn_lower["QNN with Per-Layer Dump"]; + qnn_lower --> qnn_inference[QNN Inference]; + qnn_inference --> debug + edge_program --> cpu_lower["Edge CPU with Per-Layer Dump"]; + cpu_lower --> cpu_inference["CPU Inference"]; + cpu_inference --> debug["Debug"]; + debug --> output["Output Results"] +``` + +## Instructions + +### 1. Setup +1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. +2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend. + +### 2. Enable Flag + +When executing the script, please add the flag `--dump_intermediate_outputs`. This tells QNN to dump all intermediate tensors during execution. + +### 3. Add debugger to the example script +Initialize a `QNNIntermediateDebugger`. Please pass initialized `QNNIntermediateDebugger` and the `args.dump_intermediate_outputs` to `build_executorch_binary` method as well. +#### Example: +```python +from executorch.examples.qualcomm.utils import build_executorch_binary +from executorch.backends.qualcomm.debugger.qnn_intermediate_debugger import QNNIntermediateDebugger + +qnn_intermediate_debugger = QNNIntermediateDebugger() +build_executorch_binary( + model=MyModel(), + inputs=(torch.randn(200, 768),), + soc_model="SM8650", + file_name="my_model", + dataset=my_dataset, + dump_intermediate_outputs=args.dump_intermediate_outputs, # Add this flag + qnn_intermediate_debugger=qnn_intermediate_debugger, # Add this flag +) +``` + +### 4. Set data num to 1 +It is perfectly fine for users to pass the desired amount of datasets to `build_executorch_binary`, which helps achieve better quantization results. However, after `build_executorch_binary` is called, we need to ensure that we only perform one inference during execution. Please ensure that CPU and QNN is using the same input during execution; otherwise, the debugging results might not be accurate. + +### 5. Pass flag to SimpleADB +When creating `SimpleADB`, please also pass the flag `args.dump_intermediate_outputs`. This tells the runner to create files that store the intermediate output schema and binary data. +#### Example: +```python +adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + dump_intermediate_outputs=args.dump_intermediate_outputs, # Add this flag +) +``` + +### 6: Pull and process the results. +After QNN execution with the runner, if the previous steps are done correctly, we should be able to get two files: `etdump.etdp` and `debug_output.bin`. +The following example pulls the files back and calls a callback function to process the results. In this callback function, we create the `Inspector`. Then we perform CPU inference to get CPU intermediate results. Now, we have both QNN and CPU intermediate results, we can start generating results to compare the accuracy. Taking the following example, we should be able to get `debug_graph.svg` as an output in the current directory. +#### Example: +```python +from executorch.backends.qualcomm.debugger.qnn_intermediate_debugger import OutputFormat +def validate_intermediate_tensor(): + inspector = Inspector( + etdump_path=f"{args.artifact}/etdump.etdp", + debug_buffer_path=f"{args.artifact}/debug_output.bin", + ) + qnn_intermediate_debugger.intermediate_output_module(*(inputs[0])) + qnn_intermediate_debugger.generate_results( + title="debug_graph", + path=".", + output_format=OutputFormat.SVG_GRAPHS, + inspector=inspector, + evaluator=CosineSimilarityEvaluator(0.9), + ) + +adb.pull_debug_output( + args.artifact, args.artifact, callback=validate_intermediate_tensor +) +``` + +#### Additional Options +The above example sets output formats as SVG and evaluation metrics using Cosine Similarity. Based on different needs, users can choose other output formats as shown in the `OutputFormat` class under [qnn_intermediate_debugger](./qnn_intermediate_debugger.py) +```python +class OutputFormat(IntEnum): + SVG_GRAPHS = 0 + CSV_FILES = 1 + DUMP_RAW = 2 +``` + +For evaluation metrics, if users would like to implement their own metrics, we have provided the option to implement [MetricEvaluatorBase](./metrics_evaluator.py). The following shows how to define custom metrics. +```python +class RootMeanSquaredErrorEvaluator(MetricEvaluatorBase): + def __init__(self, threshold=0.02): + self.threshold = threshold + + def metric_name(self) -> str: + return "Root Mean Squared Error" + + def evaluate( + self, qnn_output: torch.Tensor, cpu_output: torch.Tensor + ) -> Tuple[Any, bool]: + mse = F.mse_loss(qnn_output, cpu_output) + rmse = torch.sqrt(mse) + valid = rmse < self.threshold + return rmse, valid + +qnn_intermediate_debugger.generate_results( + title="my_metric", + path=".", + output_format=OutputFormat.SVG_GRAPHS, + inspector=inspector, + evaluator=RootMeanSquaredErrorEvaluator(), +) +``` + +### Example Script +We have provided an inception_v3 demo script to help users better understand how to apply the debugger to their scripts. Please refer to [qnn_intermediate_debugger_demo.py](../../../examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py) for the example script. + +Before running the example script, please ensure that dataset is downloaded. Example dataset can be retrieved [here](https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000). + +To execute the model: +```bash +python examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py -b build-android -m ${SOC_MODEL} --device ${SERIAL_NUM} --dataset ${PATH_TO_DATASET} --dump_intermediate_outputs +``` + +### Limitation +1. The current debugger only supports performing one execution. Multiple executions may cause unknown behavior and are not recommended. +2. Please ignore this if you are using `qnn_executor_runner`. If you have decided to write your own runner, please follow the [tutorial](https://pytorch.org/executorch/stable/etdump.html) on how to implement etdump into your own runner. +3. The current debugger does not support graph with partitions. (WIP) +4. The current debugger does not support LLM models. (WIP) diff --git a/backends/qualcomm/debugger/TARGETS b/backends/qualcomm/debugger/TARGETS index 85f204f9718..6a7732231fc 100644 --- a/backends/qualcomm/debugger/TARGETS +++ b/backends/qualcomm/debugger/TARGETS @@ -10,3 +10,21 @@ runtime.python_library( "fbsource//third-party/pypi/pandas:pandas", ] ) + +runtime.python_library( + name = "qnn_intermediate_debugger", + srcs = [ + "format_outputs.py", + "metrics_evaluator.py", + "qnn_intermediate_debugger.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/qualcomm/_passes:passes", + "//executorch/backends/qualcomm/utils:utils", + "//executorch/devtools:lib", + "//executorch/exir:sym_util", + "fbsource//third-party/pypi/graphviz:graphviz", + "fbsource//third-party/pypi/pandas:pandas", + ], +) diff --git a/backends/qualcomm/debugger/format_outputs.py b/backends/qualcomm/debugger/format_outputs.py new file mode 100644 index 00000000000..05f5c908919 --- /dev/null +++ b/backends/qualcomm/debugger/format_outputs.py @@ -0,0 +1,221 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import csv +import os +from typing import Any + +import pydot +import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_SCALE, + QCOM_SCALES, + QCOM_TENSOR_NAME, + QCOM_ZERO_POINT, + QCOM_ZERO_POINTS, +) + +from .metrics_evaluator import MetricEvaluatorBase + + +# Copied from site-packages/torch/fx/passes/graph_drawer.py +def typename(target: Any) -> str: + from torch.fx.node import _get_qualified_name + + if isinstance(target, torch.nn.Module): + ret = torch.typename(target) + elif isinstance(target, str): + ret = target + else: + ret = _get_qualified_name(target) + + # Escape "{" and "}" to prevent dot files like: + # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc + # which triggers `Error: bad label format (...)` from dot + return ret.replace("{", r"\{").replace("}", r"\}") + + +def retrieve_node_info(evaluator, node, node_tensor_map): + + node_info = {} + node_info["name"] = node.name + node_info["op_code"] = node.op + node_info["target"] = typename(node.target) + node_info["num_users"] = len(node.users) + + if "val" in node.meta: + if isinstance(node.meta["val"], torch.Tensor): + node_info["pytorch_layout"] = node.meta["val"].shape + elif isinstance(node.meta["val"], (list, tuple)): + shape_list = [] + for i in range(len(node.meta["val"])): + shape_list.append(node.meta["val"][i].shape) + node_info["pytorch_layout"] = shape_list + + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + node_info["scale(s)"] = ( + quant_attrs.get(QCOM_SCALES) + if QCOM_SCALES in quant_attrs + else quant_attrs.get(QCOM_SCALE) + ) + node_info["zero_point(s)"] = ( + quant_attrs.get(QCOM_ZERO_POINTS) + if QCOM_ZERO_POINTS in quant_attrs + else quant_attrs.get(QCOM_ZERO_POINT) + ) + + if node.name in node_tensor_map: + qnn_output, cpu_output, meta = node_tensor_map[node.name] + node_info[QCOM_TENSOR_NAME] = meta.get(QCOM_TENSOR_NAME) + node_info[evaluator.metric_name()], node_info["is_valid_score"] = ( + evaluator.evaluate(qnn_output, cpu_output) + ) + + # The values in meta are directly retrieved from the node during the forward hook, which means the values should be the same for meta and node.meta. + # Storing these data during the forward hook helps us compare QNN tensors with CPU tensors without traversing the graph. + # We only check "scale" and not "scales" since the forward hook only stores the node's output, which should always be per tensor. + if QCOM_QUANT_ATTRS in node.meta: + assert ( + node_info["scale(s)"] == node.meta[QCOM_QUANT_ATTRS][QCOM_SCALE] + ), "node meta scale should be same as scale retrieve during forward hook" + assert ( + node_info["zero_point(s)"] + == node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT] + ), "node meta zero_point should be same as zero_point retrieve during forward hook" + + return node_info + + +def export_svg( + title: str, + path: str, + evaluator: MetricEvaluatorBase, + edge_module: torch.fx.GraphModule, + node_tensor_map: dict, +): + def get_node_style(is_valid_score: bool): + template = { + "shape": "record", + "style": '"filled,rounded"', + "fontcolor": "#000000", + } + + if is_valid_score is None: + template["fillcolor"] = "LemonChiffon1" # No match between QNN and CPU + elif is_valid_score: + template["fillcolor"] = "DarkOliveGreen3" # Good accuracy + else: + template["fillcolor"] = "Coral1" # Bad accuracy + + return template + + pydot_graph = pydot.Dot(graph_type="graph") + node_map = {} + + # Create node + for node in edge_module.graph.nodes: + # These are just nodes before fold_quant and still there + if len(node.users) == 0 and node.op == "placeholder": + continue + node_info = retrieve_node_info( + evaluator=evaluator, node=node, node_tensor_map=node_tensor_map + ) + + node_label = "{" + node_label += f"name=%{node_info.get('name')}" + r"\n" + node_label += f"|op_code={node_info.get('op_code')}" + r"\n" + node_label += f"|qnn_tensor_name={node_info.get('qnn_tensor_name')}" + r"\n" + node_label += f"|target={node_info.get('target')}" + r"\n" + node_label += f"|num_users={node_info.get('num_users')}" + r"\n" + node_label += f"|pytorch_layout={node_info.get('pytorch_layout')}" + r"\n" + node_label += f"|scale(s)={node_info.get('scale(s)')}" + r"\n" + node_label += f"|zero_point(s)={node_info.get('zero_point(s)')}" + r"\n" + node_label += ( + f"|{evaluator.metric_name()}={node_info.get(evaluator.metric_name())}" + + r"\n" + ) + node_label += f"|is_valid_score={node_info.get('is_valid_score')}" + r"\n" + node_label += "}" + + template = get_node_style(node_info.get("is_valid_score")) + pydot_node = pydot.Node(node.name, label=node_label, **template) + node_map[node.name] = pydot_node + pydot_graph.add_node(pydot_node) + + # Create edge + for node in edge_module.graph.nodes: + if len(node.users) == 0 and node.op == "placeholder": + continue + cur_pydot_node = node_map[node.name] + users = list(node.users.keys()) + for user in users: + user_pydot_node = node_map[user.name] + pydot_graph.add_edge( + pydot.Edge(cur_pydot_node, user_pydot_node, dir="forward") + ) + + pydot_graph.write_svg(f"{path}/{title}.svg") + print(f"Intermediate debugger graph saved at: {path}/{title}.svg") + + +def export_csv( + title: str, + path: str, + evaluator: MetricEvaluatorBase, + edge_module: torch.fx.GraphModule, + node_tensor_map: dict, +): + node_info_list = [] + for node in edge_module.graph.nodes: + # These are just nodes before fold_quant and still there + if len(node.users) == 0 and node.op == "placeholder": + continue + node_info = retrieve_node_info( + evaluator=evaluator, node=node, node_tensor_map=node_tensor_map + ) + node_info_list.append(node_info) + + # Writing to a CSV file + with open(f"{path}/{title}.csv", mode="w", newline="") as csv_file: + fieldnames = [ + "name", + "op_code", + "qnn_tensor_name", + "target", + "num_users", + "pytorch_layout", + "scale(s)", + "zero_point(s)", + f"{evaluator.metric_name()}", + "is_valid_score", + ] + writer = csv.DictWriter(csv_file, fieldnames=fieldnames) + + writer.writeheader() + writer.writerows(node_info_list) + + print(f"Intermediate debugger csv saved at: {path}/{title}.csv") + + +def export_raw( + path: str, + edge_module: torch.fx.GraphModule, + node_tensor_map: dict, +): + for node in edge_module.graph.nodes: + # These are just unused nodes before fold_quant and still there + if len(node.users) == 0 and node.op == "placeholder": + continue + if paired_event := node_tensor_map.get(node.name): + qnn_output, cpu_output, meta = paired_event + qnn_tensor_name = meta[QCOM_TENSOR_NAME] + qnn_output_path = os.path.join(path, qnn_tensor_name + "_qnn.raw") + cpu_output_path = os.path.join(path, qnn_tensor_name + "_cpu.raw") + qnn_output.numpy().tofile(qnn_output_path) + cpu_output.numpy().tofile(cpu_output_path) + + print(f"Intermediate debugger raw files saved at: {path}") diff --git a/backends/qualcomm/debugger/metrics_evaluator.py b/backends/qualcomm/debugger/metrics_evaluator.py new file mode 100644 index 00000000000..55c8b92b034 --- /dev/null +++ b/backends/qualcomm/debugger/metrics_evaluator.py @@ -0,0 +1,90 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Any, Tuple + +import torch + + +class MetricEvaluatorBase(ABC): + @abstractmethod + def metric_name(self) -> str: + """ + A name for this metric evaluation + + Returns: + str: name of the metric evaluation + """ + ... + + @abstractmethod + def evaluate( + self, qnn_output: torch.Tensor, cpu_output: torch.Tensor, **kwargs + ) -> Tuple[Any, bool]: + """ + This abstract method should accept both QNN and CPU outputs for a single layer. + Define your own logic to compare the results. + + Args: + qnn_output (torch.Tensor): QNN intermediate output + cpu_output (torch.Tensor): CPU intermediate output + + Returns: + Tuple[Any, bool]: Return 2 elements: + 1) Score or anything that you would like to be printed under metrics category for svg graph or csv file. + 2) A boolean that indicates whether the evaluation result is acceptable or not. + """ + ... + + +class AtolEvaluator(MetricEvaluatorBase): + def __init__(self, threshold=1e-1): + self.threshold = threshold + + def metric_name(self) -> str: + return "Atol Similarity" + + def evaluate( + self, qnn_output: torch.Tensor, cpu_output: torch.Tensor + ) -> Tuple[Any, bool]: + avg_atol = torch.mean(torch.abs(qnn_output - cpu_output)) + valid = avg_atol < self.threshold + formatted_score = f"{avg_atol:.3f}" + return formatted_score, valid + + +class CosineSimilarityEvaluator(MetricEvaluatorBase): + def __init__(self, threshold=0.9): + self.threshold = threshold + + def metric_name(self) -> str: + return "Cosine Similarity" + + def evaluate( + self, qnn_output: torch.Tensor, cpu_output: torch.Tensor + ) -> Tuple[Any, bool]: + score = torch.nn.functional.cosine_similarity( + qnn_output.flatten(), cpu_output.flatten(), dim=0 + ).item() + valid = score > self.threshold + formatted_score = f"{score:.3f}" + return formatted_score, valid + + +class MeanSquaredErrorEvaluator(MetricEvaluatorBase): + def __init__(self, threshold=0.01): + self.threshold = threshold + + def metric_name(self) -> str: + return "Mean Squared Error" + + def evaluate( + self, qnn_output: torch.Tensor, cpu_output: torch.Tensor + ) -> Tuple[Any, bool]: + mse = torch.mean((qnn_output - cpu_output) ** 2) + valid = mse < self.threshold + return mse, valid diff --git a/backends/qualcomm/debugger/qnn_intermediate_debugger.py b/backends/qualcomm/debugger/qnn_intermediate_debugger.py new file mode 100644 index 00000000000..904dd4f6ccb --- /dev/null +++ b/backends/qualcomm/debugger/qnn_intermediate_debugger.py @@ -0,0 +1,328 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import operator +import os +import warnings +from enum import IntEnum + +import torch + +from executorch.backends.qualcomm._passes.layout_transform import LayoutTransform +from executorch.backends.qualcomm.utils.constants import ( + QCOM_AXIS_ORDER, + QCOM_QUANT_ATTRS, + QCOM_SCALE, + QCOM_TENSOR_NAME, + QCOM_ZERO_POINT, +) +from executorch.devtools import Inspector +from executorch.exir.sym_util import eval_shape + +from .format_outputs import export_csv, export_raw, export_svg +from .metrics_evaluator import MetricEvaluatorBase + + +class OutputFormat(IntEnum): + SVG_GRAPHS = 0 + CSV_FILES = 1 + DUMP_RAW = 2 + + +class IntermediateModule(torch.nn.Module): + """ + This class serves as an intermediate point and is inserted right after the call_function node. + It also saves some metadata such as scale, offset, etc. + Since we just want to check the intermediate output, we will directly return the value during the forward call. + """ + + def __init__( + self, + module_name: str, + qnn_tensor_name: str, + node_name: str, + scale: float, + zero_point: int, + revert_order: bool = None, + ): + super().__init__() + self.module_name = module_name + self.qnn_tensor_name = qnn_tensor_name + self.node_name = node_name + self.scale = scale + self.zero_point = zero_point + self.revert_order = revert_order + + def forward(self, x): + return x + + +class QNNIntermediateDebugger: + """This is a debugger tool capable of retrieving intermediate results for CPU edge EP. + We can further compare these with QNN's intermediate output to identify any QNN accuracy issues. + """ + + def __init__(self): + self.intermediate_outputs = {} + + def set_edge_module(self, edge_module: torch.fx.graph_module.GraphModule): + self.orig_edge = copy.deepcopy(edge_module) + self.intermediate_output_module = self._insert_intermediate_module( + copy.deepcopy(edge_module) + ) + + def generate_results( + self, + title: str, + path: str, + output_format: OutputFormat, + inspector: Inspector, + evaluator: MetricEvaluatorBase = None, + keep_qnn_layout: bool = False, + ): + assert isinstance( + output_format, OutputFormat + ), "output_format passed in is not an instance of OutputFormat" + os.makedirs(path, exist_ok=True) + if keep_qnn_layout: + warnings.warn( + "[QNN Delegate Debugger]: keep_qnn_layout is not recommended for general use case. " + "QNN and CPU has different dtype(FP V.S. Quantized) and data formats(NCHW V.S. NHWC) in a lot of cases.", + stacklevel=1, + ) + + # Due to users can switch between keep_qnn_layout between generate_results, rematch this every time. + # Make this a class variable if repeat matching is taking too long and handle keep_qnn_layout. + node_tensor_map = self._match_tensors( + inspector=inspector, + keep_qnn_layout=keep_qnn_layout, + ) + + if output_format == OutputFormat.SVG_GRAPHS: + assert evaluator is not None, "Please provide an evaluator." + export_svg( + title=title, + path=path, + evaluator=evaluator, + edge_module=self.orig_edge, + node_tensor_map=node_tensor_map, + ) + elif output_format == OutputFormat.CSV_FILES: + assert evaluator is not None, "Please provide an evaluator." + export_csv( + title=title, + path=path, + evaluator=evaluator, + edge_module=self.orig_edge, + node_tensor_map=node_tensor_map, + ) + elif output_format == OutputFormat.DUMP_RAW: + warnings.warn( + f"[QNN Delegate Debugger]: Param 'title' will be ignored, all raw files will be stored under: {path}", + stacklevel=1, + ) + if evaluator: + warnings.warn( + "[QNN Delegate Debugger]: Param 'evaluator' will be ignored as DUMP_RAW will only dump tensors to raw files but won't perform comparison.", + stacklevel=1, + ) + export_raw( + path=path, + edge_module=self.intermediate_output_module, + node_tensor_map=node_tensor_map, + ) + else: + warnings.warn( + "[QNN Delegate Debugger]: Unknown output format, do nothing.", + stacklevel=1, + ) + return + + def _insert_intermediate_module( # noqa: C901 + self, edge_module: torch.fx.graph_module.GraphModule + ): + """ + This feature is for intermediate tensor dump on the host CPU. + After we get an edge GraphModule, we insert submodule between each call_function node, + and we register forward hooks to store the intermediate results. + We have to use the edge GraphModule because this is the graph closest to what QNN is executing + while still being a valid graph to ExecuTorch. + + Args: + edge_module (exir.ExirExportedProgram): A deep copy of edge ir graph module. + We need to deep copy so we don't mess up the original edge_ep. + Returns: + exir.ExirExportedProgram: A deep copy of edge graph_module with intermediate modules inserted. + """ + + def hook_fn(module, input, output): + meta = {} + meta[QCOM_TENSOR_NAME] = module.qnn_tensor_name + meta["node_name"] = module.node_name + meta[QCOM_SCALE] = module.scale + meta[QCOM_ZERO_POINT] = module.zero_point + meta["revert_order"] = module.revert_order + meta["output"] = output # CPU output + + assert ( + module.qnn_tensor_name not in self.intermediate_outputs + ), f"{module.qnn_tensor_name} checked already, check if this is a potential error" + self.intermediate_outputs[module.qnn_tensor_name] = meta + + graph = edge_module.graph + module_count = 0 + for node in graph.nodes: + if node.op == "call_function": + module_name = f"intermediate_module_{module_count}" + module_count += 1 + with graph.inserting_after(node): + scale = None + zero_point = None + if QCOM_QUANT_ATTRS in node.meta: + scale = node.meta[QCOM_QUANT_ATTRS][QCOM_SCALE] + zero_point = node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT] + + revert_order = QCOM_AXIS_ORDER in node.meta + + if node.target == operator.getitem: + index = node.args[1] + # Ex: topk -> intermediate_module -> get_item + src_node = node.args[0].args[0] + qnn_tensor_name = src_node.meta[QCOM_TENSOR_NAME][index] + elif any(user.target == operator.getitem for user in node.users): + # For cases like topK, qnn_tensor_name is stored in get_item instead of source_node itself. + assert all( + user.target == operator.getitem for user in node.users + ), "Expect all users to be get_item node" + qnn_tensor_name = node.name + elif QCOM_TENSOR_NAME in node.meta: + assert ( + len(node.meta[QCOM_TENSOR_NAME]) == 1 + ), "Expecting a single qnn_tensor name but get more than 1." + qnn_tensor_name = node.meta[QCOM_TENSOR_NAME][0] + else: + # Unused + qnn_tensor_name = node.name + + obs = IntermediateModule( + module_name=module_name, + qnn_tensor_name=qnn_tensor_name, + node_name=node.name, + scale=scale, + zero_point=zero_point, + revert_order=revert_order, + ) + setattr( + edge_module, + module_name, + obs, + ) + new_obs = graph.create_node("call_module", module_name, (node,), {}) + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is new_obs: + continue + user_node.replace_input_with(node, new_obs) + + # Register hooks for all intermediate layers + for ( + _, + layer, + ) in edge_module.named_modules(): + if isinstance(layer, IntermediateModule): + layer.register_forward_hook(hook_fn) + + graph.eliminate_dead_code() + edge_module.recompile() + + return edge_module + + def _process_qnn_output(self, qnn_output: torch.tensor, meta: dict) -> torch.tensor: + """ + QNN intermediate results are all quantized. + We need to dequantize them to match CPU float values. + Additionally, we need to revert the layout format for layout-sensitive nodes. + + Args: + qnn_output (torch.tensor): QNN intermediate output from inspector event + meta (dict): The meta for this tensor/node that is stored during insert_intermediate_module(). + + Returns: + torch.tensor: Processed tensor that should have same dtype and shape as CPU tensors. + """ + qnn_output = qnn_output.to(torch.float32) + if meta[QCOM_SCALE] is not None: + scale = meta[QCOM_SCALE] + zero_point = meta[QCOM_ZERO_POINT] + qnn_output = ( + qnn_output.sub(zero_point).mul(scale).to(torch.float32).contiguous() + ) + if meta["revert_order"]: + axis_order = LayoutTransform.get_axis_order( + eval_shape(qnn_output.shape), reverse=True + ) + qnn_output = qnn_output.permute(axis_order) + return qnn_output + + def _match_tensors(self, inspector: Inspector, keep_qnn_layout: bool = False): + """ + Map QNN tensors back to CPU tensors. + Create a map using the node name as the key and (preprocessed/postprocessed QNN tensor, CPU tensor, meta) as the value. + We need meta because it holds values such as scale, offset, layout sensitivity, etc. + + Args: + inspector (Inspector): Inspector that parse QNN runtime intermediate outputs + keep_qnn_layout (bool): If true, store QNN outputs in NHWC format. Not recommended for general users. + + Returns: + A dict storing {node_name : tuple(qnn_output, cpu_output, meta_info)} + Meta_info is the info stored during forward hook_fn. + """ + + # node_tensor_map {key: tuple(qnn_output, cpu_output, meta_info)} + node_tensor_map = {} + # OPs that only exists in QNN but not CPU Golden + unmatched_qnn_tensors = [] + # E.g.: DELEGATE_CALL (This is the model input data), 'Method::execute' + ignored_events = [] + # Collected with forward hook + intermediate_outputs = self.intermediate_outputs + for event_block in inspector.event_blocks: + if event_block.name == "Execute": + for event in event_block.events: + # If user enables profiling and dump intermediate outputs the same time, we need to skip the profiling event + if event.perf_data is not None and event.is_delegated_op: + continue + if meta := intermediate_outputs.get(event.name): + node_name = meta["node_name"] + cpu_output = meta["output"] + qnn_output = ( + event.debug_data[0] + if keep_qnn_layout + else self._process_qnn_output(event.debug_data[0], meta) + ) + node_tensor_map[node_name] = ( + qnn_output, + cpu_output, + meta, + ) + + else: + ( + unmatched_qnn_tensors.append(event.name) + if event.is_delegated_op + else ignored_events.append(event.name) + ) + + warnings.warn( + f"The following events are ignored: {ignored_events}", stacklevel=1 + ) + warnings.warn( + f"The following QNN OPs are missing CPU reference. OPs added during qnn_preprocess will not have CPU reference. Please ensure the operations below are created during qnn_preprocess. {unmatched_qnn_tensors}", + stacklevel=1, + ) + return node_tensor_map diff --git a/backends/qualcomm/debugger/utils.py b/backends/qualcomm/debugger/utils.py index b1d3ea84900..aaa403dd7c0 100644 --- a/backends/qualcomm/debugger/utils.py +++ b/backends/qualcomm/debugger/utils.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Tuple -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import pandas as pd import torch from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset @@ -21,7 +21,7 @@ def __init__( self, filename: str, directory: str, - py_op_wrapper_list: [PyQnnWrapper.PyQnnOpWrapper], + py_op_wrapper_list: [PyQnnManager.PyQnnOpWrapper], dot_string=False, ): self.filename = filename @@ -98,13 +98,13 @@ def add_node(self, node_list, excel_data): offset = [] if ( quantization_encoding - == PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET + == PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET ): scale.append(node.quantizeParams.scaleOffsetEncoding.scale) offset.append(node.quantizeParams.scaleOffsetEncoding.offset) elif ( quantization_encoding - == PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET + == PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET ): for i in range( node.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets @@ -159,7 +159,7 @@ def to_excel(self, excel_data): offset = entry["offset"] if ( entry["tensor_type"] - == PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC + == PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC ): param_rows.append({"name": name, "scale": scale, "offset": offset}) else: @@ -183,11 +183,11 @@ def draw(self): cleanup=not self.dot_string, ) source_file = os.path.join(temp_directory, f"{self.filename}.svg") - destination_file = os.path.join(".", f"{self.filename}.svg") + destination_file = os.path.join(self.directory, f"{self.filename}.svg") shutil.move(source_file, destination_file) if self.dot_string: dot_file = os.path.join(temp_directory, f"{self.filename}") - dot_dest_file = os.path.join(".", f"{self.filename}.dot") + dot_dest_file = os.path.join(self.directory, f"{self.filename}.dot") shutil.move(dot_file, dot_dest_file) @@ -348,8 +348,8 @@ def generate_optrace( qnn_binary_file="forward_0.dlc", ): """ - Generate Qnn HTP Optrace Profiling https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/htp_backend.html#qnn-htp-optrace-profiling - and QNN HTP Analysis Summary (QHAS) https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/htp_backend.html#qnn-htp-analysis-summary-qhas + Generate Qnn HTP Optrace Profiling https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/htp_backend.html#qnn-htp-optrace-profiling + and QNN HTP Analysis Summary (QHAS) https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/htp_backend.html#qnn-htp-analysis-summary-qhas . You can utilize the QAIRT Visualizer (https://pypi.org/project/qairt-visualizer/) to visualize the results from the files above. """ graph_name, file_extension = os.path.splitext(qnn_binary_file) diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 7a2924fe756..2447e6a06c6 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -10,14 +10,16 @@ from executorch.exir.dialects._ops import ops as exir_ops not_supported_operator = [ + # output size is data dependent on the slice index + exir_ops.edge.aten._embedding_bag.default, + # for graph sharding purpose, different from the op used in decoder models exir_ops.edge.dim_order_ops._clone_dim_order.default, + # QNN does not support 4-bit embedding exir_ops.edge.quantized_decomposed.embedding_4bit.dtype, ] to_be_implemented_operator = [ - exir_ops.edge.aten._adaptive_avg_pool3d.default, - exir_ops.edge.aten.adaptive_max_pool2d.default, - exir_ops.edge.aten.avg_pool3d.default, + exir_ops.edge.aten.adaptive_max_pool3d.default, exir_ops.edge.aten.div.Tensor_mode, exir_ops.edge.aten.log10.default, exir_ops.edge.aten.log1p.default, diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 19e998f59a3..9d5a7467b8b 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -8,7 +8,6 @@ from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Tuple -import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch from executorch.backends.qualcomm.builders import node_visitor_manager from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader @@ -16,11 +15,11 @@ from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( flatbuffer_to_option, ) -from executorch.backends.qualcomm.utils.constants import ( - QCOM_AXIS_ORDER, - QCOM_BYPASS_NODE, -) +from executorch.backends.qualcomm.utils.constants import QCOM_BYPASS_NODE +from executorch.backends.qualcomm.utils.qnn_manager_lifecycle import ( + get_current_qnn_manager, +) from executorch.exir.backend.backend_details import CompileSpec from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( generate_partitions_from_list_of_nodes, @@ -55,7 +54,8 @@ def __init__( skip_node_id_set: set = None, skip_node_op_set: set = None, ): - python_options = flatbuffer_to_option(compiler_specs[0].value) + option = generate_qnn_executorch_option(compiler_specs) + python_options = flatbuffer_to_option(option) self.node_visitors = node_visitor_manager.get_node_visitors( edge_program, op_package_infos=python_options.op_package_options.op_package_infos, @@ -64,12 +64,10 @@ def __init__( self.skip_node_op_set = skip_node_op_set self.skip_node_id_set = skip_node_id_set self.nodes_to_wrappers = defaultdict(dict) - self.qnn_manager = PyQnnManager.QnnManager( - generate_qnn_executorch_option(compiler_specs) + self.qnn_manager = get_current_qnn_manager( + python_options.backend_options.backend_type, compiler_specs ) - self.qnn_manager.Init() - def is_node_supported(self, _, node: torch.fx.Node) -> bool: if node.op != "call_function" or node.target in not_supported_operator: return False @@ -118,9 +116,6 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: print(f"[QNN Partitioner Op Support]: {node.target.__name__} | {supported}") return supported - def __del__(self): - self.qnn_manager.Destroy() - class QnnPartitioner(Partitioner): """ @@ -213,11 +208,6 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu ) tag_mutated_buffer(edge_program) - # pop certain keys in meta for not affecting the passes in compilation - for node in edge_program.graph_module.graph.nodes: - if hasattr(node, "meta"): - # TODO: need to put property name in common definitions - node.meta.pop(QCOM_AXIS_ORDER, "") return PartitionResult( tagged_exported_program=edge_program, partition_tags=self.partition_tags ) diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 97e0b4bd109..7c45845f516 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -16,11 +16,19 @@ def generate_qnn_executorch_option( compiler_specs: List[CompileSpec], ) -> bytes: + qnn_compile_spec_buffer = None + for compiler_spec in compiler_specs: if compiler_spec.key == QCOM_QNN_COMPILE_SPEC: qnn_compile_spec_buffer = compiler_spec.value else: raise ValueError(f"unknown compiler spec key value: {compiler_spec.key}") + + if qnn_compile_spec_buffer is None: + raise ValueError( + f"QNN compile spec (key={QCOM_QNN_COMPILE_SPEC}) not found in compiler_specs" + ) + return qnn_compile_spec_buffer diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index 4e9cda21d02..c0351b01ed6 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -8,8 +8,6 @@ from collections import defaultdict from typing import Dict, final, List -import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager - import torch # noqa: F401 from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from executorch.backends.qualcomm.builders.node_visitor_manager import get_node_visitors @@ -20,7 +18,10 @@ ) from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( flatbuffer_to_option, - option_to_flatbuffer, +) +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER +from executorch.backends.qualcomm.utils.qnn_manager_lifecycle import ( + get_current_qnn_manager, ) from executorch.exir.backend.backend_details import ( BackendDetails, @@ -30,6 +31,7 @@ from torch.export.exported_program import ExportedProgram DEFAULT_DEBUG_HANDLE = 65535 +DEFAULT_GRAPH_NAME = "forward" logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -42,9 +44,16 @@ def _build_op_wrappers( edge_program: ExportedProgram, enable_tensor_dump: bool, op_package_infos: List[QnnExecuTorchOpPackageInfo], + use_mha2sha: bool, ): + for node in edge_program.graph_module.graph.nodes: + if hasattr(node, "meta"): + # pop certain keys in meta for not affecting the passes in compilation + node.meta.pop(QCOM_AXIS_ORDER, "") # QNN Delegate Specific Passes - graph_module = QnnPassManager().transform_for_preprocess_pipeline(edge_program) + graph_module = QnnPassManager().transform_for_preprocess_pipeline( + edge_program, use_mha2sha=use_mha2sha + ) assert graph_module is not None nodes_to_wrappers = defaultdict(dict) @@ -99,13 +108,16 @@ def preprocess( compile_specs: List[CompileSpec], ) -> PreprocessResult: option = generate_qnn_executorch_option(compile_specs) - qnn_manager = PyQnnManager.QnnManager(option) - qnn_manager.Init() obj_options = flatbuffer_to_option(option) + qnn_manager = get_current_qnn_manager( + obj_options.backend_options.backend_type, compile_specs + ) + qnn_manager.InitContext([DEFAULT_GRAPH_NAME]) py_op_wrapper_list = QnnBackend._build_op_wrappers( edge_program, qnn_manager.IsTensorDump(), obj_options.op_package_options.op_package_infos, + obj_options.use_mha2sha, ) qnn_context_binary = qnn_manager.Compile( @@ -118,7 +130,7 @@ def preprocess( f"Record all QNN API calls from saver backend at: {obj_options.saver_output_dir}" ) assert len(qnn_context_binary) != 0, "Failed to generate Qnn context binary." - qnn_manager.Destroy() + qnn_manager.DestroyContext() # For now, debug_handle_map is not used by QNN ExecuTorch return PreprocessResult( processed_bytes=bytes(qnn_context_binary), @@ -132,12 +144,9 @@ def preprocess_multimethod( ) -> PreprocessResult: # TODO: refactor QnnManager to consume multiple compile_spec # take first compile_specs here for the same partitions - graph_name = list(edge_programs.keys()) + graph_names = list(edge_programs.keys()) compile_spec = list(compile_specs.values())[0][0] - # gather all graph names option = flatbuffer_to_option(compile_spec[0].value) - option.graph_name = graph_name - compile_spec[0].value = option_to_flatbuffer(option) # check if each graph has equal number of partitions num_sub_graphs = set() for edge_program in edge_programs.values(): @@ -149,15 +158,15 @@ def preprocess_multimethod( all_processed_results = {key: [] for key in edge_programs.keys()} num_sub_graphs = next(iter(num_sub_graphs)) + qnn_manager = get_current_qnn_manager( + option.backend_options.backend_type, compile_spec + ) for i in range(num_sub_graphs): # e.g. 2 methods (x, y) with 3 partitions # > context_binary_0: [x.subgraph_0, y.subgraph_0] # > context_binary_1: [x.subgraph_1, y.subgraph_1] # > context_binary_2: [x.subgraph_2, y.subgraph_2] - qnn_manager = PyQnnManager.QnnManager( - generate_qnn_executorch_option(compile_spec) - ) - qnn_manager.Init() + qnn_manager.InitContext(graph_names) py_op_wrapper_list, ctx_binary_list = [], [] for j, programs in enumerate(edge_programs.values()): logger.info(f"Processing Method({j}): ({i+1}/{num_sub_graphs})") @@ -165,6 +174,7 @@ def preprocess_multimethod( programs[i], qnn_manager.IsTensorDump(), option.op_package_options.op_package_infos, + option.use_mha2sha, ) if isinstance(py_op_wrappers, bytes): ctx_binary_list.append(py_op_wrappers) @@ -177,7 +187,9 @@ def preprocess_multimethod( ) if len(py_op_wrapper_list) == len(edge_programs.values()): - qnn_context_binary = qnn_manager.Compile(graph_name, py_op_wrapper_list) + qnn_context_binary = qnn_manager.Compile( + graph_names, py_op_wrapper_list + ) if option.saver: # TODO: Currently, only the first method is saved. Update this logic if saving multiple methods becomes necessary in the future. exit( @@ -186,7 +198,7 @@ def preprocess_multimethod( assert ( len(qnn_context_binary) != 0 ), "Failed to generate Qnn context binary." - qnn_manager.Destroy() + qnn_manager.DestroyContext() # methods should share the same context binary for current partition for key in edge_programs.keys(): all_processed_results[key].append( diff --git a/backends/qualcomm/quantizer/README.md b/backends/qualcomm/quantizer/README.md index 6870ecc76ac..1f9a373d47d 100644 --- a/backends/qualcomm/quantizer/README.md +++ b/backends/qualcomm/quantizer/README.md @@ -9,7 +9,7 @@ Thank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. ## References ### Qualcomm AI Engine Direct -- [Operator Definitions for HTP](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/HtpOpDefSupplement.html) +- [Operator Definitions for HTP](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/HtpOpDefSupplement.html) ### PyTorch - [ATen Operator Definitions](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native) @@ -40,7 +40,7 @@ In order to conduct PTQ for floating point precision graph, observers are requir kernel --> id6(Q_k) --> id7(DQ_k) --> id1(convolution) bias --> id8(Q_b) --> id9(DQ_b) --> id1(convolution) ``` -Qualcomm backend will consume the generated encodings and lower operators with fixed precision. This tutorial will guide you through the details of inserting observer and some useful utilies. +Qualcomm backend will consume the generated encodings and lower operators with fixed precision. This tutorial will guide you through the details of inserting observer and some useful utilities. ### Register Annotation via Operator Type Let's start with hooking callback for designated operator target: @@ -66,7 +66,7 @@ def annotate_xxx(node: Node, quantization_config: QuantizationConfig) -> None: - __quantization_config__: data structure describing quantization configurations for IO activation / weight / bias ### Example of Conv2d Annotation -Conv2d accepts up to three input tensors: `input activation`, `kernel`, `bias`. There are constraints imposed by [Qualcomm AI Engine Direct Manual](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/HtpOpDefSupplement.html#conv2d).
+Conv2d accepts up to three input tensors: `input activation`, `kernel`, `bias`. There are constraints imposed by [Qualcomm AI Engine Direct Manual](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/HtpOpDefSupplement.html#conv2d).
Take 8-bit fixed point as example: - __weight__: must be symmetrically quantized if per-channel observer is applied - __bias__: must have `QNN_DATATYPE_SFIXED_POINT_32` and be symmetrically quantized with expected encoding `scales = weight.scales * input.scale`, `offset = 0` if per-channel observer is applied. @@ -105,7 +105,7 @@ def ptq_per_channel_quant_config( return quantization_config ``` -Here we choose `torch.uint8` + `MinMaxObserver` for better converage of IO activation and apply rules to `weight` w/`PerChannelMinMaxObserver`, `bias` w/`_derived_bias_quant_spec` (a callable method to calculate encoding in desired way) to meet aforementioned constraints. The well-defined `quantizaton_config` will then be shipped to callback for annotation.
+Here we choose `torch.uint8` + `MinMaxObserver` for better coverage of IO activation and apply rules to `weight` w/`PerChannelMinMaxObserver`, `bias` w/`_derived_bias_quant_spec` (a callable method to calculate encoding in desired way) to meet aforementioned constraints. The well-defined `quantizaton_config` will then be shipped to callback for annotation.
Now, we can start to fill in the function body: - Register annotator @@ -147,13 +147,13 @@ Now, we can start to fill in the function body: - Update node's meta with framework compatible data structure ```python - node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, output_qspec=quantization_config.output_activation, _annotated=True, ) ``` - After done processing `input_qspec_map`, it's required to have it in node's meta with special tag (`QUANT_ANNOTATION_KEY`) for `convert_pt2e` to properly insert observers. + After done processing `input_qspec_map`, it's required to have it in node's meta with special tag (`Q_ANNOTATION_KEY`) for `convert_pt2e` to properly insert observers. ### Common Annotators For operators without extra parameters to be observed, there are pre-defined annotation method for convenience: diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index a27ad2b4a5c..36f321f866c 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -25,6 +25,8 @@ ) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +from .observers.concat_observer import ConcatObserver + from .qconfig import ( get_16a16w_qnn_ptq_config, get_16a4w_qnn_qat_config, @@ -68,7 +70,7 @@ def _is_float_tensor(node: Node): or not isinstance(node.meta["val"], FakeTensor) ): return False - return node.meta["val"].dtype == torch.float32 + return node.meta["val"].dtype in (torch.bfloat16, torch.float32) def _mark_nodes_as_annotated(nodes: List[Node]): @@ -518,6 +520,29 @@ def annotate_full(node: Node, quantization_config: QuantizationConfig) -> None: ) +@register_annotator([torch.ops.aten.grid_sampler.default]) +def annotate_grid_sampler(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + input_act_qsec = quantization_config.input_activation + output_act_qsec = quantization_config.output_activation + + input_qspec_map = {} + input_act0 = node.args[0] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qsec + + input_act1 = node.args[1] + if isinstance(input_act1, Node): + input_qspec_map[input_act1] = input_act_qsec + + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qsec, + _annotated=True, + ) + + @register_annotator( [torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default] ) @@ -559,6 +584,27 @@ def annotate_neg(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.adaptive_max_pool2d.default]) +def annotate_adaptive_max_pool2d( + node: Node, quantization_config: QuantizationConfig +) -> None: + if _is_annotated([node]): + return + input_act_qsec = quantization_config.input_activation + output_act_qsec = quantization_config.output_activation + + input_qspec_map = {} + input_act0 = node.args[0] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qsec + + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qsec, + _annotated=True, + ) + + @register_annotator( [ torch.ops.aten.adaptive_avg_pool1d.default, @@ -576,6 +622,18 @@ def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> N annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.avg_pool3d.default]) +def annotate_avgpool3d(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + +@register_annotator([torch.ops.aten.adaptive_avg_pool3d.default]) +def annotate_adaptive_avgpool3d( + node: Node, quantization_config: QuantizationConfig +) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.permute.default]) def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) @@ -674,9 +732,11 @@ def annotate_pad(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) -@register_annotator([torch.ops.aten.reshape.default]) +@register_annotator([torch.ops.aten.reshape.default, torch.ops.aten.unflatten.int]) def annotate_reshape(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_single_in_single_out(node, quantization_config) + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_share_out(node, quantization_config) @register_annotator([torch.ops.aten.select.int]) @@ -691,7 +751,7 @@ def annotate_sign(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator([torch.ops.aten.slice.Tensor]) def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_single_in_single_out(node, quantization_config) + annotate_single_in_share_out(node, quantization_config) @register_annotator([torch.ops.aten.slice_scatter.default]) @@ -839,7 +899,7 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non ) -@register_annotator([torch.ops.aten.__and__.Tensor]) +@register_annotator([torch.ops.aten.__and__.Tensor, torch.ops.aten.logical_and.default]) def annotate_and(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) @@ -879,7 +939,7 @@ def annotate_unsqueeze_copy( annotate_single_in_share_out(node, quantization_config) -@register_annotator([torch.ops.aten.transpose.int]) +@register_annotator([torch.ops.aten.transpose.int, torch.ops.aten.swapaxes.default]) def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) if not _is_annotated([node]): @@ -1094,11 +1154,13 @@ def annotate_cdist(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator( [ + torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default, torch.ops.aten.conv2d.padding, - torch.ops.aten.conv1d.default, - torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv3d.default, torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d.input, torch.ops.aten.convolution.default, ] ) @@ -1275,31 +1337,40 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> @register_annotator([torch.ops.aten.cat.default, torch.ops.aten.concat.default]) def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None: - input_nodes = node.args[0] if _is_annotated([node]) or not _is_float_tensor(node): return - assert isinstance(input_nodes, Sequence) - - first_input_node = input_nodes[0] - input_qspec_map = {} - assert isinstance(first_input_node, Node) - assert isinstance(node, Node) - if _is_float_tensor(first_input_node): - input_qspec_map[first_input_node] = quantization_config.input_activation - share_qparams_with_input_act0_qspec = SharedQuantizationSpec( - (first_input_node, node) - ) - - for input_node in input_nodes[1:]: - if input_node not in input_qspec_map: - assert isinstance(input_node, Node) - if _is_float_tensor(input_node): - input_qspec_map[input_node] = share_qparams_with_input_act0_qspec - + input_qspec_map, input_nodes = {}, node.args[0] + for input in input_nodes: + input_qspec = input.meta.get(Q_ANNOTATION_KEY, None) + if ( + # placeholder + input_qspec is None + or + # keep shared qspec here for propagation the data range + # without introducing extra requantizations + not isinstance(input_qspec.output_qspec, SharedQuantizationSpec) + ): + input_qspec_map[input] = quantization_config.input_activation + + output_qspec = QuantizationSpec( + dtype=quantization_config.output_activation.dtype, + qscheme=quantization_config.output_activation.qscheme, + quant_max=quantization_config.output_activation.quant_max, + quant_min=quantization_config.output_activation.quant_min, + observer_or_fake_quant_ctr=ConcatObserver.with_args( + # we need to know the concat node in order to hack all the input observers' data range + # since deep copy of fake tensor (node.meta["val"]) is inhibited + # we could only ship grap & node name and perform postprocess inside observer currently + **{ + "node_name": node.name, + "graph": node.graph, + } + ), + ) node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=share_qparams_with_input_act0_qspec, + output_qspec=output_qspec, _annotated=True, ) @@ -1343,6 +1414,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None: input_act = node.args[0] assert isinstance(input_act, Node) input_qspec_map[input_act] = quantization_config.input_activation + share_qparams_with_input_node_qspec = SharedQuantizationSpec((input_act, node)) node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, @@ -1351,12 +1423,12 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None: for user in node.users: user.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - output_qspec=quantization_config.output_activation, + output_qspec=share_qparams_with_input_node_qspec, _annotated=True, ) -@register_annotator([torch.ops.aten.where.self]) +@register_annotator([torch.ops.aten.where.self, torch.ops.aten.where.ScalarSelf]) def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): return @@ -1366,7 +1438,6 @@ def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None: assert isinstance(input_node, Node) if _is_float_tensor(input_node): input_qspec_map[input_node] = quantization_config.input_activation - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, output_qspec=( diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 3f10dbaa3fc..e6969913c4e 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum, unique + from typing import Sequence import torch @@ -17,7 +17,6 @@ get_8a8w_qnn_ptq_config, get_8a8w_qnn_qat_config, get_ptq_per_channel_quant_config, - get_qat_per_channel_quant_config, QuantizationConfig, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -32,36 +31,6 @@ ) -def annotate_down_proj( - gm: torch.fx.GraphModule, quantization_config: QuantizationConfig -): - for node in gm.graph.nodes: - if ( - node.target == torch.ops.aten.conv2d.default - and any(s in node.meta["stack_trace"] for s in ["forward_feedfoward_conv"]) - and node.args[0].target == torch.ops.aten.mul.Tensor - ): - input_qspec_map = {} - input_qspec_map[node.args[0]] = quantization_config.input_activation - input_qspec_map[node.args[1]] = quantization_config.weight - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) - - -@unique -class StaticLLMQuantConfig(Enum): - """ - Layer namespace configuration for Qualcomm's static LLaMA quantization. - """ - - wq_sha = "wq_sha" # Query weight (single head) - wk_sha = "wk_sha" # Key weight (single head) - wv_sha = "wv_sha" # Value weight (single head) - - def annotate_eurobert(gm: torch.fx.GraphModule): """ QNN does not support int32 -> signed 16bit quant @@ -123,46 +92,6 @@ def annotate_mimi_decoder(gm: torch.fx.GraphModule): break -def annotate_output_16a8w(gm: torch.fx.GraphModule, is_qat: bool = False) -> None: - """ - This function is for static LLM models. - This function will annotate the last conv(linear), which is the lm_head, as 16a8w. - """ - - def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: - input_qspec_map = {} - input_act = node.args[0] - input_spec = quantization_config.input_activation - input_qspec_map[input_act] = input_spec - - weight = node.args[1] - input_qspec_map[weight] = quantization_config.weight - - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) - - if is_qat: - quantization_config_16a8w_per_channel = get_qat_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - else: - quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: - if "nn_module_stack" in node.meta: - module_values_list = list(node.meta["nn_module_stack"].values()) - full_qualified_name = module_values_list[-1][0] - if full_qualified_name == "output.conv": - annotate_conv2d( - node, quantization_config=quantization_config_16a8w_per_channel - ) - - def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): for node in gm.graph.nodes: if node.op == "output": @@ -197,48 +126,6 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): ) -def annotate_qkv_proj_sha( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - qkv_tags: set[StaticLLMQuantConfig], -): - """ - Annotates QKV projection layers in a GraphModule for quantization, - specifically layers defined in StaticLLMQuantConfig. - - Args: - qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers - (e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in - StaticLLMQuantConfig are allowed. - - Raises: - ValueError: If any tag in `qkv_tags` is not among the allowed enum members. - """ - - # Get all valid tags from the StaticLLMQuantConfig enum - allowed_tags = set(StaticLLMQuantConfig) - invalid_tags = qkv_tags - allowed_tags - if invalid_tags: - raise ValueError( - f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}" - ) - - for node in gm.graph.nodes: - if node.target == torch.ops.aten.conv2d.default and any( - tag.value in node.meta["stack_trace"] for tag in qkv_tags - ): - input_qspec_map = {} - input_qspec_map[node.args[0]] = quantization_config.input_activation - input_qspec_map[node.args[1]] = quantization_config.weight - if len(node.args) > 2 and isinstance(node.args[2], Node): - input_qspec_map[node.args[2]] = quantization_config.bias(node) - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) - - def annotate_kv_8bit( # noqa: C901 gm: torch.fx.GraphModule, is_qat=False, @@ -259,7 +146,6 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig): input_act = node.args[0] input_spec = quantization_config.input_activation input_qspec_map[input_act] = input_spec - input_act1 = node.args[1] input_spec1 = quantization_config.weight input_qspec_map[input_act1] = input_spec1 @@ -370,7 +256,10 @@ def annotate_matmul_input1(node: Node, is_qat: str): torch.ops.aten.transpose.int, torch.ops.aten.view.default, torch.ops.aten.reshape.default, + torch.ops.aten.select.int, torch.ops.aten.slice.Tensor, + torch.ops.aten.expand.default, + torch.ops.aten.unsqueeze.default, ]: annotate_single_in_single_out(node, quantization_config_8a8w) node = node.args[0] diff --git a/backends/qualcomm/quantizer/observers/concat_observer.py b/backends/qualcomm/quantizer/observers/concat_observer.py new file mode 100644 index 00000000000..cd2a1a99805 --- /dev/null +++ b/backends/qualcomm/quantizer/observers/concat_observer.py @@ -0,0 +1,74 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchao.quantization.pt2e import UniformQuantizationObserverBase + + +class ConcatObserver(UniformQuantizationObserverBase): + """ + Fetch maximum data range of all tensors to be concatenated + """ + + def __init__( + self, + node_name, + graph, + dtype=torch.uint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, # noqa: B008 + is_dynamic=False, + **kwargs, + ) -> None: + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + # get concat node and its inputs + self.concat_node = [node for node in graph.nodes if node.name == node_name][0] + self.input_nodes = self.concat_node.args[0] + self.input_observers = [] + + def forward(self, x_orig): + # calculate the min / max first + self.min_val = min(self.min_val, x_orig.min()) + self.max_val = max(self.max_val, x_orig.max()) + + if len(self.input_observers) == 0: + # collect observers first if they are not cached + # we cannot do this in constructor since observers have not appeared + for node in self.input_nodes: + obs_node = list( + filter(lambda user: user != self.concat_node, node.users.keys()) + )[0] + self.input_observers.append( + getattr(obs_node.graph.owning_module, obs_node.name) + ) + + # update min / max for all observers of input nodes + for observers in self.input_observers: + observers.min_val = self.min_val + observers.max_val = self.max_val + + return x_orig + + def calculate_qparams(self): + return self._calculate_qparams(self.min_val, self.max_val) diff --git a/backends/qualcomm/quantizer/observers/per_block_param_observer.py b/backends/qualcomm/quantizer/observers/per_block_param_observer.py index b3f854db527..13ab51008ed 100644 --- a/backends/qualcomm/quantizer/observers/per_block_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_block_param_observer.py @@ -7,12 +7,13 @@ from typing import Tuple import torch -from torchao.quantization.pt2e import MappingType, PerBlock +from torchao.quantization.pt2e import FakeQuantize, MappingType, PerBlock from torchao.quantization.pt2e._affine_quantization import ( _get_reduction_params, AffineQuantizedMinMaxObserver, choose_qparams_affine_with_min_max, ) +from torchao.quantization.quant_primitives import _fake_quantize_affine class PerBlockParamObserver(AffineQuantizedMinMaxObserver): @@ -89,3 +90,56 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: self.preserve_zero, self.zero_point_domain, ) + + +class PerBlockParamFakeQuantize(FakeQuantize): + def __init__( + self, + dtype: torch.dtype = torch.int8, + block_size: torch.Size = None, + quant_min: int = None, + quant_max: int = None, + eps: float = torch.finfo(torch.float32).eps, # noqa: B008 + **kwargs, + ): + super().__init__() + assert ( + block_size is not None + ), "block_size must be provided for per-block quantization" + + self.activation_post_process = PerBlockParamObserver( + dtype=dtype, + block_size=block_size, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + **kwargs, + ) + self.dtype = dtype + self.block_size = block_size + self.quant_min = quant_min if quant_min is not None else torch.iinfo(dtype).min + self.quant_max = quant_max if quant_max is not None else torch.iinfo(dtype).max + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.numel() == 0: + return x + + self.activation_post_process(x) + scale, zero_point = self.activation_post_process.calculate_qparams() + + return _fake_quantize_affine( + x, + self.block_size, + scale, + zero_point, + quant_dtype=self.dtype, + quant_min=self.quant_min, + quant_max=self.quant_max, + ) + + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self.activation_post_process.calculate_qparams() + + def convert(self, model, observer_node): + self.activation_post_process.convert(model, observer_node) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 2f26cd27d31..77fb989ba44 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -10,6 +10,7 @@ import torch from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import ( + PerBlockParamFakeQuantize, PerBlockParamObserver, ) from torch import Tensor @@ -52,12 +53,37 @@ def _derive_bias_qparams_fn( act_scale, weight_scale ) derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) - derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) - if isinstance(weight_obs_or_fq, PerBlockParamObserver): + # TransposeConv per channel axis=1, and the weight_shape[1] = out_channel / groups. + # E.g., out_channel = 6, groups = 2, weight_shape[1] = 3, which means there are 3 pairs of scale/offset. + # However, bias still has 6 values, meaning it requires repeat_interleave 2 times derived_scale in order to + # generate 6 pairs of scale/offset to perform per channel quantization. For bias node, Conv OP builder will later + # only pass 3 pairs of scale/offset to QNN. + if ( + node.target + in { + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d.input, + } + and len(node.args) > 6 + and node.args[6] != 1 + ): + groups = node.args[6] + derived_scale = derived_scale.repeat_interleave(groups) + derived_zero = torch.zeros(derived_scale.size(), device=weight_zp.device).to( + torch.int32 + ) + + # Handle per-block quantization for both observer and fake quantize + weight_observer = weight_obs_or_fq + if isinstance(weight_obs_or_fq, PerBlockParamFakeQuantize): + # Extract the underlying observer from the fake quantize wrapper + weight_observer = weight_obs_or_fq.activation_post_process + + if isinstance(weight_observer, PerBlockParamObserver): # keep maximum scale of each channel for bias derived_scale = ( derived_scale.view(derived_scale.size(0), -1).amax(dim=-1) - / weight_obs_or_fq.num_steps + / weight_observer.num_steps ) derived_zero = derived_zero.view(derived_zero.size(0), -1).amax(dim=-1) return (derived_scale, derived_zero) @@ -66,7 +92,6 @@ def _derive_bias_qparams_fn( assert isinstance(input_act, Node) weight = node.args[1] assert isinstance(weight, Node) - return DerivedQuantizationSpec( derived_from=[(input_act, node), (weight, node)], derive_qparams_fn=_derive_bias_qparams_fn, @@ -81,7 +106,8 @@ def _derive_bias_qparams_fn( def get_8a8w_qnn_ptq_config( act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver ) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} + # the smallest scale: 0.0001 / 255 + extra_args: Dict[str, Any] = {"eps": 2**-21} act_quantization_spec = QuantizationSpec( dtype=torch.uint8, @@ -119,11 +145,68 @@ def get_8a8w_qnn_ptq_config( return quantization_config +def get_8a4w_qnn_ptq_config( + act_symmetric: bool = True, act_observer=MovingAverageMinMaxObserver +) -> QuantizationConfig: + # the smallest scale: 0.0001 / 255 + extra_args: Dict[str, Any] = {"eps": 2**-21} + + if act_symmetric: + # If zero_point is 128, htp can do optimizations. + # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. + # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + else: + # PyTorch will remove redundant observers based on attributes such as: + # dtype, quant_min, quant_max, ch_axis, etc. + # Providing values like quant_min and quant_max can help observers compare + # and further reduce the number of observers. + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=torch.iinfo(torch.uint8).min, + quant_max=torch.iinfo(torch.uint8).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + # 4 bits quantization only supports specific ops. def get_16a4w_qnn_ptq_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} + # the smallest scale: 0.0001 / 65535 + extra_args: Dict[str, Any] = {"eps": 2**-29} act_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, @@ -162,7 +245,8 @@ def get_16a4w_qnn_ptq_config( def get_16a8w_qnn_ptq_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} + # the smallest scale: 0.0001 / 65535 + extra_args: Dict[str, Any] = {"eps": 2**-29} act_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, @@ -199,13 +283,13 @@ def get_16a8w_qnn_ptq_config( def get_16a8w_qnn_qat_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_fake_quant_ctr = FakeQuantize.with_args( + # the smallest scale: 0.0001 / 65535 + extra_args: Dict[str, Any] = {"eps": 2**-29} + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, - reduce_range=True, observer=act_observer.with_args(**extra_args), ) act_quantization_spec = QuantizationSpec( @@ -220,7 +304,6 @@ def get_16a8w_qnn_qat_config( quant_min=torch.iinfo(torch.int8).min + 1, quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( @@ -258,7 +341,8 @@ def get_16a8w_qnn_qat_config( def get_16a16w_qnn_ptq_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} + # the smallest scale: 0.0001 / 65535 + extra_args: Dict[str, Any] = {"eps": 2**-29} act_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, @@ -300,8 +384,10 @@ def get_ptq_per_channel_quant_config( weight_dtype=torch.int8, act_observer=MovingAverageMinMaxObserver, act_symmetric: bool = False, + ch_axis: int = 0, ) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} + # the smallest scale: 0.0001 / 65535 + extra_args: Dict[str, Any] = {"eps": 2**-29} supported_act_types = { torch.uint8, @@ -349,7 +435,7 @@ def get_ptq_per_channel_quant_config( ), quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, - ch_axis=0, + ch_axis=ch_axis, observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), ) @@ -370,8 +456,10 @@ def get_ptq_per_block_quant_config( weight_dtype=torch.int8, act_observer=MovingAverageMinMaxObserver, act_symmetric: bool = False, + ch_axis: int = 0, ) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} + # the smallest scale: 0.0001 / 65535 + extra_args: Dict[str, Any] = {"eps": 2**-29} quantization_config = get_ptq_per_channel_quant_config( act_dtype=act_dtype, weight_dtype=weight_dtype, @@ -385,7 +473,7 @@ def get_ptq_per_block_quant_config( ), quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, - ch_axis=0, + ch_axis=ch_axis, observer_or_fake_quant_ctr=PerBlockParamObserver.with_args(**extra_args), ) return QuantizationConfig( @@ -396,11 +484,97 @@ def get_ptq_per_block_quant_config( ) +def get_qat_per_block_quant_config( + act_dtype=torch.uint8, + weight_dtype=torch.int8, + act_observer=MovingAverageMinMaxObserver, + act_symmetric: bool = False, + ch_axis: int = 0, +) -> QuantizationConfig: + supported_act_types = { + torch.uint8, + torch.uint16, + torch.int8, + torch.int16, + } + supported_weight_dtypes = {torch.int4, torch.int8} + assert ( + act_dtype in supported_act_types + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" + + assert ( + weight_dtype in supported_weight_dtypes + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" + + # torch does not support uint16 quantization, use int32 to bypass + if act_symmetric: + # If zero_point is 128, htp can do optimizations. + # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. + # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + qscheme=torch.per_tensor_symmetric, + observer=act_observer, + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) + else: + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer=act_observer, + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) + + weight_fake_quant_ctr = PerBlockParamFakeQuantize.with_args( + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=ch_axis, + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=ch_axis, + observer_or_fake_quant_ctr=weight_fake_quant_ctr, + ) + + bias_quantization_spec = _derived_bias_quant_spec + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + # TODO merge qat and ptq to a function, and use a bool flag to control it def get_8a8w_qnn_qat_config( act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver ) -> QuantizationConfig: - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.uint8, qscheme=( torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine @@ -421,7 +595,6 @@ def get_8a8w_qnn_qat_config( quant_min=torch.iinfo(torch.int8).min + 1, quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( @@ -438,7 +611,6 @@ def get_8a8w_qnn_qat_config( quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) bias_quantization_spec = QuantizationSpec( @@ -462,12 +634,11 @@ def get_8a8w_qnn_qat_config( def get_16a4w_qnn_qat_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, - reduce_range=True, observer=act_observer, ) act_quantization_spec = QuantizationSpec( @@ -484,7 +655,6 @@ def get_16a4w_qnn_qat_config( quant_max=7, qscheme=torch.per_tensor_symmetric, ch_axis=0, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( @@ -501,7 +671,6 @@ def get_16a4w_qnn_qat_config( quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=MovingAverageMinMaxObserver, ) bias_quantization_spec = QuantizationSpec( @@ -527,6 +696,7 @@ def get_qat_per_channel_quant_config( weight_dtype=torch.int8, act_observer=MovingAverageMinMaxObserver, act_symmetric=False, + ch_axis: int = 0, ) -> QuantizationConfig: supported_act_types = { torch.uint8, @@ -548,10 +718,9 @@ def get_qat_per_channel_quant_config( # If zero_point is 128, htp can do optimizations. # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, qscheme=torch.per_tensor_symmetric, - reduce_range=True, observer=act_observer, ) act_quantization_spec = QuantizationSpec( @@ -561,12 +730,11 @@ def get_qat_per_channel_quant_config( observer_or_fake_quant_ctr=act_fake_quant_ctr, ) else: - act_fake_quant_ctr = FakeQuantize.with_args( + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max, qscheme=torch.per_tensor_affine, - reduce_range=True, observer=act_observer, ) act_quantization_spec = QuantizationSpec( @@ -584,7 +752,7 @@ def get_qat_per_channel_quant_config( ), quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, - ch_axis=0, + ch_axis=ch_axis, observer=MovingAveragePerChannelMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( @@ -594,7 +762,7 @@ def get_qat_per_channel_quant_config( ), quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, - ch_axis=0, + ch_axis=ch_axis, observer_or_fake_quant_ctr=weight_fake_quant_ctr, ) diff --git a/backends/qualcomm/quantizer/quant_recipe.py b/backends/qualcomm/quantizer/quant_recipe.py new file mode 100644 index 00000000000..92b9757e1fb --- /dev/null +++ b/backends/qualcomm/quantizer/quant_recipe.py @@ -0,0 +1,402 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import re +from abc import ABC, abstractmethod +from enum import IntEnum, unique +from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple + +import torch +from executorch.backends.qualcomm.quantizer.quantizer import ( + ModuleQConfig, + QnnQuantizer, + QuantDtype, + QuantizationConfig, +) +from tabulate import tabulate +from torch._ops import OpOverload +from torchao.quantization.pt2e import UniformQuantizationObserverBase + +from .annotators import OP_ANNOTATOR + + +def extract_node_metadata_mapping(node: torch.fx.Node): + deepest_module = None + + if node.op == "call_function" and "nn_module_stack" in node.meta: + deepest_module = list(node.meta["nn_module_stack"].values())[-1][0] + + return deepest_module + + +@unique +class QuantGranularity(IntEnum): + """ + Defines the quantization granularity levels: + - PER_TENSOR: single scale offset for entire tensor. + - PER_CHANNEL: independent scale/offset per channel within tensor. + - PER_BLOCK: independent scale/offset per block within tensor. + """ + + PER_TENSOR = 0 + PER_CHANNEL = 1 + PER_BLOCK = 2 + + +class QuantizationStrategy(ABC): + """ + Abstract base class for strategies that assign quantization config to FX graph nodes. + + Each strategy defines how to match nodes (e.g., by operator target, module stack pattern) + and provides a corresponding quantization config when a match occurs. + + Attributes: + quant_dtype (QuantDtype): Data type for quantization (e.g., 16a8w, 16a4w). + is_qat (bool): Whether the strategy applies QAT (True) or PTQ (False). + granularity (QuantGranularity): Quantization granularity (PER_TENSOR, PER_CHANNEL, PER_BLOCK). + act_observer (UniformQuantizationObserverBase): Observer class for activation quantization. + extra_kwargs (Dict): Additional configuration parameters (e.g., block size). + note (str): Developer notes or comments. + priority (int): Priority for resolving conflicts among multiple strategies. + + Abstract Methods: + _matches(node): Return True if the node matches this strategy's criteria. + """ + + def __init__( + self, + quant_dtype: QuantDtype, + is_qat: bool, + granularity: QuantGranularity, + act_observer: UniformQuantizationObserverBase, + extra_kwargs: Dict, + note: str, + priority: int, + ): + self.quant_dtype = quant_dtype + self.is_qat = is_qat + self.granularity = granularity + self.act_observer = act_observer + self.extra_kwargs = extra_kwargs + self.note = note + self.priority = priority + + self.quant_config = ModuleQConfig( + quant_dtype=self.quant_dtype, + is_qat=self.is_qat, + is_conv_per_channel=True, + is_linear_per_channel=True, + act_observer=self.act_observer, + ) + + @abstractmethod + def _matches(self, node: torch.fx.Node) -> bool: + pass + + def get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]: + op: OpOverload = node.target + + if not self._matches(node): + return None + + if self.granularity == QuantGranularity.PER_TENSOR: + return self.quant_config.quant_config + elif self.granularity == QuantGranularity.PER_CHANNEL: + ch_axis = self.quant_config.use_per_channel_weight_quant_ops.get(op) + assert ( + ch_axis is not None + and len(self.quant_config.per_channel_quant_config_list) > ch_axis + ), f"Unsupported per channel quantization axis: {ch_axis}, please increase the range of per_channel_quant_config_list" + return self.quant_config.per_channel_quant_config_list[ch_axis] + elif self.granularity == QuantGranularity.PER_BLOCK: + ch_axis = self.quant_config.op_axis_dict.get(op) + assert ( + ch_axis is not None + and len(self.quant_config.per_block_quant_config_list) > ch_axis + ), f"Unsupported per block quantization axis: {ch_axis}, please increase the range of per_block_quant_config_list" + config = self.quant_config.per_block_quant_config_list[ch_axis] + config.block_size = self.extra_kwargs["block_size"] + return config + else: + raise ValueError( + f"Unsupported quantization granularity: {self.granularity}. " + f"Supported values: {[granularity.name for granularity in QuantGranularity]}" + ) + + +class ByNodeTarget(QuantizationStrategy): + """ + Strategy that assigns quantization config to nodes based on their op target. + Useful for applying quantization to specific operations such as `aten.conv2d` or `aten.linear`. + + Attributes: + targets (Set[OpOverload]): Set of op overloads to match against node targets. + """ + + def __init__( + self, + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs, + note, + priority, + targets: Set[OpOverload], + ): + super().__init__( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs, + note, + priority, + ) + self.targets = targets + + def _matches(self, node: torch.fx.Node) -> bool: + # Matching: A node matches if its `node.target` is in the `targets` set. + return node.target in self.targets + + +class ByNameRegex(QuantizationStrategy): + """ + Strategy that assigns quantization config to nodes whose module stack matches given regex patterns. + Useful for targeting layers by name patterns (e.g., "layers.[0-3].feed_forward" or "layers.*.attention") in the module hierarchy. + + Attributes: + patterns (Set[str]): Set of regex patterns to match against module stack paths. + """ + + def __init__( + self, + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs, + note, + priority, + patterns: Set[str], + ): + super().__init__( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs, + note, + priority, + ) + self.patterns = patterns + + def _matches(self, node: torch.fx.Node) -> bool: + # Matching: A node matches if its `nn_module_stack` metadata contains a module path that matches any regex pattern. + if node.op == "call_function" and "nn_module_stack" in node.meta: + for module_stack, _ in list(node.meta["nn_module_stack"].values())[::-1]: + if module_stack and any( + re.search(p, module_stack) for p in self.patterns + ): + return True + return False + + +class QuantRecipe: + """ + A QuantRecipe builder for defining quantization strategies to an FX GraphModule. + + QuantRecipe manages a collection of strategies (e.g., by operator target or regex pattern) + and applies them to nodes in an FX graph to produce fine-grained quantization annotations. + + Attributes: + verbose (bool): If True, prints a summary after annotation. + custom_quant_annotations (Sequence[Callable]): Custom annotation functions applied after strategies. + + _strategies (List[QuantizationStrategy]): Registered quantization strategies. + _pending_annotate_nodes (Dict[torch.fx.Node, Tuple[QuantizationConfig, QuantizationStrategy]]): + Internal mapping of nodes to their resolved quantization config and strategy. + """ + + def __init__( + self, + quant_dtype, + is_qat, + act_observer: UniformQuantizationObserverBase, + granularity: QuantGranularity, + note: str = "", + extra_kwargs: Optional[dict] = None, + verbose: bool = False, + ): + """ + Initialize a QuantRecipe with a default quantization strategy. + + Args: + quant_dtype (QuantDtype): Data type for quantization (e.g., int8, int4). + is_qat (bool): Whether to apply QAT (True) or PTQ (False). + act_observer (UniformQuantizationObserverBase): Observer class for activation quantization. + granularity (QuantGranularity): Quantization granularity (PER_TENSOR, PER_CHANNEL, PER_BLOCK). + note (str): Optional description for the default strategy. + extra_kwargs (dict, optional): Additional parameters (e.g., block size, group size). + verbose (bool): If True, prints a summary table after annotation. + """ + + self.verbose = verbose + self.custom_quant_annotations: Sequence[Callable] = [] + + self._strategies: List[QuantizationStrategy] = [] + self._pending_annotate_nodes: Dict[ + torch.fx.Node, Tuple[QuantizationConfig, QuantizationStrategy] + ] = {} + self._default_strategy = ByNodeTarget( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs or {}, + note, + priority=1, + targets=QnnQuantizer.SUPPORTED_OPS, + ) + + def _annotate_custom_annotation(self, gm: torch.fx.GraphModule) -> None: + for annotation_func in self.custom_quant_annotations: + annotation_func(gm) + + def annotate(self, graph_module: torch.fx.GraphModule): + # Sort node level strategies by (priority, insertion index). + # Higher priority value comes first; if priorities are equal, original insertion order is preserved. + strategies: List[QuantizationStrategy] = [ + strategy + for _, strategy in sorted( + enumerate(self._strategies), + key=lambda x: (x[1].priority, x[0]), + reverse=True, + ) + ] + # Ensure the default strategy is appended last + strategies.append(self._default_strategy) + + for node in graph_module.graph.nodes: + for strategy in strategies: + if isinstance(node.target, str) or node in self._pending_annotate_nodes: + continue + + if quant_config := strategy.get_quant_config(node): + self._pending_annotate_nodes[node] = (quant_config, strategy) + + if self.verbose: + print(self.summary()) + + for node in graph_module.graph.nodes: + if isinstance(node.target, str): + continue + if node not in self._pending_annotate_nodes: + print(f"No quant config is implemented for op, {node.target}") + continue + + OP_ANNOTATOR[node.target](node, self._pending_annotate_nodes[node][0]) + + # custom annotation + self._annotate_custom_annotation(graph_module) + + def add_node_target( + self, + targets, + quant_dtype, + is_qat, + act_observer: UniformQuantizationObserverBase, + granularity: QuantGranularity, + note: str = "", + priority: int = 1, + extra_kwargs: Optional[dict] = None, + ): + self._strategies.append( + ByNodeTarget( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs or {}, + note, + priority, + targets, + ), + ) + return self + + def add_regex( + self, + regex, + quant_dtype, + is_qat, + act_observer: UniformQuantizationObserverBase, + granularity: QuantGranularity, + note: str = "", + priority: int = 1, + extra_kwargs: Optional[dict] = None, + ): + """ + Add a quantization strategy targeting nodes whose module stack matches given regex patterns. + + Args: + regex (Iterable[str]): Regex patterns to match module stack paths. + quant_dtype (QuantDtype): Data type for quantization. + is_qat (bool): Whether to apply QAT or PTQ. + act_observer (UniformQuantizationObserverBase): Observer for activation quantization. + granularity (QuantGranularity): Tensor/channel/block granularity. + note (str): Optional description for the strategy. + priority (int): Strategy priority (higher value = higher precedence). + extra_kwargs (dict, optional): Additional parameters for the strategy. + """ + self._strategies.append( + ByNameRegex( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs or {}, + note, + priority, + regex, + ), + ) + return self + + def summary(self, max_rows: int = -1): + if not self._pending_annotate_nodes: + return None + + headers = [ + "module_stack", + "op_target", + "quantize", + "act_observer", + "granularity", + "note", + "extra_kwargs", + ] + rows = [] + for i, (node, (_, strategy)) in enumerate(self._pending_annotate_nodes.items()): + if max_rows > 0 and i >= max_rows: + break + + row = [ + extract_node_metadata_mapping(node), + node.target, + f"{strategy.quant_dtype.name}/{'QAT' if strategy.is_qat else 'PTQ'}", + strategy.act_observer.__name__, + strategy.granularity.name, + strategy.note, + strategy.extra_kwargs, + ] + rows.append(row) + + if max_rows > 0 and len(self._pending_annotate_nodes) > max_rows: + rows.append(["..."] * len(headers)) + + return tabulate(rows, headers=headers, tablefmt="grid") diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 5943b54d968..0d54b250bfd 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -24,10 +24,12 @@ get_16a4w_qnn_qat_config, get_16a8w_qnn_ptq_config, get_16a8w_qnn_qat_config, + get_8a4w_qnn_ptq_config, get_8a8w_qnn_ptq_config, get_8a8w_qnn_qat_config, get_ptq_per_block_quant_config, get_ptq_per_channel_quant_config, + get_qat_per_block_quant_config, get_qat_per_channel_quant_config, QuantizationConfig, ) @@ -44,6 +46,7 @@ "get_16a16w_qnn_ptq_config", "get_8a8w_qnn_ptq_config", "get_8a8w_qnn_qat_config", + "get_8a4w_qnn_ptq_config", "get_16a4w_qnn_qat_config", "get_ptq_per_block_quant_config", ] @@ -60,6 +63,7 @@ class QuantDtype(IntEnum): use_16a4w = 2 use_16a4w_block = 3 use_8a8w = 4 + use_8a4w = 5 QUANT_CONFIG_DICT = { @@ -109,6 +113,15 @@ class QuantDtype(IntEnum): partial(get_ptq_per_channel_quant_config), None, ), + (QuantDtype.use_8a4w, False): ( + get_8a4w_qnn_ptq_config, + partial( + get_ptq_per_channel_quant_config, + act_dtype=torch.uint8, + weight_dtype=torch.int4, + ), + None, + ), # QAT, (QuantDtype.use_16a4w, True): ( get_16a4w_qnn_qat_config, @@ -119,6 +132,19 @@ class QuantDtype(IntEnum): ), None, ), + (QuantDtype.use_16a4w_block, True): ( + get_16a4w_qnn_qat_config, + partial( + get_qat_per_channel_quant_config, + act_dtype=torch.uint16, + weight_dtype=torch.int4, + ), + partial( + get_qat_per_block_quant_config, + act_dtype=torch.uint16, + weight_dtype=torch.int4, + ), + ), (QuantDtype.use_8a8w, True): ( get_8a8w_qnn_qat_config, partial(get_qat_per_channel_quant_config), @@ -150,32 +176,62 @@ def __post_init__(self): if self.act_observer else quant_config_func() ) - self.per_channel_quant_config = ( - per_channel_quant_config_func(act_observer=self.act_observer) - if self.act_observer - else per_channel_quant_config_func() - ) - self.use_per_channel_weight_quant_ops = set() + + # Assume per_channel_quant/per_block_quant only happen on axis_0 or axis_1, increase the range if there's a need + potential_axis = 2 + + self.per_channel_quant_config_list = [] + for i in range(potential_axis): + self.per_channel_quant_config_list.append( + ( + per_channel_quant_config_func( + act_observer=self.act_observer, ch_axis=i + ) + if self.act_observer + else per_channel_quant_config_func(ch_axis=i) + ) + ) + + # Key is the node target, and value is the axis to perform per channel quantization + self.op_axis_dict = { + torch.ops.aten.conv1d.default: 0, + torch.ops.aten.conv2d.default: 0, + torch.ops.aten.conv3d.default: 0, + torch.ops.aten.conv_transpose2d.input: 1, + torch.ops.aten.conv_transpose3d.input: 1, + torch.ops.aten.linear.default: 0, + } + + self.use_per_channel_weight_quant_ops = {} if self.is_conv_per_channel: + conv_ops = [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d.input, + ] self.use_per_channel_weight_quant_ops.update( - { - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - torch.ops.aten.conv_transpose2d.input, - } + {k: self.op_axis_dict[k] for k in conv_ops if k in self.op_axis_dict} ) if self.is_linear_per_channel: + linear_ops = [torch.ops.aten.linear.default] self.use_per_channel_weight_quant_ops.update( - { - torch.ops.aten.linear.default, - } + {k: self.op_axis_dict[k] for k in linear_ops if k in self.op_axis_dict} ) + if per_block_quant_config_func: - self.per_block_quant_config = ( - per_block_quant_config_func(act_observer=self.act_observer) - if self.act_observer - else per_block_quant_config_func() - ) + self.per_block_quant_config_list = [] + for i in range(potential_axis): + self.per_block_quant_config_list.append( + ( + per_block_quant_config_func( + act_observer=self.act_observer, ch_axis=i + ) + if self.act_observer + else per_block_quant_config_func(ch_axis=i) + ) + ) class QnnQuantizer(Quantizer): @@ -212,10 +268,12 @@ def __init__(self): self.submodule_qconfig_list: List[ Tuple[Callable[[torch.fx.Node], bool], ModuleQConfig] ] = [] + self.block_size_map = {} self.custom_quant_annotations: Sequence[Callable] = [] self.discard_nodes: Set[str] = set() + self.recipe = None def _annotate(self, gm: GraphModule) -> None: """ @@ -268,16 +326,22 @@ def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig] op = node.target if isinstance(op, str): return - + config = self._get_submodule_qconfig(node) if block_size := self.block_size_map.get(node.name): - config = self.default_quant_config.per_block_quant_config + ch_axis = config.op_axis_dict.get(node.target, 0) + assert ( + len(config.per_block_quant_config_list) > ch_axis + ), f"Unsupported per block quantization axis: {ch_axis}, please increase the range of per_block_quant_config_list" + config = config.per_block_quant_config_list[ch_axis] config.block_size = block_size return config - config = self._get_submodule_qconfig(node) - if op in config.use_per_channel_weight_quant_ops: - return config.per_channel_quant_config + ch_axis = config.use_per_channel_weight_quant_ops[op] + assert ( + len(config.per_channel_quant_config_list) > ch_axis + ), f"Unsupported per channel quantization axis: {ch_axis}, please increase the range of per_channel_quant_config_list" + return config.per_channel_quant_config_list[ch_axis] if op in self.quant_ops: return config.quant_config @@ -312,14 +376,20 @@ def annotate(self, model: GraphModule) -> GraphModule: """ Annotates GraphModule during prepare_pt2e. + If a recipe is provided, it will be used to annotate the model. + Otherwise, fallback to the default annotation flow. + Args: model (GraphModule): The FX GraphModule to annotate. Returns: GraphModule: The annotated model. """ - self._annotate(model) - self._annotate_custom_annotation(model) + if self.recipe: + self.recipe.annotate(model) + else: + self._annotate(model) + self._annotate_custom_annotation(model) return model @@ -353,10 +423,10 @@ def set_default_quant_config( """ self.default_quant_config = ModuleQConfig( quant_dtype, - is_qat, - is_conv_per_channel, - is_linear_per_channel, - act_observer, + is_qat=is_qat, + is_conv_per_channel=is_conv_per_channel, + is_linear_per_channel=is_linear_per_channel, + act_observer=act_observer, ) def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None: diff --git a/backends/qualcomm/recipes/TARGETS b/backends/qualcomm/recipes/TARGETS index 12d1bac6f12..6a7abfd61c6 100644 --- a/backends/qualcomm/recipes/TARGETS +++ b/backends/qualcomm/recipes/TARGETS @@ -30,6 +30,7 @@ runtime.python_library( deps = [ "//caffe2:torch", "//executorch/export:lib", + "//executorch/runtime:runtime", # @manual "//executorch/backends/qualcomm/partition:partition", "//executorch/backends/qualcomm/serialization:serialization", "//executorch/backends/qualcomm/utils:utils", diff --git a/backends/qualcomm/runtime/QnnBackendOptions.cpp b/backends/qualcomm/runtime/QnnBackendOptions.cpp index 17e9975008d..1ce48cfcd61 100644 --- a/backends/qualcomm/runtime/QnnBackendOptions.cpp +++ b/backends/qualcomm/runtime/QnnBackendOptions.cpp @@ -21,12 +21,28 @@ T get_option(T aot_option) { executorch::runtime::BackendOption backend_option; if constexpr (std::is_same_v) { - backend_option = {QNN_RUNTIME_LOG_LEVEL, -1}; + std::strncpy( + backend_option.key, + QNN_RUNTIME_LOG_LEVEL, + runtime::kMaxOptionKeyLength); + backend_option.key[runtime::kMaxOptionKeyLength - 1] = '\0'; + backend_option.value = -1; } else if constexpr (std::is_same_v) { - backend_option = {QNN_RUNTIME_HTP_PERFORMANCE_MODE, -1}; + std::strncpy( + backend_option.key, + QNN_RUNTIME_HTP_PERFORMANCE_MODE, + runtime::kMaxOptionKeyLength); + backend_option.key[runtime::kMaxOptionKeyLength - 1] = '\0'; + backend_option.value = -1; } else if constexpr (std::is_same_v) { - backend_option = {QNN_RUNTIME_PROFILE_LEVEL, -1}; + std::strncpy( + backend_option.key, + QNN_RUNTIME_PROFILE_LEVEL, + runtime::kMaxOptionKeyLength); + backend_option.key[runtime::kMaxOptionKeyLength - 1] = '\0'; + backend_option.value = -1; } + // This will call get_option under runtime backend interface status = get_option(QNN_BACKEND, backend_option); diff --git a/backends/qualcomm/runtime/QnnExecuTorch.h b/backends/qualcomm/runtime/QnnExecuTorch.h index d8fbade3b3b..ccd02273c4f 100644 --- a/backends/qualcomm/runtime/QnnExecuTorch.h +++ b/backends/qualcomm/runtime/QnnExecuTorch.h @@ -69,10 +69,6 @@ void* QnnExecuTorchAllocCustomMem(size_t bytes, size_t alignment); /// handle to tensor wrapper during execution void QnnExecuTorchAddCustomMemTensorAddr(void* tensor_addr, void* custom_mem); -/// Add custom mem tensor info. Help to bring forward the memHandle creating -/// time from execution to initialization. -void QnnExecuTorchAddCustomMemTensorInfo(const CustomMemTensorInfo& info); - /// Free the allocated shared memory. void QnnExecuTorchFreeCustomMem(void* buffer_ptr); diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index 988c4b84a68..41c2370e4cb 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -90,7 +90,11 @@ Result QnnExecuTorchBackend::init( } ET_CHECK_OR_RETURN_ERROR( - qnn_manager->Init() == Error::Ok, + qnn_manager->InitBackend() == Error::Ok, + Internal, + "Fail to initialize Qnn Manager"); + ET_CHECK_OR_RETURN_ERROR( + qnn_manager->InitContext() == Error::Ok, Internal, "Fail to initialize Qnn Manager"); diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index 5e3220f25d9..17dc6bf4e19 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -54,15 +54,9 @@ QnnManager::~QnnManager() { QnnManager::QnnManager( const QnnExecuTorchOptions* options, const QnnExecuTorchContextBinary& qnn_executorch_context_binary) - : qnn_context_blob_(qnn_executorch_context_binary), - qnn_loaded_backend_(""), - // options' life cycle is decided by compiler specs which is - // kept by executorch runtime framework - // please pay attention to any potential seg fault - options_(options) { + : qnn_context_blob_(qnn_executorch_context_binary), options_(options) { QnnExecuTorchBackendType backend_type = options->backend_options()->backend_type(); - std::string library_path = options->library_path()->str(); if (get_option(options_->log_level()) >= QnnExecuTorchLogLevel::kLogLevelInfo) { @@ -71,10 +65,8 @@ QnnManager::QnnManager( EnumNameQcomChipset(options_->soc_info()->soc_model())); QNN_EXECUTORCH_LOG_INFO( "backend_type: %s", EnumNameQnnExecuTorchBackendType(backend_type)); - for (auto name : *options_->graph_name()) { - QNN_EXECUTORCH_LOG_INFO("graph_name: %s", name->c_str()); - } - QNN_EXECUTORCH_LOG_INFO("library_path: %s", library_path.c_str()); + QNN_EXECUTORCH_LOG_INFO( + "library_path: %s", options->library_path()->str().c_str()); QNN_EXECUTORCH_LOG_INFO("dump intermediate outputs: %s", IsTensorDump()); QNN_EXECUTORCH_LOG_INFO( "log_level: %s", @@ -95,81 +87,13 @@ QnnManager::QnnManager( options_->op_package_options()->op_package_infos()->size()); } - if (library_path.empty()) { - switch (backend_type) { - case QnnExecuTorchBackendType::kHtpBackend: - library_path = htp_library_name_; - break; - case QnnExecuTorchBackendType::kDspBackend: - library_path = dsp_library_name_; - break; - case QnnExecuTorchBackendType::kGpuBackend: - library_path = gpu_library_name_; - break; - default: - QNN_EXECUTORCH_LOG_ERROR("Unknown backend type: %d", backend_type); - break; - } - } - qnn_loaded_backend_ = QnnImplementation(library_path); backend_params_ptr_ = std::make_unique(); + backend_bundle_ptr_ = std::make_shared(); qnn_dlc_manager_ = std::make_shared(qnn_context_blob_, options_); } -Error QnnManager::LoadQnnLibrary() { - auto config = GetImplementationConfig(); - Error ret = qnn_loaded_backend_.Load(config.get()); - return ret; -} - -Error QnnManager::PreRegisterMem() { - SharedBuffer& shared_buffer_manager = SharedBuffer::GetSharedBufferManager(); - for (const auto info : shared_buffer_manager.GetCustomMemTensorInfoSet()) { - void* unaligned_custom_mem_base = - shared_buffer_manager.GetUnAlignedAddr(info.custom_mem); - - size_t tensor_offset = (static_cast(info.custom_mem) - - static_cast(unaligned_custom_mem_base)) + - info.pos; - size_t total_custom_mem_size = - shared_buffer_manager.GetAllocatedSize(info.custom_mem); - - int32_t mem_fd = shared_buffer_manager.MemToFd(unaligned_custom_mem_base); - if (mem_fd == -1) { - QNN_EXECUTORCH_LOG_WARN( - "PreRegisterMem failed to get file descriptor.", - "custom_mem: %p", - "tensor_addr: %p", - "pos: %uz", - "tensor_bytes: %uz", - "shape: %p", - "rank: %zu", - "qnn_dtype: %X", - info.custom_mem, - info.tensor_addr, - info.pos, - info.tensor_bytes, - info.shape, - info.rank, - info.dtype); - return Error::Internal; - } - - ET_CHECK_OR_RETURN_ERROR( - backend_params_ptr_->qnn_mem_manager_ptr_->PreRegisterCustomMemHandle( - mem_fd, - unaligned_custom_mem_base, - total_custom_mem_size, - tensor_offset, - info) == Error::Ok, - Internal, - "Fail to register to shared memory."); - } - return Error::Ok; -} - Error QnnManager::RegisterMem( void* data_ptr, const std::shared_ptr& tensor_wrapper) { @@ -256,6 +180,9 @@ Error QnnManager::RegisterCustomMem( Qnn_MemHandle_t pre_registered_handle = backend_params_ptr_->qnn_mem_manager_ptr_->GetPreRegisteredHandle(info); + // If this memory block has already been registered, we can use it directly. + // This applies when running llama in lookahead mode with the same AR-N model + // handling both the prompt processor and the token generator. if (pre_registered_handle != nullptr) { if (get_option(options_->log_level()) >= QnnExecuTorchLogLevel::kLogLevelInfo) { @@ -268,15 +195,15 @@ Error QnnManager::RegisterCustomMem( } SharedBuffer& shared_buffer_manager = SharedBuffer::GetSharedBufferManager(); - void* unaligned_custom_mem_base = - shared_buffer_manager.GetUnAlignedAddr(custom_mem_base); - size_t tensor_offset = static_cast(custom_mem_base) - - static_cast(unaligned_custom_mem_base) + info.pos; + size_t tensor_offset = info.pos; size_t total_custom_mem_size = shared_buffer_manager.GetAllocatedSize(custom_mem_base); - int32_t mem_fd = shared_buffer_manager.MemToFd(unaligned_custom_mem_base); + int32_t mem_fd = shared_buffer_manager.MemToFd(custom_mem_base); + // Note: If obtaining the file descriptor fails, it may be due to memory not + // being released with QnnExecuTorchFreeCustomMem. In this situation, we could + // consider adding a map to monitor it. if (mem_fd == -1) { QNN_EXECUTORCH_LOG_WARN( "Tensor name %s failed to get file descriptor.", @@ -289,7 +216,6 @@ Error QnnManager::RegisterCustomMem( tensor_wrapper, mem_fd, data_ptr, - unaligned_custom_mem_base, total_custom_mem_size, tensor_offset, info) == Error::Ok, @@ -299,15 +225,20 @@ Error QnnManager::RegisterCustomMem( return Error::Ok; } -Error QnnManager::Init() { +Error QnnManager::InitBackend() { + // Get or create the shared backend bundle + Error err = QnnBackendUnifiedRegistry::GetInstance().GetOrCreateBackendBundle( + options_, backend_bundle_ptr_); ET_CHECK_OR_RETURN_ERROR( - LoadQnnLibrary() == Error::Ok, Internal, "Fail to load Qnn library"); - logger_ = std::make_unique( - qnn_loaded_backend_, LoggingCallback, get_option(options_->log_level())); - std::vector graph_names; - for (auto name : *options_->graph_name()) { - graph_names.emplace_back(name->str()); - } + err == Error::Ok, + Internal, + "Fail to get or create shared Qnn backend bundle. Error code: %d", + static_cast(err)); + return Error::Ok; +} + +Error QnnManager::InitContext( + std::optional> graph_names) { if (backend_params_ptr_->backend_init_state_ == BackendInitializeState::UNINITIALIZED) { QNN_EXECUTORCH_LOG_INFO( @@ -315,8 +246,9 @@ Error QnnManager::Init() { "parameters for Qnn executorch backend type %d", options_->backend_options()->backend_type()); backend_params_ptr_ = QnnBackendFactory().Create( - qnn_loaded_backend_, - logger_.get(), + backend_bundle_ptr_->implementation.get(), + backend_bundle_ptr_->qnn_backend_ptr.get(), + backend_bundle_ptr_->qnn_device_ptr.get(), qnn_context_blob_, options_, qnn_dlc_manager_.get()); @@ -324,20 +256,13 @@ Error QnnManager::Init() { backend_params_ptr_ != nullptr, Internal, "Failed to load Qnn backend."); + // Note: For online_prepare or deserialization, the graph name will be + // obtained from the binary. ET_CHECK_OR_RETURN_ERROR( - backend_params_ptr_->qnn_backend_cache_ptr_->Configure(graph_names) == - Error::Ok, + backend_params_ptr_->qnn_backend_cache_ptr_->Configure( + graph_names.value_or(std::vector{})) == Error::Ok, Internal, "Fail to configure Qnn backend cache"); - ET_CHECK_OR_RETURN_ERROR( - backend_params_ptr_->qnn_backend_ptr_->Configure( - options_->op_package_options()) == Error::Ok, - Internal, - "Fail to configure Qnn backend"); - ET_CHECK_OR_RETURN_ERROR( - backend_params_ptr_->qnn_device_ptr_->Configure() == Error::Ok, - Internal, - "Fail to configure Qnn device"); ET_CHECK_OR_RETURN_ERROR( backend_params_ptr_->qnn_context_ptr_->Configure() == Error::Ok, Internal, @@ -355,21 +280,16 @@ Error QnnManager::Init() { BackendInitializeState::INITIALIZED; } -#if defined(__aarch64__) - ET_CHECK_OR_RETURN_ERROR( - PreRegisterMem() == Error::Ok, - Internal, - "Fail to pre register custom memory handle"); -#endif - if (IsOnlinePrepare()) { + // Check whether the QNN version supports the DLC format. Qnn_ApiVersion_t qnn_version = {QNN_VERSION_INIT}; - qnn_loaded_backend_.GetQnnInterface().qnn_backend_get_api_version( - &qnn_version); + backend_bundle_ptr_->implementation->GetQnnInterface() + .qnn_backend_get_api_version(&qnn_version); ET_CHECK_OR_RETURN_ERROR( - qnn_dlc_manager_->SetUpDlcEnvironment(qnn_version.coreApiVersion) == - Error::Ok, + qnn_dlc_manager_->SetUpDlcEnvironment( + qnn_version.coreApiVersion, + graph_names.value_or(std::vector{})) == Error::Ok, Internal, "Fail to setup Dlc environment"); } @@ -514,13 +434,14 @@ Error QnnManager::ProfileExecuteData( } void QnnManager::Destroy() { - QNN_EXECUTORCH_LOG_INFO("Destroy Qnn backend parameters"); backend_params_ptr_.reset(new BackendConfigParameters()); - qnn_dlc_manager_->ResetBackendParams(); - logger_.reset(); - qnn_dlc_manager_->ResetLogger(); - qnn_loaded_backend_.TerminateAllBackends(); - qnn_dlc_manager_->TerminateAllBackends(); + backend_bundle_ptr_.reset(new QnnBackendBundle()); + qnn_dlc_manager_->Destroy(); +} + +void QnnManager::DestroyContext() { + backend_params_ptr_.reset(new BackendConfigParameters()); + qnn_dlc_manager_->Destroy(); } bool QnnManager::IsNodeSupportedByBackend( @@ -540,7 +461,7 @@ bool QnnManager::IsNodeSupportedByBackend( } } - error = backend_params_ptr_->qnn_backend_ptr_->BackendValidateOpConfig( + error = backend_bundle_ptr_->qnn_backend_ptr->BackendValidateOpConfig( op_wrapper->GetOpConfig()); if (error != QNN_SUCCESS) { QNN_EXECUTORCH_LOG_WARN( @@ -697,8 +618,3 @@ void QnnExecuTorchAddCustomMemTensorAddr(void* tensor_addr, void* custom_mem) { executorch::backends::qnn::SharedBuffer::GetSharedBufferManager() .AddCusomMemTensorAddr(tensor_addr, custom_mem); } - -void QnnExecuTorchAddCustomMemTensorInfo(const CustomMemTensorInfo& info) { - executorch::backends::qnn::SharedBuffer::GetSharedBufferManager() - .AddCusomMemTensorInfo(info); -} diff --git a/backends/qualcomm/runtime/QnnManager.h b/backends/qualcomm/runtime/QnnManager.h index c01a537f7bd..866a4edbad6 100644 --- a/backends/qualcomm/runtime/QnnManager.h +++ b/backends/qualcomm/runtime/QnnManager.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -30,7 +31,13 @@ class QnnManager { const QnnExecuTorchContextBinary& qnn_executorch_context_binary); ~QnnManager(); - executorch::runtime::Error Init(); + // Initialize the shared backend bundle such as QnnBackend and QnnDevice + executorch::runtime::Error InitBackend(); + // Initialize the non-shared QNN components, create the QnnGraph using the + // provided graph_names. Note: For online_prepare or deserialization, the + // graph name will be obtained from the binary. + executorch::runtime::Error InitContext( + std::optional> graph_names = std::nullopt); executorch::runtime::Error AllocateTensor(const std::string& graph_name); executorch::runtime::Error AllocateTensor( const std::string& graph_name, @@ -47,7 +54,11 @@ class QnnManager { const std::string& graph_name, executorch::runtime::EventTracer* event_tracer); + // Destroy all QNN components and decrease reference count of shared QNN + // resource void Destroy(); + // Only destroy all non-shared QNN components + void DestroyContext(); bool IsAvailable() { return true; @@ -103,35 +114,11 @@ class QnnManager { return backend_params_ptr_->qnn_context_ptr_->GetGraphNames(); } - std::string GetBinarySignature(); - private: - std::unique_ptr GetImplementationConfig() { - if (options_->saver()) { - auto outputDirCfg = std::make_unique(); - outputDirCfg->option = QNN_SAVER_CONFIG_OPTION_OUTPUT_DIRECTORY; - outputDirCfg->outputDirectory = options_->saver_output_dir()->c_str(); - - auto saverCfg = std::make_unique(2); - saverCfg[0] = outputDirCfg.release(); - saverCfg[1] = nullptr; - - return saverCfg; - } else { - return nullptr; - } - } - - executorch::runtime::Error LoadQnnLibrary(); - - static constexpr const char* htp_library_name_ = "libQnnHtp.so"; - static constexpr const char* gpu_library_name_ = "libQnnGpu.so"; - static constexpr const char* dsp_library_name_ = "libQnnDsp.so"; - QnnExecuTorchContextBinary qnn_context_blob_; std::unique_ptr backend_params_ptr_; - QnnImplementation qnn_loaded_backend_; - std::unique_ptr logger_; + std::shared_ptr + backend_bundle_ptr_; // New member to hold shared resources const QnnExecuTorchOptions* options_; std::unordered_map>> input_tensors_; diff --git a/backends/qualcomm/runtime/SharedBuffer.cpp b/backends/qualcomm/runtime/SharedBuffer.cpp index 99dee7c9a7b..d79f8041932 100644 --- a/backends/qualcomm/runtime/SharedBuffer.cpp +++ b/backends/qualcomm/runtime/SharedBuffer.cpp @@ -69,14 +69,6 @@ void* SharedBuffer::GetCustomMemBase(void* buf) { return it->second; } -void* SharedBuffer::GetUnAlignedAddr(void* buf) { - auto it = restore_map_.find(buf); - if (it == restore_map_.end()) { - return nullptr; - } - return it->second; -} - size_t SharedBuffer::GetAllocatedSize(void* buf) { auto it = allocated_size_map_.find(buf); if (it == allocated_size_map_.end()) { @@ -123,10 +115,10 @@ void* SharedBuffer::AllocMem(size_t bytes, size_t alignment) { QNN_EXECUTORCH_LOG_WARN("Failed to allocate the tensor by RPC memory."); return nullptr; } - allocated_size_map_.insert({buf, allocate_bytes}); auto aligned_buf = reinterpret_cast( alignTo(alignment, reinterpret_cast(buf))); bool status = restore_map_.insert({aligned_buf, buf}).second; + allocated_size_map_.insert({aligned_buf, allocate_bytes}); if (!status) { QNN_EXECUTORCH_LOG_ERROR("Failed to allocate the tensor by RPC memory."); rpc_mem_free_(buf); @@ -152,6 +144,15 @@ void SharedBuffer::FreeMem(void* buf) { } else { rpc_mem_free_(restore_map_[buf]); restore_map_.erase(buf); + allocated_size_map_.erase(buf); + // Unbind the custom memory from tensor address. + auto mit = custom_mem_to_tensor_addr_.find(buf); + if (mit != custom_mem_to_tensor_addr_.end()) { + for (auto it = mit->second.begin(); it != mit->second.end(); ++it) { + tensor_addr_to_custom_mem_.erase(*it); + } + custom_mem_to_tensor_addr_.erase(buf); + } } } @@ -185,14 +186,18 @@ Error SharedBuffer::Load() { } void SharedBuffer::AddCusomMemTensorAddr(void* tensor_addr, void* custom_mem) { - tensor_addr_to_custom_mem_.insert({tensor_addr, custom_mem}); + bool status = + tensor_addr_to_custom_mem_.insert({tensor_addr, custom_mem}).second; + if (!status) { + QNN_EXECUTORCH_LOG_WARN( + "Tensor address %p already associated with custom memory %p", + tensor_addr, + custom_mem); + return; + } + custom_mem_to_tensor_addr_[custom_mem].insert(tensor_addr); }; -void SharedBuffer::AddCusomMemTensorInfo(const CustomMemTensorInfo& info) { - custom_mem_tensor_info_set_.insert(info); - tensor_addr_to_custom_mem_.insert({info.tensor_addr, info.custom_mem}); -} - Error SharedBuffer::UnLoad() { if (dlclose(lib_cdsp_rpc_) != 0) { QNN_EXECUTORCH_LOG_ERROR( diff --git a/backends/qualcomm/runtime/SharedBuffer.h b/backends/qualcomm/runtime/SharedBuffer.h index a02ea0e4c25..6bf06a6350b 100644 --- a/backends/qualcomm/runtime/SharedBuffer.h +++ b/backends/qualcomm/runtime/SharedBuffer.h @@ -59,19 +59,10 @@ class SharedBuffer final { // memory handle is registered during execution void AddCusomMemTensorAddr(void* tensor_addr, void* custom_mem); - // memory handle can be registered before execution - void AddCusomMemTensorInfo(const CustomMemTensorInfo& info); - size_t GetAllocatedSize(void* buf); void* GetCustomMemBase(void* buf); - void* GetUnAlignedAddr(void* buf); - - const std::unordered_set& GetCustomMemTensorInfoSet() { - return custom_mem_tensor_info_set_; - }; - private: SharedBuffer() = default; @@ -93,7 +84,10 @@ class SharedBuffer final { std::unordered_map allocated_size_map_; // Maps for the custom memory std::unordered_map tensor_addr_to_custom_mem_; - std::unordered_set custom_mem_tensor_info_set_; + // After the custom memory is freed, we will ensure that no tensor addresses + // remain linked to this custom memory. + std::unordered_map> + custom_mem_to_tensor_addr_; std::atomic_bool initialize_{false}; static std::mutex init_mutex_; }; diff --git a/backends/qualcomm/runtime/backends/CMakeLists.txt b/backends/qualcomm/runtime/backends/CMakeLists.txt index 6a44f3234c5..d0f025bfbaa 100644 --- a/backends/qualcomm/runtime/backends/CMakeLists.txt +++ b/backends/qualcomm/runtime/backends/CMakeLists.txt @@ -43,58 +43,70 @@ target_sources( ${CMAKE_CURRENT_LIST_DIR}/QnnProfiler.cpp ) -# qnn_device -set(HOST_ARCHITECTURE - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/${CMAKE_SYSTEM_PROCESSOR} +set(HOST_ARCHITECTURE_GPU + ${CMAKE_CURRENT_LIST_DIR}/gpu/${CMAKE_SYSTEM_PROCESSOR} +) +set(HOST_ARCHITECTURE_HTP + ${CMAKE_CURRENT_LIST_DIR}/htp/${CMAKE_SYSTEM_PROCESSOR} ) +set(HOST_ARCHITECTURE_IR ${CMAKE_CURRENT_LIST_DIR}/ir/${CMAKE_SYSTEM_PROCESSOR}) +# qnn_device target_sources( qnn_device PUBLIC ${CMAKE_CURRENT_LIST_DIR}/QnnDeviceCommon.h - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpDevice.h + ${CMAKE_CURRENT_LIST_DIR}/gpu/GpuDevice.h + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpDevice.h PRIVATE ${CMAKE_CURRENT_LIST_DIR}/QnnDeviceCommon.cpp - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpDevice.cpp - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpDevicePlatformInfoConfig.h - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpDeviceCustomConfig.h + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpDevice.cpp + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpDevicePlatformInfoConfig.h + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpDeviceCustomConfig.h # When offline prepare context cache in x86 host we have to provide # platform infomation and SocModel to Qnn - ${HOST_ARCHITECTURE}/HtpDevicePlatformInfoConfig.cpp - ${HOST_ARCHITECTURE}/HtpDeviceCustomConfig.cpp + ${HOST_ARCHITECTURE_HTP}/HtpDevicePlatformInfoConfig.cpp + ${HOST_ARCHITECTURE_HTP}/HtpDeviceCustomConfig.cpp ) # qnn_context target_sources( qnn_context PUBLIC ${CMAKE_CURRENT_LIST_DIR}/QnnContextCommon.h - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpContext.h - ${CMAKE_CURRENT_LIST_DIR}/irbackend/IrContext.h - PRIVATE - ${CMAKE_CURRENT_LIST_DIR}/QnnContextCommon.cpp - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpContext.cpp - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpContextCustomConfig.h - ${HOST_ARCHITECTURE}/HtpContextCustomConfig.cpp - ${CMAKE_CURRENT_LIST_DIR}/irbackend/${CMAKE_SYSTEM_PROCESSOR}/IrContext.cpp + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpContext.h + ${CMAKE_CURRENT_LIST_DIR}/ir/IrContext.h + ${CMAKE_CURRENT_LIST_DIR}/gpu/GpuContext.h + PRIVATE ${CMAKE_CURRENT_LIST_DIR}/QnnContextCommon.cpp + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpContext.cpp + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpContextCustomConfig.h + ${CMAKE_CURRENT_LIST_DIR}/gpu/GpuContext.cpp + ${CMAKE_CURRENT_LIST_DIR}/gpu/GpuContextCustomConfig.h + ${HOST_ARCHITECTURE_GPU}/GpuContextCustomConfig.cpp + ${HOST_ARCHITECTURE_HTP}/HtpContextCustomConfig.cpp + ${HOST_ARCHITECTURE_IR}/IrContext.cpp ) # qnn_backend_cache target_sources( qnn_backend_cache PUBLIC ${CMAKE_CURRENT_LIST_DIR}/QnnBackendCache.h - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpBackendCache.h + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpBackendCache.h PRIVATE ${CMAKE_CURRENT_LIST_DIR}/QnnBackendCache.cpp - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpBackendCache.cpp + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpBackendCache.cpp ) # qnn_graph target_sources( qnn_graph PUBLIC ${CMAKE_CURRENT_LIST_DIR}/QnnGraphCommon.h - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpGraph.h + ${CMAKE_CURRENT_LIST_DIR}/gpu/GpuGraph.h + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpGraph.h PRIVATE ${CMAKE_CURRENT_LIST_DIR}/QnnGraphCommon.cpp - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpGraph.cpp - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpGraphCustomConfig.h - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpGraphCustomConfig.cpp - ${HOST_ARCHITECTURE}/HtpGraphCustomConfig.cpp + ${CMAKE_CURRENT_LIST_DIR}/gpu/GpuGraph.cpp + ${CMAKE_CURRENT_LIST_DIR}/gpu/GpuGraphCustomConfig.h + ${CMAKE_CURRENT_LIST_DIR}/gpu/GpuGraphCustomConfig.cpp + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpGraph.cpp + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpGraphCustomConfig.h + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpGraphCustomConfig.cpp + ${HOST_ARCHITECTURE_HTP}/HtpGraphCustomConfig.cpp ) # qnn_op_package_manager @@ -108,9 +120,13 @@ target_sources( target_sources( qnn_backend PUBLIC ${CMAKE_CURRENT_LIST_DIR}/QnnBackendCommon.h - ${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpBackend.h - ${CMAKE_CURRENT_LIST_DIR}/irbackend/IrBackend.h + ${CMAKE_CURRENT_LIST_DIR}/gpu/GpuBackend.h + ${CMAKE_CURRENT_LIST_DIR}/htp/HtpBackend.h + ${CMAKE_CURRENT_LIST_DIR}/ir/IrBackend.h PRIVATE ${CMAKE_CURRENT_LIST_DIR}/QnnBackendCommon.cpp + ${CMAKE_CURRENT_LIST_DIR}/gpu/GpuBackend.cpp + ${CMAKE_CURRENT_LIST_DIR}/gpu/GpuBackendCustomConfig.h + ${CMAKE_CURRENT_LIST_DIR}/gpu/GpuBackendCustomConfig.cpp ) # qnn_mem_manager @@ -138,6 +154,12 @@ target_sources( target_sources( qnn_dlc_manager PUBLIC ${CMAKE_CURRENT_LIST_DIR}/QnnDlcManager.h - PRIVATE - ${CMAKE_CURRENT_LIST_DIR}/irbackend/${CMAKE_SYSTEM_PROCESSOR}/QnnDlcManager.cpp + PRIVATE ${HOST_ARCHITECTURE_IR}/QnnDlcManager.cpp +) + +# qnn_backend_unified_registry +target_sources( + qnn_backend_unified_registry + PUBLIC ${CMAKE_CURRENT_LIST_DIR}/QnnBackendUnifiedRegistry.h + PRIVATE ${CMAKE_CURRENT_LIST_DIR}/QnnBackendUnifiedRegistry.cpp ) diff --git a/backends/qualcomm/runtime/backends/QnnBackendCommon.cpp b/backends/qualcomm/runtime/backends/QnnBackendCommon.cpp index 960bbd9513e..81ec3ebde26 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendCommon.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendCommon.cpp @@ -13,7 +13,7 @@ namespace qnn { using executorch::runtime::Error; QnnBackend::~QnnBackend() { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; if (nullptr != handle_) { QNN_EXECUTORCH_LOG_INFO("Destroy Qnn backend"); @@ -34,7 +34,7 @@ void QnnBackend::BackendRegisterOpPackage( const flatbuffers::Vector< flatbuffers::Offset>* op_packages_infos) { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; QnnExecuTorchOpPackagePlatform current_platform = QnnExecuTorchOpPackagePlatform::UNKNOWN; @@ -71,7 +71,7 @@ void QnnBackend::BackendRegisterOpPackage( Error QnnBackend::Configure( const QnnExecuTorchOpPackageOptions* op_package_options) { // create qnn backend - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; std::vector temp_backend_config; @@ -102,7 +102,7 @@ Error QnnBackend::Configure( } Error QnnBackend::VerifyQNNSDKVersion() { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ApiVersion_t qnn_version = {QNN_VERSION_INIT}; Qnn_ErrorHandle_t error = diff --git a/backends/qualcomm/runtime/backends/QnnBackendCommon.h b/backends/qualcomm/runtime/backends/QnnBackendCommon.h index a66119dab22..e146a67d772 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendCommon.h +++ b/backends/qualcomm/runtime/backends/QnnBackendCommon.h @@ -27,10 +27,11 @@ namespace qnn { // qnn backend class QnnBackend { public: - explicit QnnBackend( - const QnnImplementation& implementation, - QnnLogger* logger) + explicit QnnBackend(QnnImplementation* implementation, QnnLogger* logger) : handle_(nullptr), implementation_(implementation), logger_(logger) {} + QnnBackend(const QnnBackend&) = delete; // Delete copy constructor + QnnBackend& operator=(const QnnBackend&) = + delete; // Delete assignment operator virtual ~QnnBackend(); virtual bool IsProfileEventTypeParentOfNodeTime( @@ -42,7 +43,7 @@ class QnnBackend { const QnnExecuTorchOpPackageOptions* op_package_options); Qnn_ErrorHandle_t BackendValidateOpConfig(const Qnn_OpConfig_t& op_config) { - return implementation_.GetQnnInterface().qnn_backend_validate_op_config( + return implementation_->GetQnnInterface().qnn_backend_validate_op_config( handle_, op_config); }; @@ -65,7 +66,7 @@ class QnnBackend { flatbuffers::Offset>* op_packages_info); Qnn_BackendHandle_t handle_; - const QnnImplementation& implementation_; + QnnImplementation* implementation_; QnnOpPackageManager op_package_manager_; QnnLogger* logger_; executorch::runtime::Error VersionChecker( diff --git a/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp b/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp index e7e9db6fed8..9c559d83fcc 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp @@ -16,8 +16,9 @@ namespace qnn { using executorch::runtime::Error; std::unique_ptr QnnBackendFactory::Create( - const QnnImplementation& implementation, - QnnLogger* logger, + QnnImplementation* implementation_ptr, + QnnBackend* qnn_backend_ptr, + QnnDevice* qnn_device_ptr, const QnnExecuTorchContextBinary& qnn_context_blob, const QnnExecuTorchOptions* options, QnnDlcManager* qnn_dlc_manager) { @@ -26,15 +27,8 @@ std::unique_ptr QnnBackendFactory::Create( switch (options->backend_options()->backend_type()) { case QnnExecuTorchBackendType::kHtpBackend: { auto htp_options = options->backend_options()->htp_options(); - const std::string skel_library_dir = - htp_options->skel_library_dir()->str(); - if (!skel_library_dir.empty()) { - setenv("ADSP_LIBRARY_PATH", skel_library_dir.c_str(), /*overwrite=*/1); - } if (get_option(options->log_level()) >= QnnExecuTorchLogLevel::kLogLevelInfo) { - QNN_EXECUTORCH_LOG_INFO( - "skel_library_dir: %s", skel_library_dir.c_str()); QNN_EXECUTORCH_LOG_INFO( "htp_arch in htp_info: %s", EnumNameHtpArch(options->soc_info()->htp_info()->htp_arch())); @@ -53,51 +47,91 @@ std::unique_ptr QnnBackendFactory::Create( EnumNameQnnExecuTorchHtpPdSession(htp_options->pd_session())); QNN_EXECUTORCH_LOG_INFO( "use_conv_hmx in htp_options: %d", htp_options->use_conv_hmx()); + QNN_EXECUTORCH_LOG_INFO( + "use_dlbc in htp_options: %d", htp_options->use_dlbc()); QNN_EXECUTORCH_LOG_INFO( "use_fold_relu in htp_options: %d", htp_options->use_fold_relu()); + QNN_EXECUTORCH_LOG_INFO( + "use_multi_contexts in htp_options: %d", + htp_options->use_multi_contexts()); + QNN_EXECUTORCH_LOG_INFO( + "use_weight_sharing in htp_options: %d", + htp_options->use_weight_sharing()); } - backend_params->qnn_backend_ptr_ = - std::make_unique(implementation, logger); - - backend_params->qnn_device_ptr_ = std::make_unique( - implementation, logger, options->soc_info(), htp_options); - backend_params->qnn_backend_cache_ptr_ = std::make_unique(qnn_context_blob); backend_params->qnn_context_ptr_ = std::make_unique( - implementation, - backend_params->qnn_backend_ptr_.get(), - backend_params->qnn_device_ptr_.get(), + implementation_ptr, + qnn_backend_ptr, + qnn_device_ptr, backend_params->qnn_backend_cache_ptr_.get(), htp_options, qnn_dlc_manager); backend_params->qnn_graph_ptr_ = std::make_unique( - implementation, - backend_params->qnn_backend_ptr_.get(), + implementation_ptr, + qnn_backend_ptr, backend_params->qnn_context_ptr_.get(), get_option(options->profile_level()), options->soc_info(), htp_options); - backend_params->qnn_mem_manager_ptr_ = std::make_unique( - implementation, + } break; + case QnnExecuTorchBackendType::kGpuBackend: { + auto gpu_options = options->backend_options()->gpu_options(); + if (options->log_level() >= QnnExecuTorchLogLevel::kLogLevelInfo) { + QNN_EXECUTORCH_LOG_INFO( + "performance_mode in gpu_options: %s", + EnumNameQnnExecuTorchGpuPerformanceMode( + gpu_options->performance_mode())); + QNN_EXECUTORCH_LOG_INFO( + "precision in gpu_options: %s", + EnumNameQnnExecuTorchGpuPrecision(gpu_options->precision())); + QNN_EXECUTORCH_LOG_INFO( + "use_memory_optimizations in gpu_options: %d", + gpu_options->use_memory_optimizations()); + QNN_EXECUTORCH_LOG_INFO( + "use_node_optimizations in gpu_options: %d", + gpu_options->use_node_optimizations()); + QNN_EXECUTORCH_LOG_INFO( + "use_queue_recording in gpu_options: %d", + gpu_options->use_queue_recording()); + QNN_EXECUTORCH_LOG_INFO( + "use_weight_sharing in gpu_options: %d", + gpu_options->use_weight_sharing()); + } + + backend_params->qnn_backend_cache_ptr_ = + std::make_unique(qnn_context_blob); + + backend_params->qnn_context_ptr_ = std::make_unique( + implementation_ptr, + qnn_backend_ptr, + qnn_device_ptr, + backend_params->qnn_backend_cache_ptr_.get(), + qnn_dlc_manager, + gpu_options); + + backend_params->qnn_graph_ptr_ = std::make_unique( + implementation_ptr, + qnn_backend_ptr, backend_params->qnn_context_ptr_.get(), - get_option(options->log_level())); - backend_params->backend_init_state_ = BackendInitializeState::INITIALIZED; + options->profile_level(), + gpu_options); } break; - case QnnExecuTorchBackendType::kGpuBackend: case QnnExecuTorchBackendType::kDspBackend: case QnnExecuTorchBackendType::kUndefinedBackend: default: return nullptr; } - if (backend_params->qnn_backend_ptr_->VerifyQNNSDKVersion() == Error::Ok) { - return backend_params; - } + backend_params->qnn_mem_manager_ptr_ = std::make_unique( + implementation_ptr, + backend_params->qnn_context_ptr_.get(), + options->log_level()); - return nullptr; + backend_params->backend_init_state_ = BackendInitializeState::INITIALIZED; + return backend_params; } } // namespace qnn } // namespace backends diff --git a/backends/qualcomm/runtime/backends/QnnBackendFactory.h b/backends/qualcomm/runtime/backends/QnnBackendFactory.h index 3d78a36b9f0..c125d5ffca4 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendFactory.h +++ b/backends/qualcomm/runtime/backends/QnnBackendFactory.h @@ -17,11 +17,11 @@ #include #include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include namespace executorch { @@ -31,22 +31,18 @@ namespace qnn { class QnnDlcManager; typedef enum { UNINITIALIZED, INITIALIZED } BackendInitializeState; -// @brief Struct containing all handles for a given QNN backend +// @brief Struct containing non-shared handles for a given QNN backend typedef struct BackendConfigParameters { - std::unique_ptr qnn_backend_ptr_; BackendInitializeState backend_init_state_; std::unique_ptr qnn_context_ptr_; - std::unique_ptr qnn_device_ptr_; std::unique_ptr qnn_graph_ptr_; std::unique_ptr qnn_mem_manager_ptr_; std::unique_ptr qnn_backend_cache_ptr_; // Default ctor BackendConfigParameters() - : qnn_backend_ptr_(nullptr), - backend_init_state_(BackendInitializeState::UNINITIALIZED), + : backend_init_state_(BackendInitializeState::UNINITIALIZED), qnn_context_ptr_(nullptr), - qnn_device_ptr_(nullptr), qnn_graph_ptr_(nullptr), qnn_mem_manager_ptr_(nullptr), qnn_backend_cache_ptr_(nullptr) {} @@ -56,8 +52,6 @@ typedef struct BackendConfigParameters { qnn_backend_cache_ptr_.reset(); qnn_mem_manager_ptr_.reset(); qnn_context_ptr_.reset(); - qnn_device_ptr_.reset(); - qnn_backend_ptr_.reset(); backend_init_state_ = BackendInitializeState::UNINITIALIZED; } @@ -66,8 +60,9 @@ typedef struct BackendConfigParameters { class QnnBackendFactory { public: std::unique_ptr Create( - const QnnImplementation& implementation, - QnnLogger* logger, + QnnImplementation* implementation, + QnnBackend* qnn_backend_ptr, + QnnDevice* qnn_device_ptr, const QnnExecuTorchContextBinary& qnn_context_blob, const QnnExecuTorchOptions* options, QnnDlcManager* qnn_dlc_manager); diff --git a/backends/qualcomm/runtime/backends/QnnBackendUnifiedRegistry.cpp b/backends/qualcomm/runtime/backends/QnnBackendUnifiedRegistry.cpp new file mode 100644 index 00000000000..8b1dcdf7a9d --- /dev/null +++ b/backends/qualcomm/runtime/backends/QnnBackendUnifiedRegistry.cpp @@ -0,0 +1,164 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace executorch { +namespace backends { +namespace qnn { +using executorch::runtime::Error; + +// Static instance for the singleton +QnnBackendUnifiedRegistry& QnnBackendUnifiedRegistry::GetInstance() { + static QnnBackendUnifiedRegistry instance; + return instance; +} + +// Private constructor +QnnBackendUnifiedRegistry::QnnBackendUnifiedRegistry() = default; + +// Destructor +QnnBackendUnifiedRegistry::~QnnBackendUnifiedRegistry() { + CleanupExpired(); +} + +Error QnnBackendUnifiedRegistry::GetOrCreateBackendBundle( + const QnnExecuTorchOptions* options, + std::shared_ptr& bundle) { + std::lock_guard lock(mutex_); + + // Extract relevant parameters from options for creation and validation + std::string current_lib_path = options->library_path()->str(); + QnnExecuTorchLogLevel current_log_level = get_option(options->log_level()); + QnnExecuTorchBackendType backend_type = + options->backend_options()->backend_type(); + + if (current_lib_path.empty()) { + switch (backend_type) { + case QnnExecuTorchBackendType::kHtpBackend: { + current_lib_path = htp_library_name_; + break; + } + case QnnExecuTorchBackendType::kGpuBackend: { + current_lib_path = gpu_library_name_; + break; + } + case QnnExecuTorchBackendType::kDspBackend: + case QnnExecuTorchBackendType::kUndefinedBackend: + default: + QNN_EXECUTORCH_LOG_ERROR( + "Unsupported backend type: %s", + EnumNameQnnExecuTorchBackendType(backend_type)); + return Error::NotFound; + } + } + + // Check if resources already exist + auto it = qnn_backend_bundles_map_.find(backend_type); + if (it != qnn_backend_bundles_map_.end()) { + // Create new shared_ptr that shares ownership of the managed object. + if (auto existing_bundle = it->second.lock()) { + bundle = existing_bundle; + if (bundle->qnn_logger_ptr->GetLogLevel() != current_log_level) { + bundle->qnn_logger_ptr = std::make_unique( + bundle->implementation.get(), LoggingCallback, current_log_level); + } + QNN_EXECUTORCH_LOG_INFO( + "Use cached backend bundle for current backend: %s", + EnumNameQnnExecuTorchBackendType(backend_type)); + return Error::Ok; + } + } + + QNN_EXECUTORCH_LOG_INFO("Creating new backend bundle."); + + // 1. Create QnnImplementation and load qnn library + std::unique_ptr implementation = + std::make_unique(current_lib_path); + auto config = GetImplementationConfig(options); + Error ret = implementation->Load(config.get()); + ET_CHECK_OR_RETURN_ERROR( + ret == Error::Ok, Internal, "Fail to load Qnn library"); + + // 2. Create QnnLogger + std::unique_ptr logger = std::make_unique( + implementation.get(), LoggingCallback, current_log_level); + + // 3. Create QnnBackend (specific type based on options) + // 4. Create QnnDevice (specific type based on options) + std::unique_ptr backend = nullptr; + std::unique_ptr device = nullptr; + + switch (backend_type) { + case QnnExecuTorchBackendType::kHtpBackend: { + auto htp_options = options->backend_options()->htp_options(); + backend = + std::make_unique(implementation.get(), logger.get()); + device = std::make_unique( + implementation.get(), logger.get(), options->soc_info(), htp_options); + break; + } + case QnnExecuTorchBackendType::kGpuBackend: { + auto gpu_options = options->backend_options()->gpu_options(); + backend = std::make_unique( + implementation.get(), logger.get(), gpu_options); + device = std::make_unique(implementation.get(), logger.get()); + break; + } + case QnnExecuTorchBackendType::kDspBackend: + case QnnExecuTorchBackendType::kUndefinedBackend: + default: + return Error::NotFound; + } + ET_CHECK_OR_RETURN_ERROR( + backend->Configure(options->op_package_options()) == Error::Ok, + Internal, + "Fail to configure Qnn backend"); + ET_CHECK_OR_RETURN_ERROR( + device->Configure() == Error::Ok, + Internal, + "Fail to configure Qnn device"); + + if (backend->VerifyQNNSDKVersion() != Error::Ok) { + return Error::Internal; + } + + bundle->implementation = std::move(implementation); + bundle->qnn_logger_ptr = std::move(logger); + bundle->qnn_backend_ptr = std::move(backend); + bundle->qnn_device_ptr = std::move(device); + qnn_backend_bundles_map_.emplace( + backend_type, bundle); // Store weak_ptr to the bundle + + return Error::Ok; +} + +void QnnBackendUnifiedRegistry::CleanupExpired() { + std::lock_guard lock(mutex_); + + for (auto it = qnn_backend_bundles_map_.begin(); + it != qnn_backend_bundles_map_.end();) { + if (it->second.expired()) { + it = qnn_backend_bundles_map_.erase(it); + } else { + ++it; + } + } +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/QnnBackendUnifiedRegistry.h b/backends/qualcomm/runtime/backends/QnnBackendUnifiedRegistry.h new file mode 100644 index 00000000000..b2549a3356c --- /dev/null +++ b/backends/qualcomm/runtime/backends/QnnBackendUnifiedRegistry.h @@ -0,0 +1,106 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace qnn { + +// A bundle struct to hold all shared QNN backend resources +struct QnnBackendBundle { + std::unique_ptr implementation; + std::unique_ptr qnn_logger_ptr; + std::unique_ptr qnn_backend_ptr; + std::unique_ptr qnn_device_ptr; + + // Default ctor + QnnBackendBundle() + : implementation(nullptr), + qnn_logger_ptr(nullptr), + qnn_backend_ptr(nullptr), + qnn_device_ptr(nullptr) {} + // Default dtor + ~QnnBackendBundle() { + qnn_device_ptr.reset(); + qnn_backend_ptr.reset(); + qnn_logger_ptr.reset(); + implementation.reset(); + } +}; + +class QnnBackendUnifiedRegistry { + // Singleton class to manage shared QNN backend resources. It ensures that + // only one instance of the registry exists throughout the application's + // lifetime. The registry maintains a map of backend bundles indexed by + // backend_type. Each bundle contains QnnImplentation, QnnLogger, QnnBackend, + // and QnnDevice objects for a specific backend type. The registry provides + // methods to get or create backend bundles, ensuring that resources are + // properly managed and reused when possible. It also includes a cleanup + // mechanism to remove expired bundles. + public: + static QnnBackendUnifiedRegistry& GetInstance(); + + executorch::runtime::Error GetOrCreateBackendBundle( + const QnnExecuTorchOptions* options, + std::shared_ptr& bundle); + + void CleanupExpired(); + + private: + QnnBackendUnifiedRegistry(); + ~QnnBackendUnifiedRegistry(); + + // Delete copy constructor and assignment operator + QnnBackendUnifiedRegistry(const QnnBackendUnifiedRegistry&) = delete; + QnnBackendUnifiedRegistry& operator=(const QnnBackendUnifiedRegistry&) = + delete; + + static constexpr const char* htp_library_name_ = "libQnnHtp.so"; + static constexpr const char* gpu_library_name_ = "libQnnGpu.so"; + static constexpr const char* dsp_library_name_ = "libQnnDsp.so"; + + std::unique_ptr GetImplementationConfig( + const QnnExecuTorchOptions* options) { + if (options->saver()) { + auto outputDirCfg = std::make_unique(); + outputDirCfg->option = QNN_SAVER_CONFIG_OPTION_OUTPUT_DIRECTORY; + outputDirCfg->outputDirectory = options->saver_output_dir()->c_str(); + + auto saverCfg = std::make_unique(2); + saverCfg[0] = outputDirCfg.release(); + saverCfg[1] = nullptr; + + return saverCfg; + } else { + return nullptr; + } + } + + // Stores the collection of shared resources, with backend_type being used as + // the key. + std::unordered_map> + qnn_backend_bundles_map_; + + std::mutex mutex_; // Protects access to resources and ensures atomic + // creation/destruction +}; + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/QnnContextCommon.cpp b/backends/qualcomm/runtime/backends/QnnContextCommon.cpp index ee49b10215a..e16a173db6c 100644 --- a/backends/qualcomm/runtime/backends/QnnContextCommon.cpp +++ b/backends/qualcomm/runtime/backends/QnnContextCommon.cpp @@ -14,7 +14,7 @@ namespace backends { namespace qnn { QnnContext::~QnnContext() { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; if (handle_ != nullptr) { QNN_EXECUTORCH_LOG_INFO("Destroy Qnn context"); @@ -33,7 +33,7 @@ QnnContext::~QnnContext() { Error QnnContext::Configure() { // create qnn context - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; std::vector temp_context_config; @@ -95,7 +95,7 @@ Error QnnContext::Configure() { Error QnnContext::GetContextBinary( QnnExecuTorchContextBinary& qnn_executorch_context_binary) { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ContextBinarySize_t binary_size = 0; Qnn_ContextBinarySize_t bytes_written = 0; Qnn_ErrorHandle_t error = diff --git a/backends/qualcomm/runtime/backends/QnnContextCommon.h b/backends/qualcomm/runtime/backends/QnnContextCommon.h index 0e9e12ef544..7d507a4a50c 100644 --- a/backends/qualcomm/runtime/backends/QnnContextCommon.h +++ b/backends/qualcomm/runtime/backends/QnnContextCommon.h @@ -24,7 +24,7 @@ class QnnDlcManager; class QnnContext { public: explicit QnnContext( - const QnnImplementation& implementation, + QnnImplementation* implementation, QnnBackend* backend, QnnDevice* device, QnnBackendCache* cache, @@ -74,7 +74,7 @@ class QnnContext { private: Qnn_ContextHandle_t handle_; - const QnnImplementation& implementation_; + QnnImplementation* implementation_; QnnBackend* backend_; QnnDevice* device_; QnnBackendCache* cache_; diff --git a/backends/qualcomm/runtime/backends/QnnDeviceCommon.cpp b/backends/qualcomm/runtime/backends/QnnDeviceCommon.cpp index 93d705efd3e..0280ec4f383 100644 --- a/backends/qualcomm/runtime/backends/QnnDeviceCommon.cpp +++ b/backends/qualcomm/runtime/backends/QnnDeviceCommon.cpp @@ -13,7 +13,7 @@ namespace qnn { using executorch::runtime::Error; QnnDevice::~QnnDevice() { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; if (nullptr != handle_) { QNN_EXECUTORCH_LOG_INFO("Destroy Qnn device"); @@ -32,7 +32,7 @@ QnnDevice::~QnnDevice() { Error QnnDevice::Configure() { // create qnn device - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; std::vector temp_device_config; diff --git a/backends/qualcomm/runtime/backends/QnnDeviceCommon.h b/backends/qualcomm/runtime/backends/QnnDeviceCommon.h index 85de00f8623..27da759c507 100644 --- a/backends/qualcomm/runtime/backends/QnnDeviceCommon.h +++ b/backends/qualcomm/runtime/backends/QnnDeviceCommon.h @@ -20,8 +20,10 @@ namespace backends { namespace qnn { class QnnDevice { public: - explicit QnnDevice(const QnnImplementation& implementation, QnnLogger* logger) + explicit QnnDevice(QnnImplementation* implementation, QnnLogger* logger) : implementation_(implementation), handle_(nullptr), logger_(logger) {} + QnnDevice(const QnnDevice&) = delete; // Delete copy constructor + QnnDevice& operator=(const QnnDevice&) = delete; // Delete assignment operator virtual ~QnnDevice(); @@ -29,7 +31,7 @@ class QnnDevice { return handle_; } - executorch::runtime::Error Configure(); + virtual executorch::runtime::Error Configure(); protected: virtual executorch::runtime::Error MakeConfig( @@ -40,7 +42,7 @@ class QnnDevice { virtual executorch::runtime::Error AfterCreateDevice() { return executorch::runtime::Error::Ok; }; - const QnnImplementation& implementation_; + QnnImplementation* implementation_; private: Qnn_DeviceHandle_t handle_; diff --git a/backends/qualcomm/runtime/backends/QnnDlcManager.h b/backends/qualcomm/runtime/backends/QnnDlcManager.h index a57906df4e3..4c320fde9ac 100644 --- a/backends/qualcomm/runtime/backends/QnnDlcManager.h +++ b/backends/qualcomm/runtime/backends/QnnDlcManager.h @@ -10,7 +10,8 @@ #include #include -#include +#include +#include #include "QnnWrapperUtils.hpp" namespace executorch { @@ -35,23 +36,23 @@ class QnnDlcManager { std::unique_ptr backend_params_ptr_ = std::make_unique(); + std::unique_ptr backend_bundle_ptr_ = + std::make_unique(); - void ResetBackendParams(); - void ResetLogger(); - void TerminateAllBackends(); + void Destroy(); - Error SetUpDlcEnvironment(const Qnn_Version_t& coreApiVersion); + Error SetUpDlcEnvironment( + const Qnn_Version_t& coreApiVersion, + const std::vector& graph_names); Error RegisterGraphsFromDLC( - const QnnImplementation& implementation, + QnnImplementation* implementation, QnnBackend* backend, QnnContext* context, QnnBackendCache* cache); private: static constexpr const char* library_name_ = "libQnnIr.so"; - QnnImplementation qnn_loaded_backend_; - std::unique_ptr logger_; const QnnExecuTorchContextBinary& qnn_context_blob_; const QnnExecuTorchOptions* options_; @@ -64,7 +65,7 @@ class QnnDlcManager { Error Create(); - Error Configure(); + Error Configure(const std::vector& graph_names); }; } // namespace qnn } // namespace backends diff --git a/backends/qualcomm/runtime/backends/QnnFunctionInterface.h b/backends/qualcomm/runtime/backends/QnnFunctionInterface.h index 548c363f388..0e1e4727aa3 100644 --- a/backends/qualcomm/runtime/backends/QnnFunctionInterface.h +++ b/backends/qualcomm/runtime/backends/QnnFunctionInterface.h @@ -105,6 +105,9 @@ class QnnInterface { const QNN_INTERFACE_VER_TYPE& GetInterfaceVer() const { return qnn_interface_->QNN_INTERFACE_VER_NAME; } + void Unload() { + qnn_interface_ = nullptr; + } private: // --------- QnnInterface --------- diff --git a/backends/qualcomm/runtime/backends/QnnGraphCommon.cpp b/backends/qualcomm/runtime/backends/QnnGraphCommon.cpp index 9fe81f4cf54..44bf11bc0f5 100644 --- a/backends/qualcomm/runtime/backends/QnnGraphCommon.cpp +++ b/backends/qualcomm/runtime/backends/QnnGraphCommon.cpp @@ -14,7 +14,7 @@ using executorch::runtime::Error; Error QnnGraph::Configure(const std::string& graph_name) { // create qnn backend - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; std::vector temp_graph_config; ET_CHECK_OR_RETURN_ERROR( @@ -81,7 +81,7 @@ Qnn_ErrorHandle_t QnnGraph::GraphExecute( return QNN_COMMON_ERROR_GENERAL; } - return implementation_.GetQnnInterface().qnn_graph_execute( + return implementation_->GetQnnInterface().qnn_graph_execute( handle_[graph_name], input_tensor_structs.data(), input_tensor_structs.size(), @@ -94,7 +94,7 @@ Qnn_ErrorHandle_t QnnGraph::GraphExecute( Error QnnGraph::EnsureTensorInQnnGraph( const std::string& graph_name, const std::shared_ptr& tensor_wrapper) { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; if (!tensor_wrapper->IsTensorCreated()) { diff --git a/backends/qualcomm/runtime/backends/QnnGraphCommon.h b/backends/qualcomm/runtime/backends/QnnGraphCommon.h index 33f903dae41..fbb5ab80140 100644 --- a/backends/qualcomm/runtime/backends/QnnGraphCommon.h +++ b/backends/qualcomm/runtime/backends/QnnGraphCommon.h @@ -23,7 +23,7 @@ namespace qnn { class QnnGraph { public: explicit QnnGraph( - const QnnImplementation& implementation, + QnnImplementation* implementation, QnnBackend* backend, QnnContext* context, const QnnExecuTorchProfileLevel& profile_level) @@ -44,7 +44,7 @@ class QnnGraph { Qnn_ErrorHandle_t GraphAddNode( const std::string& graph_name, const Qnn_OpConfig_t& op_config) { - return implementation_.GetQnnInterface().qnn_graph_add_node( + return implementation_->GetQnnInterface().qnn_graph_add_node( handle_[graph_name], op_config); }; executorch::runtime::Error EnsureTensorInQnnGraph( @@ -52,7 +52,7 @@ class QnnGraph { const std::shared_ptr& tensor_wrapper); Qnn_ErrorHandle_t GraphFinalize(const std::string& graph_name) { - return implementation_.GetQnnInterface().qnn_graph_finalize( + return implementation_->GetQnnInterface().qnn_graph_finalize( handle_[graph_name], profile_[graph_name]->GetHandle(), nullptr /* signal_handle */); @@ -84,7 +84,7 @@ class QnnGraph { private: std::unordered_map handle_; - const QnnImplementation& implementation_; + QnnImplementation* implementation_; QnnBackend* backend_; QnnContext* context_; QnnExecuTorchProfileLevel profile_level_; diff --git a/backends/qualcomm/runtime/backends/QnnImplementation.cpp b/backends/qualcomm/runtime/backends/QnnImplementation.cpp index 42f866d22cc..246800791e6 100644 --- a/backends/qualcomm/runtime/backends/QnnImplementation.cpp +++ b/backends/qualcomm/runtime/backends/QnnImplementation.cpp @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ #include - +#include #include "QnnInterface.h" namespace executorch { namespace backends { @@ -14,6 +14,14 @@ namespace qnn { using executorch::runtime::Error; +struct DlCloser { + int operator()(void* handle) { + if (handle == nullptr) + return 0; + return dlclose(handle); + } +}; + Error QnnImplementation::InitBackend( void* const lib_handle, const QnnSaver_Config_t** saver_config) { @@ -34,43 +42,39 @@ Error QnnImplementation::InitBackend( return Error::Ok; } -// instantiate static members -// NOLINTNEXTLINE(fuchsia-statically-constructed-objects) -std::unordered_map - QnnImplementation::lib_path_to_backend_id_; -// NOLINTNEXTLINE(fuchsia-statically-constructed-objects) -std::unordered_map - QnnImplementation::loaded_backend_; -// NOLINTNEXTLINE(fuchsia-statically-constructed-objects) -std::unordered_map - QnnImplementation::loaded_lib_handle_; -// NOLINTNEXTLINE(fuchsia-statically-constructed-objects) -std::mutex QnnImplementation::be_init_mutex_; - -Error QnnImplementation::StartBackend( +QnnImplementation::~QnnImplementation() { + Unload(); +} + +const QnnInterface_t* QnnImplementation::StartBackend( const std::string& lib_path, const QnnSaver_Config_t** saver_config) { Qnn_ErrorHandle_t error = QNN_SUCCESS; - void* lib_handle = nullptr; - lib_handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_GLOBAL); + // If the library is already loaded, return the handle. + std::unique_ptr lib_handle( + dlopen(lib_path.c_str(), RTLD_NOW | RTLD_NOLOAD)); + if (!lib_handle) { + lib_handle = std::unique_ptr( + dlopen(lib_path.c_str(), RTLD_NOW | RTLD_GLOBAL)); + } if (lib_handle == nullptr) { QNN_EXECUTORCH_LOG_ERROR( "Cannot Open QNN library %s, with error: %s", lib_path.c_str(), dlerror()); - return Error::Internal; + return nullptr; } // load get_provider function auto get_providers = loadQnnFunction( - lib_handle, "QnnInterface_getProviders"); + lib_handle.get(), "QnnInterface_getProviders"); if (get_providers == nullptr) { QNN_EXECUTORCH_LOG_ERROR( "QnnImplementation::Load Cannot load symbol " "QnnInterface_getProviders : %s", dlerror()); - return Error::Internal; + return nullptr; } // Get QnnInterface Providers @@ -82,7 +86,7 @@ Error QnnImplementation::StartBackend( QNN_EXECUTORCH_LOG_ERROR( "Qnn Interface failed to get providers. Error %d", QNN_GET_ERROR_CODE(error)); - return Error::Internal; + return nullptr; } if (num_providers != required_num_providers_) { @@ -91,115 +95,47 @@ Error QnnImplementation::StartBackend( "%d instead of required %d", num_providers, required_num_providers_); - return Error::Internal; + return nullptr; } - BackendIdType backend_id = provider_list[0]->backendId; - - // store everything - lib_path_to_backend_id_[lib_path] = backend_id; - - // we use lib_path as the first unique key. - // Users can get wrong like, he or she assigns - // library_path=libQnnHtp_1.so - // library_path=libQnnHtp_2.so - // for different QnnBackend instances. - // So we warning out here. - if (loaded_backend_.count(backend_id) > 0) { - QNN_EXECUTORCH_LOG_WARN( - "lib_path %s is loaded, but backend %d " - "already exists. Overwriting previous loaded backend...", - lib_path.c_str(), - backend_id); - } - loaded_backend_[backend_id] = provider_list[0]; - - if (loaded_lib_handle_.count(backend_id) > 0) { - QNN_EXECUTORCH_LOG_WARN("closing %pK...", loaded_lib_handle_[backend_id]); - - int dlclose_error = dlclose(loaded_lib_handle_[backend_id]); - if (dlclose_error != 0) { - QNN_EXECUTORCH_LOG_WARN( - "Sadly, fail to close %pK with error %s", - loaded_lib_handle_[backend_id], - dlerror()); - } - } - loaded_lib_handle_[backend_id] = lib_handle; - // Saver backend need initialization. - Error be_init_st = InitBackend(loaded_lib_handle_[backend_id], saver_config); + Error be_init_st = InitBackend(lib_handle.get(), saver_config); if (be_init_st != Error::Ok) { - // backend init fails. clear things - lib_path_to_backend_id_.erase(lib_path); - loaded_backend_.erase(backend_id); - - int dlclose_error = dlclose(loaded_lib_handle_[backend_id]); - if (dlclose_error != 0) { - QNN_EXECUTORCH_LOG_WARN( - "fail to close %pK after backend-init " - "failure, with error %s", - loaded_lib_handle_[backend_id], - dlerror()); - } - - loaded_lib_handle_.erase(backend_id); - return be_init_st; + return nullptr; } - return Error::Ok; + // hold the lib_handle + lib_handle_ = lib_handle.release(); + return provider_list[0]; } -Error QnnImplementation::TerminateAllBackends() { - Error ret_status = Error::Ok; - - loaded_backend_.clear(); +Error QnnImplementation::Unload() { + qnn_interface_.Unload(); - for (auto& it : loaded_lib_handle_) { - int dlclose_error = dlclose(it.second); - if (dlclose_error != 0) { - QNN_EXECUTORCH_LOG_ERROR( - "Fail to close QNN backend %d with error %s", it.first, dlerror()); - ret_status = Error::Internal; - } + if (lib_handle_ == nullptr) { + return Error::Ok; } - loaded_lib_handle_.clear(); - lib_path_to_backend_id_.clear(); - return ret_status; + int dlclose_error = dlclose(lib_handle_); + if (dlclose_error != 0) { + QNN_EXECUTORCH_LOG_ERROR( + "Fail to close QNN backend %s with error %s", + lib_path_.c_str(), + dlerror()); + return Error::Internal; + } + lib_handle_ = nullptr; + return Error::Ok; } Error QnnImplementation::Load(const QnnSaver_Config_t** saver_config) { - BackendIdType backend_id = QNN_BACKEND_ID_NULL; - { - const std::lock_guard lock(be_init_mutex_); - - if (lib_path_to_backend_id_.count(lib_path_) == 0) { - Error st = StartBackend(lib_path_, saver_config); - ET_CHECK_OR_RETURN_ERROR( - st == Error::Ok, Internal, "Fail to start backend"); - } - - // Get backend ID - backend_id = lib_path_to_backend_id_[lib_path_]; - - // really don't expect. - if (loaded_backend_.count(backend_id) == 0 || - loaded_lib_handle_.count(backend_id) == 0) { - QNN_EXECUTORCH_LOG_ERROR( - "library %s is loaded but " - "loaded backend count=%zu, " - "loaded lib_handle count=%zu", - lib_path_.c_str(), - loaded_backend_.count(backend_id), - loaded_lib_handle_.count(backend_id)); - return Error::Internal; - } - } // be_init_mutex_ release. + const QnnInterface_t* p_qnn_intf = StartBackend(lib_path_, saver_config); + ET_CHECK_OR_RETURN_ERROR( + p_qnn_intf != nullptr, Internal, "Fail to start backend"); // Connect QnnInterface - qnn_interface_.SetQnnInterface(loaded_backend_[backend_id]); + qnn_interface_.SetQnnInterface(p_qnn_intf); return Error::Ok; } diff --git a/backends/qualcomm/runtime/backends/QnnImplementation.h b/backends/qualcomm/runtime/backends/QnnImplementation.h index a49ee6516fc..3059166523d 100644 --- a/backends/qualcomm/runtime/backends/QnnImplementation.h +++ b/backends/qualcomm/runtime/backends/QnnImplementation.h @@ -11,9 +11,7 @@ #include #include -#include #include -#include namespace executorch { namespace backends { namespace qnn { @@ -29,32 +27,32 @@ class QnnImplementation { explicit QnnImplementation(std::string lib_path) : lib_path_(std::move(lib_path)){}; + QnnImplementation(const QnnImplementation&) = + delete; // Delete copy constructor + QnnImplementation& operator=(const QnnImplementation&) = + delete; // Delete assignment operator + ~QnnImplementation(); executorch::runtime::Error Load(const QnnSaver_Config_t** saver_config); const QnnInterface& GetQnnInterface() const; - executorch::runtime::Error TerminateAllBackends(); + executorch::runtime::Error Unload(); private: - static constexpr const int required_num_providers_{1}; + static constexpr int required_num_providers_{1}; - static executorch::runtime::Error StartBackend( + const QnnInterface_t* StartBackend( const std::string& lib_path, const QnnSaver_Config_t** saver_config); - static executorch::runtime::Error InitBackend( + executorch::runtime::Error InitBackend( void* const lib_handle, const QnnSaver_Config_t** saver_config); std::string lib_path_; + void* lib_handle_{nullptr}; QnnInterface qnn_interface_; - - static std::unordered_map lib_path_to_backend_id_; - static std::unordered_map - loaded_backend_; - static std::unordered_map loaded_lib_handle_; - static std::mutex be_init_mutex_; }; } // namespace qnn } // namespace backends diff --git a/backends/qualcomm/runtime/backends/QnnLogger.cpp b/backends/qualcomm/runtime/backends/QnnLogger.cpp index 5b86894d874..fec6d426c04 100644 --- a/backends/qualcomm/runtime/backends/QnnLogger.cpp +++ b/backends/qualcomm/runtime/backends/QnnLogger.cpp @@ -40,11 +40,11 @@ void LoggingCallback( QNN_EXECUTORCH_LOG(log_level, buffer); } QnnLogger::QnnLogger( - const QnnImplementation& implementation, + QnnImplementation* implementation, QnnLog_Callback_t callback, QnnExecuTorchLogLevel log_level) - : handle_(nullptr), implementation_(implementation) { - const QnnInterface& qnn_interface = implementation.GetQnnInterface(); + : handle_(nullptr), implementation_(implementation), log_level_(log_level) { + const QnnInterface& qnn_interface = implementation->GetQnnInterface(); QnnLog_Level_t qnn_log_level = QNN_LOG_LEVEL_ERROR; if (log_level > QnnExecuTorchLogLevel::kLogOff) { @@ -86,7 +86,7 @@ QnnLogger::QnnLogger( } QnnLogger::~QnnLogger() { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); if (handle_ != nullptr) { Qnn_ErrorHandle_t error = qnn_interface.qnn_log_free(handle_); if (error != QNN_SUCCESS) { diff --git a/backends/qualcomm/runtime/backends/QnnLogger.h b/backends/qualcomm/runtime/backends/QnnLogger.h index 80be4f61c59..d329ab94407 100644 --- a/backends/qualcomm/runtime/backends/QnnLogger.h +++ b/backends/qualcomm/runtime/backends/QnnLogger.h @@ -21,18 +21,25 @@ void LoggingCallback( class QnnLogger { public: explicit QnnLogger( - const QnnImplementation& implementation, + QnnImplementation* implementation, QnnLog_Callback_t callback, QnnExecuTorchLogLevel log_level); + QnnLogger(const QnnLogger&) = delete; // Delete copy constructor + QnnLogger& operator=(const QnnLogger&) = delete; // Delete assignment operator ~QnnLogger(); Qnn_LogHandle_t GetHandle() { return handle_; } + QnnExecuTorchLogLevel GetLogLevel() { + return log_level_; + } + private: Qnn_LogHandle_t handle_; - const QnnImplementation& implementation_; + QnnImplementation* implementation_; + QnnExecuTorchLogLevel log_level_; }; } // namespace qnn } // namespace backends diff --git a/backends/qualcomm/runtime/backends/QnnMemManager.cpp b/backends/qualcomm/runtime/backends/QnnMemManager.cpp index 3b99dd10868..bf60c65b4cf 100644 --- a/backends/qualcomm/runtime/backends/QnnMemManager.cpp +++ b/backends/qualcomm/runtime/backends/QnnMemManager.cpp @@ -25,7 +25,7 @@ Error QnnMemManager::RegisterIonMem( const std::shared_ptr& tensor_wrapper, int32_t mem_fd, void* mem_ptr) { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_MemDescriptor_t descriptor = { {tensor_wrapper->GetRank(), tensor_wrapper->GetDims(), nullptr}, tensor_wrapper->GetDataType(), @@ -56,17 +56,14 @@ Error QnnMemManager::RegisterIonMem( return Error::Ok; } -// TODO: Find a better way to unify RegisterCustomMem and -// PreRegisterCustomMemHandle Error QnnMemManager::RegisterCustomMem( const std::shared_ptr& tensor_wrapper, int32_t mem_fd, void* mem_ptr, - void* unaligned_custom_mem_base, size_t total_custom_mem_size, size_t tensor_offset, const CustomMemTensorInfo& info) { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_MemDescriptor_t descriptor = { {tensor_wrapper->GetRank(), tensor_wrapper->GetDims(), nullptr}, tensor_wrapper->GetDataType(), @@ -107,46 +104,6 @@ Error QnnMemManager::RegisterCustomMem( return Error::Ok; } -Error QnnMemManager::PreRegisterCustomMemHandle( - int32_t mem_fd, - void* unaligned_custom_mem_base, - size_t total_custom_mem_size, - size_t tensor_offset, - const CustomMemTensorInfo& info) { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); - Qnn_MemDescriptor_t descriptor = { - {info.rank, info.shape, nullptr}, - scalar_type_to_qnn_dtype_[info.dtype], - QNN_MEM_TYPE_CUSTOM, - {{mem_fd}}}; - Qnn_MemHandle_t handle = nullptr; - Qnn_ErrorHandle_t error = QNN_SUCCESS; - - QnnMemHtp_Descriptor_t htp_descriptor; - htp_descriptor.type = QNN_HTP_MEM_SHARED_BUFFER; - htp_descriptor.size = total_custom_mem_size; - - QnnHtpMem_SharedBufferConfig_t htpSharedBuffConfig = {mem_fd, tensor_offset}; - htp_descriptor.sharedBufferConfig = htpSharedBuffConfig; - - descriptor.customInfo = &htp_descriptor; - - error = qnn_interface.qnn_mem_register( - context_->GetHandle(), - &descriptor, - /*numDescriptors=*/1, - &handle); - if (error != QNN_SUCCESS) { - QNN_EXECUTORCH_LOG_WARN( - "PreRegisterCustomMemHandle fail", QNN_GET_ERROR_CODE(error)); - return Error::Internal; - } - - pre_registered_handles_.insert({info, handle}); - registered_map_.insert({handle, nullptr}); - return Error::Ok; -} - void* QnnMemManager::GetPreRegisteredHandle(const CustomMemTensorInfo& info) { auto it = pre_registered_handles_.find(info); if (it == pre_registered_handles_.end()) { @@ -165,7 +122,7 @@ Error QnnMemManager::SetMemHandle( } void QnnMemManager::DeRegisterMem() { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; for (auto& it : registered_map_) { diff --git a/backends/qualcomm/runtime/backends/QnnMemManager.h b/backends/qualcomm/runtime/backends/QnnMemManager.h index 6a7f00b016a..35b039566e5 100644 --- a/backends/qualcomm/runtime/backends/QnnMemManager.h +++ b/backends/qualcomm/runtime/backends/QnnMemManager.h @@ -20,7 +20,7 @@ namespace qnn { class QnnMemManager { public: explicit QnnMemManager( - const QnnImplementation& implementation, + QnnImplementation* implementation, QnnContext* context, QnnExecuTorchLogLevel log_level) : implementation_(implementation), @@ -39,16 +39,6 @@ class QnnMemManager { const std::shared_ptr& tensor_wrapper, int32_t mem_fd, void* mem_ptr, - void* unaligned_custom_mem_base, - size_t total_custom_mem_size, - size_t tensor_offset, - const CustomMemTensorInfo& info); - - // Pre-register custom mem handle from SharedBuffer. Bring forward the - // memHandle creating time from execution to initialization. - executorch::runtime::Error PreRegisterCustomMemHandle( - int32_t mem_fd, - void* unaligned_custom_mem_base, size_t total_custom_mem_size, size_t tensor_offset, const CustomMemTensorInfo& info); @@ -65,7 +55,7 @@ class QnnMemManager { private: void DeRegisterMem(); - const QnnImplementation& implementation_; + QnnImplementation* implementation_; QnnContext* context_; QnnExecuTorchLogLevel log_level_; // Store the registered Qnn_MemHandle_t for de-registration diff --git a/backends/qualcomm/runtime/backends/QnnProfiler.cpp b/backends/qualcomm/runtime/backends/QnnProfiler.cpp index fd580867db5..8345434a145 100644 --- a/backends/qualcomm/runtime/backends/QnnProfiler.cpp +++ b/backends/qualcomm/runtime/backends/QnnProfiler.cpp @@ -13,12 +13,12 @@ namespace backends { namespace qnn { QnnProfile::QnnProfile( - const QnnImplementation& implementation, + QnnImplementation* implementation, QnnBackend* backend, const QnnExecuTorchProfileLevel& profile_level) : handle_(nullptr), implementation_(implementation), backend_(backend) { if (profile_level != QnnExecuTorchProfileLevel::kProfileOff) { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); QnnProfile_Level_t qnnProfileLevel = 0; if (profile_level == QnnExecuTorchProfileLevel::kProfileBasic) { @@ -72,7 +72,7 @@ QnnProfile::QnnProfile( Qnn_ErrorHandle_t QnnProfile::ProfileData( executorch::runtime::EventTracer* event_tracer) { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); const QnnProfile_EventId_t* events_ptr = nullptr; const QnnProfile_EventId_t* sub_events_ptr = nullptr; std::uint32_t num_events = 0; @@ -167,7 +167,7 @@ Qnn_ErrorHandle_t QnnProfile::ProfileData( } QnnProfile::~QnnProfile() { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); if (handle_ != nullptr) { Qnn_ErrorHandle_t error = qnn_interface.qnn_profile_free(handle_); if (error != QNN_SUCCESS) { diff --git a/backends/qualcomm/runtime/backends/QnnProfiler.h b/backends/qualcomm/runtime/backends/QnnProfiler.h index e21385aca7d..de8fbd1d9d5 100644 --- a/backends/qualcomm/runtime/backends/QnnProfiler.h +++ b/backends/qualcomm/runtime/backends/QnnProfiler.h @@ -19,7 +19,7 @@ namespace qnn { class QnnProfile { public: explicit QnnProfile( - const QnnImplementation& implementation, + QnnImplementation* implementation, QnnBackend* backend, const QnnExecuTorchProfileLevel& profile_level); ~QnnProfile(); @@ -31,7 +31,7 @@ class QnnProfile { private: Qnn_ProfileHandle_t handle_; - const QnnImplementation& implementation_; + QnnImplementation* implementation_; QnnBackend* backend_; }; } // namespace qnn diff --git a/backends/qualcomm/runtime/backends/gpu/GpuBackend.cpp b/backends/qualcomm/runtime/backends/gpu/GpuBackend.cpp new file mode 100644 index 00000000000..2e23615ddd6 --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/GpuBackend.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include "GPU/QnnGpuCommon.h" + +namespace executorch { +namespace backends { +namespace qnn { + +using executorch::runtime::Error; + +GpuBackend::GpuBackend( + QnnImplementation* implementation, + QnnLogger* logger, + const QnnExecuTorchGpuBackendOptions* gpu_options) + : QnnBackend(implementation, logger) { + gpu_backend_custom_config_ = + std::make_unique(gpu_options); +} + +Qnn_Version_t GpuBackend::GetExpectedBackendVersion() const { + Qnn_Version_t backend_version; + backend_version.major = QNN_GPU_API_VERSION_MAJOR; + backend_version.minor = QNN_GPU_API_VERSION_MINOR; + backend_version.patch = QNN_GPU_API_VERSION_PATCH; + return backend_version; +} + +bool GpuBackend::IsProfileEventTypeParentOfNodeTime( + QnnProfile_EventType_t event_type) { + return (event_type == QNN_PROFILE_EVENTTYPE_EXECUTE); +} + +Error GpuBackend::MakeConfig(std::vector& config) { + const std::vector& backend_custom_config = + gpu_backend_custom_config_->CreateBackendCustomConfig(); + + uint32_t num_custom_configs = backend_custom_config.size(); + backend_config_.resize(num_custom_configs); + // +1 for null terminated + config.reserve(num_custom_configs + 1); + + for (std::size_t i = 0; i < num_custom_configs; ++i) { + backend_config_[i].option = QNN_BACKEND_CONFIG_OPTION_CUSTOM; + backend_config_[i].customConfig = backend_custom_config[i]; + config.push_back(&backend_config_[i]); + } + + config.push_back(nullptr); + return Error::Ok; +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/GpuBackend.h b/backends/qualcomm/runtime/backends/gpu/GpuBackend.h new file mode 100644 index 00000000000..1a91e85c1fd --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/GpuBackend.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace qnn { + +class GpuBackend : public QnnBackend { + public: + GpuBackend( + QnnImplementation* implementation, + QnnLogger* logger, + const QnnExecuTorchGpuBackendOptions* gpu_options); + + Qnn_Version_t GetExpectedBackendVersion() const override; + + bool IsProfileEventTypeParentOfNodeTime( + QnnProfile_EventType_t event_type) override; + + protected: + executorch::runtime::Error MakeConfig( + std::vector& config) override; + + private: + std::vector backend_config_; + std::unique_ptr gpu_backend_custom_config_; +}; + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/GpuBackendCustomConfig.cpp b/backends/qualcomm/runtime/backends/gpu/GpuBackendCustomConfig.cpp new file mode 100644 index 00000000000..60e289493d0 --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/GpuBackendCustomConfig.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace qnn { + +GpuBackendCustomConfig::GpuBackendCustomConfig( + const QnnExecuTorchGpuBackendOptions* gpu_options) + : gpu_options_(gpu_options) {} + +QnnGpuBackend_CustomConfig_t* +GpuBackendCustomConfig::AllocBackendCustomConfig() { + gpu_backend_config_.emplace_back( + std::make_unique()); + gpu_backend_config_.back()->option = QNN_GPU_BACKEND_CONFIG_OPTION_UNDEFINED; + return gpu_backend_config_.back().get(); +} + +std::vector +GpuBackendCustomConfig::CreateBackendCustomConfig() { + std::vector ret; + QnnGpuBackend_CustomConfig_t* p_custom_config = nullptr; + + if (gpu_options_->use_weight_sharing()) { + p_custom_config = AllocBackendCustomConfig(); + p_custom_config->option = + QNN_GPU_BACKEND_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; + p_custom_config->weightSharingEnabled = 1; + ret.push_back(static_cast(p_custom_config)); + } + return ret; +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/GpuBackendCustomConfig.h b/backends/qualcomm/runtime/backends/gpu/GpuBackendCustomConfig.h new file mode 100644 index 00000000000..150235a82e6 --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/GpuBackendCustomConfig.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#include "GPU/QnnGpuBackend.h" + +namespace executorch { +namespace backends { +namespace qnn { + +using namespace qnn_delegate; + +class GpuBackendCustomConfig { + public: + explicit GpuBackendCustomConfig( + const QnnExecuTorchGpuBackendOptions* gpu_options); + + std::vector CreateBackendCustomConfig(); + + private: + QnnGpuBackend_CustomConfig_t* AllocBackendCustomConfig(); + std::vector> + gpu_backend_config_; + const QnnExecuTorchGpuBackendOptions* gpu_options_; +}; + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/GpuContext.cpp b/backends/qualcomm/runtime/backends/gpu/GpuContext.cpp new file mode 100644 index 00000000000..07952e77eef --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/GpuContext.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace qnn { + +using executorch::runtime::Error; + +GpuContext::GpuContext( + QnnImplementation* implementation, + QnnBackend* backend, + QnnDevice* device, + QnnBackendCache* cache, + QnnDlcManager* qnn_dlc_manager, + const QnnExecuTorchGpuBackendOptions* gpu_options) + : QnnContext(implementation, backend, device, cache, qnn_dlc_manager) { + gpu_context_custom_config_ = + std::make_unique(gpu_options); +} + +Error GpuContext::MakeConfig(std::vector& config) { + const std::vector& context_custom_config = + gpu_context_custom_config_->CreateContextCustomConfig(); + + uint32_t num_custom_configs = context_custom_config.size(); + context_config_.resize(num_custom_configs); + // +1 for null terminated + config.reserve(num_custom_configs + 1); + + for (std::size_t i = 0; i < num_custom_configs; ++i) { + context_config_[i].option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + context_config_[i].customConfig = context_custom_config[i]; + config.push_back(&context_config_[i]); + } + + config.push_back(nullptr); + return Error::Ok; +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/GpuContext.h b/backends/qualcomm/runtime/backends/gpu/GpuContext.h new file mode 100644 index 00000000000..29a36982db9 --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/GpuContext.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace qnn { + +class QnnDlcManager; +class GpuContext : public QnnContext { + public: + GpuContext( + QnnImplementation* implementation, + QnnBackend* backend, + QnnDevice* device, + QnnBackendCache* cache, + QnnDlcManager* qnn_dlc_manager, + const QnnExecuTorchGpuBackendOptions* gpu_options); + + protected: + executorch::runtime::Error MakeConfig( + std::vector& config) override; + + private: + std::vector context_config_; + std::unique_ptr gpu_context_custom_config_; +}; + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/GpuContextCustomConfig.h b/backends/qualcomm/runtime/backends/gpu/GpuContextCustomConfig.h new file mode 100644 index 00000000000..8a1f635bee0 --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/GpuContextCustomConfig.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#include "GPU/QnnGpuContext.h" + +namespace executorch { +namespace backends { +namespace qnn { + +using namespace qnn_delegate; + +class GpuContextCustomConfig { + public: + explicit GpuContextCustomConfig( + const QnnExecuTorchGpuBackendOptions* gpu_options) + : gpu_options_(gpu_options) {} + + std::vector CreateContextCustomConfig(); + + private: + QnnGpuContext_CustomConfig_t* AllocContextCustomConfig() { + gpu_context_config_.emplace_back( + std::make_unique()); + gpu_context_config_.back()->option = + QNN_GPU_CONTEXT_CONFIG_OPTION_UNDEFINED; + return gpu_context_config_.back().get(); + } + std::vector> + gpu_context_config_; + [[maybe_unused]] const QnnExecuTorchGpuBackendOptions* gpu_options_; +}; + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/GpuDevice.h b/backends/qualcomm/runtime/backends/gpu/GpuDevice.h new file mode 100644 index 00000000000..7a0141cb566 --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/GpuDevice.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace qnn { + +class GpuDevice : public QnnDevice { + public: + GpuDevice(QnnImplementation* implementation, QnnLogger* logger) + : QnnDevice(implementation, logger){}; + + // GPU backend does not support device creation + executorch::runtime::Error Configure() override { + return executorch::runtime::Error::Ok; + } +}; + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/GpuGraph.cpp b/backends/qualcomm/runtime/backends/gpu/GpuGraph.cpp new file mode 100644 index 00000000000..286fbe498a9 --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/GpuGraph.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace qnn { + +using executorch::runtime::Error; + +GpuGraph::GpuGraph( + QnnImplementation* implementation, + QnnBackend* backend, + QnnContext* context, + const QnnExecuTorchProfileLevel& profile_level, + const QnnExecuTorchGpuBackendOptions* gpu_options) + : QnnGraph(implementation, backend, context, profile_level) { + gpu_graph_custom_config_ = + std::make_unique(gpu_options); +} + +Error GpuGraph::MakeConfig(std::vector& config) { + const std::vector& graph_custom_config = + gpu_graph_custom_config_->CreateGraphCustomConfig(); + + uint32_t num_custom_configs = graph_custom_config.size(); + graph_config_.resize(num_custom_configs); + // +1 for null terminated + config.reserve(num_custom_configs + 1); + + for (std::size_t i = 0; i < num_custom_configs; ++i) { + graph_config_[i].option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; + graph_config_[i].customConfig = graph_custom_config[i]; + config.push_back(&graph_config_[i]); + } + + config.push_back(nullptr); + return Error::Ok; +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/GpuGraph.h b/backends/qualcomm/runtime/backends/gpu/GpuGraph.h new file mode 100644 index 00000000000..8cf73216eae --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/GpuGraph.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace qnn { + +class GpuGraph : public QnnGraph { + public: + GpuGraph( + QnnImplementation* implementation, + QnnBackend* backend, + QnnContext* context, + const QnnExecuTorchProfileLevel& profile_level, + const QnnExecuTorchGpuBackendOptions* gpu_options); + + protected: + executorch::runtime::Error MakeConfig( + std::vector& config) override; + + private: + std::vector graph_config_; + std::unique_ptr gpu_graph_custom_config_; +}; + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/GpuGraphCustomConfig.cpp b/backends/qualcomm/runtime/backends/gpu/GpuGraphCustomConfig.cpp new file mode 100644 index 00000000000..17f094db805 --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/GpuGraphCustomConfig.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace qnn { + +GpuGraphCustomConfig::GpuGraphCustomConfig( + const QnnExecuTorchGpuBackendOptions* gpu_options) + : gpu_options_(gpu_options) {} + +QnnGpuGraph_CustomConfig_t* GpuGraphCustomConfig::AllocGraphCustomConfig() { + gpu_graph_config_.emplace_back( + std::make_unique()); + return gpu_graph_config_.back().get(); +} + +std::vector +GpuGraphCustomConfig::CreateGraphCustomConfig() { + std::vector ret; + QnnGpuGraph_CustomConfig_t* p_custom_config = nullptr; + + p_custom_config = AllocGraphCustomConfig(); + p_custom_config->precision = + static_cast(gpu_options_->precision()); + p_custom_config->disableMemoryOptimizations = + !gpu_options_->use_memory_optimizations(); + p_custom_config->disableNodeOptimizations = + !gpu_options_->use_node_optimizations(); + p_custom_config->disableQueueRecording = !gpu_options_->use_queue_recording(); + ret.push_back(static_cast(p_custom_config)); + return ret; +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/GpuGraphCustomConfig.h b/backends/qualcomm/runtime/backends/gpu/GpuGraphCustomConfig.h new file mode 100644 index 00000000000..a47cd1a3345 --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/GpuGraphCustomConfig.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#include "GPU/QnnGpuGraph.h" + +namespace executorch { +namespace backends { +namespace qnn { + +using namespace qnn_delegate; + +class GpuGraphCustomConfig { + public: + explicit GpuGraphCustomConfig( + const QnnExecuTorchGpuBackendOptions* gpu_options); + + std::vector CreateGraphCustomConfig(); + + private: + QnnGpuGraph_CustomConfig_t* AllocGraphCustomConfig(); + std::vector> gpu_graph_config_; + const QnnExecuTorchGpuBackendOptions* gpu_options_; +}; + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/aarch64/GpuContextCustomConfig.cpp b/backends/qualcomm/runtime/backends/gpu/aarch64/GpuContextCustomConfig.cpp new file mode 100644 index 00000000000..b4f200897ba --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/aarch64/GpuContextCustomConfig.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace qnn { + +std::vector +GpuContextCustomConfig::CreateContextCustomConfig() { + std::vector ret; + QnnGpuContext_CustomConfig_t* p_custom_config = nullptr; + + p_custom_config = AllocContextCustomConfig(); + p_custom_config->option = QNN_GPU_CONTEXT_CONFIG_OPTION_PERF_HINT; + p_custom_config->perfHint = + static_cast(gpu_options_->performance_mode()); + ret.push_back(static_cast(p_custom_config)); + return ret; +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/gpu/x86_64/GpuContextCustomConfig.cpp b/backends/qualcomm/runtime/backends/gpu/x86_64/GpuContextCustomConfig.cpp new file mode 100644 index 00000000000..69784c1797f --- /dev/null +++ b/backends/qualcomm/runtime/backends/gpu/x86_64/GpuContextCustomConfig.cpp @@ -0,0 +1,22 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace qnn { + +std::vector +GpuContextCustomConfig::CreateContextCustomConfig() { + return {}; +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpBackend.h b/backends/qualcomm/runtime/backends/htp/HtpBackend.h similarity index 94% rename from backends/qualcomm/runtime/backends/htpbackend/HtpBackend.h rename to backends/qualcomm/runtime/backends/htp/HtpBackend.h index 5b5b1586cdb..3e3f727ecea 100644 --- a/backends/qualcomm/runtime/backends/htpbackend/HtpBackend.h +++ b/backends/qualcomm/runtime/backends/htp/HtpBackend.h @@ -16,7 +16,7 @@ namespace backends { namespace qnn { class HtpBackend : public QnnBackend { public: - HtpBackend(const QnnImplementation& implementation, QnnLogger* logger) + HtpBackend(QnnImplementation* implementation, QnnLogger* logger) : QnnBackend(implementation, logger) {} ~HtpBackend() {} diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpBackendCache.cpp b/backends/qualcomm/runtime/backends/htp/HtpBackendCache.cpp similarity index 96% rename from backends/qualcomm/runtime/backends/htpbackend/HtpBackendCache.cpp rename to backends/qualcomm/runtime/backends/htp/HtpBackendCache.cpp index 030b5666daf..3038a100d03 100644 --- a/backends/qualcomm/runtime/backends/htpbackend/HtpBackendCache.cpp +++ b/backends/qualcomm/runtime/backends/htp/HtpBackendCache.cpp @@ -5,7 +5,7 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include +#include #include "HTP/QnnHtpSystemContext.h" namespace executorch { diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpBackendCache.h b/backends/qualcomm/runtime/backends/htp/HtpBackendCache.h similarity index 100% rename from backends/qualcomm/runtime/backends/htpbackend/HtpBackendCache.h rename to backends/qualcomm/runtime/backends/htp/HtpBackendCache.h diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpContext.cpp b/backends/qualcomm/runtime/backends/htp/HtpContext.cpp similarity index 94% rename from backends/qualcomm/runtime/backends/htpbackend/HtpContext.cpp rename to backends/qualcomm/runtime/backends/htp/HtpContext.cpp index 50d299b55e9..0056a2c0917 100644 --- a/backends/qualcomm/runtime/backends/htpbackend/HtpContext.cpp +++ b/backends/qualcomm/runtime/backends/htp/HtpContext.cpp @@ -7,7 +7,7 @@ */ #include -#include +#include #include "HTP/QnnHtpCommon.h" diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpContext.h b/backends/qualcomm/runtime/backends/htp/HtpContext.h similarity index 91% rename from backends/qualcomm/runtime/backends/htpbackend/HtpContext.h rename to backends/qualcomm/runtime/backends/htp/HtpContext.h index 88660db080a..a0389ea5983 100644 --- a/backends/qualcomm/runtime/backends/htpbackend/HtpContext.h +++ b/backends/qualcomm/runtime/backends/htp/HtpContext.h @@ -10,7 +10,7 @@ #include #include -#include +#include namespace executorch { namespace backends { @@ -20,7 +20,7 @@ class QnnDlcManager; class HtpContext : public QnnContext { public: HtpContext( - const QnnImplementation& implementation, + QnnImplementation* implementation, QnnBackend* backend, QnnDevice* device, QnnBackendCache* cache, diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpContextCustomConfig.h b/backends/qualcomm/runtime/backends/htp/HtpContextCustomConfig.h similarity index 100% rename from backends/qualcomm/runtime/backends/htpbackend/HtpContextCustomConfig.h rename to backends/qualcomm/runtime/backends/htp/HtpContextCustomConfig.h diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.cpp b/backends/qualcomm/runtime/backends/htp/HtpDevice.cpp similarity index 99% rename from backends/qualcomm/runtime/backends/htpbackend/HtpDevice.cpp rename to backends/qualcomm/runtime/backends/htp/HtpDevice.cpp index 35a20048fc5..cb7700844b7 100644 --- a/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.cpp +++ b/backends/qualcomm/runtime/backends/htp/HtpDevice.cpp @@ -7,7 +7,7 @@ */ #include -#include +#include #include "HTP/QnnHtpCommon.h" #include "Saver/QnnSaverCommon.h" @@ -376,7 +376,7 @@ void HtpDevice::ReleasePerformanceVote() { Error HtpDevice::AfterCreateDevice() { if (IsPerfModeEnabled()) { - const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + const QnnInterface& qnn_interface = implementation_->GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; // Get htp_perf_infra diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.h b/backends/qualcomm/runtime/backends/htp/HtpDevice.h similarity index 90% rename from backends/qualcomm/runtime/backends/htpbackend/HtpDevice.h rename to backends/qualcomm/runtime/backends/htp/HtpDevice.h index 9052deb6b52..abc5fde00d1 100644 --- a/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.h +++ b/backends/qualcomm/runtime/backends/htp/HtpDevice.h @@ -9,8 +9,8 @@ #include #include -#include -#include +#include +#include #include #include "HTP/QnnHtpDevice.h" @@ -24,7 +24,7 @@ namespace qnn { class HtpDevice : public QnnDevice { public: HtpDevice( - const QnnImplementation& implementation, + QnnImplementation* implementation, QnnLogger* logger, const SocInfo* soc_info, const QnnExecuTorchHtpBackendOptions* htp_options) @@ -38,7 +38,7 @@ class HtpDevice : public QnnDevice { } ~HtpDevice(); - // Defines Qnn performance mode vote types for htpbackend + // Defines Qnn performance mode vote types for htp enum PerformanceModeVoteType { kNoVote = 0, kUpVote = 1, diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpDeviceCustomConfig.h b/backends/qualcomm/runtime/backends/htp/HtpDeviceCustomConfig.h similarity index 100% rename from backends/qualcomm/runtime/backends/htpbackend/HtpDeviceCustomConfig.h rename to backends/qualcomm/runtime/backends/htp/HtpDeviceCustomConfig.h diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpDevicePlatformInfoConfig.h b/backends/qualcomm/runtime/backends/htp/HtpDevicePlatformInfoConfig.h similarity index 100% rename from backends/qualcomm/runtime/backends/htpbackend/HtpDevicePlatformInfoConfig.h rename to backends/qualcomm/runtime/backends/htp/HtpDevicePlatformInfoConfig.h diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpGraph.cpp b/backends/qualcomm/runtime/backends/htp/HtpGraph.cpp similarity index 93% rename from backends/qualcomm/runtime/backends/htpbackend/HtpGraph.cpp rename to backends/qualcomm/runtime/backends/htp/HtpGraph.cpp index 29dcf0a58c3..6208febe61a 100644 --- a/backends/qualcomm/runtime/backends/htpbackend/HtpGraph.cpp +++ b/backends/qualcomm/runtime/backends/htp/HtpGraph.cpp @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include namespace executorch { namespace backends { namespace qnn { diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpGraph.h b/backends/qualcomm/runtime/backends/htp/HtpGraph.h similarity index 90% rename from backends/qualcomm/runtime/backends/htpbackend/HtpGraph.h rename to backends/qualcomm/runtime/backends/htp/HtpGraph.h index c3add50d08b..d9e5964ddd7 100644 --- a/backends/qualcomm/runtime/backends/htpbackend/HtpGraph.h +++ b/backends/qualcomm/runtime/backends/htp/HtpGraph.h @@ -8,7 +8,7 @@ #pragma once #include -#include +#include #include @@ -19,7 +19,7 @@ namespace qnn { class HtpGraph : public QnnGraph { public: HtpGraph( - const QnnImplementation& implementation, + QnnImplementation* implementation, QnnBackend* backend, QnnContext* context, const QnnExecuTorchProfileLevel& profile_level, diff --git a/backends/qualcomm/runtime/backends/htp/HtpGraphCustomConfig.cpp b/backends/qualcomm/runtime/backends/htp/HtpGraphCustomConfig.cpp new file mode 100644 index 00000000000..17b8438880d --- /dev/null +++ b/backends/qualcomm/runtime/backends/htp/HtpGraphCustomConfig.cpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +namespace executorch { +namespace backends { +namespace qnn { +std::vector +HtpGraphCustomConfig::CreateGraphCustomConfigCommon( + const SocInfo* qcom_target_soc_info, + float opt_level) { + std::vector ret; + QnnHtpGraph_CustomConfig_t* p_custom_config = nullptr; + + if (!htp_options_->use_conv_hmx()) { + p_custom_config = AllocGraphCustomConfig(); + p_custom_config->option = + QNN_HTP_GRAPH_CONFIG_OPTION_SHORT_DEPTH_CONV_ON_HMX_OFF; + p_custom_config->shortDepthConvOnHmxOff = true; + ret.push_back(static_cast(p_custom_config)); + } + + if (!htp_options_->use_fold_relu()) { + p_custom_config = AllocGraphCustomConfig(); + p_custom_config->option = + QNN_HTP_GRAPH_CONFIG_OPTION_FOLD_RELU_ACTIVATION_INTO_CONV_OFF; + p_custom_config->foldReluActivationIntoConvOff = true; + ret.push_back(static_cast(p_custom_config)); + } + + switch (htp_options_->precision()) { + case QnnExecuTorchHtpPrecision::kHtpFp16: + p_custom_config = AllocGraphCustomConfig(); + p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_PRECISION; + p_custom_config->precision = QNN_PRECISION_FLOAT16; + ret.push_back(static_cast(p_custom_config)); + break; + case QnnExecuTorchHtpPrecision::kHtpQuantized: + default: + break; + } + + QNN_EXECUTORCH_LOG_INFO( + "Running level=%d optimization.", static_cast(opt_level)); + + p_custom_config = AllocGraphCustomConfig(); + p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; + p_custom_config->optimizationOption.type = + QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; + p_custom_config->optimizationOption.floatValue = opt_level; + ret.push_back(static_cast(p_custom_config)); + + p_custom_config = AllocGraphCustomConfig(); + p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE; + p_custom_config->vtcmSizeInMB = + qcom_target_soc_info->htp_info()->vtcm_size_in_mb(); + ret.push_back(static_cast(p_custom_config)); + + p_custom_config = AllocGraphCustomConfig(); + p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; + p_custom_config->optimizationOption.type = + QNN_HTP_GRAPH_OPTIMIZATION_TYPE_ENABLE_DLBC; + p_custom_config->optimizationOption.floatValue = + htp_options_->use_dlbc() ? 1.0 : 0.0; + ret.push_back(static_cast(p_custom_config)); + + return ret; +} +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpGraphCustomConfig.h b/backends/qualcomm/runtime/backends/htp/HtpGraphCustomConfig.h similarity index 100% rename from backends/qualcomm/runtime/backends/htpbackend/HtpGraphCustomConfig.h rename to backends/qualcomm/runtime/backends/htp/HtpGraphCustomConfig.h diff --git a/backends/qualcomm/runtime/backends/htp/aarch64/HtpContextCustomConfig.cpp b/backends/qualcomm/runtime/backends/htp/aarch64/HtpContextCustomConfig.cpp new file mode 100644 index 00000000000..676795797f8 --- /dev/null +++ b/backends/qualcomm/runtime/backends/htp/aarch64/HtpContextCustomConfig.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace executorch { +namespace backends { +namespace qnn { + +std::vector +HtpContextCustomConfig::CreateContextCustomConfig() { + std::vector ret; + QnnHtpContext_CustomConfig_t* p_custom_config = nullptr; + const HtpContext* htp_ctx = static_cast(context_); + + if (htp_options_->use_multi_contexts() && + htp_options_->max_sf_buf_size() != 0) { + p_custom_config = AllocContextCustomConfig(); + p_custom_config->option = + QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS; + QnnHtpContext_GroupRegistration_t group_info; + group_info.firstGroupHandle = htp_ctx->GetSpillFillHandle(); + group_info.maxSpillFillBuffer = htp_options_->max_sf_buf_size(); + p_custom_config->groupRegistration = group_info; + ret.push_back(static_cast(p_custom_config)); + } + + return ret; +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htp/aarch64/HtpDeviceCustomConfig.cpp b/backends/qualcomm/runtime/backends/htp/aarch64/HtpDeviceCustomConfig.cpp new file mode 100644 index 00000000000..8207f5071ba --- /dev/null +++ b/backends/qualcomm/runtime/backends/htp/aarch64/HtpDeviceCustomConfig.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +namespace executorch { +namespace backends { +namespace qnn { +std::vector +HtpDeviceCustomConfig::CreateDeviceCustomConfig( + const SocInfo* /*qcom_target_soc_info*/) { + return {}; +} +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htp/aarch64/HtpDevicePlatformInfoConfig.cpp b/backends/qualcomm/runtime/backends/htp/aarch64/HtpDevicePlatformInfoConfig.cpp new file mode 100644 index 00000000000..91221a78fd6 --- /dev/null +++ b/backends/qualcomm/runtime/backends/htp/aarch64/HtpDevicePlatformInfoConfig.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +namespace executorch { +namespace backends { +namespace qnn { +std::vector +HtpDevicePlatformInfoConfig::CreateDevicePlatformInfo( + const SocInfo* /*qcom_target_soc_info*/) { + return {}; +} +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htp/aarch64/HtpGraphCustomConfig.cpp b/backends/qualcomm/runtime/backends/htp/aarch64/HtpGraphCustomConfig.cpp new file mode 100644 index 00000000000..faac23edc12 --- /dev/null +++ b/backends/qualcomm/runtime/backends/htp/aarch64/HtpGraphCustomConfig.cpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace qnn { +std::vector +HtpGraphCustomConfig::CreateGraphCustomConfig( + const SocInfo* qcom_target_soc_info) { + return CreateGraphCustomConfigCommon(qcom_target_soc_info, 1); +} +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htp/x86_64/HtpContextCustomConfig.cpp b/backends/qualcomm/runtime/backends/htp/x86_64/HtpContextCustomConfig.cpp new file mode 100644 index 00000000000..4850afa14a2 --- /dev/null +++ b/backends/qualcomm/runtime/backends/htp/x86_64/HtpContextCustomConfig.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace qnn { + +std::vector +HtpContextCustomConfig::CreateContextCustomConfig() { + std::vector ret; + QnnHtpContext_CustomConfig_t* p_custom_config = nullptr; + + if (htp_options_->use_weight_sharing()) { + p_custom_config = AllocContextCustomConfig(); + p_custom_config->option = + QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; + p_custom_config->weightSharingEnabled = true; + ret.push_back(static_cast(p_custom_config)); + } + + return ret; +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htp/x86_64/HtpDeviceCustomConfig.cpp b/backends/qualcomm/runtime/backends/htp/x86_64/HtpDeviceCustomConfig.cpp new file mode 100644 index 00000000000..9afbf489bc1 --- /dev/null +++ b/backends/qualcomm/runtime/backends/htp/x86_64/HtpDeviceCustomConfig.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +namespace executorch { +namespace backends { +namespace qnn { +std::vector +HtpDeviceCustomConfig::CreateDeviceCustomConfig( + const SocInfo* qcom_target_soc_info) { + std::vector ret; + QnnHtpDevice_CustomConfig_t* p_custom_config = nullptr; + + p_custom_config = AllocDeviceCustomConfig(); + p_custom_config->option = QNN_HTP_DEVICE_CONFIG_OPTION_SOC; + p_custom_config->socModel = + static_cast(qcom_target_soc_info->soc_model()); + ret.push_back(static_cast(p_custom_config)); + + return ret; +} +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htp/x86_64/HtpDevicePlatformInfoConfig.cpp b/backends/qualcomm/runtime/backends/htp/x86_64/HtpDevicePlatformInfoConfig.cpp new file mode 100644 index 00000000000..15c677e8a68 --- /dev/null +++ b/backends/qualcomm/runtime/backends/htp/x86_64/HtpDevicePlatformInfoConfig.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +namespace executorch { +namespace backends { +namespace qnn { +std::vector +HtpDevicePlatformInfoConfig::CreateDevicePlatformInfo( + const SocInfo* qcom_target_soc_info) { + std::vector ret; + QnnDevice_PlatformInfo_t* p_platform_info = nullptr; + QnnDevice_HardwareDeviceInfo_t* p_hw_device_info = nullptr; + QnnHtpDevice_DeviceInfoExtension_t* p_device_info_extension = nullptr; + QnnDevice_CoreInfo_t* p_core_info = nullptr; + + p_platform_info = AllocDevicePlatformInfo(); + p_platform_info->version = QNN_DEVICE_PLATFORM_INFO_VERSION_1; + p_platform_info->v1.numHwDevices = 1; + + p_hw_device_info = AllocHwDeviceInfo(); + p_hw_device_info->version = QNN_DEVICE_HARDWARE_DEVICE_INFO_VERSION_1; + p_hw_device_info->v1.deviceId = 0; + p_hw_device_info->v1.deviceType = 0; + p_hw_device_info->v1.numCores = 1; + + p_device_info_extension = AllocDeviceInfoExtension(); + p_device_info_extension->devType = QNN_HTP_DEVICE_TYPE_ON_CHIP; + p_device_info_extension->onChipDevice.vtcmSize = + qcom_target_soc_info->htp_info()->vtcm_size_in_mb(); + // Given by user, default value is unsigned pd + p_device_info_extension->onChipDevice.signedPdSupport = + htp_options_->pd_session() == QnnExecuTorchHtpPdSession::kHtpSignedPd; + p_device_info_extension->onChipDevice.socModel = + static_cast(qcom_target_soc_info->soc_model()); + p_device_info_extension->onChipDevice.arch = static_cast( + qcom_target_soc_info->htp_info()->htp_arch()); + // For Htp, dlbcSupport is true + p_device_info_extension->onChipDevice.dlbcSupport = true; + p_hw_device_info->v1.deviceInfoExtension = p_device_info_extension; + + p_core_info = AllocCoreInfo(); + p_core_info->version = QNN_DEVICE_CORE_INFO_VERSION_1; + p_core_info->v1.coreId = 0; + p_core_info->v1.coreType = 0; + p_core_info->v1.coreInfoExtension = nullptr; + p_hw_device_info->v1.cores = p_core_info; + + p_platform_info->v1.hwDevices = p_hw_device_info; + ret.push_back(p_platform_info); + + return ret; +} +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htp/x86_64/HtpGraphCustomConfig.cpp b/backends/qualcomm/runtime/backends/htp/x86_64/HtpGraphCustomConfig.cpp new file mode 100644 index 00000000000..ec01f2bbfdd --- /dev/null +++ b/backends/qualcomm/runtime/backends/htp/x86_64/HtpGraphCustomConfig.cpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace qnn { +std::vector +HtpGraphCustomConfig::CreateGraphCustomConfig( + const SocInfo* qcom_target_soc_info) { + return CreateGraphCustomConfigCommon(qcom_target_soc_info, 3); +} +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpGraphCustomConfig.cpp b/backends/qualcomm/runtime/backends/htpbackend/HtpGraphCustomConfig.cpp deleted file mode 100644 index d43f8320285..00000000000 --- a/backends/qualcomm/runtime/backends/htpbackend/HtpGraphCustomConfig.cpp +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -#include -#include -namespace executorch { -namespace backends { -namespace qnn { -std::vector -HtpGraphCustomConfig::CreateGraphCustomConfigCommon( - const SocInfo* qcom_target_soc_info, - float opt_level) { - std::vector ret; - QnnHtpGraph_CustomConfig_t* p_custom_config = nullptr; - - if (!htp_options_->use_conv_hmx()) { - p_custom_config = AllocGraphCustomConfig(); - p_custom_config->option = - QNN_HTP_GRAPH_CONFIG_OPTION_SHORT_DEPTH_CONV_ON_HMX_OFF; - p_custom_config->shortDepthConvOnHmxOff = true; - ret.push_back(static_cast(p_custom_config)); - } - - if (!htp_options_->use_fold_relu()) { - p_custom_config = AllocGraphCustomConfig(); - p_custom_config->option = - QNN_HTP_GRAPH_CONFIG_OPTION_FOLD_RELU_ACTIVATION_INTO_CONV_OFF; - p_custom_config->foldReluActivationIntoConvOff = true; - ret.push_back(static_cast(p_custom_config)); - } - - switch (htp_options_->precision()) { - case QnnExecuTorchHtpPrecision::kHtpFp16: - p_custom_config = AllocGraphCustomConfig(); - p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_PRECISION; - p_custom_config->precision = QNN_PRECISION_FLOAT16; - ret.push_back(static_cast(p_custom_config)); - break; - case QnnExecuTorchHtpPrecision::kHtpQuantized: - default: - break; - } - - QNN_EXECUTORCH_LOG_INFO( - "Running level=%d optimization.", static_cast(opt_level)); - - p_custom_config = AllocGraphCustomConfig(); - p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; - p_custom_config->optimizationOption.type = - QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; - p_custom_config->optimizationOption.floatValue = opt_level; - ret.push_back(static_cast(p_custom_config)); - - p_custom_config = AllocGraphCustomConfig(); - p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE; - p_custom_config->vtcmSizeInMB = - qcom_target_soc_info->htp_info()->vtcm_size_in_mb(); - ret.push_back(static_cast(p_custom_config)); - - p_custom_config = AllocGraphCustomConfig(); - p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; - p_custom_config->optimizationOption.type = - QNN_HTP_GRAPH_OPTIMIZATION_TYPE_ENABLE_DLBC; - p_custom_config->optimizationOption.floatValue = - htp_options_->use_dlbc() ? 1.0 : 0.0; - ret.push_back(static_cast(p_custom_config)); - - return ret; -} -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpContextCustomConfig.cpp b/backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpContextCustomConfig.cpp deleted file mode 100644 index 04a5d844dd0..00000000000 --- a/backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpContextCustomConfig.cpp +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -namespace executorch { -namespace backends { -namespace qnn { - -std::vector -HtpContextCustomConfig::CreateContextCustomConfig() { - std::vector ret; - QnnHtpContext_CustomConfig_t* p_custom_config = nullptr; - const HtpContext* htp_ctx = static_cast(context_); - - if (htp_options_->use_multi_contexts() && - htp_options_->max_sf_buf_size() != 0) { - p_custom_config = AllocContextCustomConfig(); - p_custom_config->option = - QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS; - QnnHtpContext_GroupRegistration_t group_info; - group_info.firstGroupHandle = htp_ctx->GetSpillFillHandle(); - group_info.maxSpillFillBuffer = htp_options_->max_sf_buf_size(); - p_custom_config->groupRegistration = group_info; - ret.push_back(static_cast(p_custom_config)); - } - - return ret; -} - -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpDeviceCustomConfig.cpp b/backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpDeviceCustomConfig.cpp deleted file mode 100644 index 81ac4a14372..00000000000 --- a/backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpDeviceCustomConfig.cpp +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -namespace executorch { -namespace backends { -namespace qnn { -std::vector -HtpDeviceCustomConfig::CreateDeviceCustomConfig( - const SocInfo* /*qcom_target_soc_info*/) { - return {}; -} -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpDevicePlatformInfoConfig.cpp b/backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpDevicePlatformInfoConfig.cpp deleted file mode 100644 index c191791fa63..00000000000 --- a/backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpDevicePlatformInfoConfig.cpp +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -namespace executorch { -namespace backends { -namespace qnn { -std::vector -HtpDevicePlatformInfoConfig::CreateDevicePlatformInfo( - const SocInfo* /*qcom_target_soc_info*/) { - return {}; -} -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpGraphCustomConfig.cpp b/backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpGraphCustomConfig.cpp deleted file mode 100644 index 096fda7b059..00000000000 --- a/backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpGraphCustomConfig.cpp +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -namespace executorch { -namespace backends { -namespace qnn { -std::vector -HtpGraphCustomConfig::CreateGraphCustomConfig( - const SocInfo* qcom_target_soc_info) { - return CreateGraphCustomConfigCommon(qcom_target_soc_info, 1); -} -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htpbackend/x86_64/HtpContextCustomConfig.cpp b/backends/qualcomm/runtime/backends/htpbackend/x86_64/HtpContextCustomConfig.cpp deleted file mode 100644 index 1fc2940eaa7..00000000000 --- a/backends/qualcomm/runtime/backends/htpbackend/x86_64/HtpContextCustomConfig.cpp +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -namespace executorch { -namespace backends { -namespace qnn { - -std::vector -HtpContextCustomConfig::CreateContextCustomConfig() { - std::vector ret; - QnnHtpContext_CustomConfig_t* p_custom_config = nullptr; - - if (htp_options_->use_weight_sharing()) { - p_custom_config = AllocContextCustomConfig(); - p_custom_config->option = - QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; - p_custom_config->weightSharingEnabled = true; - ret.push_back(static_cast(p_custom_config)); - } - - return ret; -} - -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htpbackend/x86_64/HtpDeviceCustomConfig.cpp b/backends/qualcomm/runtime/backends/htpbackend/x86_64/HtpDeviceCustomConfig.cpp deleted file mode 100644 index 154433c10b0..00000000000 --- a/backends/qualcomm/runtime/backends/htpbackend/x86_64/HtpDeviceCustomConfig.cpp +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -namespace executorch { -namespace backends { -namespace qnn { -std::vector -HtpDeviceCustomConfig::CreateDeviceCustomConfig( - const SocInfo* qcom_target_soc_info) { - std::vector ret; - QnnHtpDevice_CustomConfig_t* p_custom_config = nullptr; - - p_custom_config = AllocDeviceCustomConfig(); - p_custom_config->option = QNN_HTP_DEVICE_CONFIG_OPTION_SOC; - p_custom_config->socModel = - static_cast(qcom_target_soc_info->soc_model()); - ret.push_back(static_cast(p_custom_config)); - - return ret; -} -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htpbackend/x86_64/HtpDevicePlatformInfoConfig.cpp b/backends/qualcomm/runtime/backends/htpbackend/x86_64/HtpDevicePlatformInfoConfig.cpp deleted file mode 100644 index b025f0b2aa6..00000000000 --- a/backends/qualcomm/runtime/backends/htpbackend/x86_64/HtpDevicePlatformInfoConfig.cpp +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -#include -namespace executorch { -namespace backends { -namespace qnn { -std::vector -HtpDevicePlatformInfoConfig::CreateDevicePlatformInfo( - const SocInfo* qcom_target_soc_info) { - std::vector ret; - QnnDevice_PlatformInfo_t* p_platform_info = nullptr; - QnnDevice_HardwareDeviceInfo_t* p_hw_device_info = nullptr; - QnnHtpDevice_DeviceInfoExtension_t* p_device_info_extension = nullptr; - QnnDevice_CoreInfo_t* p_core_info = nullptr; - - p_platform_info = AllocDevicePlatformInfo(); - p_platform_info->version = QNN_DEVICE_PLATFORM_INFO_VERSION_1; - p_platform_info->v1.numHwDevices = 1; - - p_hw_device_info = AllocHwDeviceInfo(); - p_hw_device_info->version = QNN_DEVICE_HARDWARE_DEVICE_INFO_VERSION_1; - p_hw_device_info->v1.deviceId = 0; - p_hw_device_info->v1.deviceType = 0; - p_hw_device_info->v1.numCores = 1; - - p_device_info_extension = AllocDeviceInfoExtension(); - p_device_info_extension->devType = QNN_HTP_DEVICE_TYPE_ON_CHIP; - p_device_info_extension->onChipDevice.vtcmSize = - qcom_target_soc_info->htp_info()->vtcm_size_in_mb(); - // Given by user, default value is unsigned pd - p_device_info_extension->onChipDevice.signedPdSupport = - htp_options_->pd_session() == QnnExecuTorchHtpPdSession::kHtpSignedPd; - p_device_info_extension->onChipDevice.socModel = - static_cast(qcom_target_soc_info->soc_model()); - p_device_info_extension->onChipDevice.arch = static_cast( - qcom_target_soc_info->htp_info()->htp_arch()); - // For Htp, dlbcSupport is true - p_device_info_extension->onChipDevice.dlbcSupport = true; - p_hw_device_info->v1.deviceInfoExtension = p_device_info_extension; - - p_core_info = AllocCoreInfo(); - p_core_info->version = QNN_DEVICE_CORE_INFO_VERSION_1; - p_core_info->v1.coreId = 0; - p_core_info->v1.coreType = 0; - p_core_info->v1.coreInfoExtension = nullptr; - p_hw_device_info->v1.cores = p_core_info; - - p_platform_info->v1.hwDevices = p_hw_device_info; - ret.push_back(p_platform_info); - - return ret; -} -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/htpbackend/x86_64/HtpGraphCustomConfig.cpp b/backends/qualcomm/runtime/backends/htpbackend/x86_64/HtpGraphCustomConfig.cpp deleted file mode 100644 index 330ca43e20b..00000000000 --- a/backends/qualcomm/runtime/backends/htpbackend/x86_64/HtpGraphCustomConfig.cpp +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -namespace executorch { -namespace backends { -namespace qnn { -std::vector -HtpGraphCustomConfig::CreateGraphCustomConfig( - const SocInfo* qcom_target_soc_info) { - return CreateGraphCustomConfigCommon(qcom_target_soc_info, 3); -} -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/irbackend/IrBackend.h b/backends/qualcomm/runtime/backends/ir/IrBackend.h similarity index 93% rename from backends/qualcomm/runtime/backends/irbackend/IrBackend.h rename to backends/qualcomm/runtime/backends/ir/IrBackend.h index ddeb3a24460..72bb59c84f9 100644 --- a/backends/qualcomm/runtime/backends/irbackend/IrBackend.h +++ b/backends/qualcomm/runtime/backends/ir/IrBackend.h @@ -18,7 +18,7 @@ namespace backends { namespace qnn { class IrBackend : public QnnBackend { public: - IrBackend(const QnnImplementation& implementation, QnnLogger* logger) + IrBackend(QnnImplementation* implementation, QnnLogger* logger) : QnnBackend(implementation, logger) {} ~IrBackend() {} diff --git a/backends/qualcomm/runtime/backends/irbackend/IrContext.h b/backends/qualcomm/runtime/backends/ir/IrContext.h similarity index 100% rename from backends/qualcomm/runtime/backends/irbackend/IrContext.h rename to backends/qualcomm/runtime/backends/ir/IrContext.h diff --git a/backends/qualcomm/runtime/backends/ir/aarch64/IrContext.cpp b/backends/qualcomm/runtime/backends/ir/aarch64/IrContext.cpp new file mode 100644 index 00000000000..12a27b19ccd --- /dev/null +++ b/backends/qualcomm/runtime/backends/ir/aarch64/IrContext.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace qnn { + +using executorch::runtime::Error; + +Error IrContext::GetContextBinary( + QnnExecuTorchContextBinary& qnn_executorch_context_binary) { + return Error::Ok; +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/ir/aarch64/QnnDlcManager.cpp b/backends/qualcomm/runtime/backends/ir/aarch64/QnnDlcManager.cpp new file mode 100644 index 00000000000..6512b5730b5 --- /dev/null +++ b/backends/qualcomm/runtime/backends/ir/aarch64/QnnDlcManager.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace qnn { + +QnnDlcManager::QnnDlcManager( + const QnnExecuTorchContextBinary& qnn_context_blob, + const QnnExecuTorchOptions* options) + : qnn_context_blob_(qnn_context_blob), options_(options) { + if (options_ == nullptr) { + QNN_EXECUTORCH_LOG_ERROR( + "Fail to create QnnDlcManager, options is nullptr"); + } +} + +Error QnnDlcManager::LoadQnnIrLibrary() { + return Error::Ok; +} + +Error QnnDlcManager::Create() { + return Error::Ok; +} + +Error QnnDlcManager::Configure(const std::vector& graph_names) { + return Error::Ok; +} + +Error QnnDlcManager::SetUpDlcEnvironment( + const Qnn_Version_t& coreApiVersion, + const std::vector& graph_names) { + return Error::Ok; +} + +Error QnnDlcManager::RegisterGraphsFromDLC( + QnnImplementation* implementation, + QnnBackend* backend, + QnnContext* context, + QnnBackendCache* cache) { + void* lib_handle = dlopen(dlc_lib_, RTLD_NOW | RTLD_LOCAL); + if (lib_handle == nullptr) { + QNN_EXECUTORCH_LOG_ERROR( + "Cannot Open lib %s, with error: %s", dlc_lib_, dlerror()); + return Error::Internal; + } + QnnModel_composeGraphsFromDlc composeGraphsFromDlc = + loadQnnFunction( + lib_handle, "QnnModel_composeGraphsFromDlc"); + if (composeGraphsFromDlc == nullptr) { + QNN_EXECUTORCH_LOG_ERROR( + "Cannot load symbol " + "QnnModel_composeGraphsFromDlc : %s", + dlerror()); + return Error::Internal; + } + + // memfd_create on android api level 30 and above + int fd = -1; +#ifdef __ANDROID__ +#if __ANDROID_API__ >= 30 + fd = memfd_create("tmp.dlc", 0); +#endif +#endif + if (fd == -1) { + QNN_EXECUTORCH_LOG_ERROR("memfd_create fail"); + return Error::Internal; + } + + if (ftruncate(fd, qnn_context_blob_.nbytes) == -1) { + QNN_EXECUTORCH_LOG_ERROR("ftruncate fail"); + close(fd); + return Error::Internal; + } + + void* addr = mmap( + NULL, + qnn_context_blob_.nbytes, + PROT_READ | PROT_WRITE, + MAP_SHARED, + fd, + 0); + if (addr == MAP_FAILED) { + QNN_EXECUTORCH_LOG_ERROR("mmap"); + close(fd); + return Error::Internal; + } + + memcpy(addr, qnn_context_blob_.buffer, qnn_context_blob_.nbytes); + + char dlc_path[256]; + snprintf(dlc_path, sizeof(dlc_path), "/proc/self/fd/%d", fd); + + const QNN_INTERFACE_VER_TYPE& interfaceVer = + implementation->GetQnnInterface().GetInterfaceVer(); + + if (composeGraphsFromDlc( + /*backendHandle=*/backend->GetHandle(), + /*interface=*/interfaceVer, + /*contextHandle=*/context->GetHandle(), + /*graphsConfigInfo=*/nullptr, + /*dlcPath=*/dlc_path, + /*numGraphsConfigInfo=*/0, + /*graphsInfo=*/&qnn_dlc_graph_info_, + /*numGraphsInfo=*/&qnn_dlc_graph_info_num_, + /*debug=*/false, + /*logCallback=*/nullptr, + /*maxLogLevel=*/QNN_LOG_LEVEL_VERBOSE) != + qnn_wrapper_api::ModelError_t::MODEL_NO_ERROR) { + QNN_EXECUTORCH_LOG_ERROR("Failed to open Dlc"); + return Error::Internal; + } + munmap(addr, qnn_context_blob_.nbytes); + close(fd); + dlclose(lib_handle); + + for (uint32_t i = 0; i < qnn_dlc_graph_info_num_; ++i) { + auto& graphInfo = (*qnn_dlc_graph_info_)[i]; + cache->SetGraphNames(graphInfo.graphName); + } + + return Error::Ok; +} + +void QnnDlcManager::Destroy() {} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/ir/x86_64/IrContext.cpp b/backends/qualcomm/runtime/backends/ir/x86_64/IrContext.cpp new file mode 100644 index 00000000000..cf5df3de8e9 --- /dev/null +++ b/backends/qualcomm/runtime/backends/ir/x86_64/IrContext.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +namespace executorch { +namespace backends { +namespace qnn { + +using executorch::runtime::Error; + +Error IrContext::GetContextBinary( + QnnExecuTorchContextBinary& qnn_executorch_context_binary) { + // read Dlc and write to buffer + std::string dlc_name = GetGraphNames()[0] + ".dlc"; + std::ifstream dlc_file(dlc_name, std::ios::binary | std::ios::ate); + if (dlc_file.is_open()) { + std::streamsize size = dlc_file.tellg(); + dlc_file.seekg(0, std::ios::beg); + + buffer_ = std::vector(size); + dlc_file.read(buffer_.data(), size); + dlc_file.close(); + qnn_executorch_context_binary.buffer = buffer_.data(); + qnn_executorch_context_binary.nbytes = size; + return Error::Ok; + } else { + QNN_EXECUTORCH_LOG_ERROR( + "Unable to open dlc file %s for building QnnExecuTorchContextBinary", + dlc_name.c_str()); + } + return Error::Internal; +} +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/ir/x86_64/QnnDlcManager.cpp b/backends/qualcomm/runtime/backends/ir/x86_64/QnnDlcManager.cpp new file mode 100644 index 00000000000..ee7e79cfa80 --- /dev/null +++ b/backends/qualcomm/runtime/backends/ir/x86_64/QnnDlcManager.cpp @@ -0,0 +1,138 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +namespace executorch { +namespace backends { +namespace qnn { + +QnnDlcManager::QnnDlcManager( + const QnnExecuTorchContextBinary& qnn_context_blob, + const QnnExecuTorchOptions* options) + : qnn_context_blob_(qnn_context_blob), options_(options) { + if (options_ == nullptr) { + QNN_EXECUTORCH_LOG_ERROR( + "Fail to create QnnDlcManager, options is nullptr"); + } +} + +Error QnnDlcManager::LoadQnnIrLibrary() { + backend_bundle_ptr_->implementation = + std::make_unique(library_name_); + Error ret = backend_bundle_ptr_->implementation->Load(nullptr); + return ret; +} + +Error QnnDlcManager::Create() { + backend_bundle_ptr_->qnn_backend_ptr = std::make_unique( + backend_bundle_ptr_->implementation.get(), + backend_bundle_ptr_->qnn_logger_ptr.get()); + + backend_bundle_ptr_->qnn_device_ptr = std::make_unique( + backend_bundle_ptr_->implementation.get(), + backend_bundle_ptr_->qnn_logger_ptr.get()); + + backend_params_ptr_->qnn_backend_cache_ptr_ = + std::make_unique(qnn_context_blob_); + + backend_params_ptr_->qnn_context_ptr_ = std::make_unique( + backend_bundle_ptr_->implementation.get(), + backend_bundle_ptr_->qnn_backend_ptr.get(), + backend_bundle_ptr_->qnn_device_ptr.get(), + backend_params_ptr_->qnn_backend_cache_ptr_.get(), + nullptr); + + backend_params_ptr_->qnn_graph_ptr_ = std::make_unique( + backend_bundle_ptr_->implementation.get(), + backend_bundle_ptr_->qnn_backend_ptr.get(), + backend_params_ptr_->qnn_context_ptr_.get(), + get_option(options_->profile_level())); + backend_params_ptr_->backend_init_state_ = + BackendInitializeState::INITIALIZED; + return backend_bundle_ptr_->qnn_backend_ptr->VerifyQNNSDKVersion(); +} + +Error QnnDlcManager::Configure(const std::vector& graph_names) { + ET_CHECK_OR_RETURN_ERROR( + backend_params_ptr_ != nullptr, Internal, "Failed to load Qnn backend."); + ET_CHECK_OR_RETURN_ERROR( + backend_params_ptr_->qnn_backend_cache_ptr_->Configure(graph_names) == + Error::Ok, + Internal, + "Fail to configure Qnn backend cache"); + ET_CHECK_OR_RETURN_ERROR( + backend_bundle_ptr_->qnn_backend_ptr->Configure( + options_->op_package_options()) == Error::Ok, + Internal, + "Fail to configure Qnn backend"); + ET_CHECK_OR_RETURN_ERROR( + backend_params_ptr_->qnn_context_ptr_->Configure() == Error::Ok, + Internal, + "Fail to configure Qnn context"); + for (const std::string& graph_name : + backend_params_ptr_->qnn_context_ptr_->GetGraphNames()) { + ET_CHECK_OR_RETURN_ERROR( + backend_params_ptr_->qnn_graph_ptr_->Configure(graph_name) == Error::Ok, + Internal, + "Fail to configure Qnn graph"); + } + backend_params_ptr_->backend_init_state_ = + BackendInitializeState::INITIALIZED; + + return Error::Ok; +} + +Error QnnDlcManager::SetUpDlcEnvironment( + const Qnn_Version_t& coreApiVersion, + const std::vector& graph_names) { + ET_CHECK_MSG( + (coreApiVersion.major >= 2 && coreApiVersion.minor >= 23), + "Qnn API version %u.%u.%u is not supported for Qnn IR backend, The minimum supported version is 2.23.0 or QNN_SDK version 2.30.0", + coreApiVersion.major, + coreApiVersion.minor, + coreApiVersion.patch); + + ET_CHECK_OR_RETURN_ERROR( + LoadQnnIrLibrary() == Error::Ok, + Internal, + "Fail to Load Qnn IR library."); + + backend_bundle_ptr_->qnn_logger_ptr = std::make_unique( + backend_bundle_ptr_->implementation.get(), + LoggingCallback, + get_option(options_->log_level())); + + ET_CHECK_OR_RETURN_ERROR( + Create() == Error::Ok, Internal, "Failed to load Qnn IR backend."); + + ET_CHECK_OR_RETURN_ERROR( + Configure(graph_names) == Error::Ok, + Internal, + "Fail to configure IR backend."); + + return Error::Ok; +} + +Error QnnDlcManager::RegisterGraphsFromDLC( + QnnImplementation* implementation, + QnnBackend* backend, + QnnContext* context, + QnnBackendCache* cache) { + return Error::Ok; +} + +void QnnDlcManager::Destroy() { + backend_params_ptr_.reset(new BackendConfigParameters()); + backend_bundle_ptr_.reset(new QnnBackendBundle()); +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/irbackend/aarch64/IrContext.cpp b/backends/qualcomm/runtime/backends/irbackend/aarch64/IrContext.cpp deleted file mode 100644 index 44ce8de8f46..00000000000 --- a/backends/qualcomm/runtime/backends/irbackend/aarch64/IrContext.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include -#include - -namespace executorch { -namespace backends { -namespace qnn { - -using executorch::runtime::Error; - -Error IrContext::GetContextBinary( - QnnExecuTorchContextBinary& qnn_executorch_context_binary) { - return Error::Ok; -} - -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/irbackend/aarch64/QnnDlcManager.cpp b/backends/qualcomm/runtime/backends/irbackend/aarch64/QnnDlcManager.cpp deleted file mode 100644 index d8c09dabcbe..00000000000 --- a/backends/qualcomm/runtime/backends/irbackend/aarch64/QnnDlcManager.cpp +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -#include -#include -#include -#include -#include - -namespace executorch { -namespace backends { -namespace qnn { - -QnnDlcManager::QnnDlcManager( - const QnnExecuTorchContextBinary& qnn_context_blob, - const QnnExecuTorchOptions* options) - : qnn_loaded_backend_(""), - qnn_context_blob_(qnn_context_blob), - options_(options) { - if (options_ == nullptr) { - QNN_EXECUTORCH_LOG_ERROR( - "Fail to create QnnDlcManager, options is nullptr"); - } -} - -Error QnnDlcManager::LoadQnnIrLibrary() { - return Error::Ok; -} - -Error QnnDlcManager::Create() { - return Error::Ok; -} - -Error QnnDlcManager::Configure() { - return Error::Ok; -} - -Error QnnDlcManager::SetUpDlcEnvironment(const Qnn_Version_t& coreApiVersion) { - return Error::Ok; -} - -Error QnnDlcManager::RegisterGraphsFromDLC( - const QnnImplementation& implementation, - QnnBackend* backend, - QnnContext* context, - QnnBackendCache* cache) { - void* lib_handle = dlopen(dlc_lib_, RTLD_NOW | RTLD_LOCAL); - if (lib_handle == nullptr) { - QNN_EXECUTORCH_LOG_ERROR( - "Cannot Open lib %s, with error: %s", dlc_lib_, dlerror()); - return Error::Internal; - } - QnnModel_composeGraphsFromDlc composeGraphsFromDlc = - loadQnnFunction( - lib_handle, "QnnModel_composeGraphsFromDlc"); - if (composeGraphsFromDlc == nullptr) { - QNN_EXECUTORCH_LOG_ERROR( - "Cannot load symbol " - "QnnModel_composeGraphsFromDlc : %s", - dlerror()); - return Error::Internal; - } - - // memfd_create on android api level 30 and above - int fd = -1; -#ifdef __ANDROID__ -#if __ANDROID_API__ >= 30 - fd = memfd_create("tmp.dlc", 0); -#endif -#endif - if (fd == -1) { - QNN_EXECUTORCH_LOG_ERROR("memfd_create fail"); - return Error::Internal; - } - - if (ftruncate(fd, qnn_context_blob_.nbytes) == -1) { - QNN_EXECUTORCH_LOG_ERROR("ftruncate fail"); - close(fd); - return Error::Internal; - } - - void* addr = mmap( - NULL, - qnn_context_blob_.nbytes, - PROT_READ | PROT_WRITE, - MAP_SHARED, - fd, - 0); - if (addr == MAP_FAILED) { - QNN_EXECUTORCH_LOG_ERROR("mmap"); - close(fd); - return Error::Internal; - } - - memcpy(addr, qnn_context_blob_.buffer, qnn_context_blob_.nbytes); - - char dlc_path[256]; - snprintf(dlc_path, sizeof(dlc_path), "/proc/self/fd/%d", fd); - - const QNN_INTERFACE_VER_TYPE& interfaceVer = - implementation.GetQnnInterface().GetInterfaceVer(); - - if (composeGraphsFromDlc( - /*backendHandle=*/backend->GetHandle(), - /*interface=*/interfaceVer, - /*contextHandle=*/context->GetHandle(), - /*graphsConfigInfo=*/nullptr, - /*dlcPath=*/dlc_path, - /*numGraphsConfigInfo=*/0, - /*graphsInfo=*/&qnn_dlc_graph_info_, - /*numGraphsInfo=*/&qnn_dlc_graph_info_num_, - /*debug=*/false, - /*logCallback=*/nullptr, - /*maxLogLevel=*/QNN_LOG_LEVEL_VERBOSE) != - qnn_wrapper_api::ModelError_t::MODEL_NO_ERROR) { - QNN_EXECUTORCH_LOG_ERROR("Failed to open Dlc"); - return Error::Internal; - } - munmap(addr, qnn_context_blob_.nbytes); - close(fd); - dlclose(lib_handle); - - for (uint32_t i = 0; i < qnn_dlc_graph_info_num_; ++i) { - auto& graphInfo = (*qnn_dlc_graph_info_)[i]; - cache->SetGraphNames(graphInfo.graphName); - } - - return Error::Ok; -} - -void QnnDlcManager::ResetBackendParams() {} -void QnnDlcManager::ResetLogger() {} -void QnnDlcManager::TerminateAllBackends() {} - -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/irbackend/x86_64/IrContext.cpp b/backends/qualcomm/runtime/backends/irbackend/x86_64/IrContext.cpp deleted file mode 100644 index f167aae9319..00000000000 --- a/backends/qualcomm/runtime/backends/irbackend/x86_64/IrContext.cpp +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include -#include -namespace executorch { -namespace backends { -namespace qnn { - -using executorch::runtime::Error; - -Error IrContext::GetContextBinary( - QnnExecuTorchContextBinary& qnn_executorch_context_binary) { - // read Dlc and write to buffer - std::string dlc_name = GetGraphNames()[0] + ".dlc"; - std::ifstream dlc_file(dlc_name, std::ios::binary | std::ios::ate); - if (dlc_file.is_open()) { - std::streamsize size = dlc_file.tellg(); - dlc_file.seekg(0, std::ios::beg); - - buffer_ = std::vector(size); - dlc_file.read(buffer_.data(), size); - dlc_file.close(); - qnn_executorch_context_binary.buffer = buffer_.data(); - qnn_executorch_context_binary.nbytes = size; - return Error::Ok; - } else { - QNN_EXECUTORCH_LOG_ERROR( - "Unable to open dlc file %s for building QnnExecuTorchContextBinary", - dlc_name.c_str()); - } - return Error::Internal; -} -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/irbackend/x86_64/QnnDlcManager.cpp b/backends/qualcomm/runtime/backends/irbackend/x86_64/QnnDlcManager.cpp deleted file mode 100644 index 280751cf160..00000000000 --- a/backends/qualcomm/runtime/backends/irbackend/x86_64/QnnDlcManager.cpp +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -#include -#include - -namespace executorch { -namespace backends { -namespace qnn { - -QnnDlcManager::QnnDlcManager( - const QnnExecuTorchContextBinary& qnn_context_blob, - const QnnExecuTorchOptions* options) - : qnn_loaded_backend_(""), - qnn_context_blob_(qnn_context_blob), - options_(options) { - if (options_ == nullptr) { - QNN_EXECUTORCH_LOG_ERROR( - "Fail to create QnnDlcManager, options is nullptr"); - } -} - -Error QnnDlcManager::LoadQnnIrLibrary() { - qnn_loaded_backend_ = QnnImplementation(library_name_); - Error ret = qnn_loaded_backend_.Load(nullptr); - return ret; -} - -Error QnnDlcManager::Create() { - backend_params_ptr_->qnn_backend_ptr_ = - std::make_unique(qnn_loaded_backend_, logger_.get()); - - backend_params_ptr_->qnn_device_ptr_ = - std::make_unique(qnn_loaded_backend_, logger_.get()); - - backend_params_ptr_->qnn_backend_cache_ptr_ = - std::make_unique(qnn_context_blob_); - - backend_params_ptr_->qnn_context_ptr_ = std::make_unique( - qnn_loaded_backend_, - backend_params_ptr_->qnn_backend_ptr_.get(), - backend_params_ptr_->qnn_device_ptr_.get(), - backend_params_ptr_->qnn_backend_cache_ptr_.get(), - nullptr); - - backend_params_ptr_->qnn_graph_ptr_ = std::make_unique( - qnn_loaded_backend_, - backend_params_ptr_->qnn_backend_ptr_.get(), - backend_params_ptr_->qnn_context_ptr_.get(), - get_option(options_->profile_level())); - backend_params_ptr_->backend_init_state_ = - BackendInitializeState::INITIALIZED; - return backend_params_ptr_->qnn_backend_ptr_->VerifyQNNSDKVersion(); -} - -Error QnnDlcManager::Configure() { - ET_CHECK_OR_RETURN_ERROR( - backend_params_ptr_ != nullptr, Internal, "Failed to load Qnn backend."); - std::vector graph_names; - for (auto name : *options_->graph_name()) { - graph_names.emplace_back(name->str()); - } - ET_CHECK_OR_RETURN_ERROR( - backend_params_ptr_->qnn_backend_cache_ptr_->Configure(graph_names) == - Error::Ok, - Internal, - "Fail to configure Qnn backend cache"); - ET_CHECK_OR_RETURN_ERROR( - backend_params_ptr_->qnn_backend_ptr_->Configure( - options_->op_package_options()) == Error::Ok, - Internal, - "Fail to configure Qnn backend"); - ET_CHECK_OR_RETURN_ERROR( - backend_params_ptr_->qnn_context_ptr_->Configure() == Error::Ok, - Internal, - "Fail to configure Qnn context"); - for (const std::string& graph_name : - backend_params_ptr_->qnn_context_ptr_->GetGraphNames()) { - ET_CHECK_OR_RETURN_ERROR( - backend_params_ptr_->qnn_graph_ptr_->Configure(graph_name) == Error::Ok, - Internal, - "Fail to configure Qnn graph"); - } - backend_params_ptr_->backend_init_state_ = - BackendInitializeState::INITIALIZED; - - return Error::Ok; -} - -Error QnnDlcManager::SetUpDlcEnvironment(const Qnn_Version_t& coreApiVersion) { - ET_CHECK_MSG( - (coreApiVersion.major >= 2 && coreApiVersion.minor >= 23), - "Qnn API version %u.%u.%u is not supported for Qnn IR backend, The minimum supported version is 2.23.0 or QNN_SDK version 2.30.0", - coreApiVersion.major, - coreApiVersion.minor, - coreApiVersion.patch); - - ET_CHECK_OR_RETURN_ERROR( - LoadQnnIrLibrary() == Error::Ok, - Internal, - "Fail to Load Qnn IR library."); - - logger_ = std::make_unique( - qnn_loaded_backend_, LoggingCallback, get_option(options_->log_level())); - - ET_CHECK_OR_RETURN_ERROR( - Create() == Error::Ok, Internal, "Failed to load Qnn IR backend."); - - ET_CHECK_OR_RETURN_ERROR( - Configure() == Error::Ok, Internal, "Fail to configure IR backend."); - - return Error::Ok; -} - -Error QnnDlcManager::RegisterGraphsFromDLC( - const QnnImplementation& implementation, - QnnBackend* backend, - QnnContext* context, - QnnBackendCache* cache) { - return Error::Ok; -} - -void QnnDlcManager::ResetBackendParams() { - backend_params_ptr_.reset(new BackendConfigParameters()); -} - -void QnnDlcManager::ResetLogger() { - logger_.reset(); -} - -void QnnDlcManager::TerminateAllBackends() { - qnn_loaded_backend_.TerminateAllBackends(); -} - -} // namespace qnn -} // namespace backends -} // namespace executorch diff --git a/backends/qualcomm/runtime/targets.bzl b/backends/qualcomm/runtime/targets.bzl index db3706ba221..85cece2bae7 100644 --- a/backends/qualcomm/runtime/targets.bzl +++ b/backends/qualcomm/runtime/targets.bzl @@ -44,10 +44,12 @@ def define_common_targets(): [ "*.cpp", "backends/*.cpp", - "backends/irbackend/*.cpp", - "backends/htpbackend/*.cpp", - ] + (["backends/htpbackend/x86_64/*.cpp"] if include_aot_qnn_lib else ["backends/htpbackend/aarch64/*.cpp"]) + ( - ["backends/irbackend/x86_64/*.cpp"] if include_aot_qnn_lib else ["backends/irbackend/aarch64/*.cpp"] + "backends/gpu/*.cpp", + "backends/htp/*.cpp", + "backends/ir/*.cpp", + ] + (["backends/gpu/x86_64/*.cpp"] if include_aot_qnn_lib else ["backends/gpu/aarch64/*.cpp"]) + ( + ["backends/htp/x86_64/*.cpp"] if include_aot_qnn_lib else ["backends/htp/aarch64/*.cpp"]) + ( + ["backends/ir/x86_64/*.cpp"] if include_aot_qnn_lib else ["backends/ir/aarch64/*.cpp"] ), exclude = ["Logging.cpp"], ), @@ -55,8 +57,9 @@ def define_common_targets(): [ "*.h", "backends/*.h", - "backends/irbackend/*.h", - "backends/htpbackend/*.h", + "backends/gpu/*.h", + "backends/htp/*.h", + "backends/ir/*.h", ], exclude = ["Logging.h"], ), diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index 297f81fc85d..b8f366d2f7c 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -1,10 +1,20 @@ +#!/usr/bin/env bash # Copyright (c) Qualcomm Innovation Center, Inc. # All rights reserved # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. set -e -set -o xtrace + +pip install pydot + +# Check if running on macOS/Darwin +if [[ "$(uname -s)" == "Darwin" ]]; then + echo "Error: Qualcomm backend Python interface requires Linux operating system." + echo "macOS/Darwin is not supported for building the Qualcomm backend." + echo "Please use a x64 Linux system or x64 Linux container to build this backend." + exit 1 +fi if [[ -z ${QNN_SDK_ROOT} ]]; then echo "Please export QNN_SDK_ROOT=/path/to/qnn_sdk" @@ -12,12 +22,15 @@ if [[ -z ${QNN_SDK_ROOT} ]]; then fi +set -o xtrace usage() { echo "Usage: Build the aarch64 version of executor runner or the python interface of Qnn Manager" echo "First, you need to set the environment variable for QNN_SDK_ROOT" - echo ", and if you want to build the aarch64 version of executor runner" + echo ", and if you want to build the android version of executor runner" echo ", you need to export ANDROID_NDK_ROOT=/path/to/android_ndkXX" + echo "(or export TOOLCHAIN_ROOT_HOST=/path/to/sysroots/xx_host, " + echo "TOOLCHAIN_ROOT_TARGET=/path/to/sysroots/xx_target for linux embedded with --enable_linux_embedded)" echo "e.g.: executorch$ ./backends/qualcomm/scripts/build.sh --skip_x86_64" exit 1 } @@ -27,8 +40,10 @@ usage() { BUILD_X86_64="true" CMAKE_X86_64="build-x86" -BUILD_AARCH64="true" -CMAKE_AARCH64="build-android" +BUILD_ANDROID="true" +CMAKE_ANDROID="build-android" +BUILD_OE_LINUX="false" +CMAKE_OE_LINUX="build-oe-linux" CLEAN="true" BUILD_TYPE="RelWithDebInfo" BUILD_JOB_NUMBER="16" @@ -41,7 +56,7 @@ if [ -z BUCK2 ]; then BUCK2="buck2" fi -long_options=skip_x86_64,skip_aarch64,no_clean,release,job_number: +long_options=skip_x86_64,skip_linux_android,skip_linux_embedded,enable_linux_embedded,no_clean,release,job_number: parsed_args=$(getopt -a --options '' --longoptions $long_options --name "$0" -- "$@") eval set -- "$parsed_args" @@ -50,7 +65,9 @@ eval set -- "$parsed_args" while true ; do case "$1" in --skip_x86_64) BUILD_X86_64="false"; shift;; - --skip_aarch64) BUILD_AARCH64="false"; shift;; + --skip_linux_android) BUILD_ANDROID="false"; shift;; + --skip_linux_embedded) BUILD_OE_LINUX="false"; shift;; + --enable_linux_embedded) BUILD_ANDROID="false"; BUILD_OE_LINUX="true"; shift;; --no_clean) CLEAN="false"; shift;; --release) BUILD_TYPE="Release"; shift;; --job_number) BUILD_JOB_NUMBER="$2"; shift 2;; @@ -60,13 +77,13 @@ done PRJ_ROOT="$( cd "$(dirname "$0")/../../.." ; pwd -P)" -if [ "$BUILD_AARCH64" = true ]; then +if [ "$BUILD_ANDROID" = true ]; then if [[ -z ${ANDROID_NDK_ROOT} ]]; then echo "Please export ANDROID_NDK_ROOT=/path/to/android_ndkXX" exit -1 fi - BUILD_ROOT=$PRJ_ROOT/$CMAKE_AARCH64 + BUILD_ROOT=$PRJ_ROOT/$CMAKE_ANDROID if [ "$CLEAN" = true ]; then rm -rf $BUILD_ROOT && mkdir $BUILD_ROOT else @@ -85,6 +102,7 @@ if [ "$BUILD_AARCH64" = true ]; then -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ -DEXECUTORCH_ENABLE_LOGGING=ON \ @@ -133,6 +151,94 @@ if [ "$BUILD_AARCH64" = true ]; then cmake --build $LLAMA_EXAMPLE_ROOT -j$BUILD_JOB_NUMBER fi +if [ "$BUILD_OE_LINUX" = true ]; then + if [[ -z ${TOOLCHAIN_ROOT_HOST} ]]; then + echo "Please export e.g. TOOLCHAIN_ROOT_HOST=/path/to/sysroots/x86_64-qtisdk-linux" + exit -1 + fi + if [[ -z ${TOOLCHAIN_ROOT_TARGET} ]]; then + echo "Please export e.g. TOOLCHAIN_ROOT_TARGET=/path/to/sysroots/armv8a-oe-linux" + exit -1 + fi + + BUILD_ROOT=$PRJ_ROOT/$CMAKE_OE_LINUX + if [ "$CLEAN" = true ]; then + rm -rf $BUILD_ROOT && mkdir $BUILD_ROOT + else + # Force rebuild flatccrt for the correct platform + cd $BUILD_ROOT/third-party/flatcc && make clean + fi + + TOOLCHAN_PREFIX=$TOOLCHAIN_ROOT_HOST/usr/bin/aarch64-oe-linux/aarch64-oe-linux- + cd $BUILD_ROOT + cmake .. \ + -DCMAKE_INSTALL_PREFIX=$BUILD_ROOT \ + -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ + -DCMAKE_C_COMPILER=${TOOLCHAN_PREFIX}gcc \ + -DCMAKE_CXX_COMPILER=${TOOLCHAN_PREFIX}g++ \ + -DCMAKE_SYSROOT=$TOOLCHAIN_ROOT_TARGET \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=aarch64 \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DEXECUTORCH_BUILD_QNN=ON \ + -DEXECUTORCH_BUILD_DEVTOOLS=ON \ + -DEXECUTORCH_BUILD_EXTENSION_LLM=ON \ + -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DQNN_SDK_ROOT=$QNN_SDK_ROOT \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ + -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -B$BUILD_ROOT + + cmake --build $BUILD_ROOT -j$BUILD_JOB_NUMBER --target install + + EXAMPLE_ROOT=examples/qualcomm + CMAKE_PREFIX_PATH="${BUILD_ROOT};${BUILD_ROOT}/third-party/gflags;" + + cmake $PRJ_ROOT/$EXAMPLE_ROOT \ + -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ + -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \ + -DSUPPORT_REGEX_LOOKAHEAD=ON \ + -DBUILD_TESTING=OFF \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DCMAKE_C_COMPILER=${TOOLCHAN_PREFIX}gcc \ + -DCMAKE_CXX_COMPILER=${TOOLCHAN_PREFIX}g++ \ + -DCMAKE_SYSROOT=$TOOLCHAIN_ROOT_TARGET \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=aarch64 \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ + -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -B$EXAMPLE_ROOT + + cmake --build $EXAMPLE_ROOT -j$BUILD_JOB_NUMBER + + LLAMA_EXAMPLE_ROOT=examples/models/llama + cmake $PRJ_ROOT/$LLAMA_EXAMPLE_ROOT \ + -DBUILD_TESTING=OFF \ + -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ + -DCMAKE_C_COMPILER=${TOOLCHAN_PREFIX}gcc \ + -DCMAKE_CXX_COMPILER=${TOOLCHAN_PREFIX}g++ \ + -DCMAKE_SYSROOT=$TOOLCHAIN_ROOT_TARGET \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=aarch64 \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ + -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -B$LLAMA_EXAMPLE_ROOT + + cmake --build $LLAMA_EXAMPLE_ROOT -j$BUILD_JOB_NUMBER +fi + if [ "$BUILD_X86_64" = true ]; then BUILD_ROOT=$PRJ_ROOT/$CMAKE_X86_64 if [ "$CLEAN" = true ]; then @@ -154,6 +260,7 @@ if [ "$BUILD_X86_64" = true ]; then -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ diff --git a/backends/qualcomm/scripts/download_qnn_sdk.py b/backends/qualcomm/scripts/download_qnn_sdk.py new file mode 100644 index 00000000000..5524adf8988 --- /dev/null +++ b/backends/qualcomm/scripts/download_qnn_sdk.py @@ -0,0 +1,642 @@ +import argparse +import ctypes +import logging +import os +import pathlib +import platform +import re +import shutil +import subprocess +import sys +import tarfile +import tempfile +import urllib.request +import zipfile +from typing import Dict, List, Optional, Tuple + +import requests +from requests.adapters import HTTPAdapter, Retry + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +PKG_ROOT = pathlib.Path(__file__).parent.parent +SDK_DIR = PKG_ROOT / "sdk" / "qnn" + + +def is_linux_x86() -> bool: + """ + Check if the current platform is Linux x86_64. + + Returns: + bool: True if the system is Linux x86_64, False otherwise. + """ + return platform.system().lower() == "linux" and platform.machine().lower() in ( + "x86_64", + "amd64", + "i386", + "i686", + ) + + +######################### +# Cache directory helper +######################### + +APP_NAMESPACE = ["executorch", "qnn"] + + +def _get_staging_dir(*parts: str) -> pathlib.Path: + r""" + Return a cross-platform staging directory for staging SDKs/libraries. + + - On Linux: + ~/.cache/executorch/qnn/ + (falls back to $HOME/.cache if $XDG_CACHE_HOME is unset) + + - On Windows (not supported yet, but as placeholder): + %LOCALAPPDATA%\executorch\qnn\ + (falls back to $HOME/AppData/Local if %LOCALAPPDATA% is unset) + + - Override: + If QNN_STAGING_DIR is set in the environment, that path is used instead. + + Args: + parts (str): Subdirectories to append under the root staging dir. + + Returns: + pathlib.Path: Fully qualified staging path. + """ + # Environment override wins + base = os.environ.get("QNN_STAGING_DIR") + if base: + return pathlib.Path(base).joinpath(*parts) + + system = platform.system().lower() + if system == "windows": + # On Windows, prefer %LOCALAPPDATA%, fallback to ~/AppData/Local + base = pathlib.Path( + os.environ.get("LOCALAPPDATA", pathlib.Path.home() / "AppData" / "Local") + ) + elif is_linux_x86(): + # On Linux/Unix, prefer $XDG_CACHE_HOME, fallback to ~/.cache + base = pathlib.Path( + os.environ.get("XDG_CACHE_HOME", pathlib.Path.home() / ".cache") + ) + else: + raise ValueError(f"Unsupported platform: {system}") + + return base.joinpath(*APP_NAMESPACE, *parts) + + +def _atomic_download(url: str, dest: pathlib.Path): + """ + Download URL into dest atomically: + - Write to a temp file in the same dir + - Move into place if successful + """ + dest.parent.mkdir(parents=True, exist_ok=True) + + # Temp file in same dir (guarantees atomic rename) + with tempfile.NamedTemporaryFile(dir=dest.parent, delete=False) as tmp: + tmp_path = pathlib.Path(tmp.name) + + try: + urllib.request.urlretrieve(url, tmp_path) + tmp_path.replace(dest) # atomic rename + except Exception: + # Clean up partial file on failure + if tmp_path.exists(): + tmp_path.unlink(missing_ok=True) + raise + + +#################### +# qnn sdk download management +#################### + + +def _download_archive(url: str, archive_path: pathlib.Path) -> bool: + """Robust streaming download with retries.""" + + logger.debug("Archive will be saved to: %s", archive_path) + + session = requests.Session() + retries = Retry( + total=5, + backoff_factor=1.0, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=["GET"], + ) + session.mount("https://", HTTPAdapter(max_retries=retries)) + + try: + with session.get(url, stream=True) as r: + r.raise_for_status() + + downloaded = 0 + chunk_size = 1024 * 1024 # 1MB + + with open(archive_path, "wb") as f: + for chunk in r.iter_content(chunk_size): + if chunk: + f.write(chunk) + downloaded += len(chunk) + + logger.info("Download completed!") + + except Exception as e: + logger.exception("Error during download: %s", e) + return False + + if archive_path.exists() and archive_path.stat().st_size == 0: + logger.warning("Downloaded file is empty!") + return False + elif not archive_path.exists(): + logger.error("File was not downloaded!") + return False + + return True + + +def _extract_archive( + url: str, archive_path: pathlib.Path, content_dir: str, dst_folder: pathlib.Path +): + """Extract archive based on type (zip or tar).""" + if url.endswith(".zip"): + logger.info("Extracting ZIP archive...") + _extract_zip(archive_path, content_dir, dst_folder) + elif url.endswith((".tar.gz", ".tgz")): + logger.info("Extracting TAR archive...") + _extract_tar(archive_path, content_dir, dst_folder) + else: + raise ValueError(f"Unsupported archive format: {url}") + + +def _verify_extraction(dst_folder: pathlib.Path): + """Check if extraction succeeded and log contents.""" + logger.info("Verifying extraction to %s", dst_folder) + if dst_folder.exists(): + logger.debug("SDK directory exists. Contents:") + for item in dst_folder.iterdir(): + logger.debug(" %s", item.name) + else: + logger.error("SDK directory was not created!") + + +def _download_qnn_sdk(dst_folder=SDK_DIR) -> Optional[pathlib.Path]: + """ + Download and extract the Qualcomm SDK into dst_folder. + Only runs on Linux x86 platforms. + """ + QNN_VERSION = "2.37.0.250724" + logger.info("Downloading Qualcomm SDK...") + QAIRT_URL = ( + f"https://softwarecenter.qualcomm.com/api/download/software/sdks/" + f"Qualcomm_AI_Runtime_Community/All/{QNN_VERSION}/v{QNN_VERSION}.zip" + ) + QAIRT_CONTENT_DIR = f"qairt/{QNN_VERSION}" + if not is_linux_x86(): + logger.info("[QNN] Skipping Qualcomm SDK (only supported on Linux x86).") + return None + else: + logger.info("[QNN] Downloading Qualcomm SDK for Linux x86") + + dst_folder.mkdir(parents=True, exist_ok=True) + + with tempfile.TemporaryDirectory() as tmpdir: + archive_path = pathlib.Path(tmpdir) / pathlib.Path(QAIRT_URL).name + if not _download_archive(QAIRT_URL, archive_path): + return None + + _extract_archive(QAIRT_URL, archive_path, QAIRT_CONTENT_DIR, dst_folder) + _verify_extraction(dst_folder) + + return dst_folder + + +def _extract_zip(archive_path, content_dir, target_dir): + logger.debug("Extracting %s to %s", archive_path, target_dir) + logger.debug("Looking for content in subdirectory: %s", content_dir) + + target_dir.mkdir(parents=True, exist_ok=True) + + with zipfile.ZipFile(archive_path, "r") as zip_ref: + files_to_extract = [f for f in zip_ref.namelist() if f.startswith(content_dir)] + + for file in files_to_extract: + relative_path = pathlib.Path(file).relative_to(content_dir) + if relative_path == pathlib.Path("."): + continue + + out_path = target_dir / relative_path + if file.endswith("/"): + out_path.mkdir(parents=True, exist_ok=True) + else: + out_path.parent.mkdir(parents=True, exist_ok=True) + with zip_ref.open(file) as src, open(out_path, "wb") as dst: + shutil.copyfileobj(src, dst) + + +def _extract_tar(archive_path: pathlib.Path, prefix: str, target_dir: pathlib.Path): + with tarfile.open(archive_path, "r:gz") as tf: + for m in tf.getmembers(): + if not m.name.startswith(prefix + "/"): + continue + relpath = pathlib.Path(m.name).relative_to(prefix) + if not relpath.parts or relpath.parts[0] == "..": + continue + + out_path = target_dir / relpath + if m.isdir(): + out_path.mkdir(parents=True, exist_ok=True) + else: + out_path.parent.mkdir(parents=True, exist_ok=True) + src = tf.extractfile(m) + if src is None: + continue + with src, open(out_path, "wb") as dst: + dst.write(src.read()) + + +#################### +# libc management +#################### + +GLIBC_VERSION = "2.34" +GLIBC_REEXEC_GUARD = "QNN_GLIBC_REEXEC" +MINIMUM_LIBC_VERSION = GLIBC_VERSION + + +def _get_glibc_libdir() -> pathlib.Path: + glibc_root = _get_staging_dir(f"glibc-{GLIBC_VERSION}") + return glibc_root / "lib" + + +def _parse_version(v: str) -> tuple[int, int]: + """Turn '2.34' → (2,34) so it can be compared.""" + parts = v.split(".") + return int(parts[0]), int(parts[1]) if len(parts) > 1 else 0 + + +def _current_glibc_version() -> str: + """Return system glibc version string (via ctypes).""" + try: + libc = ctypes.CDLL("libc.so.6") + func = libc.gnu_get_libc_version + func.restype = ctypes.c_char_p + return func().decode() + except Exception as e: + return f"error:{e}" + + +def _resolve_glibc_loader() -> pathlib.Path | None: + """Return staged ld.so path if available.""" + for p in [ + _get_glibc_libdir() / f"ld-{GLIBC_VERSION}.so", + _get_glibc_libdir() / "ld-linux-x86-64.so.2", + ]: + if p.exists(): + return p + return None + + +def _stage_prebuilt_glibc(): + """Download + extract Fedora 35 glibc RPM into /tmp.""" + logger.info(">>> Staging prebuilt glibc-%s from Fedora 35 RPM", GLIBC_VERSION) + _get_glibc_libdir().mkdir(parents=True, exist_ok=True) + rpm_path = _get_staging_dir("glibc") / "glibc.rpm" + work_dir = _get_staging_dir("glibc") / "extracted" + rpm_url = ( + "https://archives.fedoraproject.org/pub/archive/fedora/linux/releases/35/" + "Everything/x86_64/os/Packages/g/glibc-2.34-7.fc35.x86_64.rpm" + ) + + rpm_path.parent.mkdir(parents=True, exist_ok=True) + logger.info("[glibc] Downloading %s -> %s", rpm_url, rpm_path) + try: + urllib.request.urlretrieve(rpm_url, rpm_path) + except Exception as e: + logger.error("[glibc] Failed to download %s: %s", rpm_url, e) + raise + + # Extract + if work_dir.exists(): + shutil.rmtree(work_dir) + work_dir.mkdir(parents=True) + subprocess.check_call(["bsdtar", "-C", str(work_dir), "-xf", str(rpm_path)]) + + # Copy runtime libs + staged = [ + "ld-linux-x86-64.so.2", + "libc.so.6", + "libdl.so.2", + "libpthread.so.0", + "librt.so.1", + "libm.so.6", + "libutil.so.1", + ] + for lib in staged: + src = work_dir / "lib64" / lib + if src.exists(): + shutil.copy2(src, _get_glibc_libdir() / lib) + logger.info("[glibc] Staged %s", lib) + else: + logger.warning("[glibc] Missing %s in RPM", lib) + + +def ensure_glibc_minimum(min_version: str = GLIBC_VERSION): + """ + Ensure process runs under glibc >= min_version. + - If system glibc is new enough → skip. + - Else → stage Fedora RPM and re-exec under staged loader. + """ + current = _current_glibc_version() + logger.info("[glibc] Current loaded glibc: %s", current) + + # If system glibc already sufficient → skip everything + m = re.match(r"(\d+\.\d+)", current) + if m and _parse_version(m.group(1)) >= _parse_version(min_version): + logger.info("[glibc] System glibc >= %s, no staging needed.", min_version) + return + + # Avoid infinite loop + if os.environ.get(GLIBC_REEXEC_GUARD) == "1": + logger.info("[glibc] Already re-exec'd once, continuing.") + return + + # Stage prebuilt if not already staged + if not (_get_glibc_libdir() / "libc.so.6").exists(): + _stage_prebuilt_glibc() + + loader = _resolve_glibc_loader() + if not loader: + logger.error("[glibc] Loader not found in %s", _get_glibc_libdir()) + return + + logger.info( + "[glibc] Re-execing under loader %s with libdir %s", loader, _get_glibc_libdir() + ) + os.environ[GLIBC_REEXEC_GUARD] = "1" + os.execv( + str(loader), + [str(loader), "--library-path", str(_get_glibc_libdir()), sys.executable] + + sys.argv, + ) + + +#################### +# libc++ management +#################### + +LLVM_VERSION = "14.0.0" +LIBCXX_BASE_NAME = f"clang+llvm-{LLVM_VERSION}-x86_64-linux-gnu-ubuntu-18.04" +LLVM_URL = f"https://github.com/llvm/llvm-project/releases/download/llvmorg-{LLVM_VERSION}/{LIBCXX_BASE_NAME}.tar.xz" +REQUIRED_LIBCXX_LIBS = [ + "libc++.so.1.0", + "libc++abi.so.1.0", + "libunwind.so.1", +] + + +def _stage_libcxx(target_dir: pathlib.Path): + target_dir.mkdir(parents=True, exist_ok=True) + + if all((target_dir / libname).exists() for libname in REQUIRED_LIBCXX_LIBS): + logger.info("[libcxx] Already staged at %s, skipping download", target_dir) + return + + libcxx_stage = _get_staging_dir(f"libcxx-{LLVM_VERSION}") + temp_tar = libcxx_stage / f"{LIBCXX_BASE_NAME}.tar.xz" + temp_extract = libcxx_stage / LIBCXX_BASE_NAME + + if not temp_tar.exists(): + logger.info("[libcxx] Downloading %s", LLVM_URL) + _atomic_download(LLVM_URL, temp_tar) + + # Sanity check before extracting + if not temp_tar.exists() or temp_tar.stat().st_size == 0: + raise FileNotFoundError(f"[libcxx] Tarball missing or empty: {temp_tar}") + + logger.info("[libcxx] Extracting %s", temp_tar) + with tarfile.open(temp_tar, "r:xz") as tar: + tar.extractall(temp_extract.parent) + + lib_src = temp_extract / "lib" / "x86_64-unknown-linux-gnu" + for fname in REQUIRED_LIBCXX_LIBS: + src_path = lib_src / fname + if not src_path.exists(): + logger.warning( + "[libcxx] %s not found in extracted LLVM src_path %s", fname, src_path + ) + continue + shutil.copy(src_path, target_dir / fname) + + logger.info("[libcxx] Staged libc++ to %s", target_dir) + + +REQUIRED_QNN_LIBS: List[str] = [ + "libQnnHtp.so", +] + + +def _ld_library_paths() -> List[pathlib.Path]: + """Split LD_LIBRARY_PATH into ordered directories (skip empties).""" + raw = os.environ.get("LD_LIBRARY_PATH", "") + return [pathlib.Path(p) for p in raw.split(":") if p.strip()] + + +def _find_lib_in_ld_paths( + libname: str, ld_dirs: Optional[List[pathlib.Path]] = None +) -> Optional[pathlib.Path]: + """Return first matching path to `libname` in LD_LIBRARY_PATH, or None.""" + if ld_dirs is None: + ld_dirs = _ld_library_paths() + for d in ld_dirs: + candidate = d / libname + try: + if candidate.exists(): + return candidate.resolve() + except Exception: + # Ignore unreadable / permission issues, keep looking. + pass + return None + + +def _check_libs_in_ld( + libnames: List[str], +) -> Tuple[bool, Dict[str, Optional[pathlib.Path]]]: + """ + Check if each lib in `libnames` exists in LD_LIBRARY_PATH directories. + + Returns: + all_present: True iff every lib was found + locations: mapping lib -> path (or None if missing) + """ + ld_dirs = _ld_library_paths() + locations: Dict[str, Optional[pathlib.Path]] = {} + for lib in libnames: + locations[lib] = _find_lib_in_ld_paths(lib, ld_dirs) + all_present = all(locations[lib] is not None for lib in libnames) + return all_present, locations + + +# ----------------------- +# Ensure QNN SDK library +# ----------------------- +def _ensure_qnn_sdk_lib() -> bool: + """ + Ensure libQnnHtp.so is available. + - If found in LD_LIBRARY_PATH: do nothing, return True. + - Otherwise: ensure packaged SDK is present, then load libQnnHtp.so from it. + """ + all_present, locs = _check_libs_in_ld(REQUIRED_QNN_LIBS) + if all_present: + logger.info( + "[QNN] libQnnHtp.so found in LD_LIBRARY_PATH; skipping SDK install." + ) + for lib, p in locs.items(): + logger.info(" - %s: %s", lib, p) + return True + + # Not found → use packaged SDK + qnn_sdk_dir = SDK_DIR + logger.info("[QNN] libQnnHtp.so not found in LD_LIBRARY_PATH.") + if not qnn_sdk_dir.exists(): + logger.info("[QNN] SDK dir missing; downloading...") + _download_qnn_sdk() + else: + logger.info("[QNN] Using existing SDK at %s", qnn_sdk_dir) + + os.environ["QNN_SDK_ROOT"] = str(qnn_sdk_dir) + + qnn_lib = qnn_sdk_dir / "lib" / "x86_64-linux-clang" / "libQnnHtp.so" + logger.info("[QNN] Loading %s", qnn_lib) + lib_loaded = False + try: + ctypes.CDLL(str(qnn_lib), mode=ctypes.RTLD_GLOBAL) + logger.info("[QNN] Loaded libQnnHtp.so from packaged SDK.") + lib_loaded = True + except OSError as e: + logger.error("[QNN][ERROR] Failed to load %s: %s", qnn_lib, e) + return lib_loaded + + +def _load_libcxx_libs(lib_path): + logger.debug("running _load_libcxx_libs") + candidates = list(lib_path.glob("*.so*")) + priority = ["libc++abi", "libc++"] + sorted_candidates = [ + f for name in priority for f in candidates if f.name.startswith(name) + ] + sorted_candidates += [f for f in candidates if f not in sorted_candidates] + logger.debug("sorted_candidates: %s", sorted_candidates) + for sofile in sorted_candidates: + try: + ctypes.CDLL(str(sofile), mode=ctypes.RTLD_GLOBAL) + logger.info("Loaded %s", sofile.name) + except OSError as e: + logger.warning("[WARN] Failed to load %s: %s", sofile.name, e) + + +# --------------------- +# Ensure libc++ family +# --------------------- +def _ensure_libcxx_stack() -> bool: + """ + Ensure libc++ stack is available. + - If all required libc++ libs are found in LD_LIBRARY_PATH: do nothing. + - Otherwise: stage and load the packaged libc++ bundle. + """ + all_present, locs = _check_libs_in_ld(REQUIRED_LIBCXX_LIBS) + if all_present: + logger.info( + "[libcxx] All libc++ libs present in LD_LIBRARY_PATH; skipping staging." + ) + for lib, p in locs.items(): + logger.info(" - %s: %s", lib, p) + return True + + logger.info( + "[libcxx] Some libc++ libs missing in LD_LIBRARY_PATH; staging packaged libc++..." + ) + lib_loaded = False + try: + libcxx_dir = PKG_ROOT / "sdk" / f"libcxx-{LLVM_VERSION}" + _stage_libcxx(libcxx_dir) + _load_libcxx_libs(libcxx_dir) + logger.info("[libcxx] Staged and loaded libc++ from %s", libcxx_dir) + lib_loaded = True + except Exception as e: + logger.exception("[libcxx][ERROR] Failed to stage/load libc++: %s", e) + return lib_loaded + + +# --------------- +# Public entrypoint +# --------------- +def install_qnn_sdk() -> bool: + """ + Initialize Qualcomm backend with separated logic: + + QNN SDK: + - If libQnnHtp.so exists in LD_LIBRARY_PATH: do nothing. + - Else: ensure packaged SDK, load libQnnHtp.so. + + libc++ stack: + - If required libc++ libs exist in LD_LIBRARY_PATH: do nothing. + - Else: stage and load packaged libc++. + + Returns: + True if both steps succeeded (or were already satisfied), else False. + """ + logger.info("[QNN] Starting SDK installation") + + # Make sure we’re running under >= 2.34 + ensure_glibc_minimum(GLIBC_VERSION) + + # libc++ and QNN SDK setup + return _ensure_libcxx_stack() and _ensure_qnn_sdk_lib() + + +def main(argv: Optional[List[str]] = None) -> int: + parser = argparse.ArgumentParser( + description="Helper utility for Qualcomm SDK staging." + ) + parser.add_argument( + "--dst-folder", + type=pathlib.Path, + default=SDK_DIR, + help="Destination directory for the Qualcomm SDK.", + ) + parser.add_argument( + "--print-sdk-path", + action="store_true", + help="Print the resolved Qualcomm SDK path to stdout.", + ) + parser.add_argument( + "--install-sdk", + action="store_true", + help="Ensure the SDK and runtime libraries are staged and loaded.", + ) + args = parser.parse_args(argv) + + logging.basicConfig(level=logging.INFO) + + sdk_path: Optional[pathlib.Path] + if args.install_sdk: + if not install_qnn_sdk(): + return 1 + sdk_path = pathlib.Path(os.environ.get("QNN_SDK_ROOT", args.dst_folder)) + else: + sdk_path = _download_qnn_sdk(dst_folder=args.dst_folder) + if sdk_path is None: + return 1 + + if args.print_sdk_path and sdk_path is not None: + print(sdk_path) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/backends/qualcomm/scripts/install_qnn_sdk.sh b/backends/qualcomm/scripts/install_qnn_sdk.sh index a97d4258770..5bc0f7eeb1d 100644 --- a/backends/qualcomm/scripts/install_qnn_sdk.sh +++ b/backends/qualcomm/scripts/install_qnn_sdk.sh @@ -27,7 +27,7 @@ setup_android_ndk() { mkdir -p "${NDK_INSTALL_DIR}" NDK_ZIP="android-ndk-${NDK_VERSION}-linux.zip" - curl -Lo "/tmp/${NDK_ZIP}" "https://dl.google.com/android/repository/${NDK_ZIP}" + curl --retry 3 --retry-delay 5 --retry-connrefused --continue-at - -Lo "/tmp/${NDK_ZIP}" "https://dl.google.com/android/repository/${NDK_ZIP}" unzip -q "/tmp/${NDK_ZIP}" -d "${NDK_INSTALL_DIR}" mv "${NDK_INSTALL_DIR}/android-ndk-${NDK_VERSION}" "${NDK_INSTALL_DIR}/ndk" @@ -48,7 +48,7 @@ install_qnn() { echo "Start installing qnn v${QNN_VERSION}" QNN_INSTALLATION_DIR="/tmp/qnn" - + if [ -d "${QNN_INSTALLATION_DIR}/${QNN_VERSION}" ]; then echo "QNN SDK already installed at ${QNN_INSTALLATION_DIR}/${QNN_VERSION}" export QNN_SDK_ROOT="${QNN_INSTALLATION_DIR}/${QNN_VERSION}" @@ -64,7 +64,7 @@ install_qnn() { mkdir -p "${QNN_INSTALLATION_DIR}" QNN_ZIP_FILE="v${QNN_VERSION}.zip" - curl -Lo "/tmp/${QNN_ZIP_FILE}" "${QNN_ZIP_URL}" + curl --retry 3 -Lo "/tmp/${QNN_ZIP_FILE}" "${QNN_ZIP_URL}" echo "Finishing downloading qnn sdk." unzip -qo "/tmp/${QNN_ZIP_FILE}" -d /tmp echo "Finishing unzip qnn sdk." @@ -117,7 +117,7 @@ setup_libcpp() { LLVM_URL="https://github.com/llvm/llvm-project/releases/download/llvmorg-${LLVM_VERSION}/${BASE_NAME}.tar.xz" echo "Downloading LLVM from ${LLVM_URL}" - curl -fLO "${LLVM_URL}" || { + curl --retry 3 -fLO "${LLVM_URL}" || { echo "Error: Failed to download LLVM" exit 1 } diff --git a/backends/qualcomm/serialization/qc_compiler_spec.fbs b/backends/qualcomm/serialization/qc_compiler_spec.fbs index 8aeaa060a50..4169e055454 100644 --- a/backends/qualcomm/serialization/qc_compiler_spec.fbs +++ b/backends/qualcomm/serialization/qc_compiler_spec.fbs @@ -18,6 +18,7 @@ enum HtpArch: int { V73 = 73, V75 = 75, V79 = 79, + V81 = 81, } table HtpInfo { @@ -33,16 +34,22 @@ table HtpInfo { enum QcomChipset: int { UNKNOWN_SM = 0, SA8295 = 39, + SM8350 = 35, SM8450 = 36, SM8475 = 42, SM8550 = 43, SM8650 = 57, SM8750 = 69, + SM8850 = 87, SSG2115P = 46, SSG2125P = 58, SXR1230P = 45, SXR2230P = 53, SXR2330P = 75, + QCS9100 = 77, + SAR2230P = 95, + SA8255 = 52, + SW6100 = 96, } /// Indicate the information of the specified SoC. @@ -54,6 +61,50 @@ table SocInfo { htp_info:HtpInfo; } +/// Defines performance modes available for GPU backend. +enum QnnExecuTorchGpuPerformanceMode: int { + kGpuPerfHintHigh = 0, + kGpuPerfHintNormal, + kGpuPerfHintLow, +} + +/// Defines the optimization levels of the graph tensors that are not input nor +/// output tensors. This enum controls the trade-off between performance and +/// accuracy. +enum QnnExecuTorchGpuPrecision: int { + kGpuPrecisionFp32 = 0, + kGpuPrecisionFp16, + kGpuPrecisionHybrid, + kGpuPrecisionUserProvided, +} + +/// Specifies the backend options for the GPU backend. +table QnnExecuTorchGpuBackendOptions { + /// kGpuPerfHintHigh - best inference latency at the expense of power consumption. + /// kGpuPerfHintNormal - balanced performance dependent upon power management. + /// kGpuPerfHintLow - lowest power consumption at the expense of inference latency. + performance_mode:QnnExecuTorchGpuPerformanceMode; + + /// kGpuPrecisionFp32 - best accuracy at the expense of performance. + /// kGpuPrecisionFp16 - best performance at the expense of accuracy. + /// kGpuPrecisionHybrid - good trade-off between performance and accuracy. + /// kGpuPrecisionUserProvided - backend will not optimize NATIVE tensor data types. + precision:QnnExecuTorchGpuPrecision; + + /// Backend will share NATIVE tensor memory based upon analysis of the network topology. + use_memory_optimizations:bool; + + /// Backend will fuse compatible operations into one operation to improve performance. + use_node_optimizations:bool; + + /// Backend will use queue recording to improve performance. + use_queue_recording:bool; + + /// When multiple graphs appear inside the same context, + /// weights could be reused across all graphs. + use_weight_sharing:bool; +} + /// Defines performance modes available for HTP backend. enum QnnExecuTorchHtpPerformanceMode: int { kHtpDefault = 0, @@ -105,10 +156,6 @@ table QnnExecuTorchHtpBackendOptions { /// Signed or unsigned HTP PD session. The default PD session is unsigned. pd_session:QnnExecuTorchHtpPdSession; - /// Optional parameter specifying the directory of QNN Skel library. Only - /// useful for backends which have a Skel library. - skel_library_dir:string; - /// With using conv hmx with short depths, we might have better performance, /// but convolution that have short depth and/or weights that are not /// symmetric could exhibit inaccurate results. @@ -165,7 +212,6 @@ enum QnnExecuTorchOpPackagePlatform: int { AARCH64_ANDROID, } - table QnnExecuTorchOpPackageInfo { /// The name of the op package. op_package_name:string; @@ -190,7 +236,6 @@ table QnnExecuTorchOpPackageInfo { platform:QnnExecuTorchOpPackagePlatform; } - table QnnExecuTorchOpPackageOptions { /// An array of QnnExecuTorchOpPackageInfo structures. op_package_infos:[QnnExecuTorchOpPackageInfo]; @@ -203,6 +248,8 @@ table QnnExecuTorchBackendOptions { backend_type:QnnExecuTorchBackendType; htp_options:QnnExecuTorchHtpBackendOptions; + + gpu_options:QnnExecuTorchGpuBackendOptions; } table QnnExecuTorchOptions { @@ -212,10 +259,6 @@ table QnnExecuTorchOptions { /// Optional backend specific options for the HTP backend. backend_options:QnnExecuTorchBackendOptions; - /// Optional parameter to create qnn graph if QNN context blob is not given - /// It could be a list of names only when doing weight-sharing lowering - graph_name:[string]; - /// Optional parameter to override the QNN backend library. library_path:string; @@ -238,14 +281,17 @@ table QnnExecuTorchOptions { /// Is model from qnn context binary is_from_context_binary:bool; - // Enable this option to record all QNN API calls for debugging purpose + /// Enable this option to record all QNN API calls for debugging purpose saver:bool; - // Path to saver output folder + /// Path to saver output folder saver_output_dir:string; /// Optional structure to specify op packages loaded and used by the backend. op_package_options:QnnExecuTorchOpPackageOptions; + + /// This experimental parameter is used to decide whether to enable multi-head attention to single-head attention pass, aiming to reduce time consumption in AOT and improve performance on HTP. + use_mha2sha:bool; } root_type QnnExecuTorchOptions; diff --git a/backends/qualcomm/serialization/qc_schema.py b/backends/qualcomm/serialization/qc_schema.py index f3b9e2cc1a5..5f5675af067 100644 --- a/backends/qualcomm/serialization/qc_schema.py +++ b/backends/qualcomm/serialization/qc_schema.py @@ -10,7 +10,7 @@ from dataclasses import dataclass, field from enum import IntEnum, unique -from typing import List +from typing import List, Optional @dataclass @@ -27,6 +27,7 @@ class HtpArch(IntEnum): V73 = 73 V75 = 75 V79 = 79 + V81 = 81 @dataclass @@ -39,16 +40,22 @@ class HtpInfo: class QcomChipset(IntEnum): UNKNOWN_SM = 0 SA8295 = 39 # v68 + SM8350 = 35 # v68 SM8450 = 36 # v69 SM8475 = 42 # v69 SM8550 = 43 # v73 SM8650 = 57 # v75 SM8750 = 69 # v79 + SM8850 = 87 # v81 SSG2115P = 46 # v73 SSG2125P = 58 # v73 SXR1230P = 45 # v73 SXR2230P = 53 # v69 SXR2330P = 75 # v79 + QCS9100 = 77 # v73 + SAR2230P = 95 # v81 + SA8255 = 52 # v73 + SW6100 = 96 # v81 @dataclass @@ -59,19 +66,54 @@ class SocInfo: _soc_info_table = { QcomChipset.SA8295: SocInfo(QcomChipset.SA8295, HtpInfo(HtpArch.V68, 8)), + QcomChipset.SM8350: SocInfo(QcomChipset.SM8350, HtpInfo(HtpArch.V68, 4)), QcomChipset.SM8450: SocInfo(QcomChipset.SM8450, HtpInfo(HtpArch.V69, 8)), QcomChipset.SM8475: SocInfo(QcomChipset.SM8475, HtpInfo(HtpArch.V69, 8)), QcomChipset.SM8550: SocInfo(QcomChipset.SM8550, HtpInfo(HtpArch.V73, 8)), + QcomChipset.SA8255: SocInfo(QcomChipset.SA8255, HtpInfo(HtpArch.V73, 8)), QcomChipset.SM8650: SocInfo(QcomChipset.SM8650, HtpInfo(HtpArch.V75, 8)), QcomChipset.SM8750: SocInfo(QcomChipset.SM8750, HtpInfo(HtpArch.V79, 8)), + QcomChipset.SM8850: SocInfo(QcomChipset.SM8850, HtpInfo(HtpArch.V81, 8)), QcomChipset.SSG2115P: SocInfo(QcomChipset.SSG2115P, HtpInfo(HtpArch.V73, 2)), QcomChipset.SSG2125P: SocInfo(QcomChipset.SSG2125P, HtpInfo(HtpArch.V73, 2)), QcomChipset.SXR1230P: SocInfo(QcomChipset.SXR1230P, HtpInfo(HtpArch.V73, 2)), QcomChipset.SXR2230P: SocInfo(QcomChipset.SXR2230P, HtpInfo(HtpArch.V69, 8)), QcomChipset.SXR2330P: SocInfo(QcomChipset.SXR2330P, HtpInfo(HtpArch.V79, 8)), + QcomChipset.QCS9100: SocInfo(QcomChipset.QCS9100, HtpInfo(HtpArch.V73, 8)), + QcomChipset.SAR2230P: SocInfo(QcomChipset.SAR2230P, HtpInfo(HtpArch.V81, 4)), + QcomChipset.SW6100: SocInfo(QcomChipset.SW6100, HtpInfo(HtpArch.V81, 4)), } +@unique +class QnnExecuTorchGpuPerformanceMode(IntEnum): + kGpuPerfHintHigh = 0 + kGpuPerfHintNormal = 1 + kGpuPerfHintLow = 2 + + +@unique +class QnnExecuTorchGpuPrecision(IntEnum): + kGpuPrecisionFp32 = 0 + kGpuPrecisionFp16 = 1 + kGpuPrecisionHybrid = 2 + kGpuPrecisionUserProvided = 3 + + +@dataclass +class QnnExecuTorchGpuBackendOptions: + performance_mode: QnnExecuTorchGpuPerformanceMode = ( + QnnExecuTorchGpuPerformanceMode.kGpuPerfHintHigh + ) + precision: QnnExecuTorchGpuPrecision = ( + QnnExecuTorchGpuPrecision.kGpuPrecisionUserProvided + ) + use_memory_optimizations: bool = True + use_node_optimizations: bool = True + use_queue_recording: bool = True + use_weight_sharing: bool = False + + @unique class QnnExecuTorchHtpPerformanceMode(IntEnum): kHtpDefault = 0 @@ -113,7 +155,6 @@ class QnnExecuTorchHtpBackendOptions: ) precision: QnnExecuTorchHtpPrecision = QnnExecuTorchHtpPrecision.kHtpQuantized pd_session: QnnExecuTorchHtpPdSession = QnnExecuTorchHtpPdSession.kHtpUnsignedPd - skel_library_dir: str = "" use_conv_hmx: bool = True use_dlbc: bool = False use_fold_relu: bool = True @@ -142,7 +183,8 @@ class QnnExecuTorchProfileLevel(IntEnum): @dataclass class QnnExecuTorchBackendOptions: backend_type: QnnExecuTorchBackendType - htp_options: QnnExecuTorchHtpBackendOptions + htp_options: Optional[QnnExecuTorchHtpBackendOptions] = None + gpu_options: Optional[QnnExecuTorchGpuBackendOptions] = None @unique @@ -179,7 +221,6 @@ class QnnExecuTorchOpPackageOptions: class QnnExecuTorchOptions: soc_info: SocInfo backend_options: QnnExecuTorchBackendOptions - graph_name: List[str] = field(default_factory=lambda: ["forward"]) library_path: str = "" log_level: QnnExecuTorchLogLevel = QnnExecuTorchLogLevel.kLogOff online_prepare: bool = False @@ -192,3 +233,4 @@ class QnnExecuTorchOptions: op_package_options: QnnExecuTorchOpPackageOptions = field( default_factory=QnnExecuTorchOpPackageOptions ) + use_mha2sha: bool = False diff --git a/backends/qualcomm/tests/TARGETS b/backends/qualcomm/tests/TARGETS index 639303c7eb8..005cc33c7e9 100644 --- a/backends/qualcomm/tests/TARGETS +++ b/backends/qualcomm/tests/TARGETS @@ -2,6 +2,8 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") load("@fbsource//xplat/executorch/backends/qualcomm/qnn_version.bzl", "get_qnn_library_version") +oncall("executorch") + runtime.python_library( name = "models", srcs = ["models.py"], @@ -35,6 +37,7 @@ runtime.python_library( "//executorch/examples/qualcomm:utils", "//executorch/examples/models:models", "//executorch/backends/qualcomm/debugger:utils", + "//executorch/backends/qualcomm/debugger:qnn_intermediate_debugger", ], ) @@ -47,3 +50,21 @@ runtime.python_library( ":test_qnn_delegate" ] ) + +runtime.python_test( + name = "test_passes", + srcs = [ + "test_passes.py", + ], + deps = [ + "fbsource//third-party/pypi/expecttest:expecttest", # @manual + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/backends/qualcomm/_passes:passes", + "//executorch/backends/qualcomm/partition:partition", + "//executorch/examples/models/llama:transformer_modules", + "//executorch/examples/qualcomm/oss_scripts/llama:masking_utils", + "//executorch/examples/qualcomm/oss_scripts/llama:static_llama", + "//executorch/backends/qualcomm/builders:builders", + ], +) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 0e290575beb..2b73e0c6dfb 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -4,8 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import torch +from typing import List, Optional, Tuple, Union +import torch # module with related operator only @@ -40,6 +41,19 @@ def forward(self, x): return torch.abs(x) +class AdaptiveMaxPool2D(torch.nn.Module): + def __init__(self, output_size, return_indices=False): + super().__init__() + self.output_size = output_size + self.return_indices = return_indices + + def forward(self, x): + adaptive_max_pool = torch.nn.AdaptiveMaxPool2d( + self.output_size, self.return_indices + ) + return adaptive_max_pool(x) + + class AdaptiveAvgPool1D(torch.nn.Module): def __init__(self): super().__init__() @@ -58,6 +72,16 @@ def forward(self, x): return adaptive_avg_pool(x) +class AdaptiveAvgPool3D(torch.nn.Module): + def __init__(self, output_size): + super().__init__() + self.output_size = output_size + + def forward(self, x): + adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d(self.output_size) + return adaptive_avg_pool3d(x) + + class Add(torch.nn.Module): def __init__(self): super().__init__() @@ -66,6 +90,28 @@ def forward(self, x, y): return torch.add(x, y) +class AddAlpha(torch.nn.Module): + def __init__(self, alpha): + super().__init__() + self.alpha = alpha + + def forward(self, x, y): + return torch.add(x, y, alpha=self.alpha) + + +class AddAlphaConstant(torch.nn.Module): + def __init__(self, alpha, constant_first=False): + super().__init__() + self.alpha = alpha + self.constant_first = constant_first + + def forward(self, x): + if self.constant_first: + return torch.add(5.0, x, alpha=self.alpha) + else: + return torch.add(x, 5.0, alpha=self.alpha) + + class AddConstantFloat(torch.nn.Module): def __init__(self): super().__init__() @@ -148,21 +194,23 @@ def forward(self, y): class Argmax(torch.nn.Module): - def __init__(self): + def __init__(self, dim: Optional[int] = None, keepdim: bool = False): super().__init__() + self.dim = dim + self.keepdim = keepdim def forward(self, x): - x = torch.argmax(x, dim=0, keepdim=True) - return x + return torch.argmax(x, dim=self.dim, keepdim=self.keepdim) class Argmin(torch.nn.Module): - def __init__(self): + def __init__(self, dim: Optional[int] = None, keepdim: bool = False): super().__init__() + self.dim = dim + self.keepdim = keepdim def forward(self, x): - x = torch.argmin(x, dim=0, keepdim=True) - return x + return torch.argmin(x, dim=self.dim, keepdim=self.keepdim) class ArgminViewSqueezeConv2D(torch.nn.Module): @@ -199,6 +247,21 @@ def forward(self, x): return torch.atan(x) +class AvgPool3d(torch.nn.Module): + def __init__(self, kernel_size, stride, padding, ceil_mode, count_include_pad): + super().__init__() + self.avg_pool3d = torch.nn.AvgPool3d( + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + + def forward(self, x): + return self.avg_pool3d(x) + + class AvgPoolModule(torch.nn.Module): def __init__(self, kernel_size, stride, padding, ceil_mode): super().__init__() @@ -274,6 +337,15 @@ def forward(self, x, y): return torch.cat((y, y, x, x), axis=2) +class Cat5(torch.nn.Module): + def __init__(self): + super().__init__() + self.const_tensor = torch.randn(1, 1, 2, 2) + + def forward(self, x, y): + return torch.cat((x, y, self.const_tensor), axis=2) + + class CausalMask(torch.nn.Module): def __init__(self): super().__init__() @@ -588,40 +660,6 @@ def forward(self, x): return self.conv(x) -class ConvTranspose1dSingle(torch.nn.Module): - def __init__(self, bias=True, dilation=1): - super().__init__() - self.conv_transpose = torch.nn.ConvTranspose1d( - in_channels=1, - out_channels=3, - kernel_size=3, - stride=2, - padding=1, - dilation=dilation, - bias=bias, - ) - - def forward(self, x): - return self.conv_transpose(x) - - -class ConvTranspose2dSingle(torch.nn.Module): - def __init__(self, bias=True, dilation=1): - super().__init__() - self.conv_transpose = torch.nn.ConvTranspose2d( - in_channels=1, - out_channels=3, - kernel_size=3, - stride=2, - padding=1, - dilation=dilation, - bias=bias, - ) - - def forward(self, x): - return self.conv_transpose(x) - - class Conv2dDownUpSample(torch.nn.Module): def __init__(self, bias=True): super().__init__() @@ -706,6 +744,90 @@ def forward(self, x): return topk_values +class Conv3dSequential(torch.nn.Module): + def __init__(self, bias=True): + super().__init__() + self.first = torch.nn.Conv3d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3, 3), + padding=1, + bias=bias, + ) + self.second = torch.nn.Conv3d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3, 3), + padding=1, + bias=bias, + ) + + def forward(self, x): + return self.second(self.first(x)) + + +class ConvTranspose1dSingle(torch.nn.Module): + def __init__(self, bias=True, dilation=1): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose1d( + in_channels=1, + out_channels=3, + kernel_size=3, + stride=2, + padding=1, + dilation=dilation, + bias=bias, + ) + + def forward(self, x): + return self.conv_transpose(x) + + +class ConvTranspose2dSingle(torch.nn.Module): + def __init__( + self, + bias=True, + in_channels=1, + out_channels=3, + kernel_size=1, + stride=1, + padding=1, + dilation=1, + groups=1, + ): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + return self.conv_transpose(x) + + +class ConvTranspose3dSingle(torch.nn.Module): + def __init__(self, bias=True, dilation=1): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose3d( + in_channels=1, + out_channels=3, + kernel_size=3, + stride=2, + padding=1, + dilation=dilation, + bias=bias, + ) + + def forward(self, x): + return self.conv_transpose(x) + + class Cos(torch.nn.Module): def __init__(self): super().__init__() @@ -989,6 +1111,20 @@ def forward(self, x): return x > self.constant +class GridSample(torch.nn.Module): + def __init__(self, mode, padding_mode, align_corners): + super().__init__() + self.mode = mode + self.align_corners = align_corners + self.padding_mode = padding_mode + + def forward(self, x, grid): + grid_sample = torch.nn.functional.grid_sample( + x, grid, self.mode, self.padding_mode, self.align_corners + ) + return grid_sample + + class GroupNorm(torch.nn.Module): def __init__(self, bias=True): super().__init__() @@ -1068,20 +1204,62 @@ def forward(self, input_pos, k_val): class IndexPut(torch.nn.Module): - def __init__(self, skip_mutable_buffer=False): + def __init__(self, skip_mutable_buffer=False, mode=0): super().__init__() self.skip_mutable_buffer = skip_mutable_buffer self.register_buffer( "k_cache", - torch.zeros((1, 1024, 12, 64), dtype=torch.float32), + torch.zeros((2, 1024, 12, 64), dtype=torch.float32), persistent=True, ) + self.mode = mode def forward(self, input_pos, k_val): - k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val) + match self.mode: + case 0: + k_out = torch.ops.aten.index_put_(self.k_cache, [input_pos], k_val) + case 1: + k_out = torch.ops.aten.index_put_( + self.k_cache, [None, input_pos], k_val + ) + case 2: + k_out = torch.ops.aten.index_put_( + self.k_cache, [None, None, input_pos], k_val + ) + case 3: + k_out = torch.ops.aten.index_put_( + self.k_cache, [input_pos[0], input_pos[1]], k_val + ) + case 4: + k_out = torch.ops.aten.index_put_( + self.k_cache, [None, input_pos[0], input_pos[1]], k_val + ) + case 5: + k_out = torch.ops.aten.index_put_( + self.k_cache, [input_pos[0], None, input_pos[1]], k_val + ) + return k_out + 0 +class IndexPutSuite(torch.nn.Module): + def __init__(self, accumulate=False, in_place=False): + super().__init__() + self.accumulate = accumulate + self.in_place = in_place + + def forward(self, x, indices, values): + if self.in_place: + # Clone the input to avoid modifying it in-place + result = x.clone() + # Apply index_put_ and return the modified tensor + result.index_put_(indices, values, self.accumulate) + return result + else: + # Use the non-in-place variant which returns a new tensor + return torch.index_put(x, indices, values, self.accumulate) + + class IndexSelect(torch.nn.Module): def __init__(self, dim): super().__init__() @@ -1105,12 +1283,18 @@ class LargeTensorLinear(torch.nn.Module): def __init__(self): super().__init__() hidden_dim = 4096 - self.linear1 = torch.nn.Linear(512, hidden_dim) + self.linear1_1 = torch.nn.Linear(512, hidden_dim) + self.linear1_2 = torch.nn.Linear(512, hidden_dim) + self.linear1_3 = torch.nn.Linear(512, hidden_dim) self.linear2 = torch.nn.Linear(hidden_dim, 512) + self.linear3 = torch.nn.Linear(hidden_dim, 512) + self.linear4 = torch.nn.Linear(hidden_dim, 512) def forward(self, x): - x1 = self.linear1(x) + self.linear1(x) - return self.linear2(x1) + x1 = self.linear1_1(x) + self.linear1_1(x) + x2 = self.linear1_2(x) + self.linear1_2(x) + x3 = self.linear1_3(x) + self.linear1_3(x) + return self.linear2(x1) * self.linear3(x2) * self.linear4(x3) class LayerNorm(torch.nn.Module): @@ -1193,6 +1377,19 @@ def forward(self, x): return x + N +class LinalgVectorNorm(torch.nn.Module): + def __init__(self, ord=2.0, dim=None, keepdim=False): + super().__init__() + self.ord = ord + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return torch.linalg.vector_norm( + x, ord=self.ord, dim=self.dim, keepdim=self.keepdim + ) + + class Linear(torch.nn.Module): def __init__(self, use_bias: bool = True): super().__init__() @@ -1202,17 +1399,24 @@ def forward(self, x): return self.linear(x) -class LinalgVectorNorm(torch.nn.Module): - def __init__(self, ord=2.0, dim=None, keepdim=False): +class LinearNonConstantWeight(torch.nn.Module): + def __init__(self): super().__init__() - self.ord = ord - self.dim = dim - self.keepdim = keepdim + self.input_dim = 512 + self.output_dim = 128 + self.linear = torch.nn.Linear(self.input_dim, 3 * self.output_dim, True).eval() def forward(self, x): - return torch.linalg.vector_norm( - x, ord=self.ord, dim=self.dim, keepdim=self.keepdim + w_q, w_k, w_v = self.linear.weight.split( + [self.output_dim, self.output_dim, self.output_dim] ) + b_q, b_k, b_v = self.linear.bias.split( + [self.output_dim, self.output_dim, self.output_dim] + ) + q = torch.nn.functional.linear(x, w_q, b_q) + k = torch.nn.functional.linear(x, w_k, b_k) + v = torch.nn.functional.linear(x, w_v, b_v) + return q * k * v class Log(torch.nn.Module): @@ -1223,6 +1427,14 @@ def forward(self, x): return torch.log(x) +class LogicalAnd(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.logical_and(x != 0, y != 0).float() + + class LogicalNot(torch.nn.Module): def __init__(self): super().__init__() @@ -1254,20 +1466,38 @@ def forward(self, x): return self.max_pool2d(x) -class MeanWKeppDim(torch.nn.Module): - def __init__(self): +class MaxPool3d(torch.nn.Module): + def __init__( + self, kernel_size, stride, padding, dilation, ceil_mode, return_indices + ): super().__init__() + self.max_pool3d = torch.nn.MaxPool3d( + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=return_indices, + ceil_mode=ceil_mode, + ) def forward(self, x): - return torch.mean(x, (-1, -2), keepdim=True) + return self.max_pool3d(x) -class MeanWOKeppDim(torch.nn.Module): - def __init__(self): +class Mean(torch.nn.Module): + def __init__( + self, + dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None, + keepdim: bool = False, + dtype: Optional[torch.dtype] = None, + ): super().__init__() + self.dim = dim + self.keepdim = keepdim + self.dtype = dtype def forward(self, x): - return torch.mean(x, (-1, -2)) + return torch.mean(x, dim=self.dim, keepdim=self.keepdim, dtype=self.dtype) class MaskedFill(torch.nn.Module): @@ -1428,6 +1658,15 @@ def forward(self, x): ) +class Permute(torch.nn.Module): + def __init__(self, dims: List[int]): + super().__init__() + self.dims = dims + + def forward(self, x): + return x.permute(self.dims) + + class PixelShuffle(torch.nn.Module): def __init__(self, scale): super().__init__() @@ -1461,11 +1700,12 @@ def forward(self, x): class PowTensorScalar(torch.nn.Module): - def __init__(self): + def __init__(self, exponent=2): super().__init__() + self.exponent = exponent def forward(self, x): - return torch.pow(x, 2) + return torch.pow(x, self.exponent) class PReLUDefault(torch.nn.Module): @@ -1618,10 +1858,11 @@ def forward(self, x): class RmsNorm(torch.nn.Module): - def __init__(self): + def __init__(self, eps=None): super().__init__() - self.eps = 1e-5 - self.rms = torch.nn.RMSNorm([4], 1e-5) + self.rms = torch.nn.RMSNorm([4]) + if eps: + self.rms = torch.nn.RMSNorm([4], eps) def forward(self, x): return self.rms(x) @@ -1846,6 +2087,36 @@ def forward(self, x, y): return torch.sub(x, y) +class Sub_y_x_from_x_y(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.sub(y, x) + + +class SubAlpha(torch.nn.Module): + def __init__(self, alpha): + super().__init__() + self.alpha = alpha + + def forward(self, x, y): + return torch.sub(x, y, alpha=self.alpha) + + +class SubAlphaConstant(torch.nn.Module): + def __init__(self, alpha, constant_first=False): + super().__init__() + self.alpha = alpha + self.constant_first = constant_first + + def forward(self, x): + if self.constant_first: + return torch.sub(5.0, x, alpha=self.alpha) + else: + return torch.sub(x, 5.0, alpha=self.alpha) + + class SubConstantFloat(torch.nn.Module): def __init__(self): super().__init__() @@ -1882,6 +2153,16 @@ def forward(self, x): return torch.sum(x, dim=(2, 3), keepdim=True) +class SwapAxes(torch.nn.Module): + def __init__(self, axis0, axis1): + super().__init__() + self.axis0 = axis0 + self.axis1 = axis1 + + def forward(self, x): + return torch.swapaxes(x, axis0=self.axis0, axis1=self.axis1) + + class Tanh(torch.nn.Module): def __init__(self): super().__init__() @@ -1890,6 +2171,19 @@ def forward(self, x): return torch.tanh(x) +class Threshold(torch.nn.Module): + def __init__(self, threshold=0.0, value=0.0, inplace=False): + super().__init__() + self.threshold = threshold + self.value = value + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.threshold( + x, threshold=self.threshold, value=self.value, inplace=self.inplace + ) + + class TopKandIndex(torch.nn.Module): def __init__(self): super().__init__() @@ -1900,6 +2194,32 @@ def forward(self, x): return a + self.idx_source[b] +class Triu(torch.nn.Module): + def __init__(self, diagonal: Optional[int] = None): + super().__init__() + self.diagonal = diagonal + + def forward(self, x): + if self.diagonal: + return torch.triu(x, diagonal=self.diagonal) + return torch.triu(x) + + +class TriuConstant(torch.nn.Module): + def __init__(self, diagonal, constant_dtype=torch.float32): + super().__init__() + self.diagonal = diagonal + self.constant_dtype = constant_dtype + self.register_buffer("mask", torch.ones((5, 5), dtype=constant_dtype)) + + def forward(self, x): + mask = torch.triu(self.mask, diagonal=self.diagonal) + if self.constant_dtype == torch.bool: + mask = torch.zeros(x.shape, dtype=x.dtype).masked_fill_(mask, -10000.0) + # Add x to avoid no input in graph + return mask + x + + class Unbind(torch.nn.Module): def __init__(self): super().__init__() @@ -1908,6 +2228,16 @@ def forward(self, x): return torch.unbind(x) +class Unflatten(torch.nn.Module): + def __init__(self, dim, sizes): + super().__init__() + self.dim = dim + self.sizes = sizes + + def forward(self, x): + return torch.unflatten(x, dim=self.dim, sizes=self.sizes) + + class Unfold(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_passes.py b/backends/qualcomm/tests/test_passes.py new file mode 100644 index 00000000000..8af66c4cbef --- /dev/null +++ b/backends/qualcomm/tests/test_passes.py @@ -0,0 +1,154 @@ +import unittest + +import torch +from executorch.backends.qualcomm._passes import ( + ConvertBmmToMatmul, + ConvertMhaToSha, + InsertReshapeForReduceOps, + RemoveRedundancy, +) + +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops + + +class TestPasses(unittest.TestCase): + def test_insert_reshape_for_argmax(self): + class ArgmaxModule(torch.nn.Module): + def forward(self, x): + return torch.argmax(x, dim=None) + + mod = ArgmaxModule() + + x = torch.tensor([[1.0, 5.0], [3.0, 2.0]]) + ep = torch.export.export(mod, (x,)) + # Run original module for reference + ref = mod(x) + + reshape_nodes = [ + n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default + ] + argmax_nodes = [ + n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default + ] + self.assertTrue(len(reshape_nodes) == 0, "Reshape node not inserted") + self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing") + + InsertReshapeForReduceOps()(ep.graph_module) + + out = ep.graph_module(x) + + # Check graph structure: argmax should take a reshape as input + reshape_nodes = [ + n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default + ] + argmax_nodes = [ + n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default + ] + self.assertTrue(len(reshape_nodes) == 1, "Reshape node should be inserted") + self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing") + + argmax_node = argmax_nodes[0] + self.assertEqual(argmax_node.args[1], 0, "Argmax dim not set to 0") + + # Execute new graph and compare with reference + out = ep.graph_module(x) + self.assertTrue( + torch.equal(*out, ref), f"Output mismatch: got {out}, expected {ref}" + ) + + def test_mha_to_sha(self): + from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d + from executorch.examples.models.llama.model_args import ModelArgs + from executorch.examples.qualcomm.oss_scripts.llama.masking_utils import ( + CausalAttentionMask, + ) + from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( + LlamaAttention, + ) + + # Initailize model config + args = ModelArgs() + args.max_seq_len = 128 + args.ar_len = 32 + args.use_kv_cache = True + args.dim = 32 + args.n_heads = 8 + args.n_kv_heads = 8 + args.n_layers = 2 + args.head_dim = args.dim // args.n_heads + mod = convert_linear_to_conv2d(LlamaAttention(0, args, True)) + + # Prepare inputs + hidden_states = torch.randn(args.max_batch_size, args.ar_len, args.dim) + freqs_cos = torch.randn(args.ar_len, 1) + freqs_sin = torch.randn(args.ar_len, 1) + atten_mask = CausalAttentionMask( + args.max_batch_size, args.ar_len, args.max_seq_len + ) + k_cache = torch.zeros( + args.max_batch_size, + args.n_kv_heads, + args.head_dim, + args.max_seq_len - args.ar_len, + ) + + v_cache = torch.zeros( + args.max_batch_size, + args.n_kv_heads, + args.max_seq_len - args.ar_len, + args.head_dim, + ) + sample_input = ( + hidden_states, + freqs_cos, + freqs_sin, + atten_mask.mask, + k_cache, + v_cache, + ) + + # Run original module for reference + refs = mod(*sample_input) + + # Export the module and convert linear to conv2d + edge_program = to_edge(torch.export.export(mod, sample_input)) + new_ep = edge_program.exported_program() + + conv_nodes = [ + n + for n in new_ep.graph.nodes + if n.target == exir_ops.edge.aten.convolution.default + ] + # WQ, WK, WV, O + self.assertTrue(len(conv_nodes) == 4, "Convolution nodes missing") + + # Convert MHA to SHA + # This is a simplified version of what happens in the full pipeline to test the core functionality + graph_module = RemoveRedundancy(quantization_capture=False)( + new_ep.graph_module + ).graph_module + graph_module = ConvertBmmToMatmul()(graph_module).graph_module + graph_module = ConvertMhaToSha(new_ep)(graph_module).graph_module + + conv_nodes = [ + n + for n in new_ep.graph.nodes + if n.target == exir_ops.edge.aten.convolution.default + ] + # Check graph structure: WQ, WK, WV should be converted to SHA + self.assertTrue(len(conv_nodes) == 25, "Convolution nodes should be splited") + + # Execute new graph and compare with reference + outs = graph_module( + *new_ep.state_dict.values(), *new_ep.constants.values(), *sample_input + ) + for i, (out, ref) in enumerate(zip(outs, refs)): + self.assertTrue( + torch.allclose(out, *ref, rtol=1e-6, atol=1e-6), + f"Output {i} mismatch: got {out}, expected {ref}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 6ef4fa8fe13..aa3f28b34ee 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -3,12 +3,15 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import csv import io +import itertools import json import subprocess import sys import tempfile import unittest +from dataclasses import dataclass from functools import partial from multiprocessing.connection import Listener from pathlib import Path @@ -28,6 +31,7 @@ generate_context_binary, ModuleQConfig, prepare_pt2e, + QnnExecuTorchBackendType, QuantDtype, TestQNN, validate_context_binary, @@ -45,6 +49,7 @@ capture_program, dump_context_from_pte, from_context_binary, + generate_gpu_compiler_spec, generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, is_qnn_sdk_version_less_than, @@ -68,11 +73,7 @@ from collections import defaultdict from typing import List -from executorch.backends.qualcomm._passes import ( - ExpandBroadcastTensorShape, - FoldQDQ, - TagQuantIO, -) +from executorch.backends.qualcomm._passes import FoldQDQ, TagQuantIO from executorch.backends.qualcomm.builders.node_visitor_manager import get_node_visitors from executorch.backends.qualcomm.debugger.utils import DrawGraph from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model @@ -93,9 +94,16 @@ class TestQNNFloatingPointOperator(TestQNN): # TODO: refactor to support different backends def setUp(self): + match self.get_backend_type(): + case QnnExecuTorchBackendType.kHtpBackend: + backend_options = generate_htp_compiler_spec(use_fp16=True) + case QnnExecuTorchBackendType.kGpuBackend: + backend_options = generate_gpu_compiler_spec() + case _: + raise ValueError("Backend is not implemented yet") + TestQNN.atol = 1e-1 TestQNN.rtol = 1e-1 - backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, @@ -122,6 +130,36 @@ def test_qnn_backend_adaptive_avg_pool2d(self): sample_input = (torch.randn(1, 512, 7, 7),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool3d(self): + # NOTE: Support the cases mod(input_dhw, output_dhw) = 0 + modules = [ + AdaptiveAvgPool3D((2, 2, 2)), # noqa: F405 + AdaptiveAvgPool3D((8)), # noqa: F405 + AdaptiveAvgPool3D((2, None, None)), # noqa: F405 + ] + sample_inputs = [ + (torch.randn(1, 512, 16, 8, 16),), + ] + for j in range(len(sample_inputs)): + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_inputs[j]) + + def test_qnn_backend_adaptive_max_pool2d(self): + sample_input = (torch.randn(1, 512, 24, 24),) + # NOTE: Currently, we only support the return_indices is False and default is False. + # NOTE: Currently, we only support the case mod(in_w, out_w)=0 and mod(in_h, out_h)=0. + modules = [ + AdaptiveMaxPool2D((1, 1), False), # noqa: F405 + AdaptiveMaxPool2D((4, 4)), # noqa: F405 + AdaptiveMaxPool2D((24, 24)), # noqa: F405 + AdaptiveMaxPool2D((None, 4)), # noqa: F405 + AdaptiveMaxPool2D((12, None)), # noqa: F405 + ] + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_alias(self): module = Alias() # noqa: F405 sample_input = (torch.randn(1, 10),) @@ -173,14 +211,64 @@ def test_qnn_backend_arange(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_argmax(self): - module = Argmax() # noqa: F405 - sample_input = (torch.randn(16, 3, 4, 4),) - self.lower_module_and_test_output(module, sample_input) + test_cases = [ + { + QCOM_MODULE: Argmax(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmax(dim=0, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmax(dim=1, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),), + }, + { + QCOM_MODULE: Argmax(dim=None, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),), + }, + { + QCOM_MODULE: Argmax(dim=2, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),), + }, + ] + + for i, case in enumerate(test_cases): + with self.subTest(i=i): + self.lower_module_and_test_output( + case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS] + ) def test_qnn_backend_argmin(self): - module = Argmin() # noqa: F405 - sample_input = (torch.rand(3, 4),) - self.lower_module_and_test_output(module, sample_input) + test_cases = [ + { + QCOM_MODULE: Argmin(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmin(dim=0, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmin(dim=1, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),), + }, + { + QCOM_MODULE: Argmin(dim=None, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),), + }, + { + QCOM_MODULE: Argmin(dim=2, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),), + }, + ] + + for i, case in enumerate(test_cases): + with self.subTest(i=i): + self.lower_module_and_test_output( + case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS] + ) @unittest.expectedFailure def test_qnn_backend_asin(self): @@ -208,6 +296,25 @@ def test_qnn_backend_avg_pool2d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_inputs[i]) + def test_qnn_backend_avg_pool3d(self): + # NOTE: Support the cases mod(input_dhw, filter_dhw) = 0 + # NOTE: The pad should be at most half of effective kernel size. + modules = [ + AvgPool3d((8), (2), (1), True, True), # noqa: F405 + AvgPool3d((8), (2), (1), True, False), # noqa: F405 + AvgPool3d((8), (2), (1), False, False), # noqa: F405 + AvgPool3d((16, 16, 16), (4, 4, 4), (1, 1, 1), False, True), # noqa: F405 + AvgPool3d((8, 8, 8), (2, 2, 2), (1, 1, 1), True, True), # noqa: F405 + AvgPool3d((12, 12, 12), (4, 6, 2), (0, 0, 0), True, True), # noqa: F405 + ] + sample_inputs = [ + (torch.randn(1, 3, 64, 48, 32),), + ] + for j in range(len(sample_inputs)): + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_batch_norm(self): modules = [BatchNorm(32), BatchNorm(32, False)] # noqa: F405 sample_input = (torch.randn([4, 32, 16, 16]),) @@ -232,7 +339,7 @@ def test_qnn_backend_cast(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_cat(self): - modules = [Cat2(), Cat3(), Cat4()] # noqa: F405 + modules = [Cat2(), Cat3(), Cat4(), Cat5()] # noqa: F405 sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2)) for i, module in enumerate(modules): with self.subTest(i=i): @@ -282,6 +389,13 @@ def test_qnn_backend_conv2d_channel_last(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv3d_sequential(self): + modules = [Conv3dSequential(), Conv3dSequential(bias=False)] # noqa: F405 + sample_input = (torch.randn([2, 1, 10, 32, 32]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv_transpose1d(self): modules = [ ConvTranspose1dSingle(), # noqa: F405 @@ -294,14 +408,107 @@ def test_qnn_backend_conv_transpose1d(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv_transpose2d(self): + test_comb = [ + { + QCOM_MODULE: [ConvTranspose2dSingle()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 1, 16, 16),), + ], + }, + { + QCOM_MODULE: [ConvTranspose2dSingle(bias=False)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 1, 16, 16),), + ], + }, + { + QCOM_MODULE: [ + ConvTranspose2dSingle( # noqa: F405 + in_channels=2, + out_channels=3, + dilation=2, + kernel_size=3, + stride=2, + ) + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 2, 16, 16),), + ], + }, + { + QCOM_MODULE: [ + ConvTranspose2dSingle( # noqa: F405 + in_channels=2, + out_channels=3, + dilation=(2, 3), + kernel_size=3, + stride=2, + ) + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 2, 16, 16),), + ], + }, + { + QCOM_MODULE: [ + ConvTranspose2dSingle( # noqa: F405 + in_channels=2, + out_channels=3, + dilation=(2, 1), + kernel_size=3, + stride=2, + ) + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 2, 16, 16),), + ], + }, + { + QCOM_MODULE: [ + ConvTranspose2dSingle( # noqa: F405 + in_channels=2, + out_channels=3, + dilation=(2, 1), + kernel_size=3, + stride=2, + ) + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 2, 16, 16),), + ], + }, + { + QCOM_MODULE: [ + ConvTranspose2dSingle( # noqa: F405 + in_channels=6, + out_channels=6, + kernel_size=3, + padding=0, + groups=2, + ) + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(4, 6, 16, 16),), + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_conv_transpose3d(self): modules = [ - ConvTranspose2dSingle(), # noqa: F405 - ConvTranspose2dSingle(bias=False), # noqa: F405 - ConvTranspose2dSingle(dilation=2), # noqa: F405 - ConvTranspose2dSingle(dilation=(2, 3)), # noqa: F405 - ConvTranspose2dSingle(dilation=(2, 1)), # noqa: F405 + ConvTranspose3dSingle(), # noqa: F405 + ConvTranspose3dSingle(bias=False), # noqa: F405 + ConvTranspose3dSingle(dilation=2), # noqa: F405 + ConvTranspose3dSingle(dilation=(3, 2, 3)), # noqa: F405 ] - sample_input = (torch.randn([1, 1, 33, 33]),) + sample_input = (torch.randn([1, 1, 3, 3, 3]),) for i, module in enumerate(modules): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) @@ -328,8 +535,8 @@ def test_qnn_backend_cumsum(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_einsum_outer_product(self): module = EinsumOuterProduct() # noqa: F405 @@ -372,6 +579,24 @@ def test_qnn_backend_element_wise_add(self): ], QCOM_SAMPLE_INPUTS: [(torch.randint(0, 10, size=(2, 3)),)], }, + { + QCOM_MODULE: [ + AddAlpha(alpha=2), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([[1.2, 1.3, 1.4]]), + torch.tensor([[0.8, 1.6, 0.2]]), + ) + ], + }, + { + QCOM_MODULE: [ + AddAlphaConstant(alpha=2, constant_first=True), # noqa: F405 + AddAlphaConstant(alpha=2, constant_first=False), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [(torch.tensor([[1.2, 1.3, 1.4]]),)], + }, ] index = 0 @@ -379,8 +604,8 @@ def test_qnn_backend_element_wise_add(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_element_wise_and(self): module = And(torch.tensor(1.7), torch.tensor(0.2)) # noqa: F405 @@ -418,8 +643,8 @@ def test_qnn_backend_element_wise_div(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_element_wise_mul(self): test_comb = [ @@ -445,8 +670,8 @@ def test_qnn_backend_element_wise_mul(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_element_wise_or(self): test_comb = [ @@ -495,6 +720,24 @@ def test_qnn_backend_element_wise_sub(self): QCOM_MODULE: [SubConstantFloat()], # noqa: F405 QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, + { + QCOM_MODULE: [ + SubAlpha(alpha=2), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([[1.2, 1.3, 1.4]]), + torch.tensor([[0.8, 1.6, 0.2]]), + ) + ], + }, + { + QCOM_MODULE: [ + SubAlphaConstant(alpha=2, constant_first=True), # noqa: F405 + SubAlphaConstant(alpha=2, constant_first=False), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [(torch.tensor([[1.2, 1.3, 1.4]]),)], + }, ] index = 0 @@ -502,10 +745,9 @@ def test_qnn_backend_element_wise_sub(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) - @unittest.expectedFailure def test_qnn_backend_elu(self): module = Elu() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -539,16 +781,12 @@ def test_qnn_backend_expand(self): (torch.randn([3, 1]),), (torch.randn([4]),), ] - passes_job = get_capture_program_passes() - passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True index = 0 for module in modules: for sample_input in sample_inputs: with self.subTest(i=index): - self.lower_module_and_test_output( - module, sample_input, passes_job=passes_job - ) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_expm1(self): sample_input = (torch.randn(3, 4, 5),) @@ -571,6 +809,21 @@ def test_qnn_backend_floor_divide(self): { QCOM_MODULE: [FloorDiv()], # noqa: F405 QCOM_SAMPLE_INPUTS: [ + (torch.randint(-100, 100, (10, 10)), torch.full((10, 10), 3)), + ( + torch.randint(-100, 100, (10, 10)).float(), + torch.full((10, 10), 2.5), + ), + (torch.randint(-1000, 1000, (10, 10)), torch.full((10, 10), 100)), + (torch.tensor([10]), torch.arange(1, 5)), # Failed + (torch.arange(-10, 10), torch.tensor([2])), + (torch.randint(-100, 100, (20,)), torch.full((20,), 2)), + (torch.randint(-100, 100, (5, 10)), torch.full((5, 10), 2)), + (torch.randint(-100, 100, (3, 4, 5)), torch.full((3, 4, 5), 2)), + ( + torch.randint(-100, 100, (2, 3, 4, 5)), + torch.full((2, 3, 4, 5), 2), + ), (torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)), (torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])), ], @@ -586,8 +839,8 @@ def test_qnn_backend_floor_divide(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_fold(self): sample_input = (torch.randn(3, 512, 256),) @@ -631,6 +884,36 @@ def test_qnn_backend_gelu(self): sample_input = (torch.randn(2, 5, 1, 3),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_grid_sampler(self): + # NOTE: The grid_sampler 3d version is not supported in fp16. + modes = ["bilinear", "nearest"] + padding_modes = ["zeros", "border", "reflection"] + align_corners = [False, True] + grid_samples = [ + GridSample(mode, pad, align) # noqa: F405 + for mode, pad, align in itertools.product( + modes, padding_modes, align_corners + ) + ] + sample_inputs = [ + ( + torch.randn(1, 12, 14, 14), + torch.randn(1, 3, 3, 2), + ), # for grid_sampler 2d + ] + + for j in range(len(sample_inputs)): + for i, module in enumerate(grid_samples): + with self.subTest(i=i, j=j, module=module): + self.lower_module_and_test_output(module, sample_inputs[j]) + + def test_qnn_backend_glu(self): + modules = [torch.nn.GLU(), torch.nn.GLU(dim=0)] + sample_input = (torch.randn(2, 5, 1, 4),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_greater_equal(self): test_comb = [ { @@ -760,28 +1043,191 @@ def test_qnn_backend_index_copy(self): ) def test_qnn_backend_index_put(self): - test_comb = [ - { - QCOM_MODULE: IndexPut(skip_mutable_buffer=False), # noqa: F405 - QCOM_SAMPLE_INPUTS: ( - torch.tensor([2], dtype=torch.int32), - torch.randn([1, 1, 12, 64]), + skip_mutable_buffer = [False, True] + total_test_combo = [] + # mode 0 + sample_inputs = [ + (torch.tensor([0], dtype=torch.int32), torch.randn([1, 1, 12, 64])), + (torch.tensor([0], dtype=torch.int32), torch.randn([1, 64])), + (torch.tensor([0, 1], dtype=torch.int32), torch.randn([2, 1, 12, 64])), + (torch.tensor([0, 1], dtype=torch.int32), torch.randn([1, 64])), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 1 + sample_inputs = [ + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 1, 12, 64])), + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 64])), + (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 2, 12, 64])), + (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 64])), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 2 + sample_inputs = [ + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 1, 1, 64])), + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 64])), + (torch.tensor([0, 1], dtype=torch.int32), torch.randn([1, 1, 2, 64])), + (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 64])), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 3 + sample_inputs = [ + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), ), - }, - { - QCOM_MODULE: IndexPut(skip_mutable_buffer=True), # noqa: F405 - QCOM_SAMPLE_INPUTS: ( - torch.tensor([2], dtype=torch.int32), - torch.randn([1, 1, 12, 64]), + torch.randn([2, 12, 64]), + ), + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), ), - }, + torch.randn([1, 64]), + ), ] - for i, test in enumerate(test_comb): + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 4 + sample_inputs = [ + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([2, 64]), + ), + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([1, 64]), + ), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 5 + sample_inputs = [ + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([64]), + ), + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([1]), + ), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + + for i, test_combo in enumerate(total_test_combo): + for j, combo in enumerate(test_combo): + with self.subTest(f"mode_{i}-{j}"): + self.lower_module_and_test_output( + IndexPut(skip_mutable_buffer=combo[0], mode=i), # noqa: F405 + combo[1], + skip_mutable_buffer=combo[0], + ) + + def test_qnn_backend_index_put_suite(self): + accumulate = [False, True] + in_place = [False, True] + sample_inputs = [ + # basic + ( + torch.rand(5, 2) * 100, + (torch.tensor([0, 2]),), + torch.tensor([10.0, 20.0]), + ), + (torch.rand(5, 2), (torch.tensor([0, 2]),), torch.tensor([10.0, 20.0])), + # shape + (torch.rand(5), (torch.tensor([0, 2]),), torch.tensor([10.0, 20.0])), + ( + torch.rand(5, 2), + (torch.tensor([0, 2]), torch.tensor([1, 1])), + torch.tensor([10.0, 20.0]), + ), + ( + torch.rand(5, 3, 2), + (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1])), + torch.tensor([10.0, 20.0]), + ), + # TODO: not supported by HTP + # ( + # torch.rand(5, 3, 2, 4), + # (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1]), torch.tensor([2, 3])), + # torch.tensor([10.0]), + # ), + # indices + (torch.rand(5, 2), (torch.tensor([2]),), torch.tensor([10.0])), + ( + torch.rand(5, 3), + (torch.tensor([0, 2, 4]),), + torch.tensor([10.0, 20.0, 30.0]), + ), + ( + torch.rand(5), + (torch.tensor([1, 1, 3, 3]),), + torch.tensor([10.0, 20.0, 30.0, 40.0]), + ), + # broadcasting + (torch.rand(5, 3), (torch.tensor([0, 2, 4]),), torch.tensor([42.0])), + ( + torch.rand(3, 4), + (torch.tensor([0, 1]), torch.tensor([1, 2])), + torch.tensor([10.0, 20.0]), + ), + (torch.rand(4, 2), (torch.tensor([0, 2]),), torch.tensor([5.0, 15.0])), + ( + torch.rand(3, 2, 2), + (torch.tensor([0, 1]),), + torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + ), + (torch.rand(4, 2), (torch.tensor([1, 1, 1]),), torch.tensor([5.0])), + # two-index + ( + torch.rand(4, 3), + (torch.tensor([0, 1, 2]), torch.tensor([1, 0, 2])), + torch.tensor([10.0, 20.0, 30.0]), + ), + ( + torch.rand(3, 3), + (torch.tensor([0, 2]), torch.tensor([1, 1])), + torch.tensor([15.0, 25.0]), + ), + ( + torch.rand(3, 2), + (torch.tensor([1, 1, 2]), torch.tensor([0, 0, 1])), + torch.tensor([5.0, 10.0, 15.0]), + ), + ( + torch.rand(3, 2), + (torch.tensor([1]), torch.tensor([0, 0, 1])), + torch.tensor([5.0, 10.0, 15.0]), + ), + ] + test_combo = list(itertools.product(accumulate, in_place, sample_inputs)) + for i, combo in enumerate(test_combo): with self.subTest(i=i): self.lower_module_and_test_output( - test[QCOM_MODULE], - test[QCOM_SAMPLE_INPUTS], - skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, + IndexPutSuite(accumulate=combo[0], in_place=combo[1]), # noqa: F405 + combo[2], ) def test_qnn_backend_index_select(self): @@ -860,8 +1306,8 @@ def test_qnn_backend_leaky_relu(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_less_equal(self): test_comb = [ @@ -915,15 +1361,27 @@ def test_qnn_backend_linalg_vector_norm(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_linear(self): - module = Linear() # noqa: F405 + modules = [ + Linear(), # noqa: F405 + LinearNonConstantWeight(), # noqa: F405 + ] sample_input = (torch.randn([3, 512]),) - self.lower_module_and_test_output(module, sample_input) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_log(self): module = Log() # noqa: F405 sample_input = (torch.rand([1, 2, 3, 4]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_logical_and(self): + module = LogicalAnd() # noqa: F405 + input1 = torch.tensor([True, False, True, False]) + input2 = torch.tensor([True, True, False, False]) + sample_input = (input1, input2) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_logical_not(self): module = LogicalNot() # noqa: F405 sample_input = (torch.rand([1, 2, 3, 4]),) @@ -949,20 +1407,87 @@ def test_qnn_backend_max_pool2d(self): sample_input = (torch.randn(4, 3, 24, 24),) self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_mean_dim(self): - modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405 - sample_input = (torch.randn([2, 5, 1, 3]),) + def test_qnn_backend_max_pool3d(self): + # NOTE: The pad should be at most half of effective kernel size. + modules = [ + MaxPool3d((3), (1), (1), (1), False, False), # noqa: F405 + MaxPool3d((7), (1), (3), (1), False, False), # noqa: F405 + MaxPool3d((7), (1), (3), (1), True, False), # noqa: F405 + MaxPool3d( # noqa: F405 + (7, 7, 7), (1, 1, 1), (3, 3, 3), (1, 1, 1), True, False + ), # noqa: F405 + MaxPool3d( # noqa: F405 + (7, 9, 13), (1, 1, 1), (3, 4, 6), (1, 1, 1), False, False + ), # noqa: F405 + ] + sample_input = (torch.randn(1, 7, 21, 35, 28),) for i, module in enumerate(modules): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) - @unittest.skip("failed to lower in QNN 2.26") - def test_qnn_backend_mha(self): - module = MultiheadAttention() # noqa: F405 - sample_input = (torch.randn(1, 197, 96),) - self.lower_module_and_test_output(module, sample_input) - - def test_qnn_backend_minimum(self): + def test_qnn_backend_mean(self): + test_comb = [ + # Reduce over last two dims, keepdim=True + { + QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),), + }, + # Reduce over last two dims, keepdim=False + { + QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),), + }, + # Default: reduce all dims + { + QCOM_MODULE: Mean(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),), + }, + # TODO: To be enabled via reshape input to 1d tensor + # # Scalar case + # { + # QCOM_MODULE: Mean(), + # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),), + # }, + # Edge case: dim is a empty list + { + QCOM_MODULE: Mean(dim=[]), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),), + }, + # Edge case: reduce along dim=0 (batch dimension) + { + QCOM_MODULE: Mean(dim=0), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),), + }, + # Edge case: reduce along dim=0 with keepdim=True + { + QCOM_MODULE: Mean(dim=0, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),), + }, + # Edge case: reduce along multiple dims + { + QCOM_MODULE: Mean(dim=(0, 2)), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(3, 4, 5),), + }, + # Edge case: high-dimensional tensor + { + QCOM_MODULE: Mean(dim=(1, 3), keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4, 5, 6),), + }, + ] + + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + + @unittest.skip("failed to lower in QNN 2.26") + def test_qnn_backend_mha(self): + module = MultiheadAttention() # noqa: F405 + sample_input = (torch.randn(1, 197, 96),) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_minimum(self): module = Minimum() # noqa: F405 sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)) self.lower_module_and_test_output(module, sample_input) @@ -999,6 +1524,16 @@ def test_qnn_backend_pad(self): sample_input = (torch.randn([1, 8, 128]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_permute(self): + modules = [ + Permute([0, 2, 3, 1]), # noqa: F405 + Permute([-1, -3, -2, -4]), # noqa: F405 + ] + sample_input = (torch.randn([2, 3, 4, 5]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_pixel_shuffle(self): module = PixelShuffle(2) # noqa: F405 sample_input = (torch.ones([2, 4, 3, 3]),) @@ -1010,9 +1545,28 @@ def test_qnn_backend_pixel_unshuffle(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_pow_tensor_scalar(self): - module = PowTensorScalar() # noqa: F405 - sample_input = (torch.rand([2, 4, 3, 3]),) - self.lower_module_and_test_output(module, sample_input) + test_comb = [ + { + QCOM_MODULE: [ + PowTensorScalar(), # noqa: F405 + PowTensorScalar(1), # noqa: F405 + PowTensorScalar(-1), # noqa: F405 + PowTensorScalar(0.5), # noqa: F405 + ], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.rand(10, 10) + 0.1,)], + }, + { + QCOM_MODULE: [PowTensorScalar(10)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.rand(10, 10) * 0.5 + 0.5,)], + }, + ] + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_prelu(self): test_comb = [ @@ -1031,8 +1585,8 @@ def test_qnn_backend_prelu(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_relu(self): module = Relu() # noqa: F405 @@ -1055,9 +1609,14 @@ def test_qnn_backend_reshape(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_rms_norm(self): - module = RmsNorm() # noqa: F405 - sample_input = (torch.abs(torch.randn([1, 1, 1, 4])),) - self.lower_module_and_test_output(module, sample_input) + modules = [ + RmsNorm(), # noqa: F405 + RmsNorm(eps=1e-5), # noqa: F405 + ] + sample_input = (torch.randn([1, 1, 1, 4]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_roll(self): modules = [ @@ -1147,10 +1706,8 @@ def test_qnn_backend_slice_scatter(self): ], QCOM_SAMPLE_INPUTS: [ ( - ( - torch.zeros(8, 8), - torch.ones(8, 2), - ) + torch.zeros(8, 8), + torch.ones(8, 2), ) ], }, @@ -1161,8 +1718,8 @@ def test_qnn_backend_slice_scatter(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_stack(self): module = Stack() # noqa: F405 @@ -1195,11 +1752,63 @@ def test_qnn_backend_sum_int_list(self): sample_input = (torch.randn([1, 4, 8, 8]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_swapaxes(self): + module = SwapAxes(0, 1) # noqa: F405 + sample_input = (torch.randn([1, 2, 3, 4]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tanh(self): module = Tanh() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_threshold(self): + modules = [ + Threshold(), # noqa: F405 + Threshold(threshold=0.5, value=3.0, inplace=True), # noqa: F405 + Threshold(threshold=0.5, value=3.0, inplace=False), # noqa: F405 + ] + sample_input = (torch.randn(2, 5, 1, 3),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_triu(self): + test_comb = [ + { + QCOM_MODULE: [ + Triu(), # noqa: F405 + Triu(diagonal=1), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(3, 3),), + (torch.randn(1, 2, 3, 3),), + ], + }, + { + QCOM_MODULE: [ + TriuConstant(1), # noqa: F405 + TriuConstant(1, constant_dtype=torch.bool), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + (torch.zeros(5, 5),), + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_unflatten(self): + module = Unflatten(dim=1, sizes=(2, 3, 4)) # noqa: F405 + sample_input = (torch.randn([1, 24]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_unbind(self): module = Unbind() # noqa: F405 sample_input = (torch.randn([3, 3]),) @@ -1514,6 +2123,39 @@ def test_qnn_backend_16a4w_conv2d_qat(self): ) self.lower_module_and_test_output(converted, sample_input) + def test_qnn_backend_16a4w_block_conv2d_qat(self): + in_c = 512 + out_c = 32 + kernel = 1 + padding = 0 + modules = [ + Conv2dSingle( # noqa: F405 + in_channel=in_c, + out_channel=out_c, + kernel_size=kernel, + padding=padding, + ), + Conv2dSingle( # noqa: F405 + in_channel=in_c, + out_channel=out_c, + kernel_size=kernel, + padding=padding, + ), + ] # noqa: F405 + sample_input = (torch.randn([1, 512, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + prepared = self.get_prepared_qat_module( + module, + sample_input, + quant_dtype=QuantDtype.use_16a4w_block, + block_size_map={"conv2d": (1, 32, 1, 1)}, + ) + converted = self.get_converted_sgd_trained_module( + module, prepared, sample_input + ) + self.lower_module_and_test_output(converted, sample_input) + def test_qnn_backend_16a4w_layer_norm(self): module = LayerNorm() # noqa: F405 sample_input = (torch.randn(196, 768),) @@ -1574,6 +2216,38 @@ def test_qnn_backend_adaptive_avg_pool2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool3d(self): + # NOTE: Support the cases mod(input_dhw, output_dhw) = 0 + modules = [ + AdaptiveAvgPool3D((2, 2, 2)), # noqa: F405 + AdaptiveAvgPool3D((8)), # noqa: F405 + AdaptiveAvgPool3D((2, None, None)), # noqa: F405 + ] + sample_inputs = [ + (torch.randn(1, 512, 16, 8, 16),), + ] + for j in range(len(sample_inputs)): + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_inputs[j]) + self.lower_module_and_test_output(module, sample_inputs[j]) + + def test_qnn_backend_adaptive_max_pool2d(self): + sample_input = (torch.randn(1, 512, 24, 24),) + # NOTE: Currently, we only support the return_indices is False and default is False. + # NOTE: Currently, we only support the case mod(in_w, out_w)=0 and mod(in_h, out_h)=0. + modules = [ + AdaptiveMaxPool2D((1, 1), False), # noqa: F405 + AdaptiveMaxPool2D((4, 4)), # noqa: F405 + AdaptiveMaxPool2D((24, 24)), # noqa: F405 + AdaptiveMaxPool2D((None, 4)), # noqa: F405 + AdaptiveMaxPool2D((12, None)), # noqa: F405 + ] + for i, module in enumerate(modules): + with self.subTest(i=i): + module_one = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module_one, sample_input) + def test_qnn_backend_alias(self): module = Alias() # noqa: F405 sample_input = (torch.randn(1, 10),) @@ -1631,16 +2305,66 @@ def test_qnn_backend_arange(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_argmax(self): - module = Argmax() # noqa: F405 - sample_input = (torch.randn(16, 3, 4, 4),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + test_cases = [ + { + QCOM_MODULE: Argmax(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmax(dim=0, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmax(dim=1, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),), + }, + { + QCOM_MODULE: Argmax(dim=None, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),), + }, + { + QCOM_MODULE: Argmax(dim=2, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),), + }, + ] + + for i, case in enumerate(test_cases): + with self.subTest(i=i): + module = self.get_qdq_module( + case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, case[QCOM_SAMPLE_INPUTS]) def test_qnn_backend_argmin(self): - module = Argmin() # noqa: F405 - sample_input = (torch.randn(16, 3, 4, 4),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + test_cases = [ + { + QCOM_MODULE: Argmin(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmin(dim=0, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmin(dim=1, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),), + }, + { + QCOM_MODULE: Argmin(dim=None, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),), + }, + { + QCOM_MODULE: Argmin(dim=2, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),), + }, + ] + + for i, case in enumerate(test_cases): + with self.subTest(i=i): + module = self.get_qdq_module( + case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, case[QCOM_SAMPLE_INPUTS]) def test_qnn_backend_asin(self): module = Asin() # noqa: F405 @@ -1670,6 +2394,26 @@ def test_qnn_backend_avg_pool2d(self): module = self.get_qdq_module(module, sample_inputs[i]) self.lower_module_and_test_output(module, sample_inputs[i]) + def test_qnn_backend_avg_pool3d(self): + # NOTE: Support the cases mod(input_dhw, filter_dhw) = 0 + # NOTE: The pad should be at most half of effective kernel size. + modules = [ + AvgPool3d((8), (2), (1), True, True), # noqa: F405 + AvgPool3d((8), (2), (1), True, False), # noqa: F405 + AvgPool3d((8), (2), (1), False, False), # noqa: F405 + AvgPool3d((16, 16, 16), (4, 4, 4), (1, 1, 1), False, True), # noqa: F405 + AvgPool3d((8, 8, 8), (2, 2, 2), (1, 1, 1), True, True), # noqa: F405 + AvgPool3d((12, 12, 12), (4, 6, 2), (0, 0, 0), True, True), # noqa: F405 + ] + sample_inputs = [ + (torch.randn(1, 3, 64, 48, 32),), + ] + for j in range(len(sample_inputs)): + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_inputs[j]) + self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_batch_norm(self): modules = [BatchNorm(32), BatchNorm(32, False)] # noqa: F405 sample_input = (torch.randn([4, 32, 16, 16]),) @@ -1692,7 +2436,7 @@ def test_qnn_backend_cast(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_cat(self): - modules = [Cat2(), Cat3(), Cat4()] # noqa: F405 + modules = [Cat2(), Cat3(), Cat4(), Cat5()] # noqa: F405 sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2)) for i, module in enumerate(modules): with self.subTest(i=i): @@ -1782,6 +2526,14 @@ def test_qnn_backend_conv2d_channel_last(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv3d_sequential(self): + modules = [Conv3dSequential(), Conv3dSequential(bias=False)] # noqa: F405 + sample_input = (torch.randn([2, 1, 10, 32, 32]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(qdq_module, sample_input) + def test_qnn_backend_conv_transpose1d(self): modules = [ ConvTranspose1dSingle(), # noqa: F405 @@ -1795,13 +2547,138 @@ def test_qnn_backend_conv_transpose1d(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv_transpose2d(self): + test_comb = [ + { + QCOM_MODULE: [ConvTranspose2dSingle()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 1, 16, 16),), + ], + }, + { + QCOM_MODULE: [ConvTranspose2dSingle(bias=False)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 1, 16, 16),), + ], + }, + { + QCOM_MODULE: [ + ConvTranspose2dSingle( # noqa: F405 + in_channels=2, + out_channels=3, + dilation=2, + kernel_size=3, + stride=2, + ) + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 2, 16, 16),), + ], + }, + { + QCOM_MODULE: [ + ConvTranspose2dSingle( # noqa: F405 + in_channels=2, + out_channels=3, + dilation=(2, 3), + kernel_size=3, + stride=2, + ) + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 2, 16, 16),), + ], + }, + { + QCOM_MODULE: [ + ConvTranspose2dSingle( # noqa: F405 + in_channels=2, + out_channels=3, + dilation=(2, 1), + kernel_size=3, + stride=2, + ) + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 2, 16, 16),), + ], + }, + { + QCOM_MODULE: [ + ConvTranspose2dSingle( # noqa: F405 + in_channels=2, + out_channels=3, + dilation=(2, 1), + kernel_size=3, + stride=2, + ) + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(1, 2, 16, 16),), + ], + }, + { + QCOM_MODULE: [ + ConvTranspose2dSingle( # noqa: F405 + in_channels=6, + out_channels=6, + kernel_size=3, + padding=0, + groups=2, + ) + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(4, 6, 16, 16),), + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + gm = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(gm, sample_input) + + @unittest.skip("As of QNN 2.37, transpose conv block quant is not supported") + def test_qnn_backend_conv_transpose2d_block(self): + i_ch, o_ch, kernel, padding = 128, 32, (1, 1), 0 modules = [ - ConvTranspose2dSingle(), # noqa: F405 - ConvTranspose2dSingle(bias=False), # noqa: F405 - ConvTranspose2dSingle(dilation=(2, 3)), # noqa: F405 - ConvTranspose2dSingle(dilation=(2, 1)), # noqa: F405 + ConvTranspose2dSingle( # noqa: F405 + bias=False, + in_channels=i_ch, + out_channels=o_ch, + kernel_size=kernel, + padding=padding, + ), + ConvTranspose2dSingle( # noqa: F405 + in_channels=i_ch, + out_channels=o_ch, + kernel_size=kernel, + padding=padding, + ), ] - sample_input = (torch.randn([1, 1, 3, 3]),) + + sample_input = (torch.randn(1, 128, 16, 16),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module( + module, + sample_input, + quant_dtype=QuantDtype.use_16a4w_block, + block_size_map={"conv_transpose2d": (16, 1, 1, 1)}, + ) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_conv_transpose3d(self): + modules = [ + ConvTranspose3dSingle(), # noqa: F405 + ConvTranspose3dSingle(bias=False), # noqa: F405 + ConvTranspose3dSingle(dilation=2), # noqa: F405 + ConvTranspose3dSingle(dilation=(3, 2, 3)), # noqa: F405 + ] + sample_input = (torch.randn([1, 1, 3, 3, 3]),) for i, module in enumerate(modules): with self.subTest(i=i): module = self.get_qdq_module(module, sample_input) @@ -1856,16 +2733,34 @@ def test_qnn_backend_element_wise_add(self): QCOM_MODULE: [AddConstantFloat(), AddConstantLong()], # noqa: F405 QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, - ] - - index = 0 - for comb in test_comb: - for module in comb[QCOM_MODULE]: - for sample_input in comb[QCOM_SAMPLE_INPUTS]: - with self.subTest(i=index): - gm = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(gm, sample_input) + { + QCOM_MODULE: [ + AddAlpha(alpha=2), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([[1.2, 1.3, 1.4]]), + torch.tensor([[0.8, 1.6, 0.2]]), + ) + ], + }, + { + QCOM_MODULE: [ + AddAlphaConstant(alpha=2, constant_first=True), # noqa: F405 + AddAlphaConstant(alpha=2, constant_first=False), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [(torch.tensor([[1.2, 1.3, 1.4]]),)], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): index += 1 + gm = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(gm, sample_input) def test_qnn_backend_element_wise_and(self): module = And(torch.tensor(1.7), torch.tensor(0.2)) # noqa: F405 @@ -1904,9 +2799,9 @@ def test_qnn_backend_element_wise_div(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): + index += 1 gm = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(gm, sample_input) - index += 1 def test_qnn_backend_element_wise_mul(self): test_comb = [ @@ -1932,9 +2827,9 @@ def test_qnn_backend_element_wise_mul(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): + index += 1 gm = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(gm, sample_input) - index += 1 def test_qnn_backend_element_wise_or(self): test_comb = [ @@ -1985,6 +2880,24 @@ def test_qnn_backend_element_wise_sub(self): QCOM_MODULE: [SubConstantFloat(), SubConstantLong()], # noqa: F405 QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, + { + QCOM_MODULE: [ + SubAlpha(alpha=2), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([[1.2, 1.3, 1.4]]), + torch.tensor([[0.8, 1.6, 0.2]]), + ) + ], + }, + { + QCOM_MODULE: [ + SubAlphaConstant(alpha=2, constant_first=True), # noqa: F405 + SubAlphaConstant(alpha=2, constant_first=False), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [(torch.tensor([[1.2, 1.3, 1.4]]),)], + }, ] index = 0 @@ -1992,9 +2905,9 @@ def test_qnn_backend_element_wise_sub(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): + index += 1 gm = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(gm, sample_input) - index += 1 def test_qnn_backend_elu(self): module = Elu() # noqa: F405 @@ -2037,17 +2950,13 @@ def test_qnn_backend_expand(self): (torch.randn([3, 1]),), (torch.randn([4]),), ] - passes_job = get_capture_program_passes() - passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True index = 0 for module in modules: for sample_input in sample_inputs: with self.subTest(i=index): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output( - module, sample_input, passes_job=passes_job - ) index += 1 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_expm1(self): sample_input = (torch.randn(3, 4, 5),) @@ -2073,6 +2982,21 @@ def test_qnn_backend_floor_divide(self): { QCOM_MODULE: [FloorDiv()], # noqa: F405 QCOM_SAMPLE_INPUTS: [ + (torch.randint(-100, 100, (10, 10)), torch.full((10, 10), 3)), + ( + torch.randint(-100, 100, (10, 10)).float(), + torch.full((10, 10), 2.5), + ), + (torch.randint(-1000, 1000, (10, 10)), torch.full((10, 10), 100)), + (torch.tensor([10]), torch.arange(1, 5)), + (torch.arange(-10, 10), torch.tensor([2])), + (torch.randint(-100, 100, (20,)), torch.full((20,), 2)), + (torch.randint(-100, 100, (5, 10)), torch.full((5, 10), 2)), + (torch.randint(-100, 100, (3, 4, 5)), torch.full((3, 4, 5), 2)), + ( + torch.randint(-100, 100, (2, 3, 4, 5)), + torch.full((2, 3, 4, 5), 2), + ), (torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)), (torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])), ], @@ -2088,9 +3012,12 @@ def test_qnn_backend_floor_divide(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - gm = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(gm, sample_input) index += 1 + # Support int input cases with bypass_check=True + gm = self.get_qdq_module( + module, sample_input, bypass_check=True + ) + self.lower_module_and_test_output(gm, sample_input) def test_qnn_backend_fold(self): sample_input = (torch.randn(3, 512, 256),) @@ -2139,6 +3066,42 @@ def test_qnn_backend_gelu(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_grid_sampler(self): + modes = ["bilinear", "nearest"] + padding_modes = ["zeros", "border", "reflection"] + align_corners = [False, True] + grid_samples = [ + GridSample(mode, pad, align) # noqa: F405 + for mode, pad, align in itertools.product( + modes, padding_modes, align_corners + ) + ] + sample_inputs = [ + ( + torch.randn(1, 12, 14, 14), + torch.randn(1, 3, 3, 2), + ), # for grid_sampler 2d + ( + torch.randn(1, 15, 9, 17, 33), + torch.randn(1, 7, 8, 9, 3), + ), # for grid_sampler 3d + ] + for j in range(len(sample_inputs)): + for i, module in enumerate(grid_samples): + with self.subTest(i=i, j=j, module=module): + module = self.get_qdq_module( + module, sample_inputs[j], quant_dtype=QuantDtype.use_16a16w + ) + self.lower_module_and_test_output(module, sample_inputs[j]) + + def test_qnn_backend_glu(self): + modules = [torch.nn.GLU(), torch.nn.GLU(dim=0)] + sample_input = (torch.randn(2, 5, 1, 4),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_greater_equal(self): test_comb = [ { @@ -2278,32 +3241,197 @@ def test_qnn_backend_index_copy(self): ) def test_qnn_backend_index_put(self): - test_comb = [ - { - QCOM_MODULE: IndexPut(skip_mutable_buffer=False), # noqa: F405 - QCOM_SAMPLE_INPUTS: ( - torch.tensor([2], dtype=torch.int32), - torch.randn([1, 1, 12, 64]), + skip_mutable_buffer = [False, True] + total_test_combo = [] + # mode 0 + sample_inputs = [ + (torch.tensor([0], dtype=torch.int32), torch.randn([1, 1, 12, 64])), + (torch.tensor([0], dtype=torch.int32), torch.randn([1, 64])), + (torch.tensor([0, 1], dtype=torch.int32), torch.randn([2, 1, 12, 64])), + (torch.tensor([0, 1], dtype=torch.int32), torch.randn([1, 64])), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 1 + sample_inputs = [ + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 1, 12, 64])), + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 64])), + (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 2, 12, 64])), + (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 64])), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 2 + sample_inputs = [ + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 1, 1, 64])), + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 64])), + (torch.tensor([0, 1], dtype=torch.int32), torch.randn([1, 1, 2, 64])), + (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 64])), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 3 + sample_inputs = [ + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), ), - }, - { - QCOM_MODULE: IndexPut(skip_mutable_buffer=True), # noqa: F405 - QCOM_SAMPLE_INPUTS: ( - torch.tensor([2], dtype=torch.int32), - torch.randn([1, 1, 12, 64]), + torch.randn([2, 12, 64]), + ), + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), ), - }, + torch.randn([1, 64]), + ), ] - for i, test in enumerate(test_comb): + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 4 + sample_inputs = [ + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([2, 64]), + ), + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([1, 64]), + ), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 5 + sample_inputs = [ + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([64]), + ), + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([1]), + ), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + + for i, test_combo in enumerate(total_test_combo): + for j, combo in enumerate(test_combo): + with self.subTest(f"mode_{i}-{j}"): + module = self.get_qdq_module( + IndexPut(skip_mutable_buffer=combo[0], mode=i), # noqa: F405 + combo[1], + ) + self.lower_module_and_test_output( + module, + combo[1], + skip_mutable_buffer=combo[0], + ) + + def test_qnn_backend_index_put_suite(self): + accumulate = [False, True] + in_place = [False, True] + sample_inputs = [ + # basic + ( + torch.rand(5, 2) * 100, + (torch.tensor([0, 2]),), + torch.tensor([10.0, 20.0]), + ), + (torch.rand(5, 2), (torch.tensor([0, 2]),), torch.tensor([10.0, 20.0])), + # shape + (torch.rand(5), (torch.tensor([0, 2]),), torch.tensor([10.0, 20.0])), + ( + torch.rand(5, 2), + (torch.tensor([0, 2]), torch.tensor([1, 1])), + torch.tensor([10.0, 20.0]), + ), + ( + torch.rand(5, 3, 2), + (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1])), + torch.tensor([10.0, 20.0]), + ), + # TODO: not supported by HTP + # ( + # torch.rand(5, 3, 2, 4), + # (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1]), torch.tensor([2, 3])), + # torch.tensor([10.0]), + # ), + # indices + (torch.rand(5, 2), (torch.tensor([2]),), torch.tensor([10.0])), + ( + torch.rand(5, 3), + (torch.tensor([0, 2, 4]),), + torch.tensor([10.0, 20.0, 30.0]), + ), + ( + torch.rand(5), + (torch.tensor([1, 1, 3, 3]),), + torch.tensor([10.0, 20.0, 30.0, 40.0]), + ), + # broadcasting + (torch.rand(5, 3), (torch.tensor([0, 2, 4]),), torch.tensor([42.0])), + ( + torch.rand(3, 4), + (torch.tensor([0, 1]), torch.tensor([1, 2])), + torch.tensor([10.0, 20.0]), + ), + (torch.rand(4, 2), (torch.tensor([0, 2]),), torch.tensor([5.0, 15.0])), + ( + torch.rand(3, 2, 2), + (torch.tensor([0, 1]),), + torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + ), + (torch.rand(4, 2), (torch.tensor([1, 1, 1]),), torch.tensor([5.0])), + # two-index + ( + torch.rand(4, 3), + (torch.tensor([0, 1, 2]), torch.tensor([1, 0, 2])), + torch.tensor([10.0, 20.0, 30.0]), + ), + ( + torch.rand(3, 3), + (torch.tensor([0, 2]), torch.tensor([1, 1])), + torch.tensor([15.0, 25.0]), + ), + ( + torch.rand(3, 2), + (torch.tensor([1, 1, 2]), torch.tensor([0, 0, 1])), + torch.tensor([5.0, 10.0, 15.0]), + ), + ( + torch.rand(3, 2), + (torch.tensor([1]), torch.tensor([0, 0, 1])), + torch.tensor([5.0, 10.0, 15.0]), + ), + ] + test_combo = list(itertools.product(accumulate, in_place, sample_inputs)) + for i, combo in enumerate(test_combo): with self.subTest(i=i): module = self.get_qdq_module( - test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] - ) - self.lower_module_and_test_output( - module, - test[QCOM_SAMPLE_INPUTS], - skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, + IndexPutSuite(accumulate=combo[0], in_place=combo[1]), # noqa: F405 + combo[2], ) + self.lower_module_and_test_output(module, combo[2]) def test_qnn_backend_index_select(self): module = IndexSelect(dim=1) # noqa: F405 @@ -2388,9 +3516,9 @@ def test_qnn_backend_leaky_relu(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): + index += 1 module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - index += 1 def test_qnn_backend_less_equal(self): test_comb = [ @@ -2442,10 +3570,15 @@ def test_qnn_backend_linalg_vector_norm(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_linear(self): - module = Linear() # noqa: F405 + modules = [ + Linear(), # noqa: F405 + LinearNonConstantWeight(), # noqa: F405 + ] sample_input = (torch.randn([3, 512]),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) @unittest.skipIf(is_qnn_sdk_version_less_than("2.30"), "UT pass after QNN 2.30") def test_qnn_backend_linear_block(self): @@ -2484,6 +3617,14 @@ def test_qnn_backend_log(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_logical_and(self): + module = LogicalAnd() # noqa: F405 + input1 = torch.tensor([0.0]) + input2 = torch.tensor([1.0]) + sample_input = (input1, input2) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_logical_not(self): module = LogicalNot() # noqa: F405 sample_input = (torch.rand([1, 2, 3, 4]),) @@ -2514,14 +3655,82 @@ def test_qnn_backend_max_pool2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_mean_dim(self): - modules = [MeanWKeppDim(), MeanWOKeppDim()] # noqa: F405 - sample_input = (torch.randn([2, 5, 1, 3]),) + def test_qnn_backend_max_pool3d(self): + # NOTE: The pad should be at most half of effective kernel size. + modules = [ + MaxPool3d((3), (1), (1), (1), False, False), # noqa: F405 + MaxPool3d((7), (1), (3), (1), False, False), # noqa: F405 + MaxPool3d((7), (1), (3), (1), True, False), # noqa: F405 + MaxPool3d( # noqa: F405 + (7, 7, 7), (1, 1, 1), (3, 3, 3), (1, 1, 1), True, False + ), # noqa: F405 + MaxPool3d( # noqa: F405 + (7, 9, 13), (1, 1, 1), (3, 4, 6), (1, 1, 1), False, False + ), # noqa: F405 + ] + sample_input = (torch.randn(1, 7, 21, 35, 28),) for i, module in enumerate(modules): with self.subTest(i=i): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_mean(self): + test_comb = [ + # Reduce over last two dims, keepdim=True + { + QCOM_MODULE: Mean(dim=(-1, -2), keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),), + }, + # Reduce over last two dims, keepdim=False + { + QCOM_MODULE: Mean(dim=(-1, -2), keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn([2, 5, 1, 3]),), + }, + # Default: reduce all dims + { + QCOM_MODULE: Mean(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(10, 10),), + }, + # TODO: To be enabled via reshape input to 1d tensor + # Scalar case + # { + # QCOM_MODULE: Mean(), + # QCOM_SAMPLE_INPUTS: (torch.tensor(5.0),), + # }, + # Edge case: dim is a empty list + { + QCOM_MODULE: Mean(dim=[]), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),), + }, + # Edge case: reduce along dim=0 (batch dimension) + { + QCOM_MODULE: Mean(dim=0), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),), + }, + # Edge case: reduce along dim=0 with keepdim=True + { + QCOM_MODULE: Mean(dim=0, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(4, 6, 8),), + }, + # Edge case: reduce along multiple dims + { + QCOM_MODULE: Mean(dim=(0, 2)), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(3, 4, 5),), + }, + # Edge case: high-dimensional tensor + { + QCOM_MODULE: Mean(dim=(1, 3), keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4, 5, 6),), + }, + ] + + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, test[QCOM_SAMPLE_INPUTS]) + def test_qnn_backend_mha(self): module = MultiheadAttention() # noqa: F405 sample_input = (torch.randn(1, 197, 96),) @@ -2570,6 +3779,17 @@ def test_qnn_backend_pad(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_permute(self): + modules = [ + Permute([0, 2, 3, 1]), # noqa: F405 + Permute([-1, -3, -2, -4]), # noqa: F405 + ] + sample_input = (torch.randn([2, 3, 4, 5]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_pixel_shuffle(self): module = PixelShuffle(2) # noqa: F405 sample_input = (torch.ones([2, 4, 3, 3]),) @@ -2583,10 +3803,29 @@ def test_qnn_backend_pixel_unshuffle(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_pow_tensor_scalar(self): - module = PowTensorScalar() # noqa: F405 - sample_input = (torch.rand([2, 4, 3, 3]),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + test_comb = [ + { + QCOM_MODULE: [ + PowTensorScalar(), # noqa: F405 + PowTensorScalar(1), # noqa: F405 + PowTensorScalar(-1), # noqa: F405 + PowTensorScalar(0.5), # noqa: F405 + ], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.rand(10, 10) + 0.1,)], + }, + { + QCOM_MODULE: [PowTensorScalar(10)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.rand(10, 10) * 0.5 + 0.5,)], + }, + ] + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(qdq_module, sample_input) def test_qnn_backend_prelu(self): test_comb = [ @@ -2605,9 +3844,9 @@ def test_qnn_backend_prelu(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): + index += 1 module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - index += 1 def test_qnn_backend_relu(self): module = Relu() # noqa: F405 @@ -2634,12 +3873,17 @@ def test_qnn_backend_reshape(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_rms_norm(self): - module = RmsNorm() # noqa: F405 - sample_input = (torch.abs(torch.randn([1, 1, 1, 4])),) - module = self.get_qdq_module( - module, sample_input, quant_dtype=QuantDtype.use_16a4w - ) - self.lower_module_and_test_output(module, sample_input) + modules = [ + RmsNorm(), # noqa: F405 + RmsNorm(eps=1e-5), # noqa: F405 + ] + sample_input = (torch.randn([1, 1, 1, 4]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module( + module, sample_input, quant_dtype=QuantDtype.use_16a4w + ) + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_roll(self): modules = [ @@ -2745,10 +3989,8 @@ def test_qnn_backend_slice_scatter(self): ], QCOM_SAMPLE_INPUTS: [ ( - ( - torch.zeros(8, 8), - torch.ones(8, 2), - ) + torch.zeros(8, 8), + torch.ones(8, 2), ) ], }, @@ -2759,9 +4001,9 @@ def test_qnn_backend_slice_scatter(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): + index += 1 module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - index += 1 def test_qnn_backend_softmax(self): modules = [Softmax(dim=1), Softmax(dim=-1)] # noqa: F405 @@ -2799,12 +4041,68 @@ def test_qnn_backend_sum_int_list(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_swapaxes(self): + module = SwapAxes(0, 1) # noqa: F405 + sample_input = (torch.randn([1, 2, 3, 4]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tanh(self): module = Tanh() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_threshold(self): + modules = [ + Threshold(), # noqa: F405 + Threshold(threshold=0.5, value=3.0, inplace=True), # noqa: F405 + Threshold(threshold=0.5, value=3.0, inplace=False), # noqa: F405 + ] + sample_input = (torch.randn(2, 5, 1, 3),) + for i, module in enumerate(modules): + with self.subTest(i=i): + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(qdq_module, sample_input) + + def test_qnn_backend_triu(self): + test_comb = [ + { + QCOM_MODULE: [ + Triu(), # noqa: F405 + Triu(diagonal=1), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + (torch.randn(3, 3),), + (torch.randn(1, 2, 3, 3),), + ], + }, + { + QCOM_MODULE: [ + TriuConstant(1), # noqa: F405 + TriuConstant(1, constant_dtype=torch.bool), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + (torch.zeros((5, 5)),), + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(qdq_module, sample_input) + + def test_qnn_backend_unflatten(self): + module = Unflatten(dim=1, sizes=(2, 3, 4)) # noqa: F405 + sample_input = (torch.randn([1, 24]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_unbind(self): module = Unbind() # noqa: F405 sample_input = (torch.randn([3, 3]),) @@ -2928,7 +4226,52 @@ def test_qnn_backend_chunk_add(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_conv1d_relu_log_softmax(self): + def test_qnn_backend_conformer(self): + from typing import Tuple + + import torchaudio + + class PatchedConformer(torch.nn.Module): + """ + A lightly modified version of the top-level Conformer module, such that it can be exported. + Instead of taking lengths and computing the padding mask, it takes the padding mask directly. + See https://github.com/pytorch/audio/blob/main/src/torchaudio/models/conformer.py#L215 + """ + + def __init__(self, conformer): + super().__init__() + self.conformer = conformer + + def forward( + self, input: torch.Tensor, encoder_padding_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = input.transpose(0, 1) + for layer in self.conformer.conformer_layers: + x = layer(x, encoder_padding_mask) + return x.transpose(0, 1) + + inner_model = torchaudio.models.Conformer( + input_dim=80, + num_heads=4, + ffn_dim=128, + num_layers=4, + depthwise_conv_kernel_size=31, + ) + lengths = torch.randint(1, 400, (10,)) + encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask( + lengths + ) + sample_input = ( + torch.rand(10, int(lengths.max()), 80), + encoder_padding_mask.to(torch.float32), + ) + module = PatchedConformer(inner_model).eval() + module = self.get_qdq_module( + module, sample_input, quant_dtype=QuantDtype.use_16a8w + ) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_conv1d_relu_log_softmax(self): modules = [ Conv1dReluLogSoftmax(dim=1), # noqa: F405 Conv1dReluLogSoftmax(dim=-1), # noqa: F405 @@ -3227,20 +4570,38 @@ def setUp(self): saver=False, ) - def test_qnn_backend_dump_intermediate_outputs(self): + def test_qnn_backend_dump_intermediate_outputs_topk(self): backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, dump_intermediate_outputs=True, ) - module = Relu() # noqa: F405 - sample_input = (torch.randn([2, 5, 1, 3]),) + module = TopKandIndex() # noqa: F405 + sample_input = (torch.randn(3, 10),) + self.lower_module_and_test_output( + module, + sample_input, + expected_partitions=1, + expected_intermediate_events=7, + expected_compared_events=5, + ) + + def test_qnn_backend_dump_intermediate_outputs_simple_model(self): + backend_options = generate_htp_compiler_spec(use_fp16=True) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + dump_intermediate_outputs=True, + ) + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) self.lower_module_and_test_output( module, sample_input, expected_partitions=1, - expected_intermediate_events=3, + expected_intermediate_events=20, + expected_compared_events=16, ) def test_qnn_backend_skip_node_id(self): @@ -3263,9 +4624,7 @@ def test_qnn_backend_skip_node_op(self): skip_node_op_set={"aten.add.Tensor"}, ) - @unittest.expectedFailure def test_qnn_backend_spill_fill_buffer_size(self): - # TODO: Fix self.assertNotEqual(0, max_sf_size) module = LargeTensorLinear() # noqa: F405 sample_input = (torch.randn(1, 256, 512),) backend_options = generate_htp_compiler_spec( @@ -3351,10 +4710,8 @@ def test_qnn_backend_multi_graphs(self): generate_qnn_executorch_compiler_spec( soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, - graph_name=graph_name, ) - for graph_name in graph_names - ] + ] * len(graph_names) modules_dict = {} sample_inputs_dict = {} @@ -3559,11 +4916,7 @@ def test_qnn_backend_context_extraction(self): lowered_module = edge_prog_mgr.exported_program().graph_module._modules[ "lowered_module_0" ] - qnn_mgr = PyQnnManagerAdaptor.QnnManager( - lowered_module.compile_specs[0].value - ) - qnn_mgr.Init() - binary = qnn_mgr.StripProtocol(lowered_module.processed_bytes) + binary = PyQnnManagerAdaptor.StripProtocol(lowered_module.processed_bytes) validate(binary) def test_qnn_backend_dump_context_from_pte(self): @@ -3803,21 +5156,40 @@ def setUp(self): saver=False, ) - def test_qnn_backend_dump_intermediate_outputs(self): + def test_qnn_backend_dump_intermediate_outputs_simple_model(self): backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, dump_intermediate_outputs=True, ) - module = Relu() # noqa: F405 - sample_input = (torch.randn([2, 5, 1, 3]),) + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output( + module, + sample_input, + expected_partitions=1, + expected_intermediate_events=21, + expected_compared_events=14, + ) + + def test_qnn_backend_dump_intermediate_outputs_topk(self): + backend_options = generate_htp_compiler_spec(use_fp16=False) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + dump_intermediate_outputs=True, + ) + module = TopKandIndex() # noqa: F405 + sample_input = (torch.randn(3, 10),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output( module, sample_input, expected_partitions=1, - expected_intermediate_events=5, + expected_intermediate_events=8, + expected_compared_events=5, ) def test_qnn_backend_dynamic_shape(self): @@ -4168,10 +5540,8 @@ def test_qnn_backend_multi_graphs(self): generate_qnn_executorch_compiler_spec( soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, - graph_name=graph_name, ) - for graph_name in graph_names - ] + ] * len(graph_names) modules_dict = {} sample_inputs_dict = {} compiler_specs_dict = {} @@ -4386,11 +5756,7 @@ def test_qnn_backend_context_extraction(self): lowered_module = edge_prog_mgr.exported_program().graph_module._modules[ "lowered_module_0" ] - qnn_mgr = PyQnnManagerAdaptor.QnnManager( - lowered_module.compile_specs[0].value - ) - qnn_mgr.Init() - binary = qnn_mgr.StripProtocol(lowered_module.processed_bytes) + binary = PyQnnManagerAdaptor.StripProtocol(lowered_module.processed_bytes) validate(binary) def test_qnn_backend_dump_context_from_pte(self): @@ -4665,85 +6031,74 @@ def test_qnn_backend_seq_mse(self): class TestExampleLLMScript(TestQNN): - def test_static_gemma3_1b(self): - if not self.required_envs(): - self.skipTest("missing required envs") - prompt = "My favourite condiment is " - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--model", - self.model, - "--ip", - self.ip, - "--port", - str(self.port), - "--prompt", - f"{prompt}", - "--ptq", - "16a4w_block", - "--temperature", - "0", - "--decoder_model", - "gemma3-1b", - "--model_mode", - "kv", - "--max_seq_len", - "1024", - "--eval_perplexity", - "--tasks", - "wikitext", - "--limit", - "1", - "--enable_masked_softmax", - ] - if self.compile_only: - cmds.extend(["--compile_only"]) - elif self.device: - cmds.extend(["--device", self.device]) - if self.host: - cmds.extend(["--host", self.host]) - elif self.enable_x86_64: - cmds.extend(["--enable_x86_64"]) - if self.pre_gen_pte: - cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + @dataclass(frozen=True) + class LlmSpecs: + SM8650: float + SM8750: float + ppl: float + pte_size: float - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - if not self.compile_only: - self.assertLessEqual(msg["wiki_ppl"], 23) - if not self.enable_x86_64: - pte_size = msg["pte_size"] - self.assertLessEqual(pte_size, 1_200_000_000) # 1.2GB - inference_speed_ref = {"SM8650": 70, "SM8750": 100} - if ( - not self.compile_only - and not self.enable_x86_64 - and self.model in inference_speed_ref - ): - self.assertGreaterEqual( - msg["inference_speed"], inference_speed_ref[self.model] - ) + # TODO: refactor to support different backends + def setUp(self): + self.llm_specs = { + "gemma-2b": TestExampleLLMScript.LlmSpecs( + SM8650=32, SM8750=36, ppl=35, pte_size=2_700_000_000 + ), # 2.7 GB + "gemma3-1b": TestExampleLLMScript.LlmSpecs( + SM8650=70, SM8750=100, ppl=23, pte_size=1_200_000_000 + ), # 1.2 GB + "glm-1_5b": TestExampleLLMScript.LlmSpecs( + SM8650=42, SM8750=52, ppl=21, pte_size=1_100_000_000 + ), # 1.1 GB + "phi_4_mini": TestExampleLLMScript.LlmSpecs( + SM8650=14, SM8750=19, ppl=12, pte_size=4_000_000_000 + ), # 4GB + "llama3_2-1b_instruct": TestExampleLLMScript.LlmSpecs( + SM8650=37, SM8750=45, ppl=16, pte_size=1_500_000_000 + ), # 1.5 GB + "llama3_2-3b_instruct": TestExampleLLMScript.LlmSpecs( + SM8650=21, SM8750=26, ppl=11, pte_size=2_800_000_000 + ), # 2.8 GB + "qwen2_5-0_5b": TestExampleLLMScript.LlmSpecs( + SM8650=115, SM8750=155, ppl=15, pte_size=600_000_000 + ), # 600 MB + "qwen2_5-1_5b": TestExampleLLMScript.LlmSpecs( + SM8650=38, SM8750=47, ppl=10, pte_size=1_500_000_000 + ), # 1.5 GB + "qwen3-0_6b": TestExampleLLMScript.LlmSpecs( + SM8650=47, SM8750=68, ppl=21, pte_size=700_000_000 + ), # 700 MB + "qwen3-1_7b": TestExampleLLMScript.LlmSpecs( + SM8650=28, SM8750=34, ppl=15, pte_size=1_800_000_000 + ), # 1.8 GB + "smollm2_135m": TestExampleLLMScript.LlmSpecs( + SM8650=214, SM8750=260, ppl=23, pte_size=210_000_000 + ), # 210 MB + "smollm3-3b": TestExampleLLMScript.LlmSpecs( + SM8650=23, SM8750=28, ppl=10, pte_size=2_600_000_000 + ), # 2.6 GB + } - def test_llama3_2_instruct(self): - if not self.required_envs(): + def test_static_llm_model(self): + if not self.required_envs([self.model_name]): self.skipTest("missing required envs") assert ( - self.llama_artifacts is not None - ), "Please provide path to llama artifacts" + self.model_name in self.llm_specs + ), f"Unable to find {self.model_name} under model_specs." - prompt = "What is the meaning of life?" + is_llama_model = self.model_name in { + "llama3_2-1b_instruct", + "llama3_2-3b_instruct", + } + if is_llama_model: + assert ( + self.llama_artifacts is not None + ), "Please provide path to llama artifacts" + + prompt = ( + "I would like to learn python, could you teach me with a simple example?" + ) cmds = [ "python", f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", @@ -4753,12 +6108,6 @@ def test_llama3_2_instruct(self): self.build_folder, "--model", self.model, - "--checkpoint", - f"{self.llama_artifacts}/consolidated.00.pth", - "--params", - f"{self.llama_artifacts}/params.json", - "--tokenizer_model", - f"{self.llama_artifacts}/tokenizer.model", "--ip", self.ip, "--port", @@ -4768,90 +6117,30 @@ def test_llama3_2_instruct(self): "--temperature", "0", "--decoder_model", - "llama3_2-1b_instruct", + self.model_name, "--model_mode", "kv", "--max_seq_len", "1024", - "--eval_perplexity", + "--run_lm_eval", "--tasks", "wikitext", "--limit", "1", ] - if self.compile_only: - cmds.extend(["--compile_only"]) - elif self.device: - cmds.extend(["--device", self.device]) - if self.host: - cmds.extend(["--host", self.host]) - elif self.enable_x86_64: - cmds.extend(["--enable_x86_64"]) - if self.pre_gen_pte: - cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - inference_speed_ref = {"SM8650": 37, "SM8750": 49} - if ( - not self.compile_only - and not self.enable_x86_64 - and self.model in inference_speed_ref - ): - self.assertLessEqual(msg["pte_size"], 1_500_000_000) - self.assertLessEqual(msg["wiki_ppl"], 15) - self.assertGreaterEqual( - msg["inference_speed"], inference_speed_ref[self.model] - ) - - def test_llama_stories_260k(self): - if not self.required_envs(): - self.skipTest("missing required envs") - assert ( - self.llama_artifacts is not None - ), "Please provide path to llama artifacts" + if is_llama_model: + cmds.extend( + [ + "--checkpoint", + f"{self.llama_artifacts}/consolidated.00.pth", + "--params", + f"{self.llama_artifacts}/params.json", + "--tokenizer_model", + f"{self.llama_artifacts}/tokenizer.model", + ] + ) - prompt = "Once" - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--model", - self.model, - "--checkpoint", - f"{self.llama_artifacts}/stories260K.pt", - "--params", - f"{self.llama_artifacts}/params.json", - "--tokenizer_model", - f"{self.llama_artifacts}/tokenizer.model", - "--tokenizer_bin", - f"{self.llama_artifacts}/tokenizer.bin", - "--ip", - self.ip, - "--port", - str(self.port), - "--prompt", - f"{prompt}", - "--temperature", - "0", - "--decoder_model", - "stories260k", - "--model_mode", - "hybrid", - "--prefill_ar_len", - "32", - "--max_seq_len", - "128", - ] if self.compile_only: cmds.extend(["--compile_only"]) elif self.device: @@ -4863,7 +6152,6 @@ def test_llama_stories_260k(self): if self.pre_gen_pte: cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - golden_start_with = "Once upon a time," p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: conn = listener.accept() @@ -4872,28 +6160,30 @@ def test_llama_stories_260k(self): if "Error" in msg: self.fail(msg["Error"]) else: + llm_spec = self.llm_specs[self.model_name] + pte_size = msg["pte_size"] + self.assertLessEqual(pte_size, llm_spec.pte_size) + print(f"Model Name: {self.model_name}\nTarget Device: {self.model}") + print(f"PTE Size: {pte_size} bytes") if not self.compile_only: - model_out = msg["result"][0] - print(f"Model CI result:{model_out[: len(golden_start_with)]}") - self.assertTrue( - model_out.startswith(golden_start_with), - f"Expected Output: {golden_start_with}. Actual Output: {model_out}", - ) - # x86 does not allow weight sharing, so we don't check pte size - if not self.enable_x86_64: - pte_size = msg["pte_size"] - self.assertLessEqual(pte_size, 2_020_000) # 2MB - if not self.compile_only and not self.enable_x86_64: - self.assertGreaterEqual(msg["inference_speed"], 1600) # Lanai + ppl = msg["wiki_ppl"] + print(f"PPL: {ppl}") + self.assertLessEqual(ppl, llm_spec.ppl) + if not self.enable_x86_64 and hasattr(llm_spec, self.model): + device_inference_speed = msg["inference_speed"] + expected_inference_speed = getattr(llm_spec, self.model) + print( + f"Prompt Evaluation: {device_inference_speed} tokens/second" + ) + self.assertGreaterEqual( + device_inference_speed, expected_inference_speed + ) - def test_llama_stories_110m(self): + def test_codegen2_1b(self): if not self.required_envs(): self.skipTest("missing required envs") - assert ( - self.llama_artifacts is not None - ), "Please provide path to llama artifacts" - prompt = "Once" + prompt = "def hello_world():" cmds = [ "python", f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", @@ -4903,28 +6193,18 @@ def test_llama_stories_110m(self): self.build_folder, "--model", self.model, - "--checkpoint", - f"{self.llama_artifacts}/stories110M.pt", - "--params", - f"{self.llama_artifacts}/params.json", - "--tokenizer_model", - f"{self.llama_artifacts}/tokenizer.model", - "--tokenizer_bin", - f"{self.llama_artifacts}/tokenizer.bin", "--ip", self.ip, "--port", str(self.port), "--prompt", - f"{prompt}", + prompt, "--temperature", "0", "--decoder_model", - "stories110m", + "codegen2_1b", "--model_mode", - "hybrid", - "--prefill_ar_len", - "32", + "kv", "--max_seq_len", "128", ] @@ -4939,7 +6219,7 @@ def test_llama_stories_110m(self): if self.pre_gen_pte: cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - golden_start_with = "Once upon a time," + golden_start_with = "def hello_world():" p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: conn = listener.accept() @@ -4954,77 +6234,17 @@ def test_llama_stories_110m(self): model_out.startswith(golden_start_with), f"Expected Output: {golden_start_with}. Actual Output: {model_out}", ) - # x86 does not allow weight sharing, so we don't check pte size if not self.enable_x86_64: pte_size = msg["pte_size"] - self.assertLessEqual(pte_size, 130_000_000) # 130MB + self.assertLessEqual(pte_size, 1_200_000_000) # 1200MB if not self.compile_only and not self.enable_x86_64: - self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai - - def test_static_phi4(self): - if not self.required_envs(): - self.skipTest("missing required envs") - - prompt = "My favourite condiment is " - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--model", - self.model, - "--ip", - self.ip, - "--port", - str(self.port), - "--prompt", - f"{prompt}", - "--decoder_model", - "phi_4_mini", - "--model_mode", - "kv", - "--max_seq_len", - "1024", - "--eval_perplexity", - "--tasks", - "wikitext", - "--limit", - "1", - ] - if self.compile_only: - cmds.extend(["--compile_only"]) - elif self.device: - cmds.extend(["--device", self.device]) - if self.host: - cmds.extend(["--host", self.host]) - elif self.enable_x86_64: - cmds.extend(["--enable_x86_64"]) - if self.pre_gen_pte: - cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + self.assertGreaterEqual(msg["inference_speed"], 60) - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - inference_speed_ref = {"SM8650": 14, "SM8750": 19} - self.assertLessEqual(msg["wiki_ppl"], 12) - self.assertLessEqual(msg["pte_size"], 4_000_000_000) # 4GB - if self.model in inference_speed_ref: - self.assertGreaterEqual( - msg["inference_speed"], inference_speed_ref[self.model] - ) - - def test_static_qwen2_5(self): + def test_granite_3_3_2b_instruct(self): if not self.required_envs(): self.skipTest("missing required envs") - prompt = "My favourite condiment is " + prompt = "What is the meaning of life?" cmds = [ "python", f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", @@ -5040,17 +6260,21 @@ def test_static_qwen2_5(self): str(self.port), "--prompt", f"{prompt}", + "--temperature", + "0", "--decoder_model", - "qwen2_5-0_5b", + "granite_3_3-2b_instruct", "--model_mode", "kv", "--max_seq_len", "1024", - "--eval_perplexity", + "--run_lm_eval", "--tasks", - "wikitext", + "hellaswag", "--limit", - "1", + "10", + "--kv_updater", + "shift_pointer", ] if self.compile_only: cmds.extend(["--compile_only"]) @@ -5071,99 +6295,59 @@ def test_static_qwen2_5(self): if "Error" in msg: self.fail(msg["Error"]) else: - inference_speed_ref = {"SM8650": 115, "SM8750": 155} - self.assertLessEqual(msg["wiki_ppl"], 15) - self.assertLessEqual(msg["pte_size"], 600_000_000) # 600MB - if self.model in inference_speed_ref: + inference_speed_ref = {"SM8650": 20, "SM8750": 22} + if ( + not self.compile_only + and not self.enable_x86_64 + and self.model in inference_speed_ref + ): + self.assertLessEqual(msg["pte_size"], 1_600_000_000) + self.assertGreaterEqual(msg["acc_norm"], 0.2) self.assertGreaterEqual( msg["inference_speed"], inference_speed_ref[self.model] ) - def test_static_qwen3(self): + def test_llama_stories_260k(self): if not self.required_envs(): - self.skipTest("missing required envs") - - prompt = "My favourite condiment is " - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--model", - self.model, - "--ip", - self.ip, - "--port", - str(self.port), - "--prompt", - f"{prompt}", - "--decoder_model", - "qwen3-0_6b", - "--model_mode", - "kv", - "--max_seq_len", - "1024", - "--eval_perplexity", - "--tasks", - "wikitext", - "--limit", - "1", - ] - if self.compile_only: - cmds.extend(["--compile_only"]) - elif self.device: - cmds.extend(["--device", self.device]) - if self.host: - cmds.extend(["--host", self.host]) - elif self.enable_x86_64: - cmds.extend(["--enable_x86_64"]) - if self.pre_gen_pte: - cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - inference_speed_ref = {"SM8650": 38, "SM8750": 56} - self.assertLessEqual(msg["wiki_ppl"], 18) - self.assertLessEqual(msg["pte_size"], 950_000_000) # 950MB - if self.model in inference_speed_ref: - self.assertGreaterEqual( - msg["inference_speed"], inference_speed_ref[self.model] - ) - - def test_qwen2_5(self): - if not self.required_envs([]): - self.skipTest("missing required envs") - prompt = "My favourite condiment is " - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py", - "--prompt", - prompt, - "--decoder_model", - "qwen2.5_0.5B", - "--ptq", - "16a8w", - "--enable_spinquant_r3", - "--max_seq_len", - "128", + self.skipTest("missing required envs") + assert ( + self.llama_artifacts is not None + ), "Please provide path to llama artifacts" + + prompt = "Once" + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", "--artifact", self.artifact_dir, "--build_folder", self.build_folder, "--model", self.model, + "--checkpoint", + f"{self.llama_artifacts}/stories260K.pt", + "--params", + f"{self.llama_artifacts}/params.json", + "--tokenizer_model", + f"{self.llama_artifacts}/tokenizer.model", + "--tokenizer_bin", + f"{self.llama_artifacts}/tokenizer.bin", "--ip", self.ip, "--port", str(self.port), + "--prompt", + f"{prompt}", + "--temperature", + "0", + "--decoder_model", + "stories260k", + "--model_mode", + "hybrid", + "--prefill_ar_len", + "32", + "--max_seq_len", + "128", ] if self.compile_only: cmds.extend(["--compile_only"]) @@ -5176,7 +6360,7 @@ def test_qwen2_5(self): if self.pre_gen_pte: cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - golden_start_with = "My favourite condiment is iced tea." + golden_start_with = "Once upon a time," p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: conn = listener.accept() @@ -5187,16 +6371,26 @@ def test_qwen2_5(self): else: if not self.compile_only: model_out = msg["result"][0] + print(f"Model CI result:{model_out[: len(golden_start_with)]}") self.assertTrue( model_out.startswith(golden_start_with), - f"Expected Output: '{golden_start_with}' Actual Output: '{model_out}'", + f"Expected Output: {golden_start_with}. Actual Output: {model_out}", ) + # x86 does not allow weight sharing, so we don't check pte size + if not self.enable_x86_64: + pte_size = msg["pte_size"] + self.assertLessEqual(pte_size, 2_020_000) # 2MB + if not self.compile_only and not self.enable_x86_64: + self.assertGreaterEqual(msg["inference_speed"], 1600) # Lanai - def test_static_smollm2(self): + def test_llama_stories_110m(self): if not self.required_envs(): self.skipTest("missing required envs") + assert ( + self.llama_artifacts is not None + ), "Please provide path to llama artifacts" - prompt = "My favourite condiment is " + prompt = "Once" cmds = [ "python", f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", @@ -5206,27 +6400,32 @@ def test_static_smollm2(self): self.build_folder, "--model", self.model, + "--target", + self.target, + "--checkpoint", + f"{self.llama_artifacts}/stories110M.pt", + "--params", + f"{self.llama_artifacts}/params.json", + "--tokenizer_model", + f"{self.llama_artifacts}/tokenizer.model", + "--tokenizer_bin", + f"{self.llama_artifacts}/tokenizer.bin", "--ip", self.ip, "--port", str(self.port), "--prompt", f"{prompt}", - "--decoder_model", - "smollm2_135m", - "--model_mode", - "kv", "--temperature", "0", + "--decoder_model", + "stories110m", + "--model_mode", + "hybrid", "--prefill_ar_len", - "128", + "32", "--max_seq_len", - "1024", - "--eval_perplexity", - "--task", - "wikitext", - "--limit", - "1", + "128", ] if self.compile_only: cmds.extend(["--compile_only"]) @@ -5239,6 +6438,7 @@ def test_static_smollm2(self): if self.pre_gen_pte: cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + golden_start_with = "Once upon a time," p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: conn = listener.accept() @@ -5247,17 +6447,36 @@ def test_static_smollm2(self): if "Error" in msg: self.fail(msg["Error"]) else: - self.assertLessEqual(msg["wiki_ppl"], 25) - self.assertGreaterEqual(msg["inference_speed"], 200) + if not self.compile_only: + model_out = msg["result"][0] + self.assertTrue( + model_out.startswith(golden_start_with), + f"Expected Output: {golden_start_with}. Actual Output: {model_out}", + ) + # x86 does not allow weight sharing, so we don't check pte size + if not self.enable_x86_64: + pte_size = msg["pte_size"] + self.assertLessEqual(pte_size, 135_000_000) # 135MB + if not self.compile_only and not self.enable_x86_64: + self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai - def test_static_smollm3(self): - if not self.required_envs(): + def test_qwen2_5(self): + # This is not testing static llm flow. + if not self.required_envs([]): self.skipTest("missing required envs") - prompt = "My favourite condiment is " cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py", + "--prompt", + prompt, + "--decoder_model", + "qwen2.5_0.5B", + "--ptq", + "16a8w", + "--enable_spinquant_r3", + "--max_seq_len", + "128", "--artifact", self.artifact_dir, "--build_folder", @@ -5268,21 +6487,6 @@ def test_static_smollm3(self): self.ip, "--port", str(self.port), - "--prompt", - f"{prompt}", - "--decoder_model", - "smollm3-3b", - "--model_mode", - "kv", - "--temperature", - "0", - "--max_seq_len", - "1024", - "--eval_perplexity", - "--task", - "wikitext", - "--limit", - "1", ] if self.compile_only: cmds.extend(["--compile_only"]) @@ -5295,6 +6499,7 @@ def test_static_smollm3(self): if self.pre_gen_pte: cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + golden_start_with = "My favourite condiment is iced tea." p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: conn = listener.accept() @@ -5303,12 +6508,11 @@ def test_static_smollm3(self): if "Error" in msg: self.fail(msg["Error"]) else: - inference_speed_ref = {"SM8650": 23, "SM8750": 28} - self.assertLessEqual(msg["wiki_ppl"], 10) - self.assertLessEqual(msg["pte_size"], 2_600_000_000) # 2.6GB - if self.model in inference_speed_ref: - self.assertGreaterEqual( - msg["inference_speed"], inference_speed_ref[self.model] + if not self.compile_only: + model_out = msg["result"][0] + self.assertTrue( + model_out.startswith(golden_start_with), + f"Expected Output: '{golden_start_with}' Actual Output: '{model_out}'", ) @@ -5329,6 +6533,8 @@ def test_albert(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5365,6 +6571,8 @@ def test_bert(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5402,6 +6610,8 @@ def test_conv_former(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5423,6 +6633,43 @@ def test_conv_former(self): self.assertGreaterEqual(msg["top_1"], 70) self.assertGreaterEqual(msg["top_5"], 92) + def test_convnext_small(self): + if not self.required_envs([self.image_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/convnext_small.py", + "--dataset", + self.image_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--seed", + str(1126), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 76) + self.assertGreaterEqual(msg["top_5"], 97) + def test_cvt(self): if not self.required_envs([self.image_dataset]): self.skipTest("missing required envs") @@ -5440,6 +6687,8 @@ def test_cvt(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5477,6 +6726,8 @@ def test_deit(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5514,6 +6765,8 @@ def test_dino_v2(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5551,6 +6804,8 @@ def test_distilbert(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5586,6 +6841,8 @@ def test_dit(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5623,6 +6880,8 @@ def test_efficientnet(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5644,6 +6903,7 @@ def test_efficientnet(self): self.assertGreaterEqual(msg["top_1"], 61) self.assertGreaterEqual(msg["top_5"], 88) + @unittest.skip("Bad accuracy, need investigation") def test_efficientSAM(self): if not self.required_envs( [self.image_dataset, self.pretrained_weight, self.oss_repo] @@ -5662,6 +6922,8 @@ def test_efficientSAM(self): self.device, "--model", self.model, + "--target", + self.target, "--oss_repo", self.oss_repo, "--pretrained_weight", @@ -5702,6 +6964,8 @@ def test_esrgan(self): self.device, "--model", self.model, + "--target", + self.target, "--default_dataset", "--oss_repo", self.oss_repo, @@ -5742,6 +7006,8 @@ def test_eurobert(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5780,6 +7046,8 @@ def test_fastvit(self): self.device, "--model", self.model, + "--target", + self.target, "--oss_repo", self.oss_repo, "--pretrained_weight", @@ -5822,6 +7090,8 @@ def test_fbnet(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5860,6 +7130,8 @@ def test_focalnet(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5900,6 +7172,8 @@ def test_gMLP(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5921,6 +7195,43 @@ def test_gMLP(self): self.assertGreaterEqual(msg["top_1"], 70) self.assertGreaterEqual(msg["top_5"], 88) + def test_maxvit_t(self): + if not self.required_envs([self.image_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/maxvit_t.py", + "--dataset", + self.image_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--seed", + str(1126), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 72) + self.assertGreaterEqual(msg["top_5"], 91) + @unittest.skip("Only outputs good accuracy in QNN 2.29") def test_mobilevit_v2(self): if not self.required_envs([self.image_dataset]): @@ -5939,6 +7250,8 @@ def test_mobilevit_v2(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -5979,6 +7292,8 @@ def test_mobilevit_v1(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -6056,6 +7371,8 @@ def test_regnet(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -6095,6 +7412,8 @@ def test_retinanet(self): self.device, "--model", self.model, + "--target", + self.target, "--dataset", self.image_dataset, "--ip", @@ -6122,9 +7441,86 @@ def test_roberta(self): self.skipTest("missing required envs") cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/roberta.py", - "--dataset", - self.sentence_dataset, + f"{self.executorch_root}/examples/qualcomm/oss_scripts/roberta.py", + "--dataset", + self.sentence_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--target", + self.target, + "--ip", + self.ip, + "--port", + str(self.port), + "--seed", + str(1126), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["accuracy"], 0.54) + + def test_squeezenet(self): + if not self.required_envs([self.image_dataset]): + self.skipTest("missing required envs") + + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/squeezenet.py", + "--dataset", + self.image_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--target", + self.target, + "--ip", + self.ip, + "--port", + str(self.port), + "--seed", + str(1126), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 27) + self.assertGreaterEqual(msg["top_5"], 59) + + def test_ssd300_vgg16(self): + if not self.required_envs([self.pretrained_weight, self.oss_repo]): + self.skipTest("missing required envs") + + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/ssd300_vgg16.py", "--artifact", self.artifact_dir, "--build_folder", @@ -6133,6 +7529,12 @@ def test_roberta(self): self.device, "--model", self.model, + "--target", + self.target, + "--oss_repo", + self.oss_repo, + "--pretrained_weight", + self.pretrained_weight, "--ip", self.ip, "--port", @@ -6151,15 +7553,14 @@ def test_roberta(self): if "Error" in msg: self.fail(msg["Error"]) else: - self.assertGreaterEqual(msg["accuracy"], 0.54) + self.assertGreaterEqual(msg["mAP"], 0.76) - def test_squeezenet(self): + def test_swin_transformer(self): if not self.required_envs([self.image_dataset]): self.skipTest("missing required envs") - cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/squeezenet.py", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/swin_transformer.py", "--dataset", self.image_dataset, "--artifact", @@ -6170,6 +7571,8 @@ def test_squeezenet(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -6188,16 +7591,17 @@ def test_squeezenet(self): if "Error" in msg: self.fail(msg["Error"]) else: - self.assertGreaterEqual(msg["top_1"], 27) - self.assertGreaterEqual(msg["top_5"], 59) + self.assertGreaterEqual(msg["top_1"], 71) + self.assertGreaterEqual(msg["top_5"], 90) - def test_ssd300_vgg16(self): - if not self.required_envs([self.pretrained_weight, self.oss_repo]): + def test_swin_v2_t(self): + if not self.required_envs([self.image_dataset]): self.skipTest("missing required envs") - cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/ssd300_vgg16.py", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/swin_v2_t.py", + "--dataset", + self.image_dataset, "--artifact", self.artifact_dir, "--build_folder", @@ -6206,10 +7610,6 @@ def test_ssd300_vgg16(self): self.device, "--model", self.model, - "--oss_repo", - self.oss_repo, - "--pretrained_weight", - self.pretrained_weight, "--ip", self.ip, "--port", @@ -6228,16 +7628,17 @@ def test_ssd300_vgg16(self): if "Error" in msg: self.fail(msg["Error"]) else: - self.assertGreaterEqual(msg["mAP"], 0.76) + self.assertGreaterEqual(msg["top_1"], 63) + self.assertGreaterEqual(msg["top_5"], 92) - def test_swin_transformer(self): - if not self.required_envs([self.image_dataset]): + def test_t5(self): + if not self.required_envs([self.qa_dataset]): self.skipTest("missing required envs") cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/swin_transformer.py", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/t5/t5.py", "--dataset", - self.image_dataset, + self.qa_dataset, "--artifact", self.artifact_dir, "--build_folder", @@ -6246,6 +7647,8 @@ def test_swin_transformer(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -6264,17 +7667,16 @@ def test_swin_transformer(self): if "Error" in msg: self.fail(msg["Error"]) else: - self.assertGreaterEqual(msg["top_1"], 71) - self.assertGreaterEqual(msg["top_5"], 90) + self.assertGreaterEqual(msg["f1"], 0.72) - def test_t5(self): - if not self.required_envs([self.qa_dataset]): + def test_vit_b_16(self): + if not self.required_envs([self.image_dataset]): self.skipTest("missing required envs") cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/t5/t5.py", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/vit_b_16.py", "--dataset", - self.qa_dataset, + self.image_dataset, "--artifact", self.artifact_dir, "--build_folder", @@ -6301,7 +7703,8 @@ def test_t5(self): if "Error" in msg: self.fail(msg["Error"]) else: - self.assertGreaterEqual(msg["f1"], 0.72) + self.assertGreaterEqual(msg["top_1"], 72) + self.assertGreaterEqual(msg["top_5"], 96) def test_whisper(self): if not self.required_envs(): @@ -6318,6 +7721,8 @@ def test_whisper(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -6585,6 +7990,8 @@ def test_mobilenet_v2(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -6625,6 +8032,8 @@ def test_mobilenet_v3(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -6665,6 +8074,8 @@ def test_inception_v3(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -6705,6 +8116,8 @@ def test_inception_v4(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -6745,6 +8158,8 @@ def test_vit(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -6783,6 +8198,8 @@ def test_edsr(self): self.device, "--model", self.model, + "--target", + self.target, "--default_dataset", "--ip", self.ip, @@ -6822,6 +8239,8 @@ def test_deeplab_v3(self): self.device, "--model", self.model, + "--target", + self.target, "--download", "--ip", self.ip, @@ -6863,6 +8282,8 @@ def test_mobilebert(self): self.device, "--model", self.model, + "--target", + self.target, "--pretrained_weight", self.pretrained_weight, "--ip", @@ -6904,6 +8325,8 @@ def test_ptq_mobilebert(self): self.device, "--model", self.model, + "--target", + self.target, "--pretrained_weight", self.pretrained_weight, "--ptq", @@ -6946,6 +8369,8 @@ def test_wav2letter(self): self.device, "--model", self.model, + "--target", + self.target, "--pretrained_weight", self.pretrained_weight, "--ip", @@ -6994,6 +8419,9 @@ def test_export_example(self): class TestUtilsScript(TestQNN): + TestQNN.atol = 1e-1 + TestQNN.rtol = 1 + def required_envs(self, conditions=None) -> bool: conditions = [] if conditions is None else conditions return all( @@ -7004,6 +8432,150 @@ def required_envs(self, conditions=None) -> bool: ] ) + def test_cli(self): + with tempfile.TemporaryDirectory() as tmp_dir: + sample_input = torch.randn(1, 2, 3, 4) + ep = torch.export.export(Relu(), (sample_input,)) # noqa: F405 + torch.export.save(ep, f"{tmp_dir}/relu.pt2") + torch.save(sample_input, f"{tmp_dir}/input_0_0.pt") + with open(f"{tmp_dir}/input_list", "w") as f: + f.write(f"{tmp_dir}/input_0_0.pt\n") + + # quantize + cmds = [ + "python", + "-m", + "examples.qualcomm.util_scripts.cli", + "quantize", + "--artifact", + f"{tmp_dir}/relu.pt2", + "--output_folder", + f"{tmp_dir}/q_out", + "--input_list", + f"{tmp_dir}/input_list", + ] + subprocess.run(cmds, stdout=subprocess.DEVNULL) + self.assertTrue(os.path.isfile(f"{tmp_dir}/q_out/relu_quantized.pt2")) + # compile + cmds = [ + "python", + "-m", + "examples.qualcomm.util_scripts.cli", + "compile", + "--artifact", + f"{tmp_dir}/q_out/relu_quantized.pt2", + "--output_folder", + f"{tmp_dir}/c_out", + "--model", + self.model, + ] + subprocess.run(cmds, stdout=subprocess.DEVNULL) + self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/relu_quantized.pte")) + self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/relu_quantized.svg")) + # execute + cmds = [ + "python", + "-m", + "examples.qualcomm.util_scripts.cli", + "execute", + "--artifact", + f"{tmp_dir}/c_out/relu_quantized.pte", + "--output_folder", + f"{tmp_dir}/e_out", + "--model", + self.model, + "--target", + self.target, + "--device", + self.device, + "--host", + self.host, + "--build_folder", + self.build_folder, + "--input_list", + f"{tmp_dir}/input_list", + ] + if self.host: + cmds.extend(["--host", self.host]) + subprocess.run(cmds, stdout=subprocess.DEVNULL) + self.assertTrue(os.path.isfile(f"{tmp_dir}/e_out/Result_0/output_0.pt")) + + def test_cli_with_input_list_assignment(self): + with tempfile.TemporaryDirectory() as tmp_dir: + sample_input = torch.randn(1, 2, 3, 4) + sample_input2 = torch.randn(1, 2, 3, 4) + ep = torch.export.export( + Sub_y_x_from_x_y(), (sample_input, sample_input2) # noqa: F405 + ) + torch.export.save(ep, f"{tmp_dir}/sub.pt2") + torch.save(sample_input, f"{tmp_dir}/input_0_0.pt") + torch.save(sample_input2, f"{tmp_dir}/input_0_1.pt") + with open(f"{tmp_dir}/input_list", "w") as f: + f.write(f"x:={tmp_dir}/input_0_0.pt y:={tmp_dir}/input_0_1.pt\n") + + # quantize + cmds = [ + "python", + "-m", + "examples.qualcomm.util_scripts.cli", + "quantize", + "--artifact", + f"{tmp_dir}/sub.pt2", + "--output_folder", + f"{tmp_dir}/q_out", + "--input_list", + f"{tmp_dir}/input_list", + ] + subprocess.run(cmds, stdout=subprocess.DEVNULL) + self.assertTrue(os.path.isfile(f"{tmp_dir}/q_out/sub_quantized.pt2")) + # compile + cmds = [ + "python", + "-m", + "examples.qualcomm.util_scripts.cli", + "compile", + "--artifact", + f"{tmp_dir}/q_out/sub_quantized.pt2", + "--output_folder", + f"{tmp_dir}/c_out", + "--model", + self.model, + ] + subprocess.run(cmds, stdout=subprocess.DEVNULL) + self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/sub_quantized.pte")) + self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/sub_quantized.svg")) + # execute + cmds = [ + "python", + "-m", + "examples.qualcomm.util_scripts.cli", + "execute", + "--artifact", + f"{tmp_dir}/c_out/sub_quantized.pte", + "--output_folder", + f"{tmp_dir}/e_out", + "--model", + self.model, + "--target", + self.target, + "--device", + self.device, + "--host", + self.host, + "--build_folder", + self.build_folder, + "--input_list", + f"{tmp_dir}/input_list", + ] + if self.host: + cmds.extend(["--host", self.host]) + subprocess.run(cmds, stdout=subprocess.DEVNULL) + output_file = f"{tmp_dir}/e_out/Result_0/output_0.pt" + self.assertTrue(os.path.isfile(output_file)) + device_output = torch.load(output_file, weights_only=True) + golden_output = ep.module()(sample_input, sample_input2) + self._assert_outputs_equal(golden_output, device_output) + def test_custom_op(self): if not self.required_envs([self.op_package_dir]): self.skipTest("missing required envs") @@ -7018,6 +8590,8 @@ def test_custom_op(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -7050,6 +8624,8 @@ def test_debugger_generate_optrace(self): self.device, "--model", self.model, + "--target", + self.target, "--ip", self.ip, "--port", @@ -7075,67 +8651,69 @@ def test_debugger_generate_optrace(self): qhas_data = json.load(qhas_file) self.assertIn("data", qhas_data) - def test_cli(self): - with tempfile.TemporaryDirectory() as tmp_dir: - sample_input = torch.randn(1, 2, 3, 4) - ep = torch.export.export(Relu(), (sample_input,)) # noqa: F405 - torch.export.save(ep, f"{tmp_dir}/relu.pt2") - torch.save(sample_input, f"{tmp_dir}/input_0_0.pt") - with open(f"{tmp_dir}/input_list", "w") as f: - f.write(f"{tmp_dir}/input_0_0.pt\n") + def test_intermediate_debugger(self): + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py", + "--artifact", + self.artifact_dir, + "--dataset", + self.image_dataset, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--dump_intermediate_outputs", + ] + if self.host: + cmds.extend(["--host", self.host]) - # quantize - cmds = [ - "python", - "-m", - "examples.qualcomm.util_scripts.cli", - "quantize", - "--artifact", - f"{tmp_dir}/relu.pt2", - "--output_folder", - f"{tmp_dir}/q_out", - "--input_list", - f"{tmp_dir}/input_list", - ] - subprocess.run(cmds, stdout=subprocess.DEVNULL) - self.assertTrue(os.path.isfile(f"{tmp_dir}/q_out/relu_quantized.pt2")) - # compile - cmds = [ - "python", - "-m", - "examples.qualcomm.util_scripts.cli", - "compile", - "--artifact", - f"{tmp_dir}/q_out/relu_quantized.pt2", - "--output_folder", - f"{tmp_dir}/c_out", - "--model", - self.model, - ] - subprocess.run(cmds, stdout=subprocess.DEVNULL) - self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/relu_quantized.pte")) - self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/relu_quantized.svg")) - # execute - cmds = [ - "python", - "-m", - "examples.qualcomm.util_scripts.cli", - "execute", - "--artifact", - f"{tmp_dir}/c_out/relu_quantized.pte", - "--output_folder", - f"{tmp_dir}/e_out", - "--model", - self.model, - "--device", - self.device, - "--build_folder", - self.build_folder, - "--input_list", - f"{tmp_dir}/input_list", - ] - subprocess.run(cmds, stdout=subprocess.DEVNULL) - self.assertTrue(os.path.isfile(f"{tmp_dir}/e_out/output_0_0.pt")) + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + svg_path = msg["svg_path"] + csv_path = msg["csv_path"] + min_accepted = 235 + max_accepted = 241 + # Having a +- 3 tolerance, expecting 238 events + assert os.path.exists(svg_path), f"Unable to find SVG file: {svg_path}" + assert os.path.exists(csv_path), f"Unable to find CSV file: {csv_path}" + + csv_valid_count = 0 + with open(csv_path, mode="r", newline="") as csv_file: + reader = csv.reader(csv_file) + header = next(reader) + index = header.index("is_valid_score") + for row in reader: + if len(row) > index and row[index].strip().upper() == "TRUE": + csv_valid_count += 1 + # We assume csv_valid_count == compared_events, since all compared events meet metric's threshold + assert ( + min_accepted <= csv_valid_count <= max_accepted + ), f"Expected CSV events with valid score is outside of expected range, number of valid score events found: {csv_valid_count}" + + svg_valid_count = 0 + with open(svg_path, "r", encoding="utf-8") as svg_file: + for line in svg_file: + svg_valid_count += line.count("is_valid_score=True") + # We assume svg_valid_count == compared_events, since all compared events meet metric's threshold + assert ( + min_accepted <= svg_valid_count <= max_accepted + ), f"Expected SVG events with valid score is outside of expected range, number of valid score events found: {svg_valid_count}" + print( + f"CSV valid count: {csv_valid_count}. SVG valid count: {svg_valid_count}" + ) def setup_environment(): @@ -7206,13 +8784,13 @@ def setup_environment(): default="", type=str, ) - parser.add_argument( - "--pre_gen_pte", - help="Run the pre-generated pte in the given directory.", + "--backend", + help="Backend to be deployed ('htp'/'gpu' are currently supported).", + choices=["htp", "gpu"], + default="htp", type=str, ) - parser.add_argument( "--llama_artifacts", help="A folder that contains: weight, tokenizer, and params.", @@ -7242,6 +8820,8 @@ def setup_environment(): TestQNN.pre_gen_pte = args.pre_gen_pte TestQNN.llama_artifacts = args.llama_artifacts TestQNN.op_package_dir = args.op_package_dir + TestQNN.target = args.target + TestQNN.backend = args.backend return sys.argv[:1] + ns_args diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 93eee4dfc31..f4b9339e1c2 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -15,9 +15,15 @@ import torchao from executorch import exir from executorch.backends.qualcomm.builders.node_visitor import dq_ops +from executorch.backends.qualcomm.debugger.qnn_intermediate_debugger import ( + QNNIntermediateDebugger, +) from executorch.backends.qualcomm.qnn_preprocess import QnnBackend from executorch.backends.qualcomm.quantizer.quantizer import ModuleQConfig, QuantDtype -from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset +from executorch.backends.qualcomm.serialization.qc_schema import ( + QcomChipset, + QnnExecuTorchBackendType, +) from executorch.backends.qualcomm.utils.constants import ( QCOM_DTYPE, QCOM_PASS_ACTIVATE_KEY, @@ -39,6 +45,7 @@ ) from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.utils import get_delegates from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass @@ -163,10 +170,14 @@ class TestQNN(unittest.TestCase): pretrained_weight: str = "" enable_profile: bool = False op_package_dir: str = "" + target: str = "" + model_name: str = "" + backend: str = "" online_prepare: bool = False use_8a8w: str = "8a8w" use_16a16w: str = "16a16w" use_16a4w: str = "16a4w" + oss_repo: str = "" shared_buffer: bool = False enable_x86_64: bool = False compile_only: bool = False @@ -211,6 +222,9 @@ def _save_model_and_expected_output( return ref_outputs, pte_fname + def get_backend_type(self): + return getattr(QnnExecuTorchBackendType, f"k{self.backend.title()}Backend") + def required_envs(self, conditions=None) -> bool: conditions = [] if conditions is None else conditions return all( @@ -237,6 +251,8 @@ def verify_output( # noqa: C901 extra_cmds: str = "", output_callback: Optional[Callable[[str], None]] = None, save_inference_speed: bool = False, + expected_compared_events: int = -1, + qnn_intermediate_debugger: QNNIntermediateDebugger = None, ): with tempfile.TemporaryDirectory() as tmp_dir: ( @@ -294,10 +310,25 @@ def validate_intermediate_tensor(): inspector = Inspector( etdump_path=etdump_path, debug_buffer_path=debug_output_path ) + node_tensor_map = qnn_intermediate_debugger._match_tensors( + inspector=inspector, keep_qnn_layout=False + ) + self.assertEqual( + len(node_tensor_map), + expected_compared_events, + msg=f"Unexpected number of compared events, expecting {expected_compared_events}, but has {len(node_tensor_map)} events.", + ) + # Compare accuracy for each layer + for _, value in node_tensor_map.items(): + self._assert_outputs_equal( + value[0].to(torch.float32), value[1].to(torch.float32) + ) for event_block in inspector.event_blocks: if event_block.name == "Execute": - self.assertTrue( - len(event_block.events) == expected_intermediate_events + self.assertEqual( + len(event_block.events), + expected_intermediate_events, + msg=f"Unexpected number of intermediate events, expecting {expected_intermediate_events}, but has {len(event_block.events)} events.", ) processed_inputs = list(sample_inputs) @@ -411,6 +442,7 @@ def validate_intermediate_tensor(): dump_intermediate_outputs=( True if expected_intermediate_events != -1 else False ), + backend=self.get_backend_type(), expected_input_shape=( (tensor.shape for tensor in processed_inputs) if check_io_shape @@ -421,6 +453,7 @@ def validate_intermediate_tensor(): if check_io_shape else None ), + target=self.target, ) adb.push( inputs=[processed_inputs], @@ -457,6 +490,7 @@ def lower_module_and_test_output( expected_partitions: int = 1, expected_profile_events: int = -1, expected_intermediate_events: int = -1, + expected_compared_events: int = -1, assert_output_equal: bool = True, passes_job: Optional[OrderedDict] = None, skip_node_id_set: set = None, @@ -479,6 +513,24 @@ def lower_module_and_test_output( generate_etrecord=self.enable_profile, ) + qnn_intermediate_debugger = None + if expected_intermediate_events != -1: + lowered_module_nodes = get_delegates( + delegated_program.exported_program().graph + ) + assert len(lowered_module_nodes) == 1, "Length not correct" + + lowered_module_node = lowered_module_nodes[0] + lower_module = getattr( + delegated_program.exported_program().graph_module, + lowered_module_node.name, + ) + edge_module = lower_module.original_module.module() + + qnn_intermediate_debugger = QNNIntermediateDebugger() + qnn_intermediate_debugger.set_edge_module(edge_module=edge_module) + qnn_intermediate_debugger.intermediate_output_module(*sample_inputs) + exec_prog = delegated_program.to_executorch( exir.ExecutorchBackendConfig( # For shared buffer, user must pass the memory address @@ -513,15 +565,17 @@ def lower_module_and_test_output( or expected_intermediate_events != -1 ): self.verify_output( - module, - sample_inputs, - exec_prog, - etrecord_path, - expected_profile_events, - expected_intermediate_events, + module=module, + sample_inputs=sample_inputs, + executorch_prog=exec_prog, + etrecord_path=etrecord_path, + expected_profile_events=expected_profile_events, + expected_intermediate_events=expected_intermediate_events, extra_cmds=extra_cmds, output_callback=output_callback, save_inference_speed=save_inference_speed, + expected_compared_events=expected_compared_events, + qnn_intermediate_debugger=qnn_intermediate_debugger, ) def get_qdq_module( @@ -551,6 +605,7 @@ def get_qdq_module( if block_size_map is not None: quantizer.set_block_size_map(block_size_map) prepared = prepare_pt2e(m, quantizer) + prepared(*inputs) quantized_module = convert_pt2e(prepared) nodes = {node.target for node in quantized_module.graph.nodes} @@ -574,6 +629,7 @@ def get_prepared_qat_module( is_linear_per_channel: Optional[bool] = False, custom_quant_annotations: Tuple[Callable] = (), quant_dtype: QuantDtype = QuantDtype.use_8a8w, + block_size_map: Dict[str, Tuple] = None, submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None, ) -> torch.fx.GraphModule: m = torch.export.export(module, inputs, strict=True).module() @@ -586,6 +642,8 @@ def get_prepared_qat_module( is_qat=True, submodule_qconfig_list=submodule_qconfig_list, ) + if block_size_map is not None: + quantizer.set_block_size_map(block_size_map) submodule_qconfig_list = submodule_qconfig_list or [] quantizer.set_submodule_qconfig_list(submodule_qconfig_list) @@ -618,6 +676,7 @@ def get_adb_tool(self, pte_fname): host_id=self.host, soc_model=self.model, error_only=self.error_only, + target=self.target, ) return adb diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py index a4a087287a4..5a6e7570e82 100644 --- a/backends/qualcomm/utils/constants.py +++ b/backends/qualcomm/utils/constants.py @@ -33,6 +33,7 @@ QCOM_SCALE = "scale" QCOM_SCALES = "scales" QCOM_SCALE_OFFSET = "scale_offset" +QCOM_TENSOR_NAME = "qnn_tensor_name" QCOM_ZERO_POINT = "zero_point" QCOM_ZERO_POINTS = "zero_points" QCOM_PASS_ACTIVATE_KEY = "activate" diff --git a/backends/qualcomm/utils/qnn_manager_lifecycle.py b/backends/qualcomm/utils/qnn_manager_lifecycle.py new file mode 100644 index 00000000000..2e1ba7fd2d7 --- /dev/null +++ b/backends/qualcomm/utils/qnn_manager_lifecycle.py @@ -0,0 +1,88 @@ +import contextlib +import logging +import threading +from typing import Dict, List + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager + +from executorch.backends.qualcomm.partition.utils import generate_qnn_executorch_option +from executorch.backends.qualcomm.serialization.qc_schema import ( + QnnExecuTorchBackendType, +) +from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( + flatbuffer_to_option, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec + +# Thread-local storage for QnnManager instances +_current_qnn_managers = threading.local() + + +class QnnManagerRegistry: + def __init__(self): + # Registry stores {backend_type: QnnManager instance} + self._registry = {} + + def get_or_create_qnn_manager( + self, backend_type: QnnExecuTorchBackendType, option: bytes + ) -> PyQnnManager.QnnManager: + if backend_type not in self._registry: + qnn_manager = PyQnnManager.QnnManager(option) + qnn_manager.InitBackend() + self._registry[backend_type] = qnn_manager + return self._registry[backend_type] + + def destroy_qnn_manager(self, backend_type: QnnExecuTorchBackendType): + if backend_type in self._registry: + self._registry[backend_type].Destroy() + del self._registry[backend_type] + else: + logging.warning( + f"Attempted to destroy non-existent QnnManager for backend type {backend_type.name}" + ) + + +@contextlib.contextmanager +def QnnManagerContext(compile_specs: Dict[str, List[CompileSpec]]): + # Create a new registry for the current context + current_context_registry = QnnManagerRegistry() + _current_qnn_managers.active_registry = current_context_registry + + backend_types_in_this_context = set() + + try: + for compile_spec_list in compile_specs.values(): + option = generate_qnn_executorch_option(compile_spec_list) + python_options = flatbuffer_to_option(option) + backend_type = python_options.backend_options.backend_type + + # Use the current_context_registry to get/create the manager + current_context_registry.get_or_create_qnn_manager(backend_type, option) + backend_types_in_this_context.add(backend_type) + yield + finally: + # Destroy only the managers created within this context + for backend_type in backend_types_in_this_context: + current_context_registry.destroy_qnn_manager(backend_type) + + # Clear the active registry reference + _current_qnn_managers.active_registry = None + + +def get_current_qnn_manager( + backend_type: QnnExecuTorchBackendType, compile_specs: List[CompileSpec] +) -> PyQnnManager.QnnManager: + """ + Retrieves the QnnManager instance active for the current QnnManagerContext invocation. + Return a new QnnManger if no QnnManager is active for the given backend_type in the current context. + """ + active_registry = getattr(_current_qnn_managers, "active_registry", None) + if active_registry is None or backend_type not in active_registry._registry: + logging.warning( + f"No QnnManager active for backend type {backend_type.name} in the current QnnManagerContext. " + "It would be better to use to_edge_transform_and_lower_to_qnn to lowering to QNN Backend." + ) + return QnnManagerRegistry().get_or_create_qnn_manager( + backend_type, generate_qnn_executorch_option(compile_specs) + ) + return active_registry._registry[backend_type] diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index be4e86de50f..4a68f434895 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -34,6 +34,8 @@ QcomChipset, QnnExecuTorchBackendOptions, QnnExecuTorchBackendType, + QnnExecuTorchGpuBackendOptions, + QnnExecuTorchGpuPrecision, QnnExecuTorchHtpBackendOptions, QnnExecuTorchHtpPerformanceMode, QnnExecuTorchHtpPrecision, @@ -50,6 +52,7 @@ QCOM_QNN_COMPILE_SPEC, QCOM_QUANTIZED_IO, ) +from executorch.backends.qualcomm.utils.qnn_manager_lifecycle import QnnManagerContext from executorch.exir import EdgeCompileConfig, ExirExportedProgram, to_edge from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -58,6 +61,7 @@ EdgeProgramManager, to_edge_transform_and_lower, ) +from tabulate import tabulate from torch._decomp import core_aten_decompositions, remove_decompositions from torch.export.exported_program import ExportedProgram from torch.fx import passes @@ -157,7 +161,7 @@ def __init__(self, weight, bias=None): def forward(self, x): rank = x.dim() - x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1) + x = x.reshape(*x.shape, 1) if rank == 3 else x.reshape(1, *x.shape, 1) x = torch.transpose(x, 1, 2) res = self.conv(x) res = torch.transpose(res, 1, 2) @@ -184,8 +188,9 @@ def replace_linear(module: torch.nn.Module): def dump_context_from_pte(pte_path) -> List[str]: """ Dump compiled binaries under the same directory of pte_path. - For partitioned graph, there will be multiple files with names f"{graph_name}_{index}". - Where 'graph_name' comes from the compiler_specs and 'index' represents the execution order. + For partitioned graph, there will be multiple files with names f"{method_name}_{index}". + 'method_name' refers to the name of a method in the nn.Module that was traced to + generate this program, while 'index' indicates the order of execution. Args: pte_path (str): The path of generated pte. @@ -197,17 +202,9 @@ def dump_context_from_pte(pte_path) -> List[str]: with open(pte_path, "rb") as f: program_data = f.read() - program = deserialize_pte_binary(program_data) + program = deserialize_pte_binary(program_data).program ctx_path = os.path.dirname(pte_path) - dummy_compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=QcomChipset.SM8650, - backend_options=generate_htp_compiler_spec(use_fp16=False), - ) - qnn_mgr = PyQnnManagerAdaptor.QnnManager( - generate_qnn_executorch_option(dummy_compiler_specs) - ) - qnn_mgr.Init() dumpfiles = [] for execution_plan in program.execution_plan: for i, delegate in enumerate(execution_plan.delegates): @@ -215,7 +212,7 @@ def dump_context_from_pte(pte_path) -> List[str]: processed_bytes = program.backend_delegate_data[ delegate.processed.index ].data - binary = qnn_mgr.StripProtocol(processed_bytes) + binary = PyQnnManagerAdaptor.StripProtocol(processed_bytes) file_extension = ".bin" if len(binary) == 0: binary = processed_bytes @@ -441,15 +438,15 @@ def ensure_graph_specific_dict(value, graph_names): transform_passes[graph_name] = QnnPassManager().get_to_edge_transform_passes( ep, passes_job=passes_job[graph_name], dep_table=dep_table[graph_name] ) - - return to_edge_transform_and_lower( - aten_programs, - transform_passes=transform_passes, - partitioner=qnn_partitioners, - constant_methods=constant_methods, - compile_config=qnn_edge_config(), - generate_etrecord=generate_etrecord, - ) + with QnnManagerContext(compiler_specs): + return to_edge_transform_and_lower( + aten_programs, + transform_passes=transform_passes, + partitioner=qnn_partitioners, + constant_methods=constant_methods, + compile_config=qnn_edge_config(), + generate_etrecord=generate_etrecord, + ) def capture_program( @@ -933,6 +930,47 @@ def draw_graph(title, path, graph_module: torch.fx.GraphModule): f.write(graph.get_dot_graph().create_svg()) +def generate_gpu_compiler_spec( + precision: QnnExecuTorchGpuPrecision = QnnExecuTorchGpuPrecision.kGpuPrecisionUserProvided, + use_memory_optimizations: bool = True, + use_node_optimizations: bool = True, + use_queue_recording: bool = True, + use_weight_sharing: bool = False, +) -> QnnExecuTorchBackendOptions: + """ + Helper function generating backend options for QNN HTP + + Args: + precision: + kGpuPrecisionFp32 - Sets the precision mode to floating point 32-bit (FP32). + kGpuPrecisionFp16 - Sets the precision mode to floating point 16-bit (FP16). + kGpuPrecisionHybrid - Sets the precision mode to FP16 for storage and FP32 for calculations. + kGpuPrecisionUserProvided - Uses the tensor data type provided by the user. + use_memory_optimizations: If true, backend will share NATIVE tensor memory + based upon analysis of the network topology. + use_node_optimizations: If true, backend will fuse compatible operations into + one operation to improve performance. + use_queue_recording: If true, backend will use queue recording to improve performance. + use_weight_sharing: Used with multiple_graphs, where model size will be + reduced when operations have the same weights across multiple graphs. + + Returns: + QnnExecuTorchGpuBackendOptions: backend options for QNN GPU. + """ + # TODO: enable performance hint mechanism in runtime and make this as an option + gpu_options = QnnExecuTorchGpuBackendOptions() + gpu_options.precision = precision + gpu_options.use_memory_optimizations = use_memory_optimizations + gpu_options.use_node_optimizations = use_node_optimizations + gpu_options.use_queue_recording = use_queue_recording + gpu_options.use_weight_sharing = use_weight_sharing + + return QnnExecuTorchBackendOptions( + backend_type=QnnExecuTorchBackendType.kGpuBackend, + gpu_options=gpu_options, + ) + + def generate_htp_compiler_spec( use_fp16: bool, use_dlbc: bool = False, @@ -987,8 +1025,8 @@ def generate_qnn_executorch_compiler_spec( optrace: bool = False, shared_buffer: bool = False, is_from_context_binary: bool = False, - graph_name: str = "forward", op_package_options: QnnExecuTorchOpPackageOptions = None, + use_mha2sha: bool = False, ) -> List[CompileSpec]: """ Helper function generating compiler specs for Qualcomm AI Engine Direct @@ -1001,6 +1039,7 @@ def generate_qnn_executorch_compiler_spec( SM8550(Snapdragon 8 Gen 2) SM8650(Snapdragon 8 Gen 3) SM8750(Snapdragon 8 Elite) + SM8850(Snapdragon 8 Elite Gen 5) backend_options: Options required by different backends. debug: Enable verbose logging. Disclaimer: this option must change in the near future. @@ -1016,9 +1055,9 @@ def generate_qnn_executorch_compiler_spec( shared_buffer: Enables usage of shared buffer between application and backend for graph I/O. is_from_context_binary: True if current graph comes from pre-built context binary. - graph_name: Assign unique graph name if lowering multiple methods. op_package_options: Optional structure to specify op packages loaded and used by the backend. + use_mha2sha: This experimental parameter is used to decide whether to enable multi-head attention to single-head attention pass, aiming to reduce time consumption in AOT and improve performance on HTP. Returns: List[CompileSpec]: Compiler specs for Qualcomm AI Engine Direct. @@ -1041,7 +1080,6 @@ def generate_qnn_executorch_compiler_spec( qnn_executorch_options = QnnExecuTorchOptions( _soc_info_table[soc_model], backend_options ) - qnn_executorch_options.graph_name = [graph_name] qnn_executorch_options.log_level = ( QnnExecuTorchLogLevel.kLogLevelDebug if debug @@ -1081,6 +1119,8 @@ def generate_qnn_executorch_compiler_spec( if op_package_options and len(op_package_options.op_package_infos) > 0: qnn_executorch_options.op_package_options = op_package_options + qnn_executorch_options.use_mha2sha = use_mha2sha + return [ CompileSpec(QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(qnn_executorch_options)) ] @@ -1089,35 +1129,76 @@ def generate_qnn_executorch_compiler_spec( def get_soc_to_arch_map(): return { "SA8295": HtpArch.V68, + "SM8350": HtpArch.V68, "SM8450": HtpArch.V69, "SM8475": HtpArch.V69, "SM8550": HtpArch.V73, + "SA8255": HtpArch.V73, "SM8650": HtpArch.V75, "SM8750": HtpArch.V79, + "SM8850": HtpArch.V81, "SSG2115P": HtpArch.V73, "SSG2125P": HtpArch.V73, "SXR1230P": HtpArch.V73, "SXR2230P": HtpArch.V69, "SXR2330P": HtpArch.V79, + "QCS9100": HtpArch.V73, + "SAR2230P": HtpArch.V81, + "SW6100": HtpArch.V81, } def get_soc_to_chipset_map(): return { "SA8295": QcomChipset.SA8295, + "SM8350": QcomChipset.SM8350, "SM8450": QcomChipset.SM8450, "SM8475": QcomChipset.SM8475, "SM8550": QcomChipset.SM8550, + "SA8255": QcomChipset.SA8255, "SM8650": QcomChipset.SM8650, "SM8750": QcomChipset.SM8750, + "SM8850": QcomChipset.SM8850, "SSG2115P": QcomChipset.SSG2115P, "SSG2125P": QcomChipset.SSG2125P, "SXR1230P": QcomChipset.SXR1230P, "SXR2230P": QcomChipset.SXR2230P, "SXR2330P": QcomChipset.SXR2330P, + "QCS9100": QcomChipset.QCS9100, + "SAR2230P": QcomChipset.SAR2230P, + "SW6100": QcomChipset.SW6100, } +def show_nn_module_stack_for_quant_recipe(gm: torch.fx.GraphModule, supported_ops): + """ + Print a quick preview of op targets and module stack. + + Use this to inspect the FX graph and identify module stack, which helps you craft regex or op-target for quantization recipe. + + """ + + module_metadata = {} + for node in gm.graph.nodes: + target = node.target + deepest_module = None + if node.op == "call_function" and "nn_module_stack" in node.meta: + deepest_module = list(node.meta["nn_module_stack"].values())[-1][0] + if node.target in supported_ops: + module_metadata.setdefault((target, deepest_module), []).append(node) + + table_rows = [] + for (target, module_stack), nodes in module_metadata.items(): + node_names = ", ".join([node.name for node in nodes]) + table_rows.append([str(target), module_stack, node_names]) + + print( + tabulate( + table_rows, headers=["Op Target", "Module Stack", "Nodes"], tablefmt="grid" + ) + ) + + def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable): """ Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess diff --git a/backends/samsung/CMakeLists.txt b/backends/samsung/CMakeLists.txt index fff3ece5239..6ea020c0970 100644 --- a/backends/samsung/CMakeLists.txt +++ b/backends/samsung/CMakeLists.txt @@ -161,7 +161,7 @@ if(${ANDROID}) install( TARGETS enn_backend enn_logging EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} ) endif() diff --git a/backends/samsung/_passes/annotate_qparams.py b/backends/samsung/_passes/annotate_qparams.py new file mode 100644 index 00000000000..663d1fdf5fa --- /dev/null +++ b/backends/samsung/_passes/annotate_qparams.py @@ -0,0 +1,201 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +from typing import Any, Dict, List, Optional + +import torch +from executorch.backends.samsung.utils.constants import QuantConstants +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch._export.utils import get_buffer +from torch.export import ExportedProgram +from torch.fx import GraphModule, Node + + +class AnnotateQparamsPass(ExportPass): + """This parse is to add quantize properties to node need to be quantized. + + Annotate Quant params: + For src_node->Q->DQ->..., we will add the quant params from Q->DQ node + to the src_node + + Annotate Requantize: + For src_node->Q->DQ->Q->DQ->..., if the multiple Q->DQ contains + different quant params, we will mark the src_node as need requantize, + and add Q->DQ after removing all the Q->DQs. + """ + + propagate_nodes = { + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.squeeze_copy.default, + exir_ops.edge.aten.squeeze_copy.dim, + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.concat.default, + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.expand_copy.default, + } + + def __init__(self, edge_program: ExportedProgram): + super().__init__() + self.edge_program = edge_program + + def _get_last_dqs(self, node: Node) -> List[Node]: + r"""From one Q-DQ node, find the last DQs in the quantization node chain. + + + need to consider such case: + /--Q-DQ-node1 + node->Q->DQ--node-node2 + \--Q-DQ-node3 + This is a dfs implemention, so result will keep sorted + Args: + node (Node): Search DQ from this node. + + Returns: + List[Node]: list of DQ node by original sequence + """ + + def _impl(node: Node, res_list: List[Node]): + if ( + node.target not in QuantConstants.QUANT_OPS_KEY_MAP + and node.target not in QuantConstants.DEQUANT_OPS_KEY_MAP + ): + return + for user in node.users.keys(): + if ( + user.target not in QuantConstants.QUANT_OPS_KEY_MAP + and user.target not in QuantConstants.DEQUANT_OPS_KEY_MAP + ): + res_list.append(node) + else: + _impl(user, res_list) + + res_list: List[Node] = [] + for user in node.users: + _impl(user, res_list) + return res_list + + def _propagate_quant_params(self, node: Node): + assert ( + quantize_attrs := node.meta.get("quantize_attrs") + ), "Must be annotated node." + requantize_map: Dict[Node, Node] = node.meta.get("requantize", {}) + while node.users: + if len(node.users) != 1: + break + user = list(node.users.keys())[0] + if ( + user.target not in QuantConstants.QUANT_OPS_KEY_MAP + and user.target not in QuantConstants.DEQUANT_OPS_KEY_MAP + ): + break + node = user + # Case1: ...-q-dq(cur)-propagate_node-node(not d-dq) + # Case2: propagate_node(propagateed)-propagate_node-node(not q-dq) + for idx, user in enumerate(node.users.keys()): + # For the branch who need to be requantized, we propagate the requantize params + user_attrs = requantize_map.get(idx, quantize_attrs) + if user.target not in self.propagate_nodes: + continue + if len(user.users) == 1: + # Possibily no need for checking len(users)>1 + user_of_user = list(user.users)[0] + # node-q-dq-propagate-q-dq not need for propagatey + if ( + user_of_user.target in QuantConstants.QUANT_OPS_KEY_MAP + or user_of_user.target in QuantConstants.DEQUANT_OPS_KEY_MAP + ): + continue + # propagate quant for node-q-dq-propagate_node-node(not qdq) + user.meta["quantize_attrs"] = user_attrs + self._propagate_quant_params(user) + + def _annotate_requantize(self, node: Node): + assert ( + ori_quant_attrs := node.meta.get("quantize_attrs") + ), "No quant parameters found" + list_for_requantize = self._get_last_dqs(node) + node.meta["requantize"] = node.meta.get("requantize", {}) + + # We use index to mark the output to be requantized + # Because user obj and name may change when we requantize them. + + def _check_same(requant_obj, ori_obj) -> bool: + if type(requant_obj) != type(ori_obj): # noqa E721 + # We need actually same type here. + return False + if not isinstance(requant_obj, torch.Tensor): + return requant_obj == ori_obj + if requant_obj.shape != ori_obj.shape: + return False + return bool((requant_obj == ori_obj).all()) + + requantize_map: Dict[int, Dict] = node.meta["requantize"] + for idx, dq in enumerate(list_for_requantize): + q = dq.all_input_nodes[0] + if q.target not in QuantConstants.QUANT_OPS_KEY_MAP: + continue + key_map = QuantConstants.DEQUANT_OPS_KEY_MAP[dq.target] + requantize_attrs = self.get_quant_attrs(q, key_map) + if not all( + _check_same(ori_quant_attrs[key], requantize_attrs[key]) + for key in key_map.values() + ): + requantize_map[idx] = requantize_attrs + + def _annotate(self, graph_module: GraphModule): + for node in graph_module.graph.nodes: + key_map = QuantConstants.QUANT_OPS_KEY_MAP.get(node.target, None) + if not key_map: + continue + source_node = node.args[0] + if source_node.target in ( + *QuantConstants.QUANT_OPS_KEY_MAP, + *QuantConstants.DEQUANT_OPS_KEY_MAP, + ): + # Currently, don't add quant info for d_qd node here. + continue + elif source_node.target == operator.getitem: + source_node = source_node.args[0] + quant_attrs = self.get_quant_attrs(node, key_map) + source_node.meta["quantize_attrs"] = quant_attrs + self._annotate_requantize(source_node) + self._propagate_quant_params(source_node) + + def call(self, graph_module: GraphModule): + self._annotate(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) + + def get_quant_attrs( + self, quant_node: torch.fx.Node, key_map: Optional[Dict] = None + ) -> Dict[str, Any]: + quant_attr_keys = [arg.name for arg in quant_node.target._schema.arguments] + quant_attrs = dict.fromkeys(quant_attr_keys) + for key, attr in zip(quant_attr_keys[1:], quant_node.args[1:]): + # For channel-wise quantization, params are stored by buffer nodes. + if isinstance(attr, torch.fx.Node): + attr = get_buffer(self.edge_program, attr) + quant_attrs[key] = attr + quant_attrs["target"] = quant_node.target + if key_map is None: + return quant_attrs + miss_attrs = [] + for aten_attr, snc_attr in key_map.items(): + if aten_attr not in quant_attrs: + miss_attrs.append(aten_attr) + continue + attr = quant_attrs[aten_attr] + quant_attrs.pop(aten_attr) + quant_attrs[snc_attr] = attr + assert ( + not miss_attrs + ), f"Miss quant attrs {miss_attrs} for node {quant_node.name}" + return quant_attrs diff --git a/backends/samsung/_passes/annotate_scalar_parameters.py b/backends/samsung/_passes/annotate_scalar_parameters.py new file mode 100644 index 00000000000..643685bdb25 --- /dev/null +++ b/backends/samsung/_passes/annotate_scalar_parameters.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.samsung.quantizer.quantizer import global_quant_info +from executorch.backends.samsung.utils.constants import QuantConstants +from executorch.backends.transforms.utils import get_param_tensor, is_param_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.export import ExportedProgram + + +class AnnotateScalarParametersPass(ExportPass): + """ + Need to add quantization parameters for scalars for some ops + Ifm(Quantized)------TargetOP--- + Scalar(Non-Quant)---/ + Notice: Such scalars are converted to tensor node by default pass + """ + + TARGET_OPS = { + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.div.Tensor, + } + + def __init__(self, edge_program: ExportedProgram): + super().__init__() + self.edge_program = edge_program + + def annotate(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if node.target not in self.TARGET_OPS or "quantize_attrs" not in node.meta: + continue + torch_quant_dtype = global_quant_info.weight_precison.torch_dtype + for input_arg in node.all_input_nodes: + if input_arg.op not in ("placeholder", "get_attr") or not is_param_node( + self.edge_program, input_arg + ): + continue + else: + tensor = get_param_tensor(self.edge_program, input_arg) + if not tensor.shape: + qparams = { + QuantConstants.QUANT_KEY.scale: float(tensor), + QuantConstants.QUANT_KEY.quant_dtype: torch_quant_dtype, + QuantConstants.QUANT_KEY.quant_max: torch.iinfo( + torch_quant_dtype + ).max, + QuantConstants.QUANT_KEY.quant_min: torch.iinfo( + torch_quant_dtype + ).min, + QuantConstants.QUANT_KEY.zero_point: 0, + } + input_arg.meta["quantize_attrs"] = qparams + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + self.annotate(graph_module) + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/samsung/_passes/conv1d_to_conv2d.py b/backends/samsung/_passes/conv1d_to_conv2d.py index 57f1074b348..1b8782d956b 100644 --- a/backends/samsung/_passes/conv1d_to_conv2d.py +++ b/backends/samsung/_passes/conv1d_to_conv2d.py @@ -5,84 +5,93 @@ # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.transforms.utils import get_param_tensor from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from torch._export.utils import get_param class Conv1dToConv2d(ExportPass): - def __init__(self, edge_program: ExportedProgram): super().__init__() self.edge_program = edge_program + def update_kernel(self, weight_node: torch.Tensor): + # lifted tensor in tensor constant + weight_3d = get_param_tensor(self.edge_program, weight_node) + if param_name := self.edge_program.graph_signature.inputs_to_parameters.get( + weight_node.name + ): + new_weight_param = torch.nn.Parameter( + data=weight_3d.data.contiguous().unsqueeze(dim=-1), requires_grad=False + ) + self.edge_program.state_dict[param_name] = new_weight_param + elif tensor_name := self.edge_program.graph_signature.inputs_to_lifted_tensor_constants.get( + weight_node.name + ): + self.edge_program.constants[tensor_name] = torch.unsqueeze(weight_3d, -1) + else: + RuntimeError("Weight of 1d conv should be constant tensor or Parameter obj") + weight_node.meta["val"] = weight_node.meta["val"].data.unsqueeze(dim=-1) + def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph node_list = list(graph.nodes) for node in node_list: - if node.op == "call_function": - if node.target == exir_ops.edge.aten.convolution.default: - stride = list(node.args[3]) - if len(stride) != 1: - continue + if node.op != "call_function": + continue + if node.target != exir_ops.edge.aten.convolution.default: + continue + stride = list(node.args[3]) + if len(stride) != 1: + continue - # convert 3dim weight to 4dim - weight_node = node.args[1] - weight_3dim = get_param(self.edge_program, weight_node) - weight_4dim = torch.nn.Parameter( - data=weight_3dim.data.contiguous().unsqueeze(dim=-1), - requires_grad=False, - ) - parameter_name = ( - self.edge_program.graph_signature.inputs_to_parameters[ - weight_node.name - ] - ) - self.edge_program.state_dict[parameter_name] = weight_4dim - weight_node.meta["val"] = weight_node.meta["val"].data.unsqueeze( - dim=-1 - ) + # convert 3dim weight to 4dim + weight_node = node.args[1] + self.update_kernel(weight_node) - # Extend stride, padding, and dilation - node.args = ( - node.args[0], - node.args[1], - node.args[2], - node.args[3] + [1], # stride - node.args[4] + [0], # padding - node.args[5] + [1], # dilation - node.args[6], - node.args[7], - node.args[8], - ) + # Extend stride, padding, and dilation + node.args = ( + node.args[0], + node.args[1], + node.args[2], + node.args[3] + [1], # stride + node.args[4] + [0], # padding + node.args[5] + [1], # dilation + node.args[6], + node.args[7], + node.args[8], + ) + # unsqueeze -> conv2d -> squeeze - # unsqueeze -> conv2d -> squeeze - with graph.inserting_before(node): - input_node = node.args[0] - unsqueeze_before = graph.create_node( - "call_function", exir_ops.edge.aten.unsqueeze_copy.default - ) - unsqueeze_before.args = ( - input_node, - -1, - ) - node.replace_input_with(input_node, unsqueeze_before) + with graph.inserting_before(node): + input_node = node.args[0] + prev_qparams = input_node.meta.get("quantize_attrs") + unsqueeze_before = graph.create_node( + "call_function", exir_ops.edge.aten.unsqueeze_copy.default + ) + unsqueeze_before.args = ( + input_node, + -1, + ) + node.replace_input_with(input_node, unsqueeze_before) - with graph.inserting_after(node): - squeeze_after = graph.create_node( - "call_function", exir_ops.edge.aten.squeeze_copy.dims - ) - squeeze_after.args = ( - node, - [-1], - ) - original_users = [ - user for user in node.users if user != squeeze_after - ] - for user in original_users: - user.replace_input_with(node, squeeze_after) + with graph.inserting_after(node): + squeeze_after = graph.create_node( + "call_function", exir_ops.edge.aten.squeeze_copy.dims + ) + squeeze_after.args = ( + node, + [-1], + ) + original_users = [user for user in node.users if user != squeeze_after] + for user in original_users: + user.replace_input_with(node, squeeze_after) + if quant_attr := node.meta.get("quantize_attrs"): + squeeze_after.meta["quantize_attrs"] = quant_attr + if prev_qparams is not None: + unsqueeze_before.meta["quantize_attrs"] = prev_qparams graph_module.recompile() - graph_module = super().call(graph_module).graph_module + _ = super().call(graph_module).graph_module return PassResult(graph_module, True) diff --git a/backends/samsung/_passes/fold_qdq.py b/backends/samsung/_passes/fold_qdq.py new file mode 100644 index 00000000000..c6f3699ece7 --- /dev/null +++ b/backends/samsung/_passes/fold_qdq.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.samsung.utils.constants import QuantConstants +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass +from torch.fx import GraphModule + + +class FoldQDQPass(ExportPass): + def __init__(self): + super().__init__() + + def _fold( + self, + graph_module: GraphModule, + ): + for node in graph_module.graph.nodes: + if node.target not in ( + *QuantConstants.QUANT_OPS_KEY_MAP.keys(), + *QuantConstants.DEQUANT_OPS_KEY_MAP.keys(), + ): + continue + for user in [user for user in node.users.keys()]: # noqa: C416 + user.replace_input_with(node, node.args[0]) + graph_module.graph.erase_node(node) + + def call(self, graph_module: GraphModule): + self._fold(graph_module) + graph_module.recompile() + dead_code_elimination_pass(graph_module) + _ = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/samsung/_passes/fuse_conv_act.py b/backends/samsung/_passes/fuse_conv_act.py new file mode 100644 index 00000000000..c034c98bb14 --- /dev/null +++ b/backends/samsung/_passes/fuse_conv_act.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass +from torch.fx import GraphModule + + +def map_hardtan_relux(tanhnode: torch.fx.node.Node) -> Optional[str]: + assert ( + tanhnode.target == exir_ops.edge.aten.hardtanh.default + ), "Must be a hardtanh node" + if not tanhnode.args[1] == 0.0: + return None + if tanhnode.args[2] == 6.0: + return "RELU6" + return None + + +class FuseConvActPass(ExportPass): + TARGET_ACTS_MAP = { + exir_ops.edge.aten.relu.default: (lambda x: "RELU"), + exir_ops.edge.aten.relu_.default: (lambda x: "RELU"), + exir_ops.edge.aten.relu6.default: (lambda x: "RELU6"), + exir_ops.edge.aten.relu6_.default: (lambda x: "RELU6"), + exir_ops.edge.aten.hardtanh.default: map_hardtan_relux, + exir_ops.edge.aten.hardtanh_.default: map_hardtan_relux, + } + + def _fuse( + self, + graph_module: GraphModule, + ): + for target_conv, target_act in self.get_target_conv_act(graph_module): + assert ( + act_name := self.TARGET_ACTS_MAP.get(target_act.target)(target_act) + ), f"Not supported {target_act.name} now." + target_conv.meta["activation"] = act_name + if "quantize_attrs" in target_act.meta: + target_conv.meta["quantize_attrs"] = target_act.meta["quantize_attrs"] + + # If we merge the real out activation to conv, the conv should be the real out + if "real_out" in target_act.meta: + target_conv.meta["real_out"] = target_act.meta["real_out"] + for user in [user for user in target_act.users.keys()]: # noqa: C416 + user.replace_input_with(target_act, target_conv) + graph_module.graph.erase_node(target_act) + + def get_target_conv_act(self, graph_module: GraphModule): + for node in graph_module.graph.nodes: + if node.target != exir_ops.edge.aten.convolution.default: + continue + if len(node.users) != 1: + # Such cases couldn't be conv + act + continue + act_node = list(node.users.keys())[0] + if act_node.target not in self.TARGET_ACTS_MAP: + continue + if "quantize_attrs" in node.meta: + # If the conv's output is quantized + # We do not fuse them + continue + yield node, act_node + + def call(self, graph_module: GraphModule): + self._fuse(graph_module) + graph_module.recompile() + dead_code_elimination_pass(graph_module) + _ = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/samsung/_passes/insert_qdq.py b/backends/samsung/_passes/insert_qdq.py new file mode 100644 index 00000000000..a59b011ac4b --- /dev/null +++ b/backends/samsung/_passes/insert_qdq.py @@ -0,0 +1,164 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict + +import torch +from executorch.backends.samsung._passes.utils import none_quant_tensor_quant_meta +from executorch.backends.samsung.utils.constants import QuantConstants +from executorch.backends.samsung.utils.utils import is_graph_input, is_graph_output + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.export import ExportedProgram +from torch.fx import GraphModule + + +class QType(Enum): + Quant = 0 + Dequant = 1 + + +class InsertQDQPass(ExportPass): + QDQ_MAP = { + # per tensor + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + # per channel + exir_ops.edge.quantized_decomposed.quantize_per_channel.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + } + + def __init__(self, edge_program: ExportedProgram): + super().__init__() + self.edge_program = edge_program + + def _create_qdq_node( + self, + graph_module: GraphModule, + qtype: QType, + input_node: torch.fx.Node, + quant_attrs: Dict[str, Any], + ) -> torch.fx.Node: + assert (target := quant_attrs.get("target")), "" + new_node_args = [input_node] + new_node_meta_val = input_node.meta["val"] + new_node_quant_attrs = {} + if qtype == QType.Dequant: + target = self.QDQ_MAP[target] + else: + # For input node, we should set the val type as quant type + key = QuantConstants.QUANT_KEY.quant_dtype + new_node_meta_val = new_node_meta_val.to(quant_attrs[key]) + new_node_quant_attrs.update(quant_attrs) + + for arg in target._schema.arguments[1:]: + name = arg.name + if name == "out_dtype": + continue + if qtype == QType.Quant: + key = QuantConstants.QUANT_OPS_KEY_MAP[target].get(name, name) + else: + key = QuantConstants.DEQUANT_OPS_KEY_MAP[target].get(name, name) + arg_value = quant_attrs[key] + if isinstance(arg.type, torch.Tensor) and ( + isinstance(arg_value, int) or isinstance(arg_value, float) + ): + arg_value = torch.Tensor(arg_value) + new_node_args.append(arg_value) + + new_node = graph_module.graph.create_node( + "call_function", target, tuple(new_node_args) + ) + if new_node_quant_attrs: + new_node.meta["quantize_attrs"] = new_node_quant_attrs + else: + new_node.meta["quantize_attrs"] = { + QuantConstants.QUANT_KEY.quant_dtype: torch.float32, + QuantConstants.QUANT_KEY.scale: [1.0], + QuantConstants.QUANT_KEY.zero_point: [0], + } + new_node.meta["val"] = new_node_meta_val + return new_node + + def _add_dq_after(self, graph_module: GraphModule, node: torch.fx.Node): + if not (quant_attrs := node.meta.get("quantize_attrs")): + return + with graph_module.graph.inserting_after(node): + new_node = self._create_qdq_node( + graph_module, QType.Dequant, node, quant_attrs + ) + users = [user for user in node.users.keys() if (user.op == "output")] + for user in users: + user.replace_input_with(node, new_node) + + def _add_q_after(self, graph_module: GraphModule, node: torch.fx.Node): + # In node don't need quant attrs after insert new quantize node. + if not (quant_attrs := node.meta.pop("quantize_attrs", None)): + return + node.meta["quantize_attrs"] = none_quant_tensor_quant_meta() + with graph_module.graph.inserting_after(node): + users = list(node.users.keys()) + new_node = self._create_qdq_node( + graph_module, QType.Quant, node, quant_attrs + ) + for user in users: + if user.target not in QuantConstants.QUANT_OPS_KEY_MAP: + user.replace_input_with(node, new_node) + + def _add_q_before( + self, + graph_module: GraphModule, + node: torch.fx.Node, + from_node: torch.fx.Node, + quantize_attrs: Dict, + ): + with graph_module.graph.inserting_before(node): + new_quant_node = self._create_qdq_node( + graph_module, QType.Quant, from_node, quantize_attrs + ) + node.replace_input_with(from_node, new_quant_node) + return new_quant_node + + def _add_dq_before( + self, + graph_module: GraphModule, + node: torch.fx.Node, + from_node: torch.fx.Node, + quantize_attrs: Dict, + ): + with graph_module.graph.inserting_before(node): + new_dequant_node = self._create_qdq_node( + graph_module, QType.Dequant, from_node, quantize_attrs + ) + node.replace_input_with(from_node, new_dequant_node) + return new_dequant_node + + def _add_qdq_for_requantize(self, graph_module: GraphModule): + for node in graph_module.graph.nodes: + requant_map: Dict[int, Dict] = node.meta.get("requantize") + if requant_map is None: + continue + assert (ori_quant_attrs := node.meta.get("quantize_attrs")) + usr_list = list(node.users.keys()) + for user_idx, requant_params in requant_map.items(): + user = usr_list[user_idx] + q_node = self._add_q_before(graph_module, user, node, requant_params) + _ = self._add_dq_before(graph_module, q_node, node, ori_quant_attrs) + + def _add_qdq(self, graph_module: GraphModule): + for node in list(graph_module.graph.nodes): + if is_graph_input(self.edge_program, node): + self._add_q_after(graph_module, node) + elif is_graph_output(node): + self._add_dq_after(graph_module, node) + + def call(self, graph_module: GraphModule): + self._add_qdq(graph_module) + self._add_qdq_for_requantize(graph_module) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/samsung/_passes/remove_useless_ops.py b/backends/samsung/_passes/remove_useless_ops.py new file mode 100644 index 00000000000..c88a2d4a5d8 --- /dev/null +++ b/backends/samsung/_passes/remove_useless_ops.py @@ -0,0 +1,87 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass +from torch.fx import GraphModule + + +class RemoveUselessOpPass(ExportPass): + # such ops should be single-in and single-out + USELESS_OP_SET = { + exir_ops.edge.aten._to_copy.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.alias.default, + exir_ops.edge.aten.lift_fresh_copy.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + } + + def __init__(self): + super().__init__() + + def gen_pattern_as_strided_copy(self, graph_module: GraphModule): + for node in list(graph_module.graph.nodes): # noqa: C416 + if node.target != exir_ops.edge.aten.mean.dim: + continue + if len(node.users) != 1: + continue + successor = list(node.users.keys())[0] + if successor.target != exir_ops.edge.aten.as_strided_copy.default: + continue + is_pattern = True + count = 0 + for i, stride in enumerate(successor.args[2]): + if stride < node.meta["val"].size()[i]: + if stride == 1: + count += 1 + else: + is_pattern = False + break + if count >= 2: + is_pattern = False + break + if is_pattern: + yield successor + + def _fold_as_strided_copy( + self, + graph_module: GraphModule, + ): + for as_strided_copy_node in self.gen_pattern_as_strided_copy(graph_module): + for user in list(as_strided_copy_node.users.keys()): + user.replace_input_with( + as_strided_copy_node, as_strided_copy_node.args[0] + ) + graph_module.graph.erase_node(as_strided_copy_node) + + def _remove_useless( + self, + graph_module: GraphModule, + ): + for node in graph_module.graph.nodes: + if node.target not in self.USELESS_OP_SET: + continue + + # Prevent from removing if data type may change. + if ( + node.target == exir_ops.edge.aten._to_copy.default + or node.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default + ) and "memory_format" not in node.kwargs: + continue + + for user in [user for user in node.users.keys()]: # noqa: C416 + user.replace_input_with(node, node.all_input_nodes[0]) + graph_module.graph.erase_node(node) + self._fold_as_strided_copy(graph_module) + + def call(self, graph_module: GraphModule): + self._remove_useless(graph_module) + graph_module.recompile() + dead_code_elimination_pass(graph_module) + _ = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/samsung/_passes/utils.py b/backends/samsung/_passes/utils.py new file mode 100644 index 00000000000..afa7c72c601 --- /dev/null +++ b/backends/samsung/_passes/utils.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def none_quant_tensor_quant_meta(): + return { + "quant_dtype": torch.float32, + "scales": 1, + "zero_points": 0, + } diff --git a/backends/samsung/build.sh b/backends/samsung/build.sh index dfa6407ff50..4845c760f0c 100755 --- a/backends/samsung/build.sh +++ b/backends/samsung/build.sh @@ -45,6 +45,7 @@ function build_x86_64() { -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -S ${PROJECT_DIR} \ -B ${X86_64_BUILD_DIR} @@ -77,6 +78,7 @@ function build_android() { -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_ENABLE_LOGGING=1 \ -DEXECUTORCH_BUILD_DEVTOOLS=ON \ diff --git a/backends/samsung/builders/__init__.py b/backends/samsung/builders/__init__.py index 02a457fd06e..978da82b370 100644 --- a/backends/samsung/builders/__init__.py +++ b/backends/samsung/builders/__init__.py @@ -14,11 +14,13 @@ op_clamp, op_constant_pad_nd, op_conv2d, + op_dequantize, op_div, op_embedding, op_expand_copy, op_gelu, op_getitem, + op_hardsigmoid, op_hardswish, op_hardtanh, op_layer_norm, @@ -32,6 +34,7 @@ op_mul, op_permute, op_pixel_shuffle, + op_quantize, op_relu, op_reshape, op_rsqrt, @@ -57,6 +60,7 @@ op_clamp, op_conv2d, op_constant_pad_nd, + op_dequantize, op_div, op_embedding, op_expand_copy, @@ -64,6 +68,7 @@ op_getitem, op_hardswish, op_hardtanh, + op_hardsigmoid, op_layer_norm, op_leaky_relu, op_linear, @@ -75,6 +80,7 @@ op_mul, op_permute, op_pixel_shuffle, + op_quantize, op_relu, op_reshape, op_rsqrt, diff --git a/backends/samsung/builders/node_visitor.py b/backends/samsung/builders/node_visitor.py index a35c0b4715d..0d2707da8f5 100644 --- a/backends/samsung/builders/node_visitor.py +++ b/backends/samsung/builders/node_visitor.py @@ -14,6 +14,7 @@ get_tensor_type, ) from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph +from executorch.backends.samsung.utils.constants import QuantConstants from executorch.backends.transforms.utils import is_param_node from torch.export import ExportedProgram @@ -61,18 +62,26 @@ def define_tensor( dims = [1] if len(tensor.size()) == 0 else list(tensor.size()) + quant_attrs = node.meta.get("quantize_attrs") enn_tensor_id = enn_graph.define_tensor( node.name, dims, data_type, tensor_type.name, const_data, + quant_param=quant_attrs, ) assert enn_tensor_id is not None vals_to_ids[node] = enn_tensor_id return enn_tensor_id + def _update_params_qdtype(self, node: torch.fx.Node, params: Dict): + if qdtype := node.meta.get("quantize_attrs", {}).get( + QuantConstants.QUANT_KEY.quant_dtype + ): + params["quant_dtype"] = EnnGraph._affine_meta_param(qdtype) + _node_visitor_dict = {} @@ -92,6 +101,7 @@ def register_node_visitor(visitor): raise TypeError( f"target of vistor should be str|Tuple[str]|List[str], not{type(visitor.target)}" ) + return visitor def get_node_visitors(*args) -> Dict[str, NodeVisitor]: diff --git a/backends/samsung/builders/op_add.py b/backends/samsung/builders/op_add.py index 1b0dddb0d02..a6eb79897dd 100644 --- a/backends/samsung/builders/op_add.py +++ b/backends/samsung/builders/op_add.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import Dict import torch @@ -28,9 +29,13 @@ def define_node( ) -> None: input1 = node.args[0] input_id_1 = self.define_tensor(input1, enn_graph, vals_to_ids) + params = {} + self._update_params_qdtype(node, params) input2 = node.args[1] input_id_2 = self.define_tensor(input2, enn_graph, vals_to_ids) output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "ELTSUM", [input_id_1, input_id_2], [output_id]) + enn_graph.define_op( + node.name, "ELTSUM", [input_id_1, input_id_2], [output_id], params + ) diff --git a/backends/samsung/builders/op_avg_pool2d.py b/backends/samsung/builders/op_avg_pool2d.py index ad7ccbac3ae..bfca8b89b22 100644 --- a/backends/samsung/builders/op_avg_pool2d.py +++ b/backends/samsung/builders/op_avg_pool2d.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import cast, Dict, List import torch @@ -49,6 +50,7 @@ def define_node( params["stride_w"] = stride[1] params["padding"] = "EXPLICIT" params["explicit_padding"] = explicit_padding + self._update_params_qdtype(node, params) if len(node.args) > 4: ceil_mode = cast(bool, node.args[4]) @@ -64,7 +66,5 @@ def define_node( assert ( divisor_override == kernel_size[0] * kernel_size[1] ), "Not supported divisor_override which is not equal to pooling region." - output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "AVGPOOL2D", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_bmm.py b/backends/samsung/builders/op_bmm.py index 6ba8864ebb3..13e0d19cb14 100644 --- a/backends/samsung/builders/op_bmm.py +++ b/backends/samsung/builders/op_bmm.py @@ -16,7 +16,7 @@ @register_node_visitor class BMMVisitor(NodeVisitor): - target = "aten.bmm.default" + target = ["aten.bmm.default"] def __init__(self, *args) -> None: super().__init__(*args) @@ -29,12 +29,15 @@ def define_node( ) -> None: input1 = node.args[0] input_id_1 = self.define_tensor(input1, enn_graph, vals_to_ids) + input2 = node.args[1] input_id_2 = self.define_tensor(input2, enn_graph, vals_to_ids) # output output_id = self.define_tensor(node, enn_graph, vals_to_ids) + params = {} + self._update_params_qdtype(node, params) enn_graph.define_op( - node.name, "BATCH_MATMUL", [input_id_1, input_id_2], [output_id] + node.name, "BATCH_MATMUL", [input_id_1, input_id_2], [output_id], params ) diff --git a/backends/samsung/builders/op_cat.py b/backends/samsung/builders/op_cat.py index e9c0a32b389..09387f2e361 100644 --- a/backends/samsung/builders/op_cat.py +++ b/backends/samsung/builders/op_cat.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import cast, Dict, List import torch @@ -12,6 +13,7 @@ ) from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph from executorch.backends.transforms import get_shape +from executorch.backends.transforms.utils import is_param_node @register_node_visitor @@ -29,14 +31,20 @@ def define_node( ) -> None: tensors = cast(List[torch.fx.Node], node.args[0]) input_tensor_ids = [] - - for in_tensor in tensors: + constant_idx = None + for idx, in_tensor in enumerate(tensors): + if is_param_node(self.exported_program, in_tensor): + assert constant_idx is None, "Only support at most 1 constant tensor" + constant_idx = idx input_id = self.define_tensor(in_tensor, enn_graph, vals_to_ids) input_tensor_ids.append(input_id) in_shape = get_shape(node) axis = cast(int, node.args[1]) % len(in_shape) if len(node.args) >= 2 else 0 params = {"axis": axis} + if constant_idx is not None: + params["constant_index"] = constant_idx + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) enn_graph.define_op(node.name, "CONCAT", input_tensor_ids, [output_id], params) diff --git a/backends/samsung/builders/op_clamp.py b/backends/samsung/builders/op_clamp.py index c5670b80fa3..74af83212a5 100644 --- a/backends/samsung/builders/op_clamp.py +++ b/backends/samsung/builders/op_clamp.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import cast, Dict import torch @@ -32,12 +33,15 @@ def define_node( # The default value of lower bound and upper bound output_min = torch.finfo(torch.float32).min output_max = torch.finfo(torch.float32).max + if node.args[1] is not None: output_min = cast(float, node.args[1]) if len(node.args) > 2 and node.args[2] is not None: output_max = cast(float, node.args[2]) params = {"minimum": output_min, "maximum": output_max} + self._update_params_qdtype(node, params) + output_id = self.define_tensor(node, enn_graph, vals_to_ids) enn_graph.define_op(node.name, "CLIP", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_conv2d.py b/backends/samsung/builders/op_conv2d.py index 881a533801f..ab77d8df626 100644 --- a/backends/samsung/builders/op_conv2d.py +++ b/backends/samsung/builders/op_conv2d.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import cast, Dict, List import torch @@ -56,6 +57,9 @@ def define_node( input_shape = get_shape(input) kernel_shape = get_shape(weight_node) params = {} + self._update_params_qdtype(node, params) + if "activation" in node.meta: + params["activation"] = node.meta["activation"] params["kernel_h"] = kernel_shape[2] params["kernel_w"] = kernel_shape[3] params["stride_h"] = stride[0] diff --git a/backends/samsung/builders/op_dequantize.py b/backends/samsung/builders/op_dequantize.py new file mode 100644 index 00000000000..a1c31af4037 --- /dev/null +++ b/backends/samsung/builders/op_dequantize.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.samsung.builders.node_visitor import register_node_visitor +from executorch.backends.samsung.builders.op_quantize import _QuantOpVistorBase + + +# Dequant ops here +@register_node_visitor +class DequantizeVistor(_QuantOpVistorBase): + target = [ + "quantized_decomposed.dequantize_per_tensor.default", + "quantized_decomposed.dequantize_per_tensor.tensor", + "quantized_decomposed.dequantize_per_channel.default", + "quantized_decomposed.dequantize_per_channel.tensor", + ] diff --git a/backends/samsung/builders/op_div.py b/backends/samsung/builders/op_div.py index 89d773ddb0e..8b0e7cdd5af 100644 --- a/backends/samsung/builders/op_div.py +++ b/backends/samsung/builders/op_div.py @@ -27,13 +27,16 @@ def define_node( enn_graph: EnnGraph, vals_to_ids: Dict[torch.Tensor, int], ) -> None: - # inputs input1 = node.args[0] input_id_1 = self.define_tensor(input1, enn_graph, vals_to_ids) + input2 = node.args[1] input_id_2 = self.define_tensor(input2, enn_graph, vals_to_ids) - + params = {} + self._update_params_qdtype(node, params) # output output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "ELTDIV", [input_id_1, input_id_2], [output_id]) + enn_graph.define_op( + node.name, "ELTDIV", [input_id_1, input_id_2], [output_id], params + ) diff --git a/backends/samsung/builders/op_gelu.py b/backends/samsung/builders/op_gelu.py index 059a3b77850..88417f688f9 100644 --- a/backends/samsung/builders/op_gelu.py +++ b/backends/samsung/builders/op_gelu.py @@ -27,8 +27,14 @@ def define_node( enn_graph: EnnGraph, vals_to_ids: Dict[torch.Tensor, int], ) -> None: - input_id = self.define_tensor(node.args[0], enn_graph, vals_to_ids) + # input1 + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + # output output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "GELU", [input_id], [output_id]) + params = {} + self._update_params_qdtype(node, params) + + enn_graph.define_op(node.name, "GELU", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_hardsigmoid.py b/backends/samsung/builders/op_hardsigmoid.py new file mode 100644 index 00000000000..3a50d65da41 --- /dev/null +++ b/backends/samsung/builders/op_hardsigmoid.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph + + +@register_node_visitor +class HardSigmoidVisitor(NodeVisitor): + target = "aten.hardsigmoid.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + params = {} + self._update_params_qdtype(node, params) + enn_graph.define_op(node.name, "HardSigmoid", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_hardswish.py b/backends/samsung/builders/op_hardswish.py index 72a99d17b83..8c30125e8a4 100644 --- a/backends/samsung/builders/op_hardswish.py +++ b/backends/samsung/builders/op_hardswish.py @@ -29,7 +29,7 @@ def define_node( ) -> None: input = node.args[0] input_id = self.define_tensor(input, enn_graph, vals_to_ids) - + params = {} + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) - - enn_graph.define_op(node.name, "HARDSWISH", [input_id], [output_id]) + enn_graph.define_op(node.name, "HARDSWISH", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_hardtanh.py b/backends/samsung/builders/op_hardtanh.py index 4f667bf5299..7d65e97a566 100644 --- a/backends/samsung/builders/op_hardtanh.py +++ b/backends/samsung/builders/op_hardtanh.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import cast, Dict import torch @@ -29,9 +30,12 @@ def define_node( input = node.args[0] input_id = self.define_tensor(input, enn_graph, vals_to_ids) + # default value of output_min and output_max output_min = cast(float, node.args[1]) if len(node.args) > 1 else -1 output_max = cast(float, node.args[2]) if len(node.args) > 2 else 1 + params = {"minimum": output_min, "maximum": output_max} + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) diff --git a/backends/samsung/builders/op_layer_norm.py b/backends/samsung/builders/op_layer_norm.py index e6f853178d8..098bc92dc84 100644 --- a/backends/samsung/builders/op_layer_norm.py +++ b/backends/samsung/builders/op_layer_norm.py @@ -46,9 +46,8 @@ def define_node( epsilon = node.args[4] if len(node.args) > 4 else 1e-5 params = {"epsilon": epsilon} - + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op( node.name, "LAYERNORM", all_input_tensors, [output_id], params ) diff --git a/backends/samsung/builders/op_linear.py b/backends/samsung/builders/op_linear.py index 2f7aa1e6415..720439de976 100644 --- a/backends/samsung/builders/op_linear.py +++ b/backends/samsung/builders/op_linear.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import Dict import torch @@ -43,6 +44,7 @@ def define_node( weight_shape = get_shape(weight_node) params = {"in_channels": weight_shape[1], "out_channels": weight_shape[0]} + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) diff --git a/backends/samsung/builders/op_max_pool2d.py b/backends/samsung/builders/op_max_pool2d.py index d386dd30b1a..57b716fcb34 100644 --- a/backends/samsung/builders/op_max_pool2d.py +++ b/backends/samsung/builders/op_max_pool2d.py @@ -73,6 +73,7 @@ def define_node( params["explicit_padding"] = explicit_padding params["dilation_h"] = dilation[0] params["dilation_w"] = dilation[1] + self._update_params_qdtype(node, params) if len(node.args) > 5: ceil_mode = cast(bool, node.args[5]) diff --git a/backends/samsung/builders/op_mean_dim.py b/backends/samsung/builders/op_mean_dim.py index 2f07f870ec4..3d0377703a7 100644 --- a/backends/samsung/builders/op_mean_dim.py +++ b/backends/samsung/builders/op_mean_dim.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import cast, Dict, List import torch @@ -27,6 +28,7 @@ def define_node( enn_graph: EnnGraph, vals_to_ids: Dict[torch.Tensor, int], ) -> None: + # input input = node.args[0] input_id = self.define_tensor(input, enn_graph, vals_to_ids) @@ -37,8 +39,11 @@ def define_node( in_shape = get_shape(input) for dim in dims: reduce_axes.append(dim % len(in_shape)) - reduce_axes.sort() + + if len(node.args[1]) > 1: + reduce_axes.sort() keep_dim = node.args[2] if len(node.args) >= 3 else False params = {"keep_dims": keep_dim, "axis": reduce_axes} + self._update_params_qdtype(node, params) enn_graph.define_op(node.name, "REDUCEMEAN", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_mul.py b/backends/samsung/builders/op_mul.py index dce531ff0b0..6dd7c0dd9f0 100644 --- a/backends/samsung/builders/op_mul.py +++ b/backends/samsung/builders/op_mul.py @@ -1,5 +1,9 @@ -# Copyright (c) 2024 Samsung Electronics Co. LTD +# Copyright (c) 2025 Samsung Electronics Co. LTD # All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from typing import Dict import torch @@ -23,11 +27,17 @@ def define_node( enn_graph: EnnGraph, vals_to_ids: Dict[torch.Tensor, int], ) -> None: + input1 = node.args[0] input_id_1 = self.define_tensor(input1, enn_graph, vals_to_ids) + input2 = node.args[1] input_id_2 = self.define_tensor(input2, enn_graph, vals_to_ids) + params = {} + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "ELTMUL", [input_id_1, input_id_2], [output_id]) + enn_graph.define_op( + node.name, "ELTMUL", [input_id_1, input_id_2], [output_id], params + ) diff --git a/backends/samsung/builders/op_quantize.py b/backends/samsung/builders/op_quantize.py new file mode 100644 index 00000000000..dcf30e291f9 --- /dev/null +++ b/backends/samsung/builders/op_quantize.py @@ -0,0 +1,60 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.samsung.builders.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph +from executorch.backends.samsung.utils.constants import QuantConstants + + +class _QuantOpVistorBase(NodeVisitor): + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + enn_graph: EnnGraph, + vals_to_ids: Dict[torch.Tensor, int], + ) -> None: + # input + input = node.args[0] + input_id = self.define_tensor(input, enn_graph, vals_to_ids) + + scales = node.args[1] + if isinstance(scales, torch.Tensor): + scales = scales.tolist() + elif not isinstance(scales, list): + scales = torch.tensor(scales).reshape([1]).tolist() + zero_points = node.args[2] + if isinstance(zero_points, torch.Tensor): + zero_points = zero_points.tolist() + elif not isinstance(zero_points, list): + zero_points = torch.tensor(zero_points).reshape([1]).tolist() + + output_id = self.define_tensor(node, enn_graph, vals_to_ids) + + params = {"scales": scales, "zero_points": zero_points} + + if node.target in QuantConstants.QUANT_OPS_KEY_MAP: + enn_graph.define_op(node.name, "QUANTIZE", [input_id], [output_id], params) + else: + enn_graph.define_op( + node.name, "DEQUANTIZE", [input_id], [output_id], params + ) + + +@register_node_visitor +class QuantizeVistor(_QuantOpVistorBase): + target = [ + "quantized_decomposed.quantize_per_tensor.default", + "quantized_decomposed.quantize_per_channel.default", + ] diff --git a/backends/samsung/builders/op_relu.py b/backends/samsung/builders/op_relu.py index ba90116be1d..a4a2b6bc4f0 100644 --- a/backends/samsung/builders/op_relu.py +++ b/backends/samsung/builders/op_relu.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from typing import Dict import torch @@ -30,5 +31,7 @@ def define_node( input_id = self.define_tensor(input, enn_graph, vals_to_ids) output_id = self.define_tensor(node, enn_graph, vals_to_ids) + params = {} + self._update_params_qdtype(node, params) - enn_graph.define_op(node.name, "RELU", [input_id], [output_id]) + enn_graph.define_op(node.name, "RELU", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_softmax.py b/backends/samsung/builders/op_softmax.py index 1e2e4a378dc..7f569cea6fc 100644 --- a/backends/samsung/builders/op_softmax.py +++ b/backends/samsung/builders/op_softmax.py @@ -35,5 +35,5 @@ def define_node( axis = cast(int, node.args[1]) params = {"axis": axis} - + self._update_params_qdtype(node, params) enn_graph.define_op(node.name, "SOFTMAX", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_squeeze.py b/backends/samsung/builders/op_squeeze.py index d165a22fcb3..82fa17fbc95 100644 --- a/backends/samsung/builders/op_squeeze.py +++ b/backends/samsung/builders/op_squeeze.py @@ -33,4 +33,5 @@ def define_node( # output output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "RESHAPE", [input_id], [output_id]) + params = {"new_shape": [*node.meta["val"].shape]} + enn_graph.define_op(node.name, "RESHAPE", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_to_copy.py b/backends/samsung/builders/op_to_copy.py index 545672ef6a3..c770602bb5f 100644 --- a/backends/samsung/builders/op_to_copy.py +++ b/backends/samsung/builders/op_to_copy.py @@ -11,6 +11,8 @@ NodeVisitor, register_node_visitor, ) + +from executorch.backends.samsung.builders.utils import get_map_dtype, get_tensor from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph @@ -35,5 +37,8 @@ def define_node( input_id = self.define_tensor(input, enn_graph, vals_to_ids) output_id = self.define_tensor(node, enn_graph, vals_to_ids) + params = {} + out_tensor = get_tensor(self.exported_program, node) + params["out_dtype"] = get_map_dtype(out_tensor.dtype) - enn_graph.define_op(node.name, "CAST", [input_id], [output_id]) + enn_graph.define_op(node.name, "CAST", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_unsqueeze.py b/backends/samsung/builders/op_unsqueeze.py index 942c3307de7..61fa06e6310 100644 --- a/backends/samsung/builders/op_unsqueeze.py +++ b/backends/samsung/builders/op_unsqueeze.py @@ -31,4 +31,5 @@ def define_node( output_id = self.define_tensor(node, enn_graph, vals_to_ids) - enn_graph.define_op(node.name, "RESHAPE", [input_id], [output_id]) + params = {"new_shape": [*node.meta["val"].shape]} + enn_graph.define_op(node.name, "RESHAPE", [input_id], [output_id], params) diff --git a/backends/samsung/builders/op_upsample_bilinear2d.py b/backends/samsung/builders/op_upsample_bilinear2d.py index a934b2789ba..d4b040460e3 100644 --- a/backends/samsung/builders/op_upsample_bilinear2d.py +++ b/backends/samsung/builders/op_upsample_bilinear2d.py @@ -46,6 +46,7 @@ def define_node( "upsampling_factor": scale_factor, "half_pixel_centers": True, } + self._update_params_qdtype(node, params) output_id = self.define_tensor(node, enn_graph, vals_to_ids) enn_graph.define_op( node.name, "RESIZE_BILINEAR", [input_id], [output_id], params diff --git a/backends/samsung/builders/utils.py b/backends/samsung/builders/utils.py index 58c84ff6d31..a640071c798 100644 --- a/backends/samsung/builders/utils.py +++ b/backends/samsung/builders/utils.py @@ -9,7 +9,6 @@ import torch from executorch.backends.samsung.utils.utils import is_graph_input, is_graph_output from executorch.backends.transforms.utils import get_param_tensor, is_param_node - from torch.export import ExportedProgram DATA_TYPE_STR_MAPPING = { diff --git a/backends/samsung/enn_preprocess.py b/backends/samsung/enn_preprocess.py index dde01bc09c7..0847ec0adeb 100644 --- a/backends/samsung/enn_preprocess.py +++ b/backends/samsung/enn_preprocess.py @@ -9,10 +9,16 @@ import executorch.backends.samsung.python.PyEnnWrapperAdaptor as PyEnnWrapper import torch +from executorch.backends.samsung._passes.annotate_qparams import AnnotateQparamsPass +from executorch.backends.samsung._passes.annotate_scalar_parameters import ( + AnnotateScalarParametersPass, +) from executorch.backends.samsung._passes.conv1d_to_conv2d import Conv1dToConv2d from executorch.backends.samsung._passes.customized_constant_prop import ( ConstantPropPass, ) +from executorch.backends.samsung._passes.fold_qdq import FoldQDQPass +from executorch.backends.samsung._passes.insert_qdq import InsertQDQPass from executorch.backends.samsung._passes.replace_scalar_ops import ReplaceOpsWithScalar from executorch.backends.samsung.builders.node_visitor import get_node_visitors from executorch.backends.samsung.serialization.compile_options import ( @@ -53,12 +59,16 @@ def preprocess( enn_preprocess_passes = PassManager( passes=[ + AnnotateQparamsPass(edge_program), + FoldQDQPass(), ConstantPropPass(edge_program), Conv1dToConv2d(edge_program), FuseBatchNormWithConvPass(edge_program), AddmmToLinearTransform(), ReplaceOpsWithScalar(), RemoveGetItemPass(), + InsertQDQPass(edge_program), + AnnotateScalarParametersPass(edge_program), ] ) pass_result = enn_preprocess_passes(edge_program.graph_module) diff --git a/backends/samsung/partition/enn_partitioner.py b/backends/samsung/partition/enn_partitioner.py index 952cb000429..368d069c380 100644 --- a/backends/samsung/partition/enn_partitioner.py +++ b/backends/samsung/partition/enn_partitioner.py @@ -129,5 +129,6 @@ def ops_to_not_decompose( torch.ops.aten.prelu.default, torch.ops.aten.layer_norm.default, torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.hardsigmoid.default, ] return (ops_not_to_decompose, None) diff --git a/backends/samsung/quantizer/__init__.py b/backends/samsung/quantizer/__init__.py new file mode 100644 index 00000000000..621eec69240 --- /dev/null +++ b/backends/samsung/quantizer/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .qconfig import Precision +from .quantizer import EnnQuantizer + +__all__ = [EnnQuantizer, Precision] diff --git a/backends/samsung/quantizer/annotator.py b/backends/samsung/quantizer/annotator.py new file mode 100644 index 00000000000..31015698006 --- /dev/null +++ b/backends/samsung/quantizer/annotator.py @@ -0,0 +1,871 @@ +# Copyright (c) Qualcomm Innovation Center, Inc +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Dict, List + +import torch +from torch._ops import OpOverload +from torch._subclasses import FakeTensor + +from torch.fx import Graph, Node + +from torchao.quantization.pt2e import FixedQParamsObserver +from torchao.quantization.pt2e.quantizer import ( + annotate_output_qspec, + QuantizationAnnotation, + QuantizationSpec, + SharedQuantizationSpec, +) + +from .qconfig import QuantizationConfig + +OP_ANNOTATOR: Dict[OpOverload, Callable] = {} + +ADD_OPS = [ + torch.ops.aten.add, + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, +] + + +def register_annotator(ops: List[OpOverload]): + def decorator(annotator: Callable): + for op in ops: + OP_ANNOTATOR[op] = annotator + + return decorator + + +def annotate(graph: Graph, quant_config: QuantizationConfig) -> None: + # Pattern annotation + _annotate_fused_activation_pattern(graph, quant_config) + + # Per-op annotation + for node in graph.nodes: + if node.op == "placeholder": + annotate_placeholder(node, quant_config) + elif node.op == "call_function": + annotate_func = OP_ANNOTATOR.get(node.target, None) + if annotate_func is not None: + annotate_func(node, quant_config) + + +def _is_annotated(nodes: List[Node]): + """ + Given a list of nodes (that represents an operator pattern), + return True if any of the node + is annotated, otherwise return False + """ + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + +def _is_fake_tensor(node: Node): + if ( + isinstance(node, Node) + and "val" in node.meta + and isinstance(node.meta["val"], FakeTensor) + ): + return True + return False + + +def _is_float_tensor(node: Node): + """Check if the node's tensor is a float tensor, + so that we can skip quantization for the node + since observers only works with float Tensors + """ + if not _is_fake_tensor(node): + return False + return node.meta["val"].dtype in [torch.float32, torch.float16] + + +def _mark_nodes_as_annotated(nodes: List[Node]): + for node in nodes: + if "quantization_annotation" not in node.meta: + node.meta["quantization_annotation"] = QuantizationAnnotation() + node.meta["quantization_annotation"]._annotated = True + + +# for nodes whose targets ars placehold (not call_function) +def annotate_placeholder(node: Node, quant_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + if _is_float_tensor(node): + annotate_output_qspec(node, quant_config.output_activation) + + _mark_nodes_as_annotated([node]) + + +# CASE 1: fused_activation case (ex. Conv2D + ReLU) +def _is_hardtanh_for_relux(relu_node: torch.fx.node.Node): + if relu_node.target in [ + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + ]: + # checking if hardtanh is convertable to ReLU6 + # ReLU1 is not supported now + if not relu_node.args[1] == 0.0: + return False + if relu_node.args[2] == 6.0: # for ReLU6 + return True + return True + + +def _annotate_fused_activation_pattern( + graph: Graph, quant_config: QuantizationConfig +) -> None: + for relu_node in graph.nodes: + # Check relu/relu6 node + if relu_node.op != "call_function": + continue + if relu_node.target not in [ + # The strategy of ReLU and ReLU6 is fold_activation in ENNQuant + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.relu6.default, + torch.ops.aten.relu6_.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + ]: + continue + + if not _is_hardtanh_for_relux(relu_node): + continue + + producer_node = relu_node.args[0] + if not isinstance(producer_node, Node): + continue + if producer_node.op != "call_function": + continue + if len(producer_node.users) != 1: + continue + + # Handle affine + relu fusion + if producer_node.target in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, + ]: + # input & weight (or bias) setting for Conv node(producer_node) + quantization_annotation = producer_node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + input = producer_node.args[0] + quantization_annotation.input_qspec_map[input] = ( + quant_config.input_activation + ) + + quantization_annotation.input_qspec_map[producer_node.args[1]] = ( + quant_config.weight + ) + if len(producer_node.args) > 2 and quant_config.bias is not None: + quantization_annotation.input_qspec_map[producer_node.args[2]] = ( + quant_config.bias + ) + + producer_node.meta["quantization_annotation"] = quantization_annotation + producer_node.meta["quantization_annotation"]._annotated = True + # out setting for activation node (relu_node) + quantization_annotation = relu_node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + quantization_annotation.output_qspec = quant_config.output_activation + + relu_node.meta["quantization_annotation"] = quantization_annotation + relu_node.meta["quantization_annotation"]._annotated = True + continue + + +# CASE 2-1: two input case without Shared Quant +@register_annotator( + [ + torch.ops.aten.div, + torch.ops.aten.div.Tensor, + torch.ops.aten.divide.Tensor, + torch.ops.aten.matmul.default, + torch.ops.aten.bmm.default, + torch.ops.aten.sum.dim_IntList, + ] +) +def annotate_2in1out(node: Node, quant_config: QuantizationConfig) -> None: + input_act0 = node.args[0] + input_act1 = node.args[1] + # skipping quantization if 1st input is not float. + if _is_annotated([node]) or not _is_float_tensor(input_act0): + return + + input_act_qspec = quant_config.input_activation + output_act_qspec = ( + quant_config.output_activation if _is_float_tensor(node) else None + ) + + input_qspec_map = {} + if _is_float_tensor(input_act0): + input_qspec_map[input_act0] = input_act_qspec + + if _is_float_tensor(input_act1): + input_qspec_map[input_act1] = input_act_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + + +# getting QuantAnnot though the first input +def _get_quantization_annotation(node: Node): + if node.op == "placeholder": + return False + elif "quantization_annotation" in node.meta: + return node + elif node.args == (): + return False + elif isinstance(node.args[0], Node): + return _get_quantization_annotation(node.args[0]) + elif isinstance(node.args[0], list): + # for cat, concatenate and stack + if isinstance(node.args[0][0], Node): + return _get_quantization_annotation(node.args[0][0]) + else: + return False + else: + return False + + +# CASE 2-2: two input case with Shared Quant +# ops.add / ops.add_ are processed by another annotator +@register_annotator( + [ + torch.ops.aten.sub, + torch.ops.aten.mul, + torch.ops.aten.sub.Tensor, + torch.ops.aten.mul.Tensor, + torch.ops.aten.sub_.Tensor, + torch.ops.aten.mul_.Tensor, + torch.ops.aten.rsub.Scalar, + torch.ops.aten.mul.Scalar, + ] +) +def annotate_2in1out_with_SharedQuant( + node: Node, quant_config: QuantizationConfig +) -> None: + + input_qspec_map = {} + input0 = node.args[0] + input1 = node.args[1] + + # skipping quantization if 1st input is not float. + if _is_annotated([node]) or not _is_float_tensor(input0): + return + if ( + isinstance(input0, Node) + and isinstance(input1, float) + and not _get_quantization_annotation(input0) + ): + return + if ( + isinstance(input0, float) + and isinstance(input1, Node) + and not _get_quantization_annotation(input1) + ): + return + if isinstance(input0, Node) and isinstance(input1, Node): + shared_qspec = SharedQuantizationSpec((input0, node)) + input_qspec_map[input0] = quant_config.input_activation + input_qspec_map[input1] = shared_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_qspec, + _annotated=True, + ) + + else: + input_act_qspec = quant_config.input_activation + output_act_qspec = ( + quant_config.output_activation if _is_float_tensor(node) else None + ) + + input_qspec_map = {} + input_act0 = node.args[0] + if _is_float_tensor(input_act0): + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = node.args[1] + if _is_float_tensor(input_act1): + input_qspec_map[input_act1] = input_act_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + + +# CASE 2-3: only for add ops +@register_annotator(ADD_OPS) +def annotate_add_ops_with_SharedQuant( + node: Node, quant_config: QuantizationConfig +) -> None: + + input_qspec_map = {} + input0 = node.args[0] + input1 = node.args[1] + + # skipping quantization if 1st input is not float. + if _is_annotated([node]) or not _is_float_tensor(input0): + return + + if isinstance(input0, Node) and isinstance(input1, Node): + NonQuantShare_ops_for_add = [torch.ops.aten.dropout.default] + ADD_OPS + if ( + input0.op == "call_function" and input0.target in NonQuantShare_ops_for_add + ) or ( + input1.op == "call_function" and input1.target in NonQuantShare_ops_for_add + ): + input_act_qspec = quant_config.input_activation + output_act_qspec = ( + quant_config.output_activation if _is_float_tensor(node) else None + ) + + input_qspec_map = {} + input_act0 = node.args[0] + if _is_float_tensor(input_act0): + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = node.args[1] + if _is_float_tensor(input_act1): + input_qspec_map[input_act1] = input_act_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + else: + shared_qspec = SharedQuantizationSpec((input0, node)) + input_qspec_map[input0] = quant_config.input_activation + input_qspec_map[input1] = shared_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_qspec, + _annotated=True, + ) + elif ( + isinstance(input0, Node) + and isinstance(input1, float) + and not _get_quantization_annotation(input0) + ): + pass + elif ( + isinstance(input0, float) + and isinstance(input1, Node) + and not _get_quantization_annotation(input1) + ): + pass + else: + input_act_qspec = quant_config.input_activation + output_act_qspec = ( + quant_config.output_activation if _is_float_tensor(node) else None + ) + + input_qspec_map = {} + input_act0 = node.args[0] + if _is_float_tensor(input_act0): + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = node.args[1] + if _is_float_tensor(input_act1): + input_qspec_map[input_act1] = input_act_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + + +# CASE 3-1: Single input + Single Out case without Shared Quant +@register_annotator( + [ + torch.ops.aten.ceil.default, + torch.ops.aten.clamp.default, + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.relu6.default, + torch.ops.aten.relu6_.default, + torch.ops.aten.cos.default, + torch.ops.aten.sin.default, + torch.ops.aten.tanh.default, + torch.ops.aten.hardswish.default, + torch.ops.aten.hardswish_.default, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardsigmoid_.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.mean.default, + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.leaky_relu.default, + torch.ops.aten.leaky_relu_.default, + torch.ops.aten.prelu.default, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.upsample_nearest2d.vec, + torch.ops.aten.mean.dim, + torch.ops.aten.sqrt.default, + torch.ops.aten.gelu.default, + torch.ops.aten.scaled_dot_product_attention.default, + torch.ops.aten.rsqrt.default, + torch.ops.aten.pow.Tensor_Scalar, + torch.ops.aten.topk.default, + ] +) +def annotate_1in1out(node: Node, quant_config: QuantizationConfig) -> None: + # skipping quantization if input is not float. + if _is_annotated([node]) or not _is_float_tensor(node.args[0]): + return + + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + # one inputs + one output case. + input_act_qspec = quant_config.input_activation + quantization_annotation.input_qspec_map[node.args[0]] = input_act_qspec + quantization_annotation.output_qspec = quant_config.output_activation + + node.meta["quantization_annotation"] = quantization_annotation + node.meta["quantization_annotation"]._annotated = True + + +# CASE 3-2: Single input + Single Out case with Shared Quant +@register_annotator( + [ + torch.ops.aten.permute.default, + torch.ops.aten.view.default, + torch.ops.aten._unsafe_view.default, + torch.ops.aten.squeeze.default, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dims, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.transpose.int, + torch.ops.aten.expand.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.reshape.default, + torch.ops.aten.select.int, + torch.ops.aten.flatten.using_ints, + torch.ops.aten.pad.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.to.dtype, + ] +) +def annotate_1in1out_with_SharedQuant( + node: Node, quant_config: QuantizationConfig +) -> None: + input_qspec_map = {} + input = node.args[0] + assert isinstance(input, Node) + if _is_annotated([node]) or not _is_float_tensor(input): + return + + shared_qspec = SharedQuantizationSpec((input, node)) + + # get QuantAnnot from the input path + shared_quant_node = _get_quantization_annotation(input) + if shared_quant_node: + input_qspec_map[shared_quant_node] = SharedQuantizationSpec(shared_quant_node) + shared_qspec = SharedQuantizationSpec((shared_quant_node, node)) + else: + # if no QuantAnnot in the input path + input_qspec_map[input] = quant_config.input_activation + shared_qspec = SharedQuantizationSpec((input, node)) + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_qspec, + _annotated=True, + ) + + +# CASE 3-3: Single input + Single Out case with FP +@register_annotator( + [ + torch.ops.aten.softmax.int, + torch.ops.aten._softmax.default, + torch.ops.aten._safe_softmax.default, + torch.ops.aten.log_softmax.int, + ] +) +def annotate_1in1out_with_SharedQuant_for_FP( + node: Node, quant_config: QuantizationConfig +) -> None: + input_qspec_map = {} + input = node.args[0] + assert isinstance(input, Node) + + if _is_annotated([node]) or not _is_float_tensor(input): + return + + if input.target in ADD_OPS and _is_annotated([input]): + del input.meta["quantization_annotation"] + + # get QuantAnnot from the input path + shared_quant_node = _get_quantization_annotation(input) + if shared_quant_node: + # if QuantAnnot in the input path, input_qspec is shared, but output_qspec is not. + input_qspec_map[shared_quant_node] = SharedQuantizationSpec(shared_quant_node) + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quant_config.output_activation, + _annotated=True, + ) + else: + # if no QuantAnnot in the input path + node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=quant_config.output_activation, + _annotated=True, + ) + + +# CASE 4: One value input + one index input with Shared Quant +@register_annotator([torch.ops.aten.index.Tensor]) +def annotate_index(node: Node, quant_config: QuantizationConfig) -> None: + input_qspec_map = {} + input = node.args[0] + assert isinstance(input, Node) + + if _is_annotated([node]) or not _is_float_tensor(input): + return + + # get QuantAnnt from the input path + shared_quant_node = _get_quantization_annotation(input) + if shared_quant_node: + shared_qspec = SharedQuantizationSpec((shared_quant_node, node)) + input_qspec_map[input] = quant_config.input_activation + + # sharing QuantAnnot with the parent + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_qspec, + _annotated=True, + ) + + +# CASE 5 input + index + value & output with Shared Quant +@register_annotator( + [torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default] +) +def annotate_index_put(node: Node, quant_config: QuantizationConfig) -> None: + input_qspec_map = {} + input = node.args[0] # from KVCache in LLAMA + value = node.args[2] # from linear projection layer + assert isinstance(input, Node) + assert isinstance(value, Node) + + if _is_annotated([node]) or not _is_float_tensor(input): + return + + # get QuantAnnot from input path + shared_quant_node = _get_quantization_annotation(input) + if shared_quant_node: + shared_qspec = SharedQuantizationSpec((shared_quant_node, node)) + input_qspec_map[input] = shared_qspec + input_qspec_map[value] = shared_qspec + output_qspec = shared_qspec + else: + # if no QuantAnnot in input path, asign the default QuantAnnot from quant_config. + input_qspec_map[input] = quant_config.input_activation + input_qspec_map[value] = SharedQuantizationSpec((input, node)) + output_qspec = SharedQuantizationSpec((input, node)) + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_qspec, + _annotated=True, + ) + + +# CASE 6 unbind + getitem case +# (inputQuant--unbinde--no Qunat) --> (no Qunat--getitem--outputQuant) +@register_annotator([torch.ops.aten.unbind.int]) +def annotate_unbind(node: Node, quant_config: QuantizationConfig) -> None: + input_qspec_map = {} + input = node.args[0] + assert isinstance(input, Node) + + if _is_annotated([node]) or not _is_float_tensor(input): + return + + # get QuantAnnot from input path + shared_quant_node = _get_quantization_annotation(input) + if shared_quant_node: + input_qspec_map[input] = quant_config.input_activation + shared_qspec = SharedQuantizationSpec((shared_quant_node, node)) + else: + # if no QuantAnnot in input path, asign the default QuantAnnot from quant_config. + input_qspec_map[input] = quant_config.input_activation + shared_qspec = SharedQuantizationSpec((input, node)) + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_qspec, + _annotated=True, + ) + + for users_node in node.users: + users_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=shared_qspec, + _annotated=True, + ) + + +# CASE 7: stand-alone Conv2d and Conv1d +@register_annotator( + [ + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.linear.default, + ] +) +def annotate_conv2d(node: Node, quant_config: QuantizationConfig) -> None: + # skipping quantization if weights are not float + if _is_annotated([node]) or not _is_float_tensor(node.args[1]): + return + + input = node.args[0] + # input & weight (or bias) setting for Conv node(producer_node) + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + shared_quant_node = _get_quantization_annotation(input) + if shared_quant_node: + quantization_annotation.input_qspec_map[input] = SharedQuantizationSpec( + shared_quant_node + ) + else: + quantization_annotation.input_qspec_map[input] = quant_config.input_activation + quantization_annotation.input_qspec_map[node.args[1]] = quant_config.weight + if len(node.args) > 2 and quant_config.bias is not None: + quantization_annotation.input_qspec_map[node.args[2]] = quant_config.bias + quantization_annotation.output_qspec = quant_config.output_activation + + node.meta["quantization_annotation"] = quantization_annotation + node.meta["quantization_annotation"]._annotated = True + + +# CASE 8: embedding +@register_annotator([torch.ops.aten.embedding.default]) +def annotate_embedding(node: Node, quant_config: QuantizationConfig) -> None: + input_qspec_map = {} + weight = node.args[0] + if _is_annotated([node]) or not _is_float_tensor(weight): + return + + input_qspec_map[weight] = quant_config.input_activation + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quant_config.output_activation, + _annotated=True, + ) + + +# CASE 9: Concat & Stack +@register_annotator( + [ + torch.ops.aten.cat.default, + torch.ops.aten.concat.default, + torch.ops.aten.stack.default, + ] +) +def annotate_cat(node: Node, quant_config: QuantizationConfig) -> None: + inputs = node.args[0] + first_input = inputs[0] + assert isinstance(inputs, list) + assert isinstance(first_input, Node) + + if _is_annotated([node]) or not _is_float_tensor(first_input): + return + + input_qspec_map = {} + shared_qspec = SharedQuantizationSpec((first_input, node)) + for input in inputs: + if input == first_input: + input_qspec_map[input] = quant_config.input_activation + else: + input_qspec_map[input] = shared_qspec + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_qspec, + _annotated=True, + ) + + +# CASE 10: various normalizations +@register_annotator([torch.ops.aten.rms_norm.default]) +def annotate_rms_norm(node: Node, quant_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + quantization_annotation.input_qspec_map[node.args[0]] = ( + quant_config.input_activation + ) # active + quantization_annotation.input_qspec_map[node.args[2]] = ( + quant_config.input_activation + ) # weight + quantization_annotation.output_qspec = quant_config.output_activation + node.meta["quantization_annotation"] = quantization_annotation + node.meta["quantization_annotation"]._annotated = True + + +@register_annotator([torch.ops.aten.group_norm.default]) +def annotate_group_norm(node: Node, quant_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + quantization_annotation.input_qspec_map[node.args[0]] = ( + quant_config.input_activation + ) # active + quantization_annotation.input_qspec_map[node.args[2]] = ( + quant_config.weight + ) # weight + quantization_annotation.output_qspec = quant_config.output_activation + + node.meta["quantization_annotation"] = quantization_annotation + node.meta["quantization_annotation"]._annotated = True + + +@register_annotator([torch.ops.aten.layer_norm.default]) +def annotate_layer_norm(node: Node, quant_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + quantization_annotation.input_qspec_map[node.args[0]] = ( + quant_config.input_activation + ) # active + quantization_annotation.input_qspec_map[node.args[2]] = ( + quant_config.input_activation + ) # weight + quantization_annotation.output_qspec = quant_config.output_activation + + node.meta["quantization_annotation"] = quantization_annotation + node.meta["quantization_annotation"]._annotated = True + + +@register_annotator([torch.ops.aten._native_batch_norm_legit_no_training.default]) +def annotate_batch_norm(node: Node, quant_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + + quantization_annotation.input_qspec_map[node.args[0]] = ( + quant_config.input_activation + ) # active + + quantization_annotation.input_qspec_map[node.args[1]] = ( + quant_config.input_activation + ) # weight + quantization_annotation.output_qspec = quant_config.output_activation + + node.meta["quantization_annotation"] = quantization_annotation + node.meta["quantization_annotation"]._annotated = True + + +# CASE 11: Sigmoid +@register_annotator([torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default]) +def annotate_sigmoid(node: Node, quant_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + input_qspec_map = {} + input_act = node.args[0] + input_qspec_map[input_act] = quant_config.input_activation + + assert isinstance(input_act, Node) + out_qconf = quant_config.output_activation + + q_max = ( + torch.iinfo(out_qconf.dtype).max + if out_qconf.quant_max is None + else out_qconf.quant_max + ) + q_min = ( + torch.iinfo(out_qconf.dtype).min + if out_qconf.quant_min is None + else out_qconf.quant_min + ) + + scale = 1 / (q_max - q_min + 1) + + bias_obs_ctr = FixedQParamsObserver.with_args( + scale=scale, + zero_point=0, + dtype=quant_config.output_activation.dtype, + qscheme=torch.torch.per_tensor_affine, + quant_max=q_max, + quant_min=q_min, + ) + + # make sigmoid map to the range between 0~1 + out_act_quantization_spec = QuantizationSpec( + dtype=quant_config.output_activation.dtype, + quant_max=q_max, + quant_min=q_min, + observer_or_fake_quant_ctr=bias_obs_ctr, + qscheme=torch.torch.per_tensor_affine, + ) + + if _is_float_tensor(node): + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=out_act_quantization_spec, + _annotated=True, + ) diff --git a/backends/samsung/quantizer/qconfig.py b/backends/samsung/quantizer/qconfig.py new file mode 100644 index 00000000000..f32c8d39796 --- /dev/null +++ b/backends/samsung/quantizer/qconfig.py @@ -0,0 +1,174 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import IntEnum, unique +from typing import Callable, Optional + +import torch +from torchao.quantization.pt2e import ( + FakeQuantize, + MinMaxObserver, + PerChannelMinMaxObserver, +) +from torchao.quantization.pt2e.quantizer import QuantizationSpec + + +@unique +class Precision(IntEnum): + A8W8 = 3 + + +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + weight: Optional[QuantizationSpec] + bias: Optional[QuantizationSpec | Callable] + + +def get_quant_config( + precision: Precision, + is_per_channel: bool = False, + is_qat: bool = False, +) -> QuantizationConfig: + + precision_mappings = { + Precision.A8W8: get_a8w8_enn_quant_config, + } + if precision not in precision_mappings: + raise RuntimeError("Unrecognized precision setting.") + + is_weight_symm = is_per_channel + + qconfig_fn = precision_mappings[precision] + return qconfig_fn(is_per_channel, is_qat, wei_symmetric=is_weight_symm) + + +def _get_activation_qspec( + dtype, + is_symmetric, + is_qat, + observer_cls=MinMaxObserver, + quant_min=None, + quant_max=None, +): + eps_value = 2**-12 + if quant_max is None: + quant_max = torch.iinfo(dtype).max + if quant_min is None: + quant_min = torch.iinfo(dtype).min + + qscheme = torch.per_tensor_symmetric if is_symmetric else torch.per_tensor_affine + if is_qat: + observer_or_fake_quant = FakeQuantize.with_args( + observer=observer_cls, eps=eps_value + ) + else: + observer_or_fake_quant = observer_cls.with_args(eps=eps_value) + + return QuantizationSpec( + dtype=dtype, + quant_min=quant_min, + quant_max=quant_max, + qscheme=qscheme, + observer_or_fake_quant_ctr=observer_or_fake_quant, + ) + + +def _get_weight_qspec( + dtype, is_symmetric, is_per_channel, is_qat, quant_min=None, quant_max=None +): + assert is_symmetric or not is_per_channel, "Not support asymm+perchannel mode" + + eps_value = 2**-12 + + if quant_max is None: + quant_max = torch.iinfo(dtype).max + if quant_min is None: + quant_min = torch.iinfo(dtype).min + + if not is_per_channel: + qscheme = ( + torch.per_tensor_symmetric if is_symmetric else torch.per_tensor_affine + ) + observer_cls = MinMaxObserver + else: + qscheme = ( + torch.per_channel_symmetric if is_symmetric else torch.per_channel_affine + ) + observer_cls = PerChannelMinMaxObserver + + if is_qat: + observer_or_fake_quant = FakeQuantize.with_args( + observer=observer_cls, eps=eps_value + ) + else: + observer_or_fake_quant = observer_cls.with_args(eps=eps_value) + + return QuantizationSpec( + dtype=dtype, + quant_min=quant_min, + quant_max=quant_max, + qscheme=qscheme, + ch_axis=0, + observer_or_fake_quant_ctr=observer_or_fake_quant, + ) + + +def get_a8w8_enn_quant_config( + is_per_channel=True, is_qat=False, act_symmetric=False, wei_symmetric=False +) -> QuantizationConfig: + act_quantization_spec = _get_activation_qspec(torch.int8, act_symmetric, is_qat) + wgt_quantization_spec = _get_weight_qspec( + torch.int8, wei_symmetric, is_per_channel, is_qat + ) + bias_quantization_spec = None + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=wgt_quantization_spec, + bias=bias_quantization_spec, + ) + return quantization_config + + +class QuantInfo: + def __init__(self, torch_dtype: torch.dtype, string: str): + self._torch_dtype = torch_dtype + self._string = string + + @property + def torch_dtype(self): + return self._torch_dtype + + @property + def string(self): + return self._string + + +class QuantInfoManager: + QUANT_INFO_MAP = { + Precision.A8W8: (QuantInfo(torch.int8, "INT8"), QuantInfo(torch.int8, "INT8")), + } + FP_INFO = ( + QuantInfo(torch.float32, "FLOAT32"), + QuantInfo(torch.float32, "FLOAT32"), + ) + + def __init__(self): + self.precision = None + + def set_precision(self, precision: Precision): + self.precision = precision + + @property + def weight_precison(self) -> Optional[QuantInfo]: + return self.QUANT_INFO_MAP.get(self.precision, self.FP_INFO)[0] + + @property + def act_precision(self) -> Optional[QuantInfo]: + return self.QUANT_INFO_MAP.get(self.precision, self.FP_INFO)[1] diff --git a/backends/samsung/quantizer/quantizer.py b/backends/samsung/quantizer/quantizer.py new file mode 100644 index 00000000000..cf46677d000 --- /dev/null +++ b/backends/samsung/quantizer/quantizer.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Sequence + +import torch +from torch.fx import GraphModule +from torchao.quantization.pt2e.quantizer import Quantizer + +from .annotator import annotate +from .qconfig import get_quant_config, Precision, QuantInfoManager + + +global_quant_info = QuantInfoManager() + + +class EnnQuantizer(Quantizer): + + def __init__(self): + super().__init__() + + self._precision = Precision.A8W8 + global_quant_info.set_precision(self._precision) + self._is_per_channel = True + self._is_qat = False + self.custom_quant_annotations: Sequence[Callable] = [] + + def setup_precision(self, quant_dtype: Precision) -> None: + assert quant_dtype in Precision, f"No support for Precision {quant_dtype}." + self._precision = quant_dtype + global_quant_info.set_precision(self._precision) + + def setup_quant_params( + self, quant_dtype: Precision, is_per_channel=True, is_qat=False + ) -> None: + assert quant_dtype in Precision, f"No support for Precision {quant_dtype}." + self._precision = quant_dtype + self._is_per_channel = is_per_channel + self._is_qat = is_qat + + def annotate(self, model: GraphModule) -> GraphModule: + self._annotate(model) + self._annotate_custom_annotation(model) + return model + + def _annotate(self, gm: GraphModule) -> None: + quant_config = get_quant_config( + self._precision, self._is_per_channel, self._is_qat + ) + annotate(gm.graph, quant_config) + + def add_custom_quant_annotations( + self, custom_quant_annotations: Sequence[Callable] + ) -> None: + self.custom_quant_annotations = custom_quant_annotations + + def _annotate_custom_annotation(self, gm: GraphModule) -> None: + for annotation_func in self.custom_quant_annotations: + annotation_func(gm) + + def validate(self, model: torch.fx.GraphModule) -> None: + return diff --git a/backends/samsung/serialization/compile_options.py b/backends/samsung/serialization/compile_options.py index 1ad2350cfeb..a4af40368e9 100644 --- a/backends/samsung/serialization/compile_options.py +++ b/backends/samsung/serialization/compile_options.py @@ -11,7 +11,8 @@ from dataclasses import dataclass from enum import IntEnum, unique -import pkg_resources +from importlib.resources import files + from executorch.exir._serialize._dataclass import _DataclassEncoder from executorch.exir._serialize._flatbuffer import _flatc_compile from executorch.exir.backend.backend_details import CompileSpec @@ -36,12 +37,15 @@ def gen_samsung_backend_compile_spec_core(options: EnnExecuTorchOptions) -> Comp with tempfile.TemporaryDirectory() as d: # schema schema_path = os.path.join(d, "{}.fbs".format(COMPILE_OPTION_SCHEMA_NAME)) + + schema_content = ( + files(__package__) + .joinpath(f"{COMPILE_OPTION_SCHEMA_NAME}.fbs") + .read_bytes() + ) + with open(schema_path, "wb") as schema_file: - schema_file.write( - pkg_resources.resource_string( - __name__, "{}.fbs".format(COMPILE_OPTION_SCHEMA_NAME) - ) - ) + schema_file.write(schema_content) # dump json json_path = os.path.join(d, "{}.json".format(COMPILE_OPTION_SCHEMA_NAME)) enn_options_json = json.dumps(options, cls=_DataclassEncoder, indent=4) diff --git a/backends/samsung/serialization/enn_graph_schema.py b/backends/samsung/serialization/enn_graph_schema.py index 7e74182f9d7..5209a8672ee 100644 --- a/backends/samsung/serialization/enn_graph_schema.py +++ b/backends/samsung/serialization/enn_graph_schema.py @@ -5,13 +5,16 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import executorch.backends.samsung.python.PyGraphWrapperAdaptor as PyGraphWrapper import numpy as np import torch +from executorch.backends.samsung.builders.utils import DATA_TYPE_STR_MAPPING +from executorch.backends.samsung.utils.constants import QuantConstants +from executorch.backends.samsung.utils.utils import quantize_tensor class EnnGraph: @@ -24,6 +27,10 @@ def __init__(self): self.inputs = [] self.outputs = [] + def init(self, name: str, soc_name): + self.name = name + self.soc_name = soc_name + def define_op( self, name, @@ -46,22 +53,54 @@ def define_op( py_param_wrapper.SetScalarValue(params[key]) else: logging.error("Unsupported param type.") + # Set op.AddOpParam(py_param_wrapper) self.graph.DefineOpNode(op) - def define_tensor( + def define_tensor( # noqa: C901 self, name: str, shape: List, data_type: str, tensor_type: str, data: Optional[Union[np.ndarray, torch.Tensor]] = None, + quant_param: Optional[Dict[str, Any]] = None, ) -> int: layout = "NCHW" if len(shape) == 4 else "UNDEFINED" + if quant_param is not None: + data_type = DATA_TYPE_STR_MAPPING[ + quant_param[QuantConstants.QUANT_KEY.quant_dtype] + ] + tensor = PyGraphWrapper.PyEnnTensorWrapper(name, shape, data_type, layout) + if quant_param is not None: + need_quantize = True + + scales = self._affine_meta_param( + quant_param[QuantConstants.QUANT_KEY.scale] + ) + zero_points = self._affine_meta_param( + quant_param[QuantConstants.QUANT_KEY.zero_point] + ) + q_dtype = self._affine_meta_param( + quant_param[QuantConstants.QUANT_KEY.quant_dtype] + ) + tensor.AddQuantizeParam(q_dtype, scales, zero_points) + + if need_quantize and data is not None: + if isinstance(data, np.ndarray): + data = torch.tensor(data) + data = quantize_tensor( + data, + scales, + zero_points, + quant_param[QuantConstants.QUANT_KEY.quant_dtype], + axis=quant_param.get("axis"), + ) + if data is not None: if isinstance(data, torch.Tensor): data = data.detach().numpy() @@ -83,3 +122,20 @@ def finish(self): def serialize(self): return self.graph.Serialize() + + @staticmethod + def _affine_meta_param(param: Any) -> str: + type_str_affine_table = { + torch.int8: "AINT8", + } + if isinstance(param, str): + return param + if isinstance(param, (float, int)): + return [param] + if hasattr(param, "tolist"): + return param.tolist() + if isinstance(param, torch.dtype): + # Convenient for debugging + param = type_str_affine_table.get(param, "") + + return param diff --git a/backends/samsung/utils/constants.py b/backends/samsung/utils/constants.py new file mode 100644 index 00000000000..7c3997b9fe2 --- /dev/null +++ b/backends/samsung/utils/constants.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 Samsung Electronics Co. LTD +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.exir.dialects._ops import ops as exir_ops + + +class QuantConstants: + # TODO: check keys + class QUANT_KEY: + scale = "scales" + zero_point = "zero_points" + quant_min = "quant_min" + quant_max = "quant_max" + quant_dtype = "quant_dtype" + + PERCHANNEL_KEY_MAP = { + "scales": QUANT_KEY.scale, + "zero_points": QUANT_KEY.zero_point, + "quant_min": QUANT_KEY.quant_min, + "quant_max": QUANT_KEY.quant_max, + "dtype": QUANT_KEY.quant_dtype, + } + # SNC ir always use key 'scales' and 'zero_points' + PERTENSOR_KEY_MAP = { + "scale": QUANT_KEY.scale, + "zero_point": QUANT_KEY.zero_point, + "quant_min": QUANT_KEY.quant_min, + "quant_max": QUANT_KEY.quant_max, + "dtype": QUANT_KEY.quant_dtype, + } + + QUANT_OPS_KEY_MAP = { + exir_ops.edge.quantized_decomposed.quantize_per_channel.default: PERCHANNEL_KEY_MAP, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: PERTENSOR_KEY_MAP, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor: PERTENSOR_KEY_MAP, + } + + DEQUANT_OPS_KEY_MAP = { + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: PERTENSOR_KEY_MAP, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: PERTENSOR_KEY_MAP, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: PERCHANNEL_KEY_MAP, + } diff --git a/backends/samsung/utils/export_utils.py b/backends/samsung/utils/export_utils.py index aaf407ef0b3..39992f2ea2a 100644 --- a/backends/samsung/utils/export_utils.py +++ b/backends/samsung/utils/export_utils.py @@ -4,20 +4,30 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Tuple +import logging +from typing import List, Optional, Tuple import executorch.exir as exir import torch +from executorch.backends.samsung._passes.fuse_conv_act import FuseConvActPass +from executorch.backends.samsung._passes.remove_useless_ops import RemoveUselessOpPass from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer.quantizer import EnnQuantizer, Precision +from executorch.backends.transforms.decompose_sdpa import ( + DecomposeScaledDotProductAttention, +) from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir import EdgeCompileConfig from executorch.exir.backend.backend_details import CompileSpec - from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_manager import PassType from executorch.exir.program._program import to_edge_transform_and_lower +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def get_edge_compile_config(): + # Maybe most ops in non-decomposition list should be added here + # TODO: to confirm whether all op in none-decomposed table should be added here return EdgeCompileConfig( _skip_dim_order=True, _core_aten_ops_exception_list=[ @@ -29,24 +39,55 @@ def get_edge_compile_config(): exir_ops.edge.aten._safe_softmax.default, exir_ops.edge.aten.layer_norm.default, exir_ops.edge.aten.matmul.default, + exir_ops.edge.aten.hardsigmoid.default, ], ) +def get_enn_pass_list() -> List[PassType]: + return [ + RemoveUselessOpPass(), + RemoveCloneOpsTransform(), + FuseConvActPass(), + ] + + +def quantize_module( + module: torch.nn.Module, + inputs, + calibration_dataset, + precision: Precision, + is_per_channel: bool = True, + is_qat: bool = False, +) -> torch.nn.Module: + quantizer = EnnQuantizer() + quantizer.setup_quant_params(precision, is_per_channel, is_qat) + logging.info("Export nn module for quantization...") + exported_module = torch.export.export_for_training(module, inputs).module() + DecomposeScaledDotProductAttention()(exported_module) + logging.info("Quantizing the module...") + annotated_module = prepare_pt2e(exported_module, quantizer) + for data in calibration_dataset: + annotated_module(*data) + quantized_module = convert_pt2e(annotated_module, fold_quantize=False) + logging.info("Quantizing finished.") + return quantized_module + + def to_edge_transform_and_lower_to_enn( module: torch.nn.Module, inputs: Tuple[torch.Tensor], + custom_pass_config: List[PassType] = None, compile_specs: Optional[CompileSpec] = None, ) -> exir.ExecutorchProgramManager: - assert ( - compile_specs is not None - ), "Please provide compile specifications for enn backend" + assert compile_specs is not None, "For now, we must deliver complile specs" prog = torch.export.export(module, inputs) - - ahead_pass_list = [RemoveCloneOpsTransform()] + pass_list = get_enn_pass_list() + if custom_pass_config: + pass_list.extend(custom_pass_config) return to_edge_transform_and_lower( prog, - ahead_pass_list, + pass_list, {"forward": [EnnPartitioner(compile_specs)]}, compile_config=get_edge_compile_config(), ) diff --git a/backends/samsung/utils/utils.py b/backends/samsung/utils/utils.py index 5da9808f38f..bbbec518b2a 100644 --- a/backends/samsung/utils/utils.py +++ b/backends/samsung/utils/utils.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List +from typing import List, Optional, Tuple import torch from executorch.backends.transforms.utils import is_param_node from executorch.exir.backend.backend_details import CompileSpec +from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram @@ -35,3 +36,90 @@ def is_graph_output(node: torch.fx.Node) -> bool: ): return True return False + + +def _quantize_per_tensor( + in_tensor: torch.Tensor, + scales: List[float], + zeropoints: List[int], + dtype: torch.dtype, + qrange: Optional[Tuple[int, int]], +): + assert ( + len(scales) == 1 + ), "For per-tensor quantization, there should be only one scale/zeropoint" + return exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( + in_tensor, + torch.Tensor(scales), + torch.Tensor(zeropoints), + qrange[0], + qrange[1], + dtype, + ) + + +def _quantize_per_channel( + in_tensor: torch.Tensor, + scales: List[float], + zeropoints: List[int], + dtype: torch.dtype, + qrange: Optional[Tuple[int, int]], + axis: Optional[int], # Only for per-channel +): + assert ( + len(scales) == in_tensor.shape[axis] + ), "Shape not match for quant params and input tensor" + return exir_ops.edge.quantized_decomposed.quantize_per_channel.default( + in_tensor, + torch.Tensor(scales), + torch.Tensor(zeropoints), + axis, + qrange[0], + qrange[1], + dtype, + ) + + +def quantize_tensor( + in_tensor: torch.Tensor, + scales: List[float], + zeropoints: List[int], + dtype: torch.dtype, + qrange: Optional[Tuple[int, int]] = None, + axis: Optional[int] = None, # Only for per-channel +) -> torch.Tensor: + """ + To quantize constant tensor by executorch OPs. If `axis` not set, we quantize the tensor by per tensor. + If `axis` was set, we do per-channel quantize. + + :param in_tensor: The tensor to be quantized + :param scales: List of scales. For per-tensor quantization, it should contain only one element + :param zeropoints: List of zeropoints. For per-tensor quantization, it should contain only one element + :param dtype: The output dtype + :param qrange: The quantization range (qmin, qmax). + If not set, we will get the maximum range of the dtype by `torch.iinfo` + :param axis: We do per-channel quantize by which axis. + Only when this parameter set, we do per-channel quantization + :type in_tensor: torch.Tensor + :type scalse: List[float] + :type zeropoints: List[int] + :type dtype: torch.dtype + :type qrange: Optional[Tuple[int,int]] + :type axis: Optional[int] + :return: The quantized tensor + """ + assert len(scales) == len( + zeropoints + ), "scales should have same shape with zeropoints" + if not qrange: + qrange = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) + + if axis is not None: + return _quantize_per_channel(in_tensor, scales, zeropoints, dtype, qrange, axis) + return _quantize_per_tensor( + in_tensor, + scales, + zeropoints, + dtype, + qrange, + ) diff --git a/backends/test/harness/stages/__init__.py b/backends/test/harness/stages/__init__.py index 36ed435ebd7..14431191621 100644 --- a/backends/test/harness/stages/__init__.py +++ b/backends/test/harness/stages/__init__.py @@ -1,6 +1,6 @@ from .export import Export from .partition import Partition -from .quantize import Quantize +from .quantize import Quantize, Quantize_ from .run_passes import RunPasses from .serialize import Serialize from .stage import Stage, StageType @@ -12,6 +12,7 @@ "Export", "Partition", "Quantize", + "Quantize_", "RunPasses", "Serialize", "Stage", diff --git a/backends/test/harness/stages/quantize.py b/backends/test/harness/stages/quantize.py index 9edb600e19f..6c6036c8104 100644 --- a/backends/test/harness/stages/quantize.py +++ b/backends/test/harness/stages/quantize.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Callable, Optional, Sequence, Tuple import torch @@ -15,6 +15,8 @@ prepare_qat_pt2e, ) from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.quant_api import quantize_ +from torchao.utils import unwrap_tensor_subclass class Quantize(Stage): @@ -79,3 +81,48 @@ def graph_module(self) -> str: def run_artifact(self, inputs): return self.converted_graph.forward(*inputs) + + +class Quantize_(Stage): + """ + TorchAO quantization stage using the quantize_ API. + """ + + def __init__( + self, + config: Any, + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, + ): + """ + Args: + config: TorchAO quantization config (e.g., Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig) + filter_fn: Optional filter function to select which modules to quantize + """ + self.config = config + self.filter_fn = filter_fn + self.quantized_module = None + + def stage_type(self) -> str: + return StageType.QUANTIZE + + def run( + self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] + ) -> None: + # Apply quantize_ to the model + quantize_(artifact, self.config, self.filter_fn) + + # Unwrap tensor subclasses for export compatibility + unwrap_tensor_subclass(artifact) + + self.quantized_module = artifact + + @property + def artifact(self) -> torch.nn.Module: + return self.quantized_module + + @property + def graph_module(self) -> torch.nn.Module: + return self.quantized_module + + def run_artifact(self, inputs): + return self.quantized_module.forward(*inputs) diff --git a/backends/test/harness/tester.py b/backends/test/harness/tester.py index 351bab4a605..02c6fc4c82d 100644 --- a/backends/test/harness/tester.py +++ b/backends/test/harness/tester.py @@ -1,3 +1,8 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import random from collections import Counter, OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple @@ -62,6 +67,7 @@ def __init__( StageType.RUN_PASSES: [ StageType.PARTITION, StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, ], # TODO Make this Stage optional StageType.PARTITION: [StageType.TO_EXECUTORCH], diff --git a/backends/test/multi_method_delegate_test.cpp b/backends/test/multi_method_delegate_test.cpp index e24585434c4..bf17d7c8743 100644 --- a/backends/test/multi_method_delegate_test.cpp +++ b/backends/test/multi_method_delegate_test.cpp @@ -5,6 +5,10 @@ #include #include +#include + +#include +#include #include #include @@ -12,6 +16,11 @@ #include #include +using executorch::backends::xnnpack::workspace_sharing_mode_option_key; +using executorch::backends::xnnpack::WorkspaceSharingMode; +using executorch::backends::xnnpack::xnnpack_backend_key; + +using executorch::runtime::BackendOptions; using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::HierarchicalAllocator; @@ -126,34 +135,61 @@ class XNNPACKMultiDelegateTest : public ETPTEMethodRunBaseTest { num_threads = 40; kMethodName = "forward"; } -}; -// This test is to validate the assumption that the delegate is thread safe. -// That includes the following: -// 1. The delegate can be initilized by multiple threads in parallel. -// 2. The delegate can be executed by multiple threads in parallel. -// 3. The delegate can be destroyed by multiple threads in parallel. -// Regardless of the underlying implementation of the delegate. -// This is particularly important when we have shared resources across -// delegate instances through a singleton backend instance. -TEST_F(XNNPACKMultiDelegateTest, MultipleThreads) { - ASSERT_NE(kTestPTE1Path.size(), 0); - ASSERT_NE(kTestPTE2Path.size(), 0); - ASSERT_NE(num_threads, 0); - ASSERT_NE(kMethodName.size(), 0); - - std::vector threads(num_threads); - std::atomic count{0}; - - for (int i = 0; i < num_threads; i++) { - threads[i] = std::thread([&, i]() { - run(i, i % 7 ? kTestPTE1Path : kTestPTE2Path, kMethodName, count); - }); + // This test is to validate the assumption that the delegate is thread safe. + // That includes the following: + // 1. The delegate can be initilized by multiple threads in parallel. + // 2. The delegate can be executed by multiple threads in parallel. + // 3. The delegate can be destroyed by multiple threads in parallel. + // Regardless of the underlying implementation of the delegate. + // This is particularly important when we have shared resources across + // delegate instances through a singleton backend instance. + void runStressTest() { + ASSERT_NE(kTestPTE1Path.size(), 0); + ASSERT_NE(kTestPTE2Path.size(), 0); + ASSERT_NE(num_threads, 0); + ASSERT_NE(kMethodName.size(), 0); + + std::vector threads(num_threads); + std::atomic count{0}; + + for (int i = 0; i < num_threads; i++) { + threads[i] = std::thread([&, i]() { + run(i, i % 7 ? kTestPTE1Path : kTestPTE2Path, kMethodName, count); + }); + } + for (int i = 0; i < num_threads; i++) { + threads[i].join(); + } + ASSERT_EQ(count, num_threads); } - for (int i = 0; i < num_threads; i++) { - threads[i].join(); + + void setWorkspaceSharingMode(WorkspaceSharingMode mode) { + executorch::runtime::runtime_init(); + + BackendOptions<1> backend_options; + backend_options.set_option( + workspace_sharing_mode_option_key, static_cast(mode)); + + auto status = executorch::runtime::set_option( + xnnpack_backend_key, backend_options.view()); + ASSERT_EQ(status, Error::Ok); } - ASSERT_EQ(count, num_threads); +}; + +TEST_F(XNNPACKMultiDelegateTest, MultipleThreadsSharingDisabled) { + setWorkspaceSharingMode(WorkspaceSharingMode::Disabled); + runStressTest(); +} + +TEST_F(XNNPACKMultiDelegateTest, MultipleThreadsPerModelSharing) { + setWorkspaceSharingMode(WorkspaceSharingMode::PerModel); + runStressTest(); +} + +TEST_F(XNNPACKMultiDelegateTest, MultipleThreadsGlobalSharing) { + setWorkspaceSharingMode(WorkspaceSharingMode::Global); + runStressTest(); } // TODO(T208989291): Add more tests here. For example, diff --git a/backends/test/suite/README.md b/backends/test/suite/README.md index 564f44362ad..901cd461dbe 100644 --- a/backends/test/suite/README.md +++ b/backends/test/suite/README.md @@ -5,37 +5,71 @@ This directory contains tests that validate correctness and coverage of backends These tests are intended to ensure that backends are robust and provide a smooth, "out-of-box" experience for users across the full span of input patterns. They are not intended to be a replacement for backend-specific tests, as they do not attempt to validate performance or that backends delegate operators that they expect to. ## Running Tests and Interpreting Output -Tests can be run from the command line, either using the runner.py entry point or the standard Python unittest runner. When running through runner.py, the test runner will report test statistics, including the number of tests with each result type. +Tests can be run from the command line using pytest. When generating a JSON test report, the runner will report detailed test statistics, including output accuracy, delegated nodes, lowering timing, and more. -Backends can be specified with the `ET_TEST_ENABLED_BACKENDS` environment variable. By default, all available backends are enabled. Note that backends such as Core ML or Vulkan may require specific hardware or software to be available. See the documentation for each backend for information on requirements. +Each backend and test flow (recipe) registers a pytest [marker](https://docs.pytest.org/en/stable/example/markers.html) that can be passed to pytest with the `-m marker` argument to filter execution. -Example: +To run all XNNPACK backend operator tests: ``` -ET_TEST_ENABLED_BACKENDS=xnnpack python -m executorch.backends.test.suite.runner +pytest -c /dev/nul backends/test/suite/operators/ -m backend_xnnpack -n auto ``` +To run all model tests for the CoreML static int8 lowering flow: +``` +pytest -c /dev/nul backends/test/suite/models/ -m flow_coreml_static_int8 -n auto ``` -2465 Passed / 2494 -16 Failed -13 Skipped -[Success] -736 Delegated -1729 Undelegated +To run a specific test: +``` +pytest -c /dev/nul backends/test/suite/ -k "test_prelu_f32_custom_init[xnnpack]" +``` -[Failure] -5 Lowering Fail -3 PTE Run Fail -8 Output Mismatch Fail +To generate a JSON report: +``` +pytest -c /dev/nul backends/test/suite/operators/ -n auto --json-report --json-report-file="test_report.json" ``` -Outcomes can be interpreted as follows: - * Success (delegated): The test passed and at least one op was delegated by the backend. - * Success (undelegated): The test passed with no ops delegated by the backend. This is a pass, as the partitioner works as intended. - * Skipped: test fails in eager or export (indicative of a test or dynamo issue). - * Lowering fail: The test fails in to_edge_transform_and_lower. - * PTE run failure: The test errors out when loading or running the method. - * Output mismatch failure: Output delta (vs eager) exceeds the configured tolerance. +See [pytest-json-report](https://pypi.org/project/pytest-json-report/) for information on the report format. The test logic in this repository attaches additional metadata to each test entry under the `metadata`/`subtests` keys. One entry is created for each call to `test_runner.lower_and_run_model`. + +Here is a excerpt from a test run, showing a successful run of the `test_add_f32_bcast_first[xnnpack]` test. +```json +"tests": [ + { + "nodeid": "operators/test_add.py::test_add_f32_bcast_first[xnnpack]", + "lineno": 38, + "outcome": "passed", + "keywords": [ + "test_add_f32_bcast_first[xnnpack]", + "flow_xnnpack", + "backend_xnnpack", + ... + ], + "metadata": { + "subtests": [ + { + "Test ID": "test_add_f32_bcast_first[xnnpack]", + "Test Case": "test_add_f32_bcast_first", + "Subtest": 0, + "Flow": "xnnpack", + "Result": "Pass", + "Result Detail": "", + "Error": "", + "Delegated": "True", + "Quantize Time (s)": null, + "Lower Time (s)": "2.881", + "Output 0 Error Max": "0.000", + "Output 0 Error MAE": "0.000", + "Output 0 SNR": "inf", + "Delegated Nodes": 1, + "Undelegated Nodes": 0, + "Delegated Ops": { + "aten::add.Tensor": 1 + }, + "PTE Size (Kb)": "1.600" + } + ] + } +``` ## Backend Registration @@ -43,11 +77,11 @@ To plug into the test framework, each backend should provide an implementation o At a minimum, the backend will likely need to provide a custom implementation of the Partition and ToEdgeTransformAndLower stages using the appropriate backend partitioner. See backends/xnnpack/test/tester/tester.py for an example implementation. -Once a tester is available, the backend flow(s) can be added in __init__.py in this directory by adding an entry to `ALL_TESTER_FLOWS`. Each flow entry consists of a name (used in the test case naming) and a function to instantiate a tester for a given model and input tuple. +Once a tester is available, the backend flow(s) can be added under flows/ and registered in flow.py. It is intended that this will be unified with the lowering recipes under executorch/export in the near future. ## Test Cases -Operator test cases are defined under the operators/ directory. Tests are written in a backend-independent manner, and each test is programmatically expanded to generate a variant for each registered backend flow. The `@operator_test` decorator is applied to each test class to trigger this behavior. Tests can also be tagged with an appropriate type specifier, such as `@dtype_test`, to generate variants for each dtype. The decorators and "magic" live in __init__.py in this directory. +Operator test cases are defined under the operators/ directory. Model tests are under models/. Tests are written in a backend-independent manner, and each test is programmatically expanded to generate a variant for each registered backend flow by use of the `test_runner` fixture parameter. Tests can additionally be parameterized using standard pytest decorators. Parameterizing over dtype is a common use case. ## Evolution of this Test Suite diff --git a/backends/test/suite/__init__.py b/backends/test/suite/__init__.py index 43d4e16818f..734a6690fd2 100644 --- a/backends/test/suite/__init__.py +++ b/backends/test/suite/__init__.py @@ -11,6 +11,7 @@ import os import executorch.backends.test.suite.flow +import torch from executorch.backends.test.suite.flow import TestFlow from executorch.backends.test.suite.runner import runner_main @@ -55,6 +56,11 @@ def get_test_flows() -> dict[str, TestFlow]: return _ALL_TEST_FLOWS +def dtype_to_str(dtype: torch.dtype) -> str: + # Strip off "torch." + return str(dtype)[6:] + + def load_tests(loader, suite, pattern): package_dir = os.path.dirname(__file__) discovered_suite = loader.discover( diff --git a/backends/test/suite/conftest.py b/backends/test/suite/conftest.py new file mode 100644 index 00000000000..70a97454c4e --- /dev/null +++ b/backends/test/suite/conftest.py @@ -0,0 +1,182 @@ +from typing import Any + +import pytest +import torch + +from executorch.backends.test.suite.flow import all_flows +from executorch.backends.test.suite.reporting import _sum_op_counts +from executorch.backends.test.suite.runner import run_test + + +def pytest_configure(config): + backends = set() + + for flow in all_flows().values(): + config.addinivalue_line( + "markers", + f"flow_{flow.name}: mark a test as testing the {flow.name} flow", + ) + + if flow.backend not in backends: + config.addinivalue_line( + "markers", + f"backend_{flow.backend}: mark a test as testing the {flow.backend} backend", + ) + backends.add(flow.backend) + + +class TestRunner: + def __init__(self, flow, test_name, test_base_name): + self._flow = flow + self._test_name = test_name + self._test_base_name = test_base_name + self._subtest = 0 + self._results = [] + + def lower_and_run_model( + self, + model: torch.nn.Module, + inputs: Any, + generate_random_test_inputs=True, + dynamic_shapes=None, + ): + run_summary = run_test( + model, + inputs, + self._flow, + self._test_name, + self._test_base_name, + self._subtest, + None, + generate_random_test_inputs=generate_random_test_inputs, + dynamic_shapes=dynamic_shapes, + ) + + self._subtest += 1 + self._results.append(run_summary) + + if not run_summary.result.is_success(): + if run_summary.result.is_backend_failure(): + raise RuntimeError("Test failure.") from run_summary.error + else: + # Non-backend failure indicates a bad test. Mark as skipped. + pytest.skip( + f"Test failed for reasons other than backend failure. Error: {run_summary.error}" + ) + + +@pytest.fixture( + params=[ + pytest.param( + f, + marks=[ + getattr(pytest.mark, f"flow_{f.name}"), + getattr(pytest.mark, f"backend_{f.backend}"), + ], + ) + for f in all_flows().values() + ], + ids=str, +) +def test_runner(request): + return TestRunner(request.param, request.node.name, request.node.originalname) + + +@pytest.hookimpl(optionalhook=True) +def pytest_json_runtest_metadata(item, call): + # Store detailed results in the test report under the metadata key. + metadata = {"subtests": []} + + if hasattr(item, "funcargs") and "test_runner" in item.funcargs: + runner_instance = item.funcargs["test_runner"] + + for record in runner_instance._results: + subtest_metadata = {} + + error_message = "" + if record.error is not None: + error_str = str(record.error) + if len(error_str) > 400: + error_message = error_str[:200] + "..." + error_str[-200:] + else: + error_message = error_str + + subtest_metadata["Test ID"] = record.name + subtest_metadata["Test Case"] = record.base_name + subtest_metadata["Subtest"] = record.subtest_index + subtest_metadata["Flow"] = record.flow + subtest_metadata["Result"] = record.result.to_short_str() + subtest_metadata["Result Detail"] = record.result.to_detail_str() + subtest_metadata["Error"] = error_message + subtest_metadata["Delegated"] = "True" if record.is_delegated() else "False" + subtest_metadata["Quantize Time (s)"] = ( + f"{record.quantize_time.total_seconds():.3f}" + if record.quantize_time + else None + ) + subtest_metadata["Lower Time (s)"] = ( + f"{record.lower_time.total_seconds():.3f}" + if record.lower_time + else None + ) + + for output_idx, error_stats in enumerate(record.tensor_error_statistics): + subtest_metadata[f"Output {output_idx} Error Max"] = ( + f"{error_stats.error_max:.3f}" + ) + subtest_metadata[f"Output {output_idx} Error MAE"] = ( + f"{error_stats.error_mae:.3f}" + ) + subtest_metadata[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}" + + subtest_metadata["Delegated Nodes"] = _sum_op_counts( + record.delegated_op_counts + ) + subtest_metadata["Undelegated Nodes"] = _sum_op_counts( + record.undelegated_op_counts + ) + if record.delegated_op_counts: + subtest_metadata["Delegated Ops"] = dict(record.delegated_op_counts) + if record.undelegated_op_counts: + subtest_metadata["Undelegated Ops"] = dict(record.undelegated_op_counts) + subtest_metadata["PTE Size (Kb)"] = ( + f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else "" + ) + + metadata["subtests"].append(subtest_metadata) + return metadata + + +@pytest.hookimpl(optionalhook=True) +def pytest_json_modifyreport(json_report): + # Post-process the report, mainly to populate metadata for crashed tests. The runtest_metadata + # hook doesn't seem to be called when there's a native crash, but xdist still creates a report + # entry. + + for test_data in json_report["tests"]: + if "metadata" not in test_data: + test_data["metadata"] = {} + metadata = test_data["metadata"] + if "subtests" not in metadata: + metadata["subtests"] = [] + subtests = metadata["subtests"] + + # Native crashes are recorded differently and won't have the full metadata. + # Pytest-xdist records crash info under the "???" key. + if "???" in test_data: + test_id = test_data["nodeid"].removeprefix("::") # Remove leading :: + test_base_id = test_id.split("[")[ + 0 + ] # Strip parameterization to get the base test case + params = test_id[len(test_base_id) + 1 : -1].split("-") + flow = params[0] + + crashed_test_meta = { + "Test ID": test_id, + "Test Case": test_base_id, + "Flow": flow, + "Result": "Fail", + "Result Detail": "Process Crash", + "Error": test_data["???"].get("longrepr", "Process crashed."), + } + subtests.append(crashed_test_meta) diff --git a/backends/test/suite/flow.py b/backends/test/suite/flow.py index b7a126eaf35..f3c9ee75083 100644 --- a/backends/test/suite/flow.py +++ b/backends/test/suite/flow.py @@ -1,6 +1,11 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import logging -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Callable from executorch.backends.test.harness import Tester @@ -35,6 +40,18 @@ class TestFlow: is_delegated: bool = True """ Indicates whether the flow is expected to generate CALL_DELEGATE nodes. """ + skip_patterns: list[str] = field(default_factory=lambda: []) + """ Tests with names containing any substrings in this list are skipped. """ + + supports_serialize: bool = True + """ True if the test flow supports the Serialize stage. """ + + def should_skip_test(self, test_name: str) -> bool: + return any(pattern in test_name for pattern in self.skip_patterns) + + def __str__(self): + return self.name + def all_flows() -> dict[str, TestFlow]: flows = [] @@ -109,4 +126,25 @@ def all_flows() -> dict[str, TestFlow]: except Exception as e: logger.info(f"Skipping QNN flow registration: {e}") + try: + from executorch.backends.test.suite.flows.arm import ( + ARM_ETHOS_U55_FLOW, + ARM_ETHOS_U85_FLOW, + ARM_TOSA_FP_FLOW, + ARM_TOSA_INT_FLOW, + ARM_VGF_FP_FLOW, + ARM_VGF_INT_FLOW, + ) + + flows += [ + ARM_TOSA_FP_FLOW, + ARM_TOSA_INT_FLOW, + ARM_ETHOS_U55_FLOW, + ARM_ETHOS_U85_FLOW, + ARM_VGF_FP_FLOW, + ARM_VGF_INT_FLOW, + ] + except Exception as e: + logger.info(f"Skipping ARM flow registration: {e}") + return {f.name: f for f in flows if f is not None} diff --git a/backends/test/suite/flows/arm.py b/backends/test/suite/flows/arm.py new file mode 100644 index 00000000000..29ef504d50c --- /dev/null +++ b/backends/test/suite/flows/arm.py @@ -0,0 +1,92 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Create flows for Arm Backends used to test operator and model suits + +from collections.abc import Callable + +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.quantizer import get_symmetric_quantization_config +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.util._factory import create_quantizer +from executorch.backends.test.suite.flow import TestFlow +from executorch.backends.xnnpack.test.tester.tester import Quantize + + +def _create_arm_flow( + name: str, + compile_spec_factory: Callable[[], ArmCompileSpec], + support_serialize: bool = True, + quantize: bool = True, + symmetric_io_quantization: bool = False, + per_channel_quantization: bool = True, + use_portable_ops: bool = True, + timeout: int = 1200, +) -> TestFlow: + + def _create_arm_tester(*args, **kwargs) -> ArmTester: + spec = compile_spec_factory() + kwargs["compile_spec"] = spec + return ArmTester( + *args, **kwargs, use_portable_ops=use_portable_ops, timeout=timeout + ) + + if quantize: + + def create_quantize_stage() -> Quantize: + spec = compile_spec_factory() + quantizer = create_quantizer(spec) + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization + ) + if symmetric_io_quantization: + quantizer.set_io(quantization_config) + return Quantize(quantizer, quantization_config) # type: ignore + + return TestFlow( + name, + backend="arm", + tester_factory=_create_arm_tester, + supports_serialize=support_serialize, + quantize=quantize, + quantize_stage_factory=(create_quantize_stage if quantize else False), # type: ignore + ) + + +ARM_TOSA_FP_FLOW = _create_arm_flow( + "arm_tosa_fp", + lambda: common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), + support_serialize=False, + quantize=False, +) +ARM_TOSA_INT_FLOW = _create_arm_flow( + "arm_tosa_int", + lambda: common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), + support_serialize=False, + quantize=True, +) +ARM_ETHOS_U55_FLOW = _create_arm_flow( + "arm_ethos_u55", + lambda: common.get_u55_compile_spec(), + quantize=True, +) +ARM_ETHOS_U85_FLOW = _create_arm_flow( + "arm_ethos_u85", + lambda: common.get_u85_compile_spec(), + quantize=True, +) +ARM_VGF_FP_FLOW = _create_arm_flow( + "arm_vgf_fp", + lambda: common.get_vgf_compile_spec(tosa_spec="TOSA-1.0+FP"), + quantize=False, + use_portable_ops=False, +) +ARM_VGF_INT_FLOW = _create_arm_flow( + "arm_vgf_int", + lambda: common.get_vgf_compile_spec(tosa_spec="TOSA-1.0+INT"), + quantize=True, + use_portable_ops=False, +) diff --git a/backends/test/suite/flows/coreml.py b/backends/test/suite/flows/coreml.py index fd956b64f05..8a532ff0003 100644 --- a/backends/test/suite/flows/coreml.py +++ b/backends/test/suite/flows/coreml.py @@ -19,6 +19,7 @@ def _create_coreml_flow( CoreMLTester, minimum_deployment_target=minimum_deployment_target ), quantize=quantize, + skip_patterns=["test_argmin", "test_argmax"], ) diff --git a/backends/test/suite/flows/qualcomm.py b/backends/test/suite/flows/qualcomm.py index 9998caa51b6..99deb3d4877 100644 --- a/backends/test/suite/flows/qualcomm.py +++ b/backends/test/suite/flows/qualcomm.py @@ -42,7 +42,7 @@ def create_quantize_stage() -> Quantize: QNN_TEST_FLOW = _create_qnn_flow("qnn") QNN_16A16W_TEST_FLOW = _create_qnn_flow( - "qnn_16a16w", quantize=True, quant_dtype=QuantDtype.use_8a8w, use_fp16=False + "qnn_16a16w", quantize=True, quant_dtype=QuantDtype.use_16a16w, use_fp16=False ) QNN_16A8W_TEST_FLOW = _create_qnn_flow( "qnn_16a8w", quantize=True, quant_dtype=QuantDtype.use_16a8w, use_fp16=False diff --git a/backends/test/suite/flows/vulkan.py b/backends/test/suite/flows/vulkan.py index 2a8c4e506fa..a3a4fb55aba 100644 --- a/backends/test/suite/flows/vulkan.py +++ b/backends/test/suite/flows/vulkan.py @@ -20,6 +20,7 @@ def _create_vulkan_flow_base( tester_factory=VulkanTester, quantize=quantize_stage_factory is not None, quantize_stage_factory=quantize_stage_factory, + skip_patterns=["float16", "float64"], # Not supported in swiftshader ) diff --git a/backends/test/suite/generate_markdown_summary.py b/backends/test/suite/generate_markdown_summary.py index 37bf758fed0..e54fc691723 100644 --- a/backends/test/suite/generate_markdown_summary.py +++ b/backends/test/suite/generate_markdown_summary.py @@ -1,7 +1,58 @@ import argparse import csv +import json import sys +from dataclasses import dataclass, field + + +@dataclass +class ResultCounts: + """ + Represents aggregated result counts for each status. + """ + + total: int = 0 + passes: int = 0 + fails: int = 0 + skips: int = 0 + by_detail: dict[str, int] = field(default_factory=lambda: {}) + + def add_row(self, result_value: str, result_detail: str) -> None: + """ + Update the result counts for the specified row. + """ + + self.total += 1 + + if result_value == "Pass": + self.passes += 1 + elif result_value == "Fail": + self.fails += 1 + elif result_value == "Skip": + self.skips += 1 + else: + raise RuntimeError(f"Unknown result value {result_value}") + + if result_detail: + if result_detail not in self.by_detail: + self.by_detail[result_detail] = 0 + + self.by_detail[result_detail] += 1 + + +@dataclass +class AggregatedSummary: + """ + Represents aggegrated summary data for the test run. + """ + + counts: ResultCounts + counts_by_params: dict[str, ResultCounts] + failed_tests: list[list[str]] + header: list[str] + + # # A standalone script to generate a Markdown representation of a test report. # This is primarily intended to be used with GitHub actions to generate a nice @@ -12,14 +63,7 @@ # -def generate_markdown(csv_path: str, exit_code: int = 0): # noqa (C901) - # Print warning if exit code is non-zero - if exit_code != 0: - print("> [!WARNING]") - print( - f"> Exit code {exit_code} was non-zero. Test process may have crashed. Check the job logs for more information.\n" - ) - +def aggregate_results(csv_path: str) -> AggregatedSummary: with open(csv_path, newline="", encoding="utf-8") as f: reader = csv.reader(f) rows = list(reader) @@ -27,78 +71,159 @@ def generate_markdown(csv_path: str, exit_code: int = 0): # noqa (C901) header = rows[0] data_rows = rows[1:] - # Find the Result and Result Detail column indices - result_column_index = None - result_detail_column_index = None - for i, col in enumerate(header): - if col.lower() == "result": - result_column_index = i - elif col.lower() == "result detail": - result_detail_column_index = i + header_indices_by_name = {n.lower(): i for (i, n) in enumerate(header)} + params_column_index = header_indices_by_name.get("params", None) + result_column_index = header_indices_by_name["result"] + result_detail_column_index = header_indices_by_name["result detail"] # Count results and prepare data - pass_count = 0 - fail_count = 0 - skip_count = 0 + counts = ResultCounts() failed_tests = [] - processed_rows = [] - result_detail_counts = {} + counts_by_param = {} for row in data_rows: + result = row[result_column_index] + result_detail = row[result_detail_column_index] + + counts.add_row(result, result_detail) + + params = row[params_column_index] if params_column_index else None + if params: + if params not in counts_by_param: + counts_by_param[params] = ResultCounts() + counts_by_param[params].add_row(result, result_detail) + # Make a copy of the row to avoid modifying the original - processed_row = row.copy() + processed_row = [escape_for_markdown(cell) for cell in row] # Count results and collect failed tests if result_column_index is not None and result_column_index < len(row): result_value = row[result_column_index].strip().lower() if result_value == "pass": - pass_count += 1 processed_row[result_column_index] = ( 'Pass' ) elif result_value == "fail": - fail_count += 1 processed_row[result_column_index] = ( 'Fail' ) failed_tests.append(processed_row.copy()) elif result_value == "skip": - skip_count += 1 processed_row[result_column_index] = ( 'Skip' ) - # Count result details (excluding empty ones) - if result_detail_column_index is not None and result_detail_column_index < len( - row - ): - result_detail_value = row[result_detail_column_index].strip() - if result_detail_value: # Only count non-empty result details - if result_detail_value in result_detail_counts: - result_detail_counts[result_detail_value] += 1 - else: - result_detail_counts[result_detail_value] = 1 + return AggregatedSummary( + counts=counts, + failed_tests=failed_tests, + counts_by_params=counts_by_param, + header=header, + ) + + +def escape_for_markdown(text: str) -> str: + """ + Modify a string to properly display in a markdown table cell. + """ + if not text: + return text - processed_rows.append(processed_row) + # Replace newlines with
tags + escaped = text.replace("\n", "
") + + # Escape backslashes. + escaped = escaped.replace("\\", "\\\\") + + # Escape pipe characters that would break table structure + escaped = escaped.replace("|", "\\|") + + return escaped + + +def generate_markdown(csv_path: str, exit_code: int = 0): # noqa (C901) + # Print warning if exit code is non-zero + if exit_code != 0: + print("> [!WARNING]") + print( + f"> Exit code {exit_code} was non-zero. Test process may have crashed. Check the job logs for more information.\n" + ) + + results = aggregate_results(csv_path) # Generate Summary section - total_rows = len(data_rows) print("# Summary\n") - print(f"- **Pass**: {pass_count}/{total_rows}") - print(f"- **Fail**: {fail_count}/{total_rows}") - print(f"- **Skip**: {skip_count}/{total_rows}") + total_excluding_skips = results.counts.passes + results.counts.fails + pass_fraction = results.counts.passes / total_excluding_skips + fail_fraction = results.counts.fails / total_excluding_skips + print( + f"- **Pass**: {results.counts.passes}/{total_excluding_skips} ({pass_fraction*100:.2f}%)" + ) + print( + f"- **Fail**: {results.counts.fails}/{total_excluding_skips} ({fail_fraction*100:.2f}%)" + ) + print(f"- **Skip**: {results.counts.skips}") + + if results.counts_by_params: + print("\n## Results by Parameters\n") + + # Extract all unique parameter keys from the JSON strings + all_param_keys = set() + parsed_params = {} + + for params_str in results.counts_by_params.keys(): + # Parse the JSON string (it's a string representation of a dict) + params_dict = json.loads(params_str) + parsed_params[params_str] = params_dict + all_param_keys.update(params_dict.keys()) + + if parsed_params and len(parsed_params) > 1: + # Sort parameter keys for consistent column ordering + sorted_param_keys = sorted(all_param_keys) + + # Create table header + header_cols = sorted_param_keys + ["Pass", "Fail", "Skip", "Pass %"] + print("| " + " | ".join(header_cols) + " |") + print("|" + "|".join(["---"] * len(header_cols)) + "|") + + # Create table rows + for params_str, counts in results.counts_by_params.items(): + if params_str in parsed_params: + params_dict = parsed_params[params_str] + row_values = [] + + # Add parameter values + for key in sorted_param_keys: + value = params_dict.get(key, "") + row_values.append(str(value)) + + pass_fraction = counts.passes / (counts.passes + counts.fails) + + # Add count values + row_values.extend( + [ + str(counts.passes), + str(counts.fails), + str(counts.skips), + f"{pass_fraction*100:.2f}%", + ] + ) + + print("| " + " | ".join(row_values) + " |") + + print() print("## Failure Breakdown:") - total_rows_with_result_detail = sum(result_detail_counts.values()) - for detail, count in sorted(result_detail_counts.items()): + total_rows_with_result_detail = sum(results.counts.by_detail.values()) + for detail, count in sorted(results.counts.by_detail.items()): print(f"- **{detail}**: {count}/{total_rows_with_result_detail}") # Generate Failed Tests section print("# Failed Tests\n") - if failed_tests: - print("| " + " | ".join(header) + " |") - print("|" + "|".join(["---"] * len(header)) + "|") - for row in failed_tests: + if results.failed_tests: + escaped_header = [escape_for_markdown(col) for col in results.header] + print("| " + " | ".join(escaped_header) + " |") + print("|" + "|".join(["---"] * len(results.header)) + "|") + for row in results.failed_tests: print("| " + " | ".join(row) + " |") else: print("No failed tests.\n") diff --git a/backends/test/suite/generate_markdown_summary_json.py b/backends/test/suite/generate_markdown_summary_json.py new file mode 100644 index 00000000000..4b6edc2a635 --- /dev/null +++ b/backends/test/suite/generate_markdown_summary_json.py @@ -0,0 +1,229 @@ +import argparse +import json + +from dataclasses import dataclass, field + + +@dataclass +class ResultCounts: + """ + Represents aggregated result counts for each status. + """ + + total: int = 0 + passes: int = 0 + fails: int = 0 + skips: int = 0 + by_detail: dict[str, int] = field(default_factory=lambda: {}) + + def add_row(self, result_value: str, result_detail: str) -> None: + """ + Update the result counts for the specified row. + """ + + self.total += 1 + + if result_value == "Pass": + self.passes += 1 + elif result_value == "Fail": + self.fails += 1 + elif result_value == "Skip": + self.skips += 1 + else: + raise RuntimeError(f"Unknown result value {result_value}") + + if result_detail: + if result_detail not in self.by_detail: + self.by_detail[result_detail] = 0 + + self.by_detail[result_detail] += 1 + + +@dataclass +class AggregatedSummary: + """ + Represents aggegrated summary data for the test run. + """ + + counts: ResultCounts + counts_by_params: dict[str, ResultCounts] + failed_tests: list[list[str]] + + +# +# A standalone script to generate a Markdown representation of a test report. +# This is primarily intended to be used with GitHub actions to generate a nice +# representation of the test results when looking at the action run. +# +# Usage: python executorch/backends/test/suite/generate_markdown_summary.py +# Markdown is written to stdout. +# + + +def aggregate_results(json_path: str) -> AggregatedSummary: + with open(json_path) as f: + data = json.load(f) + + # Count results and prepare data + counts = ResultCounts() + failed_tests = [] + counts_by_param = {} + + for test_data in data["tests"]: + result_meta = test_data["metadata"] + for subtest_meta in result_meta["subtests"]: + result = subtest_meta["Result"] + result_detail = subtest_meta.get("Result Detail") or "" + + counts.add_row(result, result_detail) + + test_id = subtest_meta["Test ID"] + base_test = subtest_meta["Test Case"] + params = test_id[len(base_test) + 1 : -1] + + if params: + if params not in counts_by_param: + counts_by_param[params] = ResultCounts() + counts_by_param[params].add_row(result, result_detail) + + if result.lower() == "fail": + failed_tests.append(subtest_meta) + + return AggregatedSummary( + counts=counts, + failed_tests=failed_tests, + counts_by_params=counts_by_param, + ) + + +def escape_for_markdown(text: str) -> str: + """ + Modify a string to properly display in a markdown table cell. + """ + if not text: + return text + + # Replace newlines with
tags + escaped = text.replace("\n", "
") + + # Escape backslashes. + escaped = escaped.replace("\\", "\\\\") + + # Escape pipe characters that would break table structure + escaped = escaped.replace("|", "\\|") + + return escaped + + +def generate_markdown(json_path: str, exit_code: int = 0): # noqa (C901) + results = aggregate_results(json_path) + + # Generate Summary section + print("# Summary\n") + total_excluding_skips = results.counts.passes + results.counts.fails + pass_fraction = results.counts.passes / total_excluding_skips + fail_fraction = results.counts.fails / total_excluding_skips + print( + f"- **Pass**: {results.counts.passes}/{total_excluding_skips} ({pass_fraction*100:.2f}%)" + ) + print( + f"- **Fail**: {results.counts.fails}/{total_excluding_skips} ({fail_fraction*100:.2f}%)" + ) + print(f"- **Skip**: {results.counts.skips}") + + if results.counts_by_params: + print("\n## Results by Parameters\n") + + if len(results.counts_by_params) > 0: + # Create table header + header_cols = ["Params", "Pass", "Fail", "Skip", "Pass %"] + print("| " + " | ".join(header_cols) + " |") + print("|" + "|".join(["---"] * len(header_cols)) + "|") + + # Create table rows + for params_str, counts in results.counts_by_params.items(): + row_values = [params_str] + + # Add parameter values + pass_fraction = counts.passes / (counts.passes + counts.fails) + + # Add count values + row_values.extend( + [ + str(counts.passes), + str(counts.fails), + str(counts.skips), + f"{pass_fraction*100:.2f}%", + ] + ) + + print("| " + " | ".join(row_values) + " |") + + print() + + print("## Failure Breakdown:") + total_rows_with_result_detail = sum(results.counts.by_detail.values()) + for detail, count in sorted(results.counts.by_detail.items()): + print(f"- **{detail}**: {count}/{total_rows_with_result_detail}") + + # Generate Failed Tests section + print("# Failed Tests\n") + print( + "To reproduce, run the following command from the root of the ExecuTorch repository:" + ) + print("```") + print('pytest -c /dev/nul backends/test/suite/ -k ""') + print("```") + if results.failed_tests: + header = build_header(results.failed_tests) + + escaped_header = [escape_for_markdown(col) for col in header.keys()] + print("| " + " | ".join(escaped_header) + " |") + print("|" + "|".join(["---"] * len(escaped_header)) + "|") + for rec in results.failed_tests: + row = build_row(rec, header) + print("| " + " | ".join(row) + " |") + else: + print("No failed tests.\n") + + +def build_header(data) -> dict[str, int]: + """ + Find the union of all keys and return a dict of header keys and indices. Try to preserve + ordering as much as possible. + """ + + keys = max(data, key=len) + + header = {k: i for (i, k) in enumerate(keys)} + + for rec in data: + keys = set(rec.keys()) + for k in keys: + if k not in header: + header[k] = len(header) + + return header + + +def build_row(rec, header: dict[str, int]) -> list[str]: + row = [""] * len(header) + for k, v in rec.items(): + row[header[k]] = escape_for_markdown(str(v)) + return row + + +def main(): + parser = argparse.ArgumentParser( + description="Generate a Markdown representation of a test report." + ) + parser.add_argument("json_path", help="Path to the test report CSV file.") + parser.add_argument( + "--exit-code", type=int, default=0, help="Exit code from the test process." + ) + args = parser.parse_args() + generate_markdown(args.json_path, args.exit_code) + + +if __name__ == "__main__": + main() diff --git a/backends/test/suite/models/__init__.py b/backends/test/suite/models/__init__.py index 65b546b0eb5..6ac1a72bde6 100644 --- a/backends/test/suite/models/__init__.py +++ b/backends/test/suite/models/__init__.py @@ -5,131 +5,3 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe - -import itertools -import os -import unittest -from typing import Any, Callable - -import torch -from executorch.backends.test.suite import get_test_flows -from executorch.backends.test.suite.context import get_active_test_context, TestContext -from executorch.backends.test.suite.flow import TestFlow -from executorch.backends.test.suite.reporting import log_test_summary -from executorch.backends.test.suite.runner import run_test - - -DTYPES: list[torch.dtype] = [ - torch.float16, - torch.float32, -] - - -def load_tests(loader, suite, pattern): - package_dir = os.path.dirname(__file__) - discovered_suite = loader.discover( - start_dir=package_dir, pattern=pattern or "test_*.py" - ) - suite.addTests(discovered_suite) - return suite - - -def _create_test( - cls, - test_func: Callable, - flow: TestFlow, - dtype: torch.dtype, - use_dynamic_shapes: bool, -): - dtype_name = str(dtype)[6:] # strip "torch." - test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}" - if use_dynamic_shapes: - test_name += "_dynamic_shape" - - def wrapped_test(self): - params = { - "dtype": dtype, - "use_dynamic_shapes": use_dynamic_shapes, - } - with TestContext(test_name, test_func.__name__, flow.name, params): - test_func(self, flow, dtype, use_dynamic_shapes) - - wrapped_test._name = test_func.__name__ # type: ignore - wrapped_test._flow = flow # type: ignore - - setattr(cls, test_name, wrapped_test) - - -# Expand a test into variants for each registered flow. -def _expand_test(cls, test_name: str) -> None: - test_func = getattr(cls, test_name) - supports_dynamic_shapes = getattr(test_func, "supports_dynamic_shapes", True) - dynamic_shape_values = [True, False] if supports_dynamic_shapes else [False] - dtypes = getattr(test_func, "dtypes", DTYPES) - - for flow, dtype, use_dynamic_shapes in itertools.product( - get_test_flows().values(), dtypes, dynamic_shape_values - ): - _create_test(cls, test_func, flow, dtype, use_dynamic_shapes) - delattr(cls, test_name) - - -def model_test_cls(cls) -> Callable | None: - """Decorator for model tests. Handles generating test variants for each test flow and configuration.""" - for key in dir(cls): - if key.startswith("test_"): - _expand_test(cls, key) - return cls - - -def model_test_params( - supports_dynamic_shapes: bool = True, - dtypes: list[torch.dtype] | None = None, -) -> Callable: - """Optional parameter decorator for model tests. Specifies test pararameters. Only valid with a class decorated by model_test_cls.""" - - def inner_decorator(func: Callable) -> Callable: - func.supports_dynamic_shapes = supports_dynamic_shapes # type: ignore - - if dtypes is not None: - func.dtypes = dtypes # type: ignore - - return func - - return inner_decorator - - -def run_model_test( - model: torch.nn.Module, - inputs: tuple[Any], - flow: TestFlow, - dtype: torch.dtype, - dynamic_shapes: Any | None, -): - model = model.to(dtype) - context = get_active_test_context() - - # This should be set in the wrapped test. See _create_test above. - assert context is not None, "Missing test context." - - run_summary = run_test( - model, - inputs, - flow, - context.test_name, - context.test_base_name, - 0, # subtest_index - currently unused for model tests - context.params, - dynamic_shapes=dynamic_shapes, - ) - - log_test_summary(run_summary) - - if not run_summary.result.is_success(): - if run_summary.result.is_backend_failure(): - raise RuntimeError("Test failure.") from run_summary.error - else: - # Non-backend failure indicates a bad test. Mark as skipped. - raise unittest.SkipTest( - f"Test failed for reasons other than backend failure. Error: {run_summary.error}" - ) diff --git a/backends/test/suite/models/test_torchaudio.py b/backends/test/suite/models/test_torchaudio.py index 69f6de4684f..2287b226c37 100644 --- a/backends/test/suite/models/test_torchaudio.py +++ b/backends/test/suite/models/test_torchaudio.py @@ -9,15 +9,11 @@ import unittest from typing import Tuple +import pytest import torch import torchaudio -from executorch.backends.test.suite.flow import TestFlow -from executorch.backends.test.suite.models import ( - model_test_cls, - model_test_params, - run_model_test, -) +from executorch.backends.test.suite import dtype_to_str from torch.export import Dim # @@ -47,64 +43,68 @@ def forward( return x.transpose(0, 1) -@model_test_cls -class TorchAudio(unittest.TestCase): - @model_test_params(dtypes=[torch.float32], supports_dynamic_shapes=False) - def test_conformer( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - inner_model = torchaudio.models.Conformer( - input_dim=80, - num_heads=4, - ffn_dim=128, - num_layers=4, - depthwise_conv_kernel_size=31, - ) - model = PatchedConformer(inner_model) - lengths = torch.randint(1, 400, (10,)) +@pytest.mark.parametrize("dtype", [torch.float32], ids=dtype_to_str) +@pytest.mark.parametrize("use_dynamic_shapes", [False], ids=["static_shapes"]) +def test_conformer(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + inner_model = torchaudio.models.Conformer( + input_dim=80, + num_heads=4, + ffn_dim=128, + num_layers=4, + depthwise_conv_kernel_size=31, + ) + model = PatchedConformer(inner_model).eval().to(dtype) + lengths = torch.randint(1, 400, (10,)) - encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask( - lengths - ) - inputs = ( - torch.rand(10, int(lengths.max()), 80), - encoder_padding_mask, - ) + encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask(lengths) + inputs = ( + torch.rand(10, int(lengths.max()), 80), + encoder_padding_mask, + ) + + test_runner.lower_and_run_model(model, inputs) - run_model_test(model, inputs, flow, dtype, None) - - @model_test_params(dtypes=[torch.float32]) - def test_wav2letter( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchaudio.models.Wav2Letter() - inputs = (torch.randn(1, 1, 1024, dtype=dtype),) - dynamic_shapes = ( - { - "x": { - 2: Dim("d", min=900, max=1024), - } + +@pytest.mark.parametrize("dtype", [torch.float32], ids=dtype_to_str) +@pytest.mark.parametrize( + "use_dynamic_shapes", [False, True], ids=["static_shapes", "dynamic_shapes"] +) +def test_wav2letter(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchaudio.models.Wav2Letter().to(dtype) + inputs = (torch.randn(1, 1, 1024, dtype=dtype),) + dynamic_shapes = ( + { + "x": { + 2: Dim("d", min=900, max=1024), } - if use_dynamic_shapes - else None - ) - run_model_test(model, inputs, flow, dtype, dynamic_shapes) - - @unittest.skip("This model times out on all backends.") - def test_wavernn( - self, - flow: TestFlow, - dtype: torch.dtype, - use_dynamic_shapes: bool, - ): - model = torchaudio.models.WaveRNN( + } + if use_dynamic_shapes + else None + ) + + test_runner.lower_and_run_model(model, inputs, dynamic_shapes=dynamic_shapes) + + +@pytest.mark.parametrize("dtype", [torch.float32], ids=dtype_to_str) +@pytest.mark.parametrize("use_dynamic_shapes", [False], ids=["static_shapes"]) +@unittest.skip("This model times out on all backends.") +def test_wavernn( + test_runner, + dtype: torch.dtype, + use_dynamic_shapes: bool, +): + model = ( + torchaudio.models.WaveRNN( upsample_scales=[5, 5, 8], n_classes=512, hop_length=200 - ).eval() - - # See https://docs.pytorch.org/audio/stable/generated/torchaudio.models.WaveRNN.html#forward - inputs = ( - torch.randn(1, 1, (64 - 5 + 1) * 200), # waveform - torch.randn(1, 1, 128, 64), # specgram ) + .eval() + .to(dtype) + ) + + # See https://docs.pytorch.org/audio/stable/generated/torchaudio.models.WaveRNN.html#forward + inputs = ( + torch.randn(1, 1, (64 - 5 + 1) * 200).to(dtype), # waveform + torch.randn(1, 1, 128, 64).to(dtype), # specgram + ) - run_model_test(model, inputs, flow, dtype, None) + test_runner.lower_and_run_model(model, inputs) diff --git a/backends/test/suite/models/test_torchvision.py b/backends/test/suite/models/test_torchvision.py index e69de80a871..58cf6a990d4 100644 --- a/backends/test/suite/models/test_torchvision.py +++ b/backends/test/suite/models/test_torchvision.py @@ -6,17 +6,12 @@ # pyre-unsafe -import unittest +import pytest import torch import torchvision +from executorch.backends.test.suite import dtype_to_str -from executorch.backends.test.suite.flow import TestFlow -from executorch.backends.test.suite.models import ( - model_test_cls, - model_test_params, - run_model_test, -) from torch.export import Dim # @@ -25,148 +20,175 @@ # multiple size variants, one small or medium variant is used. # +PARAMETERIZE_DTYPE = pytest.mark.parametrize("dtype", [torch.float32], ids=dtype_to_str) +PARAMETERIZE_DYNAMIC_SHAPES = pytest.mark.parametrize( + "use_dynamic_shapes", [False, True], ids=["static_shapes", "dynamic_shapes"] +) +PARAMETERIZE_STATIC_ONLY = pytest.mark.parametrize( + "use_dynamic_shapes", [False], ids=["static_shapes"] +) + + +def _test_cv_model( + model: torch.nn.Module, + test_runner, + dtype: torch.dtype, + use_dynamic_shapes: bool, +): + model = model.eval().to(dtype) + + # Test a CV model that follows the standard conventions. + inputs = (torch.randn(1, 3, 224, 224, dtype=dtype),) -@model_test_cls -class TorchVision(unittest.TestCase): - def _test_cv_model( - self, - model: torch.nn.Module, - flow: TestFlow, - dtype: torch.dtype, - use_dynamic_shapes: bool, - ): - # Test a CV model that follows the standard conventions. - inputs = (torch.randn(1, 3, 224, 224, dtype=dtype),) - - dynamic_shapes = ( - ( - { - 2: Dim("height", min=1, max=16) * 16, - 3: Dim("width", min=1, max=16) * 16, - }, - ) - if use_dynamic_shapes - else None + dynamic_shapes = ( + ( + { + 2: Dim("height", min=1, max=16) * 16, + 3: Dim("width", min=1, max=16) * 16, + }, ) + if use_dynamic_shapes + else None + ) + + test_runner.lower_and_run_model(model, inputs, dynamic_shapes=dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_alexnet(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.alexnet() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_convnext_small(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.convnext_small() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_densenet161(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.densenet161() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_efficientnet_b4(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.efficientnet_b4() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_efficientnet_v2_s(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.efficientnet_v2_s() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_googlenet(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.googlenet() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_inception_v3(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.inception_v3() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_STATIC_ONLY +def test_maxvit_t(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.maxvit_t() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_mnasnet1_0(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.mnasnet1_0() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_mobilenet_v2(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.mobilenet_v2() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_mobilenet_v3_small(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.mobilenet_v3_small() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_regnet_y_1_6gf(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.regnet_y_1_6gf() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_resnet50(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.resnet50() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_resnext50_32x4d(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.resnext50_32x4d() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_shufflenet_v2_x1_0(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.shufflenet_v2_x1_0() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_squeezenet1_1(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.squeezenet1_1() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_swin_v2_t(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.swin_v2_t() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_vgg11(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.vgg11() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + + +@PARAMETERIZE_DTYPE +@PARAMETERIZE_STATIC_ONLY +def test_vit_b_16(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.vit_b_16() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) + - run_model_test(model, inputs, flow, dtype, dynamic_shapes) - - def test_alexnet( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.alexnet() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_convnext_small( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.convnext_small() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_densenet161( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.densenet161() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_efficientnet_b4( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.efficientnet_b4() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_efficientnet_v2_s( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.efficientnet_v2_s() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_googlenet( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.googlenet() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_inception_v3( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.inception_v3() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - @model_test_params(supports_dynamic_shapes=False) - def test_maxvit_t( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.maxvit_t() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_mnasnet1_0( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.mnasnet1_0() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_mobilenet_v2( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.mobilenet_v2() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_mobilenet_v3_small( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.mobilenet_v3_small() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_regnet_y_1_6gf( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.regnet_y_1_6gf() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_resnet50( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.resnet50() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_resnext50_32x4d( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.resnext50_32x4d() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_shufflenet_v2_x1_0( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.shufflenet_v2_x1_0() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_squeezenet1_1( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.squeezenet1_1() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_swin_v2_t( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.swin_v2_t() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_vgg11(self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool): - model = torchvision.models.vgg11() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - @model_test_params(supports_dynamic_shapes=False) - def test_vit_b_16( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.vit_b_16() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - - def test_wide_resnet50_2( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): - model = torchvision.models.wide_resnet50_2() - self._test_cv_model(model, flow, dtype, use_dynamic_shapes) +@PARAMETERIZE_DTYPE +@PARAMETERIZE_DYNAMIC_SHAPES +def test_wide_resnet50_2(test_runner, dtype: torch.dtype, use_dynamic_shapes: bool): + model = torchvision.models.wide_resnet50_2() + _test_cv_model(model, test_runner, dtype, use_dynamic_shapes) diff --git a/backends/test/suite/operators/__init__.py b/backends/test/suite/operators/__init__.py index 6ceb9086f71..825aa316771 100644 --- a/backends/test/suite/operators/__init__.py +++ b/backends/test/suite/operators/__init__.py @@ -6,19 +6,14 @@ # pyre-unsafe -import copy import os +import sys import unittest from enum import Enum -from typing import Callable +import pytest import torch -from executorch.backends.test.suite import get_test_flows -from executorch.backends.test.suite.context import get_active_test_context, TestContext -from executorch.backends.test.suite.flow import TestFlow -from executorch.backends.test.suite.reporting import log_test_summary -from executorch.backends.test.suite.runner import run_test def load_tests(loader, suite, pattern): @@ -66,107 +61,48 @@ def dtype_test(func): return func -# Class annotation for operator tests. This triggers the test framework to register -# the tests. -def operator_test(cls): - _create_tests(cls) - return cls - - -# Generate test cases for each backend flow. -def _create_tests(cls): - for key in dir(cls): - if key.startswith("test_"): - _expand_test(cls, key) - +class OperatorTest(unittest.TestCase): + pass -# Expand a test into variants for each registered flow. -def _expand_test(cls, test_name: str): - test_func = getattr(cls, test_name) - for flow in get_test_flows().values(): - _create_test_for_backend(cls, test_func, flow) - delattr(cls, test_name) +class TestCaseShim: + def __init__(self, test_runner): + self._test_runner = test_runner -def _make_wrapped_test( - test_func: Callable, - test_name: str, - test_base_name: str, - flow: TestFlow, - params: dict | None = None, -): - def wrapped_test(self): - with TestContext(test_name, test_base_name, flow.name, params): - test_kwargs = copy.copy(params) or {} - test_kwargs["flow"] = flow + def _test_op(self, model, args, flow, generate_random_test_inputs=False): + self._test_runner.lower_and_run_model( + model, args, generate_random_test_inputs=generate_random_test_inputs + ) - test_func(self, **test_kwargs) - wrapped_test._name = test_name - wrapped_test._flow = flow +def wrap_test(original_func, test_type): + if test_type == TestType.STANDARD: - return wrapped_test + def wrapped_func(test_runner): + shim = TestCaseShim(test_runner) + original_func(shim, test_runner._flow) + return wrapped_func + elif test_type == TestType.DTYPE: -def _create_test_for_backend( - cls, - test_func: Callable, - flow: TestFlow, -): - test_type = getattr(test_func, "test_type", TestType.STANDARD) + @pytest.mark.parametrize("dtype", [torch.float32], ids=lambda s: str(s)[6:]) + def wrapped_func(test_runner, dtype): + shim = TestCaseShim(test_runner) + original_func(shim, test_runner._flow, dtype) - if test_type == TestType.STANDARD: - test_name = f"{test_func.__name__}_{flow.name}" - wrapped_test = _make_wrapped_test( - test_func, test_name, test_func.__name__, flow - ) - setattr(cls, test_name, wrapped_test) - elif test_type == TestType.DTYPE: - for dtype in DTYPES: - dtype_name = str(dtype)[6:] # strip "torch." - test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}" - wrapped_test = _make_wrapped_test( - test_func, - test_name, - test_func.__name__, - flow, - {"dtype": dtype}, - ) - setattr(cls, test_name, wrapped_test) + return wrapped_func else: - raise NotImplementedError(f"Unknown test type {test_type}.") + raise ValueError() -class OperatorTest(unittest.TestCase): - def _test_op( - self, model, inputs, flow: TestFlow, generate_random_test_inputs: bool = True - ): - context = get_active_test_context() - - # This should be set in the wrapped test. See _make_wrapped_test above. - assert context is not None, "Missing test context." - - run_summary = run_test( - model, - inputs, - flow, - context.test_name, - context.test_base_name, - context.subtest_index, - context.params, - generate_random_test_inputs=generate_random_test_inputs, - ) - - log_test_summary(run_summary) +def operator_test(cls): + parent_module = sys.modules[cls.__module__] - # This is reset when a new test is started - it creates the context per-test. - context.subtest_index = context.subtest_index + 1 + for func_name in dir(cls): + if func_name.startswith("test"): + original_func = getattr(cls, func_name) + test_type = getattr(original_func, "test_type", TestType.STANDARD) + wrapped_func = wrap_test(original_func, test_type) + setattr(parent_module, func_name, wrapped_func) - if not run_summary.result.is_success(): - if run_summary.result.is_backend_failure(): - raise RuntimeError("Test failure.") from run_summary.error - else: - # Non-backend failure indicates a bad test. Mark as skipped. - raise unittest.SkipTest( - f"Test failed for reasons other than backend failure. Error: {run_summary.error}" - ) + return None diff --git a/backends/test/suite/operators/test_abs.py b/backends/test/suite/operators/test_abs.py index fdfc6be671e..484281e294e 100644 --- a/backends/test/suite/operators/test_abs.py +++ b/backends/test/suite/operators/test_abs.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -45,6 +47,7 @@ def test_abs_shapes(self, flow: TestFlow) -> None: # 3D tensor self._test_op(AbsModel(), (torch.randn(3, 4, 5),), flow) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_abs_edge_cases(self, flow: TestFlow) -> None: # Test edge cases diff --git a/backends/test/suite/operators/test_add.py b/backends/test/suite/operators/test_add.py index 6b21c3bf985..850e6f5132c 100644 --- a/backends/test/suite/operators/test_add.py +++ b/backends/test/suite/operators/test_add.py @@ -7,14 +7,8 @@ # pyre-unsafe +import pytest import torch -from executorch.backends.test.suite.flow import TestFlow - -from executorch.backends.test.suite.operators import ( - dtype_test, - operator_test, - OperatorTest, -) class Model(torch.nn.Module): @@ -31,55 +25,52 @@ def forward(self, x, y): return torch.add(x, y, alpha=self.alpha) -@operator_test -class Add(OperatorTest): - @dtype_test - def test_add_dtype(self, flow: TestFlow, dtype) -> None: - self._test_op( - Model(), - ( - (torch.rand(2, 10) * 100).to(dtype), - (torch.rand(2, 10) * 100).to(dtype), - ), - flow, - ) - - def test_add_f32_bcast_first(self, flow: TestFlow) -> None: - self._test_op( - Model(), - ( - torch.randn(5), - torch.randn(1, 5, 1, 5), - ), - flow, - ) - - def test_add_f32_bcast_second(self, flow: TestFlow) -> None: - self._test_op( - Model(), - ( - torch.randn(4, 4, 2, 7), - torch.randn(2, 7), - ), - flow, - ) - - def test_add_f32_bcast_unary(self, flow: TestFlow) -> None: - self._test_op( - Model(), - ( - torch.randn(5), - torch.randn(1, 1, 5), - ), - flow, - ) - - def test_add_f32_alpha(self, flow: TestFlow) -> None: - self._test_op( - ModelAlpha(alpha=2), - ( - torch.randn(1, 25), - torch.randn(1, 25), - ), - flow, - ) +@pytest.mark.parametrize("dtype", [torch.float32], ids=lambda s: str(s)[6:]) +def test_add_dtype(test_runner, dtype) -> None: + test_runner.lower_and_run_model( + Model(), + ( + (torch.rand(2, 10) * 100).to(dtype), + (torch.rand(2, 10) * 100).to(dtype), + ), + ) + + +def test_add_f32_bcast_first(test_runner) -> None: + test_runner.lower_and_run_model( + Model(), + ( + torch.randn(5), + torch.randn(1, 5, 1, 5), + ), + ) + + +def test_add_f32_bcast_second(test_runner) -> None: + test_runner.lower_and_run_model( + Model(), + ( + torch.randn(4, 4, 2, 7), + torch.randn(2, 7), + ), + ) + + +def test_add_f32_bcast_unary(test_runner) -> None: + test_runner.lower_and_run_model( + Model(), + ( + torch.randn(5), + torch.randn(1, 1, 5), + ), + ) + + +def test_add_f32_alpha(test_runner) -> None: + test_runner.lower_and_run_model( + ModelAlpha(alpha=2), + ( + torch.randn(1, 25), + torch.randn(1, 25), + ), + ) diff --git a/backends/test/suite/operators/test_amax.py b/backends/test/suite/operators/test_amax.py index 0c9a8c06f0d..04e0b17ae0a 100644 --- a/backends/test/suite/operators/test_amax.py +++ b/backends/test/suite/operators/test_amax.py @@ -6,6 +6,7 @@ # pyre-unsafe +import unittest from typing import List, Optional, Tuple, Union import torch @@ -201,6 +202,7 @@ def test_amax_shapes(self, flow: TestFlow) -> None: flow, ) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_amax_edge_cases(self, flow: TestFlow) -> None: x = torch.tensor([[1.0, float("inf"), 3.0], [4.0, 5.0, float("inf")]]) self._test_op( diff --git a/backends/test/suite/operators/test_amin.py b/backends/test/suite/operators/test_amin.py index f4b88b1dade..7aa5c6b7a34 100644 --- a/backends/test/suite/operators/test_amin.py +++ b/backends/test/suite/operators/test_amin.py @@ -6,6 +6,7 @@ # pyre-unsafe +import unittest from typing import List, Optional, Tuple, Union import torch @@ -203,6 +204,7 @@ def test_amin_shapes(self, flow: TestFlow) -> None: flow, ) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_amin_edge_cases(self, flow: TestFlow) -> None: x = torch.tensor([[1.0, float("-inf"), 3.0], [4.0, 5.0, float("-inf")]]) self._test_op( diff --git a/backends/test/suite/operators/test_argmax.py b/backends/test/suite/operators/test_argmax.py index dc8b57fc214..ca3ae9e1805 100644 --- a/backends/test/suite/operators/test_argmax.py +++ b/backends/test/suite/operators/test_argmax.py @@ -6,6 +6,7 @@ # pyre-unsafe +import unittest from typing import Optional import torch @@ -143,6 +144,7 @@ def test_argmax_shapes(self, flow: TestFlow) -> None: flow, ) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_argmax_edge_cases(self, flow: TestFlow) -> None: x = torch.tensor([[1.0, float("inf"), 3.0], [4.0, 5.0, float("inf")]]) self._test_op( diff --git a/backends/test/suite/operators/test_argmin.py b/backends/test/suite/operators/test_argmin.py index d7a24e24f5a..aaf4e9bd167 100644 --- a/backends/test/suite/operators/test_argmin.py +++ b/backends/test/suite/operators/test_argmin.py @@ -6,6 +6,7 @@ # pyre-unsafe +import unittest from typing import Optional import torch @@ -143,6 +144,7 @@ def test_argmin_shapes(self, flow: TestFlow) -> None: flow, ) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_argmin_edge_cases(self, flow: TestFlow) -> None: x = torch.tensor([[1.0, float("-inf"), 3.0], [4.0, 5.0, float("-inf")]]) self._test_op( diff --git a/backends/test/suite/operators/test_ceil.py b/backends/test/suite/operators/test_ceil.py index 198c9e9fe16..4d7c0a5e888 100644 --- a/backends/test/suite/operators/test_ceil.py +++ b/backends/test/suite/operators/test_ceil.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -45,6 +47,7 @@ def test_ceil_shapes(self, flow: TestFlow) -> None: # 3D tensor self._test_op(CeilModel(), (torch.randn(3, 4, 5),), flow) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_ceil_edge_cases(self, flow: TestFlow) -> None: # Test edge cases diff --git a/backends/test/suite/operators/test_clamp.py b/backends/test/suite/operators/test_clamp.py index 67c61c67caa..49419f0453a 100644 --- a/backends/test/suite/operators/test_clamp.py +++ b/backends/test/suite/operators/test_clamp.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -56,6 +58,7 @@ def test_clamp_shapes(self, flow: TestFlow) -> None: # 3D tensor self._test_op(model, (torch.randn(3, 4, 5),), flow) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_clamp_edge_cases(self, flow: TestFlow) -> None: # Test edge cases diff --git a/backends/test/suite/operators/test_div.py b/backends/test/suite/operators/test_div.py index 656d350585d..d493c97a20d 100644 --- a/backends/test/suite/operators/test_div.py +++ b/backends/test/suite/operators/test_div.py @@ -46,6 +46,7 @@ def test_divide_dtype(self, flow: TestFlow, dtype) -> None: ), # Adding 0.1 to avoid division by zero ), flow, + generate_random_test_inputs=False, ) def test_divide_f32_bcast_first(self, flow: TestFlow) -> None: @@ -57,6 +58,7 @@ def test_divide_f32_bcast_first(self, flow: TestFlow) -> None: + 0.1, # Using abs and adding 0.1 to avoid division by zero ), flow, + generate_random_test_inputs=False, ) def test_divide_f32_bcast_second(self, flow: TestFlow) -> None: @@ -68,6 +70,7 @@ def test_divide_f32_bcast_second(self, flow: TestFlow) -> None: + 0.1, # Using abs and adding 0.1 to avoid division by zero ), flow, + generate_random_test_inputs=False, ) def test_divide_f32_bcast_unary(self, flow: TestFlow) -> None: @@ -79,6 +82,7 @@ def test_divide_f32_bcast_unary(self, flow: TestFlow) -> None: + 0.1, # Using abs and adding 0.1 to avoid division by zero ), flow, + generate_random_test_inputs=False, ) def test_divide_f32_trunc(self, flow: TestFlow) -> None: @@ -90,6 +94,7 @@ def test_divide_f32_trunc(self, flow: TestFlow) -> None: + 0.1, # Using abs and adding 0.1 to avoid division by zero ), flow, + generate_random_test_inputs=False, ) def test_divide_f32_floor(self, flow: TestFlow) -> None: @@ -101,4 +106,5 @@ def test_divide_f32_floor(self, flow: TestFlow) -> None: + 0.1, # Using abs and adding 0.1 to avoid division by zero ), flow, + generate_random_test_inputs=False, ) diff --git a/backends/test/suite/operators/test_elu.py b/backends/test/suite/operators/test_elu.py index f768a426954..361e1382c37 100644 --- a/backends/test/suite/operators/test_elu.py +++ b/backends/test/suite/operators/test_elu.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -42,5 +44,6 @@ def test_elu_f32_multi_dim(self, flow: TestFlow) -> None: def test_elu_f32_alpha(self, flow: TestFlow) -> None: self._test_op(Model(alpha=0.5), (torch.randn(3, 4, 5),), flow) + @unittest.skip("In place activations aren't properly defunctionalized yet.") def test_elu_f32_inplace(self, flow: TestFlow) -> None: self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/operators/test_exp.py b/backends/test/suite/operators/test_exp.py index bdae5c6a5e6..54196d81ba9 100644 --- a/backends/test/suite/operators/test_exp.py +++ b/backends/test/suite/operators/test_exp.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -46,6 +48,7 @@ def test_exp_shapes(self, flow: TestFlow) -> None: # 3D tensor self._test_op(ExpModel(), (torch.randn(3, 4, 5),), flow) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_exp_edge_cases(self, flow: TestFlow) -> None: # Test edge cases diff --git a/backends/test/suite/operators/test_floor.py b/backends/test/suite/operators/test_floor.py index fcc834afa16..bce9f0b4d34 100644 --- a/backends/test/suite/operators/test_floor.py +++ b/backends/test/suite/operators/test_floor.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -42,6 +44,7 @@ def test_floor_shapes(self, flow: TestFlow) -> None: # 3D tensor self._test_op(FloorModel(), (torch.randn(3, 4, 5),), flow) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_floor_edge_cases(self, flow: TestFlow) -> None: # Test edge cases diff --git a/backends/test/suite/operators/test_floor_divide.py b/backends/test/suite/operators/test_floor_divide.py index 87104af11dc..c14151b6181 100644 --- a/backends/test/suite/operators/test_floor_divide.py +++ b/backends/test/suite/operators/test_floor_divide.py @@ -6,6 +6,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -178,6 +180,7 @@ def test_floor_divide_values(self, flow: TestFlow) -> None: y = torch.tensor([-2.0]).expand_as(x).clone() self._test_op(model, (x, y), flow, generate_random_test_inputs=False) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_floor_divide_edge_cases(self, flow: TestFlow) -> None: # Test edge cases model = FloorDivideModel() diff --git a/backends/test/suite/operators/test_hardsigmoid.py b/backends/test/suite/operators/test_hardsigmoid.py index 238b18b1e0d..8ca254d4f61 100644 --- a/backends/test/suite/operators/test_hardsigmoid.py +++ b/backends/test/suite/operators/test_hardsigmoid.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -38,6 +40,7 @@ def test_hardsigmoid_f32_single_dim(self, flow: TestFlow) -> None: def test_hardsigmoid_f32_multi_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) + @unittest.skip("In place activations aren't properly defunctionalized yet.") def test_hardsigmoid_f32_inplace(self, flow: TestFlow) -> None: self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/operators/test_hardswish.py b/backends/test/suite/operators/test_hardswish.py index 66902791c33..a93516542c8 100644 --- a/backends/test/suite/operators/test_hardswish.py +++ b/backends/test/suite/operators/test_hardswish.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -38,6 +40,7 @@ def test_hardswish_f32_single_dim(self, flow: TestFlow) -> None: def test_hardswish_f32_multi_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) + @unittest.skip("In place activations aren't properly defunctionalized yet.") def test_hardswish_f32_inplace(self, flow: TestFlow) -> None: self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/operators/test_hardtanh.py b/backends/test/suite/operators/test_hardtanh.py index 2fcd1dbf563..7520c3faeae 100644 --- a/backends/test/suite/operators/test_hardtanh.py +++ b/backends/test/suite/operators/test_hardtanh.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -45,6 +47,7 @@ def test_hardtanh_f32_multi_dim(self, flow: TestFlow) -> None: def test_hardtanh_f32_custom_range(self, flow: TestFlow) -> None: self._test_op(Model(min_val=-2.0, max_val=2.0), (torch.randn(3, 4, 5),), flow) + @unittest.skip("In place activations aren't properly defunctionalized yet.") def test_hardtanh_f32_inplace(self, flow: TestFlow) -> None: self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/operators/test_leaky_relu.py b/backends/test/suite/operators/test_leaky_relu.py index 983da47bba3..79ed5425623 100644 --- a/backends/test/suite/operators/test_leaky_relu.py +++ b/backends/test/suite/operators/test_leaky_relu.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -44,6 +46,7 @@ def test_leaky_relu_f32_multi_dim(self, flow: TestFlow) -> None: def test_leaky_relu_f32_custom_slope(self, flow: TestFlow) -> None: self._test_op(Model(negative_slope=0.1), (torch.randn(3, 4, 5),), flow) + @unittest.skip("In place activations aren't properly defunctionalized yet.") def test_leaky_relu_f32_inplace(self, flow: TestFlow) -> None: self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/operators/test_log.py b/backends/test/suite/operators/test_log.py index 96ba8da1292..320f4fe463b 100644 --- a/backends/test/suite/operators/test_log.py +++ b/backends/test/suite/operators/test_log.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -32,20 +34,41 @@ def test_log_dtype(self, flow: TestFlow, dtype) -> None: # Test with different dtypes model = LogModel().to(dtype) # Use positive values only for log - self._test_op(model, (torch.rand(10, 10).to(dtype) + 0.01,), flow) + self._test_op( + model, + (torch.rand(10, 10).to(dtype) + 0.01,), + flow, + generate_random_test_inputs=False, + ) def test_log_shapes(self, flow: TestFlow) -> None: # Test with different tensor shapes # 1D tensor - self._test_op(LogModel(), (torch.rand(20) + 0.01,), flow) + self._test_op( + LogModel(), + (torch.rand(20) + 0.01,), + flow, + generate_random_test_inputs=False, + ) # 2D tensor - self._test_op(LogModel(), (torch.rand(5, 10) + 0.01,), flow) + self._test_op( + LogModel(), + (torch.rand(5, 10) + 0.01,), + flow, + generate_random_test_inputs=False, + ) # 3D tensor - self._test_op(LogModel(), (torch.rand(3, 4, 5) + 0.01,), flow) + self._test_op( + LogModel(), + (torch.rand(3, 4, 5) + 0.01,), + flow, + generate_random_test_inputs=False, + ) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_log_edge_cases(self, flow: TestFlow) -> None: # Test edge cases # Tensor with infinity diff --git a/backends/test/suite/operators/test_log10.py b/backends/test/suite/operators/test_log10.py index 7d0e2e111d6..aeb97671f1b 100644 --- a/backends/test/suite/operators/test_log10.py +++ b/backends/test/suite/operators/test_log10.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -46,6 +48,7 @@ def test_log10_shapes(self, flow: TestFlow) -> None: # 3D tensor self._test_op(Log10Model(), (torch.rand(3, 4, 5) + 0.01,), flow) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_log10_edge_cases(self, flow: TestFlow) -> None: # Test edge cases # Tensor with infinity diff --git a/backends/test/suite/operators/test_log1p.py b/backends/test/suite/operators/test_log1p.py index 383e3116b32..08a5c382076 100644 --- a/backends/test/suite/operators/test_log1p.py +++ b/backends/test/suite/operators/test_log1p.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -46,6 +48,7 @@ def test_log1p_shapes(self, flow: TestFlow) -> None: # 3D tensor self._test_op(Log1pModel(), (torch.rand(3, 4, 5) * 2 - 0.5,), flow) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_log1p_edge_cases(self, flow: TestFlow) -> None: # Test edge cases # Tensor with infinity diff --git a/backends/test/suite/operators/test_log2.py b/backends/test/suite/operators/test_log2.py index ddcafaf08d2..16161d334f6 100644 --- a/backends/test/suite/operators/test_log2.py +++ b/backends/test/suite/operators/test_log2.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -46,6 +48,7 @@ def test_log2_shapes(self, flow: TestFlow) -> None: # 3D tensor self._test_op(Log2Model(), (torch.rand(3, 4, 5) + 0.01,), flow) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_log2_edge_cases(self, flow: TestFlow) -> None: # Test edge cases # Tensor with infinity diff --git a/backends/test/suite/operators/test_lstm.py b/backends/test/suite/operators/test_lstm.py index 91dd73c9052..11632e1e055 100644 --- a/backends/test/suite/operators/test_lstm.py +++ b/backends/test/suite/operators/test_lstm.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -15,6 +16,11 @@ operator_test, OperatorTest, ) +from torch.nn.quantizable.modules.rnn import LSTM as QuantizableLSTM + + +def _get_lstm_cls(use_quantizable_lstm: bool): + return QuantizableLSTM if use_quantizable_lstm else torch.nn.LSTM class Model(torch.nn.Module): @@ -27,9 +33,11 @@ def __init__( batch_first=True, dropout=0.0, bidirectional=False, + use_quantizable_lstm: bool = False, ): super().__init__() - self.lstm = torch.nn.LSTM( + lstm_cls = _get_lstm_cls(use_quantizable_lstm) + self.lstm = lstm_cls( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, @@ -47,106 +55,133 @@ def forward(self, x): class LSTM(OperatorTest): @dtype_test def test_lstm_dtype(self, flow: TestFlow, dtype) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(num_layers=2).to(dtype), + Model(num_layers=2, use_quantizable_lstm=use_quantizable_lstm).to(dtype), ((torch.rand(1, 10, 64) * 10).to(dtype),), # (batch=1, seq_len, input_size) flow, ) @dtype_test def test_lstm_no_bias_dtype(self, flow: TestFlow, dtype) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(num_layers=2, bias=False).to(dtype), + Model( + num_layers=2, bias=False, use_quantizable_lstm=use_quantizable_lstm + ).to(dtype), ((torch.rand(1, 10, 64) * 10).to(dtype),), flow, ) def test_lstm_feature_sizes(self, flow: TestFlow) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(input_size=32, hidden_size=16), + Model( + input_size=32, + hidden_size=16, + use_quantizable_lstm=use_quantizable_lstm, + ), (torch.randn(1, 8, 32),), # (batch=1, seq_len, input_size) flow, ) self._test_op( - Model(input_size=128, hidden_size=64), + Model( + input_size=128, + hidden_size=64, + use_quantizable_lstm=use_quantizable_lstm, + ), (torch.randn(1, 12, 128),), flow, ) self._test_op( - Model(input_size=256, hidden_size=128), + Model( + input_size=256, + hidden_size=128, + use_quantizable_lstm=use_quantizable_lstm, + ), (torch.randn(1, 6, 256),), flow, ) self._test_op( - Model(input_size=16, hidden_size=32), + Model( + input_size=16, + hidden_size=32, + use_quantizable_lstm=use_quantizable_lstm, + ), (torch.randn(1, 5, 16),), flow, ) def test_lstm_batch_sizes(self, flow: TestFlow) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(), + Model(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(8, 10, 64),), flow, ) self._test_op( - Model(), + Model(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(32, 10, 64),), flow, ) self._test_op( - Model(), + Model(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(100, 10, 64),), flow, ) def test_lstm_seq_lengths(self, flow: TestFlow) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(), + Model(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 5, 64),), flow, ) self._test_op( - Model(), + Model(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 20, 64),), flow, ) self._test_op( - Model(), + Model(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 50, 64),), flow, ) def test_lstm_batch_first_false(self, flow: TestFlow) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(batch_first=False), + Model(batch_first=False, use_quantizable_lstm=use_quantizable_lstm), (torch.randn(10, 1, 64),), # (seq_len, batch=1, input_size) flow, ) def test_lstm_num_layers(self, flow: TestFlow) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(num_layers=2), + Model(num_layers=2, use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 10, 64),), flow, ) self._test_op( - Model(num_layers=3), + Model(num_layers=3, use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 10, 64),), flow, ) def test_lstm_bidirectional(self, flow: TestFlow) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(bidirectional=True), + Model(bidirectional=True, use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 10, 64),), flow, ) def test_lstm_with_dropout(self, flow: TestFlow) -> None: # Note: Dropout is only effective with num_layers > 1 + use_quantizable_lstm = flow.quantize self._test_op( - Model(num_layers=2, dropout=0.2), + Model(num_layers=2, dropout=0.2, use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 10, 64),), flow, ) @@ -154,9 +189,10 @@ def test_lstm_with_dropout(self, flow: TestFlow) -> None: def test_lstm_with_initial_states(self, flow: TestFlow) -> None: # Create a model that accepts initial states class ModelWithStates(torch.nn.Module): - def __init__(self): + def __init__(self, use_quantizable_lstm: bool = False): super().__init__() - self.lstm = torch.nn.LSTM( + lstm_cls = _get_lstm_cls(use_quantizable_lstm) + self.lstm = lstm_cls( input_size=64, hidden_size=32, num_layers=2, @@ -169,9 +205,10 @@ def forward(self, x, h0, c0): batch_size = 1 num_layers = 2 hidden_size = 32 + use_quantizable_lstm = flow.quantize self._test_op( - ModelWithStates(), + ModelWithStates(use_quantizable_lstm=use_quantizable_lstm), ( torch.randn(batch_size, 10, 64), # input torch.randn(num_layers, batch_size, hidden_size), # h0 @@ -183,9 +220,10 @@ def forward(self, x, h0, c0): def test_lstm_return_hidden_states(self, flow: TestFlow) -> None: # Create a model that returns both output and hidden states class ModelWithHiddenStates(torch.nn.Module): - def __init__(self): + def __init__(self, use_quantizable_lstm: bool = False): super().__init__() - self.lstm = torch.nn.LSTM( + lstm_cls = _get_lstm_cls(use_quantizable_lstm) + self.lstm = lstm_cls( input_size=64, hidden_size=32, num_layers=2, @@ -200,9 +238,10 @@ def forward(self, x): batch_size = 1 seq_len = 10 input_size = 64 + use_quantizable_lstm = flow.quantize self._test_op( - ModelWithHiddenStates(), + ModelWithHiddenStates(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(batch_size, seq_len, input_size),), flow, ) diff --git a/backends/test/suite/operators/test_mean.py b/backends/test/suite/operators/test_mean.py index 746a4b16d9f..6c5c779364b 100644 --- a/backends/test/suite/operators/test_mean.py +++ b/backends/test/suite/operators/test_mean.py @@ -6,6 +6,7 @@ # pyre-unsafe +import unittest from typing import List, Optional, Tuple, Union import torch @@ -229,6 +230,7 @@ def test_mean_shapes(self, flow: TestFlow) -> None: flow, ) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_mean_edge_cases(self, flow: TestFlow) -> None: x = torch.tensor([[1.0, float("inf"), 3.0], [4.0, 5.0, float("inf")]]) self._test_op( diff --git a/backends/test/suite/operators/test_median.py b/backends/test/suite/operators/test_median.py index 93823b812ca..0b515d68efd 100644 --- a/backends/test/suite/operators/test_median.py +++ b/backends/test/suite/operators/test_median.py @@ -6,6 +6,7 @@ # pyre-unsafe +import unittest from typing import Optional import torch @@ -167,6 +168,7 @@ def test_median_shapes(self, flow: TestFlow) -> None: # 5D tensor self._test_op(MedianValueOnlyModel(), (torch.randn(2, 2, 3, 4, 5),), flow) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_median_edge_cases(self, flow: TestFlow) -> None: # Tensor with NaN (NaN should be propagated) x = torch.tensor([[1.0, float("nan"), 3.0], [4.0, 5.0, float("nan")]]) diff --git a/backends/test/suite/operators/test_neg.py b/backends/test/suite/operators/test_neg.py index 35c9d851817..bc1adede877 100644 --- a/backends/test/suite/operators/test_neg.py +++ b/backends/test/suite/operators/test_neg.py @@ -6,6 +6,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -55,6 +57,7 @@ def test_neg_shapes(self, flow: TestFlow) -> None: NegModel(), (torch.randn(3, 4, 5),), flow, generate_random_test_inputs=False ) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_neg_edge_cases(self, flow: TestFlow) -> None: # Test edge cases diff --git a/backends/test/suite/operators/test_pow.py b/backends/test/suite/operators/test_pow.py index 334038d73d3..3082ad6ebaf 100644 --- a/backends/test/suite/operators/test_pow.py +++ b/backends/test/suite/operators/test_pow.py @@ -6,6 +6,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -127,6 +129,7 @@ def test_pow_shapes(self, flow: TestFlow) -> None: model, (torch.rand(3, 4, 5) + 0.1,), flow, generate_random_test_inputs=False ) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_pow_edge_cases(self, flow: TestFlow) -> None: # Test edge cases diff --git a/backends/test/suite/operators/test_relu.py b/backends/test/suite/operators/test_relu.py index c9f416f090f..3c4ef2a98d0 100644 --- a/backends/test/suite/operators/test_relu.py +++ b/backends/test/suite/operators/test_relu.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -38,5 +40,6 @@ def test_relu_f32_single_dim(self, flow: TestFlow) -> None: def test_relu_f32_multi_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) + @unittest.skip("In place activations aren't properly defunctionalized yet.") def test_relu_f32_inplace(self, flow: TestFlow) -> None: self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/operators/test_round.py b/backends/test/suite/operators/test_round.py index ca8e6368d48..3a3577bea32 100644 --- a/backends/test/suite/operators/test_round.py +++ b/backends/test/suite/operators/test_round.py @@ -6,6 +6,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -52,6 +54,7 @@ def test_round_values(self, flow: TestFlow) -> None: x = torch.arange(-5, 5, 0.5) # [-5.0, -4.5, -4.0, ..., 4.0, 4.5] self._test_op(RoundModel(), (x,), flow, generate_random_test_inputs=False) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_round_edge_cases(self, flow: TestFlow) -> None: # Test edge cases @@ -98,6 +101,7 @@ def test_round_decimals(self, flow: TestFlow) -> None: RoundModel(decimals=-2), (x,), flow, generate_random_test_inputs=False ) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_round_decimals_edge_cases(self, flow: TestFlow) -> None: # Test edge cases with decimal places diff --git a/backends/test/suite/operators/test_rsqrt.py b/backends/test/suite/operators/test_rsqrt.py index 175bbcdb2cc..0b7c9739cf7 100644 --- a/backends/test/suite/operators/test_rsqrt.py +++ b/backends/test/suite/operators/test_rsqrt.py @@ -6,6 +6,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -31,20 +33,39 @@ def test_rsqrt_dtype(self, flow: TestFlow, dtype) -> None: # Test with different dtypes model = RsqrtModel().to(dtype) # Use positive values only for rsqrt to avoid division by zero - self._test_op(model, (torch.rand(10, 10).to(dtype) + 0.01,), flow) + self._test_op( + model, + (torch.rand(10, 10).to(dtype) + 0.01,), + flow, + generate_random_test_inputs=False, + ) def test_rsqrt_shapes(self, flow: TestFlow) -> None: # Test with different tensor shapes - # 1D tensor - self._test_op(RsqrtModel(), (torch.rand(20) + 0.01,), flow) - + self._test_op( + RsqrtModel(), + (torch.rand(20) + 0.01,), + flow, + generate_random_test_inputs=False, + ) # 2D tensor - self._test_op(RsqrtModel(), (torch.rand(5, 10) + 0.01,), flow) + self._test_op( + RsqrtModel(), + (torch.rand(5, 10) + 0.01,), + flow, + generate_random_test_inputs=False, + ) # 3D tensor - self._test_op(RsqrtModel(), (torch.rand(3, 4, 5) + 0.01,), flow) + self._test_op( + RsqrtModel(), + (torch.rand(3, 4, 5) + 0.01,), + flow, + generate_random_test_inputs=False, + ) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_rsqrt_edge_cases(self, flow: TestFlow) -> None: # Tensor with infinity x = torch.tensor([float("inf"), 1.0, 4.0]) diff --git a/backends/test/suite/operators/test_silu.py b/backends/test/suite/operators/test_silu.py index 69b6576734f..cf6d343f271 100644 --- a/backends/test/suite/operators/test_silu.py +++ b/backends/test/suite/operators/test_silu.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -38,6 +40,7 @@ def test_silu_f32_single_dim(self, flow: TestFlow) -> None: def test_silu_f32_multi_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) + @unittest.skip("In place activations aren't properly defunctionalized yet.") def test_silu_f32_inplace(self, flow: TestFlow) -> None: self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/operators/test_sqrt.py b/backends/test/suite/operators/test_sqrt.py index c3874dcb209..4a3f931204d 100644 --- a/backends/test/suite/operators/test_sqrt.py +++ b/backends/test/suite/operators/test_sqrt.py @@ -6,6 +6,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -31,20 +33,32 @@ def test_sqrt_dtype(self, flow: TestFlow, dtype) -> None: # Test with different dtypes model = SqrtModel().to(dtype) # Use non-negative values only for sqrt - self._test_op(model, (torch.rand(10, 10).to(dtype),), flow) + self._test_op( + model, + (torch.rand(10, 10).to(dtype),), + flow, + generate_random_test_inputs=False, + ) def test_sqrt_shapes(self, flow: TestFlow) -> None: # Test with different tensor shapes # 1D tensor - self._test_op(SqrtModel(), (torch.rand(20),), flow) + self._test_op( + SqrtModel(), (torch.rand(20),), flow, generate_random_test_inputs=False + ) # 2D tensor - self._test_op(SqrtModel(), (torch.rand(5, 10),), flow) + self._test_op( + SqrtModel(), (torch.rand(5, 10),), flow, generate_random_test_inputs=False + ) # 3D tensor - self._test_op(SqrtModel(), (torch.rand(3, 4, 5),), flow) + self._test_op( + SqrtModel(), (torch.rand(3, 4, 5),), flow, generate_random_test_inputs=False + ) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_sqrt_edge_cases(self, flow: TestFlow) -> None: # Test edge cases diff --git a/backends/test/suite/operators/test_square.py b/backends/test/suite/operators/test_square.py index 52cd739bf9f..39ed212e426 100644 --- a/backends/test/suite/operators/test_square.py +++ b/backends/test/suite/operators/test_square.py @@ -6,6 +6,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -44,6 +46,7 @@ def test_square_shapes(self, flow: TestFlow) -> None: # 3D tensor self._test_op(SquareModel(), (torch.randn(3, 4, 5),), flow) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_square_edge_cases(self, flow: TestFlow) -> None: # Test edge cases diff --git a/backends/test/suite/operators/test_sub.py b/backends/test/suite/operators/test_sub.py index be7b871fdad..2243eb6ee71 100644 --- a/backends/test/suite/operators/test_sub.py +++ b/backends/test/suite/operators/test_sub.py @@ -6,7 +6,6 @@ # pyre-unsafe - import torch from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_threshold.py b/backends/test/suite/operators/test_threshold.py index 42b6fb801e5..3f69a9f41fe 100644 --- a/backends/test/suite/operators/test_threshold.py +++ b/backends/test/suite/operators/test_threshold.py @@ -7,6 +7,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -51,6 +53,7 @@ def test_threshold_f32_custom_value(self, flow: TestFlow) -> None: def test_threshold_f32_custom_threshold_value(self, flow: TestFlow) -> None: self._test_op(Model(threshold=0.5, value=1.0), (torch.randn(3, 4, 5),), flow) + @unittest.skip("In place activations aren't properly defunctionalized yet.") def test_threshold_f32_inplace(self, flow: TestFlow) -> None: self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/operators/test_trunc.py b/backends/test/suite/operators/test_trunc.py index 1d6d18817bd..71dcbf59176 100644 --- a/backends/test/suite/operators/test_trunc.py +++ b/backends/test/suite/operators/test_trunc.py @@ -6,6 +6,8 @@ # pyre-unsafe +import unittest + import torch from executorch.backends.test.suite.flow import TestFlow @@ -44,6 +46,7 @@ def test_trunc_shapes(self, flow: TestFlow) -> None: # 3D tensor self._test_op(TruncModel(), (torch.randn(3, 4, 5) * 5,), flow) + @unittest.skip("NaN and Inf are not enforced for backends.") def test_trunc_edge_cases(self, flow: TestFlow) -> None: # Test edge cases diff --git a/backends/test/suite/reporting.py b/backends/test/suite/reporting.py index ce8a48dcc12..09e950ab672 100644 --- a/backends/test/suite/reporting.py +++ b/backends/test/suite/reporting.py @@ -1,4 +1,5 @@ import csv +import json from collections import Counter from dataclasses import dataclass, field @@ -45,6 +46,8 @@ ] ) +CSV_FIELD_NAMES.append("Error") + # Operators that are excluded from the counts returned by count_ops. These are used to # exclude operatations that are not logically relevant or delegatable to backends. @@ -341,7 +344,9 @@ def _sum_op_counts(counter: Counter | None) -> int | None: def _serialize_params(params: dict[str, Any] | None) -> str: if params is not None: - return str(dict(sorted(params.items()))) + # Convert values to strings - JSON conversion doesn't like dtypes. + str_params = {k: str(v) for k, v in params.items()} + return json.dumps(str_params) else: return "" @@ -365,6 +370,15 @@ def write_csv_header(output: TextIO): def write_csv_row(record: TestCaseSummary, output: TextIO): writer = csv.DictWriter(output, CSV_FIELD_NAMES) + # Truncate error message if it's too long, keeping first and last 200 characters + error_message = "" + if record.error is not None: + error_str = str(record.error) + if len(error_str) > 400: + error_message = error_str[:200] + "..." + error_str[-200:] + else: + error_message = error_str + row = { "Test ID": record.name, "Test Case": record.base_name, @@ -373,6 +387,7 @@ def write_csv_row(record: TestCaseSummary, output: TextIO): "Params": _serialize_params(record.params), "Result": record.result.to_short_str(), "Result Detail": record.result.to_detail_str(), + "Error": error_message, "Delegated": "True" if record.is_delegated() else "False", "Quantize Time (s)": ( f"{record.quantize_time.total_seconds():.3f}" diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index 1f84db9c730..a6d7d07bce0 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -15,6 +15,7 @@ UNSUPPORTED_PORTABLE_OPS = { "aten::_embedding_bag", "aten::_adaptive_avg_pool2d", + "aten::adaptive_max_pool2d", "aten::median", "aten::median.dim", "aten::round.decimals", @@ -34,6 +35,7 @@ TestResult, ) from executorch.exir import EdgeProgramManager +from executorch.exir.dialects._ops import ops as exir_ops # A list of all runnable test suites and the corresponding python package. @@ -43,6 +45,24 @@ } +def _graph_has_unsupported_patterns(program: torch.export.ExportedProgram) -> bool: + # Returns true if the model contains patterns that will fail when running on the ET + # portable kernel library. + + # Check for 3d convolutions. All convs (1d, 2d, 3d) use the same op, so we need to look at + # the input meta to determine the rank. + for node in program.graph.nodes: + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.convolution.default + ): + in_rank = node.args[0].meta["val"].dim() + if in_rank > 4: + return True + + return False + + def _get_test_seed(test_base_name: str) -> int: # Set the seed based on the test base name to give consistent inputs between backends. Add the # run seed to allow for reproducible results, but still allow for run-to-run variation. @@ -162,7 +182,7 @@ def build_result( # Check if any undelegated ops are in the unsupported ops set. has_unsupported_ops = any( op in UNSUPPORTED_PORTABLE_OPS for op in undelegated_op_counts.keys() - ) + ) or _graph_has_unsupported_patterns(edge_manager._etrecord.edge_dialect_program) # Skip the test if there are unsupported portable ops remaining. if has_unsupported_ops: @@ -171,8 +191,11 @@ def build_result( # Only run the runtime portion if something was delegated (or the flow doesn't delegate) if is_delegated or not flow.is_delegated: try: - tester.to_executorch().serialize() - extra_stats["pte_size_bytes"] = len(tester.get_artifact()) + tester.to_executorch() + + if flow.supports_serialize: + tester.serialize() + extra_stats["pte_size_bytes"] = len(tester.get_artifact()) except Exception as e: # We could introduce a result value for this, but I'm not sure it's necessary. # We can do this if we ever see to_executorch() or serialize() fail due a backend issue. diff --git a/backends/test/suite/tests/test_reporting.py b/backends/test/suite/tests/test_reporting.py index 58ff76cba17..e42681fc678 100644 --- a/backends/test/suite/tests/test_reporting.py +++ b/backends/test/suite/tests/test_reporting.py @@ -1,3 +1,4 @@ +import json import unittest from csv import DictReader @@ -102,14 +103,16 @@ def test_csv_report_simple(self): self.assertEqual(records[2]["Test Case"], "test2") self.assertEqual(records[2]["Flow"], "flow1") self.assertEqual(records[2]["Result"], "Pass") - self.assertEqual(records[2]["Params"], str({"dtype": torch.float32})) + self.assertEqual(records[2]["Params"], json.dumps({"dtype": "torch.float32"})) # Validate fourth record: test2, backend2, EXPORT_FAIL with use_dynamic_shapes param self.assertEqual(records[3]["Test ID"], "test2_backend2_flow1") self.assertEqual(records[3]["Test Case"], "test2") self.assertEqual(records[3]["Flow"], "flow1") self.assertEqual(records[3]["Result"], "Skip") - self.assertEqual(records[3]["Params"], str({"use_dynamic_shapes": True})) + self.assertEqual( + records[3]["Params"], json.dumps({"use_dynamic_shapes": "True"}) + ) def test_count_ops(self): """ diff --git a/backends/transforms/decompose_sdpa.py b/backends/transforms/decompose_sdpa.py index d49e0da0c9b..6c36d1803fc 100644 --- a/backends/transforms/decompose_sdpa.py +++ b/backends/transforms/decompose_sdpa.py @@ -7,6 +7,7 @@ # pyre-strict import math +from typing import Set, Type import torch from executorch.exir.pass_base import ExportPass, PassResult @@ -19,6 +20,8 @@ class DecomposeScaledDotProductAttention(ExportPass): Decompose from scaled_dot_product_attention to multiple nodes. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def __init__(self, allow_non_fake_inputs: bool = True) -> None: super().__init__() # With allow_non_fake_inputs=False, we don't get _unsafe_view ops diff --git a/backends/transforms/fuse_view_copy.py b/backends/transforms/fuse_view_copy.py index c740515cdcc..b7c52f95fa3 100644 --- a/backends/transforms/fuse_view_copy.py +++ b/backends/transforms/fuse_view_copy.py @@ -7,14 +7,48 @@ # pyre-strict +from typing import Set, Type + import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +UNARY_ELEMENTWISE_OPS = [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.alias_copy.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.aten._to_copy.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.ceil.default, + exir_ops.edge.aten.floor.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.round.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.silu.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.sign.default, + exir_ops.edge.aten.reciprocal.default, + exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.log.default, +] + + def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: """ - Find chains of view_copy nodes and merge them into one view_copy node. + Find chains of view_copy nodes and unary elementwise ops and set all + view_copy nodes to have the final shape. The views will then be removed + by the remove_noop_view_copy call. + Only merges view_copy nodes that are not used by any other nodes. """ ops = exir_ops.edge @@ -22,21 +56,25 @@ def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool] modified = False for node in graph.nodes: if node.op == "call_function" and node.target == view_op: - # find ending view_copy node in chain + # Find a chain of unary elementwise ops and save all view_copy nodes end_node = node + view_ops = [node] while ( end_node.op == "call_function" - and end_node.target == view_op + and end_node.target in UNARY_ELEMENTWISE_OPS and len(end_node.users) == 1 - and list(end_node.users)[0].target == view_op + and list(end_node.users)[0].target in UNARY_ELEMENTWISE_OPS ): end_node = list(end_node.users)[0] - # we can swap the first node's shape arg with the last node's shape arg - if node != end_node: - with graph.inserting_after(node): - new_args = (node.args[0], end_node.args[1]) + if end_node.target == view_op: + view_ops.append(end_node) + + # Set all view_copy nodes to have the final shape + if len(view_ops) > 1: + final_shape = view_ops[-1].args[1] + for node in view_ops: + new_args = (node.args[0], final_shape) node.args = new_args - end_node.replace_all_uses_with(node) modified = True graph.eliminate_dead_code() @@ -62,11 +100,17 @@ def remove_noop_view_copy(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: class FuseViewCopyTransform(ExportPass): + _passes_required_after: Set[Type[ExportPass]] = set() + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph_module.graph, merge_modified = merge_view_copy_chains(graph_module.graph) - graph_module.graph, noop_modified = remove_noop_view_copy(graph_module.graph) - modified = merge_modified or noop_modified + graph_module.graph, modified = merge_view_copy_chains(graph_module.graph) if modified: graph_module.recompile() graph_module = super().call(graph_module).graph_module + + graph_module.graph, modified = remove_noop_view_copy(graph_module.graph) + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) diff --git a/backends/transforms/remove_clone_ops.py b/backends/transforms/remove_clone_ops.py index 01fe2ee26a4..07cc3e9efb1 100644 --- a/backends/transforms/remove_clone_ops.py +++ b/backends/transforms/remove_clone_ops.py @@ -25,11 +25,18 @@ class RemoveCloneOpsTransform(ExportPass): exir_ops.edge.dim_order_ops._clone_dim_order.default, } - def __init__(self) -> None: + def __init__( + self, + preserve_input_output_copies: bool = False, + eliminate_quant_dequant_pairs: bool = True, + ) -> None: super().__init__() + self._preserve_input_output_copies = preserve_input_output_copies + self._eliminate_quant_dequant_pairs = eliminate_quant_dequant_pairs - def _remove(self, graph_module: torch.fx.GraphModule) -> None: + def _remove(self, graph_module: torch.fx.GraphModule) -> bool: dequant_nodes = [] + modified = False for n in graph_module.graph.nodes: if n.target not in self.clone_ops: @@ -38,6 +45,12 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None: if self._is_non_identity_clone(n): continue + # If preserve_input_output_copies is set, don't remove clones that directly + # copy from input to output. + if self._is_input_output_copy(n) and self._preserve_input_output_copies: + continue + + modified = True to_be_removed = n for user_n in list(n.users.keys()): user_n.replace_input_with(n, n.args[0]) @@ -45,13 +58,18 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None: dequant_nodes += [n.args[0]] graph_module.graph.erase_node(to_be_removed) - eliminate_dq_q(graph_module, dequant_nodes) + if self._eliminate_quant_dequant_pairs: + eliminate_dq_q(graph_module, dequant_nodes) + + return modified def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self._remove(graph_module) - graph_module.recompile() - dead_code_elimination_pass(graph_module) - return PassResult(graph_module, True) + if self._remove(graph_module): + graph_module.recompile() + dead_code_elimination_pass(graph_module) + return PassResult(graph_module, True) + else: + return PassResult(graph_module, False) def _is_non_identity_clone(self, node: torch.fx.Node) -> bool: """Return True if clone has modified memory layout or dim order.""" @@ -76,3 +94,16 @@ def _is_non_identity_clone(self, node: torch.fx.Node) -> bool: ) return False + + def _is_input_output_copy(self, node: torch.fx.Node) -> bool: + """Return True if the node input is a graph input and output goes into an output node.""" + + input_node = node.args[0] + if input_node.op != "placeholder": + return False + + for users in node.users: + if users.op == "output": + return True + + return False diff --git a/backends/transforms/remove_getitem_op.py b/backends/transforms/remove_getitem_op.py index 733393b6d9a..9908df70765 100644 --- a/backends/transforms/remove_getitem_op.py +++ b/backends/transforms/remove_getitem_op.py @@ -1,13 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy + import torch from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.pass_base import ExportPass, PassResult, PROTECTED_KEYS class RemoveGetItemPass(ExportPass): @@ -77,6 +80,10 @@ def call(self, graph_module: torch.fx.GraphModule): args=node.args, kwargs=node.kwargs, ) + new_max_wd.meta = node.meta.copy() + new_max_wd.meta["val"] = new_max_wd.meta["val"][0] + + _copy_node_metadata(node, new_max_wd) getitem_node.replace_all_uses_with(new_max_wd) @@ -88,3 +95,15 @@ def call(self, graph_module: torch.fx.GraphModule): graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True) + + +def _copy_node_metadata(node: torch.fx.Node, new_max_wd: torch.fx.Node): + """Copy metadata from original node to new node.""" + + for key, value in node.meta.items(): + if key in PROTECTED_KEYS: + continue + try: + new_max_wd.meta[key] = copy.deepcopy(value) + except Exception: + new_max_wd.meta[key] = value diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index ca09d34c2fe..11a5b26e095 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -182,6 +182,7 @@ def define_common_targets(): ], visibility = [ "//executorch/backends/...", + "@EXECUTORCH_CLIENTS", ], deps = [ "//caffe2:torch", diff --git a/backends/vulkan/CMakeLists.txt b/backends/vulkan/CMakeLists.txt index 29ff90e7293..d9acde79ecf 100644 --- a/backends/vulkan/CMakeLists.txt +++ b/backends/vulkan/CMakeLists.txt @@ -105,17 +105,33 @@ target_include_directories( $ ) +# vulkan runtime utils files + +file(GLOB_RECURSE vulkan_runtime_utils_cpp ${RUNTIME_PATH}/utils/*.cpp) + # vulkan_backend +# Try to find boost to log stack traces when throwing exceptions +find_package(Boost 1.89 COMPONENTS stacktrace_basic stacktrace_addr2line) + file(GLOB vulkan_backend_cpp ${RUNTIME_PATH}/*.cpp) list(APPEND vulkan_backend_cpp ${vulkan_graph_cpp}) list(APPEND vulkan_backend_cpp ${vulkan_standard_shaders_cpp}) +list(APPEND vulkan_backend_cpp ${vulkan_runtime_utils_cpp}) add_library(vulkan_backend ${vulkan_backend_cpp}) target_include_directories( vulkan_backend PRIVATE ${SCHEMA_INCLUDE_DIR} ${COMMON_INCLUDES} ) target_link_libraries(vulkan_backend PRIVATE vulkan_schema executorch_core) +# Optionally link boost for stacktraces if boost is available +if(DEFINED Boost_STACKTRACE_BASIC_LIBRARY) + target_link_libraries( + vulkan_backend PRIVATE ${Boost_STACKTRACE_LIBRARY} + ${Boost_STACKTRACE_ADDR2LINE_LIBRARY} + ) + list(APPEND VULKAN_CXX_FLAGS "-DETVK_BOOST_STACKTRACE_AVAILABLE") +endif() target_compile_options(vulkan_backend PRIVATE ${VULKAN_CXX_FLAGS}) # Link this library with --whole-archive due to dynamic backend registration executorch_target_link_options_shared_lib(vulkan_backend) @@ -127,7 +143,7 @@ set_property(TARGET vulkan_backend PROPERTY CXX_STANDARD 17) install( TARGETS vulkan_backend vulkan_schema EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} INCLUDES DESTINATION ${COMMON_INCLUDES} ) diff --git a/backends/vulkan/README.md b/backends/vulkan/README.md index e0a953d05fe..b51a736c7df 100644 --- a/backends/vulkan/README.md +++ b/backends/vulkan/README.md @@ -1,205 +1,4 @@ -# Vulkan Backend +# The ExecuTorch Vulkan Backend -The ExecuTorch Vulkan delegate is a native GPU delegate for ExecuTorch that is -built on top of the cross-platform Vulkan GPU API standard. It is primarily -designed to leverage the GPU to accelerate model inference on Android devices, -but can be used on any platform that supports an implementation of Vulkan: -laptops, servers, and edge devices. - -::::{note} -The Vulkan delegate is currently under active development, and its components -are subject to change. -:::: - -## What is Vulkan? - -Vulkan is a low-level GPU API specification developed as a successor to OpenGL. -It is designed to offer developers more explicit control over GPUs compared to -previous specifications in order to reduce overhead and maximize the -capabilities of the modern graphics hardware. - -Vulkan has been widely adopted among GPU vendors, and most modern GPUs (both -desktop and mobile) in the market support Vulkan. Vulkan is also included in -Android from Android 7.0 onwards. - -**Note that Vulkan is a GPU API, not a GPU Math Library**. That is to say it -provides a way to execute compute and graphics operations on a GPU, but does not -come with a built-in library of performant compute kernels. - -## The Vulkan Compute Library - -The ExecuTorch Vulkan Delegate is a wrapper around a standalone runtime known as -the **Vulkan Compute Library**. The aim of the Vulkan Compute Library is to -provide GPU implementations for PyTorch operators via GLSL compute shaders. - -The Vulkan Compute Library is a fork/iteration of the [PyTorch Vulkan Backend](https://pytorch.org/tutorials/prototype/vulkan_workflow.html). -The core components of the PyTorch Vulkan backend were forked into ExecuTorch -and adapted for an AOT graph-mode style of model inference (as opposed to -PyTorch which adopted an eager execution style of model inference). - -The components of the Vulkan Compute Library are contained in the -`executorch/backends/vulkan/runtime/` directory. The core components are listed -and described below: - -``` -runtime/ -├── api/ .................... Wrapper API around Vulkan to manage Vulkan objects -└── graph/ .................. ComputeGraph class which implements graph mode inference - └── ops/ ................ Base directory for operator implementations - ├── glsl/ ........... GLSL compute shaders - │ ├── *.glsl - │ └── conv2d.glsl - └── impl/ ........... C++ code to dispatch GPU compute shaders - ├── *.cpp - └── Conv2d.cpp -``` - -## Features - -The Vulkan delegate currently supports the following features: - -* **Memory Planning** - * Intermediate tensors whose lifetimes do not overlap will share memory allocations. This reduces the peak memory usage of model inference. -* **Capability Based Partitioning**: - * A graph can be partially lowered to the Vulkan delegate via a partitioner, which will identify nodes (i.e. operators) that are supported by the Vulkan delegate and lower only supported subgraphs -* **Support for upper-bound dynamic shapes**: - * Tensors can change shape between inferences as long as its current shape is smaller than the bounds specified during lowering - -In addition to increasing operator coverage, the following features are -currently in development: - -* **Quantization Support** - * We are currently working on support for 8-bit dynamic quantization, with plans to extend to other quantization schemes in the future. -* **Memory Layout Management** - * Memory layout is an important factor to optimizing performance. We plan to introduce graph passes to introduce memory layout transitions throughout a graph to optimize memory-layout sensitive operators such as Convolution and Matrix Multiplication. -* **Selective Build** - * We plan to make it possible to control build size by selecting which operators/shaders you want to build with - -## End to End Example - -To further understand the features of the Vulkan Delegate and how to use it, -consider the following end to end example with a simple single operator model. - -### Compile and lower a model to the Vulkan Delegate - -Assuming ExecuTorch has been set up and installed, the following script can be -used to produce a lowered MobileNet V2 model as `vulkan_mobilenetv2.pte`. - -Once ExecuTorch has been set up and installed, the following script can be used -to generate a simple model and lower it to the Vulkan delegate. - -``` -# Note: this script is the same as the script from the "Setting up ExecuTorch" -# page, with one minor addition to lower to the Vulkan backend. -import torch -from torch.export import export -from executorch.exir import to_edge - -from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner - -# Start with a PyTorch model that adds two input tensors (matrices) -class Add(torch.nn.Module): - def __init__(self): - super(Add, self).__init__() - - def forward(self, x: torch.Tensor, y: torch.Tensor): - return x + y - -# 1. torch.export: Defines the program with the ATen operator set. -aten_dialect = export(Add(), (torch.ones(1), torch.ones(1))) - -# 2. to_edge: Make optimizations for Edge devices -edge_program = to_edge(aten_dialect) -# 2.1 Lower to the Vulkan backend -edge_program = edge_program.to_backend(VulkanPartitioner()) - -# 3. to_executorch: Convert the graph to an ExecuTorch program -executorch_program = edge_program.to_executorch() - -# 4. Save the compiled .pte program -with open("vk_add.pte", "wb") as file: - file.write(executorch_program.buffer) -``` - -Like other ExecuTorch delegates, a model can be lowered to the Vulkan Delegate -using the `to_backend()` API. The Vulkan Delegate implements the -`VulkanPartitioner` class which identifies nodes (i.e. operators) in the graph -that are supported by the Vulkan delegate, and separates compatible sections of -the model to be executed on the GPU. - -This means the a model can be lowered to the Vulkan delegate even if it contains -some unsupported operators. This will just mean that only parts of the graph -will be executed on the GPU. - - -::::{note} -The [supported ops list](https://github.com/pytorch/executorch/blob/main/backends/vulkan/op_registry.py#L194) -Vulkan partitioner code can be inspected to examine which ops are currently -implemented in the Vulkan delegate. -:::: - -### Build Vulkan Delegate libraries - -The easiest way to build and test the Vulkan Delegate is to build for Android -and test on a local Android device. Android devices have built in support for -Vulkan, and the Android NDK ships with a GLSL compiler which is needed to -compile the Vulkan Compute Library's GLSL compute shaders. - -The Vulkan Delegate libraries can be built by setting `-DEXECUTORCH_BUILD_VULKAN=ON` -when building with CMake. - -First, make sure that you have the Android NDK installed; any NDK version past -NDK r19c should work. Note that the examples in this doc have been validated with -NDK r27b. The Android SDK should also be installed so that you have access to `adb`. - -The instructions in this page assumes that the following environment variables -are set. - -```shell -export ANDROID_NDK= -# Select the appropriate Android ABI for your device -export ANDROID_ABI=arm64-v8a -# All subsequent commands should be performed from ExecuTorch repo root -cd -# Make sure adb works -adb --version -``` - -To build and install ExecuTorch libraries (for Android) with the Vulkan -Delegate: - -```shell -# From executorch root directory -(rm -rf cmake-android-out && \ - pp cmake . -DCMAKE_INSTALL_PREFIX=cmake-android-out \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DANDROID_ABI=$ANDROID_ABI \ - -DEXECUTORCH_BUILD_VULKAN=ON \ - -DPYTHON_EXECUTABLE=python \ - -Bcmake-android-out && \ - cmake --build cmake-android-out -j16 --target install) -``` - -### Run the Vulkan model on device - -::::{note} -Since operator support is currently limited, only binary arithmetic operators -will run on the GPU. Expect inference to be slow as the majority of operators -are being executed via Portable operators. -:::: - -Now, the partially delegated model can be executed (partially) on your device's -GPU! - -```shell -# Build a model runner binary linked with the Vulkan delegate libs -cmake --build cmake-android-out --target executor_runner -j32 - -# Push model to device -adb push vk_add.pte /data/local/tmp/vk_add.pte -# Push binary to device -adb push cmake-android-out/executor_runner /data/local/tmp/runner_bin - -# Run the model -adb shell /data/local/tmp/runner_bin --model_path /data/local/tmp/vk_add.pte -``` +Please see the [Vulkan Backend Overview](../../docs/source/backends/vulkan/vulkan-overview.md) +to learn more about the ExecuTorch Vulkan Backend. diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index aed41114ada..453b4814637 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -63,19 +63,6 @@ runtime.python_library( ], ) -runtime.python_library( - name = "remove_local_scalar_dense", - srcs = ["remove_local_scalar_dense_ops.py"], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - "//caffe2:torch", - "//executorch/exir:pass_base", - "//executorch/exir/dialects:lib", - ], -) - runtime.python_library( name = "remove_redundant_ops", srcs = ["remove_redundant_ops.py"], @@ -148,7 +135,6 @@ runtime.python_library( ":fuse_quantized_ops", ":insert_prepack_nodes", ":remove_asserts", - ":remove_local_scalar_dense", ":remove_redundant_ops", ":squeeze_unsqueeze_inputs", ":tag_memory_meta_pass", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index f4ef6b2ac0e..d6a6823ca88 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -16,9 +16,6 @@ remove_asserts, RemoveAssertsTransform, ) -from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( - RemoveLocalScalarDenseOpsTransform, -) from executorch.backends.vulkan._passes.remove_redundant_ops import ( RemoveRedundantOpsTransform, ) @@ -34,7 +31,6 @@ "insert_prepack_nodes", "remove_asserts", "RemoveAssertsTransform", - "RemoveLocalScalarDenseOpsTransform", "RemoveRedundantOpsTransform", "SqueezeUnsqueezeInputs", "TagMemoryMetaPass", diff --git a/backends/vulkan/_passes/fold_qdq.py b/backends/vulkan/_passes/fold_qdq.py index 3beccc2205c..a6a5e751c05 100644 --- a/backends/vulkan/_passes/fold_qdq.py +++ b/backends/vulkan/_passes/fold_qdq.py @@ -17,9 +17,8 @@ class FoldQDQPass(ExportPass): valid quant op patterns have already been fused before this pass. """ - def __init__(self, edge_program: torch.export.ExportedProgram): - super(FoldQDQPass, self).__init__() - self.edge_program = edge_program + def __init__(self): + super().__init__() def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: diff --git a/backends/vulkan/_passes/fuse_patterns.py b/backends/vulkan/_passes/fuse_patterns.py index 6ced1f32a7c..1575dd6a4f6 100644 --- a/backends/vulkan/_passes/fuse_patterns.py +++ b/backends/vulkan/_passes/fuse_patterns.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional + import executorch.backends.vulkan.patterns as vk_patterns import torch @@ -13,13 +15,15 @@ class FusePatternsPass(ExportPass): - def __init__(self, exported_program: ExportedProgram) -> None: + def __init__(self) -> None: super().__init__() - self.program = exported_program + self._exported_program: Optional[ExportedProgram] = None def call(self, graph_module: torch.fx.GraphModule): + assert self._exported_program is not None + total_replaced = vk_patterns.replace_all_fusable_subgraphs( - self.program, graph_module + self._exported_program, graph_module ) if total_replaced > 0: diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py index ca9f7541159..bb8cf5f2e64 100644 --- a/backends/vulkan/_passes/fuse_quantized_ops.py +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -211,18 +211,20 @@ def fuse_into_linear_qcnw_node( class FuseQuantizedOpsTransform(ExportPass): - def __init__(self, exported_program: ExportedProgram) -> None: + def __init__(self) -> None: super().__init__() - self.program = exported_program + self._exported_program: Optional[ExportedProgram] = None def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + assert self._exported_program is not None + for node in graph_module.graph.nodes: # Check for linear_qcnw pattern (weight-only quantization) - qcnw_details = matches_linear_qcnw_pattern(self.program, node) + qcnw_details = matches_linear_qcnw_pattern(self._exported_program, node) if qcnw_details is not None: qcnw_method, qcnw_nbits = qcnw_details fuse_into_linear_qcnw_node( - self.program, graph_module, node, qcnw_method, qcnw_nbits + self._exported_program, graph_module, node, qcnw_method, qcnw_nbits ) continue diff --git a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py b/backends/vulkan/_passes/remove_local_scalar_dense_ops.py deleted file mode 100644 index 6ce3572ec0c..00000000000 --- a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import torch -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - -from torch._subclasses.fake_tensor import FakeTensor - - -def node_is_local_scalar_dense_chain(node: torch.fx.Node) -> bool: - """ - Converting a tensor to a scalar via tensor[0].item() creates a index_select + - local_scalar_dense pattern in the graph. Check if a node is the start of this pattern. - """ - if ( - node.op == "call_function" - and node.target == exir_ops.edge.aten.select_copy.int - and len(node.users) == 1 - ): - user = list(node.users.keys())[0] - return user.target == torch.ops.aten._local_scalar_dense.default - - return False - - -def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None: - """ - A scalar tensor in the Vulkan backend is a tensor that can be represented as a scalar - value instead of a Tensor object. The criteria for identifying a tensor as a scalar - tensor are as follows: - - 1. The tensor has only 1 element - 2. One of the node's uses is converting it to a scalar via `tensor[0].item()`, which - creates a index_select + local_scalar_dense pattern in the graph - - If any of these criteria are fulfilled, then tag the node for the tensor to mark it - so that it is added as a scalar value during serialization. - """ - tensor_val = node.meta["val"] - if not isinstance(tensor_val, FakeTensor): - return - - # Scalar tensors must have only one element - if tensor_val.numel() != 1: - return - - for user in node.users: - if node_is_local_scalar_dense_chain(user): - node.meta["etvk_is_scalar_tensor"] = True - - -def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None: - """ - Remove the index_select + local_scalar_dense pattern in the graph in favor of passing - the original scalar tensor directly. - """ - replace_node = node.args[0] - assert isinstance(replace_node, torch.fx.Node) - # If the argument to the local_scalar_dense op is a select op with only - # one user, and the argument to the select op is a tensor with only one - # element (i.e. a scalar tensor), then replace the entire pattern with the - # scalar tensor. - if ( - replace_node.op == "call_function" - and replace_node.target == exir_ops.edge.aten.select_copy.int - ): - # pyre-ignore - if replace_node.args[0].meta["val"].numel() == 1: - replace_node = replace_node.args[0] - assert isinstance(replace_node, torch.fx.Node) - assert replace_node.meta.get("etvk_is_scalar_tensor", True) - - with graph.inserting_after(node): - node.replace_all_uses_with(replace_node) - - -def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph: - """ - The purpose of this pass is twofold: - 1. Tag scalar tensors (see `tag_node_if_scalar_tensor()` for the criteria) - 2. Remove the index_select + local_scalar_dense pattern in the graph in favor of - passing the original scalar tensor directly (see `remove_local_scalar_dense_chain()`) - - This makes it easier to deal with scalar tensors in the Vulkan backend. In particular, - it allows serializing scalar tensors as SymInt objects instead of Tensor objects. - Because scalar tensors are often used to inform tensor shapes, their values need to - be easily accessed by the CPU during resizing logic, while also being able to reflect - updates to their value in any GPU shaders that reference them. - """ - target_op = torch.ops.aten._local_scalar_dense.default - for node in graph.nodes: - tag_node_if_scalar_tensor(node) - - if node.op == "call_function" and node.target == target_op: - remove_local_scalar_dense_chain(graph, node) - - graph.eliminate_dead_code() - return graph - - -class RemoveLocalScalarDenseOpsTransform(ExportPass): - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph) - return PassResult(graph_module, True) diff --git a/backends/vulkan/_passes/remove_redundant_ops.py b/backends/vulkan/_passes/remove_redundant_ops.py index 8e602dd17b4..25bdd34de70 100644 --- a/backends/vulkan/_passes/remove_redundant_ops.py +++ b/backends/vulkan/_passes/remove_redundant_ops.py @@ -31,35 +31,37 @@ class RemoveRedundantOpsTransform(ExportPass): exir_ops.edge.aten.lift_fresh_copy.default, exir_ops.edge.dim_order_ops._to_dim_order_copy.default, exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.aten.expand_copy.default, } def __init__(self) -> None: super(RemoveRedundantOpsTransform, self).__init__() def _should_remove(self, node: torch.fx.Node) -> bool: - if node.target in self.redundant_ops: - return True - - # Only remove to_copy if dtype does not change. Otherwise, memory format changes - # will be handled internally by the backend. - if ( - node.target == exir_ops.edge.aten._to_copy.default - or node.target == torch.ops.aten._to_copy.default - ): - src_dtype = node.meta["val"].dtype - # pyre-ignore - dst_dtype = node.args[0].meta["val"].dtype - return src_dtype == dst_dtype - - return False + if node.target not in self.redundant_ops: + return False + + orig_node = node.args[0] + assert isinstance(orig_node, torch.fx.Node) + + src_dtype = orig_node.meta["val"].dtype + dst_dtype = node.meta["val"].dtype + + # Do not remove if the op is converting the dtype. + if src_dtype != dst_dtype: + return False + + src_shape = orig_node.meta["val"].shape + dst_shape = node.meta["val"].shape + + return src_shape == dst_shape def _remove(self, graph_module: torch.fx.GraphModule) -> None: for node in graph_module.graph.nodes: if not self._should_remove(node): continue - with graph_module.graph.inserting_after(node): - node.replace_all_uses_with(node.args[0]) + node.replace_all_uses_with(node.args[0]) graph_module.graph.eliminate_dead_code() diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index db53cc666a8..00b6c62d5d2 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -6,22 +6,16 @@ import logging import operator - from typing import Any import executorch.backends.vulkan.utils as utils - import torch - from executorch.backends.vulkan.op_registry import get_op_features, has_impl, OpFeatures - from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, VkStorageType, ) - from executorch.exir.dialects._ops import ops as exir_ops - from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.tensor import TensorSpec @@ -130,15 +124,17 @@ def __init__( texture_limits: utils.ImageExtents, default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D, default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED, + force_fp16: bool = False, ): super().__init__() self.default_storage: VkStorageType = default_storage_type self.default_layout: VkMemoryLayout = default_memory_layout self.texture_limits = texture_limits + self.force_fp16 = force_fp16 # Magic number to limit "lookahead" when tracing through users of an operator # to constrain the representation of its arguments/outputs. - self.max_trace_search_depth = 20 + self.max_trace_search_depth = None def is_valid_op_node(self, node: Any) -> bool: """ @@ -230,6 +226,11 @@ def get_arg_tensor_source_repset( """ arg_node = op_node.args[arg_i] + # For non-tensor arguments, return ALL_STORAGES_REPSET so that the respset does + # not appear to be empty. + if not utils.is_tensor_arg_node(arg_node): + return utils.ALL_STORAGES_REPSET + # Special case for cat - use the first tensor in the list as representative if isinstance(arg_node, list): arg_node = arg_node[0] @@ -357,12 +358,18 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No 2. Then, try to trace through the users of the argument to find a representation that can be used for as long as possible without needing a transition. """ + # If forcing fp16, then try to use texture storage whenever possible. This is + # a temporary stopgap measure until all buffer implementations properly account + # for potential overflow of fp16 representation range when doing math in fp16. + if self.force_fp16: + op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE) + arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i) op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) arg_repset = op_repsets.get_arg_repset(arg_i) if arg_repset.is_constrained(): - return arg_repset + return arg_node = op_repsets.op_node.args[arg_i] @@ -372,6 +379,20 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset) op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset) + def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None: + """ + Similar to the `constrain_op_arg_repset` function, but for the output repset of + the operator. + """ + out_repset = op_repsets.get_out_repset(0) + if out_repset.is_constrained(): + return + + op_node = op_repsets.op_node + out_respset = self.trace_node_users_to_constrain_repset(op_node, out_repset) + + op_repsets.try_constrain_with_out_repset(out_respset) + def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: # For most ops, constraining the argument repsets will also contrain the output # repset due to OpRepSets maintaining synchronization rules. @@ -379,14 +400,12 @@ def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: if utils.is_tensor_arg_node(op_repsets.op_node.args[i]): self.constrain_op_arg_repset(i, op_repsets) - # TODO(ssjia): For most ops, inputs and outputs must be synchronized, so there - # is no need to constrain output repsets explicitly. Currently, the exceptions - # (i.e. choose qparams) already define constrined repsets for the output, so - # there is again no need to explicitly constrain the outputs. If an operator - # appears later on that does not sync input and output representations, and - # defines ambiguous repsets for the output tensor(s), then we will need to add - # additional logic to this function to constrain the output repsets separately - # from the input repsets. + # However, some operators do not sync input and output representations and also + # define ambiguous repsets for the output tensor(s). In those cases we will need + # to execute additional logic to constrain the output repsets separately from + # the input repsets. + if not op_repsets.sync_primary_io_repr and op_repsets.sync_outs_repr: + self.constrain_op_out_repset(op_repsets) def set_op_node_tensor_reprs( self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node diff --git a/backends/vulkan/cmake/ShaderLibrary.cmake b/backends/vulkan/cmake/ShaderLibrary.cmake index 1b6838c4dfd..16a60abf6f3 100644 --- a/backends/vulkan/cmake/ShaderLibrary.cmake +++ b/backends/vulkan/cmake/ShaderLibrary.cmake @@ -24,22 +24,17 @@ if(NOT EXECUTORCH_ROOT) message("WARNING: EXECUTORCH_ROOT is not set! A failure is likely imminent.") endif() -if(ANDROID) - if(NOT ANDROID_NDK) - message(FATAL_ERROR "ANDROID_NDK not set") - endif() - - if(NOT GLSLC_PATH) - set(GLSLC_PATH - "${ANDROID_NDK}/shader-tools/${ANDROID_NDK_HOST_SYSTEM_NAME}/glslc" - ) - endif() -else() - find_program(GLSLC_PATH glslc PATHS $ENV{PATH}) +find_program(GLSLC_PATH glslc PATHS $ENV{PATH}) - if(NOT GLSLC_PATH) - message(FATAL_ERROR "USE_VULKAN glslc not found") - endif() +if(NOT GLSLC_PATH) + message( + FATAL_ERROR + "glslc from the Vulkan SDK must be installed to build the Vulkan backend. " + "Please install the Vulkan SDK 1.4.321.0 or newer from " + "https://vulkan.lunarg.com/sdk/home and ensure that the glslc binary is in your PATH. " + "Note that the glslc distributed with the Android NDK is not compatible since it " + "does not support the GL_EXT_integer_dot_product extension. " + ) endif() # Required to enable linking with --whole-archive diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 56e803b9127..aed8b591fea 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -9,6 +9,8 @@ import executorch.backends.vulkan.patterns as vk_patterns import torch.library +from torch._subclasses.fake_tensor import FakeTensor + namespace = "et_vk" lib = torch.library.Library(namespace, "DEF") @@ -354,18 +356,20 @@ def linear_q8ta_q8csw( lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd") qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name) -####################### -## conv2d_q8ta_q8csw ## -####################### +############################ +## conv2d_q8ta_q8csw_q8to ## +############################ -def conv2d_q8ta_q8csw( +def conv2d_q8ta_q8csw_q8to( x: torch.Tensor, input_scale: float, input_zero_point: int, weights: torch.Tensor, weight_sums: torch.Tensor, weight_scales: torch.Tensor, + output_scale: float, + output_zero_point: int, bias: Optional[torch.Tensor], kernel_size: list, stride: list, @@ -373,27 +377,103 @@ def conv2d_q8ta_q8csw( dilation: list, groups: int, ): - IC = x.shape[1] + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, input_scale, input_zero_point, -128, 127, x.dtype + ) + + # Calculate weight dimensions + OC = weights.shape[0] + assert OC % groups == 0, "Output channels must be divisible by groups" + IC_per_group = int(x.shape[1] / groups) K_h, K_w = kernel_size[0], kernel_size[1] - canonical_weight_K_dim = K_h * K_w * IC + orig_weight_K_dim = K_h * K_w * IC_per_group + # Remove any padding added to in_features dim to align to a multiple of 4 + if weights.shape[-1] > orig_weight_K_dim: + weights = weights[:, :orig_weight_K_dim] + # Remove any padding added to output channels dim to align to a multiple of 4 - if weights.shape[-1] != canonical_weight_K_dim: - weights = weights[:, :canonical_weight_K_dim] - weight_scales = weight_scales[:canonical_weight_K_dim] + if weight_scales.shape[0] > OC: + weight_scales = weight_scales[:OC] if bias is not None: - bias = bias[:canonical_weight_K_dim] + bias = bias[:OC] + + # Reshape to original 4D format (OC, IC, H, W) + weights = weights.view(OC, IC_per_group, K_h, K_w) weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + # Dequantize weights + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights, + weight_scales, + weight_zeros, + 0, # axis=0 for output channel quantization + -127, + 127, + torch.int8, + ) - # Calculate dimensions - OC = weights.shape[0] - in_features = weights.shape[1] - IC = in_features // (K_h * K_w) + # Perform convolution + out = torch.nn.functional.conv2d( + x, weights, bias, stride, padding, dilation, groups + ) - # Reshape to original 4D format (OC, IC, H, W) - weights = weights.view(OC, IC, K_h, K_w) + out = torch.ops.quantized_decomposed.quantize_per_tensor( + out, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return out + + +name = "conv2d_q8ta_q8csw_q8to" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + float output_scale, + int output_zero_point, + Tensor? bias, + SymInt[] kernel_size, + SymInt[] stride, + SymInt[] padding, + SymInt[] dilation, + SymInt groups) -> Tensor + """ +) +lib.impl(name, conv2d_q8ta_q8csw_q8to, "CompositeExplicitAutograd") +conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name) + + +def conv2d_q8ta_q8csw_q8to_dw( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + output_scale: float, + output_zero_point: int, + bias: Optional[torch.Tensor], + kernel_size: list, + stride: list, + padding: list, + dilation: list, + groups: int, +): + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, input_scale, input_zero_point, -128, 127, x.dtype + ) + + # Restore weight to original data layout + K_h, K_w, OC = weights.shape + weights = weights.permute(2, 0, 1).reshape(OC, 1, K_h, K_w) + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) # Dequantize weights weights = torch.ops.quantized_decomposed.dequantize_per_channel( weights, @@ -410,10 +490,14 @@ def conv2d_q8ta_q8csw( x, weights, bias, stride, padding, dilation, groups ) + out = torch.ops.quantized_decomposed.quantize_per_tensor( + out, output_scale, output_zero_point, -128, 127, torch.int8 + ) + return out -name = "conv2d_q8ta_q8csw" +name = "conv2d_q8ta_q8csw_q8to_dw" lib.define( f""" {name}( @@ -423,6 +507,8 @@ def conv2d_q8ta_q8csw( Tensor weights, Tensor weight_sums, Tensor weight_scales, + float output_scale, + int output_zero_point, Tensor? bias, SymInt[] kernel_size, SymInt[] stride, @@ -431,8 +517,8 @@ def conv2d_q8ta_q8csw( SymInt groups) -> Tensor """ ) -lib.impl(name, conv2d_q8ta_q8csw, "CompositeExplicitAutograd") -conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name) +lib.impl(name, conv2d_q8ta_q8csw_q8to_dw, "CompositeExplicitAutograd") +conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name) ###################### ## apply_rotary_emb ## @@ -452,3 +538,60 @@ def apply_rotary_emb_impl( ) lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd") apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name) + +######################## +## add_q8ta_q8ta_q8to ## +######################## + + +def add_q8ta_q8ta_q8to_impl( + input_a: torch.Tensor, + input_b: torch.Tensor, + input_a_scale: float, + input_a_zero_point: int, + input_b_scale: float, + input_b_zero_point: int, + output_scale: float, + output_zero_point: int, + alpha: float, +): + # Dequantize inputs to float + dequant_a = torch.ops.quantized_decomposed.dequantize_per_tensor( + input_a, input_a_scale, input_a_zero_point, -128, 127, input_a.dtype + ) + dequant_b = torch.ops.quantized_decomposed.dequantize_per_tensor( + input_b, input_b_scale, input_b_zero_point, -128, 127, input_b.dtype + ) + + # Perform addition with alpha scaling + result = dequant_a + alpha * dequant_b + + # Quantize the result back to int8 + quantized_result = torch.ops.quantized_decomposed.quantize_per_tensor( + result, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return quantized_result + + +name = "add_q8ta_q8ta_q8to" +lib.define( + f"{name}(Tensor input_a, Tensor input_b, float input_a_scale, int input_a_zero_point, float input_b_scale, int input_b_zero_point, float output_scale, int output_zero_point, float alpha) -> Tensor" +) +lib.impl(name, add_q8ta_q8ta_q8to_impl, "CompositeExplicitAutograd") +add_q8ta_q8ta_q8to_op = getattr(getattr(torch.ops, namespace), name) + +############################# +## select_as_symint ## +############################# + + +def select_as_symint_impl(x: torch.Tensor, dim: int, index: int): + assert isinstance(x, FakeTensor) + return x.fake_mode.shape_env.create_unbacked_symint() + + +name = "select_as_symint" +lib.define(f"{name}(Tensor x, int dim, int index) -> SymInt") +lib.impl(name, select_as_symint_impl, "Meta") +select_as_symint_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/docs/android_demo.md b/backends/vulkan/docs/android_demo.md deleted file mode 100644 index ff84938b06f..00000000000 --- a/backends/vulkan/docs/android_demo.md +++ /dev/null @@ -1,128 +0,0 @@ -# Building and Running ExecuTorch with the Vulkan Backend - -The [ExecuTorch Vulkan Delegate](../../../docs/source/native-delegates-executorch-vulkan-delegate.md) -is a native GPU delegate for ExecuTorch. - - -::::{grid} 2 -:::{grid-item-card} What you will learn in this tutorial: -:class-card: card-content -* How to export the Llama3.2-1B parameter model with partial GPU delegation -* How to execute the partially delegated model on Android -::: -:::{grid-item-card} Prerequisites: -:class-card: card-prerequisites -* Follow [**Setting up ExecuTorch**](../../../docs/source/getting-started-setup.rst) -* It is also recommended that you read through [**ExecuTorch Vulkan Delegate**](../../../docs/source/native-delegates-executorch-vulkan-delegate.md) and follow the example in that page -::: -:::: - -## Prerequisites - -Note that all the steps below should be performed from the ExecuTorch repository -root directory, and assumes that you have gone through the steps of setting up -ExecuTorch. - -It is also assumed that the Android NDK and Android SDK is installed, and the -following environment examples are set. - -```shell -export ANDROID_NDK= -# Select an appropriate Android ABI for your device -export ANDROID_ABI=arm64-v8a -# All subsequent commands should be performed from ExecuTorch repo root -cd -# Make sure adb works -adb --version -``` - -## Lowering the Llama3.2-1B model to Vulkan - -::::{note} -The resultant model will only be partially delegated to the Vulkan backend. In -particular, only binary arithmetic operators (`aten.add`, `aten.sub`, -`aten.mul`, `aten.div`), matrix multiplication operators (`aten.mm`, `aten.bmm`), -and linear layers (`aten.linear`) will be executed on the GPU via the Vulkan -delegate. The rest of the model will be executed using Portable operators. - -Operator support for LLaMA models is currently in active development; please -check out the `main` branch of the ExecuTorch repo for the latest capabilities. -:::: - -First, obtain the `consolidated.00.pth`, `params.json` and `tokenizer.model` -files for the `Llama3.2-1B` model from the [Llama website](https://www.llama.com/llama-downloads/). - -Once the files have been downloaded, the `export_llama` script can be used to -partially lower the Llama model to Vulkan. - -```shell -# The files will usually be downloaded to ~/.llama -python -m examples.models.llama.export_llama \ - --disable_dynamic_shape --vulkan -kv --use_sdpa_with_kv_cache -d fp32 \ - --model "llama3_2" \ - -c ~/.llama/checkpoints/Llama3.2-1B/consolidated.00.pth \ - -p ~/.llama/checkpoints/Llama3.2-1B/params.json \ - --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' -``` - -A `vulkan_llama2.pte` file should have been created as a result of running the -script. - -Push the tokenizer binary and `vulkan_llama2.pte` onto your Android device: - -```shell -adb push ~/.llama/tokenizer.model /data/local/tmp/ -adb push vulkan_llama2.pte /data/local/tmp/ -``` - -## Build and Run the LLaMA runner binary on Android - -First, build and install ExecuTorch libraries, then build the LLaMA runner -binary using the Android NDK toolchain. - -```shell -./install_executorch.sh --clean -(mkdir cmake-android-out && \ - cmake . -DCMAKE_INSTALL_PREFIX=cmake-android-out \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DANDROID_ABI=$ANDROID_ABI \ - -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ - -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ - -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ - -DEXECUTORCH_BUILD_VULKAN=ON \ - -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ - -DEXECUTORCH_BUILD_KERNELS_LLM=ON \ - -DPYTHON_EXECUTABLE=python \ - -Bcmake-android-out && \ - cmake --build cmake-android-out -j16 --target install) - -# Build LLaMA Runner library -(rm -rf cmake-android-out/examples/models/llama && \ - cmake examples/models/llama \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DANDROID_ABI=$ANDROID_ABI \ - -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ - -DEXECUTORCH_BUILD_KERNELS_LLM=ON \ - -DCMAKE_INSTALL_PREFIX=cmake-android-out \ - -DPYTHON_EXECUTABLE=python \ - -Bcmake-android-out/examples/models/llama && \ - cmake --build cmake-android-out/examples/models/llama -j16) -``` - -Finally, push and run the llama runner binary on your Android device. Note that -your device must have sufficient GPU memory to execute the model. - -```shell -adb push cmake-android-out/examples/models/llama/llama_main /data/local/tmp/llama_main - -adb shell /data/local/tmp/llama_main \ - --model_path=/data/local/tmp/vulkan_llama2.pte \ - --tokenizer_path=/data/local/tmp/tokenizer.model \ - --prompt "Hello" -``` - -Note that currently model inference will be very slow due to the high amount of -delegate blobs in the lowered graph, which requires a transfer to and from the -GPU for each sub graph. Performance is expected to improve drastically as more -of the model can be lowered to the Vulkan delegate, and techniques such as -quantization are supported. diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 4c686e0cfc5..feba4f6f072 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -7,19 +7,12 @@ # pyre-unsafe import operator - from typing import Any, Callable, Dict, List, Optional, Union import executorch.backends.vulkan.custom_ops_lib # noqa - import executorch.backends.vulkan.utils as utils - import torch - -from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout - from executorch.exir.dialects._ops import ops as exir_ops - from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch._subclasses.fake_tensor import FakeTensor @@ -48,6 +41,9 @@ class OpFeatures: # Optional check function used during partitioning to determine if a node's # inputs are supported by the operator implementation. "are_node_inputs_supported_fn", + # Optional function to determine valid representation sets for input and outputs + # once a node's actual inputs are known. + "pick_io_storage_fn", ] def __init__( @@ -61,6 +57,7 @@ def __init__( supports_resize: bool = False, supports_prepacking: bool = False, are_node_inputs_supported_fn: Optional[Callable] = allow_node, + pick_io_storage_fn: Optional[Callable] = None, ): self.inputs_storage: utils.TensorRepSetList = utils.TensorRepSetList( inputs_storage if inputs_storage is not None else [] @@ -77,15 +74,21 @@ def __init__( self.supports_prepacking = supports_prepacking self.are_node_inputs_supported_fn = are_node_inputs_supported_fn + self.pick_io_storage_fn = pick_io_storage_fn def make_op_repsets( self, op_node: torch.fx.Node, texture_limits: utils.ImageExtents = utils.DEFAULT_TEXTURE_LIMITS, ) -> utils.OpRepSets: - return utils.OpRepSets( - self.inputs_storage, self.outputs_storage, op_node, texture_limits - ) + inputs_storage = self.inputs_storage + outputs_storage = self.outputs_storage + if self.pick_io_storage_fn is not None: + i_storage, o_storage = self.pick_io_storage_fn(op_node) + inputs_storage = utils.TensorRepSetList(i_storage) + outputs_storage = utils.TensorRepSetList(o_storage) + + return utils.OpRepSets(inputs_storage, outputs_storage, op_node, texture_limits) ####################### @@ -121,6 +124,7 @@ def update_features_impl(op: OpKey): # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, + operator.sub, operator.lt, operator.gt, operator.ge, @@ -140,13 +144,9 @@ def register_ephemeral_op(): @update_features( [ - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_token.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, exir_ops.edge.quantized_decomposed.dequantize_per_token.default, ] ) @@ -220,6 +220,18 @@ def register_binary_op(): ) +@update_features( + [ + exir_ops.edge.aten.pow.Tensor_Scalar, + ] +) +def register_binary_scalar_op(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, + ) + + @update_features( [ exir_ops.edge.aten.abs.default, @@ -277,27 +289,9 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: @update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default) def register_to_copy_dim_order_op(): - # Currently there is no "real" implementation for to_dim_order_copy, but it can be - # removed as long as the operator is not changing the dtype, i.e. the operator call - # is modifying the dim order only. Therefore, check that the input and output dtypes - # are the same, if so the operator is safe to remove. - def check_dim_order_copy_node(node: torch.fx.Node) -> bool: - in_arg = node.args[0] - if not isinstance(in_arg, torch.fx.Node): - return False - - in_tensor = in_arg.meta.get("val", None) - out_tensor = node.meta.get("val", None) - - if in_tensor.dtype != out_tensor.dtype: - return False - - return True - return OpFeatures( - inputs_storage=utils.ANY_STORAGE, + inputs_storage=utils.ANY_BUFFER, supports_resize=True, - are_node_inputs_supported_fn=check_dim_order_copy_node, ) @@ -400,56 +394,131 @@ def register_softmax_op(): ) +def get_dims_reduced(node: torch.fx.Node) -> Union[int, List[int]]: + ndim = utils.ndim_of(node.args[0]) + assert ndim is not None + dims_reduced = None + if len(node.args) >= 2: + dims_reduced = node.args[1] + + # If dim_list is None, return a list containing all the dims of the tensor + if dims_reduced is None: + dims_reduced = list(range(ndim)) + + # Special case for reducing tensors with shape [1, N] - this is equivalent to + # reducing the last dim. + if utils.is_unsqueezed_vector(node) and ndim == 2: + dims_reduced = 1 + + if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) == 1: + dims_reduced = dims_reduced[0] + + assert isinstance(dims_reduced, (int, list, tuple)) + return utils.normalize_dims(dims_reduced, ndim) + + +def get_keepdim_setting(node: torch.fx.Node) -> bool: + for arg in node.args: + if isinstance(arg, bool): + return arg + + # Assume false by default + return False + + +def is_reduce_node_supported_by_per_row_impl(node: torch.fx.Node) -> bool: + """ + Checks if a reduction node is supported by the Vulkan backend's reduce per row + special case implementation. + """ + input_ndim = utils.ndim_of(node.args[0]) + assert input_ndim is not None + dims_reduced = get_dims_reduced(node) + + return dims_reduced == input_ndim - 1 + + +def is_reduce_node_supported_by_general_impl(node: torch.fx.Node) -> bool: + dims_reduced = get_dims_reduced(node) + # Only 1D and 2D reductions are supported at the moment. + if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) > 2: + return False + + keepdim = get_keepdim_setting(node) + # keepdim = False is not supported yet for general implementation + if isinstance(keepdim, bool) and not keepdim: + return False + + return True + + +def is_reduce_node_supported(node: torch.fx.Node) -> bool: + return is_reduce_node_supported_by_per_row_impl( + node + ) or is_reduce_node_supported_by_general_impl(node) + + +def pick_storage_for_reduce(node: torch.fx.Node): + inputs_storage = utils.NO_STORAGE + outputs_storage = utils.NO_STORAGE + + ndim = utils.ndim_of(node.args[0]) + dim_list = get_dims_reduced(node) + + if is_reduce_node_supported_by_general_impl(node): + inputs_storage = inputs_storage.make_union(utils.ANY_TEXTURE) + outputs_storage = inputs_storage + + # For 1D reductions of the last dim, a special reduce per row case is implemented + # for buffer backed tensors. + if is_reduce_node_supported_by_per_row_impl(node): + inputs_storage = inputs_storage.make_union(utils.CONTIGUOUS_BUFFER) + outputs_storage = inputs_storage + return inputs_storage, outputs_storage + + # For 2D reductions, the packed dimension cannot be one of the reduced dims + if isinstance(dim_list, (list, tuple)) and len(dim_list) == 2: + # pyre-ignore[6] + reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim) + # pyre-ignore[6] + reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim) + + possible_packed_dims = {0, 1, 2} + possible_packed_dims.discard(reduce_dim1_whcn) + possible_packed_dims.discard(reduce_dim2_whcn) + + packed_dim = possible_packed_dims.pop() + assert packed_dim in [0, 1, 2] + + if packed_dim == 0: + inputs_storage = utils.WIDTH_PACKED_TEXTURE + outputs_storage = utils.WIDTH_PACKED_TEXTURE + elif packed_dim == 1: + inputs_storage = utils.HEIGHT_PACKED_TEXTURE + outputs_storage = utils.HEIGHT_PACKED_TEXTURE + else: + inputs_storage = utils.CHANNELS_PACKED_TEXTURE + outputs_storage = utils.CHANNELS_PACKED_TEXTURE + + return inputs_storage, outputs_storage + + @update_features( [ exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, + exir_ops.edge.aten.argmax.default, + exir_ops.edge.aten.argmin.default, ] ) def register_reduce_op(): - def check_reduce_node(node: torch.fx.Node) -> bool: - dim_list = node.args[1] - if isinstance(dim_list, list) and len(dim_list) > 2: - return False - - if isinstance(dim_list, list) and len(dim_list) == 2: - # Try to get the memory layout for this node - try: - memory_layout = utils.get_node_memory_layout(node) - - # If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension - if ( - memory_layout is not None - and memory_layout != VkMemoryLayout.DEFAULT_LAYOUT - ): - # For now only default layout is supported for 2D reduction. - # Because we can't determine if the input is NCHW or NHWC here, - # assume the reduction dimension is packed so we cannot support it. - return False - except (AssertionError, KeyError, AttributeError): - # If we can't get memory layout information, we'll assume the dims aren't packed - pass - - def try_find_keepdim_arg(node: torch.fx.Node) -> bool: - for arg in node.args: - if isinstance(arg, bool): - return arg - - # Assume false by default - return False - - keepdim = try_find_keepdim_arg(node) - if isinstance(keepdim, bool) and not keepdim: - return False - - return True - return OpFeatures( inputs_storage=utils.ANY_TEXTURE, supports_resize=True, - are_node_inputs_supported_fn=check_reduce_node, + are_node_inputs_supported_fn=is_reduce_node_supported, + pick_io_storage_fn=pick_storage_for_reduce, ) @@ -474,6 +543,24 @@ def register_2d_pool_op(): ] ) def register_convolution_op(): + def check_conv_node(node: torch.fx.Node) -> bool: + x = node.args[0] + assert isinstance(x, torch.fx.Node) + x_shape = x.meta["val"].size() + # 4-D input implies 2D convolution + if len(x_shape) == 4: + batches = x.meta["val"].size()[0] + if batches != 1: + return False + # 3-D input implies 1D convolution + if len(x_shape) == 3: + transpose = node.args[6] + # Transposed 1D convolution is not supported yet + if transpose: + return False + + return True + return OpFeatures( inputs_storage=[ utils.CHANNELS_PACKED_TEXTURE, # input @@ -490,23 +577,27 @@ def register_convolution_op(): ], supports_resize=True, supports_prepacking=True, + are_node_inputs_supported_fn=check_conv_node, ) @update_features( [ - exir_ops.edge.et_vk.conv2d_q8ta_q8csw.default, + exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default, + exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default, ] ) def register_quantized_conv_op(): return OpFeatures( inputs_storage=[ - utils.CHANNELS_PACKED_TEXTURE, # input + utils.PACKED_INT8_4W4C_BUFFER, # input utils.NO_STORAGE, # input_scale (non tensor) utils.NO_STORAGE, # input_zero_point (non tensor) utils.NO_STORAGE, # weight (prepacked) utils.NO_STORAGE, # weight_sums (prepacked) utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # output_scale (non tensor) + utils.NO_STORAGE, # output_zero_point (non tensor) utils.NO_STORAGE, # bias (prepacked) utils.NO_STORAGE, # kernel_size (non tensor) utils.NO_STORAGE, # stride (non tensor) @@ -520,10 +611,57 @@ def register_quantized_conv_op(): ) +@update_features( + [ + exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default, + ] +) +def register_quantized_binary_op(): + return OpFeatures( + inputs_storage=utils.PACKED_INT8_4W4C_BUFFER, + supports_resize=False, + supports_prepacking=True, + ) + + +@update_features( + [ + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + ] +) +def register_quantize_op(): + return OpFeatures( + inputs_storage=[ + utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, + ], + outputs_storage=[ + utils.PACKED_INT8_4W4C_BUFFER, + ], + ) + + +@update_features( + [ + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + ] +) +def register_dequantize_op(): + return OpFeatures( + inputs_storage=[ + utils.PACKED_INT8_4W4C_BUFFER, + ], + outputs_storage=[ + utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, + ], + ) + + @update_features("llama::sdpa_with_kv_cache") def register_sdpa_with_kv_cache_op(): return OpFeatures( - inputs_storage=utils.WIDTH_PACKED_TEXTURE, + inputs_storage=utils.CONTIGUOUS_ANY, supports_resize=True, supports_prepacking=True, ) @@ -545,7 +683,7 @@ def register_sdpa_ops(): @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) def register_rotary_emb_op(): return OpFeatures( - inputs_storage=utils.WIDTH_PACKED_TEXTURE, + inputs_storage=utils.CONTIGUOUS_ANY, supports_resize=True, ) @@ -569,6 +707,7 @@ def register_view_ops(): exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.gather.default, ] ) def register_view_ops_with_buffer_meta(): @@ -601,6 +740,7 @@ def register_cat_op(): [ exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.split_with_sizes_copy.default, ] ) def register_transfer_ops(): @@ -643,10 +783,7 @@ def register_ported_op(): # Ops ported from PyTorch Vulkan backend. These ops are in a separate registry because they support all packed dimensions @update_features( [ - # Tensor combination exir_ops.edge.aten.repeat.default, - exir_ops.edge.aten.split_with_sizes_copy.default, - exir_ops.edge.aten.split.Tensor, ] ) def register_ported_op_all_packed_dims(): @@ -666,6 +803,7 @@ def register_ported_ops_with_prepacking(): return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, supports_prepacking=True, + supports_resize=True, ) @@ -696,6 +834,7 @@ def register_ported_ops_with_prepacking_all_dims(): return OpFeatures( inputs_storage=utils.ANY_TEXTURE, supports_prepacking=True, + supports_resize=True, ) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index e5b2d0f7864..bc3bf14bf14 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -36,7 +36,7 @@ Partitioner, PartitionResult, ) -from executorch.exir.backend.utils import tag_constant_data +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram @@ -59,6 +59,7 @@ def __init__( texture_limits: utils.ImageExtents, buffer_limit: int, require_dynamic_shape: bool = False, + skip_bool_tensors: bool = False, operator_blocklist: Optional[Set[OpKey]] = None, operator_allowlist: Optional[Set[OpKey]] = None, fusable_subgraphs: Optional[List[PatternMatch]] = None, @@ -69,6 +70,7 @@ def __init__( self.texture_limits: utils.ImageExtents = texture_limits self.buffer_limit = buffer_limit self.require_dynamic_shapes = require_dynamic_shape + self.skip_bool_tensors = skip_bool_tensors self.operator_blocklist: Set[OpKey] = ( operator_blocklist if operator_blocklist is not None else set() ) @@ -117,6 +119,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex return False, "no operator implementation" features = get_op_features(target) + # bool tensors are internally represented with int8 buffers, which may not be + # supported by some GPUs. Therefore, provide the option to skip these tensors. + if self.skip_bool_tensors and utils.op_contains_bool_tensor(node): + return False, f"op {utils.node_io_str(node)} contains bool tensor" + # Get the possible tensor representations for each tensor participating in the # this operator. Then check that all tensors are representable as either a # buffer or texture. @@ -177,36 +184,6 @@ def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]: return False, False - def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]: - """ - Scalar tensors are usually converted to scalar values in the graph via` - scalar_tensor[0].item()` in Python, which translates to a chain of - `local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph. - This function marks the entire chain as supported by the Vulkan delegate. - - Later, within vulkan_preprocess there will be a graph transform which replaces - the chain with passing in the scalar tensor directly. - - Similar to the `is_linear_permute` function, this function has 2 return values. - """ - if node.target == exir_ops.edge.aten.select_copy.int: - if len(node.users) != 1: - return False, False - # pyre-ignore - if node.args[0].meta["val"].numel() != 1: - return False, False - - local_scalar_dense = list(node.users.keys())[0] - if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default: - return False, False - - return self.is_in_local_scalar_dense_chain(local_scalar_dense) - - if node.target == torch.ops.aten._local_scalar_dense.default: - return True, all(self.node_is_compatible(user)[0] for user in node.users) - - return False, False - def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": logger.info( @@ -254,15 +231,6 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901 self.log_skip(node, "permute node of non compatible linear node") return False - is_in_local_scalar_dense_chain, dst_node_is_compatible = ( - self.is_in_local_scalar_dense_chain(node) - ) - if is_in_local_scalar_dense_chain and dst_node_is_compatible: - return True - elif is_in_local_scalar_dense_chain: - self.log_skip(node, "local scalar dense of incompatible op node") - return False - features = None if target not in vulkan_supported_ops: # For some ops, i.e. custom ops the name is registered instead of the @@ -397,6 +365,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: texture_limits, buffer_limit, require_dynamic_shape=self.options.get("require_dynamic_shapes", False), + skip_bool_tensors=self.options.get("skip_bool_tensors", False), operator_blocklist=self.operator_blocklist, operator_allowlist=self.operator_allowlist, fusable_subgraphs=fusable_subgraphs, @@ -419,6 +388,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: logger.info(f"Found {pl} Vulkan subgraphs to be partitioned.") tag_constant_data(exported_program) + tag_mutated_buffer(exported_program) return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags diff --git a/backends/vulkan/patterns/TARGETS b/backends/vulkan/patterns/TARGETS index 791edf58984..3baf7c9e251 100644 --- a/backends/vulkan/patterns/TARGETS +++ b/backends/vulkan/patterns/TARGETS @@ -11,6 +11,9 @@ runtime.python_library( "rope.py", "quantized_linear.py", "quantized_convolution.py", + "quantized_binary.py", + "sdpa.py", + "select_as_symint.py", ], visibility = [ "//executorch/backends/...", diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index 8ffad98b3c3..9b875def944 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -6,12 +6,18 @@ from typing import List +import executorch.backends.vulkan.patterns.quantized_binary # noqa + import executorch.backends.vulkan.patterns.quantized_convolution # noqa import executorch.backends.vulkan.patterns.quantized_linear # noqa import executorch.backends.vulkan.patterns.rope # noqa +import executorch.backends.vulkan.patterns.sdpa # noqa + +import executorch.backends.vulkan.patterns.select_as_symint # noqa + import torch from executorch.backends.vulkan.patterns.pattern_registry import ( diff --git a/backends/vulkan/patterns/quantized_binary.py b/backends/vulkan/patterns/quantized_binary.py new file mode 100644 index 00000000000..da4985b931d --- /dev/null +++ b/backends/vulkan/patterns/quantized_binary.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +class QuantizedBinaryMatch(PatternMatch): + def __init__(self, binary_node: torch.fx.Node) -> None: + self.anchor_node = binary_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # Extract alpha parameter if it exists (for add operations) + self.alpha = 1.0 + if len(binary_node.args) > 2 and binary_node.args[2] is not None: + # Alpha is typically a scalar value + if isinstance(binary_node.args[2], (int, float)): + self.alpha = binary_node.args[2] + + # Identify input nodes - both should be dequantize nodes for static quantization + if len(binary_node.args) < 2: + return + + input_a_node = binary_node.args[0] + assert isinstance(input_a_node, torch.fx.Node) + input_b_node = binary_node.args[1] + assert isinstance(input_b_node, torch.fx.Node) + + # Both arguments must be dequant nodes for static quantization + if not utils.is_dequant_node(input_a_node) or not utils.is_dequant_node( + input_b_node + ): + return + + self.dequantize_input_a_node = input_a_node + self.dequantize_input_b_node = input_b_node + + # Extract quantization parameters for input A + self.quantize_input_a_node = self.dequantize_input_a_node.args[0] + self.input_a_scales_node = self.dequantize_input_a_node.args[1] + self.input_a_zeros_node = self.dequantize_input_a_node.args[2] + + # Extract quantization parameters for input B + self.quantize_input_b_node = self.dequantize_input_b_node.args[0] + self.input_b_scales_node = self.dequantize_input_b_node.args[1] + self.input_b_zeros_node = self.dequantize_input_b_node.args[2] + + self.all_nodes.extend( + [self.dequantize_input_a_node, self.dequantize_input_b_node] + ) + + # Identify output node + self.output_node = self.anchor_node + + # The binary operation output must have only one user; it will be either a relu node + # or a quantize node. + if len(self.output_node.users) != 1: + return + + cur_node = list(self.output_node.users)[0] + self.relu_node = None + if cur_node.target == exir_ops.edge.aten.relu.default: + self.relu_node = cur_node + self.all_nodes.append(self.relu_node) + # If there's a relu, get its user (should be the quantize node) + if len(cur_node.users) != 1: + return + cur_node = list(cur_node.users)[0] + + if not utils.is_quant_node(cur_node): + return + + self.quantize_output_node = cur_node + self.output_scales_node = self.quantize_output_node.args[1] + self.output_zeros_node = self.quantize_output_node.args[2] + + self.all_nodes.append(self.quantize_output_node) + + self.match_found = True + + +# Define the binary operation anchor nodes that we support +binary_anchor_nodes = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.add_.Tensor, +} + + +@register_pattern_detector("quantized_binary") +def find_quantized_binary_patterns( + node: torch.fx.Node, +) -> Optional[QuantizedBinaryMatch]: + if node.target not in binary_anchor_nodes: + return None + + matched_pattern = QuantizedBinaryMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +@register_pattern_replacement("quantized_binary") +def make_add_q8ta_q8ta_q8to_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedBinaryMatch, +): + # Determine the operation type based on the anchor node + op_target = None + if match.anchor_node.target in { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.add_.Tensor, + }: + op_target = exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default + else: + # For future binary operations, add more mappings here + raise NotImplementedError( + f"Unsupported binary operation: {match.anchor_node.target}" + ) + + with graph_module.graph.inserting_before(match.output_node): + qbinary_node = graph_module.graph.create_node( + "call_function", + op_target, + args=( + match.quantize_input_a_node, + match.quantize_input_b_node, + match.input_a_scales_node, + match.input_a_zeros_node, + match.input_b_scales_node, + match.input_b_zeros_node, + match.output_scales_node, + match.output_zeros_node, + match.alpha, # Alpha parameter for scaling + ), + ) + + qbinary_node.meta["val"] = match.output_node.meta["val"] + match.quantize_output_node.replace_all_uses_with(qbinary_node) diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py index 65b51b5e103..522a19c58d6 100644 --- a/backends/vulkan/patterns/quantized_convolution.py +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -76,11 +76,13 @@ def __init__(self, conv_node: torch.fx.Node) -> None: # Identify output node self.output_node = self.anchor_node - out_channels = self.output_node.meta["val"].shape[-1] - # The implementation requires that for grouped convolutions, a group does not - # cross any texel boundary. The output channels per group must be a multiple of - # 4. If this is not true, then don't match the pattern. - if self.groups > 1 and (out_channels / self.groups) % 4 == 0: + out_channels = self.output_node.meta["val"].shape[-3] + # The implementation requires that for non-depthwise grouped convolutions, a + # group does not cross the texel boundary. The output channels per group must be + # a multiple of 4. If this is not true, then don't match the pattern. + if (self.groups > 1 and self.groups < out_channels) and ( + out_channels / self.groups + ) % 4 != 0: return # Identify bias node, if applicable @@ -93,23 +95,37 @@ def __init__(self, conv_node: torch.fx.Node) -> None: self.all_nodes.extend(arg_chain) # Identify input node - self.fp_input_node, self.quantize_input_node, dq_node = ( - utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) - ) - assert self.fp_input_node is not None - self.all_nodes.append(self.fp_input_node) - assert self.quantize_input_node is not None - assert dq_node is not None - - self.input_scales_node = self.quantize_input_node.args[1] - self.input_zeros_node = self.quantize_input_node.args[2] - - self.all_nodes.extend( - [ - self.quantize_input_node, - dq_node, - ] - ) + primary_input_node = self.anchor_node.args[0] + assert isinstance(primary_input_node, torch.fx.Node) + # Argument must be a dequant node for static quantization + if not utils.is_dequant_node(primary_input_node): + return + + self.dequantize_input_node = primary_input_node + self.quantize_input_node = self.dequantize_input_node.args[0] + + self.input_scales_node = self.dequantize_input_node.args[1] + self.input_zeros_node = self.dequantize_input_node.args[2] + + self.all_nodes.extend([self.dequantize_input_node]) + + # The convolution output must have only one user; it will be either a relu node + # or a dequantize node. + if len(self.output_node.users) != 1: + return + + cur_node = list(self.output_node.users)[0] + self.relu_node = None + if cur_node.target == exir_ops.edge.aten.relu.default: + self.relu_node = cur_node + cur_node = list(cur_node.users)[0] + + if not utils.is_quant_node(cur_node): + return + + self.quantize_output_node = cur_node + self.output_scales_node = self.quantize_output_node.args[1] + self.output_zeros_node = self.quantize_output_node.args[2] self.match_found = True @@ -161,13 +177,26 @@ def make_conv2d_q8ta_q8csw_custom_op( bias_tensor = get_param_tensor(ep, match.bias_node) assert bias_tensor is not None - OC, IC, H, W = weight_tensor.shape + OC, IC_per_group, H, W = weight_tensor.shape - # Reshape weight tensor from (OC, IC, H, W) to (OC, H * W * IC) (i.e. matrix format) - # This prepares the weights for Im2Col-based convolution - weight_tensor = ( - weight_tensor.permute(0, 2, 3, 1).contiguous().view(OC, H * W * IC).contiguous() - ) + is_depthwise_conv = IC_per_group == 1 and match.groups == OC + + if is_depthwise_conv: + assert OC % 4 == 0, "depthwise conv requires that OC is divisible by 4" + # Depthwise convs use a specialized layout; the weight tensor is reshaped to + # (H, W, OC) + weight_tensor = ( + weight_tensor.permute(2, 3, 1, 0).contiguous().view(H, W, OC).contiguous() + ) + else: + # Reshape weight tensor from (OC, IC_per_group, H, W) to (OC, H * W * IC_per_group) + # (i.e. matrix format). This prepares the weights for Im2Col-based convolution. + weight_tensor = ( + weight_tensor.permute(0, 2, 3, 1) + .contiguous() + .view(OC, H * W * IC_per_group) + .contiguous() + ) # Need to make sure that OC dim is a multiple of 4 so that data load/stores are well # aligned with texel boundaries. Add padding to align to the next multiple of 4 if @@ -178,6 +207,7 @@ def make_conv2d_q8ta_q8csw_custom_op( utils.align_width_and_update_state_dict( ep, match.weight_scales_node, weight_scales_tensor ) + if bias_tensor is not None: utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor) @@ -185,7 +215,7 @@ def make_conv2d_q8ta_q8csw_custom_op( with graph_module.graph.inserting_before(first_graph_node): qweight_tensor_name = utils.get_tensor_name(ep, match.weight_node) # Pre-compute the weight sums which are needed to apply activation zero point - # when using integer accumulation. For the reshaped 2D weight matrix (IC * H * W, OC), + # when using integer accumulation. For the reshaped 2D weight matrix (IC_per_group * H * W, OC), # sum over dimension 0 to get sums per output channel sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous() sums_name = qweight_tensor_name + "_sums" @@ -201,16 +231,22 @@ def make_conv2d_q8ta_q8csw_custom_op( ) with graph_module.graph.inserting_before(match.output_node): + op_target = exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default + if is_depthwise_conv: + op_target = exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default + qconv_node = graph_module.graph.create_node( "call_function", - exir_ops.edge.et_vk.conv2d_q8ta_q8csw.default, + op_target, args=( - match.fp_input_node, + match.quantize_input_node, match.input_scales_node, match.input_zeros_node, match.weight_node, weight_sums_node, match.weight_scales_node, + match.output_scales_node, + match.output_zeros_node, match.bias_node, # Add bias after weight_scales [H, W], # Pass kernel size information before stride match.stride, @@ -221,4 +257,4 @@ def make_conv2d_q8ta_q8csw_custom_op( ) qconv_node.meta["val"] = match.output_node.meta["val"] - match.output_node.replace_all_uses_with(qconv_node) + match.quantize_output_node.replace_all_uses_with(qconv_node) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 882d0d41e6d..374e29c634d 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -92,9 +92,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None: return # Identify input node - self.fp_input_node, self.quantize_input_node, dq_node = ( - utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) - ) + ( + self.fp_input_node, + self.quantize_input_node, + dq_node, + ) = utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) assert self.fp_input_node is not None self.all_nodes.append(self.fp_input_node) @@ -386,7 +388,7 @@ def make_linear_dq8ca_q4gsw_op( weight_sums_node = create_constant_placeholder( exp_program=ep, graph=graph_module.graph, - kind=InputKind.CONSTANT_TENSOR, + kind=InputKind.PARAMETER, name=sums_name, data=sum_per_quant_group, ) @@ -429,7 +431,7 @@ def make_linear_q8ta_q8csw_custom_op( weight_sums_node = create_constant_placeholder( exp_program=ep, graph=graph_module.graph, - kind=InputKind.CONSTANT_TENSOR, + kind=InputKind.PARAMETER, name=sums_name, data=sum_per_output_channel, ) diff --git a/backends/vulkan/patterns/sdpa.py b/backends/vulkan/patterns/sdpa.py new file mode 100644 index 00000000000..f67799f9b76 --- /dev/null +++ b/backends/vulkan/patterns/sdpa.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram + + +def is_update_cache_node(node: Any) -> bool: + return utils.node_has_target(node, "llama::update_cache") + + +def is_custom_sdpa_node(node: Any) -> bool: + return utils.node_has_target(node, "llama::custom_sdpa") + + +def is_sdpa_with_kv_cache_node(node: Any) -> bool: + return utils.node_has_target(node, "llama::sdpa_with_kv_cache") + + +class CausalSDPAMatch(PatternMatch): + def __init__(self, custom_sdpa_node: torch.fx.Node) -> None: + self.anchor_node = custom_sdpa_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # llama.custom_sdpa has signature: + # custom_sdpa(query, key_cache, value_cache, start_pos, attn_mask, dropout_p, is_causal, scale) -> output + if len(custom_sdpa_node.args) < 4: + return + + self.query_node = custom_sdpa_node.args[0] + self.key_cache_node = custom_sdpa_node.args[1] + self.value_cache_node = custom_sdpa_node.args[2] + self.start_pos_node = custom_sdpa_node.args[3] + self.attn_mask_node = custom_sdpa_node.args[4] + self.dropout_p_node = custom_sdpa_node.args[5] + self.is_causal_node = custom_sdpa_node.args[6] + if len(custom_sdpa_node.args) > 7: + self.scale_node = custom_sdpa_node.args[7] + else: + self.scale_node = None + + # try to find update key cache node + self.update_key_cache_node = None + for user in self.key_cache_node.users: + if is_update_cache_node(user): + self.update_key_cache_node = user + break + + self.key_projection_node = None + if self.update_key_cache_node is not None: + self.key_projection_node = self.update_key_cache_node.args[0] + + # find update value cache node + self.update_value_cache_node = None + for user in self.value_cache_node.users: + if is_update_cache_node(user): + self.update_value_cache_node = user + break + + self.value_projection_node = None + if self.update_value_cache_node is not None: + self.value_projection_node = self.update_value_cache_node.args[0] + + # We have additional optional arguments but we don't need to capture them + # since the new op doesn't use them + + self.match_found = True + + +@register_pattern_detector("causal_sdpa") +def find_causal_sdpa_patterns( + node: torch.fx.Node, +) -> Optional[CausalSDPAMatch]: + if not is_custom_sdpa_node(node): + return None + + matched_pattern = CausalSDPAMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +def find_singleton_start_pos_node(graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if is_update_cache_node(node): + return node.args[2] + + if is_sdpa_with_kv_cache_node(node): + return node.args[5] + + raise Exception( + "Could not find an instance of llama::update_cache or sdpa_with_kv_cache" + ) + + +@register_pattern_replacement("causal_sdpa") +def replace_custom_sdpa_with_causal_sdpa( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: CausalSDPAMatch, +): + assert match.update_key_cache_node is not None + assert match.key_projection_node is not None + assert match.update_value_cache_node is not None + assert match.value_projection_node is not None + + singleton_start_pos_node = find_singleton_start_pos_node(graph_module) + + with graph_module.graph.inserting_before(match.anchor_node): + new_node = graph_module.graph.create_node( + "call_function", + torch.ops.llama.sdpa_with_kv_cache.default, + args=( + match.query_node, + match.key_projection_node, + match.value_projection_node, + match.key_cache_node, + match.value_cache_node, + singleton_start_pos_node, + 1, + match.attn_mask_node, + match.dropout_p_node, + match.is_causal_node, + match.scale_node, + ), + ) + + new_node.meta["val"] = match.anchor_node.meta["val"] + match.anchor_node.replace_all_uses_with(new_node) + + # Manually erase update_cache nodes since DCE will not remove them since they + # modify inputs (specifically, the cache args are modified) + graph_module.graph.erase_node(match.update_key_cache_node) + graph_module.graph.erase_node(match.update_value_cache_node) diff --git a/backends/vulkan/patterns/select_as_symint.py b/backends/vulkan/patterns/select_as_symint.py new file mode 100644 index 00000000000..e7226b08188 --- /dev/null +++ b/backends/vulkan/patterns/select_as_symint.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +class SelectAsSymIntMatch(PatternMatch): + def __init__(self, local_scalar_dense_node: torch.fx.Node) -> None: + self.anchor_node = local_scalar_dense_node + self.match_found = False + + # Check if the input to local_scalar_dense is a select_copy node + if len(local_scalar_dense_node.args) < 1: + return + + select_node = local_scalar_dense_node.args[0] + if not isinstance(select_node, torch.fx.Node): + return + + if ( + select_node.op != "call_function" + or select_node.target != exir_ops.edge.aten.select_copy.int + ): + return + + # select_copy.int has signature: select_copy(Tensor self, int dim, int index) + if len(select_node.args) < 3: + return + + self.select_node = select_node + + self.tensor_node = select_node.args[0] + self.dim_node = select_node.args[1] + self.index_node = select_node.args[2] + + self.all_nodes = [ + self.anchor_node, + self.select_node, + self.tensor_node, + self.dim_node, + self.index_node, + ] + + self.match_found = True + + +@register_pattern_detector("select_as_symint") +def find_select_as_symint_patterns( + node: torch.fx.Node, +) -> Optional[SelectAsSymIntMatch]: + if node.target != torch.ops.aten._local_scalar_dense.default: + return None + + matched_pattern = SelectAsSymIntMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +@register_pattern_replacement("select_as_symint") +def replace_select_local_scalar_dense_with_select_as_symint( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: SelectAsSymIntMatch, +): + with graph_module.graph.inserting_before(match.anchor_node): + new_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.select_as_symint.default, + args=( + match.tensor_node, + match.dim_node, + match.index_node, + ), + ) + + new_node.meta["val"] = match.anchor_node.meta["val"] + match.anchor_node.replace_all_uses_with(new_node) + + # # Remove both the local_scalar_dense and select_copy nodes + # graph_module.graph.erase_node(match.anchor_node) + # # Only erase select_node if it has no other users + # if len(match.select_node.users) == 0: + # graph_module.graph.erase_node(match.select_node) diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index 40212c35c27..3d1e1eab0f2 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -124,11 +124,22 @@ class VulkanQuantizer(Quantizer): def __init__(self) -> None: super().__init__() self.global_config: Optional[QuantizationConfig] = None + # If specified, only quantize nodes that return true for the filter + # function. + self.filter_fn: Optional[Callable[[Node], bool]] = None def set_global(self, quantization_config: QuantizationConfig) -> VulkanQuantizer: self.global_config = quantization_config return self + def set_filter_function(self, filter_fn: Callable[[Node], bool]): + """ + Set the filter function. We only quantize nodes that return True for + the filter function. + """ + self.filter_fn = filter_fn + return self + def transform_for_annotation( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: @@ -149,8 +160,14 @@ def _annotate_all_patterns( if quantization_config is None: return model + # Create a combined filter function, which returns True only when + # both filter_fn and self.filter_fn return True. + def combined_filter_fn(n: Node) -> bool: + combined_filter = [self.filter_fn, filter_fn] + return all(f(n) for f in combined_filter if f is not None) + for op in _SUPPORTED_OPS: - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + OP_TO_ANNOTATOR[op](model, quantization_config, combined_filter_fn) return model def _annotate_for_quantization_config( diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 7b138072d50..677b042beb6 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -19,6 +19,7 @@ #include #include #ifdef ET_EVENT_TRACER_ENABLED +#include #include #endif // ET_EVENT_TRACER_ENABLED #include @@ -86,6 +87,32 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) { return vkapi::kFloat; case vkgraph::VkDataType::FLOAT64: return vkapi::kDouble; + default: + VK_THROW("Invalid VkDataType type encountered!"); + } +} + +vkapi::ScalarType equivalent_scalar_type( + const executorch::runtime::etensor::ScalarType& et_datatype) { + switch (et_datatype) { + case executorch::runtime::etensor::ScalarType::Byte: + return vkapi::kByte; + case executorch::runtime::etensor::ScalarType::Char: + return vkapi::kChar; + case executorch::runtime::etensor::ScalarType::Int: + return vkapi::kInt; + case executorch::runtime::etensor::ScalarType::Long: + return vkapi::kLong; + case executorch::runtime::etensor::ScalarType::Half: + return vkapi::kHalf; + case executorch::runtime::etensor::ScalarType::Float: + return vkapi::kFloat; + case executorch::runtime::etensor::ScalarType::Double: + return vkapi::kDouble; + case executorch::runtime::etensor::ScalarType::Bool: + return vkapi::kBool; + default: + VK_THROW("Invalid etensor::ScalarType encountered!"); } } @@ -113,6 +140,10 @@ utils::GPUMemoryLayout get_memory_layout( return utils::kHeightPacked; case vkgraph::VkMemoryLayout::TENSOR_CHANNELS_PACKED: return utils::kChannelsPacked; + case vkgraph::VkMemoryLayout::PACKED_INT8_4W4C: + return utils::kPackedInt8_4W4C; + case vkgraph::VkMemoryLayout::PACKED_INT8_4H4W: + return utils::kPackedInt8_4H4W; default: break; } @@ -149,6 +180,12 @@ GraphConfig get_graph_config(ArrayRef& compile_specs) { config.expect_dynamic_shapes = true; } } + if (strcmp(spec.key, "warmup_execute_after_compile") == 0) { + ET_CHECK_MSG(value_size == sizeof(uint8_t), "Unexpected value size!"); + bool value = getBool(value_data); + + config.warmup_execute_after_compile = value; + } } #ifdef ET_EVENT_TRACER_ENABLED config.enable_querypool = true; @@ -343,6 +380,15 @@ class GraphBuilder { } } + vkapi::ScalarType get_staging_scalar_type_of(const uint32_t fb_id) { + VkTensorPtr tensor_fb = + flatbuffer_->values()->Get(fb_id)->value_as_VkTensor(); + if (tensor_fb->staging_datatype() == vkgraph::VkDataType::UNSET) { + return get_scalar_type(tensor_fb->datatype()); + } + return get_scalar_type(tensor_fb->staging_datatype()); + } + void build_graph() { // Resize the mapping to the number of values in the flatbuffer resize(flatbuffer_->values()->size()); @@ -359,7 +405,8 @@ class GraphBuilder { for (const uint32_t fb_id : *flatbuffer_->input_ids()) { const ValueRef ref = get_fb_id_valueref(fb_id); if (compute_graph_->val_is_tensor(ref)) { - compute_graph_->set_input_tensor(ref); + compute_graph_->set_input_tensor( + ref, get_staging_scalar_type_of(fb_id)); } else { compute_graph_->set_val_as_input(ref); } @@ -376,6 +423,13 @@ class GraphBuilder { args.push_back(get_fb_id_valueref(static_cast(arg_fb_id))); } +#ifdef ET_EVENT_TRACER_ENABLED + std::string operator_json = + make_operator_json(compute_graph_, op_name, args); + set_and_get_current_operator_json(operator_json); + get_current_operator_count(true); +#endif // ET_EVENT_TRACER_ENABLED + auto vkFn = VK_GET_OP_FN(op_name); vkFn(*compute_graph_, args); } @@ -384,7 +438,15 @@ class GraphBuilder { // values as well if the source graph returns parameter nodes. for (const uint32_t fb_id : *flatbuffer_->output_ids()) { const ValueRef ref = get_fb_id_valueref(fb_id); - compute_graph_->set_output_value(ref); + if (compute_graph_->val_is_tensor(ref)) { +#ifdef ET_EVENT_TRACER_ENABLED + get_current_operator_count(true); +#endif // ET_EVENT_TRACER_ENABLED + compute_graph_->set_output_tensor( + ref, get_staging_scalar_type_of(fb_id)); + } else { + compute_graph_->set_output_value(ref); + } } if (compute_graph_->graphconfig().enable_querypool) { @@ -534,6 +596,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->prepack(); + compute_graph->optional_warmup_execute(); + return Error::Ok; } @@ -582,10 +646,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { bool was_resized = maybe_resize_input(compute_graph, i, args[i]->toTensor()); should_propagate_resize = should_propagate_resize || was_resized; - compute_graph->copy_into_staging( + compute_graph->maybe_cast_and_copy_into_staging( compute_graph->inputs()[i].staging, args[i]->toTensor().const_data_ptr(), - args[i]->toTensor().numel()); + args[i]->toTensor().numel(), + equivalent_scalar_type(args[i]->toTensor().scalar_type())); } else if (compute_graph->val_is_symint(iref)) { VK_CHECK_COND( args[i]->isTensor(), @@ -603,7 +668,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { } } - if (should_propagate_resize) { + if (should_propagate_resize || compute_graph->has_data_dependent_shapes()) { compute_graph->propagate_resize(); } @@ -617,10 +682,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { maybe_resize_output(compute_graph, i, args[o]->toTensor()); // args holds inputs directly followed by outputs, so the i'th output // for compute_graph corresponds to the o'th arg - compute_graph->copy_from_staging( + compute_graph->maybe_cast_and_copy_from_staging( compute_graph->outputs()[i].staging, args[o]->toTensor().mutable_data_ptr(), - args[o]->toTensor().numel()); + args[o]->toTensor().numel(), + equivalent_scalar_type(args[o]->toTensor().scalar_type())); } // TensorRef values represent constant tensors which will not have been // modified by the graph execution. Therefore, if a constant tensor is @@ -639,16 +705,14 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->context()->querypool().extract_results(); for (const auto& r : compute_graph->context()->querypool().get_shader_timestamp_data()) { - std::string event_name = - r.kernel_name + "_" + std::to_string(r.dispatch_id); + std::string event_name = "{" + r.kernel_name + + ", \"dispatch_id\": " + std::to_string(r.dispatch_id) + "}"; event_tracer_log_profiling_delegate( event_tracer, event_name.c_str(), /* delegate_debug_id = */ -1, r.start_time_ns, - r.end_time_ns, - (void*)(&r.metadata), - sizeof(r.metadata)); + r.end_time_ns); } #endif // ET_EVENT_TRACER_ENABLED diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 8599cbfffb6..326391424df 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -117,6 +117,18 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) { shader.kernel_name, vkapi::VulkanExtension::INTEGER_DOT_PRODUCT); } } + if (shader.requires_shader_int64) { + if (!adapter_p_->supports_int64_shader_types()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::SHADER_INT64); + } + } + if (shader.requires_shader_float64) { + if (!adapter_p_->supports_float64_shader_types()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::SHADER_FLOAT64); + } + } } vkapi::DescriptorSet Context::get_descriptor_set( @@ -276,9 +288,8 @@ Context* context() { return context.get(); } -#ifdef VULKAN_DEBUG - -#ifdef VK_KHR_pipeline_executable_properties +#if defined(VK_KHR_pipeline_executable_properties) && \ + defined(ETVK_INSPECT_PIPELINES) VkPipeline Context::get_shader_pipeline( const vkapi::ShaderInfo& shader, @@ -490,9 +501,7 @@ void Context::print_shader_executable_properties( } } -#endif // VK_KHR_pipeline_executable_properties - -#endif // VULKAN_DEBUG +#endif // VK_KHR_pipeline_executable_properties && ETVK_INSPECT_PIPELINES } // namespace api } // namespace vkcompute diff --git a/backends/vulkan/runtime/api/Context.h b/backends/vulkan/runtime/api/Context.h index 9c7301b9971..5764cb6a894 100644 --- a/backends/vulkan/runtime/api/Context.h +++ b/backends/vulkan/runtime/api/Context.h @@ -234,9 +234,8 @@ class Context final { void flush(); -#ifdef VULKAN_DEBUG - -#ifdef VK_KHR_pipeline_executable_properties +#if defined(VK_KHR_pipeline_executable_properties) && \ + defined(ETVK_INSPECT_PIPELINES) VkPipeline get_shader_pipeline( const vkapi::ShaderInfo& shader, @@ -260,9 +259,7 @@ class Context final { const vkapi::ShaderInfo& shader, const vkapi::SpecVarList& spec_constants); -#endif // VK_KHR_pipeline_executable_properties - -#endif // VULKAN_DEBUG +#endif // VK_KHR_pipeline_executable_properties && ETVK_INSPECT_PIPELINES }; bool available(); diff --git a/backends/vulkan/runtime/api/containers/StagingBuffer.h b/backends/vulkan/runtime/api/containers/StagingBuffer.h index 1e9f569fc4a..47469a06156 100644 --- a/backends/vulkan/runtime/api/containers/StagingBuffer.h +++ b/backends/vulkan/runtime/api/containers/StagingBuffer.h @@ -31,11 +31,13 @@ class StagingBuffer final { StagingBuffer( Context* context_p, const vkapi::ScalarType dtype, - const size_t numel) + const size_t numel, + const vkapi::CopyDirection direction) : context_p_(context_p), dtype_(dtype), vulkan_buffer_(context_p_->adapter_ptr()->vma().create_staging_buffer( - element_size(dtype_) * numel)), + element_size(dtype_) * numel, + direction)), mapped_data_(nullptr) {} StagingBuffer(const StagingBuffer&) = delete; @@ -48,7 +50,7 @@ class StagingBuffer final { context_p_->register_buffer_cleanup(vulkan_buffer_); } - inline vkapi::ScalarType dtype() { + inline vkapi::ScalarType dtype() const { return dtype_; } @@ -81,6 +83,15 @@ class StagingBuffer final { VK_WHOLE_SIZE); } + template + void cast_and_copy_from(const SRC_T* src, const size_t numel) { + VK_CHECK_COND(numel <= this->numel()); + DST_T* dst = reinterpret_cast(data()); + for (size_t i = 0; i < numel; ++i) { + dst[i] = static_cast(src[i]); + } + } + inline void copy_to(void* dst, const size_t nbytes) { VK_CHECK_COND(nbytes <= this->nbytes()); vmaInvalidateAllocation( @@ -91,9 +102,32 @@ class StagingBuffer final { memcpy(dst, data(), nbytes); } + template + void cast_and_copy_to(DST_T* dst, const size_t numel) { + VK_CHECK_COND(numel <= this->numel()); + const SRC_T* src = reinterpret_cast(data()); + for (size_t i = 0; i < numel; ++i) { + dst[i] = static_cast(src[i]); + } + } + inline void set_staging_zeros() { memset(data(), 0, nbytes()); } + + template + T select_element_at_dim( + const std::vector& sizes, + const int64_t dim, + const int64_t index) { + int64_t stride = 1; + for (size_t i = dim + 1; i < sizes.size(); ++i) { + stride *= sizes[i]; + } + const int64_t offset = index * stride; + const T* typed_data = reinterpret_cast(data()); + return typed_data[offset]; + } }; } // namespace api diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 433ae15db4e..5a1c445889e 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -14,6 +14,21 @@ namespace vkcompute { namespace api { +/* + * For PackedInt8 memory layouts, ensure that the scalar type used for the + * tensor is kInt8x4. Otherwise, return the original scalar type. + */ +vkapi::ScalarType get_effective_scalar_type( + const vkapi::ScalarType dtype, + const utils::GPUMemoryLayout memory_layout) { + vkapi::ScalarType effective_dtype = dtype; + if (utils::is_packed_int8_layout(memory_layout)) { + VK_CHECK_COND(dtype == vkapi::kInt8x4 || dtype == vkapi::kChar); + effective_dtype = vkapi::kInt8x4; + } + return effective_dtype; +} + /* * Used to infer the sizes of a tensor that would correspond to a given * VulkanImage. @@ -187,6 +202,7 @@ std::vector calculate_padded_sizes( utils::uvec3 calculate_image_extents( const std::vector& padded_sizes, + const utils::GPUMemoryLayout memory_layout, const std::vector& axis_map, const int32_t packed_dim) { utils::uvec3 extents({1, 1, 1}); @@ -205,6 +221,28 @@ utils::uvec3 calculate_image_extents( extents[axis] = utils::safe_downcast(padded_sizes.at(dim)); } + // For "regular" tensor dtypes, 4 elements along the packed dim are packed + // into one texel (4-component vectorized type). However, for packed int8 + // memory layouts, an additional level of packing is employed where 4 int8 + // elements are packed into one int32, and then 4 int32 are packed into each + // ivec4 texel. + if (utils::is_packed_int8_layout(memory_layout)) { + // Each int in the ivec4 contains 4 channels. The overall ivec4 contains + // data for a 1Hx4Wx4C block of the input tensor. + if (memory_layout == utils::kPackedInt8_4W4C) { + VK_CHECK_COND(packed_dim == 2); + extents[axis_map.at(0)] = utils::div_up(extents[axis_map.at(0)], 4u); + } + // Each int in the ivec4 contains 4 elements along the width dim. The + // overall ivec4 contains data for a 4Hx4W block of the input tensor. + else if (memory_layout == utils::kPackedInt8_4H4W) { + VK_CHECK_COND(packed_dim == 0); + extents[axis_map.at(1)] = utils::div_up(extents[axis_map.at(1)], 4u); + } else { + VK_THROW("Unhandled packed int8 memory layout!"); + } + } + // axis_map[3] indicates the WHCN index of the dimension used for batch // concatenation. Thus a double lookup is required to determine the image axis // used for batch concatenation. @@ -215,6 +253,7 @@ utils::uvec3 calculate_image_extents( VK_CHECK_COND(extents[axis_map.at(packed_dim)] % 4 == 0); extents[axis_map.at(packed_dim)] /= 4; + return extents; } @@ -247,35 +286,72 @@ utils::uvec3 calculate_logical_limits( */ utils::uvec3 calculate_logical_limits( const std::vector& sizes, + const utils::GPUMemoryLayout memory_layout, const std::vector& axis_map, const int32_t packed_dim) { return calculate_logical_limits( calculate_image_extents( - calculate_padded_sizes(sizes, packed_dim), axis_map, packed_dim), + calculate_padded_sizes(sizes, packed_dim), + memory_layout, + axis_map, + packed_dim), axis_map); } int64_t calculate_gpu_buffer_numel( + const std::vector& sizes, + const utils::GPUMemoryLayout memory_layout, + const vkapi::ScalarType dtype) { + size_t numel; + + // Mirrors the logic in calculate_image_extents for packed int8 memory layouts + if (dtype == vkapi::kInt8x4) { + VK_CHECK_COND(utils::is_packed_int8_layout(memory_layout)); + std::vector blocks_in_dim = + flip_and_unsqueeze(sizes, kTensorSizes, 0); + // Each ivec4 contains data for a 1Hx4Wx4C block of the input + if (memory_layout == utils::kPackedInt8_4W4C) { + blocks_in_dim[0] = utils::div_up_4(blocks_in_dim[0]); + blocks_in_dim[2] = utils::div_up_4(blocks_in_dim[2]); + } + // Each ivec4 contains data for a 4Hx4W block of the input + else if (memory_layout == utils::kPackedInt8_4H4W) { + blocks_in_dim[0] = utils::div_up_4(blocks_in_dim[0]); + blocks_in_dim[1] = utils::div_up_4(blocks_in_dim[1]); + } else { + VK_THROW("Unhandled packed int8 memory layout!"); + } + // Each block is represented as an ivec4, and the base dtype of the buffer + // is int. Therefore, need to multiply the number of blocks by 4 to obtain + // the number of int elements in the data buffer. + numel = utils::multiply_integers(blocks_in_dim) * 4; + } + // Case for "regular" dtypes/memory layouts + else { + numel = utils::multiply_integers(sizes); + + // For 8-bit types, align to the next multiple of 4. For devices that do not + // support 8-bit storage buffers, the tensor data will be interpreted as an + // array of int32 instead. + if (vkapi::element_size(dtype) == 1) { + numel = utils::align_up_4(numel); + } + } + return numel; +} + +int64_t calculate_staging_or_gpu_buffer_numel( Context* const context, const std::vector& sizes, const utils::uvec3 image_extents, const utils::StorageType storage_type, + const utils::GPUMemoryLayout memory_layout, const vkapi::ScalarType dtype) { // For texture backed tensors, simply multiply the total number of texels by 4 if (storage_type != utils::kBuffer) { return image_extents[0] * image_extents[1] * image_extents[2] * 4; } - const bool is_int8 = dtype == vkapi::kChar; - const bool int8_supported = - context->adapter_ptr()->has_full_int8_buffers_support(); - const size_t numel = utils::multiply_integers(sizes); - // For int8 tensors, if the device does not support int8 buffers, then int32 - // is used instead to represent the buffer data. Therefore the number of - // elements in the buffer is aligned to the next multiple of 4. - if (is_int8 && int8_supported) { - return utils::align_up_4(numel); - } - return numel; + return calculate_gpu_buffer_numel(sizes, memory_layout, dtype); } template ::value>> @@ -332,10 +408,12 @@ vkapi::VulkanImage allocate_image( Context* const context_ptr, utils::uvec3& image_extents, const utils::StorageType storage_type, - const VkFormat image_format, + const vkapi::ScalarType dtype, const bool allocate_memory) { vkapi::Adapter* adapter_ptr = context_ptr->adapter_ptr(); + const VkFormat image_format = vkcompute::vkapi::to_vkformat(dtype); + vkapi::ImageSampler::Properties sampler_props{ VK_FILTER_NEAREST, VK_SAMPLER_MIPMAP_MODE_NEAREST, @@ -420,6 +498,7 @@ vkapi::VulkanBuffer allocate_buffer( vTensorStorage::vTensorStorage( Context* const context, const utils::StorageType storage_type, + const utils::GPUMemoryLayout memory_layout, const std::vector& axis_map, const int32_t packed_dim, const std::vector& sizes, @@ -429,20 +508,22 @@ vTensorStorage::vTensorStorage( storage_type_{storage_type}, image_extents_(calculate_image_extents( calculate_padded_sizes(sizes, packed_dim), + memory_layout, axis_map, packed_dim)), - buffer_length_{calculate_gpu_buffer_numel( + buffer_length_{calculate_staging_or_gpu_buffer_numel( context_, sizes, image_extents_, storage_type, + memory_layout, dtype)}, buffer_offset_{0}, image_(allocate_image( context_, image_extents_, storage_type_, - to_vkformat(dtype), + dtype, allocate_memory)), buffer_(allocate_buffer( context_, @@ -553,7 +634,7 @@ vTensor::vTensor( const utils::GPUMemoryLayout memory_layout, const bool allocate_memory, const utils::AxisMapLayout axis_map_layout) - : dtype_(dtype), + : dtype_(get_effective_scalar_type(dtype, memory_layout)), // Calculate tensor metadata sizes_(sizes.begin(), sizes.end()), packed_dim_(utils::to_packed_dim(memory_layout)), @@ -576,6 +657,7 @@ vTensor::vTensor( storage_(std::make_shared( context, storage_type, + memory_layout, axis_map_, packed_dim_, sizes, @@ -754,6 +836,50 @@ void vTensor::BufferMetadata::update( numel = utils::safe_downcast(src_numel); } +vTensor::TextureMetadata::TextureMetadata( + const std::vector& src_sizes, + const TextureLimits& src_logical_limits, + const std::vector& src_axis_map, + const int32_t src_packed_dim) { + update(src_sizes, src_logical_limits, src_axis_map, src_packed_dim); +} + +void vTensor::TextureMetadata::update( + const std::vector& src_sizes, + const TextureLimits& src_logical_limits, + const std::vector& src_axis_map, + const int32_t src_packed_dim) { + // Convert sizes to flipped and unsqueezed format (fixed to 4 dimensions for + // texture) + std::vector fu_sizes = + flip_and_unsqueeze(src_sizes, kTensorSizes, 0, 4); + + // Copy sizes (up to 4 elements) + for (int i = 0; i < 4; ++i) { + sizes[i] = fu_sizes.at(i); + } + + // Copy logical limits (3 elements) + logical_limits[0] = + utils::safe_downcast(src_logical_limits.limits[0]); + logical_limits[1] = + utils::safe_downcast(src_logical_limits.limits[1]); + logical_limits[2] = + utils::safe_downcast(src_logical_limits.limits[2]); + logical_limits[3] = 1u; + + // Copy axis map (up to 4 elements) + for (int i = 0; i < 4 && i < src_axis_map.size(); ++i) { + axis_map[i] = utils::safe_downcast(src_axis_map.at(i)); + } + // Pad with zeros if axis_map is smaller than 4 + for (int i = src_axis_map.size(); i < 4; ++i) { + axis_map[i] = 0; + } + + packed_dim = src_packed_dim; +} + vkapi::VulkanImage& vTensor::image( vkapi::PipelineBarrier& pipeline_barrier, const vkapi::PipelineStageFlags stage) & { @@ -785,6 +911,16 @@ vkapi::VulkanBuffer& vTensor::buffer( } utils::GPUMemoryLayout vTensor::estimate_memory_layout() const { + if (dtype_ == vkapi::kInt8x4) { + switch (packed_dim_) { + case WHCN::kChannelsDim: + return utils::kPackedInt8_4W4C; + case WHCN::kWidthDim: + return utils::kPackedInt8_4H4W; + default: + VK_THROW("Invalid packed dim for Tensor with kInt8x4 type"); + } + } switch (packed_dim_) { case WHCN::kWidthDim: return utils::kWidthPacked; @@ -856,6 +992,16 @@ const vkapi::BufferBindInfo vTensor::buffer_meta_ubo() { return vkapi::BufferBindInfo(buffer_meta_.buffer(), 0, ubo_nbytes); } +const vkapi::BufferBindInfo vTensor::texture_meta_ubo() { + size_t ubo_nbytes = sizeof(TextureMetadata); + if (!texture_meta_.buffer()) { + TextureLimits limits(logical_limits()); + TextureMetadata data(sizes_, limits, axis_map_, packed_dim_); + texture_meta_ = ParamsBuffer(storage_->context_, data); + } + return vkapi::BufferBindInfo(texture_meta_.buffer(), 0, ubo_nbytes); +} + VkMemoryRequirements vTensor::get_memory_requirements() const { switch (storage_type()) { case utils::kBuffer: @@ -914,8 +1060,8 @@ void vTensor::update_metadata() { flip_and_unsqueeze_ivec4(dim_order_, kTensorDimOrder, numel_); uniform_data_->strides_v = flip_and_unsqueeze_ivec4(strides_, kTensorStrides, numel_); - uniform_data_->logical_limits.limits = - calculate_logical_limits(sizes_, axis_map_, packed_dim_); + uniform_data_->logical_limits.limits = calculate_logical_limits( + sizes_, estimate_memory_layout(), axis_map_, packed_dim_); if (sizes_uniform_offset_ != kUniformOffsetUnset) { uniforms_.update(uniform_data_->sizes_v, sizes_uniform_offset_); @@ -939,14 +1085,24 @@ void vTensor::update_metadata() { BufferMetadata data(sizes_, dim_order_, strides_, numel_); buffer_meta_.update(data); } + + if (texture_meta_.buffer()) { + TextureMetadata data( + sizes_, uniform_data_->logical_limits, axis_map_, packed_dim_); + texture_meta_.update(data); + } } void vTensor::check_sizes(const std::vector& sizes) const { + utils::GPUMemoryLayout est_memory_layout = estimate_memory_layout(); if (storage_type() != utils::kBuffer) { // For texture storage check that the current texture is large enough for // the new sizes of the tensor. utils::uvec3 virtual_extents = calculate_image_extents( - calculate_padded_sizes(sizes_, packed_dim_), axis_map_, packed_dim_); + calculate_padded_sizes(sizes_, packed_dim_), + est_memory_layout, + axis_map_, + packed_dim_); bool valid_resize = virtual_extents[0] <= storage_->image_extents_[0]; valid_resize = @@ -958,9 +1114,10 @@ void vTensor::check_sizes(const std::vector& sizes) const { valid_resize, "tensor sizes requires a larger texture than the current one."); } else { - // For buffer storage check that the current buffer is large enough for the - // new sizes of the tensor. - int64_t numel = utils::multiply_integers(sizes); + // For buffer storage check that the current buffer is large enough for + // the new sizes of the tensor. + int64_t numel = + calculate_gpu_buffer_numel(sizes_, est_memory_layout, dtype_); bool valid_resize = numel + storage_->buffer_offset_ <= storage_->buffer_length_; VK_CHECK_COND( diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index 66c1fd1e4da..967148b8dbe 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -99,6 +99,7 @@ class vTensorStorage final { vTensorStorage( Context* context, const utils::StorageType storage_type, + const utils::GPUMemoryLayout memory_layout, const std::vector& axis_map, const int32_t packed_dim, const std::vector& sizes, @@ -284,6 +285,25 @@ class vTensor final { size_t numel); }; + struct TextureMetadata { + int32_t sizes[4]; + int32_t logical_limits[4]; + int32_t axis_map[4]; + int32_t packed_dim; + + TextureMetadata( + const std::vector& sizes, + const TextureLimits& logical_limits, + const std::vector& axis_map, + const int32_t packed_dim); + + void update( + const std::vector& sizes, + const TextureLimits& logical_limits, + const std::vector& axis_map, + const int32_t packed_dim); + }; + private: /* * "Core" tensor metadata. They are the minimum amount of information required @@ -359,6 +379,12 @@ class vTensor final { */ ParamsBuffer buffer_meta_; + /* + * Used to store data for TextureMetadata to pass to shaders as + * texture_meta_ubo + */ + ParamsBuffer texture_meta_; + uint32_t uniforms_size_ = 0u; uint32_t sizes_uniform_offset_ = kUniformOffsetUnset; uint32_t dim_order_uniform_offset_ = kUniformOffsetUnset; @@ -586,6 +612,8 @@ class vTensor final { const vkapi::BufferBindInfo buffer_meta_ubo(); + const vkapi::BufferBindInfo texture_meta_ubo(); + public: inline size_t staging_buffer_numel() const { return storage_->buffer_len(); diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 3f2d616b428..ab709092351 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -233,6 +233,14 @@ def texel_component_type(dtype: str) -> str: raise AssertionError(f"Invalid vec4 type: {vec4_type}") +def accum_vec_type(dtype: str) -> str: + return texel_type(dtype) + + +def accum_scalar_type(dtype: str) -> str: + return texel_component_type(dtype) + + def texel_load_type(dtype: str, storage_type: str) -> str: if storage_type.lower() == "buffer": return buffer_gvec_type(dtype, 4) @@ -455,6 +463,8 @@ def define_required_extensions(dtypes: Union[str, List[str]]): "buffer_gvec_type": buffer_gvec_type, "texel_type": texel_type, "gvec_type": gvec_type, + "accum_vec_type": accum_vec_type, + "accum_scalar_type": accum_scalar_type, "texel_component_type": texel_component_type, "texel_load_type": texel_load_type, "texel_load_component_type": texel_load_component_type, @@ -670,7 +680,7 @@ def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None: if len(file) > 1: self.template_yaml_files.append(file) - def generateVariantCombinations( + def generateVariantCombinations( # noqa: C901 self, iterated_params: Dict[str, Any], exclude_params: Optional[Set[str]] = None, @@ -679,7 +689,25 @@ def generateVariantCombinations( exclude_params = set() all_iterated_params = [] for param_name, value_list in iterated_params.items(): - if param_name not in exclude_params: + if re.match(r"^combination\d*$", param_name): + param_values = [] + param_names = value_list["parameter_names"] + combos = value_list["combos"] + for combo in combos: + parameter_values = combo["parameter_values"] + if "suffix" in combo: + suffix = combo["suffix"] + else: + suffix = "" + for param_value in parameter_values: + if len(str(param_value)) > 0: + suffix += "_" + str(param_value) + suffix = suffix[1:] + param_values.append((param_names, suffix, parameter_values)) + + all_iterated_params.append(param_values) + + elif param_name not in exclude_params: param_values = [] for value in value_list: if "RANGE" in value: @@ -713,7 +741,7 @@ def generateVariantCombinations( return list(product(*all_iterated_params)) - def parseTemplateYaml(self, yaml_file: str) -> None: + def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901 with open(yaml_file) as f: contents = yaml.load(f, Loader=UniqueKeyLoader) for template_name, params_dict in contents.items(): @@ -762,10 +790,21 @@ def parseTemplateYaml(self, yaml_file: str) -> None: default_params_copy[key] = variant[key] variant_name = variant["NAME"] - for param_value in combination: - default_params_copy[param_value[0]] = param_value[2] - if len(str(param_value[1])) > 0: - variant_name = f"{variant_name}_{param_value[1]}" + + for setting in combination: + param_names = setting[0] + suffix = setting[1] + param_values = setting[2] + if isinstance(param_names, list): + for param_name, param_value in zip( + param_names, param_values + ): + default_params_copy[param_name] = param_value + else: + default_params_copy[param_names] = param_values + + if len(str(suffix)) > 0: + variant_name = f"{variant_name}_{suffix}" default_params_copy["NAME"] = variant_name default_params_copy["VARIANT_NAME"] = variant["NAME"] @@ -1104,6 +1143,8 @@ class ShaderInfo: requires_16bit_storage_ext: bool = False requires_8bit_storage_ext: bool = False requires_integer_dot_product_ext: bool = False + requires_shader_int64_ext: bool = False + requires_shader_float64_ext: bool = False def getName(filePath: str) -> str: @@ -1193,7 +1234,7 @@ def determineDescriptorType(lineStr: str) -> str: ) -def getShaderInfo(srcFilePath: str) -> ShaderInfo: +def getShaderInfo(srcFilePath: str) -> ShaderInfo: # noqa: C901 shader_info = ShaderInfo([], [], "") with open(srcFilePath) as srcFile: for line in srcFile: @@ -1216,6 +1257,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo: shader_info.requires_8bit_storage_ext = True if "GL_EXT_integer_dot_product" in line: shader_info.requires_integer_dot_product_ext = True + if "GL_EXT_shader_explicit_arithmetic_types_int64" in line: + shader_info.requires_shader_int64_ext = True + if "GL_EXT_shader_explicit_arithmetic_types_float64" in line: + shader_info.requires_shader_float64_ext = True return shader_info @@ -1292,6 +1337,8 @@ def to_cpp_str(val: bool): to_cpp_str(shader_info.requires_16bit_storage_ext), to_cpp_str(shader_info.requires_8bit_storage_ext), to_cpp_str(shader_info.requires_integer_dot_product_ext), + to_cpp_str(shader_info.requires_shader_int64_ext), + to_cpp_str(shader_info.requires_shader_float64_ext), ] shader_info_str = textwrap.indent( diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 6609298b0d8..346ffd4d35b 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -15,6 +15,24 @@ #include +#ifdef ET_EVENT_TRACER_ENABLED +std::string& set_and_get_current_operator_json(const std::string& json) { + static std::string current_operator_json; + if (json.size() > 0) { + current_operator_json = json; + } + return current_operator_json; +} + +size_t get_current_operator_count(const bool increment) { + static int count = 0; + if (increment) { + count++; + } + return count; +} +#endif /* ET_EVENT_TRACER_ENABLED */ + namespace vkcompute { // @@ -310,6 +328,8 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const { return val.toConstTensor().dtype(); } else if (val.isTensorRef()) { return val.toConstTensorRef().dtype; + } else if (val.isStaging()) { + return val.toConstStaging().dtype(); } else if (val.isBool()) { return vkapi::ScalarType::Bool; } else if (val.isDouble()) { @@ -541,10 +561,11 @@ ValueRef ComputeGraph::add_tensorref( ValueRef ComputeGraph::add_staging( const vkapi::ScalarType dtype, - const size_t numel) { + const size_t numel, + const vkapi::CopyDirection direction) { ValueRef idx(static_cast(values_.size())); check_no_active_value_ptrs(); - values_.emplace_back(api::StagingBuffer(context(), dtype, numel)); + values_.emplace_back(api::StagingBuffer(context(), dtype, numel, direction)); return idx; } @@ -585,21 +606,47 @@ ValueRef ComputeGraph::get_or_add_value_for_int(const int64_t val) { return add_scalar(val); } +ValueRef ComputeGraph::set_input_tensor( + const ValueRef idx, + vkapi::ScalarType staging_dtype) { + // For texture storage, the buffer size needs to account for the zero + // padding applied by unused texel elements. + size_t buf_numel = get_tensor(idx)->staging_buffer_numel(); + ValueRef staging_idx = add_staging( + staging_dtype, buf_numel, vkapi::CopyDirection::HOST_TO_DEVICE); + add_staging_to_tensor_node(*this, staging_idx, idx); + inputs_.push_back({idx, staging_idx}); + return staging_idx; +} + ValueRef ComputeGraph::set_input_tensor( const ValueRef idx, const bool use_staging) { if (use_staging) { vkapi::ScalarType dtype = get_tensor(idx)->dtype(); - // For texture storage, the buffer size needs to account for the zero - // padding applied by unused texel elements. - size_t buf_numel = get_tensor(idx)->staging_buffer_numel(); - ValueRef staging_idx = add_staging(dtype, buf_numel); - add_staging_to_tensor_node(*this, staging_idx, idx); - inputs_.push_back({idx, staging_idx}); - return staging_idx; - } - inputs_.push_back({idx, kDummyValueRef}); - return idx; + return set_input_tensor(idx, dtype); + } else { + inputs_.push_back({idx, kDummyValueRef}); + return idx; + } +} + +ValueRef ComputeGraph::set_output_tensor( + const ValueRef idx, + vkapi::ScalarType staging_dtype) { + // For texture storage, the buffer size needs to account for the zero + // padding applied by unused texel elements. + size_t buf_numel = get_tensor(idx)->staging_buffer_numel(); + ValueRef staging_idx = add_staging( + staging_dtype, buf_numel, vkapi::CopyDirection::DEVICE_TO_HOST); + // We only run this when the tensor is non-empty. When the underlying + // tensor is empty (e.g. padded_numel == 0), we do not allocate a VkImage to + // tensor, we will not be able to bind the node for execution. + if (buf_numel > 0) { + add_tensor_to_staging_node(*this, idx, staging_idx); + } + outputs_.push_back({idx, staging_idx}); + return staging_idx; } ValueRef ComputeGraph::set_output_tensor( @@ -607,21 +654,11 @@ ValueRef ComputeGraph::set_output_tensor( const bool use_staging) { if (use_staging) { vkapi::ScalarType dtype = get_tensor(idx)->dtype(); - // For texture storage, the buffer size needs to account for the zero - // padding applied by unused texel elements. - size_t buf_numel = get_tensor(idx)->staging_buffer_numel(); - ValueRef staging_idx = add_staging(dtype, buf_numel); - // We only run this when the tensor is non-empty. When the underlying - // tensor is empty (e.g. padded_numel == 0), we do not allocate a VkImage to - // tensor, we will not be able to bind the node for execution. - if (buf_numel > 0) { - add_tensor_to_staging_node(*this, idx, staging_idx); - } - outputs_.push_back({idx, staging_idx}); - return staging_idx; + return set_output_tensor(idx, dtype); + } else { + outputs_.push_back({idx, kDummyValueRef}); + return idx; } - outputs_.push_back({idx, kDummyValueRef}); - return idx; } ValueRef ComputeGraph::set_output_value(const ValueRef idx) { @@ -667,6 +704,17 @@ int32_t ComputeGraph::read_symint(const ValueRef idx) { return get_symint(idx)->get(); } +ValueRef ComputeGraph::staging_of(const ValueRef idx) { + for (size_t i = 0; i < inputs_.size(); ++i) { + if (inputs_[i].value == idx) { + if (is_valid(inputs_[i].staging)) { + return inputs_[i].staging; + } + } + } + VK_THROW("Could not find staging buffer for value at index ", idx); +} + SharedObject& ComputeGraph::get_shared_object(const int64_t idx) { if (idx >= shared_objects_.size()) { shared_objects_.resize(static_cast(idx + 1)); @@ -847,6 +895,36 @@ void ComputeGraph::copy_into_staging( staging->copy_from(data, nbytes); } +void ComputeGraph::maybe_cast_and_copy_into_staging( + const ValueRef idx, + const void* data, + const size_t numel, + const vkapi::ScalarType src_data_dtype) { + StagingPtr staging = get_staging(idx); + vkapi::ScalarType staging_dtype = staging->dtype(); + if (src_data_dtype == staging_dtype) { + size_t nbytes = numel * vkapi::element_size(staging_dtype); + staging->copy_from(data, nbytes); + return; + } else { + // Hard-coded type conversion cases + if (src_data_dtype == vkapi::kLong && staging_dtype == vkapi::kInt) { + const int64_t* casted_data = reinterpret_cast(data); + staging->cast_and_copy_from(casted_data, numel); + } else if ( + src_data_dtype == vkapi::kDouble && staging_dtype == vkapi::kFloat) { + const double* casted_data = reinterpret_cast(data); + staging->cast_and_copy_from(casted_data, numel); + } else { + VK_THROW( + "Unsupported type conversion from ", + src_data_dtype, + " to staging dtype ", + staging_dtype); + } + } +} + void ComputeGraph::copy_from_staging( const ValueRef idx, void* data, @@ -856,6 +934,36 @@ void ComputeGraph::copy_from_staging( staging->copy_to(data, nbytes); } +void ComputeGraph::maybe_cast_and_copy_from_staging( + const ValueRef idx, + void* data, + const size_t numel, + const vkapi::ScalarType dst_data_dtype) { + StagingPtr staging = get_staging(idx); + vkapi::ScalarType staging_dtype = staging->dtype(); + if (dst_data_dtype == staging_dtype) { + size_t nbytes = numel * vkapi::element_size(staging_dtype); + staging->copy_to(data, nbytes); + return; + } else { + // Hard-coded type conversion cases + if (dst_data_dtype == vkapi::kLong && staging_dtype == vkapi::kInt) { + int64_t* casted_data = reinterpret_cast(data); + staging->cast_and_copy_to(casted_data, numel); + } else if ( + dst_data_dtype == vkapi::kDouble && staging_dtype == vkapi::kFloat) { + double* casted_data = reinterpret_cast(data); + staging->cast_and_copy_to(casted_data, numel); + } else { + VK_THROW( + "Unsupported type conversion from staging dtype ", + staging_dtype, + " to ", + dst_data_dtype); + } + } +} + void ComputeGraph::prepare() { #define MERGE_FIELD(field) \ static_cast(std::ceil( \ @@ -1020,6 +1128,12 @@ void ComputeGraph::prepack() { } } +void ComputeGraph::optional_warmup_execute() { + if (config_.warmup_execute_after_compile) { + execute(); + } +} + void ComputeGraph::execute() { if (deferred_cmd_list_.empty()) { context_->flush(); diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 23b5517fd22..18e97d7b516 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -25,6 +25,11 @@ #include #include +#ifdef ET_EVENT_TRACER_ENABLED +std::string& set_and_get_current_operator_json(const std::string& json); +size_t get_current_operator_count(const bool increment = false); +#endif + namespace vkcompute { // Define valid scalar types that the Value class can @@ -449,6 +454,18 @@ class ComputeGraph final { return values_.at(idx).toTensor().buffer_meta_ubo(); } + inline vkapi::BufferBindInfo texture_meta_ubo(const ValueRef idx) { + return values_.at(idx).toTensor().texture_meta_ubo(); + } + + inline vkapi::BufferBindInfo meta_ubo(const ValueRef idx) { + if (is_buffer_storage(idx)) { + return buffer_meta_ubo(idx); + } else { + return texture_meta_ubo(idx); + } + } + inline vkapi::BufferBindInfo strides_ubo(const ValueRef idx) { return values_.at(idx).toTensor().strides_ubo(); } @@ -627,6 +644,10 @@ class ComputeGraph final { bool device_name_contains(const char* substr); + int64_t max_buffer_numel() { + return static_cast(context_->adapter_ptr()->max_buffer_numel()); + } + // // Graph Building // @@ -746,7 +767,10 @@ class ComputeGraph final { * use memory that is visible to both the CPU and GPU, and therefore is used * as a intermediary when transferring data between the CPU and GPU. */ - ValueRef add_staging(const vkapi::ScalarType dtype, const size_t numel); + ValueRef add_staging( + const vkapi::ScalarType dtype, + const size_t numel, + const vkapi::CopyDirection direction); ValueRef add_none(); @@ -771,7 +795,16 @@ class ComputeGraph final { */ ValueRef get_or_add_value_for_int(const int64_t val); + ValueRef set_input_tensor( + const ValueRef idx, + vkapi::ScalarType staging_dtype); + ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); + + ValueRef set_output_tensor( + const ValueRef idx, + vkapi::ScalarType staging_dtype); + ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_value(const ValueRef idx); @@ -803,6 +836,8 @@ class ComputeGraph final { inputs_.push_back({idx, kDummyValueRef}); } + ValueRef staging_of(const ValueRef idx); + inline void set_val_as_output(const ValueRef idx) { outputs_.push_back({idx, kDummyValueRef}); } @@ -947,8 +982,21 @@ class ComputeGraph final { void copy_into_staging(const ValueRef idx, const void* data, const size_t numel); + + void maybe_cast_and_copy_into_staging( + const ValueRef idx, + const void* data, + const size_t numel, + const vkapi::ScalarType src_data_dtype); + void copy_from_staging(const ValueRef idx, void* data, const size_t numel); + void maybe_cast_and_copy_from_staging( + const ValueRef idx, + void* data, + const size_t numel, + const vkapi::ScalarType dst_data_dtype); + protected: // Command Buffer Management @@ -993,6 +1041,12 @@ class ComputeGraph final { */ void prepack(); + // + // Optional Graph Execution + // + + void optional_warmup_execute(); + // // Graph Execution // @@ -1047,6 +1101,14 @@ class ComputeGraph final { return can_use_int8_dot_product_; } + inline void set_has_data_dependent_shapes() { + config_.has_data_dependent_shapes = true; + } + + inline bool has_data_dependent_shapes() const { + return config_.has_data_dependent_shapes; + } + /* * Check whether the GPU supports 8 bit buffers. */ diff --git a/backends/vulkan/runtime/graph/GraphConfig.cpp b/backends/vulkan/runtime/graph/GraphConfig.cpp index da5efbf8342..9a919a42573 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.cpp +++ b/backends/vulkan/runtime/graph/GraphConfig.cpp @@ -64,7 +64,9 @@ GraphConfig::GraphConfig() { enable_local_wg_size_override = false; local_wg_size_override = {}; + has_data_dependent_shapes = false; expect_dynamic_shapes = false; + force_resize = false; external_adapter = nullptr; } diff --git a/backends/vulkan/runtime/graph/GraphConfig.h b/backends/vulkan/runtime/graph/GraphConfig.h index aa5cd8f8c4e..20d01362ef1 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.h +++ b/backends/vulkan/runtime/graph/GraphConfig.h @@ -33,8 +33,14 @@ struct GraphConfig final { bool enable_local_wg_size_override; utils::uvec3 local_wg_size_override; + // If true, then resize functions should always be called even if input shapes + // have not changed. + bool has_data_dependent_shapes = false; // Whether or not the ComputeGraph should expect input shapes to be dynamic - bool expect_dynamic_shapes; + bool expect_dynamic_shapes = false; + // Used for testing/debugging only. Forces ExecuteNode to trigger the resize + // function even if none of the inputs have been updated. + bool force_resize = false; // Execution properties that determine specifics re: how command buffer // submission is handled, etc. 0 means this field is not set. @@ -65,6 +71,10 @@ struct GraphConfig final { // many command buffers. size_t execute_max_cmds = 0; + // If true, then the graph will be executed once immediately after it is + // compiled. + bool warmup_execute_after_compile = false; + vkapi::Adapter* external_adapter; // Generate a default graph config with pre-configured settings diff --git a/backends/vulkan/runtime/graph/Logging.cpp b/backends/vulkan/runtime/graph/Logging.cpp index 081083e3a63..c9406ca0ef2 100644 --- a/backends/vulkan/runtime/graph/Logging.cpp +++ b/backends/vulkan/runtime/graph/Logging.cpp @@ -17,6 +17,82 @@ namespace vkcompute { +std::ostream& operator<<(std::ostream& os, const std::vector& sizes) { + if (sizes.size() == 0) { + os << "[]"; + return os; + } + os << "["; + for (int i = 0; i < sizes.size() - 1; ++i) { + os << sizes.at(i) << ", "; + } + os << sizes.at(sizes.size() - 1); + os << "]"; + return os; +} + +std::string make_arg_json(ComputeGraph* const compute_graph, ValueRef arg) { + std::stringstream ss; + ss << "{\"type\": \"" << compute_graph->get_val_type(arg) << "\", "; + ss << "\"value_ref\": " << arg; + if (compute_graph->val_is_tensor(arg)) { + ss << ", \"dtype\": \""; + ss << compute_graph->dtype_of(arg) << "\""; + ss << ", \"sizes\": "; + ss << compute_graph->sizes_of(arg); + ss << ", \"storage\": \""; + ss << compute_graph->storage_type_of(arg) << "\""; + ss << ", \"packed_dim\": "; + ss << compute_graph->packed_dim_of(arg); + } else if (compute_graph->val_is_tref(arg)) { + ss << ", \"sizes\": "; + ss << compute_graph->sizes_of(arg); + ss << ", \"dtype\": \""; + ss << compute_graph->dtype_of(arg) << "\""; + } else if (compute_graph->val_is_value_list(arg)) { + ValueListPtr val_list = compute_graph->get_value_list(arg); + ss << ", \"values\": ["; + for (const ValueRef& value : *val_list) { + ss << value << ", "; + } + ss << "]"; + } else if (compute_graph->val_is_int_list(arg)) { + ss << ", \"values\": "; + ss << *compute_graph->get_int_list(arg); + } else if (compute_graph->val_is_int(arg)) { + ss << ", \"value\": "; + ss << compute_graph->get_int(arg); + } else if (compute_graph->val_is_double(arg)) { + ss << ", \"value\": "; + ss << compute_graph->get_double(arg); + } else if (compute_graph->val_is_bool(arg)) { + ss << ", \"value\": "; + ss << compute_graph->get_bool(arg); + } else if (compute_graph->val_is_symint(arg)) { + ss << ", \"value\": "; + ss << compute_graph->read_symint(arg); + } + ss << "}"; + + return ss.str(); +} + +std::string make_operator_json( + ComputeGraph* const compute_graph, + std::string& op_name, + std::vector& args) { + std::stringstream ss; + ss << "\"name\": \"" << op_name << "\", \"args\": ["; + for (size_t i = 0; i < args.size(); ++i) { + ss << make_arg_json(compute_graph, args[i]); + if (i + 1 < args.size()) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); +} + void ComputeGraph::print_readable() { std::set input_set; for (const IOValueRef& io_val : inputs()) { diff --git a/backends/vulkan/runtime/graph/Logging.h b/backends/vulkan/runtime/graph/Logging.h index fb2f66e2d6f..359ba649322 100644 --- a/backends/vulkan/runtime/graph/Logging.h +++ b/backends/vulkan/runtime/graph/Logging.h @@ -10,6 +10,8 @@ #include +#include + #include #include #include @@ -42,6 +44,8 @@ inline std::ostream& operator<<(std::ostream& os, const utils::ivec4& v) { return utils::operator<<(os, v); } +std::ostream& operator<<(std::ostream& os, const std::vector& sizes); + template inline std::ostream& operator<<(std::ostream& os, const std::optional& opt) { os << "["; @@ -52,4 +56,11 @@ inline std::ostream& operator<<(std::ostream& os, const std::optional& opt) { return os; } +std::string make_arg_json(ComputeGraph* const compute_graph, ValueRef arg); + +std::string make_operator_json( + ComputeGraph* const compute_graph, + std::string& op_name, + std::vector& args); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index d1add8227de..ab48ec3f4c3 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -60,8 +60,21 @@ void DispatchNode::encode(ComputeGraph* graph) { write_push_constant_data(); +#ifdef ET_EVENT_TRACER_ENABLED + std::string event_name; + if (!operator_json.empty()) { + event_name += "\"operator\": {" + operator_json + "}, "; + } + event_name += "\"kernel_name\": \"" + shader_.kernel_name + "\", "; + event_name += "\"operator_id\": " + std::to_string(operator_count); +#endif + context->report_shader_dispatch_start( +#ifdef ET_EVENT_TRACER_ENABLED + event_name, +#else shader_.kernel_name, +#endif global_workgroup_size_, local_workgroup_size_, node_id_); diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp index 953f15e7b4d..a1a089c88e5 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp @@ -14,16 +14,26 @@ ExecuteNode::ExecuteNode( const ResizeFunction& resize_fn, const std::vector& resize_args, const std::vector& args, - const std::string& name) + const std::string& name, + const bool has_data_dependent_shape) : resize_fn_(resize_fn), resize_args_(resize_args), args_(args), - name_(name) {} + name_(name), + has_data_dependent_shape_(has_data_dependent_shape) { +#ifdef ET_EVENT_TRACER_ENABLED + operator_json = set_and_get_current_operator_json(""); + operator_count = get_current_operator_count(); +#endif +} bool ExecuteNode::trigger_resize(ComputeGraph* graph) { - const bool any_arg_updated = was_any_arg_updated(graph); - if (resize_fn_ && any_arg_updated) { + bool any_arg_updated = was_any_arg_updated(graph); + if (resize_fn_ && + (any_arg_updated || graph->graphconfig().force_resize || + has_data_dependent_shape_)) { resize_fn_(graph, args_, resize_args_); + any_arg_updated = true; } return any_arg_updated; } diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 323036cef90..2084f075b5b 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -57,7 +57,8 @@ class ExecuteNode { const ResizeFunction& resize_fn = nullptr, const std::vector& resize_args = {}, const std::vector& args = {}, - const std::string& name = "Graph Node"); + const std::string& name = "Graph Node", + const bool has_data_dependent_shape = false); virtual ~ExecuteNode() = default; @@ -87,6 +88,12 @@ class ExecuteNode { const std::vector resize_args_; const std::vector args_; const std::string name_; + bool has_data_dependent_shape_ = false; + +#ifdef ET_EVENT_TRACER_ENABLED + std::string operator_json; + size_t operator_count = 0; +#endif }; } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index 62e1dc86f43..143203cdcd0 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -53,14 +53,21 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { const std::vector packed_sizes = graph->sizes_of(packed_); size_t numel = utils::multiply_integers(packed_sizes); api::StagingBuffer staging( - graph->context(), graph->dtype_of(packed_), numel); + graph->context(), + graph->dtype_of(packed_), + numel, + vkapi::CopyDirection::HOST_TO_DEVICE); staging.set_staging_zeros(); return staging; } TensorRefPtr tref = graph->get_tref(tref_); size_t numel = utils::multiply_integers(tref->sizes); - api::StagingBuffer staging(graph->context(), tref->dtype, numel); + api::StagingBuffer staging( + graph->context(), + tref->dtype, + numel, + vkapi::CopyDirection::HOST_TO_DEVICE); graph->update_staging_nbytes_in_cmd(staging.buffer().mem_size_as_size_t()); size_t nbytes = numel * vkapi::element_size(tref->dtype); staging.copy_from(tref->data, nbytes); diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl index bd210e210ce..3d5814eb6d0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl @@ -21,13 +21,17 @@ ${layout_declare_tensor(B, "r", "mat1_tensor", DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "mat2_tensor", DTYPE, "texture3d")} $if HAS_BIAS: ${layout_declare_tensor(B, "r", "bias_tensor", DTYPE, "texture3d")} -${layout_declare_ubo(B, "ivec4", "out_sizes")} -${layout_declare_ubo(B, "ivec3", "out_limits")} -${layout_declare_ubo(B, "ivec4", "mat1_sizes")} -${layout_declare_ubo(B, "ivec4", "mat2_sizes")} -$if HAS_BIAS: - ${layout_declare_ubo(B, "ivec4", "bias_sizes")} - ${layout_declare_ubo(B, "float", "alpha", "float", "beta")} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 mat1_sizes; + ivec4 mat2_sizes; + ivec3 out_limits; + $if HAS_BIAS: + ivec4 bias_sizes; + float alpha; + float beta; +}; #include "indexing_utils.h" diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op_defs.glslh b/backends/vulkan/runtime/graph/ops/glsl/binary_op_defs.glslh new file mode 100644 index 00000000000..e2bdec703ca --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op_defs.glslh @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef BINARY_OP_DEFS_GLSLH +#define BINARY_OP_DEFS_GLSLH + +// +// Power operation that handles negative and zero bases +// +// In GLSL, pow(x, y) is undefined for x < 0. This function provides +// a safe implementation that: +// - Handles x == 0 (returns 0 for y > 0, returns 1 for y == 0) +// - Handles x < 0 by using absolute value and preserving sign for odd integer exponents +// - Uses standard pow() for x > 0 +// + +// Scalar overload +T power_of(T x, T y) { + if (x == 0.0) { + // Handle 0^y: 0^0 = 1, 0^y = 0 for y > 0 + return (y == 0.0) ? T(1.0) : T(0.0); + } + + // Use absolute value to avoid undefined behavior + float result = pow(abs(x), y); + + // For negative bases with odd integer exponents, preserve the negative sign + if (x < 0.0) { + float int_y = round(y); + if (abs(y - int_y) < 1e-5 && int(int_y) % 2 == 1) { + result = -result; + } + } + + return T(result); +} + +#ifdef VEC4_T + +// Vector overload +VEC4_T power_of(VEC4_T x, VEC4_T y) { + VEC4_T result; + for (int i = 0; i < 4; i++) { + result[i] = power_of(x[i], y[i]); + } + return result; +} + +#endif // VEC4_T + +#endif // BINARY_OP_DEFS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl new file mode 100644 index 00000000000..d0bd1809d11 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl @@ -0,0 +1,76 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define NAME ${VARIANT_NAME} + +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER + +#define op(X, Y) ${OPERATOR} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing.glslh" +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_out", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_in_a", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_in_b", "int", IO_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "out_sizes")} + +layout(push_constant) uniform restrict Block { + float input_a_scale; + int input_a_zp; + float input_b_scale; + int input_b_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const int tid = int(gl_GlobalInvocationID.x); + + const int W4 = div_up_4(out_sizes.x); + const int H = out_sizes.y; + const int C4 = div_up_4(out_sizes.z); + const int N = out_sizes.w; + + if (tid >= W4 * H * C4 * N) { + return; + } + + const ivec4 in_block_1 = t_packed_int8_in_a[tid]; + const ivec4 in_block_2 = t_packed_int8_in_b[tid]; + + ivec4 out_block = ivec4(pack_into_int32(ivec4(output_zp))); + + for (int row = 0; row < 4; row++) { + vec4 in_texel_1 = unpack_and_dequantize( + in_block_1[row], input_a_scale, input_a_zp); + vec4 in_texel_2 = unpack_and_dequantize( + in_block_2[row], input_b_scale, input_b_zp); + + vec4 out_texel = op(in_texel_1, in_texel_2); + out_block[row] = quantize_and_pack(out_texel, output_inv_scale, output_zp); + } + + t_packed_int8_out[tid] = out_block; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.yaml new file mode 100644 index 00000000000..e19ed8839eb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.yaml @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +binary_q8ta_q8ta_q8to: + parameter_names_with_default_values: + OPERATOR: X + Y + NDIM: 3 + DTYPE: float + PACKING: C_packed + IO_STORAGE: buffer + generate_variant_forall: + IO_STORAGE: + - VALUE: buffer + shader_variants: + - NAME: add_q8ta_q8ta_q8to + OPERATOR: X + Y diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl new file mode 100644 index 00000000000..4d58a5d2e24 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl @@ -0,0 +1,47 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define NAME ${VARIANT_NAME} + +#define T ${buffer_scalar_type(DTYPE)} + +#define op(X, Y) ${OPERATOR} + +${define_active_storage_type(STORAGE)} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(push_constant) uniform restrict Block { + float scalar_value; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "binary_op_defs.glslh" + +void main() { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, outp)) { + return; + } + + t_out[out_bufi] = T(op(t_in[out_bufi], T(scalar_value))); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.yaml new file mode 100644 index 00000000000..b818132cf9b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +binary_scalar_buffer: + parameter_names_with_default_values: + OPERATOR: power_of(X, Y) + NDIM: 3 + DTYPE: float + PACKING: C_packed + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + shader_variants: + - NAME: pow_scalar_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl new file mode 100644 index 00000000000..f02ddf35271 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl @@ -0,0 +1,52 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define NAME ${VARIANT_NAME} + +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +#define op(X, Y) ${OPERATOR} + +${define_active_storage_type(STORAGE)} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "inp")} + +layout(push_constant) uniform restrict Block { + float scalar_value; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "binary_op_defs.glslh" + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (out_of_bounds(pos, outp)) { + return; + } + + VEC4_T in_texel = texelFetch(t_in, pos, 0); + VEC4_T out_texel = VEC4_T(op(in_texel, VEC4_T(scalar_value))); + + imageStore(t_out, pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.yaml new file mode 100644 index 00000000000..3e731bf7a15 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +binary_scalar_texture: + parameter_names_with_default_values: + OPERATOR: power_of(X, Y) + NDIM: 3 + DTYPE: float + PACKING: C_packed + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + shader_variants: + - NAME: pow_scalar_texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl index 6d164ae2645..f61081d33b7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl @@ -3,14 +3,16 @@ #define PRECISION ${PRECISION} #define T ${buffer_scalar_type(DTYPE)} +#define DST_T ${buffer_scalar_type(BUF_DTYPE)} ${define_required_extensions(DTYPE)} +${define_required_extensions(BUF_DTYPE)} layout(std430) buffer; #include "indexing.glslh" -${layout_declare_tensor(B, "w", "nchw_buf", DTYPE, STORAGE)} +${layout_declare_tensor(B, "w", "nchw_buf", BUF_DTYPE, STORAGE)} ${layout_declare_tensor(B, "r", "t_inp", DTYPE, STORAGE)} ${layout_declare_ubo(B, "BufferMetadata", "inp")} @@ -32,5 +34,5 @@ void main() { uint nchwi = tensor_idx_to_contiguous_idx(inp, inp_tidx); - nchw_buf[nchwi] = t_inp[inp_bufi]; + nchw_buf[nchwi] = DST_T(t_inp[inp_bufi]); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml index 929108cca5e..1ee7d2db8c1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml @@ -7,15 +7,19 @@ buffer_to_nchw: parameter_names_with_default_values: DTYPE: float + BUF_DTYPE: float STORAGE: buffer USE_PUSH_CONST: True generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - - VALUE: int8 - - VALUE: uint8 - - VALUE: int32 + combination: + parameter_names: [DTYPE, BUF_DTYPE] + combos: + - parameter_values: [half, half] + - parameter_values: [half, float] + - parameter_values: [float, float] + - parameter_values: [double, double] + - parameter_values: [int8, int8] + - parameter_values: [uint8, uint8] + - parameter_values: [int32, int32] shader_variants: - NAME: buffer_to_nchw diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl deleted file mode 100644 index 7e21bcf0eba..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl +++ /dev/null @@ -1,400 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define SCALE_OUT_T ${buffer_scalar_type(SCALE_OUT_DTYPE)} -#define ZP_OUT_T ${buffer_scalar_type(ZP_OUT_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("buffer")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(SCALE_OUT_DTYPE)} -${define_required_extensions(ZP_OUT_DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "buffer")} -${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} - -$if MODE == "per_tensor": - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - float eps; - }; -$if MODE == "per_token": - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - layout(push_constant) uniform BlockPC { - ivec4 blockSize; // WHCN (>=1) - ivec4 numBlocks; // #blocks along W,H,C,N - ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} - int mapping_type; // 0=ASYM, 1=SYM, 2=SYM_NO_CLIP - int quant_min; - int quant_max; - float eps; - }; - -${layout_declare_ubo(B, "ivec4", "t_in_sizes")} -${layout_declare_ubo(B, "ivec4", "t_in_strides")} -${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} -${layout_declare_ubo(B, "ivec4", "t_scale_strides")} -${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} -${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} - -#include "indexing_utils.h" -#include "choose_qparams.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -#define NWORKERS 64 - -// Shared memory for reduction - must match local work group size -shared float shared_min[NWORKERS]; -shared float shared_max[NWORKERS]; - -/* - Quantization Parameter Computation Shader (Buffer Storage) - This shader computes quantization parameters (scale and zero_point) for converting - floating-point tensors to n-bit integer representations while preserving the - original data range as much as possible. The computed parameters enable efficient - quantization by mapping the continuous floating-point range to discrete integer values. - - Important Considerations: - (+) The input tensor is assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - - Workgroup Configuration: - - choose_qparams_per_tensor - This mode computes a single set of quantization parameters for the entire tensor. - Uses parallel reduction across all threads to find global min/max values. - - (*) global_wg_size: {1, 1, 1} (single workgroup processes entire tensor) - (*) local_wg_size: {64, 1, 1} (matches NWORKERS for shared memory) - - - choose_qparams_per_token - This mode computes separate quantization parameters for each token in the tensor. - Each workgroup processes one token independently to find token-specific min/max. - - (*) global_wg_size: {num_tokens, 1, 1} (one workgroup per token) - (*) local_wg_size: {1, 1, 1} (single thread per token) - - - choose_qparams_block_wise - This mode computes quantization parameters for each block of elements, allowing - fine-grained control over quantization granularity within the tensor. Each block - is processed independently to find its own min/max values and compute corresponding - scale and zero_point parameters. - - (*) global_wg_size: {nBlocks, 1u, 1u} (one workgroup per block) - (*) local_wg_size: {1, 1, 1} (single thread per block) - - Block-wise quantization supports multiple mapping types for scale/zero_point calculation: - - - mapping_type = 0 (ASYMMETRIC): - Uses asymmetric quantization where the full floating-point range [min, max] is - mapped to the quantized range [quant_min, quant_max]. This preserves the original - data distribution but may not center zero optimally. - - Calculation: - scale = (max - min) / (quant_max - quant_min) - zero_point = quant_min - round(min / scale) - - Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: - scale = (10.2 - (-3.5)) / (7 - (-8)) = 13.7 / 15 = 0.913 - zero_point = -8 - round(-3.5 / 0.913) = -8 - (-4) = -4 - - - mapping_type = 1 (SYMMETRIC): - Uses symmetric quantization where the range is centered around zero. The scale - is computed based on the maximum absolute value, ensuring zero is exactly - representable in the quantized domain. - - Calculation: - max_abs = max(abs(min), abs(max)) - scale = max_abs / ((quant_max - quant_min) / 2) - zero_point = (quant_max + quant_min + 1) / 2 // midpoint - - Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: - max_abs = max(3.5, 10.2) = 10.2 - scale = 10.2 / ((7 - (-8)) / 2) = 10.2 / 7.5 = 1.36 - zero_point = (-8 + 7 + 1) / 2 = 0 - - - mapping_type = 2 (SYMMETRIC_NO_CLIPPING_ERR): - A variant of symmetric quantization that minimizes clipping errors by computing - separate scales for positive and negative ranges, then using the maximum. This - reduces quantization error on the dominant range while ensuring no values are - clipped. - - Calculation: - smin = abs(min) / abs(quant_min) // scale for negative range - smax = max / quant_max // scale for positive range - scale = max(smin, smax) // use larger scale to avoid clipping - zero_point = (quant_max + quant_min + 1) / 2 // midpoint - - Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: - smin = 3.5 / 8 = 0.4375 - smax = 10.2 / 7 = 1.457 - scale = max(0.4375, 1.457) = 1.457 // use smax to avoid clipping positives - zero_point = (-8 + 7 + 1) / 2 = 0 - - Tree Reduction Algorithm for Min/Max Finding: - The shader uses a parallel tree reduction algorithm to efficiently find minimum and - maximum values across multiple threads. This approach reduces the number of memory - accesses and synchronization points compared to sequential scanning. - - Example with 8 threads processing values [10, 1, 8, 1, 0, 2, 3, 5]: - - Step 1 - Initial Population: - Each thread loads its assigned value into shared memory arrays. - shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - Thread ID: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - - Step 2 - Stride 1 (Compare Adjacent Pairs): - Threads 0,2,4,6 compare with threads 1,3,5,7 respectively. - shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) - shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) - Active: | 0 | | 2 | | 4 | | 6 | | - - Step 3 - Stride 2 (Compare Pairs of Pairs): - Threads 0,4 compare with threads 2,6 respectively. - shared_min: | 1 | | | | 0 | | | | (min(1,1), min(0,3)) - shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) - Active: | 0 | | | | 4 | | | | - - Step 4 - Stride 4 (Final Comparison): - Thread 0 compares with thread 4 to get final result. - shared_min: | 0 | | | | | | | | (min(1,0) = 0) - shared_max: | 10 | | | | | | | | (max(10,5) = 10) - Active: | 0 | | | | | | | | - - Final Result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) - - The tree reduction completes in log_2(N) steps where N is the number of threads, - providing O(log N) time complexity instead of O(N) for sequential reduction. - - Quantization Parameter Calculation: - Once min/max values are determined, the shader computes: - - scale = (max - min) / (quant_max - quant_min) - - zero_point = quantization offset to map floating-point zero to integer range - - Mode-Specific Behavior: - - Per-Tensor: Single workgroup with strided access across entire tensor - - Per-Token: Multiple workgroups, each processing one token independently - - Block-Wise: Each thread processes assigned blocks using nested loops over block dimensions -*/ - -#ifdef per_tensor - -void choose_qparams_per_tensor() { - uint global_id = gl_GlobalInvocationID.x; - uint local_id = gl_LocalInvocationID.x; - uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; - - uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); - - // Each thread processes multiple elements with stride - float thread_min = 1.0/0.0; // +infinity - float thread_max = -1.0/0.0; // -infinity - bool found_valid = false; - - for (uint i = global_id; i < total_elements; i += total_threads) { - float val = t_in[i]; - if (!isnan(val) && !isinf(val)) { - if (!found_valid) { - thread_min = val; - thread_max = val; - found_valid = true; - } else { - thread_min = min(thread_min, val); - thread_max = max(thread_max, val); - } - } - } - - // Intra-group reduction using shared memory - shared_min[local_id] = thread_min; - shared_max[local_id] = thread_max; - barrier(); - - // Tree reduction within work group - for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { - if (local_id < stride) { - float other_min = shared_min[local_id + stride]; - float other_max = shared_max[local_id + stride]; - - if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { - shared_min[local_id] = other_min; - } - if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { - shared_max[local_id] = other_max; - } - } - barrier(); - } - - // Final result calculation (single workgroup only) - if (local_id == 0) { - float global_min = shared_min[0]; - float global_max = shared_max[0]; - - float scale_val; - int zero_point_val; - // Use default values: mapping_type=0 (ASYMMETRIC), eps from push constant - calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); - - t_scale[0] = SCALE_OUT_T(scale_val); - t_zero_point[0] = ZP_OUT_T(zero_point_val); - } -} - -#elif defined(per_token) - -void choose_qparams_per_token() { - uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); - uint token_size = total_elements / uint(num_tokens); - - const uint TOTAL_TOKENS = uint(num_tokens); - - /* each invocation handles token-ids: id, id+STRIDE, id+2·STRIDE … */ - const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; - for (uint token_id = gl_GlobalInvocationID.x; token_id < TOTAL_TOKENS; token_id += STRIDE) { - // Calculate the start and end indices for this token - uint token_start = token_id * token_size; - uint token_end = token_start + token_size; - - // Each thread processes the entire token - float lo = 1.0/0.0; // +INF - float hi = -1.0/0.0; // -INF - bool found_valid = false; - - // Process all elements in this token - for (uint i = token_start; i < token_end; i++) { - float val = t_in[i]; - if (!isnan(val) && !isinf(val)) { - if (!found_valid) { - lo = hi = val; - found_valid = true; - } else { - lo = min(lo, val); - hi = max(hi, val); - } - } - } - - if (!found_valid) { - // If no valid values were found, use default values - lo = 0.0; - hi = 0.0; - } - - // Calculate scale and zero point directly - float scale_val; - int zero_point_val; - // Use default values: mapping_type=0 (ASYMMETRIC), eps=1e-5 - calc_scale_zp(lo, hi, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); - - // Write results - t_scale[token_id] = SCALE_OUT_T(scale_val); - t_zero_point[token_id] = ZP_OUT_T(zero_point_val); - } -} - -#elif defined(block_wise) - -ivec4 block_id_to_coord(uint bid) { - ivec4 bc; - bc.w = int(bid) / blockStride.w; - - int r = int(bid) - bc.w * blockStride.w; - bc.z = r / blockStride.z; - - r -= bc.z * blockStride.z; - bc.y = r / blockStride.y; - - r -= bc.y * blockStride.y; - bc.x = r; - return bc; -} - -void choose_qparams_block_wise() { - const uint TOTAL_BLOCKS = uint(numBlocks.x * numBlocks.y * numBlocks.z * numBlocks.w); - - // each invocation handles block-ids: id, id+STRIDE, id+2·STRIDE - const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; - for (uint block_id = gl_GlobalInvocationID.x; block_id < TOTAL_BLOCKS; block_id += STRIDE) { - // block -> WHCN coordinate - ivec4 bc = block_id_to_coord(block_id); - ivec4 blockStart = bc * blockSize; // first element (inclusive) - ivec4 blockEnd = blockStart + blockSize; // last element (exclusive) - - // min / max scan over the block - float lo = 1.0/0.0; // +INF - float hi = -1.0/0.0; // -INF - bool found_valid = false; - - // Calculate actual block dimensions - ivec4 actualBlockSize = blockEnd - blockStart; - int blockElements = actualBlockSize.x * actualBlockSize.y * actualBlockSize.z * actualBlockSize.w; - - // Linear iteration over block elements - for (int elemIdx = 0; elemIdx < blockElements; ++elemIdx) { - // Convert linear index to 4D coordinates within block - int remaining = elemIdx; - int dn = remaining / (actualBlockSize.x * actualBlockSize.y * actualBlockSize.z); - remaining -= dn * (actualBlockSize.x * actualBlockSize.y * actualBlockSize.z); - int dc = remaining / (actualBlockSize.x * actualBlockSize.y); - remaining -= dc * (actualBlockSize.x * actualBlockSize.y); - int dh = remaining / actualBlockSize.x; - int dw = remaining - dh * actualBlockSize.x; - - ivec4 tidx = blockStart + ivec4(dw, dh, dc, dn); - uint idx = tidx_to_bufi(tidx, t_in_strides); - float v = t_in[idx]; - - if (!isnan(v) && !isinf(v)) { - if (!found_valid) { - lo = hi = v; - found_valid = true; - } else { - lo = min(lo, v); - hi = max(hi, v); - } - } - } - - // Handle the case where no valid values were found in the block - if (!found_valid) { - lo = 0.0; - hi = 0.0; - } - - float scale_val; - int zero_point_val; - calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale_val, zero_point_val); - - t_scale[block_id] = SCALE_OUT_T(scale_val); - t_zero_point[block_id] = ZP_OUT_T(zero_point_val); - } -} - -#endif - -void main() { - choose_qparams_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml deleted file mode 100644 index 8459b043baa..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml +++ /dev/null @@ -1,22 +0,0 @@ -choose_qparams_buffer: - parameter_names_with_default_values: - IN_DTYPE: float - SCALE_OUT_DTYPE: float - ZP_OUT_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: float - SCALE_OUT_DTYPE: - - VALUE: float - ZP_OUT_DTYPE: - - VALUE: int32 - - VALUE: int8 - - VALUE: float - shader_variants: - - NAME: choose_qparams_tensor_buffer - MODE: per_tensor - - NAME: choose_qparams_per_token_asymmetric_buffer - MODE: per_token - - NAME: choose_qparams_block_wise_buffer - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl index 639fe312148..7234b50a3f5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl @@ -19,6 +19,8 @@ #define MAX_THREADS 256 ${define_active_storage_type(STORAGE)} + +${define_required_extensions(DTYPE)} ${define_required_extensions("int8")} #extension GL_EXT_control_flow_attributes : require @@ -126,8 +128,8 @@ void find_min_max_for_row(const int output_y) { const int X4 = div_4(input_sizes.x); // Initialize thread-local min/max - float local_min = 1e30; - float local_max = -1e30; + T local_min = T(1e30); + T local_max = T(-1e30); // Each thread processes elements along their assigned output_id with stride // NUM_WORKERS_PER_OUTPUT @@ -187,7 +189,7 @@ void main() { calculate_scale_and_zero_point( local_min, local_max, quant_min, quant_max, scale, zero_point); - scales_out[i] = scale; + scales_out[i] = T(scale); zps_out[i] = zero_point; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml index 1594bb574bd..5dbf3d7adaa 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml @@ -14,5 +14,6 @@ choose_qparams_per_row: - VALUE: buffer DTYPE: - VALUE: float + - VALUE: half shader_variants: - NAME: choose_qparams_per_row diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl deleted file mode 100644 index a17a3ae41dd..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl +++ /dev/null @@ -1,533 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} -#define SCALE_OUT_T ${buffer_scalar_type(SCALE_OUT_DTYPE)} -#define ZP_OUT_T ${buffer_scalar_type(ZP_OUT_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("texture3d")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(SCALE_OUT_DTYPE)} -${define_required_extensions(ZP_OUT_DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -$if MODE != "block_wise": - ${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "texture3d")} - ${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "texture3d")} -$else: - ${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "buffer")} - ${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "buffer")} - -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} - -$if MODE == "per_tensor": - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - float eps; - }; -$if MODE == "per_token": - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - layout(push_constant) uniform BlockPC { - ivec4 blockSize; // WHCN (>=1) - ivec4 numBlocks; // #blocks along W,H,C,N - ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} - int mapping_type; // 0=ASYM, 1=SYM, 2=SYM_NO_CLIP - int quant_min; - int quant_max; - float eps; - }; - -${layout_declare_ubo(B, "ivec3", "t_in_limits")} -$if MODE != "block_wise": - ${layout_declare_ubo(B, "ivec3", "t_scale_limits")} - ${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} -$else: - ${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} - ${layout_declare_ubo(B, "ivec4", "t_scale_strides")} - ${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} - ${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} - - -#include "indexing_utils.h" -#include "choose_qparams.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -#define NWORKERS 64 - -// Shared memory for reduction - must match local work group size -shared float shared_min[NWORKERS]; -shared float shared_max[NWORKERS]; - -/*/* - Quantization Parameter Computation Shader (Buffer Storage) - This shader computes quantization parameters (scale and zero_point) for converting - floating-point tensors to n-bit integer representations while preserving the - original data range as much as possible. The computed parameters enable efficient - quantization by mapping the continuous floating-point range to discrete integer values. - - Important Considerations: - (+) The input tensor is assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - - Workgroup Configuration: - - choose_qparams_per_tensor - This mode computes a single set of quantization parameters for the entire tensor. - Uses parallel reduction across all threads to find global min/max values. - - (*) global_wg_size: default - (*) local_wg_size: default - - - choose_qparams_per_token - This mode computes separate quantization parameters for each token in the tensor. - Each workgroup processes one token independently to find token-specific min/max. - - (*) global_wg_size: default - (*) local_wg_size: {1, 1, 1} - - - choose_qparams_block_wise - This mode computes quantization parameters for each block of elements, allowing - fine-grained control over quantization granularity within the tensor. Each block - is processed independently to find its own min/max values and compute corresponding - scale and zero_point parameters. - - NOTE: This mode currently only supports buffer storage for the output. - - (*) global_wg_size: {nBlocks, 1u, 1u} (one workgroup per block) - (*) local_wg_size: {1, 1, 1} (single thread per block) - - Tree Reduction Algorithm for Min/Max Finding: - The shader uses a parallel tree reduction algorithm to efficiently find minimum and - maximum values across multiple threads. This approach reduces the number of memory - accesses and synchronization points compared to sequential scanning. - - Example with 8 threads processing values [10, 1, 8, 1, 0, 2, 3, 5]: - - Step 1 - Initial Population: - Each thread loads its assigned value into shared memory arrays. - shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - Thread ID: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - - Step 2 - Stride 1 (Compare Adjacent Pairs): - Threads 0,2,4,6 compare with threads 1,3,5,7 respectively. - shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) - shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) - Active: | 0 | | 2 | | 4 | | 6 | | - - Step 3 - Stride 2 (Compare Pairs of Pairs): - Threads 0,4 compare with threads 2,6 respectively. - shared_min: | 1 | | | | 0 | | | | (min(1,1), min(0,3)) - shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) - Active: | 0 | | | | 4 | | | | - - Step 4 - Stride 4 (Final Comparison): - Thread 0 compares with thread 4 to get final result. - shared_min: | 0 | | | | | | | | (min(1,0) = 0) - shared_max: | 10 | | | | | | | | (max(10,5) = 10) - Active: | 0 | | | | | | | | - - Final Result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) - - The tree reduction completes in log_2(N) steps where N is the number of threads, - providing O(log N) time complexity instead of O(N) for sequential reduction. - - Quantization Parameter Calculation: - Once min/max values are determined, the shader computes: - - scale = (max - min) / (quant_max - quant_min) - - zero_point = quantization offset to map floating-point zero to integer range - - Mode-Specific Behavior: - - Per-Tensor: Single workgroup with strided access across entire tensor - - Per-Token: Multiple workgroups, each processing one token independently -*/ - -#ifdef per_tensor - -void choose_qparams_per_tensor() { - uint global_id = gl_GlobalInvocationID.x; - uint local_id = gl_LocalInvocationID.x; - uint group_id = gl_WorkGroupID.x; - uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; - - uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); - - // Each thread processes multiple texels with stride - float thread_min = 1.0/0.0; // +infinity - float thread_max = -1.0/0.0; // -infinity - bool found_valid = false; - - // Process texels with stride across all threads - for (uint texel_idx = global_id; texel_idx < total_texels; texel_idx += total_threads) { - // Convert linear texel index to 3D coordinates - uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); - uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); - uint y = remainder / uint(t_in_limits.x); - uint x = remainder % uint(t_in_limits.x); - ivec3 texel_pos = ivec3(int(x), int(y), int(z)); - - FVEC4_T texel_data = load_texel(t_in, texel_pos); - - // For texture storage, we assume width-packed (packed_dim = 0) - // Calculate number of valid elements in this texel (handle padding) - int packed_dim = 0; // Width dimension is packed - ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format - ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); - - // Calculate total tensor elements to determine padding - int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; - int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + - tensor_coord.z * sizes.x * sizes.y; - int remaining_elements = total_elements - (linear_tensor_idx); - int valid_elements = min(4, remaining_elements); - - // Find min/max within this texel, considering only valid elements - if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { - if (!found_valid) { - thread_min = texel_data.x; - thread_max = texel_data.x; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.x); - thread_max = max(thread_max, texel_data.x); - } - } - - if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { - if (!found_valid) { - thread_min = texel_data.y; - thread_max = texel_data.y; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.y); - thread_max = max(thread_max, texel_data.y); - } - } - - if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { - if (!found_valid) { - thread_min = texel_data.z; - thread_max = texel_data.z; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.z); - thread_max = max(thread_max, texel_data.z); - } - } - - if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { - if (!found_valid) { - thread_min = texel_data.w; - thread_max = texel_data.w; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.w); - thread_max = max(thread_max, texel_data.w); - } - } - } - - // Intra-workgroup reduction using shared memory - shared_min[local_id] = thread_min; - shared_max[local_id] = thread_max; - barrier(); - - // Tree reduction within work group - for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { - if (local_id < stride) { - float other_min = shared_min[local_id + stride]; - float other_max = shared_max[local_id + stride]; - - if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { - shared_min[local_id] = other_min; - } - if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { - shared_max[local_id] = other_max; - } - } - barrier(); - } - - // Final result calculation (single workgroup only for reliability) - if (local_id == 0 && group_id == 0) { - float global_min = shared_min[0]; - float global_max = shared_max[0]; - - float scale_val; - int zero_point_val; - calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); - - write_texel(t_scale, ivec3(0, 0, 0), vec4(SCALE_OUT_T(scale_val), 0.0, 0.0, 0.0)); - write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(ZP_OUT_T(zero_point_val), 0, 0, 0)); - } -} - -#elif defined(per_token) - -void choose_qparams_per_token() { - // Each token is processed by multiple workgroups for parallel reduction - uint local_id = gl_LocalInvocationID.x; - uint group_id = gl_WorkGroupID.x; - uint total_workgroups = gl_NumWorkGroups.x; - - uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); - - // Calculate texels per token (assuming last dimension contains the token data) - // For per-token quantization, we assume tokens are along the last dimension - uint texels_per_token = total_texels / uint(num_tokens); - - // Calculate how many tokens each workgroup should process - uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; - - // Calculate which tokens this workgroup is responsible for - uint start_token = group_id * tokens_per_workgroup; - uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); - - // Process each token assigned to this workgroup - for (uint token_id = start_token; token_id < end_token; token_id++) { - // Calculate the texel range for this token - uint token_start_texel = token_id * texels_per_token; - uint token_end_texel = token_start_texel + texels_per_token; - - // Each thread processes multiple texels within the token - float thread_min = 1.0/0.0; // +infinity - float thread_max = -1.0/0.0; // -infinity - bool found_valid = false; - - // Process texels within this token only - for (uint texel_idx = token_start_texel + local_id; texel_idx < token_end_texel; texel_idx += gl_WorkGroupSize.x) { - // Convert linear texel index to 3D coordinates - uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); - uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); - uint y = remainder / uint(t_in_limits.x); - uint x = remainder % uint(t_in_limits.x); - ivec3 texel_pos = ivec3(int(x), int(y), int(z)); - - FVEC4_T texel_data = load_texel(t_in, texel_pos); - - // For texture storage, we assume width-packed (packed_dim = 0) - // Calculate number of valid elements in this texel (handle padding) - int packed_dim = 0; // Width dimension is packed - ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format - ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); - - // Calculate total tensor elements to determine padding - int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; - int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + - tensor_coord.z * sizes.x * sizes.y; - int remaining_elements = total_elements - (linear_tensor_idx); - int valid_elements = min(4, remaining_elements); - - // Find min/max within this texel, considering only valid elements - if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { - if (!found_valid) { - thread_min = texel_data.x; - thread_max = texel_data.x; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.x); - thread_max = max(thread_max, texel_data.x); - } - } - - if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { - if (!found_valid) { - thread_min = texel_data.y; - thread_max = texel_data.y; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.y); - thread_max = max(thread_max, texel_data.y); - } - } - - if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { - if (!found_valid) { - thread_min = texel_data.z; - thread_max = texel_data.z; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.z); - thread_max = max(thread_max, texel_data.z); - } - } - - if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { - if (!found_valid) { - thread_min = texel_data.w; - thread_max = texel_data.w; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.w); - thread_max = max(thread_max, texel_data.w); - } - } - } - - // Intra-workgroup reduction using shared memory - shared_min[local_id] = thread_min; - shared_max[local_id] = thread_max; - barrier(); - - // Tree reduction within work group - for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { - if (local_id < stride) { - float other_min = shared_min[local_id + stride]; - float other_max = shared_max[local_id + stride]; - - // Handle infinity values properly - if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { - shared_min[local_id] = other_min; - } - if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { - shared_max[local_id] = other_max; - } - } - barrier(); - } - - // Final calculation for this token - if (local_id == 0) { - float token_min = shared_min[0]; - float token_max = shared_max[0]; - - float scale_val; - int zero_point_val; - calc_scale_zp(token_min, token_max, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); - - // Convert token_id to 3D coordinates for output texture - // Assuming output tensors have the same layout as input but with different dimensions - uint out_z = token_id / uint(t_scale_limits.x * t_scale_limits.y); - uint out_remainder = token_id % uint(t_scale_limits.x * t_scale_limits.y); - uint out_y = out_remainder / uint(t_scale_limits.x); - uint out_x = out_remainder % uint(t_scale_limits.x); - ivec3 out_pos = ivec3(int(out_x), int(out_y), int(out_z)); - - write_texel(t_scale, out_pos, vec4(SCALE_OUT_T(scale_val), 0.0, 0.0, 0.0)); - write_texel(t_zero_point, out_pos, ivec4(ZP_OUT_T(zero_point_val), 0, 0, 0)); - } - - // Synchronize before processing next token - barrier(); - } -} - -#elif defined(block_wise) - -ivec4 block_id_to_coord(uint bid) { - ivec4 bc; - bc.w = int(bid) / blockStride.w; - - int r = int(bid) - bc.w * blockStride.w; - bc.z = r / blockStride.z; - - r -= bc.z * blockStride.z; - bc.y = r / blockStride.y; - - r -= bc.y * blockStride.y; - bc.x = r; - return bc; -} - -void choose_qparams_block_wise() { - const uint T = uint(numBlocks.x * numBlocks.y * numBlocks.z * numBlocks.w); - const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; - - // tensor full size in WHCN order - const ivec4 tensorSz = blockSize * numBlocks; - - // Process blocks with stride for better parallelization - for (uint blkIdx = gl_GlobalInvocationID.x; blkIdx < T; blkIdx += STRIDE) { - // block index in WHCN - const ivec4 b4d = block_id_to_coord(blkIdx); - const ivec4 blockStart = b4d * blockSize; - const ivec4 blockEnd = blockStart + blockSize; - - // scan all elements inside the block - float vmin = 3.402823e38; // +FLT_MAX - float vmax = -3.402823e38; // -FLT_MAX - bool found_valid = false; - - // Calculate total elements in block for linear iteration - const int blockElements = blockSize.x * blockSize.y * blockSize.z * blockSize.w; - - // Linear iteration over block elements (more cache-friendly) - for (int elemIdx = 0; elemIdx < blockElements; ++elemIdx) { - // Convert linear index to 4D coordinates within block - int remaining = elemIdx; - int dn = remaining / (blockSize.x * blockSize.y * blockSize.z); - remaining -= dn * (blockSize.x * blockSize.y * blockSize.z); - int dc = remaining / (blockSize.x * blockSize.y); - remaining -= dc * (blockSize.x * blockSize.y); - int dh = remaining / blockSize.x; - int dw = remaining - dh * blockSize.x; - - ivec4 tidx = blockStart + ivec4(dw, dh, dc, dn); - - // skip padding when tensor size is not an exact multiple of block - if (any(greaterThanEqual(tidx, tensorSz))) { continue; } - - // tensor index -> (x,y,z,component) inside input texture - ivec4 posi = to_texture_elem_pos(tidx, tensorSz, 0); // 0 = W_DIM (width packed) - - // fetch texel and pick the element inside it - FVEC4_T texl = load_texel(t_in, posi.xyz); - float v; - if (posi.w == 0) v = texl.x; - else if (posi.w == 1) v = texl.y; - else if (posi.w == 2) v = texl.z; - else v = texl.w; - - if (!isnan(v) && !isinf(v)) { - if (!found_valid) { - vmin = vmax = v; - found_valid = true; - } else { - vmin = min(vmin, v); - vmax = max(vmax, v); - } - } - } - - // Handle case where no valid values were found - if (!found_valid) { - vmin = 0.0; - vmax = 0.0; - } - - // compute scale / zero‑point (same maths as buffer kernel) - float scale; - int zp; - calc_scale_zp(vmin, vmax, quant_min, quant_max, mapping_type, eps, scale, zp); - - // Write the scalar values directly to buffer using linear index - t_scale[blkIdx] = SCALE_OUT_T(scale); - t_zero_point[blkIdx] = ZP_OUT_T(zp); - } -} - -#endif - -void main() { - choose_qparams_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml deleted file mode 100644 index 12228822d4b..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml +++ /dev/null @@ -1,22 +0,0 @@ -choose_qparams_texture: - parameter_names_with_default_values: - IN_DTYPE: float - SCALE_OUT_DTYPE: float - ZP_OUT_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: float - SCALE_OUT_DTYPE: - - VALUE: float - ZP_OUT_DTYPE: - - VALUE: int32 - - VALUE: int8 - - VALUE: float - shader_variants: - - NAME: choose_qparams_tensor_texture3d - MODE: per_tensor - - NAME: choose_qparams_per_token_asymmetric_texture3d - MODE: per_token - - NAME: choose_qparams_block_wise_texture3d - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/clone.glsl b/backends/vulkan/runtime/graph/ops/glsl/clone.glsl index 3bd1af8bb0c..e7f18526d06 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/clone.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/clone.glsl @@ -16,7 +16,10 @@ layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} -${layout_declare_ubo(B, "ivec3", "out_limits")} + +layout(push_constant) uniform restrict Block { + ivec3 out_limits; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; diff --git a/backends/vulkan/runtime/graph/ops/glsl/clone.yaml b/backends/vulkan/runtime/graph/ops/glsl/clone.yaml index 1fdbf506bfd..a85d201046e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/clone.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/clone.yaml @@ -7,5 +7,7 @@ clone: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: clone diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh index 732b7006c2c..9ade64910f2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -9,6 +9,10 @@ #ifndef COMMON_GLSLH #define COMMON_GLSLH +#ifdef DEBUG_MODE +#extension GL_EXT_debug_printf : enable +#endif + #define mul_2(x) ((x) << 1) #define mul_4(x) ((x) << 2) #define mul_8(x) ((x) << 3) @@ -29,21 +33,66 @@ #define mod_4(x) ((x) & 3) #define mod_8(x) ((x) & 7) -struct TensorIndex4D { - ivec4 data; -}; +int sign_extend_8bit(const int val) { + if ((val & 0x80) != 0) { + return val | (~0xFF); + } + return val; +} + +int extract_8bit_from_packed_int_le(const int packed, const int i) { + // account for little endian + int byte = sign_extend_8bit(packed >> (8 * i) & 0xFF); + return byte; +} + +ivec4 unpack_int8x4(const int packed) { + return ivec4( + extract_8bit_from_packed_int_le(packed, 0), + extract_8bit_from_packed_int_le(packed, 1), + extract_8bit_from_packed_int_le(packed, 2), + extract_8bit_from_packed_int_le(packed, 3)); +} + +int pack_4xqint_into_int32( + const int val0, + const int val1, + const int val2, + const int val3) { + int packed = (val0 & 0xFF) | ((val1 & 0xFF) << 8) | ((val2 & 0xFF) << 16) | + ((val3 & 0xFF) << 24); + + return packed; +} + +int pack_into_int32(const ivec4 quant_vals) { + int packed = ((quant_vals[0] & 0xFF) << 0) | ((quant_vals[1] & 0xFF) << 8) | + ((quant_vals[2] & 0xFF) << 16) | ((quant_vals[3] & 0xFF) << 24); + + return packed; +} + +vec4 unpack_and_dequantize( + const int packed_int8_vals, + const float scale, + const int zp) { + ivec4 unpacked = unpack_int8x4(packed_int8_vals); + return vec4(unpacked - zp) * scale; +} + +int quantize_and_pack(const vec4 vals, const float inv_scale, const int zp) { + ivec4 quantized = ivec4(round(vals * inv_scale) + zp); + quantized = clamp(quantized, -128, 127); + return pack_into_int32(quantized); +} #ifdef DEBUG_MODE -#extension GL_EXT_debug_printf : require +#define printf debugPrintfEXT -void printTensorIndex4D(const TensorIndex4D index) { +void printVec4(vec4 texel) { debugPrintfEXT( - "tensor_idx: %d, %d, %d, %d\\n", - index.data.x, - index.data.y, - index.data.z, - index.data.w); + "texel: %f, %f, %f, %f\\n", texel.x, texel.y, texel.z, texel.w); } #endif // DEBUG_MODE diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml index 39f96df5e90..36d0b879bdd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml @@ -6,6 +6,7 @@ concat_buffer: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 shader_variants: - NAME: concat_1_buffer NUM_INPUTS: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl index afab0c524d6..0611defa4c3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl @@ -113,8 +113,6 @@ void main() { VEC4_T out_texel = imageLoad(t_out, out_pos); - VEC4_T test_texel = VEC4_T(-1.0); - for (int comp = 0; comp < 4; ++comp) { ivec4 out_tidx = out_read_start_tidx; out_tidx[out_packed_dim] += comp; @@ -124,7 +122,6 @@ void main() { // of the previous input batch; if so, then don't overwrite this texel // element if (out_tidx[concat_dim] < concat_offset) { - test_texel[comp] = -5.0; continue; } @@ -164,7 +161,6 @@ void main() { inp${i}_packed_dim); out_texel[comp] = texelFetch(t_inp${i}, in_posi.xyz, 0)[in_posi.w]; - test_texel[comp] = out_texel[comp]; continue; } else { diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml index ed5003382a1..d3de77d8ea9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml @@ -6,6 +6,7 @@ concat_texture: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 shader_variants: - NAME: concat_1_texture3d NUM_INPUTS: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl index 0f5dbc41273..88746c5594e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl @@ -60,7 +60,7 @@ void main() { int num_steps = ((-ipos.y) + dilation.y - 1) / dilation.y; start.y = ipos.y + num_steps * dilation.y; } - const ivec2 end = min(ipos + overlay_region.xy, ivec2(in_sizes.xy)); + const ivec2 end = min(ipos + overlay_region.xy, in_sizes.xy); // Compute the start of the kernel based on how far we are skipping ahead when // reading the input. Note that these are "canonical" indices. ivec2 kstart = (start - ipos) / dilation; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh index 41825cba867..6f460d1398c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh @@ -27,6 +27,60 @@ struct Conv2DParams { int K4; }; +struct Conv2dTensorIndex { + ivec3 data; + int texel_i; +}; + +struct Conv2dBlockIndex { + ivec3 data; +}; + +Conv2dTensorIndex block_idx_to_tensor_idx(const Conv2dBlockIndex block_idx) { + Conv2dTensorIndex tensor_idx; + tensor_idx.data.x = mul_4(block_idx.data.x); + tensor_idx.data.y = block_idx.data.y; + tensor_idx.data.z = block_idx.data.z; + tensor_idx.texel_i = 0; + return tensor_idx; +} + +struct Conv2dBlockExtents { + ivec3 data; + int data_xz; +}; + +Conv2dBlockExtents make_block_extents(const ivec4 tensor_sizes) { + Conv2dBlockExtents block_sizes; + block_sizes.data.x = div_up_4(tensor_sizes.x); + block_sizes.data.y = tensor_sizes.y; + block_sizes.data.z = div_up_4(tensor_sizes.z); + + block_sizes.data_xz = block_sizes.data.x * block_sizes.data.z; + + return block_sizes; +} + +Conv2dBlockIndex linear_idx_to_block_idx( + const int idx, const Conv2dBlockExtents block_extents) { + Conv2dBlockIndex block_idx; + block_idx.data.z = idx % block_extents.data.z; + + const int row = idx / block_extents.data.z; + block_idx.data.x = row % block_extents.data.x; + block_idx.data.y = row / block_extents.data.x; + + return block_idx; +} + +bool block_idx_out_of_bounds( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents) { + return block_idx.data.x >= block_extents.data.x || + block_idx.data.y >= block_extents.data.y || + block_idx.data.z >= block_extents.data.z; +} + #ifdef DEBUG_MODE void printConv2DParams(const Conv2DParams params) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl index 02fbef29b75..9089f87d658 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl @@ -54,7 +54,7 @@ void main() { // Compute the start and end of the input indices to load. Padding is assumed // to be constant 0 padding, so reads from the padding region are skipped. const ivec2 start = ipos; - const ivec2 end = ipos + overlay_region.xy; + const ivec2 end = min(ipos + overlay_region.xy, in_sizes.xy); VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0); int kx = 0; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index 19250419baf..7448b042cad 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -97,6 +97,10 @@ void main() { for (int y = start.y, i = 0; i < TILE_SIZE + BATCH_SIZE_Y - 1; y += dilation.y, i++) { for (int x = start.x, j = 0; j < TILE_SIZE + BATCH_SIZE_X - 1; x += dilation.x, j++) { in_texels[j] = texelFetch(t_in, ivec3(x, y, pos.z), 0); + // Set to zero if reading out of bounds + if (any(greaterThanEqual(ivec2(x, y), in_sizes.xy))) { + in_texels[j] = VEC4_T(0); + } } // from 2nd iteration onwards accumulate dot product in 2nd sum diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh new file mode 100644 index 00000000000..836a138f6bc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh @@ -0,0 +1,145 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_DW_Q8_UTILS_GLSLH +#define CONV2D_DW_Q8_UTILS_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +vec4 dequantize(const int packed_texel, const float scale, const int zp) { + return vec4(unpack_int8x4(packed_texel) - zp) * scale; +} + +vec4 dequantize(const int packed_texel, const vec4 scales) { + return vec4(unpack_int8x4(packed_texel)) * scales; +} + +bool in_bounds( + const int block_w, + const int block_h, + const int block_c4, + const Conv2dBlockExtents block_extents) { + ivec3 idx = ivec3(block_w, block_h, block_c4); + if (any(lessThan(idx, ivec3(0)))) { + return false; + } + if (any(greaterThanEqual(idx, block_extents.data))) { + return false; + } + + return true; +} + +struct FPOutBlock { + vec4[4] data; +}; + +ivec4 quantize( + const vec4 texel, const float inv_scale, const int zp) { + vec4 quantized = round(texel * inv_scale) + zp; + return clamp(ivec4(quantized), -128, 127); +} + +ivec4 quantize_and_pack( + FPOutBlock out_block, const float inv_scale, const int zp) { + ivec4 packed_block; + for (int row = 0; row < 4; ++row) { + ivec4 quantized_texel = quantize(out_block.data[row], inv_scale, zp); + packed_block[row] = pack_into_int32(quantized_texel); + } + return packed_block; +} + +// Load a 4xint8 block of weights. Equivalent to unpacked_weights[kh][kw][c:c+4]. +int load_weight_1w4c( + int kw, // w coordinate + int kh, // h coordinate + int oc4, // channel block + int KW4, // kernel width / 4 (rounded up) + int OC4 // out channels count / 4 (rounded up) + ) { + + // Find the packed block index. Weights are packed as 4W4C tiles. + int kw4 = kw / 4; // W block + int linear_idx = ((kh * KW4 + kw4) * OC4 + oc4) * 4; + int block_x_offset = kw % 4; +#ifdef WEIGHT_BUFFER + return t_packed_int8_weight[linear_idx + block_x_offset]; +#else + return texelFetch(t_packed_int8_weight, ivec2(oc4, kh * KW4 + kw4), 0)[block_x_offset]; +#endif +} + +// Load a 4xint8 block of inputs - channel c through c+3 (c = oc4*4) at +// the given spatial location. Equivalent to unpacked_input[0][c:c+4][h][w]. +int load_input_1w4c( + int w, // w coordinate + int h, // h coordinate + int oc4, // channel block + int OC4, // out channels / 4 (rounded up) + Conv2dBlockExtents block_extents +) { + int block_w = w / 4; + + if (in_bounds(block_w, h, oc4, block_extents) && w >= 0) { +#ifdef PACKED_INT8_INPUT_BUFFER + const int buffer_idx = + (h * block_extents.data_xz + block_w * block_extents.data.z + oc4) * 4 + (w % 4); + return t_packed_int8_input[buffer_idx]; +#else + #error Unimplemented +#endif + } else { + return pack_into_int32(ivec4(input_zp)); + } +} + +#ifdef DEBUG_MODE + +void printInputWindow1D(const InputWindow1D input_window) { + debugPrintfEXT("InputWindow1D contents (len = %d): \\n", input_window.len); + for (int i = 0; i < min(input_window.len, MAX_WINDOW_WIDTH); ++i) { + debugPrintfEXT( + " [%d]: (%.3f, %.3f, %.3f, %.3f) \\n", + i, + input_window.data[i].x, + input_window.data[i].y, + input_window.data[i].z, + input_window.data[i].w); + } +} + +void printWeightRow(const WeightRow weight_row) { + debugPrintfEXT("WeightRow contents (len = %d): \\n", weight_row.len); + for (int i = 0; i < min(weight_row.len, MAX_KERNEL_WIDTH); ++i) { + debugPrintfEXT( + " [%d]: (%.3f, %.3f, %.3f, %.3f) \\n", + i, + weight_row.data[i].x, + weight_row.data[i].y, + weight_row.data[i].z, + weight_row.data[i].w); + } +} + +void printFPOutBlock(const FPOutBlock out_block) { + debugPrintfEXT("FPOutBlock contents: \\n"); + for (int i = 0; i < 4; ++i) { + debugPrintfEXT( + " [%d]: (%.3f, %.3f, %.3f, %.3f) \\n", + i, + out_block.data[i].x, + out_block.data[i].y, + out_block.data[i].z, + out_block.data[i].w); + } + } + +#endif // DEBUG_MODE + +#endif // CONV2D_DW_Q8_UTILS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl new file mode 100644 index 00000000000..bc61de32073 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl @@ -0,0 +1,137 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define MAX_WINDOW_WIDTH 12 +#define MAX_KERNEL_WIDTH 5 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "conv2d_dw_q8_utils.glslh" + +void main() { + const int tid = int(gl_GlobalInvocationID.x); + Conv2dBlockExtents out_block_extents = make_block_extents(output_sizes); + + Conv2dBlockIndex out_block_idx = linear_idx_to_block_idx( + tid, out_block_extents); + + if (block_idx_out_of_bounds(out_block_idx, out_block_extents)) { + return; + } + + const int out_h = out_block_idx.data.y; + const int out_w = mul_4(out_block_idx.data.x); + + Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes); + + const int Kw4 = div_up_4(conv2d_params.kernel_size.x); + + // Compute 4 channels for 4 output elements. + ivec4 acc[4]; + [[unroll]] for (int i = 0; i < 4; ++i) { + acc[i] = ivec4(0); + } + + for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) { + const int h = out_h * conv2d_params.stride.y - conv2d_params.padding.y + + ky * conv2d_params.dilation.y; + + for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) { + const int w = out_w * conv2d_params.stride.x - conv2d_params.padding.x + + kx * conv2d_params.dilation.x; + + // Load and unpack weights. + const int packed_weight_4c = load_weight_1w4c( + kx, + ky, + out_block_idx.data.z, + Kw4, + out_block_extents.data.z + ); + + const ivec4 weight_4c = unpack_int8x4(packed_weight_4c); + + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + ivec4 input_texel = unpack_int8x4(load_input_1w4c( + w + conv2d_params.stride.x * subtile_w, + h, + out_block_idx.data.z, + out_block_extents.data.z, + in_block_extents)); + acc[subtile_w] += weight_4c * input_texel; + } + } + } + + // Apply input zero point as weight_sum * input_zp. + vec4 weight_sums = vec4(t_weight_sums[out_block_idx.data.z]); + const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]); + + vec4 facc[4]; + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = vec4(acc[subtile_w]); + facc[subtile_w] -= weight_sums * input_zp; + facc[subtile_w] *= weight_scales * input_scale; + } + + if (apply_bias > 0) { + const vec4 bias = vec4(t_bias[out_block_idx.data.z]); + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] += bias; + } + } + + ivec4 packed_out; + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + packed_out[subtile_w] = pack_into_int32(quantize(facc[subtile_w], output_inv_scale, output_zp)); + } + +#ifdef PACKED_INT8_OUTPUT_BUFFER + t_packed_int8_output[tid] = packed_out; +#else + imageStore(t_packed_int8_output, out_block_idx.data, packed_out); +#endif +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml new file mode 100644 index 00000000000..77f801668a4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_dw_q8ta_q8csw_q8to: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_dw_q8ta_q8csw_q8to diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh index 7add8c4cd16..3be8bf32a61 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh @@ -23,7 +23,7 @@ #extension GL_EXT_control_flow_attributes : require -#include "common.glslh" +#include "indexing.glslh" #include "conv2d_common.glslh" struct Im2ColMatrixIdx { diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh index c02b070e17e..18ed8074a8a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh @@ -23,7 +23,7 @@ #extension GL_EXT_debug_printf : require -#include "common.glslh" +#include "indexing.glslh" #include "conv2d_common.glslh" #include "conv2d_fp_im2col_block.glslh" #include "linear_fp_input_tile.glslh" diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_store.glslh index 2171d75c628..6c4dd7f0b52 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_store.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_store.glslh @@ -20,7 +20,7 @@ #extension GL_EXT_control_flow_attributes : require -#include "common.glslh" +#include "indexing.glslh" #include "conv2d_common.glslh" #include "conv2d_fp_im2col_block.glslh" #include "linear_fp_output_tile.glslh" diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh new file mode 100644 index 00000000000..4456043bb9f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh @@ -0,0 +1,52 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_FP_INPUT_TILE_LOAD +#define CONV2D_FP_INPUT_TILE_LOAD + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_input_tile.glslh" + +VEC4_T load_fp_input_texel(const Conv2dTensorIndex tidx) { +#ifdef INPUT_BUFFER + VEC4_T texel = VEC4_T(0); + const int c_idx = mul_4(tidx.data.z); + const int c_stride = input_sizes.y * input_sizes.x; + + const int base_buf_i = c_idx * c_stride + tidx.data.y * input_sizes.x + tidx.data.x; + const int limit = min(input_sizes.z - c_idx, 4); + + for (int i = 0; i < limit; i++) { + texel[i] = t_fp_input[base_buf_i + i * c_stride]; + } + return texel; +#else + return texelFetch(t_fp_input, tidx.data, 0); +#endif +} + +void load_fp_input_tile( + out FPInputTile tile, + const Conv2dBlockIndex block_idx) { +#if TILE_M == 4 && TILE_K4 == 1 + Conv2dTensorIndex load_tidx = block_idx_to_tensor_idx(block_idx); + [[unroll]] for (int w = 0; w < TILE_M; w++) { + if (load_tidx.data.x < input_sizes.x) { + tile.data[w][0] = load_fp_input_texel(load_tidx); + } else { + tile.data[w][0] = VEC4_T(0); + } + load_tidx.data.x++; + } +#else + not_implemented; +#endif +} + +#endif // CONV2D_FP_INPUT_TILE_LOAD diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_block_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_block_load.glslh new file mode 100644 index 00000000000..44c226f6891 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_block_load.glslh @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_INT8_INPUT_BLOCK_LOAD +#define CONV2D_INT8_INPUT_BLOCK_LOAD + +#extension GL_EXT_control_flow_attributes : require + +#include "conv2d_common.glslh" +#include "conv2d_int8_activation_block.glslh" + +void store_packed_int8_input_block( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents, + const Int8ActivationBlock packed_int8_block) { +#ifdef OUTPUT_BUFFER + const int buffer_idx = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z + block_idx.data.z; + t_packed_int8_input[buffer_idx] = packed_int8_block.data; +#else + imageStore(t_packed_int8_input, block_idx.data, packed_int8_block.data); +#endif +} + +#endif // CONV2D_INT8_INPUT_BLOCK_LOAD diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_tile_load.glslh new file mode 100644 index 00000000000..44aa09912ec --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_input_tile_load.glslh @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_INT8_INPUT_TILE_LOAD +#define CONV2D_INT8_INPUT_TILE_LOAD + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_int8_input_tile.glslh" + +struct Int8InputTileIndex { +#ifdef PACKED_INT8_INPUT_BUFFER + int data; +#else + ivec3 data; +#endif +}; + +Int8InputTileIndex make_initial_int8_input_tile_index( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents) { + Int8InputTileIndex idx; +#ifdef PACKED_INT8_INPUT_BUFFER + idx.data = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z; +#else + idx.data = ivec3(block_idx.data.x, block_idx.data.y, 0); +#endif + return idx; +} + +Int8InputTileIndex make_initial_int8_input_tile_index( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents, + const int group_k4_offset) { + Int8InputTileIndex idx; +#ifdef PACKED_INT8_INPUT_BUFFER + idx.data = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z + group_k4_offset; +#else + idx.data = ivec3(block_idx.data.x, block_idx.data.y, group_k4_offset); +#endif + return idx; +} + +void load_packed_int8_input_tile( + out Int8InputTile int8_tile, + const Int8InputTileIndex idx) { +#ifdef PACKED_INT8_INPUT_BUFFER + int8_tile.data[0][0] = t_packed_int8_input[idx.data]; +#else + int8_tile.data[0][0] = texelFetch(t_packed_int8_input, idx.data, 0); +#endif + + // Guard against unsupported tile sizes +#if TILE_M4 != 1 || TILE_K4 != 1 + not_implemented; +#endif +} + +void increment_k4(inout Int8InputTileIndex idx) { +#ifdef PACKED_INT8_INPUT_BUFFER + idx.data += 1; +#else + idx.data.z += 1; +#endif +} + +#endif // CONV2D_INT8_INPUT_TILE_LOAD diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_output_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_output_tile_store.glslh new file mode 100644 index 00000000000..27244f67953 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_output_tile_store.glslh @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_INT8_OUTPUT_TILE_STORE +#define CONV2D_INT8_OUTPUT_TILE_STORE + +#extension GL_EXT_control_flow_attributes : require + +#include "conv2d_common.glslh" +#include "linear_int8_output_tile.glslh" + +void store_packed_int8_output_tile( + const Int8OutTile int8_tile, + Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents) { +#ifdef PACKED_INT8_OUTPUT_BUFFER + [[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) { + int buffer_idx = block_idx.data.y * block_extents.data_xz + + (block_idx.data.x + m4) * block_extents.data.z + block_idx.data.z; + [[unroll]] for (int n4 = 0; n4 < TILE_N4; n4++) { + if (block_idx.data.x + m4 < block_extents.data.x && + block_idx.data.z + n4 < block_extents.data.z) { + t_packed_int8_output[buffer_idx++] = int8_tile.data[m4][n4]; + } + } + } +#else + [[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; n4++) { + if (block_idx.data.x + m4 < block_extents.data.x && + block_idx.data.z + n4 < block_extents.data.z) { + const ivec3 idx_offset = ivec3(m4, 0, n4); + imageStore( + t_packed_int8_output, + block_idx.data + idx_offset, + int8_tile.data[m4][n4]); + } + } + } +#endif +} + +#endif // CONV2D_INT8_OUTPUT_TILE_STORE diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.glsl new file mode 100644 index 00000000000..16c12b3ee5a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.glsl @@ -0,0 +1,144 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +// corresponds to input/output width dim +#define TILE_M4 1 +// corresponds to input channels dim +#define TILE_K4 1 +// corresponds to output channels dim +#define TILE_N4 2 + +#define TILE_M 4 +#define TILE_K 4 +#define TILE_N 8 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "conv2d_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" +#include "linear_int8_output_tile_compute.glslh" +#include "conv2d_int8_output_tile_store.glslh" + +void main() { + Conv2dBlockIndex output_block_idx; + output_block_idx.data.z = int(gl_GlobalInvocationID.x) * TILE_N4; + output_block_idx.data.x = int(gl_GlobalInvocationID.y) * TILE_M4; + output_block_idx.data.y = int(gl_GlobalInvocationID.z); + + Conv2dBlockExtents output_block_extents = make_block_extents(output_sizes); + if (block_idx_out_of_bounds(output_block_idx, output_block_extents)) { + return; + } + + Conv2dBlockExtents input_block_extents = make_block_extents(input_sizes); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_input_tile; + Int8WeightTile int8_weight_tile; + + Int8InputTileIndex input_idx = make_initial_int8_input_tile_index( + output_block_idx, input_block_extents); + + for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) { + load_packed_int8_input_tile(int8_input_tile, input_idx); + + load_int8_weight_tile( + int8_weight_tile, + output_block_idx.data.z, + k4, + output_block_extents.data.z); + + int_accumulate_with_int8_weight( + out_accum, int8_input_tile, int8_weight_tile); + + increment_k4(input_idx); + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, output_block_idx.data.z); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, output_block_idx.data.z); + + Int8OutTile int8_out_tile; + initialize(int8_out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, output_block_idx.data.z); + + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile); + } + + store_packed_int8_output_tile( + int8_out_tile, output_block_idx, output_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.yaml new file mode 100644 index 00000000000..23803dc6da1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_pw_q8ta_q8csw_q8to_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_pw_q8ta_q8csw_q8to_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl index 9f84afeb1a1..ef50a1aca9f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -12,10 +12,12 @@ #define PRECISION ${PRECISION} -#define VEC4_T ${texel_type(DTYPE)} +$if DTYPE == "half": + #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require + #define VEC4_T f16vec4 +$else: + #define VEC4_T ${texel_type(DTYPE)} -#define TILE_SIZE_X uint16_t(${TILE_SIZE_X}) -#define TILE_SIZE_Y uint16_t(${TILE_SIZE_Y}) #define op(X, A, B) ${OPERATOR} @@ -50,119 +52,90 @@ ${layout_declare_spec_const(C, "int", "ngroups", "1")} * size is only 1x1, making it easier to re-use loaded texels from t_kernel. */ void main() { - const int out_limits_scaled[2] = - {(out_limits.x + (TILE_SIZE_X - 1)) / TILE_SIZE_X, - (out_limits.y + (TILE_SIZE_Y - 1)) / TILE_SIZE_Y}; - const uint16_t div_by_x = uint16_t(gl_GlobalInvocationID.x / out_limits_scaled[0]); - const uint16_t out_pos_xy[2] = {uint16_t(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x}; - const int out_pos_z = int(gl_GlobalInvocationID.y); + int inputAndOutputWidth = out_limits.x; + int inputAndOutputHeight = out_limits.y; + int outputChannel = out_limits.z*4; - // If the top left position is out of bounds, then this invocation will have - // no work to do. - if (out_pos_xy[1] >= out_limits_scaled[1] || out_pos_z >= out_limits.z) { + // Divided by 4 because the input channels are packed + int inputChannel = in_group_size/4; + + int threadHW = int(gl_GlobalInvocationID.x); + int threadOutChannel = int(gl_GlobalInvocationID.y); + + int xIdx = threadHW % inputAndOutputWidth; + int yIdx = threadHW / inputAndOutputWidth; + + if (threadHW >= inputAndOutputWidth * inputAndOutputHeight && threadOutChannel >= outputChannel) { return; } - // Output position for TILE_SIZE = 2 - // +--------+--------+ - // | pos[0] | pos[1] | - // +--------+--------+ - // | pos[2] | pos[3] | - // +--------+--------+ - uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2]; - for (uint16_t y = uint16_t(0), i = uint16_t(0); y < TILE_SIZE_Y; ++y) { - for (uint16_t x = uint16_t(0); x < TILE_SIZE_X; ++x) { - pos[i * 2] = out_pos_xy[0] * TILE_SIZE_X + x; - pos[i * 2 + 1] = out_pos_xy[1] * TILE_SIZE_Y + y; - i++; - } - } + VEC4_T outputTexel = VEC4_T(texelFetch(t_bias, ivec2(threadOutChannel, 0), 0)); - // Final output array where each element is a tensor value. - // Tuple of consecutive 4 elements represents a single output texel. - float sum[TILE_SIZE_X * TILE_SIZE_Y * 4]; + VEC4_T inputVec; + VEC4_T weight1OutputChannelPacked; + VEC4_T weight2OutputChannelPacked; + VEC4_T weight3OutputChannelPacked; + VEC4_T weight4OutputChannelPacked; - // Initialize the output array with the bias value - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i++) { - sum[i] = 0; - } + // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions + // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute + for (int inputC = 0; inputC < inputChannel; inputC += 1) { - int z4 = 0; - // Since the kernel is 1x1, we only have to loop over the depth dimension. - for (int z = 0; z < in_group_size; z += 4, ++z4) { - // During prepacking, the weight tensor has been permuted so that the - // channel (IC) dim is along the x-axis, and the batch (OC) dim is along - // the z-axis. - float kernel_values[4 * 4]; // 4 channels, 4 elements per channel - - // Load kernel values from texels to array - [[unroll]] for (int i = 0; i < 4; ++i) { - const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos_z), 0); - kernel_values[i * 4 + 0] = k_tex.x; - kernel_values[i * 4 + 1] = k_tex.y; - kernel_values[i * 4 + 2] = k_tex.z; - kernel_values[i * 4 + 3] = k_tex.w; - } - - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { - const vec4 in_tex = texelFetch(t_in, ivec3(pos[i * 2], pos[i * 2 + 1], z4), 0); - // Load the input texel into an array - float tex_values[4]; - tex_values[0] = in_tex.x; - tex_values[1] = in_tex.y; - tex_values[2] = in_tex.z; - tex_values[3] = in_tex.w; - - // For 2x2 tile size algorithm works as follows. - // To explain the calculations below, the contents of one in_tex and the - // group of 4 texels loaded from t_kernel are shown: - // - // in_tex t_kernel - // -x-> ---x---> - // +---+ +----+----+----+----+ - // ^ | w | ^ | D0 | D1 | D2 | D3 | - // | +---+ | +----+----+----+----+ - // | | z | | | C0 | C1 | C2 | C3 | - // z +---+ z +----+----+----+----+ - // | | y | | | B0 | B2 | B2 | B3 | - // | +---+ | +----+----+----+----+ - // | x | | A0 | A1 | A2 | A3 | - // +---+ +----+----+----+----+ - // - // In the t_kernel graphic, cells sharing the same letter are from - // the same batch/output channel index, and the number denotes a unique - // channel index. To calculate the output texel, the following - // calculation is performed: - // - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ - // | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ - // | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 | - // +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+ - // | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ - // | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ - // - // which is what is expressed in the following calculations. This is done - // for each output position. - for (int j = 0; j < 4; ++j) { - sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j]; - sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j]; - sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j]; - sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j]; - } - } - } + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); + + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); + + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + + inputC += 1; + + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); - const vec4 bias = texelFetch(t_bias, ivec2(out_pos_z, 0), 0); + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { - const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos_z); - if (all(lessThan(pos_l.xy, out_limits.xy))) { - const vec4 out_sum = vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]); - imageStore(t_out, pos_l, op(out_sum + bias, out_min, out_max)); - } + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + + inputC += 1; + + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); + + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); + + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + + inputC += 1; + + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); + + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); + + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); } + + imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max)); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml index ebfee11c405..bab3c715540 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml @@ -9,8 +9,6 @@ conv2d_pw_s1p0: OPERATOR: X NDIM: 3 DTYPE: float - TILE_SIZE_X: 1 - TILE_SIZE_Y: 4 generate_variant_forall: DTYPE: - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh new file mode 100644 index 00000000000..279f4f17f13 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh @@ -0,0 +1,151 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_Q8_UTILS_GLSLH +#define CONV2D_Q8_UTILS_GLSLH + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#include "linear_int_accumulator.glslh" + +struct Int8InputWindow1D { + int[MAX_WINDOW_WIDTH] data; + int len; +}; + +Int8InputWindow1D initial_input_window() { + Int8InputWindow1D input_window; + for (int i = 0; i < MAX_WINDOW_WIDTH; ++i) { + input_window.data[i] = 0; + } + input_window.len = 0; + return input_window; +} + +bool in_bounds( + const int block_w, + const int block_h, + const int block_c4, + const Conv2dBlockExtents block_extents) { + ivec3 idx = ivec3(block_w, block_h, block_c4); + if (any(lessThan(idx, ivec3(0)))) { + return false; + } + if (any(greaterThanEqual(idx, block_extents.data))) { + return false; + } + + return true; +} + +Int8InputWindow1D load_input_window( + const int w_start, + const int w_end, + const int h, + const int c4, + const Conv2dBlockExtents block_extents, + const ivec4 input_zps) { + Int8InputWindow1D input_window = initial_input_window(); + + const int block_w_start = div_4(w_start); + const int block_w_end = div_4(w_end); + + int window_i = 0; + for (int block_w = block_w_start; block_w <= block_w_end; ++block_w) { + ivec4 input_block = input_zps; + + if (in_bounds(block_w, h, c4, block_extents)) { +#ifdef PACKED_INT8_INPUT_BUFFER + const int buffer_idx = + h * block_extents.data_xz + block_w * block_extents.data.z + c4; + input_block = t_packed_int8_input[buffer_idx]; +#else + input_block = texelFetch(t_packed_int8_input, ivec3(block_w, h, c4), 0); +#endif + } + + const int loaded_w_start = mul_4(block_w); + for (int row = 0; row < 4; ++row) { + if (loaded_w_start + row >= w_start && loaded_w_start + row <= w_end) { + input_window.data[window_i++] = input_block[row]; + } + } + } + input_window.len = window_i; + return input_window; +} + +ivec4 load_weight_block( + const int ic4, + const int kx, + const int ky, + const int oc4, + const int IC4, + const int Kw, + const int Kh, + const int OC4) { +#ifdef PACKED_INT8_WEIGHTS_BUFFER + const int block_x = oc4 * Kw + kx; + const int block_y = ky * IC4 + ic4; + return t_packed_int8_weight[block_y * (Kw * OC4) + block_x]; +#else + return texelFetch( + t_packed_int8_weight, ivec2(oc4 * Kw + kx, ky * IC4 + ic4), 0); +#endif +} + +void perform_conv1d( + inout Int32Accum accum, + const Int8InputWindow1D input_window, + const ivec4 weight_block, + const int kx) { + [[unroll]] for (int out_w = 0; out_w < 4; ++out_w) { + const int window_i = out_w * conv2d_params.stride.x + kx; + [[unroll]] for (int out_c = 0; out_c < 4; ++out_c) { + accum.data[out_w][0][out_c] = dotPacked4x8AccSatEXT( + input_window.data[window_i], + weight_block[out_c], + accum.data[out_w][0][out_c]); + } + } +} + +#ifdef DEBUG_MODE + +void printInt8InputWindow1D(const Int8InputWindow1D input_window) { + debugPrintfEXT("Int8InputWindow1D contents (len = %d): \\n", input_window.len); + for (int i = 0; i < min(input_window.len, MAX_WINDOW_WIDTH); ++i) { + ivec4 unpacked = unpack_int8x4(input_window.data[i]); + debugPrintfEXT( + " [%d]: (%d, %d, %d, %d) \\n", + i, + unpacked.x, + unpacked.y, + unpacked.z, + unpacked.w); + } +} + +void printWeightBlock(const ivec4 weight_block) { + debugPrintfEXT("WeightBlock contents: \\n"); + for (int i = 0; i < 4; ++i) { + ivec4 unpacked = unpack_int8x4(weight_block[i]); + debugPrintfEXT( + " [%d]: (%d, %d, %d, %d) \\n", + i, + unpacked.x, + unpacked.y, + unpacked.z, + unpacked.w); + } +} + +#endif // DEBUG_MODE + +#endif // CONV2D_Q8_UTILS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl new file mode 100644 index 00000000000..5839b13aeaa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl @@ -0,0 +1,173 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define MAX_WINDOW_WIDTH 16 + +// corresponds to input/output width dim +#define TILE_M4 1 +// corresponds to input channels dim +#define TILE_K4 1 +// corresponds to output channels dim +#define TILE_N4 1 + +#define TILE_M 4 +#define TILE_K 4 +#define TILE_N 4 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "im2col_packed_int8_utils.glslh" +#include "conv2d_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" +#include "linear_int8_output_tile_compute.glslh" +#include "conv2d_int8_output_tile_store.glslh" + +#include "conv2d_q8_utils.glslh" + +void main() { + Conv2dBlockIndex out_block_idx; + out_block_idx.data.z = int(gl_GlobalInvocationID.x) * TILE_N4; + out_block_idx.data.x = int(gl_GlobalInvocationID.y) * TILE_M4; + out_block_idx.data.y = int(gl_GlobalInvocationID.z); + + Conv2dBlockExtents out_block_extents = make_block_extents(output_sizes); + if (block_idx_out_of_bounds(out_block_idx, out_block_extents)) { + return; + } + + const int out_w = mul_4(out_block_idx.data.x); + const int w_start = + (out_w * conv2d_params.stride.x) - conv2d_params.padding.x; + const int w_end = ((out_w + 3) * conv2d_params.stride.x) - + conv2d_params.padding.x + + (conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x; + + Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes); + + const ivec4 input_zps = ivec4(pack_into_int32(ivec4(input_zp))); + const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]); + + Int32Accum out_accum; + initialize(out_accum); + + const int IC4_per_group = div_up_4(conv2d_params.in_channels_per_group); + + const int n = mul_4(out_block_idx.data.z); + const int group_idx = n / conv2d_params.out_channels_per_group; + const int group_ic4_offset = group_idx * IC4_per_group; + + for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) { + const int h = out_block_idx.data.y * conv2d_params.stride.y - + conv2d_params.padding.y + ky * conv2d_params.dilation.y; + + for (int ic4 = 0; ic4 < IC4_per_group; ic4++) { + Int8InputWindow1D int8_input_window = load_input_window( + w_start, + w_end, + h, + group_ic4_offset + ic4, + in_block_extents, + input_zps); + + for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) { + const ivec4 weight_block = load_weight_block( + ic4, + kx, + ky, + out_block_idx.data.z, + IC4_per_group, + conv2d_params.kernel_size.x, + conv2d_params.kernel_size.y, + out_block_extents.data.z); + + perform_conv1d(out_accum, int8_input_window, weight_block, kx); + } + } + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, out_block_idx.data.z); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, out_block_idx.data.z); + + Int8OutTile int8_out_tile; + initialize(int8_out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, out_block_idx.data.z); + + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile); + } + + store_packed_int8_output_tile( + int8_out_tile, out_block_idx, out_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.yaml new file mode 100644 index 00000000000..7d33434940c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.yaml @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_q8ta_q8csw_q8to: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + - parameter_values: [texture3d, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_q8ta_q8csw_q8to diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.glsl new file mode 100644 index 00000000000..b44e37766fc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.glsl @@ -0,0 +1,149 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +// corresponds to input/output width dim +#define TILE_M4 1 +// corresponds to input channels dim +#define TILE_K4 1 +// corresponds to output channels dim +#define TILE_N4 2 + +#define TILE_M 4 +#define TILE_K 4 +#define TILE_N 8 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "im2col_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "conv2d_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" +#include "linear_int8_output_tile_compute.glslh" +#include "conv2d_int8_output_tile_store.glslh" + +void main() { + Conv2dBlockIndex output_block_idx; + output_block_idx.data.z = int(gl_GlobalInvocationID.x) * TILE_N4; + output_block_idx.data.x = int(gl_GlobalInvocationID.y) * TILE_M4; + output_block_idx.data.y = int(gl_GlobalInvocationID.z); + + Conv2dBlockExtents output_block_extents = make_block_extents(output_sizes); + if (block_idx_out_of_bounds(output_block_idx, output_block_extents)) { + return; + } + + const int n = mul_4(output_block_idx.data.z); + + const int group_idx = n / conv2d_params.out_channels_per_group; + const int group_k4_offset = group_idx * conv2d_params.K4_per_group; + + Conv2dBlockExtents input_block_extents = make_block_extents(im2col_sizes); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_input_tile; + Int8WeightTile int8_weight_tile; + + Int8InputTileIndex input_idx = make_initial_int8_input_tile_index( + output_block_idx, input_block_extents, group_k4_offset); + + for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) { + load_packed_int8_input_tile(int8_input_tile, input_idx); + + load_int8_weight_tile( + int8_weight_tile, + output_block_idx.data.z, + k4, + output_block_extents.data.z); + + int_accumulate_with_int8_weight( + out_accum, int8_input_tile, int8_weight_tile); + + increment_k4(input_idx); + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, output_block_idx.data.z); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, output_block_idx.data.z); + + Int8OutTile int8_out_tile; + initialize(int8_out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, output_block_idx.data.z); + + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + compute_int8_out_tile_with_int32_accum( + int8_out_tile, + out_accum, + input_scale, + input_zp, + output_inv_scale, + output_zp, + weight_sums_tile, + weight_scales_tile); + } + + store_packed_int8_output_tile( + int8_out_tile, output_block_idx, output_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.yaml new file mode 100644 index 00000000000..14d303b99e7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.yaml @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_q8ta_q8csw_q8to_linear_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, WEIGHT_STORAGE] + combos: + - parameter_values: [buffer, texture2d] + - parameter_values: [texture3d, texture2d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_q8ta_q8csw_q8to_linear_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/convert.glslh b/backends/vulkan/runtime/graph/ops/glsl/convert.glslh new file mode 100644 index 00000000000..b901bc7e9d9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/convert.glslh @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONVERT_GLSLH +#define CONVERT_GLSLH + +// Scalar Conversions + +#ifdef T + +#if T == float16_t + +#define convert_to_T(x) T(clamp(x, -65504, 65504)); + +#else + +#define convert_to_T(x) T(x); + +#endif // T == float16_t + +#endif // T + +#endif // CONVERT_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl deleted file mode 100644 index 39aa9b11a0d..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define VEC4_T ${texel_type(DTYPE)} - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "existing_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} - -layout(push_constant) uniform restrict Block { - ivec4 out_sizes; - ivec4 in_sizes; - // Operates on (x, y, z) logical extents. - // channel_range is stored in range.w - ivec4 range; - // Analogus to range variable in copy. It defines the # of channel being - // copied. - // dst channel offset is stored in dst_offset.w - ivec4 dst_offset; - int src_channel_offset; -}; - -#include "indexing_utils.h" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); -const lowp int packed_dim = unhash_packed_dim(out_layout); - -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); - -void main() { - // Note: Unlike other shaders, the range is often not equal to the destination - // texture extent. - const ivec3 lpos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(lpos, range.xyz))) { - return; - } - - const ivec3 out_lpos = lpos + dst_offset.xyz; - - const ivec4 out_tidx = lpos_to_tidx(out_lpos, out_sizes, out_axis_map.w, packed_dim); - - // First read the existing values to make sure the boundary values stay. - VEC4_T v = load_texel_lpos(existing_out, out_lpos, out_axis_map); - - ivec4 in_tidx = out_tidx; - for (int i=0; i<4; i++) { - - in_tidx[packed_dim] = out_tidx[packed_dim] - dst_offset.w + i; - - // Handle the partial update for begining of channel in an existing tensor. - // If the source channel index is below zero or exceeds the range, we skip - // updating the element to avoid overwriting existing data. - if ((in_tidx[packed_dim] < 0) || (in_tidx[packed_dim] >= range.w)) { - continue; - } - - // Readjust for the source offset. - in_tidx[packed_dim] += src_channel_offset; - - ivec4 in_posi = tidx_to_posi(in_tidx, in_sizes, in_axis_map, packed_dim); - v[i] = load_texel(t_in, in_posi.xyz)[in_posi.w]; - } - - write_texel_lpos(t_out, out_lpos, v, out_axis_map); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml deleted file mode 100644 index 984d9a09d43..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml +++ /dev/null @@ -1,12 +0,0 @@ -copy_channel_offset: - parameter_names_with_default_values: - DTYPE: float - NDIM: 3 - STORAGE: texture3d - generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: int32 - shader_variants: - - NAME: copy_channel_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl deleted file mode 100644 index 178814a90c3..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -${define_active_storage_type(STORAGE)} - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} - -layout(push_constant) uniform restrict Block { - ivec3 range; - // xyz is source offset w is channel size - ivec4 src_offset; - // xyz is destination offset w is channel size - ivec4 dst_offset; -}; - -#include "indexing_utils.h" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); - -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); - -${layout_declare_spec_const(C, "int", "batch_index_function", "0")} - -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, range))) { - return; - } - - ivec3 in_pos = pos + src_offset.xyz; - ivec3 out_pos = pos + dst_offset.xyz; - if (src_offset.w > 0) { - if (batch_index_function == 1) { - // batch index is calculated using source channel size - const int channel_index = pos.z % src_offset.w; - const int batch_index = pos.z / src_offset.w; - out_pos.z = channel_index + dst_offset.z + batch_index * dst_offset.w; - } else if (batch_index_function == 2) { - // batch index is calculated using destination channel size - const int channel_index = pos.z % dst_offset.w; - const int batch_index = pos.z / dst_offset.w; - in_pos.z = channel_index + src_offset.z + batch_index * src_offset.w; - } - } - - write_texel_lpos( - t_out, - out_pos, - load_texel_lpos(t_in, in_pos, in_axis_map), - out_axis_map); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml deleted file mode 100644 index 09f5ca36ea4..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml +++ /dev/null @@ -1,17 +0,0 @@ -copy_offset: - parameter_names_with_default_values: - DTYPE: float - NDIM: 3 - STORAGE: texture3d - generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: int32 - - VALUE: int8 - - VALUE: uint8 - STORAGE: - - VALUE: texture3d - - VALUE: texture2d - shader_variants: - - NAME: copy_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.glsl deleted file mode 100644 index 3100565d08a..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.glsl +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define VEC4_T ${texel_type(DTYPE)} - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "existing_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} - -layout(push_constant) uniform restrict Block { - ivec4 range; - - // xyz is source offset w is channel size - ivec4 src_offset; - - // xyz is destination offset w is channel size - ivec4 dst_offset; -}; - -#include "indexing_utils.h" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); -const lowp int packed_dim = unhash_packed_dim(out_layout); - -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); - -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, range.xyz))) { - return; - } - - // Position in input tensor - ivec3 in_pos = pos + src_offset.xyz; - in_pos[packed_dim] = pos[packed_dim] + (src_offset[packed_dim] >> 2); - - // Read input value mapping to this output texel - VEC4_T in_value = load_texel_lpos(t_in, in_pos, in_axis_map); - - // Starting offset to read from a texel - const int src_lane_offset = src_offset[packed_dim] & 0x3; - const bool has_src_lane_offset = src_lane_offset != 0; - - // If input lane offset is non zero i.e packed texel is composed from multiple sources - if (has_src_lane_offset) { - // Boundary values will come from next input texel in the packed dim. - ivec3 next_in_pos = in_pos; - next_in_pos[packed_dim] = in_pos[packed_dim] + 1; - VEC4_T next_value = load_texel_lpos(t_in, next_in_pos, in_axis_map); - - // Keep input values from the end of current input pixel based on src_lane_offset - // offset 1 means the first lane of current input texel is not a part of the output texel - // offset 2 means first 2 lanes are not and so on - // Copy next texel's values towards the end of input texel, based on lane offset - // offset 1 means the first lane from next texel is part of the input texel - // offset 2 means first 2 lanes from next texel is part of the input texel and so on - if (src_lane_offset == 1) { - in_value = ivec4(in_value.yzw, next_value.x); - } else if (src_lane_offset == 2) { - in_value = ivec4(in_value.zw, next_value.xy); - } else { - in_value = ivec4(in_value.w, next_value.xyz); - } - } - - // Starting offset to write at within a texel - const int out_lane_offset = dst_offset[packed_dim] & 0x3; - const bool has_dst_lane_offset = out_lane_offset != 0; - - ivec3 out_pos = pos + dst_offset.xyz; - out_pos[packed_dim] = pos[packed_dim] + (dst_offset[packed_dim] >> 2); - - VEC4_T out_value; - - // If lane offset is non zero i.e packed texel is composed from multiple sources - if (has_dst_lane_offset) { - // When position in packed dim is > 0 - if (pos[packed_dim] > 0) { - // Boundary values will come from previous input texel in the packed dim. - ivec3 prev_in_pos = in_pos; - prev_in_pos[packed_dim] = in_pos[packed_dim] - 1; - VEC4_T prev_value = load_texel_lpos(t_in, prev_in_pos, in_axis_map); - - // Shift values toward the beginning based on out_lane_offset - // offset 1 means the last lane from the previous texel is a part of the output texel - // offset 2 means last 2 lanes and so on - if (out_lane_offset == 1) { - out_value.x = prev_value.w; - } else if (out_lane_offset == 2) { - out_value.xy = prev_value.zw; - } else { - out_value.xyz = prev_value.yzw; - } - } else { - // When position in packed dim is == 0 - // Boundary values will be the previous texel values. - out_value = load_texel_lpos(existing_out, out_pos, out_axis_map); - } - - // Copy input values towards the end of output array, based on lane offset - // offset 1 means the first lane from previous texel is part of the output texel starting at offset - // offset 2 means first 2 lanes from the previous texel is part of the output texel and so on - if (out_lane_offset == 1) { - out_value.yzw = in_value.xyz; - } else if (out_lane_offset == 2) { - out_value.zw = in_value.xy; - } else { - out_value.w = in_value.x; - } - } else { - out_value = in_value; - } - - write_texel_lpos( - t_out, - out_pos, - out_value, - out_axis_map); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml deleted file mode 100644 index 6e55876cb28..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml +++ /dev/null @@ -1,12 +0,0 @@ -copy_packed_dim_offset: - parameter_names_with_default_values: - DTYPE: float - NDIM: 3 - STORAGE: texture3d - generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: int32 - shader_variants: - - NAME: copy_packed_dim_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh deleted file mode 100644 index 7194bebda35..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#ifndef DEQUANTIZE_GLSLH -#define DEQUANTIZE_GLSLH - -OUT_T dequantize_val(IN_T qvalue, float scale_val, int zero_point_val) { - return OUT_T(float(int(qvalue) - zero_point_val) * scale_val); -} - -#endif // DEQUANTIZE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl deleted file mode 100644 index 57dc2d53fff..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl +++ /dev/null @@ -1,263 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} -#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} -#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("buffer")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(OUT_DTYPE)} -${define_required_extensions(SCALE_DTYPE)} -${define_required_extensions(ZP_DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} - -$if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - }; -$if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int axis; - int num_channels; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - ivec4 blockSize; // bW, bH, bC, bN - ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN - ivec4 blockStride; // pre-computed linear strides for the block grid - int quant_min; - int quant_max; - }; - -${layout_declare_ubo(B, "int", "out_numel")} -${layout_declare_ubo(B, "ivec4", "t_in_sizes")} -${layout_declare_ubo(B, "ivec4", "t_in_strides")} -${layout_declare_ubo(B, "ivec4", "t_out_sizes")} -${layout_declare_ubo(B, "ivec4", "t_out_strides")} - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} - -#include "dequantize.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); -const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); - -/* - Dequantization Shader (Buffer Storage) - This shader converts n-bit integer tensor values back to floating-point representations - using pre-computed quantization parameters (scale and zero_point). The dequantization - reconstructs the original floating-point values from their discrete integer representations - with minimal precision loss. - - Important Considerations: - (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - (+) The axis map layout is assumed to be a standard layout for scales and zero_points - (++) The scale and zero_point tensors must be implemented as buffers - - Workgroup Configuration: - - dequantize_per_tensor - This mode reverses the uniform quantization applied across the entire tensor by using the - single scale and zero_point values to convert quantized integer values back to their original - floating-point representation. - - (*) global_wg_size: default - (*) local_wg_size: default - - - dequantize_per_token - This mode reverses the quantization applied individually to each token (or element) in the - input by using separate scale and zero_point values for each token. For a tensor of shape - [B, S, H], it applies the inverse transformation token-wise across the B*S tokens, converting - quantized values back to their original floating-point representation for each group of H - elements independently. - - (*) global_wg_size: default - (*) local_wg_size: default - - - dequantize_per_channel - This mode reverses the quantization applied separately to each channel of the input tensor - by using distinct scale and zero_point values for each channel. For a tensor of shape - [B, C, H, W] with axis = 1, it applies the inverse transformation channel-wise across the C - channels, converting quantized values back to their original floating-point representation - independently for each channel. - - (*) global_wg_size: default - (*) local_wg_size: default - - - dequantize_block_wise - This mode reverses the block-wise quantization applied to groups of elements by using separate - scale and zero_point values for each block. Equivalent to dequantize_affine, it applies the - inverse affine transformation per block to convert quantized values back to their original - floating-point representation. For example, if the tensor shape is [6, 9, 4] and - blockSize = [3, 3, 2], the tensor is divided into 12 blocks, each containing 18 elements, - and dequantization is performed independently on each block. - - (*) global_wg_size: default - (*) local_wg_size: default - - Dequantization Formula: - value = (qvalue - zero_point) * scale -*/ - -#ifdef per_tensor - -void dequantize_per_tensor() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T qvalue = t_in[in_bufi]; - OUT_T value = dequantize_val(qvalue, float(t_scale[0]), int(t_zero_point[0])); - - t_out[out_bufi] = value; -} - -#elif defined(per_token) - -void dequantize_per_token() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T qvalue = t_in[in_bufi]; - - int token_idx = 0; - - if (t_out_sizes.w > 1) { - // 4D tensor - token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; - } else if (t_out_sizes.z > 1) { - // 3D tensor - token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; - } else if (t_out_sizes.y > 1) { - // 2D tensor - token_idx = out_tidx.y; - } - // For 1D tensor, token_idx remains 0 - - token_idx = min(token_idx, num_tokens - 1); - - OUT_T value = dequantize_val(qvalue, float(t_scale[token_idx]), int(t_zero_point[token_idx])); - - t_out[out_bufi] = value; -} - -#elif defined(per_channel) - -void dequantize_per_channel() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T qvalue = t_in[in_bufi]; - - // Calculate channel index based on the dequantization axis (already converted to WHCN) - // The axis parameter is now in WHCN coordinate system: - // axis 0 -> W dimension (tidx.x) - // axis 1 -> H dimension (tidx.y) - // axis 2 -> C dimension (tidx.z) - // axis 3 -> N dimension (tidx.w) - int channel_idx = 0; - - if (axis == 0) { - channel_idx = out_tidx.x; - } else if (axis == 1) { - channel_idx = out_tidx.y; - } else if (axis == 2) { - channel_idx = out_tidx.z; - } else if (axis == 3) { - channel_idx = out_tidx.w; - } - - channel_idx = min(channel_idx, num_channels - 1); - - OUT_T value = dequantize_val(qvalue, float(t_scale[channel_idx]), int(t_zero_point[channel_idx])); - - t_out[out_bufi] = value; -} - -#else // block_wise - -void dequantize_block_wise() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T qvalue = t_in[in_bufi]; - - const ivec4 bcoord = out_tidx / blockSize; - - const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - - const OUT_T value = dequantize_val(qvalue, float(t_scale[block_id]), int(t_zero_point[block_id])); - - t_out[out_bufi] = value; -} - -#endif - -void main() { - dequantize_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml deleted file mode 100644 index a4375038a75..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml +++ /dev/null @@ -1,31 +0,0 @@ -dequantize_buffer: - parameter_names_with_default_values: - IN_DTYPE: int32 - OUT_DTYPE: float - SCALE_DTYPE: float - ZP_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: uint8 - - VALUE: int8 - - VALUE: int32 - OUT_DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - SCALE_DTYPE: - - VALUE: float - ZP_DTYPE: - - VALUE: int8 - - VALUE: int32 - - VALUE: float - shader_variants: - - NAME: dequantize_per_tensor_buffer - MODE: per_tensor - - NAME: dequantize_per_token_buffer - MODE: per_token - - NAME: dequantize_per_channel_buffer - MODE: per_channel - - NAME: dequantize_block_wise_buffer - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl deleted file mode 100644 index 19276cd8f7f..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ /dev/null @@ -1,347 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define IVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} - -#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} -#define FVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} -#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} -#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("texture3d")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(OUT_DTYPE)} -${define_required_extensions(SCALE_DTYPE)} -${define_required_extensions(ZP_DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} - -$if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - }; -$if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int axis; - int num_channels; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - ivec4 blockSize; // bW, bH, bC, bN - ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN - ivec4 blockStride; // pre-computed linear strides for the block grid - int quant_min; - int quant_max; - }; - -${layout_declare_ubo(B, "ivec3", "t_in_limits")} -${layout_declare_ubo(B, "ivec3", "t_out_limits")} - -#include "indexing_utils.h" -#include "dequantize.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -/* - * DEQUANTIZATION SHADER (TEXTURE STORAGE) - * - * This shader converts n-bit integer tensor values back to floating-point representations - * using pre-computed quantization parameters (scale and zero_point). The dequantization - * reconstructs the original floating-point values from their discrete integer representations - * with minimal precision loss. - * - * ALGORITHM: - * 1. Load quantized integer texel (4 values) from 3D texture - * 2. Apply dequantization formula to each component: value = (qvalue - zero_point) * scale - * 3. Store reconstructed floating-point texel to output texture - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing - * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) - * - Per-Token Mode: - * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing - * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Texture Storage: Uses 3D texture indexing with texel-based processing - * - Assumes width-packed layout (packed_dim = 0) for input/output textures - * - Handles texel padding for non-multiple-of-4 tensor dimensions - * - For per-token mode: scale/zero_point tensors must use buffer storage - * - Input/output textures: Must use standard axis mapping for per-token mode - * - * DEQUANTIZATION FORMULA VISUALIZATION: - * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: - * - * Integer Domain: Floating Point Domain: - * quant_min ──────────────► min_val - * │ │ - * │ scale = (max_val - min_val) / (quant_max - quant_min) - * │ zero_point = quant_min - round(min_val / scale) - * │ │ - * quant_max ──────────────► max_val - * - * Texel Dequantization Process: - * Input Texel: [-103, -128, -123, -96] (int4) - * Per-component dequantization with scale=0.1, zero_point=-128: - * Component 0: (-103 - (-128)) * 0.1 = 25 * 0.1 = 2.5 - * Component 1: (-128 - (-128)) * 0.1 = 0 * 0.1 = 0.0 - * Component 2: (-123 - (-128)) * 0.1 = 5 * 0.1 = 0.5 - * Component 3: (-96 - (-128)) * 0.1 = 32 * 0.1 = 3.2 - * Output Texel: [2.5, 0.0, 0.5, 3.2] (float4) - * - * PER-TENSOR DEQUANTIZATION: - * - Single scale and zero_point values for entire tensor - * - All texel components use same dequantization parameters - * - Parameters passed as push constants for efficiency - * - Each thread processes one texel (4 elements) independently - * - Formula: value[i] = (qvalue[i] - zero_point) * scale - * - * PER-TOKEN DEQUANTIZATION: - * - Separate scale and zero_point for each token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Parameters stored in buffer arrays indexed by token_id - * - Each thread calculates token_id from its 3D texture position - * - Scale/zero_point buffers accessed directly (not as textures) - * - Formula: value[i] = (qvalue[i] - zero_point[token_id]) * scale[token_id] - * - * Token ID calculation for texel at position (x, y, z): - * - 3D tensor: token_id = z * texture_height + y - * - 2D tensor: token_id = y - * - 1D tensor: token_id = 0 - */ - -#ifdef per_tensor - -void dequantize_per_tensor() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - // Skip if out of bounds - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - IVEC4_T intex = load_texel(t_in, pos); - FVEC4_T outtex; - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, float(t_scale[0]), int(t_zero_point[0])); - - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - write_texel(t_out, pos, outtex); -} - -#elif defined(per_token) - -void dequantize_per_token() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - IVEC4_T intex = load_texel(t_in, pos); - - int token_idx = 0; - ivec3 dims = t_in_limits; - - if (dims.z > 1) { - // 3D tensor - token_idx = pos.z * dims.y + pos.y; - } else if (dims.y > 1) { - // 2D tensor - token_idx = pos.y; - } - // For 1D tensor, token_idx remains 0 - - token_idx = min(token_idx, num_tokens - 1); - - // Scale and zero_point are prepacked as buffers, so direct access - float scale_val = float(t_scale[token_idx]); - int zero_point_val = int(t_zero_point[token_idx]); - - FVEC4_T outtex; - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - - write_texel(t_out, pos, outtex); -} - -#elif defined(per_channel) - -void dequantize_per_channel() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - IVEC4_T intex = load_texel(t_in, pos); - FVEC4_T outtex; - - // Calculate channel index based on the dequantization axis (already converted to WHCN) - // The axis parameter is now in WHCN coordinate system: - // axis 0 -> W dimension (pos.x) - // axis 1 -> H dimension (pos.y) - // axis 2 -> C dimension (pos.z) - // axis 3 -> N dimension (batch folding in texture storage) - - if (axis == 0) { - // Width dimension - each texel component has different channel index - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - int channel_idx = pos.x * 4 + i; - channel_idx = min(channel_idx, num_channels - 1); - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - } else if (axis == 1) { - int channel_idx = pos.y; - channel_idx = min(channel_idx, num_channels - 1); - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - } else if (axis == 2) { - // Channel dimension - for 4D tensors, need to account for batch-channel folding - // The Z coordinate contains folded batch*channel information - // We need to extract the actual channel index from the folded dimension - int folded_idx = pos.z; - int channel_idx = folded_idx % num_channels; - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - } else if (axis == 3) { - // Batch dimension - for 4D tensors, need to account for batch-channel folding - // The Z coordinate contains folded batch*channel information - // We need to extract the actual channel index from the folded dimension - int folded_idx = pos.z; - // In this case num_channels actually corresponds to the number of channels - // the C dimension N(C)HW - int channel_idx = folded_idx / num_channels; - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - } - - write_texel(t_out, pos, outtex); -} - -#else // block_wise - -void dequantize_block_wise() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) - return; - - IVEC4_T intex = load_texel(t_in, pos); - FVEC4_T outtex; - - ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0); - int foldedZ = pos.z; - - int C_total = numBlocks.z * blockSize.z; - - [[unroll]] for (int i = 0; i < 4; ++i) { - ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total)); - - ivec4 bcoord = tidx / blockSize; - int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, float(t_scale[block_id]), int(t_zero_point[block_id])); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - - write_texel(t_out, pos, outtex); -} - -#endif - -void main() { - dequantize_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml deleted file mode 100644 index 7a58e9410d3..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml +++ /dev/null @@ -1,31 +0,0 @@ -dequantize_texture: - parameter_names_with_default_values: - IN_DTYPE: int32 - OUT_DTYPE: float - SCALE_DTYPE: float - ZP_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: uint8 - - VALUE: int8 - - VALUE: int32 - OUT_DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - SCALE_DTYPE: - - VALUE: float - ZP_DTYPE: - - VALUE: int8 - - VALUE: int32 - - VALUE: float - shader_variants: - - NAME: dequantize_per_tensor_texture3d - MODE: per_tensor - - NAME: dequantize_per_token_texture3d - MODE: per_token - - NAME: dequantize_per_channel_texture3d - MODE: per_channel - - NAME: dequantize_block_wise_texture3d - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml deleted file mode 100644 index 0e7b491c433..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml +++ /dev/null @@ -1,12 +0,0 @@ -embedding: - parameter_names_with_default_values: - DTYPE: float - NDIM: 3 - STORAGE: texture3d - generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: int32 - shader_variants: - - NAME: embedding diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.glsl new file mode 100644 index 00000000000..c1a21e44c60 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.glsl @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_indices", "int", "buffer")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "indices")} +${layout_declare_ubo(B, "BufferMetadata", "weight")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +TensorIndex out_tidx_to_indices_tidx(const TensorIndex out_tidx) { + TensorIndex indices_tidx; + int d = 0; + // First half of the index + [[unroll]] for (uint d = 0; d < ndim(indices); ++d) { + indices_tidx.data[div_4(d)][mod_4(d)] = idx_at(out_tidx, d + 1); + } + [[unroll]] for (uint d = ndim(indices); d < DIMLIMIT; ++d) { + indices_tidx.data[div_4(d)][mod_4(d)] = 0; + } + return indices_tidx; +} + +int load_embedding_idx(const TensorIndex indices_tidx) { + const uint bufi = tensor_idx_to_linear_idx(indices, indices_tidx); + return t_indices[bufi]; +} + +T load_weight_elem(const int embedding_idx, const uint dim_idx) { + uint bufi = uint(embedding_idx) * width(weight) + dim_idx; + return t_weight[bufi]; +} + +void main() { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, outp)) { + return; + } + + TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi); + TensorIndex indices_tidx = out_tidx_to_indices_tidx(out_tidx); + + const uint bufi = tensor_idx_to_linear_idx(indices, indices_tidx); + const int embedding_idx = load_embedding_idx(indices_tidx); + + t_out[out_bufi] = load_weight_elem(embedding_idx, x(out_tidx)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.yaml new file mode 100644 index 00000000000..fdd4d6f13e1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +embedding_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: embedding_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.glsl similarity index 100% rename from backends/vulkan/runtime/graph/ops/glsl/embedding.glsl rename to backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.glsl diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.yaml b/backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.yaml new file mode 100644 index 00000000000..a3cf16db4c4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.yaml @@ -0,0 +1,12 @@ +embedding_legacy: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + shader_variants: + - NAME: embedding_legacy diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl new file mode 100644 index 00000000000..9a6295a8094 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, "texture3d")} +#define T ${texel_load_component_type(DTYPE, "texture3d")} + +${define_active_storage_type("texture3d")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "common.glslh" +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_indices", "int", "texture3d")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")} + +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "indices")} +${layout_declare_ubo(B, "BufferMetadata", "weight")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +int load_embedding_idx(const TensorIndex4D out_tidx) { + TensorIndex4D indices_tidx; + indices_tidx.data.xyz = out_tidx.data.yzw; + indices_tidx.data.w = 0; + + TextureElementIndex elem_pos = tensor4d_idx_to_texture_element_idx_simple( + indices, indices_tidx); + + const ivec4 in_texel = texelFetch(t_indices, elem_pos.pos, 0); + return in_texel[elem_pos.comp]; +} + +VEC4_T load_weight_texel(const int embedding_idx, const int dim_idx) { + int buf_i = embedding_idx * int(width(weight)) + dim_idx; + VEC4_T weight_texel; + [[unroll]] for (int i = 0; i < 4; ++i) { + weight_texel[i] = T(t_weight[buf_i++]); + } + return weight_texel; +} + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + if (out_of_bounds(out_pos, outp)) { + return; + } + + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); + const int embedding_idx = load_embedding_idx(out_tidx); + + const VEC4_T weight_texel = load_weight_texel(embedding_idx, out_tidx.data.x); + + imageStore(t_out, out_pos, weight_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.yaml new file mode 100644 index 00000000000..475db0941ce --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +embedding_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: embedding_texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml index 6d90e1fa8b1..887f7893061 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml @@ -6,5 +6,6 @@ expand_buffer: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: expand_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/full.yaml b/backends/vulkan/runtime/graph/ops/glsl/full.yaml index eff78a7938d..5d7a983cae3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/full.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/full.yaml @@ -14,5 +14,7 @@ full: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: full diff --git a/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.glsl new file mode 100644 index 00000000000..318631a160f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.glsl @@ -0,0 +1,57 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_input", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_index", "int", "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} +${layout_declare_ubo(B, "BufferMetadata", "index")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int gather_dim = 0; + +void main() { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, outp)) { + return; + } + + TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi); + + // Load the index value at the same position in the index tensor + const uint index_bufi = tensor_idx_to_linear_idx(index, out_tidx); + const int gather_idx = t_index[index_bufi]; + + // Construct the input tensor index by replacing the gather dimension + // with the gathered index value + TensorIndex input_tidx = out_tidx; + input_tidx.data[div_4(gather_dim)][mod_4(gather_dim)] = gather_idx; + + // Load from input tensor and store to output + const uint input_bufi = tensor_idx_to_linear_idx(inp, input_tidx); + + t_out[out_bufi] = t_input[input_bufi]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.yaml new file mode 100644 index 00000000000..8e2cff21b61 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +gather_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + - VALUE: uint8 + shader_variants: + - NAME: gather_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/gather_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/gather_texture.glsl new file mode 100644 index 00000000000..71e352a7875 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/gather_texture.glsl @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, "texture3d")} +#define T ${texel_load_component_type(DTYPE, "texture3d")} + +${define_active_storage_type("texture3d")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "common.glslh" +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_input", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_index", "int", "texture3d")} + +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "inp")} +${layout_declare_ubo(B, "TextureMetadata", "index")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int gather_dim = 0; + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + if (out_of_bounds(out_pos, outp)) { + return; + } + + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); + ivec4 idx_texel = texelFetch(t_index, out_pos, 0); + + VEC4_T out_texel = VEC4_T(0); + + int limit = min( + 4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]); + for (int comp = 0; comp < 4; comp++) { + TensorIndex4D input_tidx = out_tidx; + int gather_idx = idx_texel[comp]; + input_tidx.data[gather_dim] = gather_idx; + + TextureElementIndex input_elem_pos = tensor4d_idx_to_texture_element_idx_simple( + inp, input_tidx); + + VEC4_T input_texel = texelFetch(t_input, input_elem_pos.pos, 0); + out_texel[comp] = input_texel[input_elem_pos.comp]; + + out_tidx.data[outp.packed_dim]++; + } + + imageStore(t_out, out_pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/gather_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/gather_texture.yaml new file mode 100644 index 00000000000..dd38ecd0a7d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/gather_texture.yaml @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +gather_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + - VALUE: uint8 + shader_variants: + - NAME: gather_texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.glsl b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.glsl new file mode 100644 index 00000000000..3ecaa597ecc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.glsl @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +$if STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER + +#define TILE_M4 1 +#define TILE_N4 1 +#define TILE_K4 1 + +#define TILE_M 4 +#define TILE_N 4 +#define TILE_K 4 + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "im2col_sizes")} +// Sizes of the output image +${layout_declare_ubo(B, "ivec4", "output_sizes")} +// Sizes of the input image +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float inv_scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_int8_output_tile_store.glslh" +#include "im2col_packed_int8_utils.glslh" + +void main() { + const int out_buf_idx = int(gl_GlobalInvocationID.x); + Conv2dBlockExtents im2col_block_extents = make_block_extents(im2col_sizes); + + Conv2dBlockIndex im2col_block_idx = linear_idx_to_block_idx( + out_buf_idx, im2col_block_extents); + + if (block_idx_out_of_bounds(im2col_block_idx, im2col_block_extents)) { + return; + } + + Im2ColBlockLoadIndices load_ixs = im2col_block_idx_to_load_ixs( + im2col_block_idx); + + Conv2dBlockExtents input_block_extents = make_block_extents(input_sizes); + + const ivec4 input_zps = ivec4(pack_into_int32(ivec4(zp))); + Int8OutTile int8_im2col_tile; + int8_im2col_tile.data[0][0] = load_im2col_block( + load_ixs, input_block_extents, zp, input_zps); + + store_packed_int8_output_tile( + int8_im2col_tile, im2col_block_idx, im2col_block_extents); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.yaml b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.yaml new file mode 100644 index 00000000000..1d00ddddd6e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +im2col_packed_int8: + parameter_names_with_default_values: + STORAGE: buffer + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture3d + shader_variants: + - NAME: im2col_packed_int8 diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8_utils.glslh new file mode 100644 index 00000000000..f2617aec7c7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8_utils.glslh @@ -0,0 +1,287 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef IM2COL_PACKED_INT8_GLSLH +#define IM2COL_PACKED_INT8_GLSLH + +#include "indexing.glslh" + +struct Conv2dBlockElementIndex { + int x4; + int y; + int z4; + + int row; + int col; +}; + +struct Im2ColBlockLoadIndices { + bool block_aligned; + bool cols_aligned; + bool rows_contiguous; + + int im2col_w_start; + int im2col_h; + int k_in_group_start; + int group_idx; + + Conv2dBlockElementIndex block_idx_start; +}; + +Conv2dBlockElementIndex tidx_to_block_elem_idx(const TensorIndex4D tidx) { + Conv2dBlockElementIndex block_idx; + block_idx.x4 = div_4(tidx.data.x); + block_idx.row = mod_4(tidx.data.x); + + block_idx.y = tidx.data.y; + + block_idx.z4 = div_4(tidx.data.z); + block_idx.col = mod_4(tidx.data.z); + + return block_idx; +} + +TensorIndex4D get_input_tensor_tidx( + const int w, + const int h, + const int k_in_group, + const int group_idx) { + TensorIndex4D tidx; + tidx.data.w = 0; + + const int c_in_group = k_in_group % conv2d_params.in_channels_per_group; + const int row = k_in_group / conv2d_params.in_channels_per_group; + const int kernel_x = row % conv2d_params.kernel_size.x; + const int kernel_y = row / conv2d_params.kernel_size.x; + + tidx.data.z = group_idx * conv2d_params.in_channels_per_group + c_in_group; + + tidx.data.x = (w * conv2d_params.stride.x) - conv2d_params.padding.x + + (kernel_x * conv2d_params.dilation.x); + tidx.data.y = (h * conv2d_params.stride.y) - conv2d_params.padding.y + + (kernel_y * conv2d_params.dilation.y); + + return tidx; +} + +Im2ColBlockLoadIndices im2col_block_idx_to_load_ixs( + Conv2dBlockIndex im2col_block_idx) { + const int im2col_w = mul_4(im2col_block_idx.data.x); + const int im2col_h = im2col_block_idx.data.y; + const int im2col_k = mul_4(im2col_block_idx.data.z); + + const int group_idx = im2col_k / conv2d_params.K_per_group; + const int k_in_group = im2col_k % conv2d_params.K_per_group; + + TensorIndex4D input_tidx = + get_input_tensor_tidx(im2col_w, im2col_h, k_in_group, group_idx); + + bool cols_aligned = (mod_4(input_tidx.data.z) == 0) && + (input_tidx.data.z + 3 < conv2d_params.in_channels_per_group); + + bool rows_aligned = mod_4(input_tidx.data.x) == 0; + bool rows_contiguous = conv2d_params.stride.x == 1; + + Im2ColBlockLoadIndices load_ixs; + load_ixs.block_aligned = cols_aligned && rows_aligned && rows_contiguous; + load_ixs.cols_aligned = cols_aligned; + load_ixs.rows_contiguous = rows_contiguous; + + load_ixs.im2col_w_start = im2col_w; + load_ixs.im2col_h = im2col_h; + load_ixs.k_in_group_start = k_in_group; + load_ixs.group_idx = group_idx; + + load_ixs.block_idx_start = tidx_to_block_elem_idx(input_tidx); + + return load_ixs; +} + +bool is_block_elem_idx_in_bounds( + const Conv2dBlockElementIndex idx, + const Conv2dBlockExtents block_extents) { + const ivec3 block_idx = ivec3(idx.x4, idx.y, idx.z4); + if (any(lessThan(block_idx, ivec3(0))) || + any(greaterThanEqual(block_idx, block_extents.data))) { + return false; + } + return true; +} + +int load_packed_int8_input_element( + const Conv2dBlockElementIndex idx, + const Conv2dBlockExtents block_extents, + const int input_zp) { + // bounds checking + if (!is_block_elem_idx_in_bounds(idx, block_extents)) { + return input_zp; + } +#ifdef PACKED_INT8_INPUT_BUFFER + const int buf_idx = + idx.y * block_extents.data_xz + idx.x4 * block_extents.data.z + idx.z4; + const ivec4 tile = t_packed_int8_input[buf_idx]; +#else + const ivec4 tile = + texelFetch(t_packed_int8_input, ivec3(idx.x4, idx.y, idx.z4), 0); +#endif + return extract_8bit_from_packed_int_le(tile[idx.row], idx.col); +} + +Conv2dBlockElementIndex get_packed_int8_input_element_idx( + const int im2col_w, + const int im2col_h, + const int k_in_group, + const int group_idx) { + TensorIndex4D input_tidx = + get_input_tensor_tidx(im2col_w, im2col_h, k_in_group, group_idx); + + return tidx_to_block_elem_idx(input_tidx); +} + +ivec4 load_im2col_block_aligned( + const Im2ColBlockLoadIndices load_ixs, + const Conv2dBlockExtents block_extents) { +#ifdef PACKED_INT8_INPUT_BUFFER + const int buf_idx = load_ixs.block_idx_start.y * block_extents.data_xz + + load_ixs.block_idx_start.x4 * block_extents.data.z + + load_ixs.block_idx_start.z4; + return t_packed_int8_input[buf_idx]; +#else + return texelFetch( + t_packed_int8_input, + ivec3( + load_ixs.block_idx_start.x4, + load_ixs.block_idx_start.y, + load_ixs.block_idx_start.z4), + 0); +#endif +} + +ivec4 load_im2col_block_c_aligned_w_contiguous( + const Im2ColBlockLoadIndices load_ixs, + const Conv2dBlockExtents block_extents, + const ivec4 input_zps) { + ivec4 im2col_block; + Conv2dBlockElementIndex block_elem_idx = load_ixs.block_idx_start; + +#ifdef PACKED_INT8_INPUT_BUFFER + int buf_idx = load_ixs.block_idx_start.y * block_extents.data_xz + + load_ixs.block_idx_start.x4 * block_extents.data.z + + load_ixs.block_idx_start.z4; +#endif + + ivec4 in_block = input_zps; + if (is_block_elem_idx_in_bounds(block_elem_idx, block_extents)) { +#ifdef PACKED_INT8_INPUT_BUFFER + in_block = t_packed_int8_input[buf_idx]; +#else + in_block = texelFetch( + t_packed_int8_input, + ivec3(block_elem_idx.x4, block_elem_idx.y, block_elem_idx.z4), + 0); +#endif + } + + int current_row = 0; + int r_limit = min(4 - block_elem_idx.row, 4); + for (int r = 0; r < r_limit; r++) { + im2col_block[current_row++] = in_block[r + block_elem_idx.row]; + } + + in_block = input_zps; + block_elem_idx.x4++; +#ifdef PACKED_INT8_INPUT_BUFFER + buf_idx += block_extents.data.z; +#endif + + if (is_block_elem_idx_in_bounds(block_elem_idx, block_extents)) { +#ifdef PACKED_INT8_INPUT_BUFFER + in_block = t_packed_int8_input[buf_idx]; +#else + in_block = texelFetch( + t_packed_int8_input, + ivec3(block_elem_idx.x4, block_elem_idx.y, block_elem_idx.z4), + 0); +#endif + } + + for (int r = 0; current_row < 4; ++r) { + im2col_block[current_row++] = in_block[r]; + } + + return im2col_block; +} + +ivec4 load_im2col_block_no_alignment( + const Im2ColBlockLoadIndices load_ixs, + const Conv2dBlockExtents block_extents, + const int input_zp) { + ivec4 im2col_block; + + for (int r = 0; r < 4; r++) { + const int im2col_w = load_ixs.im2col_w_start + r; + ivec4 row_values; + for (int c = 0; c < 4; c++) { + const int k_in_group = load_ixs.k_in_group_start + c; + + if (k_in_group >= conv2d_params.logical_K_per_group) { + row_values[c] = input_zp; + continue; + } + + Conv2dBlockElementIndex block_idx = get_packed_int8_input_element_idx( + im2col_w, load_ixs.im2col_h, k_in_group, load_ixs.group_idx); + + row_values[c] = + load_packed_int8_input_element(block_idx, block_extents, input_zp); + } + + im2col_block[r] = pack_into_int32(row_values); + } + return im2col_block; +} + +ivec4 load_im2col_block( + const Im2ColBlockLoadIndices load_ixs, + const Conv2dBlockExtents block_extents, + const int input_zp, + const ivec4 input_zps) { + if (load_ixs.cols_aligned && load_ixs.rows_contiguous) { + return load_im2col_block_c_aligned_w_contiguous( + load_ixs, block_extents, input_zps); + } + return load_im2col_block_no_alignment(load_ixs, block_extents, input_zp); +} + +#ifdef DEBUG_MODE + +void printLoadIndices(const Im2ColBlockLoadIndices load_ixs) { + debugPrintfEXT("LoadIndices: \\n"); + + if (load_ixs.block_aligned) { + debugPrintfEXT(" block_aligned \\n"); + } + if (load_ixs.cols_aligned) { + debugPrintfEXT(" cols_aligned \\n"); + } + if (load_ixs.rows_contiguous) { + debugPrintfEXT(" rows_contiguous \\n"); + } + + debugPrintfEXT( + " block_idx_start: %d %d %d || %d %d \\n", + load_ixs.block_idx_start.x4, + load_ixs.block_idx_start.y, + load_ixs.block_idx_start.z4, + load_ixs.block_idx_start.row, + load_ixs.block_idx_start.col); +} + +#endif + +#endif // IM2COL_PACKED_INT8_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl index d7bef9f0163..1498ed01aef 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl @@ -16,10 +16,11 @@ ${define_active_storage_type(STORAGE)} ${define_required_extensions(DTYPE)} +${define_required_extensions(BUF_DTYPE)} layout(std430) buffer; -${layout_declare_buffer(B, "w", "buf_out", DTYPE)} +${layout_declare_buffer(B, "w", "buf_out", BUF_DTYPE)} ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} $if USE_PUSH_CONST: diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml index 646d8f1be81..ebbc55dd9dc 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml @@ -7,17 +7,21 @@ image_to_nchw: parameter_names_with_default_values: DTYPE: float + BUF_DTYPE: float STORAGE: texture3d TO_STAGING: True USE_PUSH_CONST: True generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - - VALUE: int8 - - VALUE: uint8 - - VALUE: int32 + combination: + parameter_names: [DTYPE, BUF_DTYPE] + combos: + - parameter_values: [half, half] + - parameter_values: [half, float] + - parameter_values: [float, float] + - parameter_values: [double, double] + - parameter_values: [int8, int8] + - parameter_values: [uint8, uint8] + - parameter_values: [int32, int32] shader_variants: - NAME: image_to_nchw_texture3d - NAME: image_to_nchw_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml index abef2225cd9..6bf4c71a3c0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml @@ -8,5 +8,6 @@ index_select: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: index_select diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml index a306e3ce47d..716f7ecf2d0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml @@ -8,5 +8,6 @@ index_select_channel: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: index_select_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh index 81783422ab4..b9ac0e5dace 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh @@ -9,14 +9,11 @@ #ifndef INDEXING_GLSLH #define INDEXING_GLSLH +#include "common.glslh" + #define DIMLIMIT 8 #define DIMLIMIT_DIV4 2 -#define mul_4(x) ((x) << 2) -#define div_4(x) ((x) >> 2) - -#define mod_4(x) ((x) & 3) - // // BufferMetadata // @@ -56,6 +53,14 @@ uint stride_at(const BufferMetadata meta, const uint dim) { return meta.strides[div_4(dim)][mod_4(dim)]; } +uint width(const BufferMetadata meta) { + return meta.sizes[0][0]; +} + +uint height(const BufferMetadata meta) { + return meta.sizes[0][1]; +} + uint size_at(const BufferMetadata meta, const int dim) { return meta.sizes[div_4(dim)][mod_4(dim)]; } @@ -81,6 +86,25 @@ bool are_equal(const BufferMetadata meta1, const BufferMetadata meta2) { return true; } +bool out_of_bounds(const uint bufi, const BufferMetadata meta) { + return bufi >= meta.ndim_numel[1]; +} + +// +// TextureMetadata +// + +struct TextureMetadata { + ivec4 sizes; + ivec3 limits; + ivec4 axis_map; + int packed_dim; +}; + +bool out_of_bounds(const ivec3 pos, const TextureMetadata meta) { + return any(greaterThanEqual(pos, meta.limits)); +} + // // TensorIndex // @@ -98,6 +122,10 @@ uint idx_at(const TensorIndex tidx, const int dim) { return tidx.data[div_4(dim)][mod_4(dim)]; } +uint idx_at(const TensorIndex tidx, const uint dim) { + return tidx.data[div_4(dim)][mod_4(dim)]; +} + void permute(inout TensorIndex tidx, const ivec4 permute_order[DIMLIMIT_DIV4]) { TensorIndex new_tidx = tidx; for (int d = 0; d < DIMLIMIT; ++d) { @@ -107,6 +135,41 @@ void permute(inout TensorIndex tidx, const ivec4 permute_order[DIMLIMIT_DIV4]) { tidx = new_tidx; } +uint x(const TensorIndex tidx) { + return tidx.data[0][0]; +} + +// +// TensorIndex4D (useful for texture backed tensors) +// + +struct TensorIndex4D { + ivec4 data; +}; + +TensorIndex4D zero_tensor4d_idx() { + TensorIndex4D tidx; + tidx.data = ivec4(0); + return tidx; +} + +bool out_of_bounds(const TensorIndex4D tidx, const BufferMetadata meta) { + return any(greaterThanEqual(tidx.data, meta.sizes[0])); +} + +bool out_of_bounds(const TensorIndex4D tidx, const TextureMetadata meta) { + return any(greaterThanEqual(tidx.data, meta.sizes)); +} + +// +// TextureElementIndex +// + +struct TextureElementIndex { + ivec3 pos; + int comp; +}; + // // Index Conversions // @@ -133,6 +196,14 @@ void contiguous_idx_to_tensor_idx( } } +TensorIndex contiguous_idx_to_tensor_idx( + const BufferMetadata meta, + uint contiguous_idx) { + TensorIndex tidx; + contiguous_idx_to_tensor_idx(meta, contiguous_idx, tidx); + return tidx; +} + uint tensor_idx_to_contiguous_idx( const BufferMetadata meta, const TensorIndex tidx) { @@ -165,6 +236,14 @@ void linear_idx_to_tensor_idx( } } +TensorIndex linear_idx_to_tensor_idx( + const BufferMetadata meta, + uint linear_idx) { + TensorIndex tidx; + linear_idx_to_tensor_idx(meta, linear_idx, tidx); + return tidx; +} + uint tensor_idx_to_linear_idx( const BufferMetadata meta, const TensorIndex tidx) { @@ -180,6 +259,80 @@ void clamp_tensor_idx(const BufferMetadata meta, inout TensorIndex tidx) { tidx.data[1] = min(tidx.data[1], meta.sizes[1] - 1); } +// Does not account for axis mapping +TensorIndex4D texture_pos_to_tensor4d_idx_simple( + const TextureMetadata meta, const ivec3 pos) { + TensorIndex4D tidx; + tidx.data.xyz = pos; + tidx.data.w = 0; + tidx.data[meta.packed_dim] *= 4; + + // Compute batch idx accounting for batch concatenation, assuming channels as + // the concatenation dim. + if (meta.sizes.w > 1) { + int channels = meta.sizes.z; + if (meta.packed_dim == 2) { + channels = align_up_4(channels); + } + tidx.data.w = tidx.data.z / channels; + tidx.data.z = tidx.data.z % channels; + } + return tidx; +} + +// Does not account for axis mapping +ivec3 tensor4d_idx_to_texel_pos_simple( + const TextureMetadata meta, const TensorIndex4D tidx) { + ivec3 texel_pos; + + const int packed_dim_idx = tidx.data[meta.packed_dim]; + + texel_pos = tidx.data.xyz; + texel_pos[meta.packed_dim] = div_4(packed_dim_idx); + + // Account for batch concatenation, assuming channels as the concatenation dim + if (meta.sizes.w > 1) { + int channels_ntexels = meta.sizes.z; + if (meta.packed_dim == 2) { + channels_ntexels = div_up_4(channels_ntexels); + } + texel_pos.z += tidx.data.w * channels_ntexels; + } + + return texel_pos; +} + +// Does not account for axis mapping +TextureElementIndex tensor4d_idx_to_texture_element_idx_simple( + const TextureMetadata meta, const TensorIndex4D tidx) { + const int packed_dim_idx = tidx.data[meta.packed_dim]; + TextureElementIndex tex_idx; + tex_idx.pos = tidx.data.xyz; + tex_idx.pos[meta.packed_dim] = div_4(packed_dim_idx); + tex_idx.comp = mod_4(packed_dim_idx); + + // Account for batch concatenation, assuming channels as the concatenation dim + if (meta.sizes.w > 1) { + int channels_ntexels = meta.sizes.z; + if (meta.packed_dim == 2) { + channels_ntexels = div_up_4(channels_ntexels); + } + tex_idx.pos.z += tidx.data.w * channels_ntexels; + } + + return tex_idx; +} + +uint tensor4d_idx_to_linear_idx( + const BufferMetadata meta, + const TensorIndex4D tidx) { + uint lin_idx = 0; + for (int d = 0; d < 4; ++d) { + lin_idx += meta.strides[0][d] * tidx.data[d]; + } + return lin_idx; +} + // // Debug utilities // @@ -194,6 +347,21 @@ void printTensorIndex(const TensorIndex tidx) { ); } +void printTensorIndex4D(const TensorIndex4D tidx) { + debugPrintfEXT( + "TensorIndex4D: [%u, %u, %u, %u]\\n", + tidx.data[0], tidx.data[1], tidx.data[2], tidx.data[3] + ); +} + +void printTextureElementIndex(const TextureElementIndex tex_idx) { + debugPrintfEXT( + "TextureElementIndex: pos=[%d %d %d] comp=%d\\n", + tex_idx.pos.x, tex_idx.pos.y, tex_idx.pos.z, tex_idx.comp + ); +} + + void printBufferMetadata(const BufferMetadata meta) { debugPrintfEXT( "BufferMetadata: ndim=%u numel=%u\\n sizes=[%u %u %u %u %u %u %u %u]\\n dim_order=[%u %u %u %u %u %u %u %u]\\n strides=[%u %u %u %u %u %u %u %u]\\n", @@ -211,6 +379,16 @@ void printBufferMetadata(const BufferMetadata meta) { ); } +void printTextureMetadata(const TextureMetadata meta) { + debugPrintfEXT( + "TextureMetadata:\\n sizes=[%u %u %u %u]\\n limits=[%u %u %u]\\n axis_map=[%u %u %u %u]\\n packed_dim=%u\\n", + meta.sizes[0], meta.sizes[1], meta.sizes[2], meta.sizes[3], + meta.limits[0], meta.limits[1], meta.limits[2], + meta.axis_map[0], meta.axis_map[1], meta.axis_map[2], meta.axis_map[3], + meta.packed_dim + ); +} + #endif #endif // INDEXING_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh index da326b26e93..c95abdcb230 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh @@ -16,19 +16,6 @@ #include "common.glslh" -int sign_extend_8bit(const int val) { - if ((val & 0x80) != 0) { - return val | (~0xFF); - } - return val; -} - -int extract_8bit_from_packed_int_le(const int packed, const int i) { - // account for little endian - int byte = sign_extend_8bit(packed >> (8 * i) & 0xFF); - return byte; -} - // Extract a 4-bit value from a packed int (little endian) // It is assumed that the 4-bit value is in the range [0, 15] int extract_4bit_from_packed_int_le(const int packed, const int col) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml index cb9cdc4a046..a252055ed40 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml @@ -16,6 +16,7 @@ linear_dq8ca_q4gsw_tiled: generate_variant_forall: DTYPE: - VALUE: float + - VALUE: half shader_variants: - NAME: linear_dq8ca_q4gsw_tiled_texture3d_texture2d - NAME: linear_dq8ca_q4gsw_tiled_texture3d_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh index 7bc7071ab1f..0a11ed6f482 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh @@ -77,7 +77,7 @@ void accumulate_out_tile_with_int_accum_from_int4_weights( out_tile.data[m][n4] = fma(VEC4_T(accum_adjusted), - input_scale_m * weight_scales.data[n4], + VEC4_T(input_scale_m * weight_scales.data[n4]), out_tile.data[m][n4]); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh index 68ac269e9d7..850dc7943c0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh @@ -75,7 +75,7 @@ void accumulate_out_tile_with_int_accum( input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; out_tile.data[m][n4] = fma(VEC4_T(accum_adjusted), - input_q_scale * weight_scales.data[0], + VEC4_T(input_q_scale * weight_scales.data[n4]), out_tile.data[m][n4]); } } @@ -98,7 +98,7 @@ void accumulate_out_tile_with_int_accum( input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; out_tile.data[m][n4] = fma(VEC4_T(accum_adjusted), - input_q_scale * weight_scales.data[n4], + VEC4_T(input_q_scale * weight_scales.data[n4]), out_tile.data[m][n4]); out_tile.data[m][n4] += bias.data[n4]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh index a6dbd7e78a2..8f19418cd19 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh @@ -43,13 +43,6 @@ ivec4 quantize( return clamp(ivec4(quantized), -128, 127); } -int pack_into_int32(const ivec4 quant_vals) { - int packed = ((quant_vals[0] & 0xFF) << 0) | ((quant_vals[1] & 0xFF) << 8) | - ((quant_vals[2] & 0xFF) << 16) | ((quant_vals[3] & 0xFF) << 24); - - return packed; -} - void quantize_and_pack( out Int8InputBlock packed, const FPInputTile in_block, diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh index 5d8f78bae7c..177e0741269 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh @@ -31,13 +31,13 @@ void printInt8InputTile(const Int8InputTile tile) { [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { - debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, k4); + debugPrintfEXT(" tile[%d][%d]:\\n", m4, k4); // Each ivec4 contains 4 packed integers, each integer contains 4 8-bit // values [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { int packed_int = tile.data[m4][k4][vec_idx]; - debugPrintfEXT("packed_int[%d]=%d -> [", vec_idx, packed_int); + debugPrintfEXT(" [", vec_idx, packed_int); // Extract 4 8-bit values from this packed integer [[unroll]] for (int byte_idx = 0; byte_idx < 4; ++byte_idx) { @@ -48,6 +48,7 @@ void printInt8InputTile(const Int8InputTile tile) { debugPrintfEXT("%d] ", val); } } + debugPrintfEXT("(packed=%d)\\n", packed_int); } debugPrintfEXT("\\n"); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile.glslh new file mode 100644 index 00000000000..cc3c4e9e089 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile.glslh @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Macro Settings: + * - TILE_M + * - TILE_N4 + */ + +#ifndef LINEAR_INT8_OUTPUT_TILE_GLSLH +#define LINEAR_INT8_OUTPUT_TILE_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct Int8OutTile { + ivec4 data[TILE_M4][TILE_N4]; +}; + +void initialize(out Int8OutTile tile) { + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + tile.data[m4][n4] = ivec4(0); + } + } +} + +#ifdef DEBUG_MODE + +#include "linear_common.glslh" + +void printInt8OutTile(const Int8OutTile tile) { + debugPrintfEXT( + "Int8OutTile [TILE_M4=%d][TILE_N4=%d]:\\n", TILE_M4, TILE_N4); + + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT(" tile[%d][%d]:\\n", m4, n4); + + // Each ivec4 contains 4 packed integers, each integer contains 4 8-bit + // values + [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { + int packed_int = tile.data[m4][n4][vec_idx]; + debugPrintfEXT(" [", vec_idx, packed_int); + + // Extract 4 8-bit values from this packed integer + [[unroll]] for (int byte_idx = 0; byte_idx < 4; ++byte_idx) { + int val = extract_8bit_from_packed_int_le(packed_int, byte_idx); + if (byte_idx < 3) { + debugPrintfEXT("%d, ", val); + } else { + debugPrintfEXT("%d] ", val); + } + } + debugPrintfEXT("(packed=%d)\\n", packed_int); + } + debugPrintfEXT("\\n"); + } + } +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_OUTPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile_compute.glslh new file mode 100644 index 00000000000..1251ca60b87 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile_compute.glslh @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines functions to compute a FPOutTile using int8 input and weight tiles. + * + * Settings: + * - TILE_M: The number of rows in the output tile. + * - TILE_N4: The number of (groups of 4) columns in the output tile. + */ + +#ifndef LINEAR_INT8_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH +#define LINEAR_INT8_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#include "linear_fp_per_out_channel_params.glslh" +#include "linear_int8_output_tile.glslh" +#include "linear_int_accumulator.glslh" +#include "linear_int_per_out_channel_params.glslh" + +void compute_int8_out_tile_with_int32_accum( + out Int8OutTile out_tile, + const Int32Accum accum, + const float input_q_scale, + const int input_q_zp, + const float output_q_inv_scale, + const int output_q_zp, + const IntPerOutChannelParams weight_sums, + const FPPerOutChannelParams weight_scales) { + ivec4 input_zp_vec = ivec4(-input_q_zp); + ivec4 output_zp_vec = ivec4(-output_q_zp); + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int m4i = 0; m4i < 4; ++m4i) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int m = mul_4(m4) + m4i; + // Compute floating point output values + ivec4 accum_adjusted = + input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; + vec4 float_out_texel = + vec4(accum_adjusted) * vec4(weight_scales.data[n4] * input_q_scale); + // Requantize to int8 + float_out_texel = + round(float_out_texel * output_q_inv_scale) + output_q_zp; + ivec4 quantized_out_texel = clamp(ivec4(float_out_texel), -128, 127); + + out_tile.data[m4][n4][m4i] = pack_into_int32(quantized_out_texel); + } + } + } +} + +void compute_int8_out_tile_with_int32_accum( + out Int8OutTile out_tile, + const Int32Accum accum, + const float input_q_scale, + const int input_q_zp, + const float output_q_inv_scale, + const int output_q_zp, + const IntPerOutChannelParams weight_sums, + const FPPerOutChannelParams weight_scales, + const FPPerOutChannelParams bias) { + ivec4 input_zp_vec = ivec4(-input_q_zp); + ivec4 output_zp_vec = ivec4(-output_q_zp); + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int m4i = 0; m4i < 4; ++m4i) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + const int m = mul_4(m4) + m4i; + // Compute floating point output values + ivec4 accum_adjusted = + input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; + vec4 float_out_texel = + fma(vec4(accum_adjusted), + vec4(weight_scales.data[n4]) * input_q_scale, + vec4(bias.data[n4])); + // Requantize to int8 + float_out_texel = + round(float_out_texel * output_q_inv_scale) + output_q_zp; + ivec4 quantized_out_texel = clamp(ivec4(float_out_texel), -128, 127); + + out_tile.data[m4][n4][m4i] = pack_into_int32(quantized_out_texel); + } + } + } +} + +#endif // LINEAR_INT8_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl index 0ad91643219..878821d4189 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_tiled.glsl @@ -76,9 +76,6 @@ void main() { const int N4 = div_up_4(output_sizes.x); // number of texels in each row const int N8 = div_up_8(output_sizes.x); // number of texels in each row - bool should_print = (n8 == 0) && (m4 == 0); - should_print = false; - // VEC4_T out_texels[4][2]; FPOutTile out_tile; initialize(out_tile); diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml index aa1de3077fc..989729f2d7f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8ta_q8csw_tiled.yaml @@ -11,7 +11,7 @@ linear_q8ta_q8csw_tiled: PACKED_INT8_INPUT_STORAGE: buffer WEIGHT_STORAGE: texture2d TILE_M4: 1 - TILE_N4: 1 + TILE_N4: 2 TILE_K4: 1 generate_variant_forall: DTYPE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl index c766a3cd7d0..31e04c3a86a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl @@ -98,12 +98,17 @@ void main() { // Preload weight tensor [[unroll]] for (int r = 0; r < 4; r++) { $if QUANT_NBITS == 4: + $if WEIGHT_STORAGE == "buffer": + u8vec4 packed_weight_tex; + $else: + uvec4 packed_weight_tex; + $for c in range(0, TILE_TXCOLS, 2): $if WEIGHT_STORAGE == "buffer": qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; - const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}] + packed_weight_tex = t_weight[qmat2_bufi + ${c}] $else: - const uvec4 packed_weight_tex = texelFetch( + packed_weight_tex = texelFetch( t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0); diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml index 3dff6855142..f05dc7104c4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml @@ -12,7 +12,7 @@ linear_qcsnw_coop: WEIGHT_STORAGE: texture2d SCALES_STORAGE: texture2d TILE_ROWS: 4 - TILE_TXCOLS: 1 + TILE_TXCOLS: 2 QUANT_NBITS: 8 generate_variant_forall: TILE_ROWS: diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index f6f05aab7ca..d966de7282e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -18,19 +18,17 @@ ${define_required_extensions(DTYPE)} -$if WEIGHT_STORAGE == "buffer": - ${define_required_extensions("int8")} - -#extension GL_EXT_control_flow_attributes : require - layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)} -$if QUANT_NBITS == 4: - ${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +$if WEIGHT_STORAGE == "buffer": + ${layout_declare_tensor(B, "r", "t_weight", "uint", WEIGHT_STORAGE, is_scalar_array=True)} $else: - ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} + $if QUANT_NBITS == 4: + ${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} + $else: + ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)} @@ -49,108 +47,156 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { // txcol stands for "texel column". One txcol corresponds to 4 scalar columns. $if TILE_TXCOLS > 1: - const uint16_t global_wg_x = uint16_t(divup(out_sizes.x, 4 * TILE_TXCOLS)); - const uint16_t out_txcol = uint16_t( - (gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS); + const int global_wg_x = divup(out_sizes.x, 4 * TILE_TXCOLS); + const int out_txcol = (int(gl_GlobalInvocationID.x) % global_wg_x) * TILE_TXCOLS; $else: - const uint16_t global_wg_x = uint16_t(divup4(out_sizes.x)); - const uint16_t out_txcol = uint16_t(gl_GlobalInvocationID.x % global_wg_x); + const int global_wg_x = divup4(out_sizes.x); + const int out_txcol = int(gl_GlobalInvocationID.x) % global_wg_x; - const uint16_t out_row = uint16_t( - (gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS); + const int out_row = (int(gl_GlobalInvocationID.x) / global_wg_x) * TILE_ROWS; $if QUANT_NBITS == 4: - const uint16_t weight_txcol = uint16_t(out_txcol / 2); + const int weight_txcol = out_txcol / 2; - if (out_row >= uint16_t(out_sizes.y)) { + if (out_row >= int(out_sizes.y)) { return; } - VEC4_T mat1[TILE_ROWS]; - VEC4_T qmat2[4][TILE_TXCOLS]; - VEC4_T sums[TILE_ROWS][TILE_TXCOLS]; + T sums[TILE_ROWS * TILE_TXCOLS * 4]; - VEC4_T scales[TILE_TXCOLS]; - $for c in range(TILE_TXCOLS): - $if SCALES_STORAGE == "buffer": - scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]); - $else: - scales[${c}] = VEC4_T( - texelFetch(t_scales, u16vec2(out_txcol + ${c}, 0), 0)); + $if QUANT_NBITS == 4: + // accumulate mat1 elements sum so -8 bias can be applied using it later. + T mat1_accum[TILE_ROWS]; + $for r in range(TILE_ROWS): + mat1_accum[${r}] = T(0.0); - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + for (int r = 0; r < TILE_ROWS; ++r) { $for c in range(TILE_TXCOLS): - sums[r][${c}] = VEC4_T(0.0); + $for j in range(4): + sums[r * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] = T(0.0); } - for (uint16_t pos = uint16_t(0), txpos = uint16_t(0); - pos < uint16_t(in_sizes.x); - pos += uint16_t(4), txpos += uint16_t(1)) { + const int in_row_txstride = div4(in_sizes.x); + + $if WEIGHT_STORAGE == "buffer": + $if QUANT_NBITS == 4: + uint qmat2_bufi = weight_txcol; + $else: + uint qmat2_bufi = out_txcol; + + for (int pos = 0, txpos = 0; + txpos < in_row_txstride; + pos += 4, txpos += 1) { + + T mat1[TILE_ROWS * 4]; + + // Preload input tensor + for (int i = 0; i < TILE_ROWS; i++) { + $if IN_STORAGE == "buffer": + VEC4_T mat1_vec4 = t_in[(out_row + i) * in_row_txstride + txpos]; + $else: + VEC4_T mat1_vec4 = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0)); + $for j in range(4): + mat1[i * 4 + ${j}] = mat1_vec4[${j}]; + + $if QUANT_NBITS == 4: + // Accumulate mat1 element sum, this will be multiplied with -8 later for converting 4 bit data to a signed number. + mat1_accum[i] += mat1[i * 4 + 0] + mat1[i * 4 + 1] + mat1[i * 4 + 2] + mat1[i * 4 + 3]; + } + $if WEIGHT_STORAGE == "buffer": - uint qmat2_bufi; uint weight_row_txstride = div4(weight_sizes.x); + uint encoded_weight; // Preload weight tensor - [[unroll]] for (int r = 0; r < 4; r++) { + for (int r = 0; r < 4; r++) { + T qmat2[TILE_TXCOLS * 4]; + $if QUANT_NBITS == 4: + uvec4 packed_weight_tex; + $else: + ivec4 packed_weight_tex; + $if QUANT_NBITS == 4: $for c in range(0, TILE_TXCOLS, 2): $if WEIGHT_STORAGE == "buffer": - qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; - const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}] + encoded_weight = t_weight[qmat2_bufi + ${c}]; + qmat2[${c} * 4 * TILE_TXCOLS + 0] = T((encoded_weight >> 4) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 1] = T((encoded_weight >> 12) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 2] = T((encoded_weight >> 20) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 3] = T((encoded_weight >> 28)); + + qmat2[${c} * 4 * TILE_TXCOLS + 4] = T((encoded_weight) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 5] = T((encoded_weight >> 8) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 6] = T((encoded_weight >> 16) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 7] = T((encoded_weight >> 24) & 0xF); $else: - const uvec4 packed_weight_tex = texelFetch( - t_weight, u16vec2(weight_txcol + ${c}, pos + r), 0); - - qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0); - qmat2[r][${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0); + packed_weight_tex = texelFetch( + t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); + qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4); + + qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF); $else: $for c in range(TILE_TXCOLS): $if WEIGHT_STORAGE == "buffer": - qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol; - qmat2[r][${c}] = t_weight[qmat2_bufi + ${c}]; + encoded_weight = t_weight[qmat2_bufi + ${c}]; + packed_weight_tex = ivec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24); $else: - qmat2[r][${c}] = VEC4_T( - texelFetch(t_weight, u16vec2(out_txcol + ${c}, pos + r), 0)); - } - - $if IN_STORAGE == "buffer": - uint in_row_txstride = div4(in_sizes.x); - - // Preload input tensor - [[unroll]] for (int i = 0; i < TILE_ROWS; i++) { - $if IN_STORAGE == "buffer": - mat1[i] = t_in[(out_row + i) * in_row_txstride + txpos]; - $else: - mat1[i] = VEC4_T( - texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0)); - } + packed_weight_tex = ivec4(texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0)); + $for j in range(4): + qmat2[${c} * 4 + ${j}] = T(packed_weight_tex[${j}]); - // Accumulate output - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - $for c in range(TILE_TXCOLS): - sums[r][${c}] += mat1[r].x * qmat2[0][${c}] + - mat1[r].y * qmat2[1][${c}] + - mat1[r].z * qmat2[2][${c}] + - mat1[r].w * qmat2[3][${c}]; + for (int tr = 0; tr < TILE_ROWS; ++tr) { + $for c in range(TILE_TXCOLS): + $for j in range(4): + sums[tr * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += qmat2[${c} * 4 + ${j}] * mat1[tr * 4 + r]; + } + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi += weight_row_txstride; } } + VEC4_T scales[TILE_TXCOLS]; + $for c in range(TILE_TXCOLS): + $if SCALES_STORAGE == "buffer": + scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]); + $else: + scales[${c}] = VEC4_T( + texelFetch(t_scales, ivec2(out_txcol + ${c}, 0), 0)); + // Store to output tensor $if OUT_STORAGE == "buffer": uint out_bufi; uint out_row_txstride = div4(out_sizes.x); - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + for (int r = 0; r < TILE_ROWS; ++r) { + VEC4_T scaled_sums; $for c in range(TILE_TXCOLS): + $if QUANT_NBITS == 4: + scaled_sums.x = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] + mat1_accum[r] * -8.0) * scales[${c}].x; + scaled_sums.y = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] + mat1_accum[r] * -8.0) * scales[${c}].y; + scaled_sums.z = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] + mat1_accum[r] * -8.0) * scales[${c}].z; + scaled_sums.w = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] + mat1_accum[r] * -8.0) * scales[${c}].w; + $else: + scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] * scales[${c}].x; + scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] * scales[${c}].y; + scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] * scales[${c}].z; + scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] * scales[${c}].w; + $if OUT_STORAGE == "buffer": if (out_row + r < out_sizes.y) { out_bufi = (out_row + r) * out_row_txstride + out_txcol; - t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}]; + t_out[out_bufi + ${c}] = scaled_sums; } $else: imageStore( t_out, ivec3(out_txcol + ${c}, out_row + r, 0), - sums[r][${c}] * scales[${c}]); + scaled_sums); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml index 1c9ec4e524a..81824a12026 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml @@ -20,6 +20,8 @@ linear_qcsnw_tiled: SUFFIX: o4x1 - VALUE: 2 SUFFIX: o4x2 + - VALUE: 3 + SUFFIX: o4x3 - VALUE: 4 SUFFIX: o4x4 shader_variants: @@ -35,8 +37,18 @@ linear_qcsnw_tiled: - NAME: linear_qcs4w_tiled_texture3d_texture3d_texture2d_texture2d_float TILE_TXCOLS: 2 QUANT_NBITS: 4 + - NAME: linear_qcs4w_tiled_texture3d_texture3d_buffer_texture2d_float + TILE_TXCOLS: 2 + QUANT_NBITS: 4 + WEIGHT_STORAGE: buffer - NAME: linear_qcs4w_tiled_buffer_buffer_texture2d_texture2d_float IN_STORAGE: buffer OUT_STORAGE: buffer TILE_TXCOLS: 2 QUANT_NBITS: 4 + - NAME: linear_qcs4w_tiled_buffer_buffer_buffer_texture2d_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + WEIGHT_STORAGE: buffer + TILE_TXCOLS: 2 + QUANT_NBITS: 4 diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl index 074624dc37e..a16f5405cbb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl @@ -5,13 +5,14 @@ #define T ${buffer_scalar_type(DTYPE)} ${define_required_extensions(DTYPE)} +${define_required_extensions(BUF_DTYPE)} layout(std430) buffer; #include "indexing.glslh" ${layout_declare_tensor(B, "w", "t_outp", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "nchw_in", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "nchw_in", BUF_DTYPE, STORAGE)} ${layout_declare_ubo(B, "BufferMetadata", "outp")} @@ -44,5 +45,5 @@ void main() { nchwi = tensor_idx_to_contiguous_idx(outp, outp_tidx); } - t_outp[outp_bufi] = nchw_in[nchwi]; + t_outp[outp_bufi] = T(nchw_in[nchwi]); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml index 9d6c3aa76a9..602fd1bc65a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml @@ -7,15 +7,19 @@ nchw_to_buffer: parameter_names_with_default_values: DTYPE: float + BUF_DTYPE: float STORAGE: buffer USE_PUSH_CONST: True generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - - VALUE: int8 - - VALUE: uint8 - - VALUE: int32 + combination: + parameter_names: [DTYPE, BUF_DTYPE] + combos: + - parameter_values: [half, half] + - parameter_values: [half, float] + - parameter_values: [float, float] + - parameter_values: [double, double] + - parameter_values: [int8, int8] + - parameter_values: [uint8, uint8] + - parameter_values: [int32, int32] shader_variants: - NAME: nchw_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index f3f604e10cd..15676fb0500 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -16,11 +16,12 @@ ${define_active_storage_type(STORAGE)} ${define_required_extensions(DTYPE)} +${define_required_extensions(BUF_DTYPE)} layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_buffer(B, "r", "buf_in", DTYPE)} +${layout_declare_buffer(B, "r", "buf_in", BUF_DTYPE)} $if USE_PUSH_CONST: layout(push_constant) uniform restrict Block { diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml index 85119c8d508..f6809e4024a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml @@ -11,13 +11,16 @@ nchw_to_image: FROM_STAGING: True USE_PUSH_CONST: True generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - - VALUE: int8 - - VALUE: uint8 - - VALUE: int32 + combination: + parameter_names: [DTYPE, BUF_DTYPE] + combos: + - parameter_values: [half, half] + - parameter_values: [half, float] + - parameter_values: [float, float] + - parameter_values: [double, double] + - parameter_values: [int8, int8] + - parameter_values: [uint8, uint8] + - parameter_values: [int32, int32] shader_variants: - NAME: nchw_to_image_texture3d - NAME: nchw_to_image_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl index 0079526c248..18e9b4c7275 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl @@ -12,12 +12,14 @@ $if not NO_INT8_BUFFERS: ${define_required_extensions("uint8")} -$if STORAGE == "buffer": - ${define_required_extensions("int8")} layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)} +$if STORAGE == "buffer" and NO_INT8_BUFFERS: + ${layout_declare_tensor(B, "w", "t_qmat2", "uint", STORAGE, is_scalar_array=True)} +$else: + ${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)} + $if NO_INT8_BUFFERS: ${layout_declare_tensor(B, "r", "nchw_4x2", "uint", "buffer")} $else: @@ -35,7 +37,10 @@ $else: #define BUF_T uint8_t $if STORAGE == "buffer": - #define UVEC4_T u8vec4 + $if NO_INT8_BUFFERS: + #define UVEC4_T uvec4 + $else: + #define UVEC4_T u8vec4 $else: #define UVEC4_T uvec4 @@ -48,7 +53,7 @@ uint get_second(const BUF_T packed) { } uint combine(const uint first, const uint second) { - return (first << 4 | second); + return first * 16 + second; } $if NO_INT8_BUFFERS: @@ -155,8 +160,12 @@ void main() { $if STORAGE == "buffer": int stride = qmat2_sizes.x >> 2; - t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1; - t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2; + $if NO_INT8_BUFFERS: + t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1.x | (out_tex_1.y << 8) | (out_tex_1.z << 16) | (out_tex_1.w << 24); + t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2.x | (out_tex_2.y << 8) | (out_tex_2.z << 16) | (out_tex_2.w << 24); + $else: + t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1; + t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2; $else: imageStore(t_qmat2, packed_pos.xy, out_tex_1); imageStore(t_qmat2, ivec2(packed_pos.x, packed_pos.y + 1), out_tex_2); diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml index 145f4301f14..6bddb4c62cd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml @@ -14,3 +14,6 @@ pack_int4_linear_weight_transposed_interleaved: STORAGE: buffer - NAME: pack_int4_linear_weight_transposed_interleaved_nobitw8buffer_texture2d NO_INT8_BUFFERS: true + - NAME: pack_int4_linear_weight_transposed_interleaved_nobitw8buffer_buffer + STORAGE: buffer + NO_INT8_BUFFERS: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.glsl new file mode 100644 index 00000000000..da4162b6e58 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.glsl @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_packed_int8_weight", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_int8_weight", "int", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec3 orig_sizes; // [K_h, aligned_K_w, OC] +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" + +void main() { + // The size of the source weight tensor is [K_h, aligned_K_w, OC] for depthwise conv. + // Each shader invocation processes a 4x4 block of weights for a group of output channels. + const int oc4 = int(gl_GlobalInvocationID.x); + const int k4 = int(gl_GlobalInvocationID.y); + const int k = mul_4(k4); + + const int H = orig_sizes.x; + const int orig_W = orig_sizes.y; + const int W4 = div_up_4(orig_W); + const int OC = orig_sizes.z; + + const int h = k4 / W4; + const int w4 = k4 % W4; + const int w = mul_4(w4); + + // Determine the total number of blocks and check bounds + const int OC4 = div_up_4(OC); + const int K4 = H * W4; + + if (oc4 >= OC4 || k4 >= K4) { + return; + } + + ivec4 packed_block; + + int buf_idx = (h * orig_W + w) * OC4 + oc4; + int r_limit = min(4, orig_W - w); + [[unroll]] for (int r = 0; r < r_limit; r++) { + packed_block[r] = t_int8_weight[buf_idx]; + buf_idx += OC4; + } + [[unroll]] for (int r = r_limit; r < 4; r++) { + packed_block[r] = 0; + } + +#ifdef USING_BUFFER + t_packed_int8_weight[k4 * OC4 + oc4] = packed_block; +#else + imageStore(t_packed_int8_weight, ivec2(oc4, k4), packed_block); +#endif +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.yaml new file mode 100644 index 00000000000..9cfa3108ff0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_dw_weights.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_q8_conv2d_dw_weights: + parameter_names_with_default_values: + STORAGE: buffer + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture2d + shader_variants: + - NAME: pack_q8_conv2d_dw_weights diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.glsl new file mode 100644 index 00000000000..e9982a8273d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.glsl @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +${define_required_extensions("int8")} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_packed_int8_weight", "int", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_int8_weight", "int8", "buffer")} + +layout(push_constant) uniform restrict Block { + ivec4 qmat2_sizes; + ivec4 orig_sizes; // [OC, K_h, K_w, IC] +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" + +void main() { + const int block_x = int(gl_GlobalInvocationID.x); + const int block_y = int(gl_GlobalInvocationID.y); + + const int kx = block_x % orig_sizes.z; + const int oc4 = block_x / orig_sizes.z; + + const int OC4 = div_up_4(orig_sizes.x); + const int IC4 = div_up_4(orig_sizes.w); + + const int nblocks_x = orig_sizes.z * OC4; + const int nblocks_y = IC4 * orig_sizes.y; + + const int ic4 = block_y % IC4; + const int ky = block_y / IC4; + + if (block_x >= nblocks_x || block_y >= nblocks_y) { + return; + } + + const int oc = mul_4(oc4); + const int ic = mul_4(ic4); + + const int oc_stride = align_up_4(orig_sizes.y * orig_sizes.z * orig_sizes.w); + const int oc_offset = oc * oc_stride; + const int ky_offset = ky * (orig_sizes.z * orig_sizes.w); + const int kx_offset = kx * orig_sizes.w; + int buf_idx = oc_offset + ky_offset + kx_offset + ic; + + ivec4 packed_block = ivec4(0); + for (int row = 0; row < 4; row++) { + if (oc + row < orig_sizes.x) { + ivec4 weight_vals = ivec4(0); + for (int col = 0; col < 4; col++) { + if (ic + col < orig_sizes.w) { + weight_vals[col] = int(t_int8_weight[buf_idx + col]); + } + } + packed_block[row] = pack_into_int32(weight_vals); + } + buf_idx += oc_stride; + } + +#ifdef USING_BUFFER + const int out_buf_idx = block_y * (nblocks_x) + block_x; + t_packed_int8_weight[out_buf_idx] = packed_block; +#else + imageStore(t_packed_int8_weight, ivec2(block_x, block_y), packed_block); +#endif +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.yaml new file mode 100644 index 00000000000..9331de6e758 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_q8_conv2d_weights.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pack_q8_conv2d_weights: + parameter_names_with_default_values: + STORAGE: buffer + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture2d + shader_variants: + - NAME: pack_q8_conv2d_weights diff --git a/backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml index 02afc3846a2..91306bd4cbf 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml @@ -8,5 +8,7 @@ pad_channel: DTYPE: - VALUE: float - VALUE: half + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: pad_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml b/backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml index dd74ec9cc28..2eb57291bb2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml @@ -8,5 +8,7 @@ pad_height_width: DTYPE: - VALUE: float - VALUE: half + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: pad_height_width diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml index 81675ae8917..6fe5a67c286 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml @@ -6,5 +6,6 @@ permute_buffer: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: permute_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml index f68b8dcdd3d..22d1bdd7b51 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml @@ -6,5 +6,6 @@ permute_texture: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: permute_texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w.glsl similarity index 100% rename from backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl rename to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w.glsl diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w.yaml new file mode 100644 index 00000000000..e453214bc1a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w.yaml @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +quantize_and_pack_4h4w: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + STORAGE: texture3d + GRANULARITY: per_tensor + generate_variant_forall: + combination: + parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: quantize_and_pack_4h4w_per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.glsl similarity index 100% rename from backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.glsl rename to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.glsl diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.yaml new file mode 100644 index 00000000000..bdbc81c59d7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.yaml @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +quantize_and_pack_4h4w_with_group_sums: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: buffer + INPUT_STORAGE: texture3d + NUM_GROUPS_PER_WG: 2 + NUM_WORKERS_PER_GROUP: 32 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: quantize_and_pack_4h4w_with_group_sums_o2w32_buffer_texture3d + - NAME: quantize_and_pack_4h4w_with_group_sums_o2w32_buffer_buffer + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer + - NAME: quantize_and_pack_4h4w_with_group_sums_o4w16_buffer_texture3d + NUM_GROUPS_PER_WG: 4 + NUM_WORKERS_PER_GROUP: 16 + - NAME: quantize_and_pack_4h4w_with_group_sums_o4w16_buffer_buffer + NUM_GROUPS_PER_WG: 4 + NUM_WORKERS_PER_GROUP: 16 + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.glsl new file mode 100644 index 00000000000..dfa0b5a95bf --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.glsl @@ -0,0 +1,77 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +// corresponds to the input width dim +#define TILE_M4 1 +// corresponds to the input channels dim +#define TILE_K4 1 + +#define TILE_M 4 + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_fp_input", DTYPE, INPUT_STORAGE)} + +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float inv_scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_fp_input_tile_load.glslh" +#include "linear_int8_input_block.glslh" + +void store_packed_int8_block( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents, + const Int8InputBlock packed_int8_block) { +#ifdef OUTPUT_BUFFER + const int buffer_idx = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z + block_idx.data.z; + t_packed_int8_input[buffer_idx] = packed_int8_block.data; +#else + imageStore(t_packed_int8_input, block_idx.data, packed_int8_block.data); +#endif +} + +void main() { + Conv2dBlockIndex block_idx; + block_idx.data = ivec3(gl_GlobalInvocationID); + + Conv2dBlockExtents block_extents = make_block_extents(input_sizes); + if (block_idx_out_of_bounds(block_idx, block_extents)) { + return; + } + + FPInputTile fp_tile; + load_fp_input_tile(fp_tile, block_idx); + + Int8InputBlock int8_block; + quantize_and_pack(int8_block, fp_tile, inv_scale, zp); + + store_packed_int8_block(block_idx, block_extents, int8_block); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.yaml new file mode 100644 index 00000000000..fecc93df07b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.yaml @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +quantize_and_pack_4w4c: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + generate_variant_forall: + combination: + parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [texture3d, buffer] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] + DTYPE: + - VALUE: float + shader_variants: + - NAME: quantize_and_pack_4w4c_per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml deleted file mode 100644 index 37721db1ba8..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -quantize_and_pack_linear_input: - parameter_names_with_default_values: - DTYPE: float - OUTPUT_STORAGE: texture3d - INPUT_STORAGE: texture3d - STORAGE: texture3d - GRANULARITY: per_tensor - generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - shader_variants: - - NAME: quantize_and_pack_linear_input_per_tensor_texture3d_texture3d - - NAME: quantize_and_pack_linear_input_per_tensor_buffer_texture3d - OUTPUT_STORAGE: buffer - - NAME: quantize_and_pack_linear_input_per_tensor_buffer_buffer - OUTPUT_STORAGE: buffer - INPUT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.yaml deleted file mode 100644 index 3fc66db2718..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -quantize_and_pack_linear_input_with_sums: - parameter_names_with_default_values: - DTYPE: float - OUTPUT_STORAGE: buffer - INPUT_STORAGE: texture3d - NUM_GROUPS_PER_WG: 2 - NUM_WORKERS_PER_GROUP: 32 - generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - shader_variants: - - NAME: quantize_and_pack_linear_input_with_sums_o2w32_buffer_texture3d - - NAME: quantize_and_pack_linear_input_with_sums_o2w32_buffer_buffer - OUTPUT_STORAGE: buffer - INPUT_STORAGE: buffer - - NAME: quantize_and_pack_linear_input_with_sums_o4w16_buffer_texture3d - NUM_GROUPS_PER_WG: 4 - NUM_WORKERS_PER_GROUP: 16 - - NAME: quantize_and_pack_linear_input_with_sums_o4w16_buffer_buffer - NUM_GROUPS_PER_WG: 4 - NUM_WORKERS_PER_GROUP: 16 - OUTPUT_STORAGE: buffer - INPUT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl deleted file mode 100644 index 7bf3a932c6c..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} -#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} -#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("buffer")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(OUT_DTYPE)} -${define_required_extensions(SCALE_DTYPE)} -${define_required_extensions(ZP_DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} - -$if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - }; -$if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int axis; - int num_channels; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - ivec4 blockSize; // bW, bH, bC, bN - ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN - ivec4 blockStride; // pre-computed linear strides for the block grid - int quant_min; - int quant_max; - }; - -${layout_declare_ubo(B, "int", "out_numel")} -${layout_declare_ubo(B, "ivec4", "t_in_sizes")} -${layout_declare_ubo(B, "ivec4", "t_in_strides")} -${layout_declare_ubo(B, "ivec4", "t_out_sizes")} -${layout_declare_ubo(B, "ivec4", "t_out_strides")} - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} - -#include "quantize.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); -const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); - -/* - Quantization Shader (Buffer Storage) - This shader converts floating-point tensor values to n-bit integer representations - using pre-computed quantization parameters (scale and zero_point). The quantization - maps floating-point values to a discrete integer range while preserving the original - data distribution as much as possible. - - Important Considerations: - (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - (+) The axis map layout is assumed to be a standard layout for scales and zero_points - (++) The scale and zero_point tensors must be implemented as buffers - - Workgroup Configuration: - - quantize_per_tensor - This mode applies uniform quantization across the entire tensor using a single scale - and zero_point value. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_per_token - This mode applies quantization individually to each token (or element) in the input, - using separate scale and zero_point values for each token. For instance if we have - a tensor of shape [B, S, H] then we have B*S tokens (and s+zp pairs) of H elements each. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_per_channel - This mode applies quantization separately to each channel of the input tensor, using - distinct scale and zero_point values for each channel. For example, if the tensor shape - is [B, C, H, W] and axis = 1, quantization parameters are computed per channel C, allowing - each channel to be quantized independently. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_block_wise - This mode applies quantization in blocks or groups of elements, allowing different scale - and zero_point values for each block. It is equivalent to quantize_affine, where quantization - parameters are affine transformations applied per block. For example, if the tensor shape - is [6, 9, 4] and blockSize = [3, 3, 2], then we have 12 blocks each with 18 elements. - - (*) global_wg_size: default - (*) local_wg_size: default - - Quantization Formula: - qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max). -*/ - -#ifdef per_tensor - -void quantize_per_tensor() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T value = t_in[in_bufi]; - OUT_T qvalue = quantize_val(value, float(t_scale[0]), int(t_zero_point[0])); - - t_out[out_bufi] = qvalue; -} - -#elif defined(per_token) - -void quantize_per_token() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T value = t_in[in_bufi]; - - int token_idx = 0; - - if (t_out_sizes.w > 1) { - // 4D tensor - token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; - } else if (t_out_sizes.z > 1) { - // 3D tensor - token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; - } else if (t_out_sizes.y > 1) { - // 2D tensor - token_idx = out_tidx.y; - } - // For 1D tensor, token_idx remains 0 - - token_idx = min(token_idx, num_tokens - 1); - - OUT_T qvalue = quantize_val(value, float(t_scale[token_idx]), int(t_zero_point[token_idx])); - - t_out[out_bufi] = qvalue; -} - -#elif defined(per_channel) - -void quantize_per_channel() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T value = t_in[in_bufi]; - - // Calculate channel index based on the quantization axis (already converted to WHCN) - // The axis parameter is now in WHCN coordinate system: - // axis 0 -> W dimension (tidx.x) - // axis 1 -> H dimension (tidx.y) - // axis 2 -> C dimension (tidx.z) - // axis 3 -> N dimension (tidx.w) - int channel_idx = 0; - - if (axis == 0) { - channel_idx = out_tidx.x; - } else if (axis == 1) { - channel_idx = out_tidx.y; - } else if (axis == 2) { - channel_idx = out_tidx.z; - } else if (axis == 3) { - channel_idx = out_tidx.w; - } - - channel_idx = min(channel_idx, num_channels - 1); - - OUT_T qvalue = quantize_val(value, float(t_scale[channel_idx]), int(t_zero_point[channel_idx])); - - t_out[out_bufi] = qvalue; -} - -#else // block_wise - -void quantize_block_wise() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T value = t_in[in_bufi]; - - const ivec4 bcoord = out_tidx / blockSize; - - const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - - const OUT_T qvalue = quantize_val(value, float(t_scale[block_id]), int(t_zero_point[block_id])); - - t_out[out_bufi] = qvalue; -} - -#endif - -void main() { - quantize_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml deleted file mode 100644 index fb5853ecd20..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml +++ /dev/null @@ -1,31 +0,0 @@ -quantize_buffer: - parameter_names_with_default_values: - IN_DTYPE: float - OUT_DTYPE: int32 - SCALE_DTYPE: float - ZP_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - OUT_DTYPE: - - VALUE: uint8 - - VALUE: int8 - - VALUE: int32 - SCALE_DTYPE: - - VALUE: float - ZP_DTYPE: - - VALUE: int8 - - VALUE: int32 - - VALUE: float - shader_variants: - - NAME: quantize_per_tensor_buffer - MODE: per_tensor - - NAME: quantize_per_token_buffer - MODE: per_token - - NAME: quantize_per_channel_buffer - MODE: per_channel - - NAME: quantize_block_wise_buffer - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl deleted file mode 100644 index 12e5769f50d..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl +++ /dev/null @@ -1,312 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} - -#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} -#define IVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} -#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} -#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("texture3d")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(OUT_DTYPE)} -${define_required_extensions(SCALE_DTYPE)} -${define_required_extensions(ZP_DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} - -$if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - }; -$if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int axis; - int num_channels; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict BlockPC { - ivec4 blockSize; // WHCN - ivec4 numBlocks; // (#W,#H,#C,#N) - ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} - int quant_min; - int quant_max; - }; - -${layout_declare_ubo(B, "ivec3", "t_in_limits")} -${layout_declare_ubo(B, "ivec3", "t_out_limits")} - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} - -#include "quantize.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -/* - Quantization Shader (Texture Storage) - This shader converts floating-point tensor values to n-bit integer representations - using pre-computed quantization parameters (scale and zero_point). The quantization - maps floating-point values to a discrete integer range while preserving the original - data distribution as much as possible. - - Important Considerations: - (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - (+) The axis map layout is assumed to be a standard layout for scales and zero_points - (++) The scale and zero_point tensors must be implemented as buffers - - Workgroup Configuration: - - quantize_per_tensor - This mode applies uniform quantization across the entire tensor using a single scale - and zero_point value. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_per_token - This mode applies quantization individually to each token (or element) in the input, - using separate scale and zero_point values for each token. For instance if we have - a tensor of shape [B, S, H] then we have B*S tokens (and s+zp pairs) of H elements each. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_per_channel - This mode applies quantization separately to each channel of the input tensor, using - distinct scale and zero_point values for each channel. For example, if the tensor shape - is [B, C, H, W] and axis = 1, quantization parameters are computed per channel C, allowing - each channel to be quantized independently. - - (*) global_wg_size: default - (*) local_wg_size: Default with special handling for batch dimension. When quantizing along - the batch axis, Z dimension is set to 1 to ensure correct workgroup dispatching. Otherwise, - uses standard workgroup size derived from global workgroup dimensions. - - - quantize_block_wise - This mode applies quantization in blocks or groups of elements, allowing different scale - and zero_point values for each block. It is equivalent to quantize_affine, where quantization - parameters are affine transformations applied per block. For example, if the tensor shape - is [6, 9, 4] and blockSize = [3, 3, 2], then we have 12 blocks each with 18 elements. - - (*) global_wg_size: default - (*) local_wg_size: Default with special handling for batch dimension. When quantizing along - the batch axis, Z dimension is set to 1 to ensure correct workgroup dispatching. Otherwise, - uses standard workgroup size derived from global workgroup dimensions. - - Quantization Formula: - qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max). -*/ - -#ifdef per_tensor - -void quantize_per_tensor() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - FVEC4_T intex = load_texel(t_in, pos); - IVEC4_T outtex; - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, float(t_scale[0]), int(t_zero_point[0])); - outtex[i] = qvalue; - } - write_texel(t_out, pos, outtex); -} - -#elif defined(per_token) - -void quantize_per_token() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - FVEC4_T intex = load_texel(t_in, pos); - - int token_idx = 0; - ivec3 dims = t_in_limits; - - if (dims.z > 1) { - // 3D tensor - token_idx = pos.z * dims.y + pos.y; - } else if (dims.y > 1) { - // 2D tensor - token_idx = pos.y; - } - // For 1D tensor, token_idx remains 0 - - token_idx = min(token_idx, num_tokens - 1); - - // Scale and zero_point are prepacked as buffers, so direct access - float scale_val = float(t_scale[token_idx]); - int zero_point_val = int(t_zero_point[token_idx]); - - IVEC4_T outtex; - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - - write_texel(t_out, pos, outtex); -} - -#elif defined(per_channel) - -void quantize_per_channel() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - FVEC4_T intex = load_texel(t_in, pos); - IVEC4_T outtex; - - // Calculate channel index based on the quantization axis (already converted to WHCN) - // The axis parameter is now in WHCN coordinate system: - // axis 0 -> W dimension (pos.x for texture, but width-packed so pos.x * 4 + component) - // axis 1 -> H dimension (pos.y) - // axis 2 -> C dimension (pos.z / C), but for 4D tensors this includes batch-channel folding - // axis 3 -> N dimension (pos.z / N), but for 4D tensors this includes batch-channel folding - - if (axis == 0) { - // Width dimension - each texel component has different channel index - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - int channel_idx = pos.x * 4 + i; - channel_idx = min(channel_idx, num_channels - 1); - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - } else if (axis == 1) { - // Height dimension - all texel components use same channel index - int channel_idx = pos.y; - channel_idx = min(channel_idx, num_channels - 1); - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - } else if (axis == 2) { - // Channel dimension - for 4D tensors, need to account for batch-channel folding - // The Z coordinate contains folded batch*channel information - // We need to extract the actual channel index from the folded dimension - int folded_idx = pos.z; - int channel_idx = folded_idx % num_channels; - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - } else if (axis == 3) { - // Batch dimension - for 4D tensors, need to account for batch-channel folding - // The Z coordinate contains folded batch*channel information - // We need to extract the actual batch index from the folded dimension - int folded_idx = pos.z; - int batch_idx = folded_idx / num_channels; - - float scale_val = float(t_scale[batch_idx]); - int zero_point_val = int(t_zero_point[batch_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - } - - write_texel(t_out, pos, outtex); -} - -#else // block_wise - -void quantize_block_wise() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) - return; - - FVEC4_T intex = load_texel(t_in, pos); - IVEC4_T outtex; - - ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0); - int foldedZ = pos.z; - - int C_total = numBlocks.z * blockSize.z; - - [[unroll]] for (int i = 0; i < 4; ++i) { - ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total)); - - ivec4 bcoord = tidx / blockSize; - int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, float(t_scale[block_id]), int(t_zero_point[block_id])); - outtex[i] = qvalue; - } - - write_texel(t_out, pos, outtex); -} - -#endif - -void main() { - quantize_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml deleted file mode 100644 index 03d418ff2f7..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml +++ /dev/null @@ -1,31 +0,0 @@ -quantize_texture: - parameter_names_with_default_values: - IN_DTYPE: float - OUT_DTYPE: int32 - SCALE_DTYPE: float - ZP_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - OUT_DTYPE: - - VALUE: uint8 - - VALUE: int8 - - VALUE: int32 - SCALE_DTYPE: - - VALUE: float - ZP_DTYPE: - - VALUE: int8 - - VALUE: int32 - - VALUE: float - shader_variants: - - NAME: quantize_per_tensor_texture3d - MODE: per_tensor - - NAME: quantize_per_token_texture3d - MODE: per_token - - NAME: quantize_per_channel_texture3d - MODE: per_channel - - NAME: quantize_block_wise_texture3d - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce_op_defs.glslh b/backends/vulkan/runtime/graph/ops/glsl/reduce_op_defs.glslh new file mode 100644 index 00000000000..e5f61da7586 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce_op_defs.glslh @@ -0,0 +1,96 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef REDUCE_OP_DEFS_GLSLH +#define REDUCE_OP_DEFS_GLSLH + +struct Accum { + ACCUM_T val; + uint idx; + uint count; +}; + +void init_accum(out Accum accum, T val, uint idx) { + accum.val = ACCUM_T(val); + accum.idx = idx; + accum.count = 1; +} + +void init_accum_zero(out Accum accum) { + accum.val = T(0); + accum.idx = 0; + accum.count = 0; +} + +// Sum / Mean + +void update_accum_sum(inout Accum accum, T val, uint idx) { + accum.val += ACCUM_T(val); + accum.count += 1; +} + +void merge_accum_sum(inout Accum accum, const Accum other) { + accum.val += other.val; + accum.count += other.count; +} + +void postprocess_accum_mean(inout Accum accum) { + accum.val /= T(accum.count); +} + +// Amax (maximum value) + +void update_accum_amax(inout Accum accum, T in_val, uint idx) { + ACCUM_T val = ACCUM_T(in_val); + if (val > accum.val) { + accum.val = val; + accum.idx = idx; + } + // For equivalence, select the lower index + if (val == accum.val && idx < accum.idx) { + accum.idx = idx; + } +} + +void merge_accum_amax(inout Accum accum, const Accum other) { + if (other.val > accum.val) { + accum.val = other.val; + accum.idx = other.idx; + } + // For equivalence, select the lower index + if (other.val == accum.val && other.idx < accum.idx) { + accum.idx = other.idx; + } +} + +// Amin (minimum value) + +void update_accum_amin(inout Accum accum, T in_val, uint idx) { + ACCUM_T val = ACCUM_T(in_val); + if (val < accum.val) { + accum.val = val; + accum.idx = idx; + } + // For equivalence, select the lower index + if (val == accum.val && idx < accum.idx) { + accum.idx = idx; + } +} + +void merge_accum_amin(inout Accum accum, const Accum other) { + if (other.count > 0 && (accum.count == 0 || other.val < accum.val)) { + accum.val = other.val; + accum.idx = other.idx; + } + // For equivalence, select the lower index + if (other.val == accum.val && other.idx < accum.idx) { + accum.idx = other.idx; + } +} + +#endif // REDUCE_OP_DEFS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl new file mode 100644 index 00000000000..af5f5f661e7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define ACCUM_T ${accum_scalar_type(DTYPE)} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +#define NUM_OUTPUTS_PER_WG 1 +#define NUM_WORKERS_PER_OUTPUT 64 + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" +#include "convert.glslh" +#include "reduce_op_defs.glslh" + +$if OUTPUT_IS_INDICES: + ${layout_declare_tensor(B, "w", "t_out", "int", "buffer")} +$else: + ${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} + +${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Shared memory for cooperative reduction +shared Accum shared_values[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT]; + +#define init_fn ${INIT_ACCUM_FN} +#define update_fn ${UPDATE_ACCUM_FN} +#define merge_fn ${MERGE_ACCUM_FN} + +$if POSTPROCESS_ACCUM_FN != "none": + #define postprocess_fn ${POSTPROCESS_ACCUM_FN} + +$if OOB_INIT_MODE == "zero": + #define OOB_INIT_MODE 0 +$else: + #define OOB_INIT_MODE 1 + +$if OUTPUT_IS_INDICES: + #define OUTPUT_IS_INDICES + +void main() { + const uint out_bufi = gl_GlobalInvocationID.y; + + if (out_of_bounds(out_bufi, outp)) { + return; + } + + // Local indices + const uint worker_id = gl_LocalInvocationID.x; + const uint output_id = gl_LocalInvocationID.y; + + const uint in_bufi_base = out_bufi * width(inp); + + Accum local_accum; + // Initialize accumulator with the first element being processed + if (worker_id < width(inp)) { + const uint in_bufi = in_bufi_base + worker_id; + init_fn(local_accum, t_in[in_bufi], worker_id); + } + // For out of bounds case, initialization depends on reduction op + else { +#if OOB_INIT_MODE == 0 + // Init with a zero value + init_accum_zero(local_accum); +#else + // Init with the first value (i.e. amin, amax) + init_fn(local_accum, t_in[in_bufi_base], 0); +#endif + } + + for (uint x = worker_id + NUM_WORKERS_PER_OUTPUT; x < width(inp); + x += NUM_WORKERS_PER_OUTPUT) { + update_fn(local_accum, t_in[in_bufi_base + x], x); + } + + shared_values[output_id][worker_id] = local_accum; + + memoryBarrierShared(); + barrier(); + + for (int i = NUM_WORKERS_PER_OUTPUT / 2; i > 0; i >>= 1) { + if (worker_id < i) { + merge_fn( + shared_values[output_id][worker_id], + shared_values[output_id][worker_id + i]); + } + memoryBarrierShared(); + barrier(); + } + + if (worker_id == 0) { + local_accum = shared_values[output_id][0]; +#ifdef postprocess_fn + postprocess_fn(local_accum); +#endif + +#ifdef OUTPUT_IS_INDICES + t_out[out_bufi] = int(0); // int(local_accum.idx); +#else + t_out[out_bufi] = convert_to_T(local_accum.val); +#endif + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.yaml new file mode 100644 index 00000000000..e5a94165b96 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.yaml @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +reduce_per_row_buffer: + parameter_names_with_default_values: + DTYPE: float + INIT_ACCUM_FN: init_accum + UPDATE_ACCUM_FN: update_accum_sum + MERGE_ACCUM_FN: merge_accum_sum + POSTPROCESS_ACCUM_FN: none + OOB_INIT_MODE: zero + OUTPUT_IS_INDICES: false + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + - VALUE: int32 + shader_variants: + - NAME: sum_per_row_buffer + - NAME: mean_per_row_buffer + POSTPROCESS_ACCUM_FN: postprocess_accum_mean + - NAME: amax_per_row_buffer + UPDATE_ACCUM_FN: update_accum_amax + MERGE_ACCUM_FN: merge_accum_amax + OOB_INIT_MODE: first_element + - NAME: amin_per_row_buffer + UPDATE_ACCUM_FN: update_accum_amin + MERGE_ACCUM_FN: merge_accum_amin + OOB_INIT_MODE: first_element + - NAME: argmax_per_row_buffer + UPDATE_ACCUM_FN: update_accum_amax + MERGE_ACCUM_FN: merge_accum_amax + OOB_INIT_MODE: first_element + OUTPUT_IS_INDICES: true + - NAME: argmin_per_row_buffer + UPDATE_ACCUM_FN: update_accum_amin + MERGE_ACCUM_FN: merge_accum_amin + OOB_INIT_MODE: first_element + OUTPUT_IS_INDICES: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml index 4147e82965a..c48237f7568 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml @@ -6,5 +6,7 @@ repeat_channel: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: repeat_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml index 5c284a580c9..f56172dc7f0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml @@ -6,5 +6,7 @@ repeat_interleave: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: repeat_interleave diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl index 30375728921..155eda467c4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl @@ -13,23 +13,29 @@ #define VEC4_T ${texel_load_type(DTYPE, STORAGE)} ${define_required_extensions(DTYPE)} +${define_active_storage_type(STORAGE)} layout(std430) buffer; -${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)} -${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)} -${layout_declare_ubo(B, "ivec3", "xqout_limits")} -${layout_declare_ubo(B, "ivec3", "xkout_limits")} +#include "indexing.glslh" -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_tensor(B, "w", "t_xqout", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_xkout", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_xq", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_xk", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_freqs_cos", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_freqs_sin", DTYPE, STORAGE, is_scalar_array=False)} -layout(constant_id = 3) const int packed_dim = 0; +$if STORAGE == "buffer": + ${layout_declare_ubo(B, "BufferMetadata", "xqout")} + ${layout_declare_ubo(B, "BufferMetadata", "xkout")} + ${layout_declare_ubo(B, "BufferMetadata", "freqs_cos")} +$else: + ${layout_declare_ubo(B, "TextureMetadata", "xqout")} + ${layout_declare_ubo(B, "TextureMetadata", "xkout")} + ${layout_declare_ubo(B, "TextureMetadata", "freqs_cos")} -#include "indexing_utils.h" +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; /* * This shader computes rotary positional embeddings which are used in the Llama @@ -39,7 +45,7 @@ layout(constant_id = 3) const int packed_dim = 0; * 1. xq (batch_size, sequence_len, num_heads, head_dim) * 2. xk (batch_size, sequence_len, num_kv_heads, head_dim) * 3. freqs_cos (sequence_len, head_dim / 2) - * 4. freqs_cos (sequence_len, head_dim / 2) + * 4. freqs_sin (sequence_len, head_dim / 2) * * Two output tensors are produced, with the same shapes as xq and xk * respectively. @@ -66,23 +72,43 @@ void main() { // Each thread will write to two output locations to maximize data re-use. // One texel loaded from the freqs_cos/freqs_sin tensors can be used to // calculate two output texels. - const ivec3 x_pos_1 = ivec3( - gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz); - const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz); + TensorIndex4D out_tidx_1 = zero_tensor4d_idx(); + out_tidx_1.data.x = int(gl_GlobalInvocationID.x) * 8; + out_tidx_1.data.yz = ivec2(gl_GlobalInvocationID.yz); + + TensorIndex4D out_tidx_2 = out_tidx_1; + out_tidx_2.data.x += 4; - if (any(greaterThanEqual(x_pos_2, xqout_limits))) { + if (out_of_bounds(out_tidx_2, xqout)) { return; } - const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0); + TensorIndex4D freqs_tidx = zero_tensor4d_idx(); + freqs_tidx.data.x = int(gl_GlobalInvocationID.x) * 4; + freqs_tidx.data.y = out_tidx_1.data.z; - VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos); - VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos); +#ifdef USING_BUFFER + const uint freqs_texel_bufi = div_4(tensor4d_idx_to_linear_idx(freqs_cos, freqs_tidx)); + VEC4_T cos_tex = t_freqs_cos[freqs_texel_bufi]; + VEC4_T sin_tex = t_freqs_sin[freqs_texel_bufi]; - // Compute xqout + uint x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_1)); + uint x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_2)); + VEC4_T x_tex_1 = t_xq[x_texel_bufi_1]; + VEC4_T x_tex_2 = t_xq[x_texel_bufi_2]; + +#else // USING_TEXTURE + const ivec3 freqs_pos = tensor4d_idx_to_texel_pos_simple(freqs_cos, freqs_tidx); + VEC4_T cos_tex = texelFetch(t_freqs_cos, freqs_pos, 0); + VEC4_T sin_tex = texelFetch(t_freqs_sin, freqs_pos, 0); - VEC4_T x_tex_1 = load_texel(xq, x_pos_1); - VEC4_T x_tex_2 = load_texel(xq, x_pos_2); + const ivec3 x_pos_1 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_1); + const ivec3 x_pos_2 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_2); + VEC4_T x_tex_1 = texelFetch(t_xq, x_pos_1, 0); + VEC4_T x_tex_2 = texelFetch(t_xq, x_pos_2, 0); +#endif + + // Compute xqout // Separate into even and odd elements VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); @@ -94,20 +120,34 @@ void main() { VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); - write_texel(xqout, x_pos_1, xout_tex_1); - write_texel(xqout, x_pos_2, xout_tex_2); +#ifdef USING_BUFFER + t_xqout[x_texel_bufi_1] = xout_tex_1; + t_xqout[x_texel_bufi_2] = xout_tex_2; +#else // USING_TEXTURE + imageStore(t_xqout, x_pos_1, xout_tex_1); + imageStore(t_xqout, x_pos_2, xout_tex_2); +#endif // n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout // may have a larger height dim than xk and xkout. Only compute xkout if this // invocation is still within bounds. - if (any(greaterThanEqual(x_pos_2, xkout_limits))) { + if (out_of_bounds(out_tidx_2, xkout)) { return; } // Compute xkout - x_tex_1 = load_texel(xk, x_pos_1); - x_tex_2 = load_texel(xk, x_pos_2); +#ifdef USING_BUFFER + x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_1)); + x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_2)); + + x_tex_1 = t_xk[x_texel_bufi_1]; + x_tex_2 = t_xk[x_texel_bufi_2]; + +#else // USING_TEXTURE + x_tex_1 = texelFetch(t_xk, x_pos_1, 0); + x_tex_2 = texelFetch(t_xk, x_pos_2, 0); +#endif x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw); @@ -118,6 +158,11 @@ void main() { xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); - write_texel(xkout, x_pos_1, xout_tex_1); - write_texel(xkout, x_pos_2, xout_tex_2); +#ifdef USING_BUFFER + t_xkout[x_texel_bufi_1] = xout_tex_1; + t_xkout[x_texel_bufi_2] = xout_tex_2; +#else // USING_TEXTURE + imageStore(t_xkout, x_pos_1, xout_tex_1); + imageStore(t_xkout, x_pos_2, xout_tex_2); +#endif } diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml index a81fd564d10..ba8aa400958 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml @@ -3,6 +3,9 @@ rotary_embedding: DTYPE: float STORAGE: texture3d generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer DTYPE: - VALUE: half - VALUE: float diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl index 1dff0017f30..652453bbec7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl @@ -16,6 +16,8 @@ ${define_active_storage_type(STORAGE)} +${define_required_extensions(DTYPE)} + #extension GL_EXT_control_flow_attributes : require layout(std430) buffer; @@ -74,6 +76,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // manually determine size of the context_len dim of the attention weight. // The "actual" tensor sizes may have been aligned to a multiple of 4 to allow // memory loads to be aligned to texel boundaries. @@ -85,7 +88,7 @@ void main() { } // Initialize thread-local min/max - T local_exp_sum = 0; + T local_exp_sum = T(0); const int context_len_aligned_down = context_len - mod_4(context_len); const int C4_limit = div_4(context_len_aligned_down); @@ -94,7 +97,7 @@ void main() { // number of threads in the work group. for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); for (int comp = 0; comp < 4; comp++) { local_exp_sum += exp(in_texel[comp]); @@ -106,7 +109,7 @@ void main() { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { @@ -136,11 +139,11 @@ void main() { // Now go back through each element in the row and normalize for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); VEC4_T out_texel = exp(in_texel) / local_exp_sum; store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S, Q_H); + out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); } // First thread in the work group responsible for handling last texel if it // contains any padded elements @@ -148,7 +151,7 @@ void main() { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); VEC4_T in_texel = load_attn_weights_c4( - c4, s, q_h, context_texel_len, S, Q_H); + c4, s, q_h, context_texel_len, S_aligned, Q_H); // Ensure that padding elements are set to 0. VEC4_T out_texel = VEC4_T(0); @@ -158,7 +161,7 @@ void main() { } } store_attn_weights_softmax_c4( - out_texel, c4, s, q_h, context_texel_len, S, Q_H); + out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml index 8abf50399e0..66ec030680e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml @@ -14,5 +14,6 @@ sdpa_attn_weights_softmax: - VALUE: buffer DTYPE: - VALUE: float + - VALUE: half shader_variants: - NAME: sdpa_attn_weights_softmax diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl index 2900d63666b..7dec6c1697f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -81,6 +81,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = k_cache_sizes.y; @@ -118,55 +119,27 @@ void main() { } // Otherwise, need to actually compute output tile else { - const bool dont_check_bounds = (S - s) >= TILE_M && - (context_len - c) >= TILE_N; - - if (dont_check_bounds) { - for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { - load_q_projected_tile_no_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_no_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } - } else { - for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { - load_q_projected_tile_with_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_with_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } + for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); } } @@ -205,7 +178,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml index 6a4cffcc913..d5cadc36060 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml @@ -12,10 +12,14 @@ sdpa_compute_attn_weights_coop: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_attn_weights_coop_texture3d_texture3d - - NAME: sdpa_compute_attn_weights_coop_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_attn_weights_coop diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl index 95c22d91b80..2892f74e05f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -93,6 +93,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = k_cache_sizes.y; @@ -129,55 +130,28 @@ void main() { } // Otherwise, need to actually compute output tile else { - const bool dont_check_bounds = (S - s) >= TILE_M && - (context_len - c) >= TILE_N; - - if (dont_check_bounds) { - for (int d4 = 0; d4 < D4; d4++) { - load_q_projected_tile_no_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_no_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } - } else { - for (int d4 = 0; d4 < D4; d4++) { - load_q_projected_tile_with_checks( - q_tile, - d4, - s, - q_h, - D4, - Q_H, - S); - - load_k_cache_tile_with_checks( - w_tile, - d4, - c, - kv_h, - D4, - context_len, - C, - KV_H); - - fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); - } + for (int d4 = 0; d4 < D4; d4++) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); } // Apply scale and mask @@ -196,6 +170,6 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml index 6aadbbc379e..7fc016cf3c3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml @@ -13,10 +13,14 @@ sdpa_compute_attn_weights_tiled: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_attn_weights_tiled_texture3d_texture3d - - NAME: sdpa_compute_attn_weights_tiled_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_attn_weights_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl index 5f408b7581d..cc60193cf18 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl @@ -81,6 +81,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = v_cache_sizes.y; @@ -120,7 +121,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_no_checks( @@ -146,7 +147,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_with_checks( diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml index ccebf8f7c1c..33ec2f8b322 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml @@ -12,10 +12,14 @@ sdpa_compute_out_coop: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, V_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_out_coop_texture3d_texture3d - - NAME: sdpa_compute_out_coop_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_out_coop diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl index 0063ebf9d38..385ad7a921e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl @@ -75,6 +75,7 @@ void main() { const int Q_H = q_projected_sizes.y; // sequence length const int S = q_projected_sizes.z; + const int S_aligned = align_up_4(S); // number of K/V heads const int KV_H = v_cache_sizes.y; @@ -113,7 +114,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_no_checks( @@ -136,7 +137,7 @@ void main() { s, q_h, context_texel_len, - S, + S_aligned, Q_H); load_v_cache_tile_with_checks( diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml index 7fbce29e908..eac2c6f37dd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml @@ -13,10 +13,14 @@ sdpa_compute_out_tiled: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, V_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_out_tiled_texture3d_texture3d - - NAME: sdpa_compute_out_tiled_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_out_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh index 03132db1348..1880397181d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh @@ -44,7 +44,6 @@ void load_k_cache_tile_no_checks( const int context_len, const int C, const int KV_H) { - bool should_print = d4_start == 0 && c_start == 0 && kv_h == 0; [[unroll]] for (int c = 0; c < TILE_N; ++c) { const int c4 = div_4(c); const int c4i = mod_4(c); diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl index 932696fff02..5f7e4c2719d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl @@ -5,6 +5,8 @@ #define IN_VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} #define T ${buffer_scalar_type(DTYPE)} +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER $if INPUT_STORAGE == "buffer": #define INPUT_BUFFER @@ -78,13 +80,17 @@ void main() { const int S = projected_sizes.z; const int H = projected_sizes.y; - if (d4 >= D4 || s >= S || h >= H) { + const int c = s + input_pos; // idx along max_context_len dim + const int C = cache_sizes.z; + + if (d4 >= D4 || c >= C || h >= H) { return; } - const int c = s + input_pos; // idx along max_context_len dim - const int C = cache_sizes.y; + IN_VEC4_T in_texel = IN_VEC4_T(0.0); + if (s < S) { + in_texel = read_projected_d4(d4, h, s, D4, H, S); + } - IN_VEC4_T in_texel = read_projected_d4(d4, h, s, D4, H, S); write_cache_d4(in_texel, d4, c, h, D4, C, H); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml index 85f4ce090f8..5ec2f3e190c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml @@ -10,10 +10,14 @@ sdpa_kv_cache_update: INPUT_STORAGE: texture3d OUTPUT_STORAGE: texture3d generate_variant_forall: + combination: + parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [texture3d, buffer] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: half - VALUE: float shader_variants: - - NAME: sdpa_kv_cache_update_texture3d - - NAME: sdpa_kv_cache_update_buffer - INPUT_STORAGE: buffer + - NAME: sdpa_kv_cache_update diff --git a/backends/vulkan/runtime/graph/ops/glsl/select.glslh b/backends/vulkan/runtime/graph/ops/glsl/select.glslh index 6509015b4b6..5390e2a4bb2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/select.glslh @@ -9,70 +9,87 @@ #ifndef SELECT_GLSLH #define SELECT_GLSLH -#ifndef USING_BUFFER +#ifdef USING_BUFFER /* - * Enable the fast path if a texel loaded from the input texture can be used as - * is to store to the output texture. The following conditions must be met: + * Converts output tensor indices to input tensor indices for the select operation + * on buffer storage. * - * 1. The input and output textures have the same packed dimension. - * 2. The selected_dim must not be the packed dimension of the input. - * 3. The packed dimension of the input must "map" to the packed dimension of - * the output. This occurs if selected_dim is greater than the packed dimension - * of the input. + * This is done by "inserting" the select index at the selected_dim in the input + * tensor index. + * + * Parameters assumed to be defined: + * - inp: BufferMetadata + * - selected_dim + * - index */ -bool can_use_fast_path() { - if (out_packed_dim != in_packed_dim) { - return false; +TensorIndex out_tidx_to_in_tidx(const TensorIndex out_tidx) { + TensorIndex in_tidx; + initialize(in_tidx); + + int in_size = int(size_at(inp, selected_dim)); + int adjusted_index = index; + if (index < 0) { + adjusted_index = index + in_size; } - if (selected_dim <= in_packed_dim) { - return false; + + // Copy indices before selected_dim + for (int d = 0; d < selected_dim; d++) { + in_tidx.data[div_4(d)][mod_4(d)] = idx_at(out_tidx, d); } - return true; + + // Insert the selected index + in_tidx.data[div_4(selected_dim)][mod_4(selected_dim)] = adjusted_index; + + // Copy indices after selected_dim (shifted by 1) + for (int d = selected_dim; d < int_ndim(inp) - 1; d++) { + in_tidx.data[div_4(d + 1)][mod_4(d + 1)] = idx_at(out_tidx, d); + } + + return in_tidx; } -#endif // USING_BUFFER +#else // texture storage /* - * Given an output tensor index, return the corresponding input tensor index for - * the select operator. This is done by "inserting" the select index at the - * selected_dim in the input tensor index. + * Converts output tensor indices to input tensor indices for the select operation + * on texture storage. * - * A simple example is (note all tensor index are in WHCN order): - * out_tidx = [7, 5, 9] - * selected_dim = 2 - * index = 3 - * in_tidx = [7, 3, 5, 9] + * This is done by "inserting" the select index at the selected_dim in the input + * tensor index. * - * This function assumes that the following variables are defined in the layout: - * - in_sizes + * Parameters assumed to be defined: + * - inp: TextureMetadata * - selected_dim * - index */ -ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) { - ivec4 in_tidx = ivec4(0); +TensorIndex4D out_tidx_to_in_tidx(const TensorIndex4D out_tidx) { + TensorIndex4D in_tidx; + in_tidx.data = ivec4(0); int adjusted_index = index; if (index < 0) { - adjusted_index = index + in_sizes[selected_dim]; + adjusted_index = index + inp.sizes[selected_dim]; } // Handle different dimensions for selection if (selected_dim == 0) { // Select from width dimension - in_tidx = ivec4(adjusted_index, out_tidx.x, out_tidx.y, out_tidx.z); + in_tidx.data = ivec4(adjusted_index, out_tidx.data.x, out_tidx.data.y, out_tidx.data.z); } else if (selected_dim == 1) { // Select from height dimension - in_tidx = ivec4(out_tidx.x, adjusted_index, out_tidx.y, out_tidx.z); + in_tidx.data = ivec4(out_tidx.data.x, adjusted_index, out_tidx.data.y, out_tidx.data.z); } else if (selected_dim == 2) { // Select from channel dimension - in_tidx = ivec4(out_tidx.x, out_tidx.y, adjusted_index, out_tidx.z); + in_tidx.data = ivec4(out_tidx.data.x, out_tidx.data.y, adjusted_index, out_tidx.data.z); } else if (selected_dim == 3) { // Select from batch dimension - in_tidx = ivec4(out_tidx.x, out_tidx.y, out_tidx.z, adjusted_index); + in_tidx.data = ivec4(out_tidx.data.x, out_tidx.data.y, out_tidx.data.z, adjusted_index); } return in_tidx; } +#endif // USING_BUFFER + #endif // SELECT_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/slice.glslh b/backends/vulkan/runtime/graph/ops/glsl/slice.glslh index 87325754f4d..0a815c85d66 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/slice.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/slice.glslh @@ -9,49 +9,61 @@ #ifndef SLICE_GLSLH #define SLICE_GLSLH -#ifndef USING_BUFFER +#include "indexing.glslh" -/** - * Enable the fast path if a texel loaded from the input texture can be used as - * is to store to the output texture. The following conditions must be met: +#ifdef USING_BUFFER + +/* + * Converts output tensor indices to input tensor indices for the slice operation + * on buffer storage. * - * 1. The input and output textures have the same packed dimension. - * 2. The select_dim must not be the packed dimension of the input. + * Parameters assumed to be defined: + * - inp: BufferMetadata + * - selected_dim + * - start + * - step */ -bool can_use_fast_path() { - if (out_packed_dim != in_packed_dim) { - return false; - } - if (in_packed_dim == selected_dim) { - return false; +TensorIndex out_tidx_to_in_tidx(const TensorIndex out_tidx) { + TensorIndex in_tidx = out_tidx; + + int in_size = int(size_at(inp, selected_dim)); + int adjusted_start = start; + if (start < 0) { + adjusted_start = start + in_size; } - return true; + + uint out_idx = idx_at(out_tidx, selected_dim); + in_tidx.data[div_4(selected_dim)][mod_4(selected_dim)] = + adjusted_start + int(out_idx) * step; + + return in_tidx; } -#endif // USING_BUFFER +#else // texture storage /* - * Converts output tensor indices to input tensor indices for the slice operation. - * This function maps the output indices to the corresponding input indices based on - * the slice parameters (start, step, selected_dim). + * Converts output tensor indices to input tensor indices for the slice operation + * on texture storage. * - * Parameters assumed to be defined in the layout specifier: - * - in_sizes + * Parameters assumed to be defined: + * - inp: TextureMetadata * - selected_dim * - start * - step */ -ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) { - ivec4 in_tidx = out_tidx; +TensorIndex4D out_tidx_to_in_tidx(const TensorIndex4D out_tidx) { + TensorIndex4D in_tidx = out_tidx; int adjusted_start = start; if (start < 0) { - adjusted_start = start + in_sizes[selected_dim]; + adjusted_start = start + inp.sizes[selected_dim]; } - in_tidx[selected_dim] = adjusted_start + out_tidx[selected_dim] * step; + in_tidx.data[selected_dim] = adjusted_start + out_tidx.data[selected_dim] * step; return in_tidx; } +#endif // USING_BUFFER + #endif // SLICE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl index d35492bc367..9b44d5c5a94 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl @@ -23,8 +23,10 @@ layout(std430) buffer; ${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)} ${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)} -${layout_declare_ubo(B, "ivec3", "tout_limits")} -${layout_declare_ubo(B, "ivec4", "tin_sizes")} +layout(push_constant) uniform restrict Block { + ivec4 tin_sizes; + ivec3 tout_limits; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -42,7 +44,8 @@ layout(constant_id = 5) const int group_dim = 1; // work group will write into its assigned element in the shared array. #define MAX_NTHREADS 16 -shared vec4 shared_vecs[MAX_NTHREADS]; +shared vec4 shared_max[MAX_NTHREADS]; +shared vec4 shared_sum[MAX_NTHREADS]; #include "indexing_utils.h" @@ -102,13 +105,13 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { max_elements = max(max_elements, load_texel(tin, scan_pos)); } - shared_vecs[smi] = max_elements; + shared_max[smi] = max_elements; barrier(); // Iterate over the partial maximums to obtain the overall maximum group_i = tid.y * NWORKERS; - max_elements = shared_vecs[group_i++]; + max_elements = shared_max[group_i++]; for (int i = 1; i < NWORKERS; ++i, group_i++) { - max_elements = max(max_elements, shared_vecs[group_i]); + max_elements = max(max_elements, shared_max[group_i]); } scan_pos[reduce_dim] = tid.x; @@ -118,13 +121,13 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { denominators += exp(load_texel(tin, scan_pos) - max_elements); } - shared_vecs[smi] = denominators; + shared_sum[smi] = denominators; barrier(); // Iterate over the partial sums to obtain the overall sum group_i = tid.y * NWORKERS; - denominators = shared_vecs[group_i++]; + denominators = shared_sum[group_i++]; for (int i = 1; i < NWORKERS; ++i, group_i++) { - denominators += shared_vecs[group_i]; + denominators += shared_sum[group_i]; } // Determine if there are any padding elements in the final texel of the @@ -184,13 +187,13 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { max_elements.x = max(intex[i], max_elements.x); } } - shared_vecs[smi] = max_elements; + shared_max[smi] = max_elements; barrier(); // Iterate over the partial maximums to obtain the overall maximum group_i = tid.y * NWORKERS; - max_elements = shared_vecs[group_i++]; + max_elements = shared_max[group_i++]; for (int i = 1; i < NWORKERS; ++i, group_i++) { - max_elements = max(max_elements, shared_vecs[group_i]); + max_elements = max(max_elements, shared_max[group_i]); } // Each element of the texel is itself a partial maximum; iterate over the // texel to find the actual maximum @@ -214,13 +217,13 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { denominators.x += exp(intex[i] - max_element); } } - shared_vecs[smi] = denominators; + shared_sum[smi] = denominators; barrier(); // Iterate over the partial sums to obtain the overall sum group_i = tid.y * NWORKERS; - denominators = shared_vecs[group_i++]; + denominators = shared_sum[group_i++]; for (int i = 1; i < NWORKERS; ++i, group_i++) { - denominators += shared_vecs[group_i]; + denominators += shared_sum[group_i]; } // Reduce over the accumulated texel to find the overall sum float denominator = 0; diff --git a/backends/vulkan/runtime/graph/ops/glsl/split_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/split_buffer.glsl new file mode 100644 index 00000000000..0505c9e7bcd --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/split_buffer.glsl @@ -0,0 +1,50 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_input", DTYPE, "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int split_dim = 0; +layout(constant_id = 4) const int split_idx = 0; +layout(constant_id = 5) const int split_offset = 0; + +void main() { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, outp)) { + return; + } + + TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi); + + TensorIndex input_tidx = out_tidx; + input_tidx.data[div_4(split_dim)][mod_4(split_dim)] += split_offset; + + const uint input_bufi = tensor_idx_to_linear_idx(inp, input_tidx); + + t_out[out_bufi] = t_input[input_bufi]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/split_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/split_buffer.yaml new file mode 100644 index 00000000000..45dbff832f9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/split_buffer.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +split_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + - VALUE: uint8 + shader_variants: + - NAME: split_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/split_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/split_texture.glsl new file mode 100644 index 00000000000..92d7ce548e2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/split_texture.glsl @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, "texture3d")} +#define T ${texel_load_component_type(DTYPE, "texture3d")} + +${define_active_storage_type("texture3d")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "common.glslh" +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_input", DTYPE, "texture3d")} + +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int split_dim = 0; +layout(constant_id = 4) const int split_idx = 0; +layout(constant_id = 5) const int split_offset = 0; + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + if (out_of_bounds(out_pos, outp)) { + return; + } + + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); + + VEC4_T out_texel = VEC4_T(0); + + int limit = min( + 4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]); + + TensorIndex4D input_tidx = out_tidx; + input_tidx.data[split_dim] += split_offset; + + for (int comp = 0; comp < limit; comp++) { + TextureElementIndex input_elem_pos = tensor4d_idx_to_texture_element_idx_simple( + inp, input_tidx); + + VEC4_T input_texel = texelFetch(t_input, input_elem_pos.pos, 0); + out_texel[comp] = input_texel[input_elem_pos.comp]; + + input_tidx.data[outp.packed_dim]++; + } + + imageStore(t_output, out_pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/split_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/split_texture.yaml new file mode 100644 index 00000000000..6a1613a401e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/split_texture.yaml @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +split_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + - VALUE: uint8 + shader_variants: + - NAME: split_texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.glsl index 7605c59c72f..73b753ccc0b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.glsl @@ -11,18 +11,23 @@ #define PRECISION ${PRECISION} #define UBO_PARAMS ${UBO_PARAMS} -#define VEC4_T ${texel_type(DTYPE)} #define T ${buffer_scalar_type(DTYPE)} ${define_active_storage_type("buffer")} ${define_required_extensions(DTYPE)} +#extension GL_EXT_control_flow_attributes : require + layout(std430) buffer; -#include "indexing_utils.h" +#include "indexing.glslh" + ${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} ${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + $if UBO_PARAMS: $if OP_NAME == "slice": ${layout_declare_ubo(B, "int", "start")} @@ -32,10 +37,6 @@ $if UBO_PARAMS: ${layout_declare_ubo(B, "int", "index")} layout(push_constant) uniform restrict Block { - ivec4 in_sizes; - ivec4 out_strides; - ivec4 in_strides; - int out_numel; int selected_dim; $if not UBO_PARAMS: $if OP_NAME == "slice": @@ -46,24 +47,19 @@ layout(push_constant) uniform restrict Block { int index; }; -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} - -const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); - layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #include "${OP_NAME}.glslh" void main() { - const int out_bufi = ivec3(gl_GlobalInvocationID).x; - if (out_bufi >= out_numel) { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, outp)) { return; } - const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order); - ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); + TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi); + TensorIndex in_tidx = out_tidx_to_in_tidx(out_tidx); - const int in_bufi = tidx_to_bufi(in_tidx, in_strides); + const uint in_bufi = tensor_idx_to_linear_idx(inp, in_tidx); t_out[out_bufi] = t_in[in_bufi]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.yaml index f68b2bd1250..62bab110828 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.yaml @@ -8,6 +8,7 @@ transfer_buffer: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: select_buffer OP_NAME: select diff --git a/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.glsl index 0f34713cb43..d2c9c025242 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.glsl @@ -11,19 +11,25 @@ #define PRECISION ${PRECISION} #define UBO_PARAMS ${UBO_PARAMS} -#define VEC4_T ${texel_type(DTYPE)} -#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${texel_load_type(DTYPE, "texture3d")} +#define T ${texel_load_component_type(DTYPE, "texture3d")} ${define_active_storage_type("texture3d")} ${define_required_extensions(DTYPE)} +#extension GL_EXT_control_flow_attributes : require + layout(std430) buffer; -#include "indexing_utils.h" +#include "common.glslh" +#include "indexing.glslh" ${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "inp")} + $if UBO_PARAMS: $if OP_NAME == "slice": ${layout_declare_ubo(B, "int", "start")} @@ -33,8 +39,6 @@ $if UBO_PARAMS: ${layout_declare_ubo(B, "int", "index")} layout(push_constant) uniform restrict Block { - ivec4 out_sizes; - ivec4 in_sizes; int selected_dim; $if not UBO_PARAMS: $if OP_NAME == "slice": @@ -45,48 +49,33 @@ layout(push_constant) uniform restrict Block { int index; }; -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); -const lowp int out_packed_dim = unhash_packed_dim(out_layout); - -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); -const lowp int in_packed_dim = unhash_packed_dim(in_layout); - layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #include "${OP_NAME}.glslh" void main() { - const ivec3 lpos = ivec3(gl_GlobalInvocationID); - ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim); + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(out_tidx, out_sizes))) { + if (out_of_bounds(out_pos, outp)) { return; } - if (can_use_fast_path()) { - ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); - ivec3 in_pos = tidx_to_pos(in_tidx, in_sizes, in_axis_map, in_packed_dim); - VEC4_T in_texel = VEC4_T(load_texel(t_in, in_pos)); + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); + VEC4_T out_texel = VEC4_T(0); - write_texel_lpos(t_out, lpos, in_texel, out_axis_map); - } - else { - VEC4_T out_texel = VEC4_T(0); - for (int texel_i = 0; texel_i < 4; ++texel_i) { - ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); - ivec3 in_pos = tidx_to_pos(in_tidx, in_sizes, in_axis_map, in_packed_dim); - int element_idx = in_tidx[in_packed_dim] % 4; - - VEC4_T in_texel = VEC4_T(load_texel(t_in, in_pos)); - T selected_value = T(in_texel[element_idx]); + int limit = min( + 4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]); + for (int comp = 0; comp < limit; comp++) { + TensorIndex4D in_tidx = out_tidx_to_in_tidx(out_tidx); - out_texel[texel_i] = selected_value; + TextureElementIndex in_elem_pos = tensor4d_idx_to_texture_element_idx_simple( + inp, in_tidx); - out_tidx[out_packed_dim]++; - } + VEC4_T in_texel = texelFetch(t_in, in_elem_pos.pos, 0); + out_texel[comp] = in_texel[in_elem_pos.comp]; - write_texel_lpos(t_out, lpos, out_texel, out_axis_map); + out_tidx.data[outp.packed_dim]++; } + + imageStore(t_out, out_pos, out_texel); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.yaml index 6922f120e49..7824801ddb6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.yaml @@ -8,6 +8,7 @@ transfer_texture: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: select_texture3d OP_NAME: select diff --git a/backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.glsl b/backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.glsl new file mode 100644 index 00000000000..be0a39bac3c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.glsl @@ -0,0 +1,131 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +// corresponds to the output width dim +#define TILE_M4 1 +// corresponds to the output channels dim +#define TILE_K4 1 + +#define TILE_M 4 + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_fp_output", DTYPE, OUTPUT_STORAGE)} +${layout_declare_tensor(B, "r", "t_packed_int8_output", "int", INPUT_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} + +layout(push_constant) uniform restrict Block { + float scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "linear_fp_input_tile.glslh" +#include "linear_int8_input_tile.glslh" + +void load_packed_int8_tile( + out Int8InputTile int8_tile, + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents) { +#ifdef INPUT_BUFFER + const int buffer_idx = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z + block_idx.data.z; + int8_tile.data[0][0] = t_packed_int8_output[buffer_idx]; +#else + int8_tile.data[0][0] = texelFetch(t_packed_int8_output, block_idx.data, 0); +#endif +} + +VEC4_T +dequantize_8bit(const ivec4 val, const float q_scale, const int q_zero_point) { + return VEC4_T(val - q_zero_point) * q_scale; +} + +void unpack_and_dequantize( + out FPInputTile fp_tile, + const Int8InputTile int8_tile, + const float q_scale, + const int q_zero_point) { + [[unroll]] for (int w = 0; w < 4; ++w) { + int packed = int8_tile.data[0][0][w]; + fp_tile.data[w][0] = dequantize_8bit( + ivec4( + extract_8bit_from_packed_int_le(packed, 0), + extract_8bit_from_packed_int_le(packed, 1), + extract_8bit_from_packed_int_le(packed, 2), + extract_8bit_from_packed_int_le(packed, 3)), + q_scale, + q_zero_point); + } +} + +void store_fp_output_texel( + const Conv2dTensorIndex tidx, + const VEC4_T out_texel) { +#ifdef OUTPUT_BUFFER + const int c_idx = mul_4(tidx.data.z); + const int c_stride = output_sizes.y * output_sizes.x; + + const int base_buf_i = c_idx * c_stride + tidx.data.y * output_sizes.x + tidx.data.x; + const int limit = min(output_sizes.z - c_idx, 4); + + for (int i = 0; i < limit; ++i) { + t_fp_output[base_buf_i + i * c_stride] = out_texel[i]; + } +#else + imageStore(t_fp_output, tidx.data, out_texel); +#endif +} + +void store_fp_tile( + const FPInputTile block, + const Conv2dBlockIndex block_idx) { + Conv2dTensorIndex store_tidx = block_idx_to_tensor_idx(block_idx); + [[unroll]] for (int w = 0; w < 4; w++) { + if (store_tidx.data.x < output_sizes.x) { + store_fp_output_texel(store_tidx, block.data[w][0]); + } + store_tidx.data.x++; + } +} + +void main() { + Conv2dBlockIndex block_idx; + block_idx.data = ivec3(gl_GlobalInvocationID); + + Conv2dBlockExtents block_extents = make_block_extents(output_sizes); + if (block_idx_out_of_bounds(block_idx, block_extents)) { + return; + } + + Int8InputTile int8_tile; + load_packed_int8_tile(int8_tile, block_idx, block_extents); + + FPInputTile fp_tile; + unpack_and_dequantize( + fp_tile, int8_tile, scale, zp); + + store_fp_tile(fp_tile, block_idx); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.yaml b/backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.yaml new file mode 100644 index 00000000000..9f2a584a6c3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.yaml @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +unpack_4w4c_and_dequantize: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + generate_variant_forall: + combination: + parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [texture3d, buffer] + - parameter_values: [buffer, buffer] + DTYPE: + - VALUE: float + shader_variants: + - NAME: unpack_4w4c_and_dequantize_per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.yaml b/backends/vulkan/runtime/graph/ops/glsl/view.yaml index 33364a25225..e963d253424 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/view.yaml @@ -8,5 +8,6 @@ view: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: view diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl index 2c02803a9b1..96b9aa85a1f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl @@ -18,6 +18,8 @@ ${layout_declare_ubo(B, "BufferMetadata", "inp")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "all_contiguous", "0")} + /* * The insight behind the view operation is that the contiguous index of each * tensor element in the input and output tensors are the same. @@ -28,17 +30,20 @@ void main() { return; } - TensorIndex outp_tidx; - linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + uint inp_bufi = outp_bufi; + if (all_contiguous == 0) { + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); - // To map the output to the input, find the input element that has the same - // contiguous index as the output element. - const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); + // To map the output to the input, find the input element that has the same + // contiguous index as the output element. + const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); - TensorIndex inp_tidx; - contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); + TensorIndex inp_tidx; + contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); - const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + } t_outp[outp_bufi] = t_inp[inp_bufi]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl new file mode 100644 index 00000000000..a926c9fea11 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl @@ -0,0 +1,54 @@ +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_buffer(B, "w", "t_outp", OUT_DTYPE)} +${layout_declare_buffer(B, "r", "t_inp", IN_DTYPE)} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "all_contiguous", "0")} + +/* + * The insight behind the view_convert operation is that the contiguous index of each + * tensor element in the input and output tensors are the same, but the data types + * may be different and need conversion. + */ +void main() { + const uint outp_bufi = gl_GlobalInvocationID.x; + if (outp_bufi >= numel(outp)) { + return; + } + + uint inp_bufi = outp_bufi; + + if (all_contiguous == 0) { + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + + // To map the output to the input, find the input element that has the same + // contiguous index as the output element. + const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); + + TensorIndex inp_tidx; + contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); + + inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + } + + // Convert data type from input to output + t_outp[outp_bufi] = OUT_T(t_inp[inp_bufi]); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml new file mode 100644 index 00000000000..11d56cad4a9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +view_convert_buffer: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: float + STORAGE: buffer + generate_variant_forall: + combination: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [int32, float] + - parameter_values: [int32, half] + - parameter_values: [uint8, float] + - parameter_values: [uint8, half] + - parameter_values: [uint8, int32] + shader_variants: + - NAME: view_convert_buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/ArgReduce.cpp b/backends/vulkan/runtime/graph/ops/impl/ArgReduce.cpp new file mode 100644 index 00000000000..68a51602f74 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/ArgReduce.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include + +namespace vkcompute { + +void arg_reduce_impl( + ComputeGraph& graph, + const std::vector& args, + const std::string& op_name) { + int arg_idx = 0; + const ValueRef in = args.at(arg_idx++); + const ValueRef dim = args.at(arg_idx++); + const ValueRef keepdim = args.at(arg_idx++); + const ValueRef out = args.at(arg_idx++); + + VK_CHECK_COND(graph.is_buffer_storage(in)); + + int64_t dim_val = 0; + if (graph.val_is_not_none(dim)) { + dim_val = graph.extract_scalar(dim); + } + const int64_t ndim = graph.dim_of(in); + const int64_t normalized_dim = normalize(dim_val, graph.dim_of(in)); + + VK_CHECK_COND(normalized_dim == ndim - 1); + + // Use the reduce_per_row_node function + add_reduce_per_row_node(graph, in, keepdim, out, op_name); +} + +void argmin(ComputeGraph& graph, const std::vector& args) { + arg_reduce_impl(graph, args, "argmin"); +} + +void argmax(ComputeGraph& graph, const std::vector& args) { + arg_reduce_impl(graph, args, "argmax"); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.argmin.default, argmin); + VK_REGISTER_OP(aten.argmax.default, argmax); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp index 757afd06849..a6dd8f07f53 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp @@ -19,6 +19,18 @@ namespace vkcompute { +void resize_batch_norm_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + + // For batch norm, output dimensions are the same as input dimensions + std::vector new_out_sizes = graph->sizes_of(self); + graph->virtual_resize(out, new_out_sizes); +} + ValueRef check_and_prepack_arg( ComputeGraph& graph, ValueRef arg_ref, @@ -101,7 +113,7 @@ void add_native_batch_norm_node( // Resize Args {}, // Resizing Logic - nullptr)); + resize_batch_norm_node)); } void native_batch_norm(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp new file mode 100644 index 00000000000..15553706494 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +#include + +namespace vkcompute { + +void resize_binary_scalar_op_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + + const std::vector in_sizes = graph->sizes_of(in); + + graph->virtual_resize(out, in_sizes); +} + +void add_binary_scalar_op_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef scalar, + const ValueRef out, + const std::string& op_name) { + ValueRef arg = prepack_standard_like(graph, in, out, true); + + // Extract scalar value + float scalar_val = graph.extract_scalar(scalar); + + // Pick shader + std::string kernel_name = op_name + "_scalar"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(in)); + + vkapi::ParamsBindList param_ubos = {graph.meta_ubo(out), graph.meta_ubo(in)}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {arg, vkapi::kRead}}, + // Shader params buffers + param_ubos, + // Push Constants + {PushConstantDataInfo(&scalar_val, sizeof(scalar_val))}, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_binary_scalar_op_node)); +} + +void pow_tensor_scalar(ComputeGraph& graph, const std::vector& args) { + return add_binary_scalar_op_node(graph, args[0], args[1], args[2], "pow"); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.pow.Tensor_Scalar, pow_tensor_scalar); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index a4a96ffdb88..5b8615e0a70 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -34,150 +34,6 @@ void resize_choose_qparams_per_row( graph->virtual_resize(input_zeros, new_sizes); } -utils::uvec3 choose_qparams_pick_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - // For per-tensor quantization, we want a single workgroup that can handle - // all elements with proper reduction. The shader uses NWORKERS=64 threads. - const ValueRef input = args.at(1).refs.at(0); - - if (graph->is_buffer_storage(input)) { - // For buffer storage, use a single workgroup in X dimension - // The shader will handle strided access across all elements - return {1u, 1u, 1u}; - } else { - // For texture storage, use the default logic - return graph->create_global_wg_size(args.at(0).refs.at(0)); - } -} - -utils::uvec3 choose_qparams_pick_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - if (graph->is_buffer_storage(input)) { - // For buffer storage, use 64 threads in X dimension to match NWORKERS - // This ensures the shared memory arrays are properly sized - return {64u, 1u, 1u}; - } else { - // For texture storage, use the default logic - return graph->create_local_wg_size(global_workgroup_size); - } -} - -utils::uvec3 choose_qparams_per_token_pick_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - if (graph->is_buffer_storage(input)) { - // For per-token quantization, we need one workgroup per token - // Calculate number of tokens (product of all dimensions except the last - // one) - const auto input_sizes = graph->sizes_of(input); - int64_t num_tokens = 1; - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - return {static_cast(num_tokens), 1u, 1u}; - } else { - // For texture storage, use the default logic - return graph->create_global_wg_size(args.at(0).refs.at(0)); - } -} - -utils::uvec3 choose_qparams_per_token_pick_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - if (graph->is_buffer_storage(input)) { - return {1u, 1u, 1u}; - } else { - // For texture storage, use the default logic - return graph->create_local_wg_size(global_workgroup_size); - } -} - -utils::uvec3 choose_qparams_block_wise_pick_global_wg_size( - ComputeGraph* g, - const vkapi::ShaderInfo&, - const std::vector& a, - const std::vector& r) { - const ValueRef input = a.at(2).refs.at(0); - const auto blkRef = r.at(0); - const auto inSz = g->sizes_of(input); - const auto blkList = g->get_int_list(blkRef); - - // Use same code as in add_choose_qparams_block_wise_node - utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*blkList); - utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(inSz); - - // Calculate numBlocks: ceil(tensorSize / blockSize) (both in WHCN order) - utils::ivec4 nBlk = { - (tensor_size_whcn[0] + block_size_vec[0] - 1) / block_size_vec[0], - (tensor_size_whcn[1] + block_size_vec[1] - 1) / block_size_vec[1], - (tensor_size_whcn[2] + block_size_vec[2] - 1) / block_size_vec[2], - (tensor_size_whcn[3] + block_size_vec[3] - 1) / block_size_vec[3]}; - - uint32_t nBlocks = nBlk[0] * nBlk[1] * nBlk[2] * nBlk[3]; - - // For texture storage, use more threads to better utilize GPU parallelism - // Each thread can process multiple blocks with stride - if (g->is_buffer_storage(input)) { - return {nBlocks, 1u, 1u}; - } else { - // For texture storage, use more workgroups to better utilize GPU - // Aim for ~64-256 threads per workgroup for good occupancy - uint32_t preferred_threads_per_wg = 64; - uint32_t num_workgroups = - (nBlocks + preferred_threads_per_wg - 1) / preferred_threads_per_wg; - num_workgroups = std::max(1u, std::min(num_workgroups, nBlocks)); - return {num_workgroups * preferred_threads_per_wg, 1u, 1u}; - } -} - -utils::uvec3 choose_qparams_block_wise_pick_local_wg_size( - ComputeGraph* g, - const vkapi::ShaderInfo&, - const utils::uvec3& global_wg_size, - const std::vector& a, - const std::vector&) { - const ValueRef input = a.at(2).refs.at(0); - - if (g->is_buffer_storage(input)) { - return {1u, 1u, 1u}; - } else { - // For texture storage, use 64 threads per workgroup for better occupancy - uint32_t local_size = std::min(64u, global_wg_size[0]); - return {local_size, 1u, 1u}; - } -} - vkapi::ShaderInfo pick_choose_qparams_per_row_shader( ComputeGraph* graph, const std::vector& args, @@ -222,160 +78,6 @@ utils::uvec3 pick_choose_qparams_per_row_local_wg_size( return {workers_per_output, outputs_per_wg, 1u}; } -void add_choose_qparams_tensor_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& eps, - const ValueRef& scale_out, - const ValueRef& zero_point_out) { - std::string kernel_name("choose_qparams_tensor"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale_out)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point_out)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(zero_point_out)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - float eps_val = static_cast(graph.get_double(eps)); - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(scale_out), - graph.strides_ubo(scale_out), - graph.sizes_ubo(zero_point_out), - graph.strides_ubo(zero_point_out)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(scale_out), - graph.logical_limits_ubo(zero_point_out)}; - } - - push_constants = { - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - PushConstantDataInfo(&eps_val, sizeof(float)), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - choose_qparams_pick_global_wg_size, - choose_qparams_pick_local_wg_size, - // Inputs and Outputs - {{scale_out, vkapi::kWrite}, - {zero_point_out, vkapi::kWrite}, - {input, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize Args - {}, - // Resizing Logic - nullptr)); -} - -void add_choose_qparams_per_token_asymmetric_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale_out, - const ValueRef& zero_point_out) { - std::string kernel_name("choose_qparams_per_token_asymmetric"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale_out)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point_out)); - - // Calculate number of tokens (product of all dimensions except the last one) - int64_t num_tokens = 1; - const auto input_sizes = graph.sizes_of(input); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - int num_tokens_val = static_cast(num_tokens); - int quant_min_val = -128; // Fixed for asymmetric quantization - int quant_max_val = 127; // Fixed for asymmetric quantization - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(scale_out), - graph.strides_ubo(scale_out), - graph.sizes_ubo(zero_point_out), - graph.strides_ubo(zero_point_out)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(scale_out), - graph.logical_limits_ubo(zero_point_out)}; - } - - push_constants = { - PushConstantDataInfo(&num_tokens_val, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - choose_qparams_per_token_pick_global_wg_size, - choose_qparams_per_token_pick_local_wg_size, - // Inputs and Outputs - {{scale_out, vkapi::kWrite}, - {zero_point_out, vkapi::kWrite}, - {input, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize Args - {}, - // Resizing Logic - nullptr)); -} - void add_choose_qparams_per_row_node( ComputeGraph& graph, const ValueRef& input, @@ -427,221 +129,6 @@ void add_choose_qparams_per_row_node( resize_choose_qparams_per_row)); } -void add_choose_qparams_block_wise_node( - ComputeGraph& graph, - ValueRef input, - ValueRef block_size, - int mapping_type, // 0 / 1 / 2 - ValueRef quant_min, - ValueRef quant_max, - ValueRef eps, - ValueRef scale_out, - ValueRef zp_out) { - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - - // For shader compatibility, we still need to convert to WHCN order - // but the output shape calculation is now handled correctly in resize - // function - utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); - utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); - - // Calculate numBlocks: ceil(tensorSize / blockSize) (both in WHCN order) - utils::ivec4 num_blocks_vec = { - (tensor_size_whcn[0] + block_size_vec[0] - 1) / block_size_vec[0], - (tensor_size_whcn[1] + block_size_vec[1] - 1) / block_size_vec[1], - (tensor_size_whcn[2] + block_size_vec[2] - 1) / block_size_vec[2], - (tensor_size_whcn[3] + block_size_vec[3] - 1) / block_size_vec[3]}; - - // Calculate blockStride: pre-computed linear strides for the block grid - utils::ivec4 block_stride_vec = { - 1, - num_blocks_vec[0], - num_blocks_vec[0] * num_blocks_vec[1], - num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; - - // Handle optional quant_min and quant_max parameters - int qmin, qmax; - if (graph.val_is_none(quant_min) || graph.val_is_none(quant_max)) { - // Use default values based on target_dtype (similar to - // _get_and_check_qmin_qmax) For now, assume int8 range as default - this - // should match the Python implementation - qmin = -128; - qmax = 127; - } else { - qmin = static_cast(graph.get_int(quant_min)); - qmax = static_cast(graph.get_int(quant_max)); - } - - float eps_val; - if (graph.val_is_none(eps)) { - // Use default eps value (similar to Python implementation) - eps_val = 1.192092896e-07f; // torch.finfo(torch.float32).eps - } else { - eps_val = static_cast(graph.get_double(eps)); - } - - // Create push constants vector - std::vector push_constants = { - PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), - PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), - PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), - PushConstantDataInfo(&mapping_type, sizeof(int)), - PushConstantDataInfo(&qmin, sizeof(int)), - PushConstantDataInfo(&qmax, sizeof(int)), - PushConstantDataInfo(&eps_val, sizeof(float))}; - - std::string kernel_name("choose_qparams_block_wise"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale_out)); - add_dtype_suffix(kernel_name, graph.dtype_of(zp_out)); - - vkapi::ParamsBindList param_ubos; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(scale_out), - graph.strides_ubo(scale_out), - graph.sizes_ubo(zp_out), - graph.strides_ubo(zp_out)}; - } else { - // For texture input, the shader uses buffer storage for outputs - // so we need buffer UBOs for the output tensors - param_ubos = { - graph.logical_limits_ubo(input), - graph.sizes_ubo(scale_out), - graph.strides_ubo(scale_out), - graph.sizes_ubo(zp_out), - graph.strides_ubo(zp_out)}; - } - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - choose_qparams_block_wise_pick_global_wg_size, - choose_qparams_block_wise_pick_local_wg_size, - // Inputs and Outputs - {{scale_out, vkapi::kWrite}, - {zp_out, vkapi::kWrite}, - {input, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize Args - {block_size}, - // Resizing Logic - nullptr)); -} - -void choose_qparams_tensor_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef eps = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef out_tuple_ref = args[arg_idx++]; - - ValueRef scale_out = kDummyValueRef; - ValueRef zero_point_out = kDummyValueRef; - - { - const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); - scale_out = out_tuple->at(0); - zero_point_out = out_tuple->at(1); - } - - // Void the unused dtype parameter to match ATen signature - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale_out)); - VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); - - // Verify input is a floating point type - VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - - // Get scale and zero point output dtypes - vkapi::ScalarType scale_out_dtype = graph.dtype_of(scale_out); - vkapi::ScalarType zero_point_out_dtype = graph.dtype_of(zero_point_out); - - // Verify supported output types for scale (fp32 only for now) - VK_CHECK_COND(scale_out_dtype == vkapi::kFloat); - - // Verify supported output types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_out_dtype == vkapi::kInt || - zero_point_out_dtype == vkapi::kChar || - zero_point_out_dtype == vkapi::kFloat); - - // Check that texture storage is width packed - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); - } - - add_choose_qparams_tensor_node( - graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); -} - -void choose_qparams_per_token_asymmetric_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef out_tuple_ref = args[arg_idx++]; - - ValueRef scale_out = kDummyValueRef; - ValueRef zero_point_out = kDummyValueRef; - - { - const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); - scale_out = out_tuple->at(0); - zero_point_out = out_tuple->at(1); - } - - // Void the unused parameter to match ATen signature - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale_out)); - VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); - - // Verify input is a floating point type - VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - - // Get scale and zero point output dtypes - vkapi::ScalarType scale_out_dtype = graph.dtype_of(scale_out); - vkapi::ScalarType zero_point_out_dtype = graph.dtype_of(zero_point_out); - - // Verify supported output types for scale (fp32 only for now) - VK_CHECK_COND(scale_out_dtype == vkapi::kFloat); - - // Verify supported output types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_out_dtype == vkapi::kInt || - zero_point_out_dtype == vkapi::kChar || - zero_point_out_dtype == vkapi::kFloat); - - // Check that texture storage is width packed - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); - } - - add_choose_qparams_per_token_asymmetric_node( - graph, input, scale_out, zero_point_out); -} - bool can_use_choose_qparams_per_row( ComputeGraph& graph, const ValueRef input, @@ -671,17 +158,21 @@ bool can_use_choose_qparams_per_row( void choose_qparams_affine_impl( ComputeGraph& graph, const std::vector& args) { - int arg_idx = 0; + size_t arg_idx = 0; + size_t last_arg_idx = args.size() - 1; const ValueRef input = args[arg_idx++]; const ValueRef mapping_type = args[arg_idx++]; + (void)mapping_type; const ValueRef block_size = args[arg_idx++]; const ValueRef target_dtype = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; const ValueRef eps = args[arg_idx++]; + (void)eps; const ValueRef scale_dtype = args[arg_idx++]; const ValueRef zero_point_dtype = args[arg_idx++]; - const ValueRef out_tuple_ref = args[arg_idx++]; + + const ValueRef out_tuple_ref = args[last_arg_idx]; // Suppress unused variable warnings (void)target_dtype; @@ -704,59 +195,7 @@ void choose_qparams_affine_impl( graph, input, quant_min, quant_max, scale_out, zero_point_out); } - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale_out)); - VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); - - // Verify input is a floating point type - VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - - // Get scale and zero point dtypes from arguments - vkapi::ScalarType scale_out_dtype = graph.dtype_of(scale_out); - vkapi::ScalarType zero_point_out_dtype = graph.dtype_of(zero_point_out); - - // Verify supported output types for scale (fp32 only for now) - VK_CHECK_COND(scale_out_dtype == vkapi::kFloat); - - // Verify supported output types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_out_dtype == vkapi::kInt || - zero_point_out_dtype == vkapi::kChar || - zero_point_out_dtype == vkapi::kFloat); - - // Check that texture storage is width packed - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); - } - - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - VK_CHECK_COND(block_size_list->size() == input_sizes.size()); - - std::string mapping_type_str = graph.get_string(mapping_type); - int mapping_type_val = 0; // Default to ASYMMETRIC - - if (mapping_type_str == "ASYMMETRIC" || mapping_type_str.empty()) { - mapping_type_val = 0; // ASYMMETRIC - } else if (mapping_type_str == "SYMMETRIC") { - mapping_type_val = 1; - } else if (mapping_type_str == "SYMMETRIC_NO_CLIPPING_ERR") { - mapping_type_val = 2; - } else { - VK_THROW("Unsupported mapping_type: ", mapping_type_str); - } - - add_choose_qparams_block_wise_node( - graph, - input, - block_size, - mapping_type_val, - quant_min, - quant_max, - eps, - scale_out, - zero_point_out); + VK_THROW("Unsupported input case for choose_qparams_affine"); } void choose_qparams_per_row( @@ -769,27 +208,11 @@ void choose_qparams_per_row( const ValueRef input_scales = args[arg_idx++]; const ValueRef input_zps = args[arg_idx++]; - // ValueRef scale_out = kDummyValueRef; - // ValueRef zero_point_out = kDummyValueRef; - // - // { - // const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); - // scale_out = out_tuple->at(0); - // zero_point_out = out_tuple->at(1); - // } - // - add_choose_qparams_per_row_node( graph, input, quant_min, quant_max, input_scales, input_zps); } REGISTER_OPERATORS { - VK_REGISTER_OP( - quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl); - VK_REGISTER_OP( - quantized_decomposed.choose_qparams_per_token_asymmetric.default, - choose_qparams_per_token_asymmetric_impl); - // Register the per-channel quantization operator VK_REGISTER_OP(etvk.choose_qparams_per_row.default, choose_qparams_per_row); diff --git a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp index 0ae9d53a481..a64cb0143a9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp @@ -48,9 +48,9 @@ void add_clone_node( // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Parameter Buffers - {graph.logical_limits_ubo(out)}, - // Push Constants {}, + // Push Constants + {graph.logical_limits_pc_of(out)}, // Specialization Constants {}, // Resize Args @@ -76,6 +76,7 @@ void add_image_to_buffer_node( const ValueRef buffer) { std::string kernel_name = "clone_image_to_buffer"; add_dtype_suffix(kernel_name, graph.dtype_of(image)); + add_dtype_suffix(kernel_name, graph.dtype_of(buffer)); vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); graph.execute_nodes().emplace_back(new DynamicDispatchNode( @@ -103,6 +104,7 @@ void add_buffer_to_image_node( const ValueRef image) { std::string kernel_name = "clone_buffer_to_image"; add_dtype_suffix(kernel_name, graph.dtype_of(image)); + add_dtype_suffix(kernel_name, graph.dtype_of(buffer)); vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); graph.execute_nodes().emplace_back(new DynamicDispatchNode( diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.cpp b/backends/vulkan/runtime/graph/ops/impl/Common.cpp index 6c701224f7f..71690ffc604 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Common.cpp @@ -56,4 +56,27 @@ utils::uvec3 pick_hw_square_wg_size( return {16u, 4u, 1u}; } +utils::uvec3 pick_wc_square_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)args; + (void)resize_args; + // Some inactive invocations are okay; set 6 as the threshold to use the + // a square wg size. + if (global_workgroup_size[0u] >= 6 && global_workgroup_size[2u] >= 6) { + return {8u, 1u, 8u}; + } + // If channels dim is sufficiently small, then bias towards width dim to + // reduce the number of inactive invocations. + if (global_workgroup_size[2u] < 2u) { + return {64u, 1u, 1u}; + } + return {16u, 1u, 4u}; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.h b/backends/vulkan/runtime/graph/ops/impl/Common.h index 1831ab2a845..b412f737c13 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.h +++ b/backends/vulkan/runtime/graph/ops/impl/Common.h @@ -54,4 +54,11 @@ utils::uvec3 pick_hw_square_wg_size( const std::vector& args, const std::vector& resize_args); +utils::uvec3 pick_wc_square_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index ded1defe973..479bb44ae6f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -105,7 +105,8 @@ ValueRef prepack_biases( ValueRef v = graph.add_tensor( {out_channels}, graph.dtype_of(weight), storage_type, memory_layout); - vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(graph, v); + vkapi::ShaderInfo shader = + get_nchw_to_tensor_shader(graph, v, graph.dtype_of(weight)); graph.prepack_nodes().emplace_back(new PrepackNode( graph, @@ -364,6 +365,10 @@ utils::uvec3 conv2d_global_wg_size( if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) { wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; + + if (shader.kernel_name.find("s1p0") != std::string::npos) { + wg_size[0] *= 4; + } } return wg_size; diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp deleted file mode 100644 index bd648dbae2d..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp +++ /dev/null @@ -1,317 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include -#include -#include -#include -#include - -namespace vkcompute { - -using utils::ivec3; -using utils::ivec4; -using utils::uvec3; - -void add_copy_offset_node( - ComputeGraph& graph, - const ValueRef in, - const ivec3& range, - const ivec4& src_offset, - const ivec4& dst_offset, - const ValueRef out, - bool calc_out_pos_using_src_chnl, - bool calc_in_pos_using_dst_chnl) { - std::string kernel_name = "copy_offset"; - kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - - auto shader = VK_KERNEL_FROM_STR(kernel_name); - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - { - {out, vkapi::kWrite}, - {in, vkapi::kRead}, - }, - // Parameter buffers - {}, - // Push Constants - { - PushConstantDataInfo(&range, sizeof(range), sizeof(ivec4)), - PushConstantDataInfo(&src_offset, sizeof(src_offset), sizeof(ivec4)), - PushConstantDataInfo(&dst_offset, sizeof(dst_offset), sizeof(ivec4)), - }, - // Specialization Constants - {graph.hashed_layout_of(out), - graph.hashed_layout_of(in), - (calc_out_pos_using_src_chnl ? 1 - : calc_in_pos_using_dst_chnl ? 2 - : 0)}, - // Resize Args - {}, - // Resizing Logic - nullptr)); -} - -void add_copy_packed_dim_offset_node( - ComputeGraph& graph, - const ValueRef in, - const ivec3& range, - const ivec4& src_offset, - const ivec4& dst_offset, - const ValueRef out) { - // Check the packed dimension is same for both tensors, also check if the - // packed dimension is Width or Height. Since the function does not support - // channel packing. - VK_CHECK_COND( - graph.packed_dim_of(in) == graph.packed_dim_of(out) && - (graph.packed_dim_of(in) == WHCN::kWidthDim || - graph.packed_dim_of(in) == WHCN::kHeightDim)); - - std::string kernel_name = "copy_packed_dim_offset"; - kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - const std::vector in_sizes = graph.sizes_of(in); - const std::vector out_sizes = graph.sizes_of(out); - - // A copy of range with the last element set to batch size of the input tensor - ivec4 final_range = { - range[0], range[1], range[2], dim_at(in_sizes, kBatch4D)}; - ivec3 global_wg_size = graph.logical_limits_of(out); - - const auto packed_dim = graph.packed_dim_of(in); - // The starting offset in a texel where this tensor will start copying from - const auto src_lane_offset = src_offset[packed_dim] & 0x3; - // The starting offset in a texel where this tensor will start copying to - const auto dst_lane_offset = dst_offset[packed_dim] & 0x3; - - // The total packed texels this tensor will be copied from - // The first texel of tensor data in packed dimension will be copied from - // remaining lanes from current source Hence (4 - src_lane_offset) is added - // to tensor size in packed dimension - const auto src_packed_size = utils::div_up_4( - (4 - src_lane_offset) + utils::val_at(-packed_dim, out_sizes)); - - // The total packed texels this tensor will be copied to - // The first texel of tensor data in packed dimension will be copied to - // remaining lanes from previous write Hence (4 - dst_lane_offset) is added - // to tensor size in packed dimension - const auto dst_packed_size = utils::div_up_4( - (4 - dst_lane_offset) + utils::val_at(-packed_dim, in_sizes)); - - // If the starting src offset is not 0, and the total packed texels is - // greater than the source texel range - const bool has_additional_src_work = - src_lane_offset != 0 && src_packed_size > final_range[packed_dim]; - // If the starting dst offset is not 0, and the total packed texels is - // greater than the source texel range - const bool has_additional_dst_work = - dst_lane_offset != 0 && dst_packed_size > final_range[packed_dim]; - - if (has_additional_src_work || has_additional_dst_work) { - global_wg_size[packed_dim]++; // Increase the global work group size in - // packed dimension - final_range[packed_dim]++; // Increase the range in packed dimension - } - - auto shader = VK_KERNEL_FROM_STR(kernel_name); - - graph.execute_nodes().emplace_back(new DispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - graph.create_local_wg_size(global_wg_size), - // Inputs and Outputs - { - {out, vkapi::kWrite}, - {out, vkapi::kRead}, - {in, vkapi::kRead}, - }, - // Parameter buffers - {}, - // Push Constants - { - PushConstantDataInfo( - &final_range, sizeof(final_range), sizeof(ivec4)), - PushConstantDataInfo(&src_offset, sizeof(src_offset), sizeof(ivec4)), - PushConstantDataInfo(&dst_offset, sizeof(dst_offset), sizeof(ivec4)), - }, - // Specialization Constants - {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}, - // Resize Args - {}, - // Resizing Logic - nullptr)); -} - -void add_copy_channel_offset_node( - ComputeGraph& graph, - const ValueRef in, - int32_t channel_range, - int32_t src_channel_offset, - int32_t dst_channel_offset, - const ValueRef out) { - // Likely need to prepad these numbers. - const std::vector in_sizes = graph.sizes_of(in); - const std::vector out_sizes = graph.sizes_of(out); - - VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kChannelsDim); - VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kChannelsDim); - - // NOTE: This function should be able to support 1d and 2d tensors when - // range=1, src_offset=dst_offset=1. - VK_CHECK_COND(graph.dim_of(in) >= 3, "Src dim should be at least 3"); - VK_CHECK_COND(graph.dim_of(out) >= 3, "Dst dim should be at least 3"); - - VK_CHECK_COND( - dim_at(in_sizes) >= src_channel_offset + channel_range, - "Src channel (", - src_channel_offset, - ") and range (", - channel_range, - ") should be less than or equal to input tensor's channel size (", - dim_at(in_sizes), - ")"); - - VK_CHECK_COND( - dim_at(out_sizes) >= dst_channel_offset + channel_range, - "Dst channel (", - dst_channel_offset, - ") and range (", - channel_range, - ") should be less than or equal to input tensor's channel size (", - dim_at(out_sizes), - ")"); - - VK_CHECK_COND(channel_range >= 0, "Channel range must be non-negative"); - VK_CHECK_COND( - src_channel_offset >= 0, "Src channel offset must be non-negative"); - VK_CHECK_COND( - dst_channel_offset >= 0, "Dst channel offset must be non-negative"); - - std::string kernel_name = "copy_channel_offset"; - kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - int32_t out_channels = dim_at(out_sizes); - - // Copy one batch at a time. - for (int batch_idx = 0; batch_idx < dim_at(in_sizes); batch_idx++) { - // Mapping the tensor NCHW coordinates into texture XYZ coordinates - int32_t dst_first_z = dst_channel_offset / 4; - int32_t dst_last_z = (dst_channel_offset + channel_range - 1) / 4; - - // We copy the entire width and height dimension. For the channel dimension, - // we use the z-dimension of the global_size to specify the texture range. - // The shader combines the global invocation id and the dst_offset to get - // the actual coordinate. - - const ivec3 dst_offset{ - 0, 0, dst_first_z + batch_idx * utils::div_up_4(out_channels)}; - - const uvec3 global_size{ - utils::safe_downcast(dim_at(in_sizes)), - utils::safe_downcast(dim_at(in_sizes)), - utils::safe_downcast(dst_last_z - dst_first_z + 1)}; - const uvec3 local_size = graph.create_local_wg_size(global_size); - - const utils::ivec4 range_params = { - static_cast(global_size[0]), - static_cast(global_size[1]), - static_cast(global_size[2]), - channel_range}; - - const ivec4 offset_params = { - dst_offset[0], dst_offset[1], dst_offset[2], dst_channel_offset}; - - auto shader = VK_KERNEL_FROM_STR(kernel_name); - - graph.execute_nodes().emplace_back(new DispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, - // Inputs and Outputs - { - {out, vkapi::kWrite}, - {out, vkapi::kRead}, - {in, vkapi::kRead}, - }, - // Parameter buffers - {}, - // Push Constants - {graph.sizes_pc_of(out), - graph.sizes_pc_of(in), - PushConstantDataInfo(&range_params, sizeof(range_params)), - PushConstantDataInfo(&offset_params, sizeof(offset_params)), - PushConstantDataInfo(&src_channel_offset, sizeof(src_channel_offset))}, - // Specialization Constants - {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}, - // Resize Args - {}, - // Resizing Logic - nullptr)); - } -} - -void add_copy_offset_node( - ComputeGraph& graph, - ValueRef in, - ValueRef range_ref, - ValueRef src_offset_ref, - ValueRef dst_offset_ref, - ValueRef out) { - ivec3 range = utils::make_ivec3(*graph.get_int_list(range_ref)); - ivec3 src = utils::make_ivec3(*graph.get_int_list(src_offset_ref)); - ivec3 dst = utils::make_ivec3(*graph.get_int_list(dst_offset_ref)); - - ivec4 src_offset = {src[0], src[1], src[2], 0}; - ivec4 dst_offset = {dst[0], dst[1], dst[2], 0}; - - add_copy_offset_node( - graph, in, range, src_offset, dst_offset, out, false, false); -} - -void copy_offset(ComputeGraph& graph, const std::vector& args) { - add_copy_offset_node(graph, args[0], args[1], args[2], args[3], args[4]); -} - -void copy_channel_offset( - ComputeGraph& graph, - const std::vector& args) { - ValueRef in = args[0]; - ValueRef channel_range_ref = args[1]; - ValueRef src_channel_offset_ref = args[2]; - ValueRef dst_channel_offset_ref = args[3]; - ValueRef out = args[4]; - - auto channel_range = graph.extract_scalar(channel_range_ref); - auto src_channel_offset = - graph.extract_scalar(src_channel_offset_ref); - auto dst_channel_offset = - graph.extract_scalar(dst_channel_offset_ref); - - add_copy_channel_offset_node( - graph, in, channel_range, src_channel_offset, dst_channel_offset, out); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP(etvk.copy_offset, copy_offset); - VK_REGISTER_OP(etvk.copy_channel_offset, copy_channel_offset); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.h b/backends/vulkan/runtime/graph/ops/impl/Copy.h deleted file mode 100644 index 41956d482d9..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.h +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include - -#include - -namespace vkcompute { - -// add_copy_offset_node resumes the vkCmdCopyImage command. It copies the -// texture extents specified by the range, src_offset, and dst_offset (all are -// in texture coordinate (x, y, z) from the input image to the output image. -// src_offset.w and dst_offset.w may contain channel size information. -// -// It is possible to have input and output to point to the same image -// object. But when the source range and destination range overlap, the behavior -// is undefined. -// -// boolean flags calc_out_pos_using_src_chnl and calc_in_pos_using_dst_chnl -// can be used to specify an indexing function in the shader -// If calc_out_pos_using_src_chnl is set to true channel and batch index will be -// calculated based on source channel size and will be used to determine -// destination texel position. -// -// If calc_in_pos_using_dst_chnl is set to truechannel and batch index will be -// calculated based on destination channel size and will be used to determine -// source texel position. -// -// If both are true calc_out_pos_using_src_chnl is picked. If both are false no -// index calculation happens. -void add_copy_offset_node( - ComputeGraph& graph, - const ValueRef in, - const utils::ivec3& range, - const utils::ivec4& src_offset, - const utils::ivec4& dst_offset, - const ValueRef out, - bool calc_out_pos_using_src_chnl, - bool calc_in_pos_using_dst_chnl); - -// add_copy_packed_dim_offset_node behaves similar to add_copy_node, except that -// its used when copying packed dimension, if tensor is width or height packed. -// src_offset.w and dst_offset.w may contain channel size information. -// -// It copies the texture extents specified by the range, src_offset, and -// dst_offset (all are in texture coordinate (x, y, z) from the input image to -// the output image. -void add_copy_packed_dim_offset_node( - ComputeGraph& graph, - const ValueRef in, - const utils::ivec3& range, - const utils::ivec4& src_offset, - const utils::ivec4& dst_offset, - const ValueRef out); - -// add_copy_channel_offset_node behaves similar to add_copy_node, except that it -// works on the channel dimensions of the tensor (up to 4 dimensions in NCHW). -// The range and offset arguments are in the tensor coordinate. It assumes the -// underlying texture is channel-packed. -// -// This function is specialized implementation for copying -// channel packed values. The complication comes from when reading / writing the -// channel dimension on indices that are not aligned to packing, we will need -// be careful about the boundaries. -// -// It achieves the following: -// out[:, dst_channel_offset:dst_channel_offset + channel_range, :, :] = -// in [:, src_channel_offset:src_channel_offset + channel_range, :, :] -void add_copy_channel_offset_node( - ComputeGraph& graph, - const ValueRef in, - int32_t channel_range, - int32_t src_channel_offset, - int32_t dst_channel_offset, - const ValueRef out); - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp deleted file mode 100644 index a217734653d..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ /dev/null @@ -1,843 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#include -#include -#include -#include - -namespace vkcompute { - -void resize_dequantize_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - - const ValueRef out = args.at(0).refs.at(0); - const ValueRef in = args.at(1).refs.at(0); - - const std::vector in_sizes = graph->sizes_of(in); - graph->virtual_resize(out, in_sizes); -} - -utils::uvec3 dequantize_per_channel_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)args; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - utils::uvec3 local_wg_size = - graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides - // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. We need to ensure that we dispatch the correct - // number of workgroups in the Z dimension to cover all batch-channel - // combinations. - // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], - // local_wg_size[2]) might reduce the number of workgroups dispatched. To - // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, - // we set local_wg_size[2] = 1. - const auto input_sizes = graph->sizes_of(input); - if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && - global_workgroup_size[2] > 1) { - local_wg_size[2] = 1; - } - - return local_wg_size; -} - -utils::uvec3 dequantize_block_wise_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - const ValueRef input = args.at(1).refs.at(0); - - utils::uvec3 local_wg_size = - graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides - // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. We need to ensure that we dispatch the correct - // number of workgroups in the Z dimension to cover all batch-channel - // combinations. - // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], - // local_wg_size[2]) might reduce the number of workgroups dispatched. To - // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, - // we set local_wg_size[2] = 1. - const auto input_sizes = graph->sizes_of(input); - if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && - global_workgroup_size[2] > 1) { - local_wg_size[2] = 1; - } - - return local_wg_size; -} - -void add_dequantize_per_tensor_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("dequantize_per_tensor"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(input)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_dequantize_node)); -} - -void add_dequantize_per_token_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("dequantize_per_token"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(input)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - int num_tokens = static_cast(graph.sizes_of(scale)[0]); - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_dequantize_node)); -} - -void add_dequantize_per_channel_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& axis, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("dequantize_per_channel"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - int axis_val = static_cast(graph.get_int(axis)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(input)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - // Normalize axis and convert from NCHW to WHCN using utility functions - const auto input_sizes = graph.sizes_of(input); - const int64_t ndim = graph.dim_of(input); - - // Normalize axis to handle negative indices - axis_val = normalize(axis_val, ndim); - - // Convert from NCHW axis to WHCN axis for shader (vulkan representation) - int axis_whcn = nchw_dim_to_whcn_dim(axis_val, ndim); - - int num_channels; - if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) { - // For batch dimension dequantization in 4D tensors, pass the actual number - // of channels so the shader can correctly unfold the batch-channel folding - num_channels = static_cast(input_sizes[1]); // Channel dimension - } else { - num_channels = static_cast(input_sizes[axis_val]); - } - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - dequantize_per_channel_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_dequantize_node)); -} - -void add_dequantize_block_wise_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& block_size, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("dequantize_block_wise"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(input)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - - // Convert dimensions to WHCN order for shader - utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); - utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); - - // Calculate numBlocks: tensorSize / blockSize (both in WHCN order) - utils::ivec4 num_blocks_vec = { - tensor_size_whcn[0] / block_size_vec[0], - tensor_size_whcn[1] / block_size_vec[1], - tensor_size_whcn[2] / block_size_vec[2], - tensor_size_whcn[3] / block_size_vec[3]}; - - // Calculate blockStride: pre-computed linear strides for the block grid - utils::ivec4 block_stride_vec = { - 1, - num_blocks_vec[0], - num_blocks_vec[0] * num_blocks_vec[1], - num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), - PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), - PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - dequantize_block_wise_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_dequantize_node)); -} - -void dequantize_per_tensor_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - dtype and output_dtype are inferred - (void)dtype; - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is an integer type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kByte || - graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kInt); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - add_dequantize_per_tensor_node( - graph, input, scale, zero_point, quant_min, quant_max, output); -} - -void dequantize_per_token_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - dtype and output_dtype are inferred - (void)dtype; - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is an integer type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kByte || - graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kInt); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Calculate number of tokens (product of all dimensions except the last one) - int64_t num_tokens = 1; - const auto input_sizes = graph.sizes_of(input); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - const auto scale_sizes = graph.sizes_of(scale); - const auto zero_point_sizes = graph.sizes_of(zero_point); - - // Calculate total number of elements in scale and zero_point tensors - int64_t scale_numel = 1; - for (size_t i = 0; i < scale_sizes.size(); i++) { - scale_numel *= scale_sizes[i]; - } - - int64_t zero_point_numel = 1; - for (size_t i = 0; i < zero_point_sizes.size(); i++) { - zero_point_numel *= zero_point_sizes[i]; - } - - // Check that the total number of elements matches num_tokens - // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors - // (size [num_tokens, 1]) - VK_CHECK_COND(scale_numel == num_tokens); - VK_CHECK_COND(zero_point_numel == num_tokens); - - add_dequantize_per_token_node( - graph, input, scale, zero_point, quant_min, quant_max, output); -} - -void dequantize_per_channel_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef axis = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - dtype and output_dtype are inferred - (void)dtype; - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is an integer type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kByte || - graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kInt); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Normalize axis - int axis_val = static_cast(graph.get_int(axis)); - const auto input_sizes = graph.sizes_of(input); - int ndim = graph.dim_of(input); - if (axis_val < 0) { - axis_val += ndim; - } - - // Verify axis is valid - VK_CHECK_COND(axis_val >= 0 && axis_val < ndim); - - // Get number of channels along the specified axis - int64_t num_channels = input_sizes[axis_val]; - - const auto scale_sizes = graph.sizes_of(scale); - const auto zero_point_sizes = graph.sizes_of(zero_point); - - // Calculate total number of elements in scale and zero_point tensors - int64_t scale_numel = 1; - for (size_t i = 0; i < scale_sizes.size(); i++) { - scale_numel *= scale_sizes[i]; - } - - int64_t zero_point_numel = 1; - for (size_t i = 0; i < zero_point_sizes.size(); i++) { - zero_point_numel *= zero_point_sizes[i]; - } - - // Check that the total number of elements matches num_channels - VK_CHECK_COND(scale_numel == num_channels); - VK_CHECK_COND(zero_point_numel == num_channels); - - add_dequantize_per_channel_node( - graph, input, scale, zero_point, axis, quant_min, quant_max, output); -} - -void dequantize_affine_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef block_size = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef input_dtype = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - (void)input_dtype; - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is an integer type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kByte || - graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kInt); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Verify block_size is valid (each dimension must divide evenly into input - // size) - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - VK_CHECK_COND(block_size_list->size() == input_sizes.size()); - - for (size_t i = 0; i < input_sizes.size(); i++) { - if ((*block_size_list)[i] > 1) { - VK_CHECK_COND( - input_sizes[i] % (*block_size_list)[i] == 0, - "Input size at dimension ", - i, - " (", - input_sizes[i], - ") must be divisible by block_size at dimension ", - i, - " (", - (*block_size_list)[i], - ")"); - } - } - - add_dequantize_block_wise_node( - graph, - input, - block_size, - scale, - zero_point, - quant_min, - quant_max, - output); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP( - quantized_decomposed.dequantize_per_tensor.tensor, - dequantize_per_tensor_impl); - VK_REGISTER_OP( - quantized_decomposed.dequantize_per_token.default, - dequantize_per_token_impl); - VK_REGISTER_OP( - quantized_decomposed.dequantize_per_channel.default, - dequantize_per_channel_impl); - - // TorchAO affine dequantization operators - VK_REGISTER_OP(torchao.dequantize_affine.default, dequantize_affine_impl); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp index 475e7796b09..61d27d48f6c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp @@ -36,14 +36,66 @@ void check_embedding_args( VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kChannelsDim); } +void resize_embedding_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef indices = args.at(1).refs.at(0); + const ValueRef weight = args.at(1).refs.at(1); + + const std::vector indices_sizes = graph->sizes_of(indices); + const std::vector weight_sizes = graph->sizes_of(weight); + + // Output shape is indices.shape + [embedding_dim] + // where embedding_dim is the last dimension of weight + std::vector out_sizes = indices_sizes; + out_sizes.push_back(weight_sizes.back()); + + graph->virtual_resize(out, out_sizes); +} + void add_embedding_node( + ComputeGraph& graph, + const ValueRef indices, + const ValueRef weight, + const ValueRef out) { + std::string kernel_name = "embedding"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(out), graph.meta_ubo(indices), graph.meta_ubo(weight)}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{indices, weight}, vkapi::kRead}}, + // Shader params buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_embedding_node)); +} + +void add_embedding_legacy_node( ComputeGraph& graph, ValueRef weight, ValueRef in, ValueRef out) { check_embedding_args(graph, weight, in, out); - std::string kernel_name = "embedding"; + std::string kernel_name = "embedding_legacy"; kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); @@ -69,16 +121,25 @@ void add_embedding_node( } void embedding(ComputeGraph& graph, const std::vector& args) { - ValueRef in = args[1]; + ValueRef weight_data = args[0]; + ValueRef indices = args[1]; ValueRef out = args[5]; - ValueRef weight = prepack_standard( - graph, - args[0], - StorageType::TEXTURE_2D, - GPUMemoryLayout::TENSOR_HEIGHT_PACKED); + // Legacy implementation that accepts channels packed texture tensors for + // input/output. Needed to support some old models still in circulation. + if (graph.is_standard_channels_packed_texture_tensor(indices)) { + ValueRef weight = prepack_standard( + graph, weight_data, utils::kTexture2D, utils::kHeightPacked); + + add_embedding_legacy_node(graph, weight, indices, out); + return; + } + + ValueRef weight = + prepack_standard(graph, weight_data, utils::kBuffer, utils::kWidthPacked); - add_embedding_node(graph, weight, in, out); + // New implementation for contiguous buffer and width-packed texture tensors + add_embedding_node(graph, indices, weight, out); } REGISTER_OPERATORS { diff --git a/backends/vulkan/runtime/graph/ops/impl/Gather.cpp b/backends/vulkan/runtime/graph/ops/impl/Gather.cpp new file mode 100644 index 00000000000..584a8d0437b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Gather.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +#include + +#include + +namespace vkcompute { + +using utils::GPUMemoryLayout; +using utils::StorageType; + +void resize_gather_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef index = args.at(1).refs.at(1); + + // Output shape is the same as index shape + std::vector out_sizes = graph->sizes_of(index); + graph->virtual_resize(out, out_sizes); +} + +void add_gather_node( + ComputeGraph& graph, + const ValueRef input, + const int64_t dim, + const ValueRef index, + const ValueRef out) { + std::string kernel_name = "gather"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(out), graph.meta_ubo(input), graph.meta_ubo(index)}; + + const int64_t dim_whcn = graph.dim_of(input) - dim - 1; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{input, index}, vkapi::kRead}}, + // Shader params buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {static_cast(dim_whcn)}, + // Resize Args + {}, + // Resizing Logic + resize_gather_node)); +} + +void gather(ComputeGraph& graph, const std::vector& args) { + ValueRef input = args[0]; + ValueRef dim_ref = args[1]; + ValueRef index = args[2]; + ValueRef out = args[4]; + + int64_t dim = graph.extract_scalar(dim_ref); + + add_gather_node(graph, input, dim, index, out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.gather.default, gather); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 38d70271f4f..67c3f377f0c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -182,16 +182,14 @@ void add_addmm_naive_texture_node( // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, // Shader params buffers - { - graph.sizes_ubo(out), - graph.logical_limits_ubo(out), - graph.sizes_ubo(mat1), - graph.sizes_ubo(mat2), - graph.sizes_ubo(self), - graph.create_params_buffer(params), - }, - // Push Constants {}, + // Push Constants + {graph.sizes_pc_of(out), + graph.sizes_pc_of(mat1), + graph.sizes_pc_of(mat2), + graph.logical_limits_pc_of(out), + graph.sizes_pc_of(self), + PushConstantDataInfo(¶ms, sizeof(params))}, // Specialization Constants {graph.hashed_layout_of(out), graph.hashed_layout_of(mat1), diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index 47ecf5f18d2..6c687ec67a8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -162,14 +162,12 @@ void add_matmul_naive_texture3d_node( // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, // Shader params buffers - { - graph.sizes_ubo(out), - graph.logical_limits_ubo(out), - graph.sizes_ubo(mat1), - graph.sizes_ubo(mat2), - }, - // Push Constants {}, + // Push Constants + {graph.sizes_pc_of(out), + graph.sizes_pc_of(mat1), + graph.sizes_pc_of(mat2), + graph.logical_limits_pc_of(out)}, // Specialization Constants {graph.hashed_layout_of(out), graph.hashed_layout_of(mat1), diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index 9ac4c963bc3..329620e80e6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -109,11 +109,15 @@ void add_permute_node( { IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims); const int32_t permute_ndim = - utils::safe_downcast(permute_dims_ptr->size()); + utils::safe_downcast(permute_dims_ptr->size()); for (int32_t nchw_i = permute_ndim - 1, whcn_i = 0; nchw_i >= 0; nchw_i--, whcn_i++) { - const int32_t permute_dim_nchw = permute_dims_ptr->at(nchw_i); + int32_t permute_dim_nchw = + utils::safe_downcast(permute_dims_ptr->at(nchw_i)); + if (permute_dim_nchw < 0) { + permute_dim_nchw += permute_ndim; + } const int32_t permute_dim_whcn = permute_ndim - 1 - permute_dim_nchw; whcn_permute_dims[whcn_i] = permute_dim_whcn; diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index 250fcdd5490..d405825fad1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -59,7 +59,11 @@ void resize_pool2d_node( if (is_max_pool2d) { const ValueRef indices = args.at(0).refs.at(1); - graph->virtual_resize(indices, new_out_sizes); + // For max_pool2d variant, indices tensor will be a 0-dim tensor - only + // resize the indices tensor if this is not the case. + if (graph->sizes_of(indices).size() > 0) { + graph->virtual_resize(indices, new_out_sizes); + } } } @@ -137,7 +141,7 @@ void max_pool2d(ComputeGraph& graph, const std::vector& args) { struct DivisorParams final { int32_t divisor_override; - bool count_include_pad; + int32_t count_include_pad; }; DivisorParams create_divisor_params( @@ -148,7 +152,7 @@ DivisorParams create_divisor_params( graph.val_is_int(divisor_override) ? static_cast(graph.get_int(divisor_override)) : 0, - graph.get_bool(count_include_pad)}; + int32_t(graph.get_bool(count_include_pad))}; } void add_avg_pool2d_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp deleted file mode 100644 index 88f77261f4f..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ /dev/null @@ -1,836 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#include -#include - -#include - -namespace vkcompute { - -void resize_quantize_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - - const ValueRef out = args.at(0).refs.at(0); - const ValueRef in = args.at(1).refs.at(0); - - const std::vector in_sizes = graph->sizes_of(in); - graph->virtual_resize(out, in_sizes); -} - -utils::uvec3 quantize_per_channel_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)args; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - utils::uvec3 local_wg_size = - graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides - // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. For per-channel quantization along the batch axis, - // we need to ensure that we dispatch the correct number of workgroups in the - // Z dimension to cover all batch-channel combinations. - // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], - // local_wg_size[2]) might reduce the number of workgroups dispatched. To - // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, - // we set local_wg_size[2] = 1. - const auto input_sizes = graph->sizes_of(input); - if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && - global_workgroup_size[2] > 1) { - local_wg_size[2] = 1; - } - - return local_wg_size; -} - -utils::uvec3 quantize_block_wise_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - const ValueRef input = args.at(1).refs.at(0); - - utils::uvec3 local_wg_size = - graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides - // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. For per-channel quantization along the batch axis, - // we need to ensure that we dispatch the correct number of workgroups in the - // Z dimension to cover all batch-channel combinations. - // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], - // local_wg_size[2]) might reduce the number of workgroups dispatched. To - // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, - // we set local_wg_size[2] = 1. - const auto input_sizes = graph->sizes_of(input); - if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && - global_workgroup_size[2] > 1) { - local_wg_size[2] = 1; - } - - return local_wg_size; -} - -void add_quantize_per_tensor_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("quantize_per_tensor"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(output)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_quantize_node)); -} - -void add_quantize_per_token_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("quantize_per_token"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(output)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - int num_tokens = static_cast(graph.sizes_of(scale)[0]); - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - } else { - param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - } - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_quantize_node)); -} - -void add_quantize_per_channel_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& axis, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("quantize_per_channel"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - int axis_val = static_cast(graph.get_int(axis)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(output)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - // Normalize axis and convert from NCHW to WHCN using utility functions - const auto input_sizes = graph.sizes_of(input); - const int64_t ndim = graph.dim_of(input); - - // Normalize axis to handle negative indices - axis_val = normalize(axis_val, ndim); - - // Convert from NCHW axis to WHCN axis for shader (vulkan representation) - int axis_whcn = nchw_dim_to_whcn_dim(axis_val, ndim); - - int num_channels; - if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) { - // For batch dimension quantization in 4D tensors, pass the actual number of - // channels so the shader can correctly unfold the batch-channel folding - num_channels = static_cast(input_sizes[1]); // Channel dimension - } else { - num_channels = static_cast(input_sizes[axis_val]); - } - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - } else { - param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - } - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - quantize_per_channel_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_quantize_node)); -} - -void add_quantize_block_wise_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& block_size, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("quantize_block_wise"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(output)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - - // Convert PyTorch dimensions to WHCN order for shader - utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); - utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); - - // Calculate numBlocks: tensorSize / blockSize (both in WHCN order) - utils::ivec4 num_blocks_vec = { - tensor_size_whcn[0] / block_size_vec[0], - tensor_size_whcn[1] / block_size_vec[1], - tensor_size_whcn[2] / block_size_vec[2], - tensor_size_whcn[3] / block_size_vec[3]}; - - // Calculate blockStride: pre-computed linear strides for the block grid - utils::ivec4 block_stride_vec = { - 1, - num_blocks_vec[0], - num_blocks_vec[0] * num_blocks_vec[1], - num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), - PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), - PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - quantize_block_wise_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_quantize_node)); -} - -void quantize_per_tensor_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warning - dtype is inferred from output - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kDouble || - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - add_quantize_per_tensor_node( - graph, input, scale, zero_point, quant_min, quant_max, output); -} - -void quantize_per_token_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warning - dtype is inferred from output - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kDouble || - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Calculate number of tokens (product of all dimensions except the last one) - int64_t num_tokens = 1; - const auto input_sizes = graph.sizes_of(input); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - const auto scale_sizes = graph.sizes_of(scale); - const auto zero_point_sizes = graph.sizes_of(zero_point); - - // Calculate total number of elements in scale and zero_point tensors - int64_t scale_numel = 1; - for (size_t i = 0; i < scale_sizes.size(); i++) { - scale_numel *= scale_sizes[i]; - } - - int64_t zero_point_numel = 1; - for (size_t i = 0; i < zero_point_sizes.size(); i++) { - zero_point_numel *= zero_point_sizes[i]; - } - - // Check that the total number of elements matches num_tokens - // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors - // (size [num_tokens, 1]) - VK_CHECK_COND(scale_numel == num_tokens); - VK_CHECK_COND(zero_point_numel == num_tokens); - - add_quantize_per_token_node( - graph, input, scale, zero_point, quant_min, quant_max, output); -} - -void quantize_per_channel_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef axis = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warning - dtype is inferred from output - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kDouble || - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Normalize axis - int axis_val = static_cast(graph.get_int(axis)); - const auto input_sizes = graph.sizes_of(input); - int64_t ndim = graph.dim_of(input); - if (axis_val < 0) { - axis_val += ndim; - } - - // Verify axis is valid - VK_CHECK_COND(axis_val >= 0 && axis_val < ndim); - - // Get number of channels along the specified axis - int64_t num_channels = input_sizes[axis_val]; - - const auto scale_sizes = graph.sizes_of(scale); - const auto zero_point_sizes = graph.sizes_of(zero_point); - - // Calculate total number of elements in scale and zero_point tensors - int64_t scale_numel = 1; - for (size_t i = 0; i < scale_sizes.size(); i++) { - scale_numel *= scale_sizes[i]; - } - - int64_t zero_point_numel = 1; - for (size_t i = 0; i < zero_point_sizes.size(); i++) { - zero_point_numel *= zero_point_sizes[i]; - } - - // Check that the total number of elements matches num_channels - VK_CHECK_COND(scale_numel == num_channels); - VK_CHECK_COND(zero_point_numel == num_channels); - - add_quantize_per_channel_node( - graph, input, scale, zero_point, axis, quant_min, quant_max, output); -} - -void quantize_affine_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef block_size = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kDouble || - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Verify block_size is valid (each dimension must divide evenly into input - // size) - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - VK_CHECK_COND(block_size_list->size() == input_sizes.size()); - - for (size_t i = 0; i < input_sizes.size(); i++) { - if ((*block_size_list)[i] > 1) { - VK_CHECK_COND( - input_sizes[i] % (*block_size_list)[i] == 0, - "Input size at dimension ", - i, - " (", - input_sizes[i], - ") must be divisible by block_size at dimension ", - i, - " (", - (*block_size_list)[i], - ")"); - } - } - - add_quantize_block_wise_node( - graph, - input, - block_size, - scale, - zero_point, - quant_min, - quant_max, - output); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP( - quantized_decomposed.quantize_per_tensor.tensor, - quantize_per_tensor_impl); - VK_REGISTER_OP( - quantized_decomposed.quantize_per_token.default, quantize_per_token_impl); - VK_REGISTER_OP( - quantized_decomposed.quantize_per_channel.default, - quantize_per_channel_impl); - - // TorchAO affine quantization operators - VK_REGISTER_OP(torchao.quantize_affine.default, quantize_affine_impl); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp new file mode 100644 index 00000000000..8ebbf6dcb99 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp @@ -0,0 +1,452 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace vkcompute { + +// +// General utilities +// + +bool is_gemv(ComputeGraph* graph, const ValueRef& fp_input) { + return graph->size_at(-2, fp_input) == 1; +} + +// +// Dispatch utilities (Linear) +// + +std::tuple get_quantized_input_num_blocks( + ComputeGraph& graph, + const ValueRef input) { + std::vector input_sizes = graph.sizes_of(input); + const int64_t ndim = graph.dim_of(input); + + const int64_t M = input_sizes.at(ndim - 2); + const int64_t K = input_sizes.at(ndim - 1); + + const int64_t num_blocks_M = utils::div_up(M, int64_t(4)); + const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); + + return std::make_tuple(num_blocks_M, num_blocks_K); +} + +utils::uvec3 quantize_and_pack_4h4w_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef input = args.at(1).refs.at(0); + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(*graph, input); + + return { + utils::safe_downcast(num_blocks_K), + utils::safe_downcast(num_blocks_M), + 1u}; +} + +vkapi::ShaderInfo pick_quantize_and_pack_4h4w_with_group_sums_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef packed_int_input = args.at(0).refs.at(0); + const ValueRef fp_input = args.at(1).refs.at(0); + const ValueRef group_size = resize_args.at(0); + + const int64_t group_size_val = graph->extract_scalar(group_size); + + std::string shader_name = "quantize_and_pack_4h4w_with_group_sums"; + if (group_size_val >= 128) { + shader_name += "_o2w32"; + } else { + shader_name += "_o4w16"; + } + + add_storage_type_suffix( + shader_name, graph->storage_type_of(packed_int_input)); + add_storage_type_suffix(shader_name, graph->storage_type_of(fp_input)); + add_dtype_suffix(shader_name, graph->dtype_of(fp_input)); + + return VK_KERNEL_FROM_STR(shader_name); +} + +utils::uvec3 pick_quantize_and_pack_4h4w_with_group_sums_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef fp_input = args.at(1).refs.at(0); + // For gemv cases, skip the quantize and pack input step in favor of computing + // the quantized linear as a weight only quantized linear operation. The + // rationale for this is that gemv is a memory bound operation and may not + // necessarily benefit from quantizing the input and computing with integer + // accumulation. + if (is_gemv(graph, fp_input)) { + return {0u, 0u, 0u}; + } + + const ValueRef group_size = resize_args.at(0); + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(*graph, fp_input); + + const int64_t group_size_val = graph->extract_scalar(group_size); + const int64_t blocks_per_group = group_size_val / 4; + + const int64_t num_groups = num_blocks_K / blocks_per_group; + + return { + utils::safe_downcast(num_groups), + utils::safe_downcast(num_blocks_M), + 1u}; +} + +utils::uvec3 pick_quantize_and_pack_4h4w_with_group_sums_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef fp_input = args.at(1).refs.at(0); + // For gemv, skip the quantize input step since the quantized linear is + // computed as a weight only quantized linear operation. + if (is_gemv(graph, fp_input)) { + return {1u, 1u, 1u}; + } + + uint32_t groups_per_wg = 2u; + uint32_t workers_per_group = 32u; + + if (shader.kernel_name.find("o4w16") != std::string::npos) { + groups_per_wg = 4u; + workers_per_group = 16u; + } + + return {groups_per_wg, 1u, workers_per_group}; +} + +// +// Dispatch logic (Linear) +// + +void add_quantize_and_pack_4h4w_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const ValueRef fp_input, + const ValueRef packed_input_scale, + const ValueRef packed_input_zp, + const ValueRef input_scale_data, + const ValueRef input_zp_data, + const ValueRef packed_int_input, + const ValueRef group_size) { + // Only certain quantization types supported at the moment + VK_CHECK_COND(input_quant_config.granularity == kPerTensor); + + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, fp_input); + + float inv_scale = 1.0f / graph.extract_scalar(input_scale_data); + int32_t zp = graph.extract_scalar(input_zp_data); + + std::string shader_name = "quantize_and_pack_4h4w_per_tensor"; + add_storage_type_suffix(shader_name, graph.storage_type_of(packed_int_input)); + add_storage_type_suffix(shader_name, graph.storage_type_of(fp_input)); + add_dtype_suffix(shader_name, graph.dtype_of(fp_input)); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(shader_name), + quantize_and_pack_4h4w_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{packed_int_input, vkapi::kWrite}, {fp_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {})); +} + +void add_quantize_and_pack_4h4w_with_group_sums_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const ValueRef fp_input, + const ValueRef int_input_sums, + const ValueRef packed_input_scales, + const ValueRef packed_input_zps, + const ValueRef packed_int_input, + const ValueRef group_size) { + // Only certain quantization types supported at the moment + VK_CHECK_COND(input_quant_config.granularity == kPerChannel); + + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, fp_input); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; + + const int32_t group_size_val = graph.extract_scalar(group_size); + const int32_t blocks_per_group = utils::div_up(group_size_val, int32_t(4)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_quantize_and_pack_4h4w_with_group_sums_shader, + pick_quantize_and_pack_4h4w_with_group_sums_global_wg_size, + pick_quantize_and_pack_4h4w_with_group_sums_local_wg_size, + // Inputs and Outputs + {{{packed_int_input, int_input_sums}, vkapi::kWrite}, + {{fp_input, packed_input_scales, packed_input_zps}, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {blocks_per_group}, + // Resize args + {group_size})); +} + +// +// Dispatch utilities (Conv2d) +// + +utils::uvec3 pick_quantize_and_pack_4w4c_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef fp_input = args.at(1).refs.at(0); + + const uint32_t W = graph->size_at(-1, fp_input); + const uint32_t H = graph->size_at(-2, fp_input); + const uint32_t C = graph->size_at(-3, fp_input); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {W4, H, C4}; +} + +utils::uvec3 pick_unpack_4w4c_and_dequantize_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef fp_output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, fp_output); + const uint32_t H = graph->size_at(-2, fp_output); + const uint32_t C = graph->size_at(-3, fp_output); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {W4, H, C4}; +} + +// +// Dispatch logic (Conv2d) +// + +void add_quantize_and_pack_4w4c_node( + ComputeGraph& graph, + const ValueRef fp_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_int8_input) { + float inv_scale = 1.0f / graph.extract_scalar(input_scale); + int32_t zp = graph.extract_scalar(input_zp); + + // Get shader for quantized conv2d linear tiled + std::string kernel_name = "quantize_and_pack_4w4c_per_tensor"; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_input)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(fp_input)); + add_dtype_suffix(kernel_name, graph.dtype_of(fp_input)); + + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_quantize_and_pack_4w4c_global_wg_size, + pick_wc_square_wg_size, + // Inputs and Outputs + {{packed_int8_input, vkapi::kWrite}, {fp_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +void add_unpack_4w4c_and_dequantize_node( + ComputeGraph& graph, + const ValueRef packed_int8_output, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef fp_output) { + float scale = graph.extract_scalar(output_scale); + int32_t zp = graph.extract_scalar(output_zp); + + // Get shader for quantized conv2d linear tiled + std::string kernel_name = "unpack_4w4c_and_dequantize_per_tensor"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(fp_output)); + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + add_dtype_suffix(kernel_name, graph.dtype_of(fp_output)); + + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_output)}; + + std::vector push_constants = { + PushConstantDataInfo(&scale, sizeof(scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_unpack_4w4c_and_dequantize_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{fp_output, vkapi::kWrite}, {packed_int8_output, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// Operator Entrypoints +// + +void quantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + size_t arg_idx = 0; + size_t last_arg_idx = args.size() - 1; + const ValueRef fp_input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + (void)quant_min; + const ValueRef quant_max = args[arg_idx++]; + (void)quant_max; + const ValueRef dtype = args[arg_idx++]; + (void)dtype; + + const ValueRef int8_output = args[last_arg_idx]; + + VK_CHECK_COND( + graph.estimate_memory_layout_of(int8_output) == utils::kPackedInt8_4W4C); + + add_quantize_and_pack_4w4c_node( + graph, fp_input, scale, zero_point, int8_output); +} + +void dequantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + size_t arg_idx = 0; + size_t last_arg_idx = args.size() - 1; + const ValueRef int8_input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + (void)quant_min; + const ValueRef quant_max = args[arg_idx++]; + (void)quant_max; + const ValueRef dtype = args[arg_idx++]; + (void)dtype; + const ValueRef output_dtype = args[arg_idx++]; + (void)output_dtype; + + const ValueRef fp_output = args[last_arg_idx]; + + VK_CHECK_COND( + graph.estimate_memory_layout_of(int8_input) == utils::kPackedInt8_4W4C); + + add_unpack_4w4c_and_dequantize_node( + graph, int8_input, scale, zero_point, fp_output); +} + +void qdq8ta_conv2d_input( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef scale = args.at(idx++); + const ValueRef zero_point = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + add_quantize_and_pack_4w4c_node( + graph, fp_input, scale, zero_point, packed_int8_input); + + add_unpack_4w4c_and_dequantize_node( + graph, packed_int8_input, scale, zero_point, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP( + quantized_decomposed.quantize_per_tensor.default, + quantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_tensor.default, + dequantize_per_tensor_impl); + VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.h b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.h new file mode 100644 index 00000000000..96e9cc7c1d3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace vkcompute { + +// +// General utils +// + +bool is_gemv(ComputeGraph* graph, const ValueRef& fp_input); + +// +// Quantize, Dequantize for Linear/Matmul +// + +void add_quantize_and_pack_4h4w_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const ValueRef fp_input, + const ValueRef packed_input_scale, + const ValueRef packed_input_zp, + const ValueRef input_scale_data, + const ValueRef input_zp_data, + const ValueRef packed_int_input, + const ValueRef group_size); + +void add_quantize_and_pack_4h4w_with_group_sums_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const ValueRef fp_input, + const ValueRef int_input_sums, + const ValueRef packed_input_scales, + const ValueRef packed_input_zps, + const ValueRef packed_int_input, + const ValueRef group_size); + +// +// Quantize, Dequantize for Convolution +// + +void add_quantize_and_pack_4w4c_node( + ComputeGraph& graph, + const ValueRef fp_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_int8_input); + +void add_unpack_4w4c_and_dequantize_node( + ComputeGraph& graph, + const ValueRef packed_int8_output, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef fp_output); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedBinary.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedBinary.cpp new file mode 100644 index 00000000000..99b5880c2eb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedBinary.cpp @@ -0,0 +1,210 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace vkcompute { + +// +// Shader dispatch utilities +// + +utils::uvec3 pick_q8ta_q8ta_q8to_binary_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef packed_int8_output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, packed_int8_output); + const uint32_t H = graph->size_at(-2, packed_int8_output); + const uint32_t C = graph->size_at(-3, packed_int8_output); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {W4 * H * C4, 1, 1}; +} + +// +// Dispatch nodes +// + +void add_q8ta_q8ta_q8to_binary_node( + ComputeGraph& graph, + const ValueRef packed_int8_input_a, + const ValueRef packed_int8_input_b, + const ValueRef input_a_scale, + const ValueRef input_a_zp, + const ValueRef input_b_scale, + const ValueRef input_b_zp, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef alpha, + const ValueRef packed_int8_output, + const std::string& op_name) { + float input_a_scale_val = graph.extract_scalar(input_a_scale); + int32_t input_a_zp_val = graph.extract_scalar(input_a_zp); + float input_b_scale_val = graph.extract_scalar(input_b_scale); + int32_t input_b_zp_val = graph.extract_scalar(input_b_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + float alpha_val = 1.0f; + // String is checked since some ops pass in an unused string argument in + // place of alpha + if (is_valid(alpha) && !graph.val_is_string(alpha)) { + alpha_val = graph.extract_scalar(alpha); + } + + std::string kernel_name = op_name + "_q8ta_q8ta_q8to"; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(packed_int8_output)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_a_scale_val, sizeof(input_a_scale_val)), + PushConstantDataInfo(&input_a_zp_val, sizeof(input_a_zp_val)), + PushConstantDataInfo(&input_b_scale_val, sizeof(input_b_scale_val)), + PushConstantDataInfo(&input_b_zp_val, sizeof(input_b_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + PushConstantDataInfo(&alpha_val, sizeof(alpha_val)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_q8ta_q8ta_q8to_binary_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input_a, packed_int8_input_b}, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// High level operator impl +// + +void add_q8ta_q8ta_q8to( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input_a = args.at(idx++); + const ValueRef packed_int8_input_b = args.at(idx++); + const ValueRef input_a_scale = args.at(idx++); + const ValueRef input_a_zp = args.at(idx++); + const ValueRef input_b_scale = args.at(idx++); + const ValueRef input_b_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef alpha = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + add_q8ta_q8ta_q8to_binary_node( + graph, + packed_int8_input_a, + packed_int8_input_b, + input_a_scale, + input_a_zp, + input_b_scale, + input_b_zp, + output_scale, + output_zp, + alpha, + packed_int8_output, + "add"); +} + +// +// Test operators +// + +void add_q8ta_q8ta_q8to_test( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input_a = args.at(idx++); + const ValueRef fp_input_b = args.at(idx++); + const ValueRef input_a_scale = args.at(idx++); + const ValueRef input_a_zp = args.at(idx++); + const ValueRef input_b_scale = args.at(idx++); + const ValueRef input_b_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef alpha = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + TmpTensor packed_int8_input_a( + &graph, + graph.sizes_of(fp_input_a), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + TmpTensor packed_int8_input_b( + &graph, + graph.sizes_of(fp_input_b), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + add_quantize_and_pack_4w4c_node( + graph, fp_input_a, input_a_scale, input_a_zp, packed_int8_input_a); + + add_quantize_and_pack_4w4c_node( + graph, fp_input_b, input_b_scale, input_b_zp, packed_int8_input_b); + + std::vector add_args = { + packed_int8_input_a, + packed_int8_input_b, + input_a_scale, + input_a_zp, + input_b_scale, + input_b_zp, + output_scale, + output_zp, + alpha, + packed_int8_output}; + + add_q8ta_q8ta_q8to(graph, add_args); + + add_unpack_4w4c_and_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.add_q8ta_q8ta_q8to.default, add_q8ta_q8ta_q8to); + VK_REGISTER_OP(et_vk.add_q8ta_q8ta_q8to.test, add_q8ta_q8ta_q8to_test); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp index 51f8138485e..d7d5ad6db1e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp @@ -9,6 +9,8 @@ #include #include +#include +#include #include #include #include @@ -19,6 +21,86 @@ namespace vkcompute { // Utility functions // +bool is_pointwise(ComputeGraph* graph, const ValueRef& kernel_size) { + const auto kernel_size_list = graph->get_int_list(kernel_size); + return kernel_size_list->at(0) == 1 && kernel_size_list->at(1) == 1; +} + +bool is_s1p1d1( + ComputeGraph* graph, + const ValueRef& stride, + const ValueRef& padding, + const ValueRef& dilation) { + const auto stride_list = graph->get_int_list(stride); + const auto padding_list = graph->get_int_list(padding); + const auto dilation_list = graph->get_int_list(dilation); + if (stride_list->at(0) != 1 && stride_list->at(1) != 1) { + return false; + } + if (padding_list->at(0) != 1 && padding_list->at(1) != 1) { + return false; + } + if (dilation_list->at(0) != 1 && dilation_list->at(1) != 1) { + return false; + } + return true; +} + +bool is_s1p0d1_pointwise( + ComputeGraph* graph, + const ValueRef& kernel_size, + const ValueRef& stride, + const ValueRef& padding, + const ValueRef& dilation) { + if (is_pointwise(graph, kernel_size)) { + const auto stride_list = graph->get_int_list(stride); + const auto padding_list = graph->get_int_list(padding); + const auto dilation_list = graph->get_int_list(dilation); + if (stride_list->at(0) != 1 && stride_list->at(1) != 1) { + return false; + } + if (padding_list->at(0) != 0 && padding_list->at(1) != 0) { + return false; + } + if (dilation_list->at(0) != 1 && dilation_list->at(1) != 1) { + return false; + } + return true; + } + return false; +} + +bool should_use_im2col( + ComputeGraph* graph, + const ValueRef kernel_size, + const ValueRef groups) { + const auto kernel_size_list = graph->get_int_list(kernel_size); + + // Always use im2col for pointwise convolutions + if (kernel_size_list->at(0) * kernel_size_list->at(1) == 1) { + return true; + } + + // For large kernel sizes, the im2col matrix will be too big. Not only will + // this result in a larger footprint for the im2col matrix, but the cost of + // performing the im2col procedure will also become prohibitive. In these + // cases it is faster to just compute convolution directly without going + // through im2col. Empirically, im2col works well for 3x3 convolution and + // not for 5x5 convolution, so set the limit at 10. + if (kernel_size_list->at(0) * kernel_size_list->at(1) > 10) { + return false; + } + + // Only use im2col for non-grouped convolutions; manual experimentation shows + // that im2col becomes very slow when dealing with grouped convolutions. The + // reason for this is likely that memory access in the im2col shader becomes + // too non-linear due to needed to keep convolution groups contiguous in + // in memory. This means that the channels of the input tensor (which are + // originally contiguous in memory) will be split up during the im2col + // procedure. + return graph->get_int(groups) == 1; +} + struct Conv2DParams { utils::ivec2 kernel_size; utils::ivec2 stride; @@ -135,6 +217,43 @@ std::vector calculate_input_im2col_sizes( return {M, K}; } +std::vector calculate_packed_int8_input_im2col_sizes( + ComputeGraph* graph, + const ValueRef& input, + const ValueRef& output, + const ValueRef& kernel_size, + const ValueRef& groups) { + std::vector in_sizes = graph->sizes_of(input); + const int64_t in_channels = utils::val_at(-3, in_sizes); + + std::vector out_sizes = graph->sizes_of(output); + const int64_t out_height = utils::val_at(-2, out_sizes); + const int64_t out_width = utils::val_at(-1, out_sizes); + + // Represents the number of channel groups + const int64_t groups_val = graph->extract_scalar(groups); + // No need to div_up because in_channels % groups_val = 0 + const int64_t in_channels_per_group = in_channels / groups_val; + + const auto kernel_size_list = graph->get_int_list(kernel_size); + + // Align to the next multiple of 4 to ensure that data loads align nicely with + // texel boundaries. We want to ensure that the first data element of each + // group is at the start of its texel. + const int64_t flattened_kernel_len = utils::align_up_4( + in_channels_per_group * kernel_size_list->at(0) * + kernel_size_list->at(1)); + + // K -> flattened convolution window (repeated for each group) + const int64_t K = flattened_kernel_len * groups_val; + // M -> number of elements in 2D output plane. This is aligned to the next + // multiple of 4 since the im2col shader operates on 4x4 blocks. + const int64_t W = utils::align_up_4(out_width); + const int64_t H = out_height; + + return {K, H, W}; +} + std::vector calculate_output_im2col_sizes( ComputeGraph* graph, const ValueRef& output) { @@ -178,6 +297,33 @@ utils::uvec3 im2col_global_wg_size( return {K4, M4, 1}; } +utils::uvec3 im2col_packed_int8_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef input_im2col = args.at(0).refs.at(0); + + std::vector im2col_sizes = graph->sizes_of(input_im2col); + const uint32_t K = utils::safe_downcast(im2col_sizes[0]); + const uint32_t H = utils::safe_downcast(im2col_sizes[1]); + const uint32_t W = utils::safe_downcast(im2col_sizes[2]); + + const uint32_t K4 = utils::div_up(K, 4u); + const uint32_t W4 = utils::div_up(W, 4u); + + return {K4 * W4 * H, 1, 1}; +} + +utils::uvec3 im2col_packed_int8_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return {64, 1, 1}; +} + utils::uvec3 col2im_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -197,6 +343,229 @@ utils::uvec3 col2im_global_wg_size( return {N4, M4, 1}; } +utils::uvec3 pick_static_quantized_conv2d_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef packed_int8_output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, packed_int8_output); + const uint32_t H = graph->size_at(-2, packed_int8_output); + const uint32_t C = graph->size_at(-3, packed_int8_output); + + uint32_t C_per_tile = 4; + uint32_t W_per_tile = 4; + + if (shader.kernel_name.find("linear") != std::string::npos) { + C_per_tile = 8; + } + + const uint32_t num_W_tiles = utils::div_up(W, W_per_tile); + const uint32_t num_C_tiles = utils::div_up(C, C_per_tile); + + return {num_C_tiles, num_W_tiles, H}; +} + +utils::uvec3 pick_static_quantized_conv2d_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); +} + +utils::uvec3 int8_conv2d_dw_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef packed_int8_output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, packed_int8_output); + const uint32_t H = graph->size_at(-2, packed_int8_output); + const uint32_t C = graph->size_at(-3, packed_int8_output); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {C4 * W4 * H, 1, 1}; +} + +// +// Prepack nodes +// + +ValueRef prepack_quantized_conv2d_weight( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef weight_data, + const ValueRef input, + const ValueRef output, + const ValueRef groups, + const ValueRef kernel_size) { + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + const int32_t groups_val = graph.get_int(groups); + + const int64_t OC = graph.size_at(-3, output); + const int64_t IC = graph.size_at(-3, input) / groups_val; + + int64_t K_h; + int64_t K_w; + + { + const auto kernel_size_list = graph.get_int_list(kernel_size); + K_h = kernel_size_list->at(0); + K_w = kernel_size_list->at(1); + } + + const int64_t num_blocks_OC = utils::div_up_4(OC); + const int64_t num_blocks_IC = utils::div_up_4(IC); + + const int64_t num_blocks_y = num_blocks_IC * K_h; + const int64_t num_blocks_x = K_w * num_blocks_OC; + + // The packed tensor arranges blocks as [OC_blocks * K_total, IC_blocks] + const int64_t output_height = num_blocks_y; + const int64_t output_width = num_blocks_x * 4; + + // Store the original sizes of the weight data to pass to the shader + utils::ivec4 orig_sizes = { + utils::safe_downcast(OC), + utils::safe_downcast(K_h), + utils::safe_downcast(K_w), + utils::safe_downcast(IC)}; + + std::vector packed_weight_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef packed_weight = graph.add_tensor( + packed_weight_sizes, + vkcompute::vkapi::kInt, + storage_type, + utils::kWidthPacked); + + utils::uvec3 global_wg_size = { + utils::safe_downcast(num_blocks_x), + utils::safe_downcast(num_blocks_y), + 1u}; + + std::string kernel_name = "pack_q8_conv2d_weights"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + weight_data, + packed_weight, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(packed_weight), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec4))})); + + return packed_weight; +} + +ValueRef prepack_quantized_conv2d_dw_weight( + ComputeGraph& graph, + const QuantizationConfig& weight_quant_config, + const ValueRef weight_data, + const ValueRef kernel_size) { + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + std::vector weight_orig_sizes = graph.sizes_of(weight_data); + const int64_t ndim = graph.dim_of(weight_data); + + // For depthwise convolution, expect weight layout [K_h, aligned_K_w, OC] + VK_CHECK_COND(ndim == 3); + int64_t K_h = weight_orig_sizes.at(0); + int64_t K_w = weight_orig_sizes.at(1); + int64_t aligned_K_w = utils::align_up_4(K_w); + int64_t OC = weight_orig_sizes.at(2); + + // The packing format packs the weight tensor into blocks of 4 output channels + // (OC) and 4 kernel elements (K_h * aligned_K_w) + int64_t OC_per_block = 4; + int64_t K_per_block = 4; + + // To figure out the size of the output tensor, determine the number of blocks + // along each dimension. + const int64_t total_K_elements = K_h * aligned_K_w; + const int64_t num_blocks_K = utils::div_up(total_K_elements, K_per_block); + const int64_t num_blocks_OC = utils::div_up(OC, OC_per_block); + + // The blocks are arranged in a transposed manner, such that the transposed + // weight block is indexed like packed_weights[k4][oc4] - this is to allow for + // optimal memory coalescing when computing the depthwise convolution. + int64_t output_height = num_blocks_K; + // The base dtype of the packed tensor is int32 (each int32 contains 4x 8bit + // values) and each block is represented as a ivec4. Therefore the width dim + // of the packed tensor is multiplied by 4. + int64_t output_width = num_blocks_OC * 4; + + // Store the original sizes of the weight data to pass to the shader + utils::ivec3 orig_sizes = { + utils::safe_downcast(K_h), + utils::safe_downcast(K_w), + utils::safe_downcast(OC)}; + + std::vector packed_weight_sizes{output_height, output_width}; + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (output_width > max_extent * 4 || output_height > max_extent) { + storage_type = utils::kBuffer; + } + + ValueRef packed_weight = graph.add_tensor( + packed_weight_sizes, + vkcompute::vkapi::kInt, + storage_type, + utils::kWidthPacked); + + utils::uvec3 global_wg_size = { + utils::safe_downcast(num_blocks_OC), + utils::safe_downcast(num_blocks_K), + 1u}; + + std::string kernel_name = "pack_q8_conv2d_dw_weights"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + weight_data, + packed_weight, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(packed_weight), + PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec3))})); + + return packed_weight; +} + // // Dispatch nodes // @@ -251,6 +620,57 @@ void add_input_im2col_node( nullptr)); } +void add_input_im2col_packed_int8_node( + ComputeGraph& graph, + const ValueRef input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef output, + const ValueRef input_im2col) { + Conv2DParams conv_params = create_conv2d_params( + graph, input, output, kernel_size, stride, padding, dilation, groups); + + float inv_scale = 1.0f / graph.extract_scalar(input_scale); + int32_t zp = graph.extract_scalar(input_zp); + + std::string kernel_name = "im2col_packed_int8"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(input_im2col)); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(input_im2col), + graph.sizes_ubo(output), + graph.sizes_ubo(input), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + im2col_packed_int8_global_wg_size, + im2col_packed_int8_local_wg_size, + // Inputs and Outputs + {{input_im2col, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + void add_quantize_and_pack_im2col_node( ComputeGraph& graph, const ValueRef input_image, @@ -468,6 +888,181 @@ void add_conv2d_q8ta_q8csw_linear_node( nullptr)); } +void add_conv2d_q8ta_q8csw_q8to_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef packed_int8_input_im2col, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef packed_int8_output) { + Conv2DParams conv_params = create_conv2d_params( + graph, + packed_int8_input, + packed_int8_output, + kernel_size, + stride, + padding, + dilation, + groups); + + const bool use_im2col = should_use_im2col(&graph, kernel_size, groups); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + std::string kernel_name = use_im2col ? "conv2d_q8ta_q8csw_q8to_linear_tiled" + : "conv2d_q8ta_q8csw_q8to"; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(packed_int8_output), + graph.sizes_ubo(packed_int8_input_im2col), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_static_quantized_conv2d_global_wg_size, + pick_static_quantized_conv2d_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input_im2col, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +void add_conv2d_dw_q8ta_q8csw_q8to_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef packed_int8_output) { + Conv2DParams conv_params = create_conv2d_params( + graph, + packed_int8_input, + packed_int8_output, + kernel_size, + stride, + padding, + dilation, + groups); + + // Verify this is actually a depthwise convolution + const int64_t groups_val = graph.extract_scalar(groups); + const int64_t in_channels = graph.size_at(-3, packed_int8_input); + VK_CHECK_COND(groups_val == in_channels); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + std::string kernel_name = "conv2d_dw_q8ta_q8csw_q8to"; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(packed_int8_output), + graph.sizes_ubo(packed_int8_input), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + int8_conv2d_dw_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + // // High level operator impl // @@ -564,16 +1159,12 @@ void quantized_conv2d_impl( ValueRef packed_weight_sums = prepack_standard( graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); - // Allocate quantized + packed im2col matrix for input - const int64_t num_blocks_M = utils::div_up_4(input_im2col_sizes.at(0)); - const int64_t num_blocks_K = utils::div_up_4(input_im2col_sizes.at(1)); - TmpTensor input_int_im2col( &graph, - {num_blocks_M, num_blocks_K * 4}, - vkapi::kInt, + input_im2col_sizes, + vkapi::kInt8x4, utils::kBuffer, - utils::kWidthPacked); + utils::kPackedInt8_4H4W); add_quantize_and_pack_im2col_node( graph, @@ -687,9 +1278,303 @@ void conv2d_q8csw(ComputeGraph& graph, const std::vector& args) { output_image); } +// Implementation for statically quantized conv2d, which expects input, weight, +// and output tensors to all have packed int8 dtype/memory layout. +void static_quantized_conv2d_impl( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const QuantizationConfig& weight_quant_config, + const QuantizationConfig& output_quant_config, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef weight_data, + const ValueRef weight_sums_data, + const ValueRef weight_scales_data, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef packed_int8_output) { + // Currently, only certain quantization configs are supported + VK_CHECK_COND(input_quant_config.granularity == kPerTensor); + VK_CHECK_COND(input_quant_config.nbits == 8); + + VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + VK_CHECK_COND(output_quant_config.granularity == kPerTensor); + VK_CHECK_COND(output_quant_config.nbits == 8); + + // Check for depthwise conv + const int64_t groups_val = graph.extract_scalar(groups); + const int64_t in_channels = graph.size_at(-3, packed_int8_input); + + // Depthwise convs have a specialized implementation, since the regular conv + // implementations requires that the number of input and output channels per + // groups is a multiple of 4. This is so that all values that are part of the + // same 4Wx4C block have the same group index. + const bool is_depthwise = (groups_val == in_channels); + + const bool use_im2col = should_use_im2col(&graph, kernel_size, groups); + // For pointwise convolution with stride = 1, padding = 0, dilation = 1, the + // input tensor is already equivalent to its im2col representation. In this + // case we can skip the im2col procedure and pass in the input image to the + // convolution_as_matmul implementation directly. + const bool is_optimizable_pw = + is_s1p0d1_pointwise(&graph, kernel_size, stride, padding, dilation); + + ValueRef packed_weight; + if (is_depthwise) { + packed_weight = prepack_quantized_conv2d_dw_weight( + graph, weight_quant_config, weight_data, kernel_size); + } else if (use_im2col) { + packed_weight = prepack_quantized_linear_weight( + graph, weight_quant_config, weight_data); + } else { + packed_weight = prepack_quantized_conv2d_weight( + graph, + weight_quant_config, + weight_data, + packed_int8_input, + packed_int8_output, + groups, + kernel_size); + } + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + // See quantized_conv2d_impl for why this is needed + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + // Depthwise conv path + if (is_depthwise) { + add_conv2d_dw_q8ta_q8csw_q8to_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output); + return; + } + + std::vector input_im2col_sizes = + calculate_packed_int8_input_im2col_sizes( + &graph, packed_int8_input, packed_int8_output, kernel_size, groups); + + ValueRef packed_int8_input_im2col = packed_int8_input; + if (use_im2col && !is_optimizable_pw) { + TmpTensor packed_int8_input_im2col_tensor( + &graph, + input_im2col_sizes, + vkapi::kInt8x4, + graph.storage_type_of(packed_int8_input), + utils::kPackedInt8_4W4C); + + packed_int8_input_im2col = packed_int8_input_im2col_tensor.vref; + + add_input_im2col_packed_int8_node( + graph, + packed_int8_input, + input_scale, + input_zp, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output, + packed_int8_input_im2col); + } + + add_conv2d_q8ta_q8csw_q8to_node( + graph, + packed_int8_input, + packed_int8_input_im2col, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output); +} + +void conv2d_q8ta_q8csw_q8to( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + QuantizationConfig input_quant_config(8, kPerTensor, {}); + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + QuantizationConfig output_quant_config(8, kPerTensor, {}); + + static_quantized_conv2d_impl( + graph, + input_quant_config, + weight_quant_config, + output_quant_config, + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output); +} + +// +// Test operators +// + +void conv2d_q8ta_q8csw_q8to_test( + ComputeGraph& graph, + const std::vector& args, + utils::StorageType io_storage_type) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + io_storage_type, + utils::kPackedInt8_4W4C); + + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + io_storage_type, + utils::kPackedInt8_4W4C); + + add_quantize_and_pack_4w4c_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + std::vector conv2d_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; + + conv2d_q8ta_q8csw_q8to(graph, conv2d_args); + + add_unpack_4w4c_and_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + +void conv2d_q8ta_q8csw_q8to_test_buffer( + ComputeGraph& graph, + const std::vector& args) { + conv2d_q8ta_q8csw_q8to_test(graph, args, utils::kBuffer); +} + +void conv2d_q8ta_q8csw_q8to_test_texture( + ComputeGraph& graph, + const std::vector& args) { + conv2d_q8ta_q8csw_q8to_test(graph, args, utils::kBuffer); +} + REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw.default, conv2d_q8ta_q8csw); VK_REGISTER_OP(et_vk.conv2d_q8csw.default, conv2d_q8csw); + VK_REGISTER_OP( + etvk.conv2d_q8ta_q8csw_q8to.test_texture, + conv2d_q8ta_q8csw_q8to_test_texture); + VK_REGISTER_OP( + etvk.conv2d_q8ta_q8csw_q8to.test_buffer, + conv2d_q8ta_q8csw_q8to_test_buffer); + VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw_q8to.default, conv2d_q8ta_q8csw_q8to); + VK_REGISTER_OP( + et_vk.conv2d_q8ta_q8csw_q8to_dw.default, conv2d_q8ta_q8csw_q8to); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.h b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.h new file mode 100644 index 00000000000..c3ea15bc318 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.h @@ -0,0 +1,18 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +// This header is intentionally empty as all quantize/dequantize functions +// have been moved to QuantizeDequantize.h + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 7fbfcee5cb1..7a42d463f2a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -19,10 +20,6 @@ namespace vkcompute { // Shader dispatch utilities // -bool is_gemv(ComputeGraph* graph, const ValueRef& fp_input) { - return graph->size_at(-2, fp_input) == 1; -} - void resize_linear_qw_node( ComputeGraph* graph, const std::vector& args, @@ -77,6 +74,10 @@ utils::uvec3 quantized_linear_global_wg_size( M_per_tile = 1; } + if (shader.kernel_name.find("q8ta_q8csw_tiled") != std::string::npos) { + N_per_tile = 8; + } + const uint32_t num_N_tiles = utils::div_up(N, N_per_tile); const uint32_t num_M_tiles = utils::div_up(M, M_per_tile); @@ -101,120 +102,6 @@ utils::uvec3 quantized_linear_local_wg_size( } } -std::tuple get_quantized_input_num_blocks( - ComputeGraph& graph, - const ValueRef input) { - std::vector input_sizes = graph.sizes_of(input); - const int64_t ndim = graph.dim_of(input); - - const int64_t M = input_sizes.at(ndim - 2); - const int64_t K = input_sizes.at(ndim - 1); - - const int64_t num_blocks_M = utils::div_up(M, int64_t(4)); - const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); - - return std::make_tuple(num_blocks_M, num_blocks_K); -} - -utils::uvec3 quant_pack_input_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - const ValueRef input = args.at(1).refs.at(0); - int64_t num_blocks_M, num_blocks_K; - std::tie(num_blocks_M, num_blocks_K) = - get_quantized_input_num_blocks(*graph, input); - - return { - utils::safe_downcast(num_blocks_K), - utils::safe_downcast(num_blocks_M), - 1u}; -} - -vkapi::ShaderInfo pick_quantize_and_pack_input_with_sums_shader( - ComputeGraph* graph, - const std::vector& args, - const std::vector& resize_args) { - const ValueRef packed_int_input = args.at(0).refs.at(0); - const ValueRef fp_input = args.at(1).refs.at(0); - const ValueRef group_size = resize_args.at(0); - - const int64_t group_size_val = graph->extract_scalar(group_size); - - std::string shader_name = "quantize_and_pack_linear_input_with_sums"; - if (group_size_val >= 128) { - shader_name += "_o2w32"; - } else { - shader_name += "_o4w16"; - } - - add_storage_type_suffix( - shader_name, graph->storage_type_of(packed_int_input)); - add_storage_type_suffix(shader_name, graph->storage_type_of(fp_input)); - add_dtype_suffix(shader_name, graph->dtype_of(fp_input)); - - return VK_KERNEL_FROM_STR(shader_name); -} - -utils::uvec3 pick_quantize_and_pack_input_with_sums_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - const ValueRef fp_input = args.at(1).refs.at(0); - // For gemv cases, skip the quantize and pack input step in favor of computing - // the quantized linear as a weight only quantized linear operation. The - // rationale for this is that gemv is a memory bound operation and may not - // necessarily benefit from quantizing the input and computing with integer - // accumulation. - if (is_gemv(graph, fp_input)) { - return {0u, 0u, 0u}; - } - - const ValueRef group_size = resize_args.at(0); - int64_t num_blocks_M, num_blocks_K; - std::tie(num_blocks_M, num_blocks_K) = - get_quantized_input_num_blocks(*graph, fp_input); - - const int64_t group_size_val = graph->extract_scalar(group_size); - const int64_t blocks_per_group = group_size_val / 4; - - const int64_t num_groups = num_blocks_K / blocks_per_group; - - return { - utils::safe_downcast(num_groups), - utils::safe_downcast(num_blocks_M), - 1u}; -} - -utils::uvec3 pick_quantize_and_pack_input_with_sums_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef fp_input = args.at(1).refs.at(0); - // For gemv, skip the quantize input step since the quantized linear is - // computed as a weight only quantized linear operation. - if (is_gemv(graph, fp_input)) { - return {1u, 1u, 1u}; - } - - uint32_t groups_per_wg = 2u; - uint32_t workers_per_group = 32u; - - if (shader.kernel_name.find("o4w16") != std::string::npos) { - groups_per_wg = 4u; - workers_per_group = 16u; - } - - return {groups_per_wg, 1u, workers_per_group}; -} - vkapi::ShaderInfo pick_linear_qw_shader( ComputeGraph* graph, const std::vector& args, @@ -417,7 +304,7 @@ ValueRef prepack_quantized_linear_weight( /* * Shader dispatch for linear with quantized weight but fp activations. */ -DynamicDispatchNode make_linear_qw_node( +void add_linear_qw_node( ComputeGraph& graph, const QuantizationConfig& weight_quant_config, const ValueRef fp_input, @@ -454,7 +341,7 @@ DynamicDispatchNode make_linear_qw_node( const ValueRef is_4bit_flag = weight_quant_config.nbits == 4 ? group_size : kDummyValueRef; - return DynamicDispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, pick_linear_qw_shader, quantized_linear_global_wg_size, @@ -472,98 +359,10 @@ DynamicDispatchNode make_linear_qw_node( // Resize args {is_4bit_flag, weight_data}, // Resizing Logic - resize_linear_qw_node); + resize_linear_qw_node)); } -DynamicDispatchNode make_quantize_and_pack_linear_input_node( - ComputeGraph& graph, - const QuantizationConfig& input_quant_config, - const ValueRef fp_input, - const ValueRef packed_input_scale, - const ValueRef packed_input_zp, - const ValueRef input_scale_data, - const ValueRef input_zp_data, - const ValueRef packed_int_input, - const ValueRef group_size) { - // Only certain quantization types supported at the moment - VK_CHECK_COND(input_quant_config.granularity == kPerTensor); - - int64_t num_blocks_M, num_blocks_K; - std::tie(num_blocks_M, num_blocks_K) = - get_quantized_input_num_blocks(graph, fp_input); - - float inv_scale = 1.0f / graph.extract_scalar(input_scale_data); - int32_t zp = graph.extract_scalar(input_zp_data); - - std::string shader_name = "quantize_and_pack_linear_input_per_tensor"; - add_storage_type_suffix(shader_name, graph.storage_type_of(packed_int_input)); - add_storage_type_suffix(shader_name, graph.storage_type_of(fp_input)); - add_dtype_suffix(shader_name, graph.dtype_of(fp_input)); - - vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; - - std::vector push_constants = { - PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), - PushConstantDataInfo(&zp, sizeof(zp)), - }; - - return DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(shader_name), - quant_pack_input_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{packed_int_input, vkapi::kWrite}, {fp_input, vkapi::kRead}}, - // Shader params buffers - param_buffers, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize args - {}); -} - -DynamicDispatchNode make_quantize_and_pack_linear_input_with_sums_node( - ComputeGraph& graph, - const QuantizationConfig& input_quant_config, - const ValueRef fp_input, - const ValueRef int_input_sums, - const ValueRef packed_input_scales, - const ValueRef packed_input_zps, - const ValueRef packed_int_input, - const ValueRef group_size) { - // Only certain quantization types supported at the moment - VK_CHECK_COND(input_quant_config.granularity == kPerChannel); - - int64_t num_blocks_M, num_blocks_K; - std::tie(num_blocks_M, num_blocks_K) = - get_quantized_input_num_blocks(graph, fp_input); - - vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; - - const int32_t group_size_val = graph.extract_scalar(group_size); - const int32_t blocks_per_group = utils::div_up(group_size_val, int32_t(4)); - - return DynamicDispatchNode( - graph, - pick_quantize_and_pack_input_with_sums_shader, - pick_quantize_and_pack_input_with_sums_global_wg_size, - pick_quantize_and_pack_input_with_sums_local_wg_size, - // Inputs and Outputs - {{{packed_int_input, int_input_sums}, vkapi::kWrite}, - {{fp_input, packed_input_scales, packed_input_zps}, vkapi::kRead}}, - // Shader params buffers - param_buffers, - // Push Constants - {}, - // Specialization Constants - {blocks_per_group}, - // Resize args - {group_size}); -} - -DynamicDispatchNode make_linear_qa_qw_node( +void add_linear_qa_qw_node( ComputeGraph& graph, const QuantizationConfig& input_quant_config, const QuantizationConfig& weight_quant_config, @@ -611,8 +410,7 @@ DynamicDispatchNode make_linear_qa_qw_node( apply_bias = 0; } - // Add the compute node - return DynamicDispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), quantized_linear_global_wg_size, @@ -634,10 +432,10 @@ DynamicDispatchNode make_linear_qa_qw_node( // Resize args {fp_input}, // Resizing Logic - nullptr); + nullptr)); } -DynamicDispatchNode make_linear_dqa_qw_node( +void add_linear_dqa_qw_node( ComputeGraph& graph, const QuantizationConfig& input_quant_config, const QuantizationConfig& weight_quant_config, @@ -681,8 +479,7 @@ DynamicDispatchNode make_linear_dqa_qw_node( const ValueRef is_4bit_flag = weight_quant_config.nbits == 4 ? group_size : kDummyValueRef; - // Add the compute node - return DynamicDispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, pick_linear_dqa_qw_shader, quantized_linear_global_wg_size, @@ -708,7 +505,7 @@ DynamicDispatchNode make_linear_dqa_qw_node( // Resize args {is_4bit_flag, weight_data}, // Resizing Logic - resize_linear_qw_node); + resize_linear_qw_node)); } // @@ -766,7 +563,7 @@ void quantized_linear_impl( // 2. Input is not quantized if (!graph.can_use_int8_dot_product() || input_quant_config.granularity == kNoQuantization) { - DynamicDispatchNode linear_qw_node(make_linear_qw_node( + add_linear_qw_node( graph, weight_quant_config, fp_input, @@ -777,9 +574,8 @@ void quantized_linear_impl( group_size, bias_data, packed_bias, - output)); + output); - graph.execute_nodes().emplace_back(new DynamicDispatchNode(linear_qw_node)); return; } // Otherwise, use input and weight quantized linear computed with integer @@ -802,39 +598,27 @@ void quantized_linear_impl( graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); // Allocate temporary tensor to store quantized and packed input - - int64_t num_blocks_M, num_blocks_K; - std::tie(num_blocks_M, num_blocks_K) = - get_quantized_input_num_blocks(graph, fp_input); - - const int64_t int_input_height = num_blocks_M; - const int64_t int_input_width = num_blocks_K * 4; - TmpTensor packed_int_input( &graph, - {int_input_height, int_input_width}, - vkapi::kInt, + graph.sizes_of(fp_input), + vkapi::kInt8x4, utils::kBuffer, - utils::kWidthPacked); + utils::kPackedInt8_4H4W); // Non dynamically quantized input case if (!input_quant_config.is_dynamic) { - DynamicDispatchNode quantize_and_pack_linear_node( - make_quantize_and_pack_linear_input_node( - graph, - input_quant_config, - fp_input, - packed_input_scale, - packed_input_zp, - input_scale, - input_zp, - packed_int_input, - group_size)); - - graph.execute_nodes().emplace_back( - new DynamicDispatchNode(quantize_and_pack_linear_node)); - - DynamicDispatchNode linear_qa_qw_node(make_linear_qa_qw_node( + add_quantize_and_pack_4h4w_node( + graph, + input_quant_config, + fp_input, + packed_input_scale, + packed_input_zp, + input_scale, + input_zp, + packed_int_input, + group_size); + + add_linear_qa_qw_node( graph, input_quant_config, weight_quant_config, @@ -851,10 +635,7 @@ void quantized_linear_impl( group_size, bias_data, packed_bias, - output)); - - graph.execute_nodes().emplace_back( - new DynamicDispatchNode(linear_qa_qw_node)); + output); return; } @@ -875,21 +656,17 @@ void quantized_linear_impl( utils::kBuffer, utils::kWidthPacked); - DynamicDispatchNode quantize_and_pack_input_with_sums_node( - make_quantize_and_pack_linear_input_with_sums_node( - graph, - input_quant_config, - fp_input, - int_input_sums, - packed_input_scale, - packed_input_zp, - packed_int_input, - group_size)); - - graph.execute_nodes().emplace_back( - new DynamicDispatchNode(quantize_and_pack_input_with_sums_node)); - - DynamicDispatchNode linear_dqa_qw_node(make_linear_dqa_qw_node( + add_quantize_and_pack_4h4w_with_group_sums_node( + graph, + input_quant_config, + fp_input, + int_input_sums, + packed_input_scale, + packed_input_zp, + packed_int_input, + group_size); + + add_linear_dqa_qw_node( graph, input_quant_config, weight_quant_config, @@ -907,10 +684,7 @@ void quantized_linear_impl( group_size, bias_data, packed_bias, - output)); - - graph.execute_nodes().emplace_back( - new DynamicDispatchNode(linear_dqa_qw_node)); + output); } void linear_q8ta_q8csw(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp index 89c9e847724..18958ccc3ce 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp @@ -61,15 +61,15 @@ utils::uvec3 linear_qcsnw_tiled_global_wg_size( std::vector mat1_sizes = graph->sizes_of(mat1); const int64_t M = utils::val_at(-2, mat1_sizes); - uint32_t out_tile_nrows = 4; - if (M % 6 == 0) { - out_tile_nrows = 2; + uint32_t out_tile_nrows = 1; + if (M % 3 == 0) { + out_tile_nrows = 3; } else if (M % 4 == 0) { out_tile_nrows = 4; - } else if (M % 1 == 0) { - out_tile_nrows = 1; + } else if (M % 2 == 0) { + out_tile_nrows = 2; } else { - out_tile_nrows = 4; + out_tile_nrows = 1; } // Number of output texels in the output tile @@ -225,7 +225,8 @@ void add_linear_qcs8w_node( } else { pcs = { graph.logical_limits_pc_of(out_W_packed), - graph.sizes_pc_of(mat1_W_packed)}; + graph.sizes_pc_of(mat1_W_packed), + graph.sizes_pc_of(q_mat2)}; } const utils::uvec3 global_wg = { @@ -308,19 +309,19 @@ void add_linear_qcsnw_tiled_node( std::vector mat1_sizes = graph.sizes_of(mat1); const int64_t M = utils::val_at(-2, mat1_sizes); - uint32_t out_tile_nrows = 4; - if (M % 6 == 0) { - kernel_name += "_o4x2"; - out_tile_nrows = 2; + uint32_t out_tile_nrows = 1; + if (M % 3 == 0) { + kernel_name += "_o4x3"; + out_tile_nrows = 3; } else if (M % 4 == 0) { kernel_name += "_o4x4"; out_tile_nrows = 4; - } else if (M % 1 == 0) { + } else if (M % 2 == 0) { + kernel_name += "_o4x2"; + out_tile_nrows = 2; + } else { kernel_name += "_o4x1"; out_tile_nrows = 1; - } else { - kernel_name += "_o4x4"; - out_tile_nrows = 4; } // Number of output texels in the output tile @@ -351,7 +352,9 @@ void add_linear_qcsnw_tiled_node( // Shader params buffers {}, // Push Constants - {{graph.sizes_pc_of(out), graph.sizes_pc_of(mat1)}}, + {{graph.sizes_pc_of(out), + graph.sizes_pc_of(mat1), + graph.sizes_pc_of(q_mat2)}}, // Specialization Constants {}, // Resize Args diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp index 6ad1d7f371d..856783ce219 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp @@ -52,6 +52,26 @@ void resize_reduce2d_node( graph->virtual_resize(out, new_sizes); } +void resize_reduce_per_row_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + + const bool keepdim = graph->extract_scalar(resize_args.at(0)); + + std::vector new_sizes = graph->sizes_of(in); + if (keepdim) { + // Per-row reduction always reduces along the last dimension (width) + new_sizes.back() = 1; + } else { + // Remove the last dimension + new_sizes.pop_back(); + } + graph->virtual_resize(out, new_sizes); +} + utils::uvec3 reduce_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -237,12 +257,89 @@ void add_reduce2d_node( resize_reduce2d_node)); } +utils::uvec3 reduce_per_row_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + return {1u, utils::safe_downcast(graph->numel_of(out)), 1u}; +} + +utils::uvec3 reduce_per_row_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + uint32_t outputs_per_wg = 1u; + uint32_t workers_per_output = 64u; + + return {workers_per_output, outputs_per_wg, 1u}; +} + +void add_reduce_per_row_node( + ComputeGraph& graph, + const ValueRef input, + const ValueRef keepdim_ref, + const ValueRef output, + const std::string& op_name) { + std::string kernel_name = op_name + "_per_row"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(output), + graph.meta_ubo(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + // Global workgroup size function + reduce_per_row_global_wg_size, + // Local workgroup size function + reduce_per_row_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {keepdim_ref}, + // Resizing Logic + resize_reduce_per_row_node)); +} + #define DEFINE_REDUCE_FN(op_name, out_arg_idx) \ void op_name(ComputeGraph& graph, const std::vector& args) { \ - const std::vector dims_list = \ - graph.extract_int_or_symint_list(args[1]); \ + std::vector dims_list; \ + if (graph.val_is_not_none(args[1])) { \ + dims_list = graph.extract_int_or_symint_list(args[1]); \ + } else if (graph.dim_of(args[0]) == 1) { \ + dims_list = {-1}; \ + } else { \ + VK_THROW("dims_list=None only supported for 1D tensors"); \ + } \ if (dims_list.size() == 1) { \ - const int64_t dim_val = dims_list.at(0); \ + int64_t dim_val = dims_list.at(0); \ + int64_t ndim = graph.dim_of(args[0]); \ + if ((dim_val == -1 || dim_val == ndim - 1) && \ + graph.is_buffer_storage(args[0])) { \ + return add_reduce_per_row_node( \ + graph, args[0], args[2], args[out_arg_idx], #op_name); \ + } \ const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \ return add_reduce_node( \ graph, args[0], dim_ref, args[out_arg_idx], #op_name); \ diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.h b/backends/vulkan/runtime/graph/ops/impl/Reduce.h new file mode 100644 index 00000000000..7d38e438d31 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace vkcompute { + +void add_reduce_per_row_node( + ComputeGraph& graph, + const ValueRef input, + const ValueRef keepdim_ref, + const ValueRef output, + const std::string& op_name); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp index 72c1637a2c9..2b42c0bd150 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp @@ -14,8 +14,6 @@ #include #include -#include - namespace vkcompute { namespace { diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp index fcc8fe4b265..e1914f350b7 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -43,10 +43,17 @@ utils::uvec3 rotary_embedding_global_wg_size( const ValueRef xq_out = args.at(0).refs.at(0); - utils::uvec3 global_wg_size = graph->logical_limits_of(xq_out); - global_wg_size[0] /= 2; + // Head dim texel size + const uint32_t D4 = utils::div_up_4(graph->size_at(-1, xq_out)); + // Divide by 2 since each invocation computes 2 output locations + const uint32_t D8 = utils::div_up(D4, uint32_t(2)); - return global_wg_size; + // Number of query heads + const uint32_t QH = graph->size_at(-2, xq_out); + // Input tokens sequence length + const uint32_t S = graph->size_at(-3, xq_out); + + return {D8, QH, S}; } void add_rotary_embedding_node( @@ -73,8 +80,14 @@ void add_rotary_embedding_node( VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin)); std::string kernel_name = "rotary_embedding"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(xq_out)); add_dtype_suffix(kernel_name, graph.dtype_of(xq_out)); + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(xq_out), + graph.meta_ubo(xk_out), + graph.meta_ubo(freqs_cos)}; + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -84,7 +97,7 @@ void add_rotary_embedding_node( {{{xq_out, xk_out}, vkapi::kWrite}, {{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}}, // Parameter buffers - {graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)}, + param_ubos, // Push Constants {}, // Specialization Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 8edaebd11ff..d28d2c90fcb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -50,7 +50,7 @@ void resize_compute_attn_weights_node( std::vector out_sizes = { 1, // batch num_q_heads, - seq_len, + utils::align_up_4(seq_len), utils::align_up_4(context_len)}; graph->virtual_resize(attn_weights, out_sizes); @@ -282,6 +282,7 @@ void add_sdpa_kv_cache_update_node( const ValueRef projected, const ValueRef cache) { std::string kernel_name("sdpa_kv_cache_update"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(cache)); add_storage_type_suffix(kernel_name, graph.storage_type_of(projected)); add_dtype_suffix(kernel_name, graph.dtype_of(projected)); @@ -470,10 +471,31 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { VK_CHECK_COND(graph.val_is_none(attn_mask)); const int64_t num_q_heads = graph.size_at(-2, q_projected); - const int64_t max_seq_len = graph.size_at(-3, q_projected); - + int64_t max_seq_len = graph.size_at(-3, q_projected); const int64_t max_context_len = graph.size_at(-3, k_cache); + const utils::StorageType attn_weights_storage = + graph.storage_type_of(q_projected); + + // If using buffer storage for attn weights, we need to ensure that the buffer + // numel limit is not exceeded. If needed, manually adjust max_seq_len based + // on the buffer numel limit. + if (attn_weights_storage == utils::kBuffer) { + const int64_t max_buffer_numel = graph.max_buffer_numel(); + if (num_q_heads * max_seq_len * max_context_len >= max_buffer_numel) { + // Compute the maximum possible value for max_seq_len that will hit + // the buffer numel limit. + max_seq_len = max_buffer_numel / (num_q_heads * max_context_len); + // Adjust down to the nearest multiple of 4 to make sure the limit is + // not hit. + if (max_seq_len % 4 != 0) { + max_seq_len = (max_seq_len / 4) * 4; + } else { + max_seq_len -= 4; + } + } + } + std::vector attn_weight_full_sizes = { 1, // batch num_q_heads, @@ -484,14 +506,14 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { &graph, attn_weight_full_sizes, graph.dtype_of(q_projected), - graph.storage_type_of(q_projected), + attn_weights_storage, utils::kWidthPacked); TmpTensor attn_weights_softmax( &graph, attn_weight_full_sizes, graph.dtype_of(q_projected), - graph.storage_type_of(q_projected), + attn_weights_storage, utils::kWidthPacked); add_sdpa_compute_attn_weights_node( @@ -525,10 +547,11 @@ void sdpa_with_kv_cache_impl( (void)sequence_len; - const ValueRef k_cache = prepack_standard( - graph, k_cache_data, utils::kTexture3D, utils::kWidthPacked); - const ValueRef v_cache = prepack_standard( - graph, v_cache_data, utils::kTexture3D, utils::kWidthPacked); + utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const ValueRef k_cache = + graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); + const ValueRef v_cache = + graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); @@ -546,10 +569,51 @@ void sdpa_with_kv_cache_impl( out}); } +void compute_attn_weight_with_kv_cache_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef q_projected = args[arg_idx++]; + const ValueRef k_projected = args[arg_idx++]; + const ValueRef v_projected = args[arg_idx++]; + const ValueRef k_cache_data = args[arg_idx++]; + const ValueRef v_cache_data = args[arg_idx++]; + const ValueRef input_pos_symint = args[arg_idx++]; + const ValueRef sequence_len = args[arg_idx++]; + const ValueRef attn_mask = args[arg_idx++]; + (void)attn_mask; + const ValueRef dropout_p = args[arg_idx++]; + (void)dropout_p; + const ValueRef is_causal = args[arg_idx++]; + (void)is_causal; + const ValueRef scale = args[arg_idx++]; + (void)scale; + + // Output tensors + const ValueRef out = args[arg_idx++]; + + (void)sequence_len; + + const utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const ValueRef k_cache = + graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); + const ValueRef v_cache = + graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked); + + update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); + update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); + + add_sdpa_compute_attn_weights_node( + graph, q_projected, k_cache, input_pos_symint, out); +} + REGISTER_OPERATORS { VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl); VK_REGISTER_OP(update_cache.default, update_cache_impl); VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl); + VK_REGISTER_OP( + testing.compute_attn_weight_with_kv_cache.default, + compute_attn_weight_with_kv_cache_impl); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp b/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp index 5e645e29e3d..2d683719ba2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp @@ -139,9 +139,9 @@ void add_softmax_node( // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers - {graph.logical_limits_ubo(out), graph.sizes_ubo(in)}, - // Push Constants {}, + // Push Constants + {graph.sizes_pc_of(in), graph.logical_limits_pc_of(out)}, // Specialization Constants {graph.packed_dim_of(out), reduce_dim_xyz, group_dim}, // Resize Args diff --git a/backends/vulkan/runtime/graph/ops/impl/Split.cpp b/backends/vulkan/runtime/graph/ops/impl/Split.cpp index f87af08ee69..4e62ae8806d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Split.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Split.cpp @@ -8,134 +8,131 @@ #include -#include +#include +#include #include -#include #include + #include -namespace vkcompute { +#include -void add_split_with_sizes_default_node( - ComputeGraph& graph, - ValueRef in, - const std::vector& split_sizes, - int64_t dim, - ValueRef out_list_ref) { - const ValueListPtr out_list = graph.get_value_list(out_list_ref); +namespace vkcompute { - const int64_t input_ndim = graph.dim_of(in); +using utils::GPUMemoryLayout; +using utils::StorageType; + +void resize_split_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef input = args.at(0).refs.at(0); + const ValueRef split_sizes_ref = args.at(1).refs.at(0); + const ValueRef dim_ref = args.at(2).refs.at(0); + const ValueRef out_list_ref = args.at(3).refs.at(0); + + const ValueListPtr out_list = graph->get_value_list(out_list_ref); + const std::vector split_sizes = + *(graph->get_int_list(split_sizes_ref)); + const int64_t dim = graph->extract_scalar(dim_ref); + + const int64_t input_ndim = graph->dim_of(input); const DimIndex dim_index = dim < 0 ? static_cast(dim) : static_cast(dim - input_ndim); - VK_CHECK_COND(out_list->size() == split_sizes.size()); + std::vector input_sizes = graph->sizes_of(input); for (int split_idx = 0; split_idx < split_sizes.size(); split_idx++) { const int64_t split_size = split_sizes.at(split_idx); const ValueRef out_ref = out_list->at(split_idx); - VK_CHECK_COND(dim_at(graph.sizes_of(out_ref), dim_index) == split_size); - } - - const auto packed_dim = graph.packed_dim_of(in); - const auto packed_dim_index = static_cast(kWidth4D - packed_dim); + std::vector out_sizes = input_sizes; + out_sizes.at(dim_index) = split_size; - // Index of dimension to be concatenated in (w, h, c * b) coordinate system - const auto dim_xyz_index = std::min(2, -dim_index - 1); - - utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false); - utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false); - - const bool is_splitting_channel = (dim_index == kChannel4D); - - // if splitting channels - if (is_splitting_channel) { - // set source offset w as channel size of the input tensor - src_offset[3] = dim_at(graph.sizes_of(in), kChannel4D); + graph->virtual_resize(out_ref, out_sizes); } +} - for (ValueRef out_ref : *out_list) { - // Doesn't need to use split_size since we have already verified that the - // output tensor's size matches with the split_size. - const auto out_channel_size = dim_at(graph.sizes_of(out_ref), kChannel4D); - const utils::ivec3 range = graph.logical_limits_of(out_ref); - - if (dim_index == packed_dim_index) { - // if splitting channels, use add_copy_channel_offset_node function as - // add_copy_packed_dim_offset_node does not support channel packing - if (is_splitting_channel) { - add_copy_channel_offset_node( - graph, in, out_channel_size, src_offset[2], dst_offset[2], out_ref); - src_offset[dim_xyz_index] += out_channel_size; - } else { - // dst_offset[3] is not used now but will be used in the future when - // add_copy_packed_dim_offset_node will support channel packing - // - // set destination offset w as channel size of the output tensor if - // splitting channel - dst_offset[3] = is_splitting_channel ? out_channel_size : 0; - add_copy_packed_dim_offset_node( - graph, in, range, src_offset, dst_offset, out_ref); - src_offset[dim_xyz_index] += - dim_at(graph.sizes_of(out_ref), packed_dim_index); - } - } else { - // set destination offset w as channel size of the output tensor if - // splitting channels - dst_offset[3] = is_splitting_channel ? out_channel_size : 0; - add_copy_offset_node( - graph, in, range, src_offset, dst_offset, out_ref, false, true); - src_offset[dim_xyz_index] += - is_splitting_channel ? out_channel_size : range[dim_xyz_index]; - } +void add_split_node( + ComputeGraph& graph, + const ValueRef input, + const std::vector& split_sizes, + const int64_t dim, + const ValueRef out, + const int split_idx) { + std::string kernel_name = "split"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(out), graph.meta_ubo(input)}; + + int64_t dim_whcn = nchw_dim_to_whcn_dim(dim, graph.dim_of(input)); + + // Calculate the offset for this split by summing previous split sizes + int64_t split_offset = 0; + for (int i = 0; i < split_idx; i++) { + split_offset += split_sizes[i]; } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader params buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {utils::safe_downcast(dim_whcn), + static_cast(split_idx), + static_cast(split_offset)}, + // Resize Args + {}, + // Resizing Logic + nullptr)); } -void add_split_with_sizes_default_node( +void add_split_with_sizes_node( ComputeGraph& graph, - ValueRef in, - ValueRef split_sizes_ref, - ValueRef dim_ref, - ValueRef out) { - int64_t dim = graph.extract_scalar(dim_ref); - std::vector split_sizes = *(graph.get_int_list(split_sizes_ref)); + const ValueRef input, + const std::vector& split_sizes, + const int64_t dim, + const ValueRef out_list_ref) { + const ValueListPtr out_list = graph.get_value_list(out_list_ref); + + VK_CHECK_COND(out_list->size() == split_sizes.size()); - add_split_with_sizes_default_node(graph, in, split_sizes, dim, out); + // Dispatch a shader for each output tensor + for (int split_idx = 0; split_idx < split_sizes.size(); split_idx++) { + const ValueRef out_ref = out_list->at(split_idx); + add_split_node(graph, input, split_sizes, dim, out_ref, split_idx); + } } void split_with_sizes_copy_default( ComputeGraph& graph, const std::vector& args) { - add_split_with_sizes_default_node(graph, args[0], args[1], args[2], args[3]); -} - -void add_split_tensor_node( - ComputeGraph& graph, - ValueRef in, - ValueRef split_size_ref, - ValueRef dim_ref, - ValueRef out) { - const int64_t split_size = graph.extract_scalar(split_size_ref); - const int64_t dim = graph.extract_scalar(dim_ref); - - const int64_t input_ndim = graph.dim_of(in); - const DimIndex dim_index = dim < 0 ? static_cast(dim) - : static_cast(dim - input_ndim); - const int64_t size = dim_at(graph.sizes_of(in), dim_index); - const std::vector split_sizes(size / split_size, split_size); + ValueRef input = args[0]; + ValueRef split_sizes_ref = args[1]; + ValueRef dim_ref = args[2]; + ValueRef out_list_ref = args[3]; - add_split_with_sizes_default_node(graph, in, split_sizes, dim, out); -} + int64_t dim = graph.extract_scalar(dim_ref); + std::vector split_sizes = *(graph.get_int_list(split_sizes_ref)); -void split_tensor(ComputeGraph& graph, const std::vector& args) { - add_split_tensor_node(graph, args[0], args[1], args[2], args[3]); + add_split_with_sizes_node(graph, input, split_sizes, dim, out_list_ref); } REGISTER_OPERATORS { VK_REGISTER_OP( aten.split_with_sizes_copy.default, split_with_sizes_copy_default); - VK_REGISTER_OP(aten.split.Tensor, split_tensor); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp index 13801b45cc7..e2b73b2f3f2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp @@ -32,8 +32,13 @@ void add_squeeze_copy_dims_node( // 2. Squeeze outter most dim // For these cases, just pass input to output via clone. for (int i = 0; i < dims.size(); ++i) { - if (dims.at(i) != 0 && in_sizes.at(dims.at(i)) == 1) { - squeeze_dims.push_back(dims.at(i)); + // adjust negative dims + int64_t dim_val = dims.at(i); + if (dim_val < 0) { + dim_val += in_dim; + } + if (dims.at(i) != 0 && in_sizes.at(dim_val) == 1) { + squeeze_dims.push_back(dim_val); } } if (squeeze_dims.size() == 0) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 648d7b8da09..db7c5a7e88b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -27,7 +27,10 @@ void add_staging_to_tensor_node( VK_CHECK_COND(graph.val_is_staging(in_staging)); vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( - graph, out_tensor, graph.int8_buffers_enabled()); + graph, + out_tensor, + graph.dtype_of(in_staging), + graph.int8_buffers_enabled()); vkapi::ParamsBindList param_buffers = {}; if (graph.is_buffer_storage(out_tensor)) { @@ -66,16 +69,6 @@ bool is_bitw8_shader(const vkapi::ShaderInfo& shader) { return shader_prefix_str == kBitw8PrefixStr; } -vkapi::ShaderInfo get_tensor_to_staging_shader( - ComputeGraph* graph, - const std::vector& args, - const std::vector& resize_args) { - (void)resize_args; - const ValueRef in_tensor = args.at(1).refs.at(0); - return get_tensor_to_nchw_shader( - *graph, in_tensor, graph->int8_buffers_enabled()); -} - utils::uvec3 tensor_to_staging_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -110,8 +103,11 @@ void add_tensor_to_staging_node( const ValueRef out_staging) { VK_CHECK_COND(graph.val_is_staging(out_staging)); - vkapi::ShaderInfo shader = - get_tensor_to_nchw_shader(graph, in_tensor, graph.int8_buffers_enabled()); + vkapi::ShaderInfo shader = get_tensor_to_nchw_shader( + graph, + in_tensor, + graph.dtype_of(out_staging), + graph.int8_buffers_enabled()); vkapi::ParamsBindList param_buffers = {}; if (graph.is_buffer_storage(in_tensor)) { @@ -151,8 +147,8 @@ void add_prepack_standard_node( const ValueRef tensor_data, const ValueRef tensor, const bool transpose_hw = false) { - vkapi::ShaderInfo shader = - get_nchw_to_tensor_shader(graph, tensor, graph.int8_buffers_enabled()); + vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( + graph, tensor, graph.dtype_of(tensor_data), graph.int8_buffers_enabled()); vkapi::ParamsBindList param_buffers = {}; if (graph.is_buffer_storage(tensor)) { @@ -289,7 +285,7 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved( const int64_t N = qmat2_orig_sizes.at(ndim - 2); const int64_t N_div2 = N / int64_t(2); - utils::StorageType storage_type = utils::kTexture2D; + utils::StorageType storage_type = utils::kBuffer; uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); if (N_div2 > max_extent * 4 || K > max_extent) { storage_type = utils::kBuffer; diff --git a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp index f07522d2578..eb03639abf1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp @@ -81,9 +81,58 @@ void sym_add(ComputeGraph& graph, const std::vector& args) { new ExecuteNode(resize_sym_add_node, args)); } +void select_as_symint_impl( + ComputeGraph* graph, + const std::vector& unused, + const std::vector& args) { + (void)unused; // Unused parameter + + const ValueRef x = args.at(0); + const ValueRef dim = args.at(1); + const ValueRef index = args.at(2); + const ValueRef out = args.at(3); + + const int64_t dim_val = graph->extract_scalar(dim); + int64_t index_val = graph->extract_scalar(index); + + const std::vector x_sizes = graph->sizes_of(x); + const vkapi::ScalarType x_dtype = graph->dtype_of(x); + + if (index_val < 0) { + index_val += x_sizes[dim_val]; + } + + const StagingPtr x_staging = graph->get_staging(graph->staging_of(x)); + + int32_t x_val; + switch (x_dtype) { + case vkapi::ScalarType::Int: + x_val = x_staging->select_element_at_dim( + x_sizes, dim_val, index_val); + break; + case vkapi::ScalarType::Long: + x_val = static_cast(x_staging->select_element_at_dim( + x_sizes, dim_val, index_val)); + break; + default: + VK_THROW("Unsupported dtype for select_as_symint"); + } + + graph->set_symint(out, x_val); +} + +void select_as_symint(ComputeGraph& graph, const std::vector& args) { + select_as_symint_impl(&graph, {}, args); + + graph.execute_nodes().emplace_back(new ExecuteNode( + select_as_symint_impl, args, {}, "select_as_symint", true)); + graph.set_has_data_dependent_shapes(); +} + REGISTER_OPERATORS { VK_REGISTER_OP(sym_size.int, sym_size_int); VK_REGISTER_OP(add, sym_add); + VK_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Transfer.cpp b/backends/vulkan/runtime/graph/ops/impl/Transfer.cpp index 60127ecf9bd..1823271824a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Transfer.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Transfer.cpp @@ -50,15 +50,16 @@ void add_transfer_copy_node( (transfer_type == TransferType::SELECT || graph.is_scalar_or_none(step_ref)); - vkapi::ParamsBindList param_buffers; + vkapi::ParamsBindList param_ubos = {graph.meta_ubo(out), graph.meta_ubo(in)}; + if (!param_is_scalar) { if (transfer_type == TransferType::SELECT) { - param_buffers = { - graph.get_or_create_int_param_buffer(index_or_start_ref, 0)}; + param_ubos.append( + graph.get_or_create_int_param_buffer(index_or_start_ref, 0)); } else { // TransferType::SLICE - param_buffers = { - graph.get_or_create_int_param_buffer(index_or_start_ref, 0), - graph.get_or_create_int_param_buffer(step_ref, 1)}; + param_ubos.append( + graph.get_or_create_int_param_buffer(index_or_start_ref, 0)); + param_ubos.append(graph.get_or_create_int_param_buffer(step_ref, 1)); } } else { transfer_params.index_or_start_ref = @@ -69,18 +70,6 @@ void add_transfer_copy_node( } std::vector push_constants; - push_constants.reserve(graph.is_buffer_storage(out) ? 5 : 3); - - if (graph.is_buffer_storage(out)) { - push_constants.emplace_back(graph.sizes_pc_of(in)); - push_constants.emplace_back(graph.strides_pc_of(out)); - push_constants.emplace_back(graph.strides_pc_of(in)); - push_constants.emplace_back(graph.numel_pc_of(out)); - } else { - push_constants.emplace_back(graph.sizes_pc_of(out)); - push_constants.emplace_back(graph.sizes_pc_of(in)); - } - if (param_is_scalar) { push_constants.emplace_back(&transfer_params, sizeof(transfer_params)); } else { @@ -88,11 +77,6 @@ void add_transfer_copy_node( &transfer_params.dim, sizeof(transfer_params.dim)); } - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(out), - graph.hashed_layout_of(in), - }; - // Determine the shader directly std::string kernel_name; if (transfer_type == TransferType::SELECT) { @@ -115,11 +99,11 @@ void add_transfer_copy_node( // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Parameter buffers - param_buffers, + param_ubos, // Push Constants push_constants, // Specialization Constants - spec_vars, + {}, // Resize Args resize_args, // Resizing Logic diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index 0a98f6d8f43..602fe1ef129 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -54,12 +54,33 @@ void resize_unsqueeze_node( const ValueRef in = args.at(1).refs.at(0); const ValueRef dims_ref = extra_args.at(0); - const IntListPtr dims = graph->get_int_list(dims_ref); + std::vector dims_vec; + if (graph->is_scalar_or_none(dims_ref)) { + // Handle scalar case + int64_t dim = graph->extract_scalar(dims_ref); + dims_vec.push_back(dim); + } else { + // Handle list case + const IntListPtr dims = graph->get_int_list(dims_ref); + dims_vec.assign(dims->begin(), dims->end()); + } std::vector out_sizes = graph->sizes_of(in); + std::vector unsqueezed_dims; + + if (graph->val_is_int_list(dims_ref)) { + const IntListPtr dims = graph->get_int_list(dims_ref); + for (int64_t d : *dims) { + unsqueezed_dims.push_back(d); + } + } else { + const int64_t dim = graph->extract_scalar(dims_ref); + unsqueezed_dims.push_back(dim); + } + // Insert singleton dimensions at the specified positions - for (auto dim : *dims) { + for (auto dim : dims_vec) { int64_t d = dim; if (d < 0) { d += static_cast(out_sizes.size()) + 1; diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 8701a6246b0..5e2c898573a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -60,6 +60,16 @@ void resize_view_node( } } +void resize_to_dim_order_copy_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const std::vector in_sizes = graph->sizes_of(in); + graph->virtual_resize(out, in_sizes); +} + void add_view_node( ComputeGraph& graph, ValueRef in, @@ -98,6 +108,11 @@ void add_view_copy_buffer_node( std::string kernel_name = "view_buffer"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); + bool all_contiguous = graph.is_contiguous_buffer_tensor(in) && + graph.is_contiguous_buffer_tensor(out); + + int32_t all_contiguous_int = all_contiguous ? 1 : 0; + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -110,7 +125,41 @@ void add_view_copy_buffer_node( // Push Constants {}, // Specialization Constants + {all_contiguous_int}, + // Resize Args + resize_args, + // Resizing Logic + resize_fn)); +} + +void add_view_copy_convert_buffer_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn) { + std::string kernel_name = "view_convert_buffer"; + add_dtype_suffix(kernel_name, graph.dtype_of(in)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + bool all_contiguous = graph.is_contiguous_buffer_tensor(in) && + graph.is_contiguous_buffer_tensor(out); + + int32_t all_contiguous_int = all_contiguous ? 1 : 0; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Parameter Buffers + {graph.buffer_meta_ubo(out), graph.buffer_meta_ubo(in)}, + // Push Constants {}, + // Specialization Constants + {all_contiguous_int}, // Resize Args resize_args, // Resizing Logic @@ -132,8 +181,38 @@ void view(ComputeGraph& graph, const std::vector& args) { return add_view_node(graph, in, sizes, out); } +void to_dim_order_copy(ComputeGraph& graph, const std::vector& args) { + int args_idx = 0; + const ValueRef in = args.at(args_idx++); + const ValueRef dtype = args.at(args_idx++); + (void)dtype; + const ValueRef layout = args.at(args_idx++); + (void)layout; + const ValueRef device = args.at(args_idx++); + (void)device; + const ValueRef pin_memory = args.at(args_idx++); + (void)pin_memory; + const ValueRef non_blocking = args.at(args_idx++); + (void)non_blocking; + const ValueRef dim_order = args.at(args_idx++); + (void)dim_order; + + const ValueRef out = args.at(args_idx++); + + VK_CHECK_COND(graph.is_buffer_storage(in) && graph.is_buffer_storage(out)); + + if (graph.dtype_of(in) == graph.dtype_of(out)) { + return add_view_copy_buffer_node( + graph, in, out, {}, resize_to_dim_order_copy_node); + } + + return add_view_copy_convert_buffer_node( + graph, in, out, {}, resize_to_dim_order_copy_node); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten.view_copy.default, view); + VK_REGISTER_OP(dim_order_ops._to_dim_order_copy.default, to_dim_order_copy); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/View.h b/backends/vulkan/runtime/graph/ops/impl/View.h index 7a7a8d57742..c8e52492417 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.h +++ b/backends/vulkan/runtime/graph/ops/impl/View.h @@ -24,6 +24,19 @@ void add_view_copy_buffer_node( const std::vector& resize_args, const ExecuteNode::ResizeFunction& resize_fn); +/* + * Dispatches the view_convert_buffer compute shader. This can be used to + * implement ops that preserve the "contiguous" indexes of elements between the + * input and output while converting between different data types such as + * view_copy with dtype conversion. + */ +void add_view_copy_convert_buffer_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn); + void add_view_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h index b62bf661995..05234c7790f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h @@ -69,7 +69,7 @@ template < std::is_integral::value && std::is_signed::value, int>::type = 0> T nchw_dim_to_whcn_dim(const T& nchw_dim, const int64_t ndim) { - return ndim - 1 - nchw_dim; + return ndim - 1 - normalize(nchw_dim, ndim); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index c90bfa402bb..c2adca526fb 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -23,6 +23,7 @@ bool is_bitw8(vkapi::ScalarType dtype) { vkapi::ShaderInfo get_nchw_to_tensor_shader( ComputeGraph& graph, const ValueRef dst, + const vkapi::ScalarType staging_dtype, bool int8_buffer_enabled, bool push_constant_variant) { std::string kernel_name; @@ -45,6 +46,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( if (dst_storage_type == utils::kBuffer) { kernel_name = "nchw_to_buffer"; add_dtype_suffix(kernel_name, dst_dtype); + add_dtype_suffix(kernel_name, staging_dtype); return VK_KERNEL_FROM_STR(kernel_name); } @@ -54,6 +56,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( } add_storage_type_suffix(kernel_name, dst_storage_type); add_dtype_suffix(kernel_name, dst_dtype); + add_dtype_suffix(kernel_name, staging_dtype); return VK_KERNEL_FROM_STR(kernel_name); } @@ -61,6 +64,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( vkapi::ShaderInfo get_tensor_to_nchw_shader( ComputeGraph& graph, const ValueRef src, + const vkapi::ScalarType staging_dtype, bool int8_buffer_enabled, bool push_constant_variant) { std::string kernel_name; @@ -83,6 +87,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( if (src_storage_type == utils::kBuffer) { kernel_name = "buffer_to_nchw"; add_dtype_suffix(kernel_name, src_dtype); + add_dtype_suffix(kernel_name, staging_dtype); return VK_KERNEL_FROM_STR(kernel_name); } @@ -92,6 +97,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( } add_storage_type_suffix(kernel_name, src_storage_type); add_dtype_suffix(kernel_name, src_dtype); + add_dtype_suffix(kernel_name, staging_dtype); return VK_KERNEL_FROM_STR(kernel_name); } diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h index 71c92b833b7..a4419de3932 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h @@ -15,11 +15,14 @@ namespace vkcompute { vkapi::ShaderInfo get_nchw_to_tensor_shader( ComputeGraph& graph, const ValueRef dst, + const vkapi::ScalarType staging_dtype, bool int8_buffer_enabled = true, bool push_constant_variant = true); + vkapi::ShaderInfo get_tensor_to_nchw_shader( ComputeGraph& graph, const ValueRef src, + const vkapi::ScalarType staging_dtype, bool int8_buffer_enabled = true, bool push_constant_variant = true); diff --git a/backends/vulkan/runtime/utils/StorageUtils.cpp b/backends/vulkan/runtime/utils/StorageUtils.cpp new file mode 100644 index 00000000000..cfe3d9e159a --- /dev/null +++ b/backends/vulkan/runtime/utils/StorageUtils.cpp @@ -0,0 +1,25 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace vkcompute { +namespace utils { + +bool is_packed_int8_layout(const GPUMemoryLayout layout) { + switch (layout) { + case kPackedInt8_4W4C: + case kPackedInt8_4H4W: + return true; + default: + return false; + } +} + +} // namespace utils +} // namespace vkcompute diff --git a/backends/vulkan/runtime/utils/StorageUtils.h b/backends/vulkan/runtime/utils/StorageUtils.h index 20addf88c53..a269adccecb 100644 --- a/backends/vulkan/runtime/utils/StorageUtils.h +++ b/backends/vulkan/runtime/utils/StorageUtils.h @@ -8,6 +8,7 @@ #pragma once +#include #include namespace vkcompute { @@ -84,9 +85,24 @@ enum class GPUMemoryLayout : uint8_t { * 2. For texture backed tensors, the packed dim will be the specified dim. * The axis map will be `{0, 1, 2, 2}`. */ + TENSOR_WIDTH_PACKED = 0u, TENSOR_HEIGHT_PACKED = 1u, TENSOR_CHANNELS_PACKED = 2u, + + /* + * The following memory layouts are used for quantized int8 tensors. For the + * above "standard" memory layouts, 4 elements along the packed dim are stored + * in each texel (4-component vectorized type). However, for packed int8 + * memory layouts, an additional level of packing is used where 4 int8 values + * are packed into each int32, and each int32 is packed into each ivec4. + * Conceptually, this allows an additional packed dimension to be used. + * When loading a ivec4 from the GPU storage buffer / texture, data for a + * 16 element block is loaded, rather than 4 elements along one dimension. + */ + + TENSOR_PACKED_INT8_4W4C = 3u, + TENSOR_PACKED_INT8_4H4W = 4u, }; static constexpr GPUMemoryLayout kWidthPacked = @@ -98,6 +114,12 @@ static constexpr GPUMemoryLayout kHeightPacked = static constexpr GPUMemoryLayout kChannelsPacked = GPUMemoryLayout::TENSOR_CHANNELS_PACKED; +static constexpr GPUMemoryLayout kPackedInt8_4W4C = + GPUMemoryLayout::TENSOR_PACKED_INT8_4W4C; + +static constexpr GPUMemoryLayout kPackedInt8_4H4W = + GPUMemoryLayout::TENSOR_PACKED_INT8_4H4W; + template T to_packed_dim(const GPUMemoryLayout layout) { switch (layout) { @@ -107,11 +129,17 @@ T to_packed_dim(const GPUMemoryLayout layout) { return 1; case kChannelsPacked: return 2; + case kPackedInt8_4W4C: + return 2; + case kPackedInt8_4H4W: + return 0; }; // Should be unreachable return 0; } +bool is_packed_int8_layout(const GPUMemoryLayout layout); + inline std::ostream& operator<<( std::ostream& os, const StorageType storage_type) { @@ -142,6 +170,12 @@ inline std::ostream& operator<<( case kChannelsPacked: os << "TENSOR_CHANNELS_PACKED"; break; + case kPackedInt8_4W4C: + os << "TENSOR_PACKED_INT8_4W4C"; + break; + case kPackedInt8_4H4W: + os << "TENSOR_PACKED_INT8_4H4W"; + break; } return os; } diff --git a/backends/vulkan/runtime/vk_api/Adapter.cpp b/backends/vulkan/runtime/vk_api/Adapter.cpp index 0e87dde1922..e0b2f1c978b 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.cpp +++ b/backends/vulkan/runtime/vk_api/Adapter.cpp @@ -11,6 +11,7 @@ #include #include +#include namespace vkcompute { namespace vkapi { @@ -112,9 +113,10 @@ VkDevice create_logical_device( #ifdef VK_KHR_shader_integer_dot_product VK_KHR_SHADER_INTEGER_DOT_PRODUCT_EXTENSION_NAME, #endif /* VK_KHR_shader_integer_dot_product */ -#if defined(VK_KHR_pipeline_executable_properties) && defined(VULKAN_DEBUG) +#if defined(VK_KHR_pipeline_executable_properties) && \ + defined(ETVK_INSPECT_PIPELINES) VK_KHR_PIPELINE_EXECUTABLE_PROPERTIES_EXTENSION_NAME, -#endif /* VK_KHR_pipeline_executable_properties */ +#endif /* VK_KHR_pipeline_executable_properties && ETVK_INSPECT_PIPELINES */ }; std::vector enabled_device_extensions; @@ -412,6 +414,11 @@ std::string Adapter::stringize() const { #endif /* VK_KHR_shader_float16_int8 */ ss << " }" << std::endl; + ss << " Shader 64bit Features {" << std::endl; + PRINT_BOOL(physical_device_.supports_int64_shader_types, shaderInt64) + PRINT_BOOL(physical_device_.supports_float64_shader_types, shaderFloat64) + ss << " }" << std::endl; + #ifdef VK_KHR_shader_integer_dot_product ss << " Shader Integer Dot Product Features {" << std::endl; PRINT_PROP( diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index 6a68b487348..65d0977b533 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -225,6 +225,14 @@ class Adapter final { return physical_device_.supports_int16_shader_types; } + inline bool supports_int64_shader_types() { + return physical_device_.supports_int64_shader_types; + } + + inline bool supports_float64_shader_types() { + return physical_device_.supports_float64_shader_types; + } + inline bool has_full_float16_buffers_support() { return supports_16bit_storage_buffers() && supports_float16_shader_types(); } diff --git a/backends/vulkan/runtime/vk_api/Device.cpp b/backends/vulkan/runtime/vk_api/Device.cpp index a21130f1231..7a3a825f5ec 100644 --- a/backends/vulkan/runtime/vk_api/Device.cpp +++ b/backends/vulkan/runtime/vk_api/Device.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace vkcompute { namespace vkapi { @@ -45,6 +46,8 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) queue_families{}, num_compute_queues(0), supports_int16_shader_types(false), + supports_int64_shader_types(false), + supports_float64_shader_types(false), has_unified_memory(false), has_timestamps(false), timestamp_period(0), @@ -97,6 +100,12 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) if (features2.features.shaderInt16 == VK_TRUE) { supports_int16_shader_types = true; } + if (features2.features.shaderInt64 == VK_TRUE) { + supports_int64_shader_types = true; + } + if (features2.features.shaderFloat64 == VK_TRUE) { + supports_float64_shader_types = true; + } // Check if there are any memory types have both the HOST_VISIBLE and the // DEVICE_LOCAL property flags diff --git a/backends/vulkan/runtime/vk_api/Device.h b/backends/vulkan/runtime/vk_api/Device.h index f5b7154d260..917df514c4b 100644 --- a/backends/vulkan/runtime/vk_api/Device.h +++ b/backends/vulkan/runtime/vk_api/Device.h @@ -12,7 +12,7 @@ #include -#include +#include #include namespace vkcompute { @@ -57,6 +57,8 @@ struct PhysicalDevice final { // Metadata uint32_t num_compute_queues; bool supports_int16_shader_types; + bool supports_int64_shader_types; + bool supports_float64_shader_types; bool has_unified_memory; bool has_timestamps; float timestamp_period; diff --git a/backends/vulkan/runtime/vk_api/Exception.cpp b/backends/vulkan/runtime/vk_api/Exception.cpp index c07349fa7ca..5bcf047aaf1 100644 --- a/backends/vulkan/runtime/vk_api/Exception.cpp +++ b/backends/vulkan/runtime/vk_api/Exception.cpp @@ -10,6 +10,13 @@ #include +#ifdef ETVK_BOOST_STACKTRACE_AVAILABLE +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif // _GNU_SOURCE +#include +#endif // ETVK_BOOST_STACKTRACE_AVAILABLE + namespace vkcompute { namespace vkapi { @@ -65,6 +72,11 @@ Error::Error(SourceLocation source_location, std::string msg) std::ostringstream oss; oss << "Exception raised from " << source_location_ << ": "; oss << msg_; +#ifdef ETVK_BOOST_STACKTRACE_AVAILABLE + oss << "\n"; + oss << "Stack trace:\n"; + oss << boost::stacktrace::stacktrace(); +#endif // ETVK_BOOST_STACKTRACE_AVAILABLE what_ = oss.str(); } @@ -74,6 +86,11 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg) oss << "Exception raised from " << source_location_ << ": "; oss << "(" << cond << ") is false! "; oss << msg_; +#ifdef ETVK_BOOST_STACKTRACE_AVAILABLE + oss << "\n"; + oss << "Stack trace:\n"; + oss << boost::stacktrace::stacktrace(); +#endif // ETVK_BOOST_STACKTRACE_AVAILABLE what_ = oss.str(); } @@ -95,6 +112,12 @@ std::ostream& operator<<(std::ostream& out, const VulkanExtension result) { case VulkanExtension::INTEGER_DOT_PRODUCT: out << "VK_KHR_shader_integer_dot_product"; break; + case VulkanExtension::SHADER_INT64: + out << "shaderInt64"; + break; + case VulkanExtension::SHADER_FLOAT64: + out << "shaderFloat64"; + break; } return out; } diff --git a/backends/vulkan/runtime/vk_api/Exception.h b/backends/vulkan/runtime/vk_api/Exception.h index a883a68fefc..aa1ef1f2526 100644 --- a/backends/vulkan/runtime/vk_api/Exception.h +++ b/backends/vulkan/runtime/vk_api/Exception.h @@ -83,6 +83,8 @@ enum class VulkanExtension : uint8_t { INT16_STORAGE, INT8_STORAGE, INTEGER_DOT_PRODUCT, + SHADER_INT64, + SHADER_FLOAT64, }; class ShaderNotSupportedError : public std::exception { diff --git a/backends/vulkan/runtime/vk_api/Pipeline.cpp b/backends/vulkan/runtime/vk_api/Pipeline.cpp index 994b46b8c76..6fa85924223 100644 --- a/backends/vulkan/runtime/vk_api/Pipeline.cpp +++ b/backends/vulkan/runtime/vk_api/Pipeline.cpp @@ -298,10 +298,11 @@ ComputePipeline::ComputePipeline( }; VkPipelineCreateFlags flags = 0u; -#if defined(VULKAN_DEBUG) && defined(VK_KHR_pipeline_executable_properties) +#if defined(VK_KHR_pipeline_executable_properties) && \ + defined(ETVK_INSPECT_PIPELINES) flags = VK_PIPELINE_CREATE_CAPTURE_STATISTICS_BIT_KHR | VK_PIPELINE_CREATE_CAPTURE_INTERNAL_REPRESENTATIONS_BIT_KHR | flags; -#endif /* VULKAN_DEBUG && VK_KHR_pipeline_executable_properties */ +#endif // VK_KHR_pipeline_executable_properties && ETVK_INSPECT_PIPELINES const VkComputePipelineCreateInfo compute_pipeline_create_info{ VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType diff --git a/backends/vulkan/runtime/vk_api/Runtime.cpp b/backends/vulkan/runtime/vk_api/Runtime.cpp index c3376e2ccbf..8bd4f8843bf 100644 --- a/backends/vulkan/runtime/vk_api/Runtime.cpp +++ b/backends/vulkan/runtime/vk_api/Runtime.cpp @@ -14,12 +14,6 @@ #include #include -#ifdef USE_VOLK_HEADER_ONLY -// For volk.h, define this before including volk.h in exactly one CPP file. -#define VOLK_IMPLEMENTATION -#include -#endif /* USE_VOLK_HEADER_ONLY */ - namespace vkcompute { namespace vkapi { @@ -88,7 +82,7 @@ VkInstance create_instance(const RuntimeConfig& config) { const VkApplicationInfo application_info{ VK_STRUCTURE_TYPE_APPLICATION_INFO, // sType nullptr, // pNext - "PyTorch Vulkan Backend", // pApplicationName + "ExecuTorch Vulkan Delegate", // pApplicationName 0, // applicationVersion nullptr, // pEngineName 0, // engineVersion diff --git a/backends/vulkan/runtime/vk_api/Shader.cpp b/backends/vulkan/runtime/vk_api/Shader.cpp index 4356f92efe7..c932d0a264b 100644 --- a/backends/vulkan/runtime/vk_api/Shader.cpp +++ b/backends/vulkan/runtime/vk_api/Shader.cpp @@ -32,7 +32,9 @@ ShaderInfo::ShaderInfo( const bool requires_shader_int16_ext, const bool requires_16bit_storage_ext, const bool requires_8bit_storage_ext, - const bool requires_integer_dot_product_ext) + const bool requires_integer_dot_product_ext, + const bool requires_shader_int64_ext, + const bool requires_shader_float64_ext) : src_code{ spirv_bin, size, @@ -43,7 +45,9 @@ ShaderInfo::ShaderInfo( requires_shader_int16(requires_shader_int16_ext), requires_16bit_storage(requires_16bit_storage_ext), requires_8bit_storage(requires_8bit_storage_ext), - requires_integer_dot_product(requires_integer_dot_product_ext) { + requires_integer_dot_product(requires_integer_dot_product_ext), + requires_shader_int64(requires_shader_int64_ext), + requires_shader_float64(requires_shader_float64_ext) { } bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) { diff --git a/backends/vulkan/runtime/vk_api/Shader.h b/backends/vulkan/runtime/vk_api/Shader.h index 21332381406..6311710f02b 100644 --- a/backends/vulkan/runtime/vk_api/Shader.h +++ b/backends/vulkan/runtime/vk_api/Shader.h @@ -66,6 +66,8 @@ struct ShaderInfo final { bool requires_16bit_storage = false; bool requires_8bit_storage = false; bool requires_integer_dot_product = false; + bool requires_shader_int64 = false; + bool requires_shader_float64 = false; explicit ShaderInfo(); @@ -78,7 +80,9 @@ struct ShaderInfo final { const bool requires_shader_int16_ext, const bool requires_16bit_storage_ext, const bool requires_8bit_storage_ext, - const bool requires_integer_dot_product_ext); + const bool requires_integer_dot_product_ext, + const bool requires_shader_int64_ext, + const bool requires_shader_float64_ext); operator bool() const { return src_code.bin != nullptr; diff --git a/backends/vulkan/runtime/vk_api/Types.h b/backends/vulkan/runtime/vk_api/Types.h index b3309aa6c69..f4415b5c08f 100644 --- a/backends/vulkan/runtime/vk_api/Types.h +++ b/backends/vulkan/runtime/vk_api/Types.h @@ -43,7 +43,8 @@ _(double, VK_FORMAT_R64G64B64A64_SFLOAT, Double) \ _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \ - _(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32) + _(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32) \ + _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int8x4) namespace vkcompute { namespace vkapi { diff --git a/backends/vulkan/runtime/vk_api/memory/Allocator.cpp b/backends/vulkan/runtime/vk_api/memory/Allocator.cpp index 7976d0ddee5..1d814533ede 100644 --- a/backends/vulkan/runtime/vk_api/memory/Allocator.cpp +++ b/backends/vulkan/runtime/vk_api/memory/Allocator.cpp @@ -141,19 +141,25 @@ VulkanImage Allocator::create_image( allocate_memory); } -VulkanBuffer Allocator::create_staging_buffer(const VkDeviceSize size) { +VulkanBuffer Allocator::create_staging_buffer( + const VkDeviceSize size, + const CopyDirection direction) { const VkBufferUsageFlags buffer_usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; VmaAllocationCreateInfo alloc_create_info = {}; - alloc_create_info.flags = DEFAULT_ALLOCATION_STRATEGY; + alloc_create_info.flags = + DEFAULT_ALLOCATION_STRATEGY | VMA_ALLOCATION_CREATE_MAPPED_BIT; alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE; // Staging buffers are accessed by both the CPU and GPU, so set the // appropriate flags to indicate that the host device will be accessing // the data from this buffer. - alloc_create_info.flags |= - VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | - VMA_ALLOCATION_CREATE_MAPPED_BIT; + if (direction == CopyDirection::HOST_TO_DEVICE) { + alloc_create_info.flags |= + VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT; + } else { + alloc_create_info.flags |= VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT; + } alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO_PREFER_HOST; alloc_create_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; alloc_create_info.preferredFlags = diff --git a/backends/vulkan/runtime/vk_api/memory/Allocator.h b/backends/vulkan/runtime/vk_api/memory/Allocator.h index 8f76ca932b7..9a731fc6766 100644 --- a/backends/vulkan/runtime/vk_api/memory/Allocator.h +++ b/backends/vulkan/runtime/vk_api/memory/Allocator.h @@ -23,6 +23,17 @@ namespace vkcompute { namespace vkapi { +/** + * Indicates the direction of a copy to or from a staging buffer. + * + * HOST_TO_DEVICE: Data is written by the host and read by the device. + * DEVICE_TO_HOST: Data is written by the device and read by the host. + */ +enum class CopyDirection : uint8_t { + HOST_TO_DEVICE = 0u, + DEVICE_TO_HOST = 1u, +}; + constexpr VmaAllocationCreateFlags DEFAULT_ALLOCATION_STRATEGY = VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT; @@ -66,7 +77,7 @@ class Allocator final { const bool allow_transfer = false, const bool allocate_memory = true); - VulkanBuffer create_staging_buffer(const VkDeviceSize); + VulkanBuffer create_staging_buffer(const VkDeviceSize, const CopyDirection); VulkanBuffer create_storage_buffer( const VkDeviceSize, diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index b6670b6f53d..9d738bc386f 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -20,6 +20,7 @@ enum VkDataType : byte { FLOAT32 = 5, FLOAT64 = 6, INT64 = 7, + UNSET = 127, } // Describes what kind of GPU resource should be used to represent a tensor. The @@ -39,6 +40,8 @@ enum VkMemoryLayout : ubyte { TENSOR_WIDTH_PACKED = 0, TENSOR_HEIGHT_PACKED = 1, TENSOR_CHANNELS_PACKED = 2, + PACKED_INT8_4W4C = 3, + PACKED_INT8_4H4W = 4, DEFAULT_LAYOUT = 255, } @@ -55,6 +58,9 @@ table VkTensor { storage_type:VkStorageType = DEFAULT_STORAGE; // Memory layout that should be used to represent this tensor memory_layout:VkMemoryLayout = DEFAULT_LAYOUT; + // dtype to use for staging buffer. This may be different from the tensor's datatype + // if force_fp16 is enabled to force all float tensors to be represented as fp16. + staging_datatype:VkDataType = UNSET; } table Null {} diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 78ac51c8808..43ea6c7ce30 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -50,10 +50,12 @@ def __init__( program: ExportedProgram, delegate_mapping_builder: DelegateMappingBuilder, downcast_64_bit: bool = True, + force_fp16: bool = False, ) -> None: self.program = program self.delegate_mapping_builder = delegate_mapping_builder self.downcast_64_bit = downcast_64_bit + self.force_fp16 = force_fp16 self.chain = [] self.values = [] self.input_ids = [] @@ -135,6 +137,12 @@ def maybe_add_constant_tensor(self, node: Node) -> int: if is_param_node(self.program, node): tensor = self.get_param_tensor(node) + effective_dtype = self.get_effective_dtype(tensor.dtype) + + # Convert the tensor dtype if needed + if tensor.dtype != effective_dtype: + tensor = tensor.to(effective_dtype) + # Serialize tensor data to bytes tensor = tensor.contiguous() size = tensor.untyped_storage().nbytes() @@ -222,6 +230,29 @@ def create_symint_value(self) -> int: self.values.append(vk_graph_schema.VkValue(vk_graph_schema.SymInt(0))) return new_id + def get_effective_dtype(self, dtype: torch.dtype) -> torch.dtype: + if self.downcast_64_bit and dtype == torch.float64: + return torch.float32 + elif self.downcast_64_bit and dtype == torch.int64: + return torch.int32 + elif self.force_fp16 and dtype == torch.float32: + return torch.float16 + else: + return dtype + + def get_staging_dtype(self, dtype: torch.dtype) -> torch.dtype: + # Since 64 bit types are not guaranteed to be supported on all GPUs, + # the conversion between 32 bit and 64 bit types is handled on the CPU + # side. The conversion will occur when copying the staging buffer + # contents to/from ETensor data pointers, rather than in the shader to + # copy between GPU buffer/image to staging buffer. + if self.downcast_64_bit and dtype == torch.float64: + return torch.float32 + elif self.downcast_64_bit and dtype == torch.int64: + return torch.int32 + else: + return dtype + def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: # Negative id indicates that this tensor will have its own dedicated memory. mem_obj_id = -1 @@ -236,14 +267,16 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: storage_type = spec.etvk_node_repr.storage_type memory_layout = spec.etvk_node_repr.memory_layout - # Apply downcast logic before getting VK datatype - effective_dtype = spec.dtype - if self.downcast_64_bit and spec.dtype == torch.float64: - effective_dtype = torch.float32 - elif self.downcast_64_bit and spec.dtype == torch.int64: - effective_dtype = torch.int32 + effective_dtype = self.get_effective_dtype(spec.dtype) + # For constant tensors, the datatype of the original tensor will have been + # converted to the effective dtype. Otherwise, the type of the staging buffer + # for inputs/outputs should match the original tensor dtype. + staging_dtype = ( + effective_dtype if constant_id >= 0 else self.get_staging_dtype(spec.dtype) + ) datatype = self.get_vk_datatype(effective_dtype) + staging_datatype = self.get_vk_datatype(staging_dtype) new_id = len(self.values) self.values.append( @@ -255,6 +288,7 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: mem_obj_id=mem_obj_id, storage_type=storage_type, memory_layout=memory_layout, + staging_datatype=staging_datatype, ) ) ) diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index aa7641bd927..236183ce42f 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -31,6 +31,7 @@ class VkDataType(IntEnum): FLOAT32 = 5 FLOAT64 = 6 INT64 = 7 + UNSET = 127 class VkStorageType(IntEnum): @@ -47,6 +48,8 @@ class VkMemoryLayout(IntEnum): TENSOR_WIDTH_PACKED = 0 TENSOR_HEIGHT_PACKED = 1 TENSOR_CHANNELS_PACKED = 2 + PACKED_INT8_4W4C = 3 + PACKED_INT8_4H4W = 4 DEFAULT_LAYOUT = 255 def __str__(self) -> str: @@ -61,6 +64,7 @@ class VkTensor: mem_obj_id: int storage_type: VkStorageType = VkStorageType.DEFAULT_STORAGE memory_layout: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT + staging_datatype: VkDataType = VkDataType.UNSET @dataclass diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index a9ba62b6f9f..94c1f824633 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -19,6 +19,8 @@ def get_vulkan_preprocessor_flags(no_volk, is_fbcode): default_flags = [] android_flags = [] + debug_mode = read_config("etvk", "debug", "0") == "1" + if not no_volk: for flags in [default_flags, android_flags]: flags.append("-DUSE_VULKAN_WRAPPER") @@ -32,6 +34,10 @@ def get_vulkan_preprocessor_flags(no_volk, is_fbcode): if link_moltenvk: mac_flags = [] + if debug_mode: + mac_flags.append("-DETVK_BOOST_STACKTRACE_AVAILABLE") + default_flags.append("-DETVK_BOOST_STACKTRACE_AVAILABLE") + VK_API_PREPROCESSOR_FLAGS += select({ "DEFAULT": default_flags, "ovr_config//os:android": android_flags, @@ -59,7 +65,6 @@ def get_vulkan_preprocessor_flags(no_volk, is_fbcode): if etvk_default_cache_path != "": VK_API_PREPROCESSOR_FLAGS += ["-DETVK_DEFAULT_CACHE_PATH={}".format(etvk_default_cache_path)] - debug_mode = read_config("etvk", "debug", "0") == "1" if debug_mode: VK_API_PREPROCESSOR_FLAGS += ["-DVULKAN_DEBUG"] @@ -136,6 +141,8 @@ def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False, no_volk = Fal ) def define_common_targets(is_fbcode = False): + debug_mode = read_config("etvk", "debug", "0") == "1" + runtime.python_library( name = "gen_vulkan_spv_lib", srcs = [ @@ -185,6 +192,7 @@ def define_common_targets(is_fbcode = False): else: for deps in [default_deps, android_deps]: deps.append("fbsource//third-party/volk:volk-header") + deps.append("fbsource//third-party/volk:volk-implementation") if is_fbcode: VK_API_DEPS += [ @@ -200,6 +208,10 @@ def define_common_targets(is_fbcode = False): "//third-party/khronos:moltenVK_static" ] + if debug_mode: + mac_deps.append("fbsource//third-party/boost:boost") + default_deps.append("fbsource//third-party/boost:boost") + VK_API_DEPS += select({ "DEFAULT": default_deps, "ovr_config//os:android": android_deps, diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index 53fad86f90c..ee296a4f68f 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -34,7 +34,6 @@ python_unittest( deps = [ "//caffe2:torch", "//executorch/backends/vulkan/_passes:vulkan_passes", - "//executorch/backends/vulkan/quantizer:vulkan_quantizer", "//executorch/backends/vulkan:vulkan_preprocess", "//pytorch/ao:torchao", # @manual ] diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index 97b632338db..6db814815fb 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -48,7 +48,9 @@ if(TARGET vulkan_backend) # Prototyping utility files set(PROTOTYPING_UTILS_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}) - set(PROTOTYPING_UTILS_CPP ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) + set(PROTOTYPING_UTILS_CPP ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/conv2d_utils.cpp + ) # Prototyping shaders message(STATUS "shader stuff") @@ -95,4 +97,8 @@ if(TARGET vulkan_backend) add_operator_prototype(q8csw_conv2d) add_operator_prototype(q4gsw_linear) add_operator_prototype(choose_qparams_per_row) + add_operator_prototype(qdq8ta_conv2d_activations) + add_operator_prototype(q8ta_q8csw_q8to_conv2d) + add_operator_prototype(q8ta_q8csw_q8to_conv2d_dw) + add_operator_prototype(q8ta_q8ta_q8to_add) endif() diff --git a/backends/vulkan/test/custom_ops/conv2d_utils.cpp b/backends/vulkan/test/custom_ops/conv2d_utils.cpp new file mode 100644 index 00000000000..3dbbf0a4c0f --- /dev/null +++ b/backends/vulkan/test/custom_ops/conv2d_utils.cpp @@ -0,0 +1,92 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "conv2d_utils.h" + +namespace executorch { +namespace vulkan { +namespace prototyping { + +std::string make_test_case_conv_params_suffix(const Conv2dConfig& config) { + std::string suffix; + // Only print groups if not equal to 1 + if (config.groups != 1) { + suffix += "g=" + std::to_string(config.groups); + suffix += " "; + } + + suffix += "k="; + if (config.kernel.h == config.kernel.w) { + suffix += std::to_string(config.kernel.w); + } else { + suffix += + std::to_string(config.kernel.w) + "," + std::to_string(config.kernel.h); + } + // Only print stride if either dimension is not 1 + if (config.stride.h > 1 || config.stride.w > 1) { + suffix += ",s="; + if (config.stride.h == config.stride.w) { + suffix += std::to_string(config.stride.w); + } else { + suffix += std::to_string(config.stride.w) + "," + + std::to_string(config.stride.h); + } + } + // Only print padding if either dimension is not 1 + if (config.padding.h != 1 || config.padding.w != 1) { + suffix += ",p="; + if (config.padding.h == config.padding.w) { + suffix += std::to_string(config.padding.w); + } else { + suffix += std::to_string(config.padding.w) + "," + + std::to_string(config.padding.h); + } + } + // Only print dilation if either dimension is not 1 + if (config.dilation.h != 1 || config.dilation.w != 1) { + suffix += ",d="; + if (config.dilation.h == config.dilation.w) { + suffix += std::to_string(config.dilation.w); + } else { + suffix += std::to_string(config.dilation.w) + "," + + std::to_string(config.dilation.h); + } + } + return suffix; +} + +std::string to_string(const vkcompute::utils::StorageType storage_type) { + switch (storage_type) { + case vkcompute::utils::kTexture3D: + return "Tex"; + case vkcompute::utils::kTexture2D: + return "Tex2D"; + case vkcompute::utils::kBuffer: + return "Buf"; + } +} + +std::string make_test_case_name( + const Conv2dConfig& config, + const bool is_performance, + const vkcompute::utils::StorageType fp_storage_type, + const vkcompute::utils::StorageType int8_storage_type) { + std::string test_case_name = is_performance ? "PERF " : "ACCU "; + test_case_name += std::to_string(config.channels.in) + "->" + + std::to_string(config.channels.out) + + " I=" + std::to_string(config.input_size.h) + "," + + std::to_string(config.input_size.w) + " " + + make_test_case_conv_params_suffix(config); + + test_case_name += + " " + to_string(fp_storage_type) + "->" + to_string(int8_storage_type); + + return test_case_name; +} + +} // namespace prototyping +} // namespace vulkan +} // namespace executorch diff --git a/backends/vulkan/test/custom_ops/conv2d_utils.h b/backends/vulkan/test/custom_ops/conv2d_utils.h new file mode 100644 index 00000000000..416f6c50061 --- /dev/null +++ b/backends/vulkan/test/custom_ops/conv2d_utils.h @@ -0,0 +1,96 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#include + +namespace executorch { +namespace vulkan { +namespace prototyping { + +// Component structs for better readability +struct KernelSize { + int32_t h; + int32_t w; + + KernelSize(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Stride { + int32_t h; + int32_t w; + + Stride(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Padding { + int32_t h; + int32_t w; + + Padding(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Dilation { + int32_t h; + int32_t w; + + Dilation(int32_t height = 1, int32_t width = 1) : h(height), w(width) {} +}; + +struct OutInChannels { + int32_t out; + int32_t in; + + OutInChannels(int32_t out_channels, int32_t in_channels) + : out(out_channels), in(in_channels) {} +}; + +struct InputSize2D { + int32_t h; + int32_t w; + + InputSize2D(int32_t height, int32_t width) : h(height), w(width) {} +}; + +// Conv2d configuration struct +struct Conv2dConfig { + OutInChannels channels; + InputSize2D input_size; + KernelSize kernel; + Stride stride; + Padding padding; + Dilation dilation; + int32_t groups; // Number of groups for grouped convolution + std::string test_case_name = "placeholder"; + std::string op_name = "conv2d"; + + // Calculate output dimensions + int64_t get_output_height() const { + return (input_size.h + 2 * padding.h - dilation.h * (kernel.h - 1) - 1) / + stride.h + + 1; + } + + int64_t get_output_width() const { + return (input_size.w + 2 * padding.w - dilation.w * (kernel.w - 1) - 1) / + stride.w + + 1; + } +}; + +std::string make_test_case_name( + const Conv2dConfig& config, + const bool is_performance, + const vkcompute::utils::StorageType fp_storage_type, + const vkcompute::utils::StorageType int8_storage_type); + +} // namespace prototyping +} // namespace vulkan +} // namespace executorch diff --git a/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_buffer.glsl b/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_buffer.glsl index c1d90fadf7e..e2d198b129f 100644 --- a/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_buffer.glsl +++ b/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_buffer.glsl @@ -12,8 +12,6 @@ ${define_active_storage_type("texture3d")} -#extension GL_EXT_debug_printf : enable - layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", "int", "texture3d")} @@ -33,12 +31,6 @@ void main() { // Pack four 8-bit values equal to 1 into a single uint int packed = (1 << 0) | (1 << 8) | (1 << 16) | (1 << 24); - debugPrintfEXT( - "t_out[%i, %i] = %i\\n", - lpos.x, lpos.y, - packed); - - // Placeholder: just copy input to output ivec4 in_texel = ivec4(packed); imageStore(t_out, lpos, in_texel); diff --git a/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_texture3d.glsl b/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_texture3d.glsl index be6717efdaa..80e6fc27909 100644 --- a/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_texture3d.glsl +++ b/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_texture3d.glsl @@ -12,8 +12,6 @@ ${define_active_storage_type("texture2d")} -#extension GL_EXT_debug_printf : enable - layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", "int", "texture3d")} @@ -33,12 +31,6 @@ void main() { // Pack four 8-bit values equal to 1 into a single uint int packed = (1 << 0) | (1 << 8) | (1 << 16) | (1 << 24); - debugPrintfEXT( - "t_out[%i, %i] = %i\\n", - lpos.x, lpos.y, - packed); - - // Placeholder: just copy input to output ivec4 in_texel = ivec4(packed); imageStore(t_out, lpos, in_texel); diff --git a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp index 59d9d694c2c..2af1488541d 100644 --- a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp +++ b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp @@ -552,7 +552,7 @@ int main(int argc, char* argv[]) { generate_quantized_linear_test_cases, quantized_linear_flop_calculator, "QuantizedLinearQ4GSW", - 10, + 3, 10, ref_fn); diff --git a/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp b/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp index d566e5b2646..219bccb04c3 100644 --- a/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp @@ -8,6 +8,7 @@ #include #include #include +#include "conv2d_utils.h" #include "utils.h" #include @@ -18,76 +19,6 @@ using namespace vkcompute; static constexpr int64_t kRefDimSizeLimit = 100; -// Component structs for better readability -struct KernelSize { - int32_t h; - int32_t w; - - KernelSize(int32_t height, int32_t width) : h(height), w(width) {} -}; - -struct Stride { - int32_t h; - int32_t w; - - Stride(int32_t height, int32_t width) : h(height), w(width) {} -}; - -struct Padding { - int32_t h; - int32_t w; - - Padding(int32_t height, int32_t width) : h(height), w(width) {} -}; - -struct Dilation { - int32_t h; - int32_t w; - - Dilation(int32_t height = 1, int32_t width = 1) : h(height), w(width) {} -}; - -struct OutInChannels { - int32_t out; - int32_t in; - - OutInChannels(int32_t out_channels, int32_t in_channels) - : out(out_channels), in(in_channels) {} -}; - -struct InputSize2D { - int32_t h; - int32_t w; - - InputSize2D(int32_t height, int32_t width) : h(height), w(width) {} -}; - -// Conv2d configuration struct -struct Conv2dConfig { - OutInChannels channels; - InputSize2D input_size; - KernelSize kernel; - Stride stride; - Padding padding; - Dilation dilation; - int32_t groups; // Number of groups for grouped convolution - std::string test_case_name = "placeholder"; - std::string op_name = "conv2d_q8ta_q8csw"; - - // Calculate output dimensions - int64_t get_output_height() const { - return (input_size.h + 2 * padding.h - dilation.h * (kernel.h - 1) - 1) / - stride.h + - 1; - } - - int64_t get_output_width() const { - return (input_size.w + 2 * padding.w - dilation.w * (kernel.w - 1) - 1) / - stride.w + - 1; - } -}; - // Utility function to create a test case from a Conv2dConfig TestCase create_test_case_from_config( const Conv2dConfig& config, @@ -366,13 +297,20 @@ std::vector generate_quantized_conv2d_test_cases() { Stride(1, 1), Padding(1, 1), Dilation(1, 1), - 8}, + 1}, {OutInChannels(128, 64), InputSize2D(128, 128), KernelSize(3, 3), Stride(1, 1), Padding(1, 1), Dilation(1, 1), + 1}, + {OutInChannels(128, 1024), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), 1}}; // Test with different storage types and data types @@ -394,6 +332,7 @@ std::vector generate_quantized_conv2d_test_cases() { std::to_string(config.kernel.h) + "/" + std::to_string(config.kernel.w); + config.op_name = "conv2d_q8ta_q8csw"; config.test_case_name = prefix + suffix; // The default operator tested is activation + weight quantized conv2d; // however, only test this if the int8 dot product extension is supported @@ -763,7 +702,7 @@ int64_t quantized_conv2d_flop_calculator(const TestCase& test_case) { int main(int argc, char* argv[]) { set_debugging(false); set_print_output(false); - set_print_latencies(false); + set_print_latencies(true); set_use_gpu_timestamps(true); print_performance_header(); diff --git a/backends/vulkan/test/custom_ops/q8csw_linear.cpp b/backends/vulkan/test/custom_ops/q8csw_linear.cpp index 23973426fcc..4aa6f00d3f5 100644 --- a/backends/vulkan/test/custom_ops/q8csw_linear.cpp +++ b/backends/vulkan/test/custom_ops/q8csw_linear.cpp @@ -471,7 +471,7 @@ int main(int argc, char* argv[]) { generate_quantized_linear_test_cases, quantized_linear_flop_calculator, "QuantizedLinear", - 0, + 3, 10, ref_fn); diff --git a/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp new file mode 100644 index 00000000000..13d8f48bae8 --- /dev/null +++ b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp @@ -0,0 +1,662 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +#include + +#include "conv2d_utils.h" +#include "utils.h" + +// #define DEBUG_MODE + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 100; + +// Utility function to create a test case from a Conv2dConfig +TestCase create_test_case_from_config( + const Conv2dConfig& config, + vkapi::ScalarType input_dtype, + utils::StorageType fp_storage_type, + utils::StorageType int8_storage_type) { + TestCase test_case; + test_case.set_name(config.test_case_name); + + std::string operator_suffix = ".test"; + if (int8_storage_type == utils::kTexture3D) { + operator_suffix += "_texture"; + } else { + operator_suffix += "_buffer"; + } + + // Set the operator name for the test case + std::string operator_name = "etvk." + config.op_name + operator_suffix; + test_case.set_operator_name(operator_name); + + // Calculate output dimensions + int64_t H_out = config.get_output_height(); + int64_t W_out = config.get_output_width(); + + // Input tensor (float/half) - [1, C_in, H_in, W_in] (batch size always 1) + std::vector input_size = { + 1, config.channels.in, config.input_size.h, config.input_size.w}; + + utils::GPUMemoryLayout fp_memory_layout = fp_storage_type == utils::kBuffer + ? utils::kWidthPacked + : utils::kChannelsPacked; + + ValueSpec input_tensor( + input_size, + input_dtype, + fp_storage_type, + fp_memory_layout, +#ifdef DEBUG_MODE + DataGenType::RANDOM +#else + DataGenType::RANDOM +#endif + ); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.008123; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = 2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [C_out, C_in_per_group * K_h * K_w] + // Memory layout: height, width, then channels - in_c is innermost (stride 1) + // in the second dimension + const int64_t in_channels_per_group = config.channels.in / config.groups; + const int64_t in_features = utils::align_up_4( + in_channels_per_group * config.kernel.h * config.kernel.w); + std::vector weight_size = {config.channels.out, in_features}; + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + fp_storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + const int64_t aligned_out_channels = utils::align_up_4(config.channels.out); + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {aligned_out_channels}, // Per output channel + input_dtype, + fp_storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {aligned_out_channels}, // Per output channel + vkapi::kInt, + fp_storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + compute_weight_sums( + weight_sums, quantized_weight, config.channels.out, in_features); + + // Bias (optional, float/half) - [C_out] + ValueSpec bias( + {aligned_out_channels}, // Per output channel + input_dtype, + fp_storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + bias.set_constant(true); + + // Output quantization parameters + float output_scale_val = 0.05314; + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = -1; + ValueSpec output_zero_point(output_zero_point_val); + + // Stride and padding parameters + ValueSpec stride({config.stride.h, config.stride.w}); + ValueSpec padding({config.padding.h, config.padding.w}); + + // Dilation and groups parameters + ValueSpec dilation({config.dilation.h, config.dilation.w}); + ValueSpec groups(config.groups); + + // Kernel size parameters + ValueSpec kernel_size({config.kernel.h, config.kernel.w}); + + // Output tensor (float/half) - [1, C_out, H_out, W_out] (batch size always 1) + ValueSpec output( + {1, config.channels.out, H_out, W_out}, + input_dtype, + fp_storage_type, + fp_memory_layout, + DataGenType::ZEROS); + + // Add all specs to test case for q8ta_q8csw_q8to operation + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(bias); + test_case.add_input_spec(kernel_size); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(groups); + + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + return test_case; +} + +// Generate easy test cases for quantized conv2d operation (for debugging) +std::vector generate_quantized_conv2d_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + Conv2dConfig config = { + OutInChannels(16, 8), // channels (out, in) + InputSize2D(21, 17), // input_size (h, w) + KernelSize(3, 3), // kernel + Stride(1, 1), // stride + Padding(1, 1), // padding + Dilation(1, 1), // dilation + 2, // groups + }; + config.op_name = "conv2d_q8ta_q8csw_q8to"; + + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + // Generate test cases for each combination + for (const utils::StorageType fp_storage_type : storage_types) { + for (const utils::StorageType int8_storage_type : storage_types) { + config.test_case_name = make_test_case_name( + config, false, fp_storage_type, int8_storage_type); + test_cases.push_back(create_test_case_from_config( + config, vkapi::kFloat, fp_storage_type, int8_storage_type)); + } + } + + return test_cases; +} + +// Generate test cases for quantized conv2d operation +std::vector generate_quantized_conv2d_test_cases() { + std::vector test_cases; + if (!vkcompute::api::context()->adapter_ptr()->supports_int8_dot_product()) { + return test_cases; + } + + std::vector configs = { + // Pointwise convolutions: kernel size 1x1 + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(32, 32), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(96, 64), + InputSize2D(16, 16), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(13, 7), + InputSize2D(57, 33), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + // General 2D convolutions + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(16, 32), + InputSize2D(77, 77), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + // Grouped convolutions + {OutInChannels(64, 32), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 2}, + {OutInChannels(96, 96), + InputSize2D(81, 81), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 3}, + {OutInChannels(96, 96), + InputSize2D(64, 64), + KernelSize(5, 5), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 4}, + // Performance cases (pointwise - will use im2col) + {OutInChannels(128, 128), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(128, 128), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + // Performance cases (3x3 convs - will use im2col) + {OutInChannels(32, 3), + InputSize2D(256, 256), + KernelSize(3, 3), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 64), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + // Performance cases (grouped convs) + {OutInChannels(64, 64), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 2}, + {OutInChannels(96, 96), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 3}, + {OutInChannels(128, 128), + InputSize2D(128, 128), + KernelSize(5, 5), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 4}}; + + // Test with different storage types and data types + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + // Generate test cases for each combination + for (auto& config : configs) { + bool is_performance = config.channels.out > kRefDimSizeLimit || + config.channels.in > kRefDimSizeLimit || + config.input_size.h > kRefDimSizeLimit || + config.input_size.w > kRefDimSizeLimit; + + config.op_name = "conv2d_q8ta_q8csw_q8to"; + + for (const utils::StorageType fp_storage_type : storage_types) { + for (const utils::StorageType int8_storage_type : storage_types) { + config.test_case_name = make_test_case_name( + config, is_performance, fp_storage_type, int8_storage_type); + test_cases.push_back(create_test_case_from_config( + config, vkapi::kFloat, fp_storage_type, int8_storage_type)); + } + } + } + + return test_cases; +} + +// Reference implementation for activation, weight, and output quantized conv2d +void conv2d_q8ta_q8csw_q8to_reference_impl(TestCase& test_case) { + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& kernel_size_spec = test_case.inputs()[idx++]; + const ValueSpec& stride_spec = test_case.inputs()[idx++]; + const ValueSpec& padding_spec = test_case.inputs()[idx++]; + const ValueSpec& dilation_spec = test_case.inputs()[idx++]; + const ValueSpec& groups_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [N, C_in, H_in, W_in] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [C_out, C_in_per_group * K_h * K_w] + auto output_sizes = + output_spec.get_tensor_sizes(); // [N, C_out, H_out, W_out] + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t H_in = input_sizes[2]; + int64_t W_in = input_sizes[3]; + int64_t C_out = output_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Get kernel dimensions from kernel_size ValueSpec + auto kernel_size_data = kernel_size_spec.get_int32_data(); + int64_t K_h = kernel_size_data[0]; + int64_t K_w = kernel_size_data[1]; + + // Get stride, padding, dilation, and groups + auto stride_data = stride_spec.get_int32_data(); + auto padding_data = padding_spec.get_int32_data(); + auto dilation_data = dilation_spec.get_int32_data(); + int64_t stride_h = stride_data[0]; + int64_t stride_w = stride_data[1]; + int64_t pad_h = padding_data[0]; + int64_t pad_w = padding_data[1]; + int64_t dilation_h = dilation_data[0]; + int64_t dilation_w = dilation_data[1]; + int64_t groups = groups_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (N > kRefDimSizeLimit || C_in > kRefDimSizeLimit || + H_in > kRefDimSizeLimit || W_in > kRefDimSizeLimit || + C_out > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zeros_spec.get_int_value(); + + // Calculate channels per group for grouped convolution + int64_t C_in_per_group = C_in / groups; + int64_t C_out_per_group = C_out / groups; + + // Calculate number of output elements + int64_t num_output_elements = N * C_out * H_out * W_out; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + const int in_features = utils::align_up_4(C_in_per_group * K_h * K_w); + + // Perform activation, weight, and output quantized conv2d operation + for (int64_t n = 0; n < N; ++n) { + for (int64_t out_c = 0; out_c < C_out; ++out_c) { + for (int64_t out_h = 0; out_h < H_out; ++out_h) { + for (int64_t out_w = 0; out_w < W_out; ++out_w) { + int32_t int_sum = 0; + int32_t weight_sum = 0; // Track weight sum on the fly + + // Determine which group this output channel belongs to + int64_t group_idx = out_c / C_out_per_group; + int64_t in_c_start = group_idx * C_in_per_group; + int64_t in_c_end = (group_idx + 1) * C_in_per_group; + + // Convolution operation with integer accumulation + for (int64_t in_c = in_c_start; in_c < in_c_end; ++in_c) { + for (int64_t kh = 0; kh < K_h; ++kh) { + for (int64_t kw = 0; kw < K_w; ++kw) { + // Calculate input position with dilation + int64_t in_h = out_h * stride_h - pad_h + kh * dilation_h; + int64_t in_w = out_w * stride_w - pad_w + kw * dilation_w; + + // Check bounds (zero padding) + if (in_h >= 0 && in_h < H_in && in_w >= 0 && in_w < W_in) { + // Get input value and quantize to int8 + int64_t input_idx = n * (C_in * H_in * W_in) + + in_c * (H_in * W_in) + in_h * W_in + in_w; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + + input_zero_point; + quant_input_f = + std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Get quantized weight (already int8) + // Weight layout: [C_out, C_in_per_group * K_h * K_w] + int64_t weight_idx = out_c * in_features + + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + (in_c % C_in_per_group)); + int8_t quantized_weight = weight_data[weight_idx]; + + // Integer multiplication and accumulation + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } else { + // For zero padding, we still need to account for the weight + // in weight_sum when input is effectively 0 (but quantized 0 + // is input_zero_point) + int64_t weight_idx = out_c * in_features + + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + (in_c % C_in_per_group)); + int8_t quantized_weight = weight_data[weight_idx]; + + // Add contribution from zero-padded input (quantized zero = + // input_zero_point) + int_sum += static_cast(input_zero_point) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } + } + } + } + + // Convert accumulated integer result to float and apply scales + // Final result = (int_sum - zero_point_correction) * input_scale * + // weight_scale + bias zero_point_correction = input_zero_point * + // sum_of_weights_for_this_output_channel + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_c]; + + // Add bias and store result + float_result += bias_data[out_c]; + + // Quantize the output to int8 + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + // Dequantize back to float + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + int64_t output_idx = n * (C_out * H_out * W_out) + + out_c * (H_out * W_out) + out_h * W_out + out_w; + ref_data[output_idx] = dequant_output; + } + } + } + } +} + +void reference_impl(TestCase& test_case) { + conv2d_q8ta_q8csw_q8to_reference_impl(test_case); +} + +// Custom FLOP calculator for quantized conv2d operation +int64_t quantized_conv2d_flop_calculator(const TestCase& test_case) { + int kernel_idx = 9; // kernel_size is at index 9 for q8ta_q8csw_q8to + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + + const auto& kernel_sizes = test_case.inputs()[kernel_idx].get_int32_data(); + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t C_out = output_sizes[1]; + int64_t K_h = kernel_sizes[0]; + int64_t K_w = kernel_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Calculate FLOPs for quantized conv2d operation + // Each output element requires: + // - C_in * K_h * K_w multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = N * C_out * H_out * W_out; + int64_t ops_per_output = C_in * K_h * K_w; + + int64_t flop = output_elements * (ops_per_output); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); +#ifdef DEBUG_MODE + set_print_latencies(true); +#else + set_print_latencies(false); +#endif + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout + << "Quantized Conv2d Operation with Output Quantization Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + // Execute test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( +#ifdef DEBUG_MODE + generate_quantized_conv2d_easy_cases, +#else + generate_quantized_conv2d_test_cases, +#endif + quantized_conv2d_flop_calculator, + "QuantizedConv2dQ8ToQ8To", +#ifdef DEBUG_MODE + 0, + 1, +#else + 3, + 10, +#endif + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d_dw.cpp b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d_dw.cpp new file mode 100644 index 00000000000..2d8d32dde74 --- /dev/null +++ b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d_dw.cpp @@ -0,0 +1,602 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +#include + +#include "conv2d_utils.h" +#include "utils.h" + +// #define DEBUG_MODE + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 100; + +// Utility function to create a test case from a Conv2dConfig for depthwise +// convolution +TestCase create_test_case_from_config( + const Conv2dConfig& config, + vkapi::ScalarType input_dtype, + utils::StorageType fp_storage_type, + utils::StorageType int8_storage_type) { + TestCase test_case; + test_case.set_name(config.test_case_name); + + std::string operator_suffix = ".test"; + if (int8_storage_type == utils::kTexture3D) { + operator_suffix += "_texture"; + } else { + operator_suffix += "_buffer"; + } + + // Set the operator name for the test case + std::string operator_name = "etvk." + config.op_name + operator_suffix; + test_case.set_operator_name(operator_name); + + // Calculate output dimensions + int64_t H_out = config.get_output_height(); + int64_t W_out = config.get_output_width(); + + // Input tensor (float/half) - [1, C_in, H_in, W_in] (batch size always 1) + std::vector input_size = { + 1, config.channels.in, config.input_size.h, config.input_size.w}; + + utils::GPUMemoryLayout fp_memory_layout = fp_storage_type == utils::kBuffer + ? utils::kWidthPacked + : utils::kChannelsPacked; + + ValueSpec input_tensor( + input_size, + input_dtype, + fp_storage_type, + fp_memory_layout, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor", false, 64); + } + + float input_scale_val = 0.008123; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = 2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) for depthwise convolution + // Memory layout: [K_h, K_w, OC] + // For depthwise conv: groups = channels.out, in_channels_per_group = 1 + std::vector weight_size = { + config.kernel.h, config.kernel.w, config.channels.out}; + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + fp_storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor", false, 64); + } + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {config.channels.out}, // Per output channel + input_dtype, + fp_storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {config.channels.out}, // Per output channel + vkapi::kInt, + fp_storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights for depthwise layout + // For depthwise conv: each output channel has K_h * K_w weights + // Custom computation for depthwise layout [K_h, K_w, OC] + auto& weight_sums_data = weight_sums.get_int32_data(); + auto& quantized_weight_data = quantized_weight.get_int8_data(); + + weight_sums_data.resize(config.channels.out); + + for (int64_t out_c = 0; out_c < config.channels.out; ++out_c) { + int32_t sum = 0; + for (int64_t kh = 0; kh < config.kernel.h; ++kh) { + for (int64_t kw = 0; kw < config.kernel.w; ++kw) { + // Weight indexing for depthwise layout [K_h, K_w, OC] + int64_t weight_idx = kh * (config.kernel.w * config.channels.out) + + kw * config.channels.out + out_c; + sum += static_cast(quantized_weight_data[weight_idx]); + } + } + weight_sums_data[out_c] = sum; + } + + // Bias (optional, float/half) - [C_out] + ValueSpec bias( + {config.channels.out}, // Per output channel + input_dtype, + fp_storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + bias.set_constant(true); + + // Output quantization parameters + float output_scale_val = 0.05314; + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = -1; + ValueSpec output_zero_point(output_zero_point_val); + + // Stride and padding parameters + ValueSpec stride({config.stride.h, config.stride.w}); + ValueSpec padding({config.padding.h, config.padding.w}); + + // Dilation and groups parameters + ValueSpec dilation({config.dilation.h, config.dilation.w}); + ValueSpec groups(config.groups); + + // Kernel size parameters + ValueSpec kernel_size({config.kernel.h, config.kernel.w}); + + // Output tensor (float/half) - [1, C_out, H_out, W_out] (batch size always 1) + ValueSpec output( + {1, config.channels.out, H_out, W_out}, + input_dtype, + fp_storage_type, + fp_memory_layout, + DataGenType::ZEROS); + + // Add all specs to test case for q8ta_q8csw_q8to operation + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(bias); + test_case.add_input_spec(kernel_size); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(groups); + + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + return test_case; +} + +// Generate easy test cases for quantized depthwise conv2d operation (for +// debugging) +std::vector generate_quantized_conv2d_dw_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging - depthwise convolution + Conv2dConfig config = { + OutInChannels(8, 8), // channels (out, in) - equal for depthwise + InputSize2D(8, 8), // input_size (h, w) + KernelSize(3, 3), // kernel + Stride(1, 1), // stride + Padding(1, 1), // padding + Dilation(1, 1), // dilation + 8, // groups = channels.out for depthwise + }; + config.op_name = "conv2d_q8ta_q8csw_q8to"; + + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + // Generate test cases for each combination + for (const utils::StorageType fp_storage_type : storage_types) { + for (const utils::StorageType int8_storage_type : storage_types) { + config.test_case_name = make_test_case_name( + config, false, fp_storage_type, int8_storage_type); + test_cases.push_back(create_test_case_from_config( + config, vkapi::kFloat, fp_storage_type, int8_storage_type)); + } + } + + return test_cases; +} + +// Generate test cases for quantized depthwise conv2d operation +std::vector generate_quantized_conv2d_dw_test_cases() { + std::vector test_cases; + if (!vkcompute::api::context()->adapter_ptr()->supports_int8_dot_product()) { + return test_cases; + } + + std::vector configs = { + // Depthwise convolutions: groups = channels.out, channels.in = + // channels.out + {OutInChannels(32, 32), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 32}, + {OutInChannels(64, 64), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 64}, + {OutInChannels(64, 64), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 64}, + {OutInChannels(80, 80), + InputSize2D(16, 16), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 80}, + {OutInChannels(16, 16), + InputSize2D(57, 33), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 16}, + // Different kernel sizes for depthwise + {OutInChannels(32, 32), + InputSize2D(64, 64), + KernelSize(5, 5), + Stride(1, 1), + Padding(2, 2), + Dilation(1, 1), + 32}, + {OutInChannels(96, 96), + InputSize2D(64, 64), + KernelSize(5, 5), + Stride(2, 2), + Padding(2, 2), + Dilation(1, 1), + 96}, + // Performance cases + {OutInChannels(128, 128), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 128}, + {OutInChannels(64, 64), + InputSize2D(256, 256), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 64}, + {OutInChannels(288, 288), + InputSize2D(16, 16), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 288}, + {OutInChannels(32, 32), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(2, 2), + Dilation(1, 1), + 32}}; + + // Test with different storage types and data types + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + // Generate test cases for each combination + for (auto& config : configs) { + bool is_performance = config.channels.out > kRefDimSizeLimit || + config.channels.in > kRefDimSizeLimit || + config.input_size.h > kRefDimSizeLimit || + config.input_size.w > kRefDimSizeLimit; + + config.op_name = "conv2d_q8ta_q8csw_q8to"; + + for (const utils::StorageType fp_storage_type : storage_types) { + for (const utils::StorageType int8_storage_type : storage_types) { + config.test_case_name = make_test_case_name( + config, is_performance, fp_storage_type, utils::kBuffer); + test_cases.push_back(create_test_case_from_config( + config, vkapi::kFloat, fp_storage_type, int8_storage_type)); + } + } + } + + return test_cases; +} + +// Reference implementation for activation, weight, and output quantized +// depthwise conv2d +void conv2d_q8ta_q8csw_q8to_dw_reference_impl(TestCase& test_case) { + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& kernel_size_spec = test_case.inputs()[idx++]; + const ValueSpec& stride_spec = test_case.inputs()[idx++]; + const ValueSpec& padding_spec = test_case.inputs()[idx++]; + const ValueSpec& dilation_spec = test_case.inputs()[idx++]; + const ValueSpec& groups_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [N, C_in, H_in, W_in] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [K_h, align_up_4(K_w), OC] + auto output_sizes = + output_spec.get_tensor_sizes(); // [N, C_out, H_out, W_out] + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t H_in = input_sizes[2]; + int64_t W_in = input_sizes[3]; + int64_t C_out = output_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Get kernel dimensions from kernel_size ValueSpec + auto kernel_size_data = kernel_size_spec.get_int32_data(); + int64_t K_h = kernel_size_data[0]; + int64_t K_w = kernel_size_data[1]; + + // Get stride, padding, dilation, and groups + auto stride_data = stride_spec.get_int32_data(); + auto padding_data = padding_spec.get_int32_data(); + auto dilation_data = dilation_spec.get_int32_data(); + int64_t stride_h = stride_data[0]; + int64_t stride_w = stride_data[1]; + int64_t pad_h = padding_data[0]; + int64_t pad_w = padding_data[1]; + int64_t dilation_h = dilation_data[0]; + int64_t dilation_w = dilation_data[1]; + int64_t groups = groups_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (N > kRefDimSizeLimit || C_in > kRefDimSizeLimit || + H_in > kRefDimSizeLimit || W_in > kRefDimSizeLimit || + C_out > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Verify this is a depthwise convolution + if (groups != C_out || C_in != C_out) { + throw std::invalid_argument( + "This is not a depthwise convolution configuration"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zeros_spec.get_int_value(); + + // Calculate number of output elements + int64_t num_output_elements = N * C_out * H_out * W_out; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform activation, weight, and output quantized depthwise conv2d operation + for (int64_t n = 0; n < N; ++n) { + for (int64_t out_c = 0; out_c < C_out; ++out_c) { + for (int64_t out_h = 0; out_h < H_out; ++out_h) { + for (int64_t out_w = 0; out_w < W_out; ++out_w) { + int32_t int_sum = 0; + int32_t weight_sum = 0; // Track weight sum on the fly + + // For depthwise convolution, each output channel corresponds to one + // input channel + int64_t in_c = out_c; + + // Convolution operation with integer accumulation + for (int64_t kh = 0; kh < K_h; ++kh) { + for (int64_t kw = 0; kw < K_w; ++kw) { + // Calculate input position with dilation + int64_t in_h = out_h * stride_h - pad_h + kh * dilation_h; + int64_t in_w = out_w * stride_w - pad_w + kw * dilation_w; + + // Check bounds (zero padding) + if (in_h >= 0 && in_h < H_in && in_w >= 0 && in_w < W_in) { + // Get input value and quantize to int8 + int64_t input_idx = n * (C_in * H_in * W_in) + + in_c * (H_in * W_in) + in_h * W_in + in_w; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + + input_zero_point; + quant_input_f = + std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Get quantized weight using depthwise layout [K_h, K_w, OC] + int64_t weight_idx = kh * (K_w * C_out) + kw * C_out + out_c; + int8_t quantized_weight = weight_data[weight_idx]; + + if (false && in_w == 0 && in_h == 0 && out_c == 0) { + std::cout << "input: " << input_data[input_idx] << std::endl; + std::cout << "quantized_input: " << (int)quantized_input + << std::endl; + std::cout << "quantized_weight: " << (int)quantized_weight + << std::endl; + } + // Integer multiplication and accumulation + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } else { + // For zero padding, we still need to account for the weight + // in weight_sum when input is effectively 0 (but quantized 0 + // is input_zero_point) + int64_t weight_idx = kh * (K_w * C_out) + kw * C_out + out_c; + int8_t quantized_weight = weight_data[weight_idx]; + + // Add contribution from zero-padded input (quantized zero = + // input_zero_point) + int_sum += static_cast(input_zero_point) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } + } + } + + // Convert accumulated integer result to float and apply scales + // Final result = (int_sum - zero_point_correction) * input_scale * + // weight_scale + bias zero_point_correction = input_zero_point * + // sum_of_weights_for_this_output_channel + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_c]; + + // Add bias and store result + float_result += bias_data[out_c]; + + // Quantize the output to int8 + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + if (false && out_c < 4 && out_h < 1 && out_w < 4) { + std::cout << "int_sum[" << out_c << ", " << out_h << ", " << out_w + << "] = " << int_sum << ", " << float_result << ", " + << output_scale << ", " << quant_output_f << std::endl; + } + + // Dequantize back to float + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + int64_t output_idx = n * (C_out * H_out * W_out) + + out_c * (H_out * W_out) + out_h * W_out + out_w; + ref_data[output_idx] = dequant_output; + } + } + } + } +} + +void reference_impl(TestCase& test_case) { + conv2d_q8ta_q8csw_q8to_dw_reference_impl(test_case); +} + +// Custom FLOP calculator for quantized depthwise conv2d operation +int64_t quantized_conv2d_dw_flop_calculator(const TestCase& test_case) { + int kernel_idx = 9; // kernel_size is at index 9 for q8ta_q8csw_q8to + + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + + const auto& kernel_sizes = test_case.inputs()[kernel_idx].get_int32_data(); + + int64_t N = input_sizes[0]; + int64_t C_out = output_sizes[1]; + int64_t K_h = kernel_sizes[0]; + int64_t K_w = kernel_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Calculate FLOPs for quantized depthwise conv2d operation + // Each output element requires: + // - K_h * K_w multiply-accumulate operations (only one input channel per + // output channel) + // - Additional operations for quantization/dequantization + int64_t output_elements = N * C_out * H_out * W_out; + int64_t ops_per_output = K_h * K_w; + + int64_t flop = output_elements * ops_per_output; + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout + << "Quantized Depthwise Conv2d Operation with Output Quantization Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + // Execute test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( +#ifdef DEBUG_MODE + generate_quantized_conv2d_dw_easy_cases, +#else + generate_quantized_conv2d_dw_test_cases, +#endif + quantized_conv2d_dw_flop_calculator, + "QuantizedDepthwiseInt8Conv2d", +#ifdef DEBUG_MODE + 0, + 1, +#else + 3, + 10, +#endif + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp b/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp new file mode 100644 index 00000000000..eb8e6908060 --- /dev/null +++ b/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp @@ -0,0 +1,258 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +using namespace executorch::vulkan::prototyping; + +// Utility function to create a test case for quantized add operation +TestCase create_quantized_add_test_case( + const std::vector& sizes, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string size_str = ""; + for (size_t i = 0; i < sizes.size(); ++i) { + size_str += std::to_string(sizes[i]); + if (i < sizes.size() - 1) + size_str += "x"; + } + + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + "QuantizedAdd_" + size_str + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + test_case.set_operator_name("et_vk.add_q8ta_q8ta_q8to.test"); + + utils::GPUMemoryLayout io_memory_layout = storage_type == utils::kBuffer + ? utils::kWidthPacked + : utils::kChannelsPacked; + + // Input tensor A (float/half) + ValueSpec input_a( + sizes, input_dtype, storage_type, io_memory_layout, DataGenType::RANDOM); + + // Input tensor B (float/half) + ValueSpec input_b( + sizes, input_dtype, storage_type, io_memory_layout, DataGenType::RANDOM); + + // Quantization parameters for input A + float input_a_scale_val = 0.007843; // 2/255 approximately + ValueSpec input_a_scale(input_a_scale_val); + + int32_t input_a_zero_point_val = 3; + ValueSpec input_a_zero_point(input_a_zero_point_val); + + // Quantization parameters for input B + float input_b_scale_val = 0.009412; // 2.4/255 approximately + ValueSpec input_b_scale(input_b_scale_val); + + int32_t input_b_zero_point_val = -2; + ValueSpec input_b_zero_point(input_b_zero_point_val); + + // Output quantization parameters + float output_scale_val = 0.015686; // 4/255 approximately + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = 1; + ValueSpec output_zero_point(output_zero_point_val); + + // Alpha parameter + float alpha_val = 1.0f; + ValueSpec alpha(alpha_val); + + // Output tensor (float/half) + ValueSpec output( + sizes, input_dtype, storage_type, io_memory_layout, DataGenType::ZEROS); + + // Add all specs to test case for q8ta_q8ta_q8to add operation + test_case.add_input_spec(input_a); + test_case.add_input_spec(input_b); + test_case.add_input_spec(input_a_scale); + test_case.add_input_spec(input_a_zero_point); + test_case.add_input_spec(input_b_scale); + test_case.add_input_spec(input_b_zero_point); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(alpha); + + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + return test_case; +} + +// Generate test cases for quantized add operation +std::vector generate_quantized_add_test_cases() { + std::vector test_cases; + + // Define different input size configurations + std::vector> size_configs = { + {3, 32, 32}, // Small square + {8, 64, 64}, // Medium square + {16, 16, 16}, // 3D cube + {8, 32, 16}, // 3D rectangular + {7, 7, 13}, // Irregular sizes + }; + + // Storage types to test + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + // Data types to test + std::vector data_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& sizes : size_configs) { + for (const auto& storage_type : storage_types) { + for (const auto& data_type : data_types) { + test_cases.push_back( + create_quantized_add_test_case(sizes, storage_type, data_type)); + } + } + } + + return test_cases; +} + +// Reference implementation for quantized add operation +void add_q8ta_q8ta_q8to_reference_impl(TestCase& test_case) { + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_a_spec = test_case.inputs()[idx++]; + const ValueSpec& input_b_spec = test_case.inputs()[idx++]; + const ValueSpec& input_a_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_a_zero_point_spec = test_case.inputs()[idx++]; + const ValueSpec& input_b_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_b_zero_point_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zero_point_spec = test_case.inputs()[idx++]; + const ValueSpec& alpha_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_a_spec.get_tensor_sizes(); + int64_t num_elements = input_a_spec.numel(); + + if (input_a_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_a_data = input_a_spec.get_float_data(); + auto& input_b_data = input_b_spec.get_float_data(); + + const float input_a_scale = input_a_scale_spec.get_float_value(); + const int32_t input_a_zero_point = input_a_zero_point_spec.get_int_value(); + const float input_b_scale = input_b_scale_spec.get_float_value(); + const int32_t input_b_zero_point = input_b_zero_point_spec.get_int_value(); + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zero_point_spec.get_int_value(); + const float alpha = alpha_spec.get_float_value(); + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_elements); + + // Perform quantized add operation + for (int64_t i = 0; i < num_elements; ++i) { + // Quantize input A to int8 + float quant_a_f = + std::round(input_a_data[i] / input_a_scale) + input_a_zero_point; + quant_a_f = std::min(std::max(quant_a_f, -128.0f), 127.0f); + int8_t quantized_a = static_cast(quant_a_f); + + // Quantize input B to int8 + float quant_b_f = + std::round(input_b_data[i] / input_b_scale) + input_b_zero_point; + quant_b_f = std::min(std::max(quant_b_f, -128.0f), 127.0f); + int8_t quantized_b = static_cast(quant_b_f); + + // Dequantize both inputs to a common scale for addition + float dequant_a = + (static_cast(quantized_a) - input_a_zero_point) * input_a_scale; + float dequant_b = + (static_cast(quantized_b) - input_b_zero_point) * input_b_scale; + + // Perform addition in float space with alpha + float float_result = dequant_a + alpha * dequant_b; + + // Quantize the result to int8 + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + // Dequantize back to float for comparison + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + ref_data[i] = dequant_output; + } +} + +void reference_impl(TestCase& test_case) { + add_q8ta_q8ta_q8to_reference_impl(test_case); +} + +// Custom FLOP calculator for quantized add operation +int64_t quantized_add_flop_calculator(const TestCase& test_case) { + // Calculate total elements from the first input tensor + int64_t total_elements = 1; + if (!test_case.empty() && test_case.num_inputs() > 0 && + test_case.inputs()[0].is_tensor()) { + const auto& sizes = test_case.inputs()[0].get_tensor_sizes(); + for (int64_t size : sizes) { + total_elements *= size; + } + } + + // Quantized add operation includes: + // - 2 quantizations (float to int8) + // - 2 dequantizations (int8 to float) + // - 1 addition + // For simplicity, we count this as 1 FLOP per element (the addition) + return total_elements; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Quantized Add Operation (q8ta_q8ta_q8to) Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + // Execute test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( + generate_quantized_add_test_cases, + quantized_add_flop_calculator, + "QuantizedAddQ8taQ8taQ8to", + 0, + 1, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations.cpp b/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations.cpp new file mode 100644 index 00000000000..5275e6c9335 --- /dev/null +++ b/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations.cpp @@ -0,0 +1,251 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 512; + +// QDQ8TA Conv2D configuration struct for 4D tensor quantize-dequantize testing +struct QDQ8TAConv2DConfig { + int64_t batch_size; // N dimension + int64_t in_channels; // C dimension + int64_t height; // H dimension + int64_t width; // W dimension + std::string test_case_name = "placeholder"; + std::string op_name = "qdq8ta_conv2d_input"; +}; + +// Utility function to create a test case from a QDQ8TAConv2DConfig +TestCase create_test_case_from_config( + const QDQ8TAConv2DConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + config.test_case_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "etvk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Input tensor (float) - [N, C, H, W] + std::vector input_size = { + config.batch_size, config.in_channels, config.height, config.width}; + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kChannelsPacked, // Use channels packed for conv2d tensors + DataGenType::RANDOM); + + float scale_val = 0.007112; + ValueSpec scale(scale_val); + + // Generate random zero point within quantization range + int32_t zero_point_val = -2; + ValueSpec zero_point(zero_point_val); + + // Output tensor (float) - same shape as input [N, C, H, W] + ValueSpec output_tensor( + input_size, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(scale); + test_case.add_input_spec(zero_point); + test_case.add_output_spec(output_tensor); + + test_case.set_abs_tolerance(scale_val + 1e-4); + + return test_case; +} + +// Generate easy test cases for qdq8ta_conv2d operation (for debugging) +std::vector generate_qdq8ta_conv2d_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + QDQ8TAConv2DConfig config = { + 1, // batch_size + 3, // in_channels + 4, // height + 4, // width + "simple", // test_case_name + }; + + // Test with both storage types + std::vector storage_types = {utils::kTexture3D}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for qdq8ta_conv2d operation +std::vector generate_qdq8ta_conv2d_test_cases() { + std::vector test_cases; + + std::vector configs = { + // Small test cases for correctness + {1, 3, 16, 16}, + {1, 8, 32, 32}, + {1, 16, 24, 24}, + {1, 32, 12, 12}, + {1, 1, 64, 64}, + {1, 3, 64, 64}, + {1, 4, 16, 16}, + + // Different tensor sizes + {1, 8, 20, 20}, + {1, 16, 14, 14}, + {1, 8, 28, 28}, + + // Odd tensor sizes + {1, 3, 15, 15}, + {1, 13, 31, 31}, + {1, 17, 23, 23}, + + // Performance test cases (larger tensors) + {1, 64, 128, 128}, + {1, 32, 64, 64}, + {1, 128, 56, 56}, + }; + + // Test with different storage types + std::vector storage_types = {utils::kTexture3D}; + + for (auto config : configs) { + std::string prefix = + (config.batch_size < kRefDimSizeLimit && + config.in_channels < kRefDimSizeLimit && + config.height < kRefDimSizeLimit && config.width < kRefDimSizeLimit) + ? "correctness_" + : "performance_"; + std::string generated_test_case_name = prefix + + std::to_string(config.batch_size) + "_" + + std::to_string(config.in_channels) + "_" + + std::to_string(config.height) + "_" + std::to_string(config.width); + + config.test_case_name = generated_test_case_name; + + for (const auto& storage_type : storage_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + } + + return test_cases; +} + +// Reference implementation for qdq8ta_conv2d operation +void qdq8ta_conv2d_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& scale_spec = test_case.inputs()[idx++]; + const ValueSpec& zero_point_spec = test_case.inputs()[idx++]; + + // Extract output specification + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [N, C, H, W] + int64_t N = input_sizes[0]; + int64_t C = input_sizes[1]; + int64_t H = input_sizes[2]; + int64_t W = input_sizes[3]; + + // Skip for large tensors since computation time will be extremely slow + if (N > kRefDimSizeLimit || C > kRefDimSizeLimit || H > kRefDimSizeLimit || + W > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (N, C, H, W) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + + // Extract the randomized scale and zero point values (following + // q8csw_conv2d.cpp pattern) + float scale = scale_spec.get_float_value(); + int32_t zero_point = zero_point_spec.get_int_value(); + int32_t quant_min = -128; + int32_t quant_max = 127; + + // Prepare output data + auto& ref_data = output_spec.get_ref_float_data(); + int64_t num_elements = N * C * H * W; + ref_data.resize(num_elements); + + // Perform quantize-dequantize operation on each element + for (int64_t i = 0; i < num_elements; ++i) { + float input_val = input_data[i]; + + // Quantize: quantized = round(input / scale + zero_point) + float quantized_float = std::round(input_val / scale) + zero_point; + + // Clamp to quantization range + quantized_float = std::max(quantized_float, static_cast(quant_min)); + quantized_float = std::min(quantized_float, static_cast(quant_max)); + + int32_t quantized_int = static_cast(quantized_float); + + // Dequantize: output = (quantized - zero_point) * scale + float dequantized = (quantized_int - zero_point) * scale; + + ref_data[i] = dequantized; + } +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "QDQ8TA Conv2D Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = qdq8ta_conv2d_reference_impl; + + auto results = execute_test_cases( + generate_qdq8ta_conv2d_test_cases, "QDQ8TAConv2D", 0, 1, ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 3162857c2d3..ace56a7cf25 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -57,12 +57,15 @@ def define_common_targets(is_fbcode = False): name = "prototyping_utils", srcs = [ "utils.cpp", + "conv2d_utils.cpp", ], headers = [ "utils.h", + "conv2d_utils.h", ], exported_headers = [ "utils.h", + "conv2d_utils.h", ], platforms = get_platforms(), deps = [ @@ -97,3 +100,7 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("q8csw_conv2d") define_custom_op_test_binary("choose_qparams_per_row") define_custom_op_test_binary("q4gsw_linear") + define_custom_op_test_binary("qdq8ta_conv2d_activations") + define_custom_op_test_binary("q8ta_q8csw_q8to_conv2d") + define_custom_op_test_binary("q8ta_q8csw_q8to_conv2d_dw") + define_custom_op_test_binary("q8ta_q8ta_q8to_add") diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index 2aa827a4d5a..79f1dd6b777 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -661,7 +661,11 @@ float collect_gpu_timing_us(ComputeGraph& graph) { float total_duration_us = 0.0f; for (const auto& shader_result : results) { if (shader_result.kernel_name.find("nchw_to") == std::string::npos && - shader_result.kernel_name.find("to_nchw") == std::string::npos) { + shader_result.kernel_name.find("to_nchw") == std::string::npos && + shader_result.kernel_name.find("quantize_and_pack_4w4c") == + std::string::npos && + shader_result.kernel_name.find("unpack_4w4c_and_dequantize") == + std::string::npos) { // Calculate duration from start and end times, convert from ns to μs uint64_t duration_ns = shader_result.end_time_ns - shader_result.start_time_ns; @@ -1284,7 +1288,7 @@ TestResult execute_test_cases( try { result = execute_test_case(test_case, warmup_runs, benchmark_runs); result.set_operator_name(test_case.operator_name()); - } catch (const vkcompute::vkapi::ShaderNotSupportedError& e) { + } catch (const vkcompute::vkapi::ShaderNotSupportedError&) { result = BenchmarkResult( test_case.name().empty() ? "unnamed_test_case" : test_case.name(), test_case.operator_name()); @@ -1715,6 +1719,41 @@ void compute_weight_sums( } } +// Compute weight sums for 4D quantized conv2d operations +// Weight layout: [C_out, K_h, K_w, align_up_4(C_in_per_group)] +void compute_weight_sums_4d( + ValueSpec& weight_sums, + const ValueSpec& quantized_weight, + int64_t out_channels, + int64_t kernel_h, + int64_t kernel_w, + int64_t aligned_in_channels) { + auto& weight_sums_data = weight_sums.get_int32_data(); + auto& quantized_weight_data = quantized_weight.get_int8_data(); + + weight_sums_data.resize(out_channels); + + // For each output channel, compute the sum of quantized weights + for (int64_t out_c = 0; out_c < out_channels; ++out_c) { + int32_t sum = 0; + + for (int64_t kh = 0; kh < kernel_h; ++kh) { + for (int64_t kw = 0; kw < kernel_w; ++kw) { + for (int64_t in_c = 0; in_c < aligned_in_channels; ++in_c) { + // Weight indexing: [out_c, kh, kw, in_c] + int64_t weight_idx = + out_c * (kernel_h * kernel_w * aligned_in_channels) + + kh * (kernel_w * aligned_in_channels) + kw * aligned_in_channels + + in_c; + sum += static_cast(quantized_weight_data[weight_idx]); + } + } + } + + weight_sums_data[out_c] = sum; + } +} + // Helper function to unpack 4-bit values from uint8 (same as in // q4gsw_linear.cpp) std::pair unpack_4bit_utils(uint8_t packed) { diff --git a/backends/vulkan/test/custom_ops/utils.h b/backends/vulkan/test/custom_ops/utils.h index f1736f1d144..b80f28639e8 100644 --- a/backends/vulkan/test/custom_ops/utils.h +++ b/backends/vulkan/test/custom_ops/utils.h @@ -653,6 +653,16 @@ void compute_weight_sums( int64_t out_features, int64_t elements_per_output_feature); +// Compute weight sums for 4D quantized conv2d operations +// Weight layout: [C_out, K_h, K_w, align_up_4(C_in_per_group)] +void compute_weight_sums_4d( + ValueSpec& weight_sums, + const ValueSpec& quantized_weight, + int64_t out_channels, + int64_t kernel_h, + int64_t kernel_w, + int64_t aligned_in_channels); + // Compute weight sums for 4-bit group symmetric quantized weights void compute_weight_sums_4bit_grouped( ValueSpec& weight_sums, diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 8c5d0c4797b..b21a8458a89 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -6,7 +6,6 @@ import itertools - from collections import namedtuple from typing import Callable @@ -1108,21 +1107,84 @@ def get_index_select_inputs(): @register_test_suite("aten.embedding.default") def get_embedding_inputs(): - Test = namedtuple("VkEmbeddingTest", ["weight", "indices"]) + Test = namedtuple("EmbeddingTest", ["weight", "indices"]) Test.__new__.__defaults__ = (None, None) test_cases = [ - Test(weight=[10, 9], indices=[0, 2]), + Test(weight=[10, 9], indices=[3, 5]), Test(weight=[10, 9], indices=[2, 3, 4, 5, 7]), Test(weight=[10, 9], indices=[[0, 2], [1, 4], [7, 7]]), Test(weight=[10, 9], indices=[[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]), - Test(weight=[10, 9], indices=[[[3, 1, 4], [1, 5, 9]], [[2, 6, 5], [3, 5, 8]]]), ] - test_suite = VkTestSuite([tuple(tc) + (-1, "false", "false") for tc in test_cases]) + # Channels packed test cases currently fail on Mac, so they are not included. + # However the test case definition is kept for later debugging. + test_suite_cpack = VkTestSuite( + [tuple(tc) + (-1, "false", "false") for tc in test_cases] + ) + + test_suite_cpack.dtypes = ["at::kFloat"] + test_suite_cpack.layouts = ["utils::kChannelsPacked"] + test_suite_cpack.test_name_suffix = "cpacked" + + test_suite_wpack = VkTestSuite( + [tuple(tc) + (-1, "false", "false") for tc in test_cases] + ) + + test_suite_wpack.dtypes = ["at::kFloat"] + test_suite_wpack.layouts = ["utils::kWidthPacked"] + test_suite_wpack.storage_types = ["utils::kBuffer", "utils::kTexture3D"] + test_suite_wpack.test_name_suffix = "wpacked" + + return test_suite_wpack + + +@register_test_suite("aten.gather.default") +def get_gather_inputs(): + Test = namedtuple("GatherTest", ["input", "dim", "index"]) + Test.__new__.__defaults__ = (None, None, None) + + test_cases = [ + # Simple 2D case + Test(input=[4, 4], dim=1, index=[[1, 2], [2, 1], [3, 3], [3, 1]]), + # # 1D cases + Test(input=[10], dim=0, index=[0, 2, 5, 7, 9]), + Test(input=[8], dim=0, index=[1, 3, 5]), + # # 2D cases with different dims + Test(input=[5, 8], dim=0, index=[[0, 1], [2, 3], [4, 0]]), + Test( + input=[5, 8], + dim=1, + index=[[0, 2, 4], [1, 3, 5], [6, 7, 0], [1, 2, 3], [4, 5, 6]], + ), + # # 3D cases + Test( + input=[3, 4, 5], + dim=0, + index=[ + [[0, 1, 2, 0, 1], [1, 2, 0, 1, 2], [2, 0, 1, 2, 0], [0, 1, 2, 0, 1]] + ], + ), + Test( + input=[3, 4, 5], + dim=1, + index=[ + [[0, 1, 2, 3], [1, 2, 3, 0], [2, 3, 0, 1], [3, 0, 1, 2], [0, 1, 2, 3]] + ], + ), + Test( + input=[3, 4, 5], dim=2, index=[[[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 0]]] + ), + ] + + test_suite = VkTestSuite( + [tuple(tc) + (False, "false", "false") for tc in test_cases] + ) test_suite.dtypes = ["at::kFloat"] - test_suite.layouts = ["utils::kChannelsPacked"] + test_suite.layouts = ["utils::kWidthPacked", "utils::kChannelsPacked"] + test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"] + return test_suite @@ -1456,64 +1518,11 @@ def get_split_with_sizes_inputs(): test_suite.layouts = [ "utils::kWidthPacked", - "utils::kHeightPacked", - "utils::kChannelsPacked", - ] - test_suite.data_gen = "make_seq_tensor" - test_suite.dtypes = ["at::kFloat"] - return test_suite - - -@register_test_suite("aten.split.Tensor") -def get_split_tensor_inputs(): - test_suite = VkTestSuite( - [ - # Split on Width - ((S1, 7, 10, 12), 12, 3), - ((S1, 7, 10, 12), 3, 3), - ((S1, 7, 10, 12), 1, 3), - ((7, 10, 12), 12, 2), - ((7, 10, 12), 3, 2), - ((7, 10, 12), 1, 2), - ((10, 12), 12, 1), - ((10, 12), 3, 1), - ((10, 12), 1, 1), - ((12,), 12, 0), - ((12,), 3, 0), - ((12,), 1, 0), - # Split on Height - ((S1, 7, 12, 8), 12, 2), - ((S1, 7, 12, 8), 3, 2), - ((S1, 7, 12, 8), 1, 2), - ((7, 12, 8), 12, 1), - ((7, 12, 8), 3, 1), - ((7, 12, 8), 1, 1), - ((12, 8), 12, 0), - ((12, 8), 3, 0), - ((12, 8), 1, 0), - # Split on Batch - ((12, 7, 10, 10), 12, 0), - ((12, 7, 10, 10), 3, 0), - ((12, 7, 10, 10), 1, 0), - # Split on Channel - ((7, 15, 10, 10), 15, 1), - ((7, 15, 10, 10), 5, 1), - ((7, 15, 10, 10), 3, 1), - ((7, 15, 10, 10), 1, 1), - ((15, 10, 10), 15, 0), - ((15, 10, 10), 5, 0), - ((15, 10, 10), 3, 0), - ((15, 10, 10), 1, 0), - ] - ) - - test_suite.layouts = [ - "utils::kWidthPacked", - "utils::kHeightPacked", "utils::kChannelsPacked", ] test_suite.data_gen = "make_seq_tensor" test_suite.dtypes = ["at::kFloat"] + test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"] return test_suite @@ -1543,6 +1552,21 @@ def get_reduce_inputs(is_softmax: bool = False): ] +def get_reduce_per_row_inputs(): + inputs = [ + ((5, 10), 1, False), + ((5, 16), -1, True), + ((5, 16), -1, False), + ((7, 21), -1, True), + ((7, 21), -1, False), + ((3, 7, 280), -1, True), + ((3, 7, 280), -1, False), + ((3, 17, 77), -1, True), + ((3, 17, 77), -1, False), + ] + return inputs + + @register_test_suite(["aten._softmax.default", "aten._log_softmax.default"]) def get_softmax_inputs(): test_suite = VkTestSuite(get_reduce_inputs(is_softmax=True)) @@ -1562,6 +1586,20 @@ def get_reduce_op_inputs(): "utils::kChannelsPacked", "utils::kWidthPacked", ] + + per_row_suite = VkTestSuite(get_reduce_per_row_inputs()) + per_row_suite.layouts = ["utils::kWidthPacked"] + per_row_suite.storage_types = ["utils::kBuffer"] + per_row_suite.test_name_suffix = "per_row" + return [test_suite, per_row_suite] + + +@register_test_suite(["aten.argmin.default", "aten.argmax.default"]) +def get_reduce_arg_op_inputs(): + test_suite = VkTestSuite(get_reduce_per_row_inputs()) + test_suite.layouts = ["utils::kWidthPacked"] + test_suite.storage_types = ["utils::kBuffer"] + test_suite.dtypes = ["at::kFloat"] return test_suite @@ -1947,3 +1985,28 @@ def get_where_inputs(): test_suite.atol = "1e-4" test_suite.rtol = "1e-4" return test_suite + + +@register_test_suite("aten.pow.Tensor_Scalar") +def get_pow_tensor_scalar_inputs(): + test_suite = VkTestSuite( + [ + ((M1,), 2.0), + ((M2, M1), 2.0), + ((S1, M1, M2), 0.5), + ((S1, S2, S2, M2), 2.5), + ((S, S1, S2), -1.0), + ((M1, M2), 4.0), + ((S1, S2), 1.5), + ] + ) + test_suite.storage_types = [ + "utils::kBuffer", + "utils::kTexture3D", + ] + test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kChannelsPacked", + ] + test_suite.dtypes = ["at::kFloat"] + return test_suite diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp deleted file mode 100644 index 3b1094a1e84..00000000000 --- a/backends/vulkan/test/op_tests/choose_qparams_test.cpp +++ /dev/null @@ -1,786 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -#include -#include -#include - -#include -#include - -#include "test_utils.h" - -#include -#include - -namespace torch { -namespace executor { -namespace native { - -// Forward declarations of the functions we're testing -std::tuple choose_qparams_tensor_out( - const Tensor& input, - int64_t quant_min, - int64_t quant_max, - ET_UNUSED double eps, - ScalarType dtype, - Tensor& scale_out, - Tensor& zero_point_out); - -std::tuple choose_qparams_per_token_asymmetric_out( - const Tensor& input, - ScalarType dtype, - Tensor& scale_out, - Tensor& zero_point_out); - -// Wrapper function for choose_qparams_tensor_out without context -Tensor& choose_qparams_tensor_out_no_context( - const Tensor& input, - int64_t quant_min, - int64_t quant_max, - ET_UNUSED double eps, - ScalarType dtype, - Tensor& scale_out, - Tensor& zero_point_out) { - torch::executor::native::choose_qparams_tensor_out( - input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out); - return scale_out; -} - -// Wrapper function for choose_qparams_per_token_asymmetric_out without context -Tensor& choose_qparams_per_token_asymmetric_out_no_context( - const Tensor& input, - ScalarType dtype, - Tensor& scale_out, - Tensor& zero_point_out) { - torch::executor::native::choose_qparams_per_token_asymmetric_out( - input, dtype, scale_out, zero_point_out); - return scale_out; -} - -// ATen wrapper for choose_qparams_tensor -std::tuple choose_qparams_tensor_aten( - const at::Tensor& input, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); - auto zero_point_out = at::empty({}, at::device(at::kCPU).dtype(at::kLong)); - double eps = 1e-7; - - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - // Use WRAP_TO_ATEN with the wrapper function - WRAP_TO_ATEN(choose_qparams_tensor_out_no_context, 5) - (input, quant_min, quant_max, eps, et_dtype, scale_out, zero_point_out); - - return {scale_out, zero_point_out}; -} - -// ATen wrapper for choose_qparams_per_token_asymmetric -std::tuple choose_qparams_per_token_asymmetric_aten( - const at::Tensor& input, - at::ScalarType dtype) { - // Calculate output sizes for scale and zero_point tensors - std::vector output_sizes; - for (int64_t i = 0; i < input.dim() - 1; i++) { - output_sizes.push_back(input.size(i)); - } - output_sizes.push_back(1); - - auto scale_out = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); - auto zero_point_out = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); - - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - // Use WRAP_TO_ATEN with the wrapper function - WRAP_TO_ATEN(choose_qparams_per_token_asymmetric_out_no_context, 2) - (input, et_dtype, scale_out, zero_point_out); - - return {scale_out, zero_point_out}; -} - -} // namespace native -} // namespace executor -} // namespace torch - -// -// Reference Implementation -// - -/* - * Reference implementation of choose_qparams_tensor - */ -std::tuple choose_qparams_tensor_reference_impl( - const at::Tensor& input, - int64_t quant_min, - int64_t quant_max) { - // Create output tensors - at::Tensor scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_out = - at::empty({}, at::device(at::kCPU).dtype(at::kLong)); - - // Find min and max values in the input tensor - float min_val = input.min().item(); - float max_val = input.max().item(); - - // Extend the [min, max] interval to ensure it contains 0 - min_val = std::min(min_val, 0.f); - max_val = std::max(max_val, 0.f); - - // Calculate scale - double scale = - (static_cast(max_val) - min_val) / (quant_max - quant_min); - - // Handle small scale - constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; - if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { - scale = 0.1; - } - - if (scale < SMALL_SCALE_THRESHOLD) { - float org_scale = scale; - scale = SMALL_SCALE_THRESHOLD; - // Adjust min and max based on new scale - if (min_val == 0.0f) { - max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min); - } else if (max_val == 0.0f) { - min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min); - } else { - float amplifier = SMALL_SCALE_THRESHOLD / org_scale; - min_val *= amplifier; - max_val *= amplifier; - } - } - - // Calculate zero point - double zero_point_from_min = quant_min - min_val / static_cast(scale); - double zero_point_from_max = quant_max - max_val / static_cast(scale); - double zero_point_from_min_error = - std::abs(quant_min) - std::abs(min_val / static_cast(scale)); - double zero_point_from_max_error = - std::abs(quant_max) - std::abs(max_val / static_cast(scale)); - double initial_zero_point = - zero_point_from_min_error < zero_point_from_max_error - ? zero_point_from_min - : zero_point_from_max; - - // Nudge zero point to be an integer - int64_t nudged_zero_point = 0; - if (initial_zero_point < quant_min) { - nudged_zero_point = quant_min; - } else if (initial_zero_point > quant_max) { - nudged_zero_point = quant_max; - } else { - nudged_zero_point = std::nearbyint(static_cast(initial_zero_point)); - } - - // Set output values - use item_mutable() for scalar tensors - scale_out.fill_(scale); - zero_point_out.fill_(nudged_zero_point); - - return std::make_tuple(scale_out, zero_point_out); -} - -/* - * Reference implementation of choose_qparams_per_token_asymmetric - */ -std::tuple -choose_qparams_per_token_asymmetric_reference_impl( - const at::Tensor& input, - at::ScalarType dtype) { - // For per-token quantization, we need to compute scale and zero_point for - // each token - int64_t quant_min = -128; - int64_t quant_max = 127; - - // Calculate output sizes - std::vector output_sizes; - for (int64_t i = 0; i < input.dim() - 1; i++) { - output_sizes.push_back(input.size(i)); - } - output_sizes.push_back(1); - - // Create output tensors - at::Tensor scale_out = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_out = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); - - // Calculate number of tokens - int64_t num_tokens = 1; - for (int64_t i = 0; i < input.dim() - 1; i++) { - num_tokens *= input.size(i); - } - - // Reshape input to [num_tokens, last_dim] - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - - // Process each token - for (int64_t token_idx = 0; token_idx < num_tokens; token_idx++) { - at::Tensor token = reshaped_input[token_idx]; - - // Find min and max values for this token - float min_val = token.min().item(); - float max_val = token.max().item(); - - // Extend the [min, max] interval to ensure it contains 0 - min_val = std::min(min_val, 0.f); - max_val = std::max(max_val, 0.f); - - // Calculate scale - double scale = - (static_cast(max_val) - min_val) / (quant_max - quant_min); - - // Handle small scale - constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; - if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { - scale = 0.1; - } - - if (scale < SMALL_SCALE_THRESHOLD) { - float org_scale = scale; - scale = SMALL_SCALE_THRESHOLD; - // Adjust min and max based on new scale - if (min_val == 0.0f) { - max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min); - } else if (max_val == 0.0f) { - min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min); - } else { - float amplifier = SMALL_SCALE_THRESHOLD / org_scale; - min_val *= amplifier; - max_val *= amplifier; - } - } - - // Calculate zero point - double zero_point_from_min = - quant_min - min_val / static_cast(scale); - double zero_point_from_max = - quant_max - max_val / static_cast(scale); - double zero_point_from_min_error = - std::abs(quant_min) - std::abs(min_val / static_cast(scale)); - double zero_point_from_max_error = - std::abs(quant_max) - std::abs(max_val / static_cast(scale)); - double initial_zero_point = - zero_point_from_min_error < zero_point_from_max_error - ? zero_point_from_min - : zero_point_from_max; - - // Nudge zero point to be an integer - int64_t nudged_zero_point = 0; - if (initial_zero_point < quant_min) { - nudged_zero_point = quant_min; - } else if (initial_zero_point > quant_max) { - nudged_zero_point = quant_max; - } else { - nudged_zero_point = - std::nearbyint(static_cast(initial_zero_point)); - } - - // Set output values for this token - use index_put_ for safety - scale_out.view({num_tokens, 1}).index_put_({token_idx, 0}, scale); - zero_point_out.view({num_tokens, 1}) - .index_put_({token_idx, 0}, nudged_zero_point); - } - - return std::make_tuple(scale_out, zero_point_out); -} - -// Forward declaration of implementation functions -void test_vulkan_choose_qparams_tensor_impl( - const std::vector& input_sizes, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_choose_qparams_per_token_asymmetric_impl( - const std::vector& input_sizes, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_choose_qparams_tensor( - const std::vector& input_sizes, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - // Test with buffer storage - test_vulkan_choose_qparams_tensor_impl( - input_sizes, - quant_min, - quant_max, - dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Test with texture storage - test_vulkan_choose_qparams_tensor_impl( - input_sizes, - quant_min, - quant_max, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_choose_qparams_per_token_asymmetric( - const std::vector& input_sizes, - at::ScalarType dtype) { - // Test with buffer storage - test_vulkan_choose_qparams_per_token_asymmetric_impl( - input_sizes, dtype, vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); - - // Test with texture storage - test_vulkan_choose_qparams_per_token_asymmetric_impl( - input_sizes, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -void test_reference_choose_qparams_tensor( - const std::vector& input_sizes, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Get reference output - auto [reference_scale, reference_zero_point] = - choose_qparams_tensor_reference_impl(input, quant_min, quant_max); - - // Get implementation output - auto [impl_scale, impl_zero_point] = - torch::executor::native::choose_qparams_tensor_aten( - input, quant_min, quant_max, dtype); - - // Compare outputs - const bool scale_correct = at::allclose(reference_scale, impl_scale); - const bool zero_point_correct = - at::equal(reference_zero_point, impl_zero_point); - - if (!scale_correct || !zero_point_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference scale:" << std::endl; - std::cout << reference_scale << std::endl; - std::cout << "implementation scale:" << std::endl; - std::cout << impl_scale << std::endl; - std::cout << "reference zero_point:" << std::endl; - std::cout << reference_zero_point << std::endl; - std::cout << "implementation zero_point:" << std::endl; - std::cout << impl_zero_point << std::endl; - } - - ASSERT_TRUE(scale_correct && zero_point_correct); -} - -void test_vulkan_choose_qparams_tensor_impl( - const std::vector& input_sizes, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Get reference output - auto [reference_scale, reference_zero_point] = - torch::executor::native::choose_qparams_tensor_aten( - input, quant_min, quant_max, dtype); - - // Build Vulkan choose_qparams_tensor graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - // Output tensors - const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage); - const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage); - - // Create output tuple - const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); - - // Add eps and dtype parameters to match ATen signature - const ValueRef r_eps = graph.add_scalar(6.1e-5); - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN("quantized_decomposed.choose_qparams.tensor") - (graph, - { - r_input.value, - r_quant_min, - r_quant_max, - r_eps, - r_dtype, - r_out_tuple, - }); - - ValueRef staging_scale = graph.set_output_tensor(r_scale); - ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point); - - graph.prepare(); - - graph.prepack(); - - // Run Vulkan choose_qparams_tensor - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - graph.execute(); - - // Create output tensors to hold the results - use types that match GPU output - at::Tensor vk_scale = - at::empty({}, at::device(at::kCPU).dtype(at::kFloat)).contiguous(); - at::Tensor vk_zero_point = - at::empty({}, at::device(at::kCPU).dtype(at::kInt)).contiguous(); - - // Copy results from GPU to CPU - graph.copy_from_staging( - staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); - graph.copy_from_staging( - staging_zero_point, - vk_zero_point.mutable_data_ptr(), - vk_zero_point.numel()); - - // Convert reference values to match Vulkan output types for comparison - at::Tensor reference_scale_float = reference_scale.to(at::kFloat); - at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt); - - // Compare outputs - const bool scale_correct = at::allclose(reference_scale_float, vk_scale); - const bool zero_point_correct = - at::equal(reference_zero_point_int, vk_zero_point); - - if (!scale_correct || !zero_point_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - // make sure that there arent a ton of elements in the input tensor - if (input.numel() < 100) { - std::cout << "input:" << std::endl; - std::cout << input << "\n" << std::endl; - std::cout << "reference scale:" << std::endl; - std::cout << reference_scale << std::endl; - std::cout << "vulkan scale:" << std::endl; - std::cout << vk_scale << "\n" << std::endl; - std::cout << "reference zero_point:" << std::endl; - std::cout << reference_zero_point << std::endl; - std::cout << "vulkan zero_point:" << std::endl; - std::cout << vk_zero_point << std::endl; - } - } - - ASSERT_TRUE(scale_correct && zero_point_correct); -} - -TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) { - test_reference_choose_qparams_tensor( - {2, 3, 4}, // input sizes - -128, // quant_min - 127, // quant_max - at::kChar); -} - -TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_uint8_4D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_tensor( - {5, 3, 2, 4}, // input sizes - 0, // quant_min - 255, // quant_max - at::kByte); -} - -TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_2D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_tensor( - {5, 5}, // input sizes - -128, // quant_min - 127, // quant_max - at::kChar); -} - -TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_3D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_tensor( - {12, 8, 2}, // input sizes - -128, // quant_min - 127, // quant_max - at::kChar); -} - -TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_4D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_tensor( - {10, 10, 6, 4}, // input sizes - -128, // quant_min - 127, // quant_max - at::kChar); -} - -void test_reference_choose_qparams_per_token_asymmetric( - const std::vector& input_sizes, - at::ScalarType dtype) { - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Get reference output - auto [reference_scale, reference_zero_point] = - choose_qparams_per_token_asymmetric_reference_impl(input, dtype); - - // Get implementation output - auto [impl_scale, impl_zero_point] = - torch::executor::native::choose_qparams_per_token_asymmetric_aten( - input, dtype); - - // Compare outputs - const bool scale_correct = at::allclose(reference_scale, impl_scale); - const bool zero_point_correct = - at::equal(reference_zero_point, impl_zero_point); - - if (!scale_correct || !zero_point_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference scale:" << std::endl; - std::cout << reference_scale << std::endl; - std::cout << "implementation scale:" << std::endl; - std::cout << impl_scale << std::endl; - std::cout << "reference zero_point:" << std::endl; - std::cout << reference_zero_point << std::endl; - std::cout << "implementation zero_point:" << std::endl; - std::cout << impl_zero_point << std::endl; - } - - ASSERT_TRUE(scale_correct && zero_point_correct); -} - -void test_vulkan_choose_qparams_per_token_asymmetric_impl( - const std::vector& input_sizes, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Calculate output sizes - std::vector output_sizes; - for (int64_t i = 0; i < input.dim() - 1; i++) { - output_sizes.push_back(input.size(i)); - } - output_sizes.push_back(1); - - // Get reference output - auto [reference_scale, reference_zero_point] = - torch::executor::native::choose_qparams_per_token_asymmetric_aten( - input, dtype); - - // Build Vulkan choose_qparams_per_token_asymmetric graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - - // Output tensors - const ValueRef r_scale = - graph.add_tensor(output_sizes, vkapi::kFloat, out_storage); - const ValueRef r_zero_point = - graph.add_tensor(output_sizes, vkapi::kInt, out_storage); - - // Create output tuple - const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); - - // Add dtype parameter to match ATen signature - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN( - "quantized_decomposed.choose_qparams_per_token_asymmetric.default") - (graph, - { - r_input.value, - r_dtype, - r_out_tuple, - }); - - ValueRef staging_scale = graph.set_output_tensor(r_scale); - ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point); - - graph.prepare(); - - graph.prepack(); - - // Run Vulkan choose_qparams_per_token_asymmetric - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - graph.execute(); - - // Create output tensors to hold the results - use types that match GPU output - at::Tensor vk_scale = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kFloat)) - .contiguous(); - at::Tensor vk_zero_point = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kInt)) - .contiguous(); - - // Copy results from GPU to CPU - graph.copy_from_staging( - staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); - graph.copy_from_staging( - staging_zero_point, - vk_zero_point.mutable_data_ptr(), - vk_zero_point.numel()); - - // Convert reference values to match Vulkan output types for comparison - at::Tensor reference_scale_float = reference_scale.to(at::kFloat); - at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt); - - // Compare outputs - const bool scale_correct = at::allclose(reference_scale_float, vk_scale); - const bool zero_point_correct = - at::equal(reference_zero_point_int, vk_zero_point); - if (!scale_correct || !zero_point_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - if (input.numel() < 100) { - std::cout << "input:" << std::endl; - std::cout << input << "\n" << std::endl; - std::cout << "reference scale:" << std::endl; - std::cout << reference_scale << std::endl; - std::cout << "vulkan scale:" << std::endl; - std::cout << vk_scale << "\n" << std::endl; - std::cout << "reference zero_point:" << std::endl; - std::cout << reference_zero_point << std::endl; - std::cout << "vulkan zero_point:" << std::endl; - std::cout << vk_zero_point << std::endl; - } - } - - ASSERT_TRUE(scale_correct && zero_point_correct); -} - -TEST( - VulkanChooseQparamsTest, - test_reference_choose_qparams_per_token_asymmetric_int8) { - test_reference_choose_qparams_per_token_asymmetric( - {2, 3, 4}, // input sizes (2*3=6 tokens) - at::kChar); -} - -TEST( - VulkanChooseQparamsTest, - test_vulkan_choose_qparams_per_token_asymmetric_int8_1D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_per_token_asymmetric({7}, at::kChar); -} - -TEST( - VulkanChooseQparamsTest, - test_vulkan_choose_qparams_per_token_asymmetric_int8_2D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_per_token_asymmetric({2, 2}, at::kChar); -} - -TEST( - VulkanChooseQparamsTest, - test_vulkan_choose_qparams_per_token_asymmetric_int8_3D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_per_token_asymmetric({3, 6, 4}, at::kChar); -} - -TEST( - VulkanChooseQparamsTest, - test_vulkan_choose_qparams_per_token_asymmetric_int8_4D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_per_token_asymmetric({128, 2, 16, 3}, at::kChar); -} diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp deleted file mode 100644 index 9fca2c632d3..00000000000 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ /dev/null @@ -1,2492 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -#include -#include -#include - -#include -#include - -#include "test_utils.h" - -#include -#include -#include -#include - -namespace torch { -namespace executor { -namespace native { - -// Forward declarations of the functions we're testing -Tensor& dequantize_per_tensor_out( - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out); - -Tensor& dequantize_per_token_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_points, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - ScalarType out_dtype, - Tensor& out); - -Tensor& dequantize_per_channel_out( - const Tensor& input, - const Tensor& scale, - const std::optional& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out); - -Tensor& dequantize_per_tensor_tensor_args_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out); - -// Wrapper function for dequantize_per_tensor_out without context -Tensor& dequantize_per_tensor_out_no_context( - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out) { - return torch::executor::native::dequantize_per_tensor_out( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); -} - -// Wrapper function for dequantize_per_token_out without context -Tensor& dequantize_per_token_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_points, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - ScalarType out_dtype, - Tensor& out) { - return torch::executor::native::dequantize_per_token_out( - input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); -} - -// Wrapper function for dequantize_per_channel_out without context -Tensor& dequantize_per_channel_out_no_context( - const Tensor& input, - const Tensor& scale, - const std::optional& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out) { - return torch::executor::native::dequantize_per_channel_out( - input, - scale, - zero_points, - axis, - quant_min, - quant_max, - dtype, - out_dtype, - out); -} - -// Wrapper function for dequantize_per_tensor_tensor_args_out without context -Tensor& dequantize_per_tensor_tensor_args_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out) { - return torch::executor::native::dequantize_per_tensor_tensor_args_out( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); -} - -// ATen wrapper for dequantize_per_tensor -at::Tensor dequantize_per_tensor_aten( - const at::Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - auto out = at::empty_like(input, out_dtype); - // Convert at::ScalarType to executorch::ScalarType - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); - - executorch::aten::optional opt_et_out_dtype(et_out_dtype); - - WRAP_TO_ATEN(dequantize_per_tensor_out_no_context, 7) - (input, - scale, - zero_point, - quant_min, - quant_max, - et_dtype, - opt_et_out_dtype, - out); - return out; -} - -// ATen wrapper for dequantize_per_token -at::Tensor dequantize_per_token_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - auto out = at::empty_like(input, out_dtype); - // Convert at::ScalarType to executorch::ScalarType - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); - - WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7) - (input, - scale, - zero_points, - quant_min, - quant_max, - et_dtype, - et_out_dtype, - out); - return out; -} - -// ATen wrapper for dequantize_per_channel -at::Tensor dequantize_per_channel_aten( - const at::Tensor& input, - const at::Tensor& scale, - const std::optional& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - auto out = at::empty_like(input, out_dtype); - // Convert at::ScalarType to executorch::ScalarType - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); - - executorch::aten::optional opt_et_out_dtype(et_out_dtype); - - WRAP_TO_ATEN(dequantize_per_channel_out_no_context, 8) - (input, - scale, - zero_points, - axis, - quant_min, - quant_max, - et_dtype, - opt_et_out_dtype, - out); - return out; -} - -// ATen wrapper for dequantize_per_tensor with tensor args -at::Tensor dequantize_per_tensor_tensor_args_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - auto out = at::empty_like(input, out_dtype); - // Convert at::ScalarType to executorch::ScalarType - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); - - executorch::aten::optional opt_et_out_dtype(et_out_dtype); - - WRAP_TO_ATEN(dequantize_per_tensor_tensor_args_out_no_context, 7) - (input, - scale, - zero_point, - quant_min, - quant_max, - et_dtype, - opt_et_out_dtype, - out); - return out; -} - -} // namespace native -} // namespace executor -} // namespace torch - -void check_dequantize_args( - int64_t quant_min, - int64_t quant_max, - c10::ScalarType in_dtype, - c10::ScalarType out_dtype) { - using namespace vkcompute; - - // Check that quant_min <= quant_max - VK_CHECK_COND( - quant_min <= quant_max, - "quant_min must be <= quant_max, got quant_min: ", - quant_min, - " quant_max: ", - quant_max); - - // Check that input dtype is a quantized type - switch (in_dtype) { - case c10::kByte: - case c10::kChar: - case c10::kShort: - case c10::kInt: - case c10::kLong: - break; - default: - VK_THROW( - "Unsupported input dtype: ", - scalar_type_name(in_dtype), - " (", - static_cast(in_dtype), - ")"); - } - - // Check that output dtype is a floating point type - switch (out_dtype) { - case c10::kHalf: - case c10::kFloat: - case c10::kDouble: - break; - default: - VK_THROW( - "Unsupported output dtype: ", - scalar_type_name(out_dtype), - " (", - static_cast(out_dtype), - ")"); - } -} - -/** - * Helper function to validate dequantize_per_channel arguments - * Similar to the validation in quantize_test.cpp - */ -void check_dequantize_per_channel_args( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis) { - // Normalize axis - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input_sizes.size(); - } - - ASSERT_GE(normalized_axis, 0) - << "axis " << axis << " is not legal, normalized axis " << normalized_axis - << " should be >= 0"; - - ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) - << "axis " << axis << " is not legal, normalized axis " << normalized_axis - << " should be < input.dim() " << input_sizes.size(); - - int64_t num_channels = input_sizes[normalized_axis]; - - ASSERT_EQ(num_channels, static_cast(scales.size())) - << "Expected scales.size() to match input.size(axis) (" << num_channels - << "), but got " << scales.size(); - - ASSERT_EQ(num_channels, static_cast(zero_points.size())) - << "Expected zero_points.size() to match input.size(axis) (" - << num_channels << "), but got " << zero_points.size(); -} - -// -// Reference Implementation -// - -/* - * Reference implementation of dequantize_per_tensor - */ -at::Tensor dequantize_per_tensor_reference_impl( - const at::Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Create output tensor with the target dtype - at::Tensor out = at::empty_like(input, out_dtype); - - // Dequantize the input tensor - at::Tensor flat_input = input.flatten(); - at::Tensor flat_out = out.flatten(); - - // Store casted values to avoid repeated casting - const int32_t zero_point_int32 = static_cast(zero_point); - const float scale_float = static_cast(scale); - - for (int i = 0; i < flat_input.numel(); i++) { - double dequantized_value = 0.0; - - // Extract quantized value and dequantize based on input dtype - // Following the CPU implementation pattern: (input - zero_point) * scale - if (dtype == at::kByte) { - uint8_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } else if (dtype == at::kChar) { - int8_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } else if (dtype == at::kShort) { - int16_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } else if (dtype == at::kInt) { - int32_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } else if (dtype == at::kLong) { - int64_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } - - // Store result based on output dtype - if (out_dtype == at::kFloat) { - flat_out[i] = static_cast(dequantized_value); - } else if (out_dtype == at::kDouble) { - flat_out[i] = dequantized_value; - } else if (out_dtype == at::kHalf) { - flat_out[i] = static_cast(dequantized_value); - } - } - - return out.reshape(input.sizes()); -} - -/* - * Reference implementation of dequantize_per_token - */ -at::Tensor dequantize_per_token_reference_impl( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Create output tensor with the target dtype - at::Tensor out = at::empty_like(input, out_dtype); - - // Calculate number of tokens - int num_tokens = 1; - for (int i = 0; i < input.dim() - 1; i++) { - num_tokens *= input.size(i); - } - - // Verify that the number of tokens matches the size of scale and zero_point - // tensors - assert(num_tokens == scale.numel()); - assert(num_tokens == zero_point.numel()); - - // Reshape input to [num_tokens, last_dim] - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); - - // Dequantize each token separately - for (int token_idx = 0; token_idx < num_tokens; token_idx++) { - // Get scale and zero_point for this token - float token_scale = scale[token_idx].item(); - int64_t token_zero_point = zero_point[token_idx].item(); - - // Store casted values to avoid repeated casting - const int32_t token_zero_point_int32 = - static_cast(token_zero_point); - - // Dequantize the token - for (int i = 0; i < input.size(-1); i++) { - double dequantized_value = 0.0; - - // Extract quantized value and dequantize based on input dtype - // Following the CPU implementation pattern: (input - zero_point) * scale - if (dtype == at::kByte) { - uint8_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else if (dtype == at::kChar) { - int8_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else if (dtype == at::kShort) { - int16_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else if (dtype == at::kInt) { - int32_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else if (dtype == at::kLong) { - int64_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else { - throw std::runtime_error("Unsupported input dtype"); - } - - // Store result based on output dtype - if (out_dtype == at::kFloat) { - reshaped_out[token_idx][i] = static_cast(dequantized_value); - } else if (out_dtype == at::kDouble) { - reshaped_out[token_idx][i] = dequantized_value; - } else if (out_dtype == at::kHalf) { - reshaped_out[token_idx][i] = static_cast(dequantized_value); - } - } - } - - return out; -} - -/* - * Reference implementation of dequantize_per_channel - */ -at::Tensor dequantize_per_channel_reference_impl( - const at::Tensor& input, - const at::Tensor& scale, - const std::optional& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Normalize axis to handle negative values - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input.dim(); - } - - // Create output tensor with the same shape as input but with target dtype - at::Tensor output = at::empty_like(input, out_dtype); - - // Get the number of channels along the quantization axis - int64_t num_channels = input.size(normalized_axis); - - // Calculate strides for efficient indexing - std::vector input_strides; - std::vector input_sizes; - for (int64_t i = 0; i < input.dim(); i++) { - input_sizes.push_back(input.size(i)); - input_strides.push_back(input.stride(i)); - } - - // Get data pointers - const double* scale_data = scale.const_data_ptr(); - const int64_t* zero_point_data = nullptr; - if (zero_point.has_value()) { - zero_point_data = zero_point.value().const_data_ptr(); - } - - // Iterate through all elements in the tensor - int64_t total_elements = input.numel(); - - // Helper lambda to convert flat index to multi-dimensional coordinates - auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { - int64_t remaining = flat_idx; - for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { - coords[dim] = remaining % input_sizes[dim]; - remaining /= input_sizes[dim]; - } - }; - - // Process each element - std::vector coords(input.dim()); - for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { - // Convert flat index to coordinates - flat_to_coords(flat_idx, coords); - - // Get the channel index for this element - int64_t channel_idx = coords[normalized_axis]; - - // Get the quantization parameters for this channel - double channel_scale = scale_data[channel_idx]; - int64_t channel_zero_point = 0; - if (zero_point_data != nullptr) { - channel_zero_point = zero_point_data[channel_idx]; - } - - // Store casted values to avoid repeated casting - const int32_t channel_zero_point_int32 = - static_cast(channel_zero_point); - const float channel_scale_float = static_cast(channel_scale); - - // Get the input value and dequantize - double dequantized_value = 0.0; - - // Extract quantized value and dequantize based on input dtype - // Following the CPU implementation pattern: (input - zero_point) * scale - if (dtype == at::kByte) { - uint8_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else if (dtype == at::kChar) { - int8_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else if (dtype == at::kShort) { - int16_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else if (dtype == at::kInt) { - int32_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else if (dtype == at::kLong) { - int64_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else { - throw std::runtime_error("Unsupported input dtype"); - } - - // Store the result based on output dtype - if (out_dtype == at::kFloat) { - output.flatten()[flat_idx] = static_cast(dequantized_value); - } else if (out_dtype == at::kDouble) { - output.flatten()[flat_idx] = dequantized_value; - } else if (out_dtype == at::kHalf) { - output.flatten()[flat_idx] = static_cast(dequantized_value); - } - } - - return output; -} - -// Forward declaration of implementation functions -void test_vulkan_dequantize_per_token_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_dequantize_per_channel_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_dequantize_per_tensor_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_token( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Test with buffer storage - test_vulkan_dequantize_per_token_impl( - input_sizes, - scales, - zero_points, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Telling the system to expect a float instead of a double - // since the shader can only return 32bit anyways - if (out_dtype == at::kDouble) { - out_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_dequantize_per_token_impl( - input_sizes, - scales, - zero_points, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_channel( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Test with buffer storage - test_vulkan_dequantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Telling the system to expect a float instead of a double - // since the shader can only return 32bit anyways - if (out_dtype == at::kDouble) { - out_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_dequantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_tensor_tensor( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Test with buffer storage - test_vulkan_dequantize_per_tensor_tensor_impl( - input_sizes, - scale, - zero_point, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Telling the system to expect a float instead of a double - // since the shader can only return 32bit anyways - if (out_dtype == at::kDouble) { - out_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_dequantize_per_tensor_tensor_impl( - input_sizes, - scale, - zero_point, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -void test_reference_dequantize_per_tensor( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create a quantized input tensor with values from quant_min to quant_max - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - float step = 1.0f; - if (input.numel() > 1) { - step = static_cast(quant_max - quant_min) / (input.numel() - 1); - } - - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - flat_input[i] = static_cast(qvalue); - } - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Get reference output - at::Tensor reference_out = dequantize_per_tensor_reference_impl( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); - - // Get implementation output - at::Tensor impl_out = torch::executor::native::dequantize_per_tensor_aten( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); - - // Compare outputs - const bool output_correct = at::allclose(reference_out, impl_out); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "implementation:" << std::endl; - std::cout << impl_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_uint8_to_float) { - test_reference_dequantize_per_tensor( - {2, 3, 4}, // input sizes - 0.1, // scale - 5, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_int8_to_float) { - test_reference_dequantize_per_tensor( - {3, 4, 5}, // input sizes - 0.05, // scale - 0, // zero_point - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_int32_to_float) { - test_reference_dequantize_per_tensor( - {4, 6, 2}, // input sizes - 0.2, // scale - 2, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_uint8_to_half) { - test_reference_dequantize_per_tensor( - {7, 4}, // input sizes - 0.1, // scale - 10, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype (uint8) - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_int32_to_half) { - test_reference_dequantize_per_tensor( - {2, 6, 5}, // input sizes - 0.3, // scale - -10, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kHalf); // output dtype -} - -// No Vulkan tests for quantized_decomposed.dequantize_per_tensor.default -// because it is not going to be implemented in Vulkan since we will -// be handling any future calls to this op via the export stage - -void test_reference_dequantize_per_token( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - int num_tokens = 1; - for (int i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - ASSERT_EQ(num_tokens, scales.size()); - ASSERT_EQ(num_tokens, zero_points.size()); - - // Create input tensor with quantized values - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - for (int token_idx = 0; token_idx < num_tokens; token_idx++) { - float step = 1.0f; - if (input.size(-1) > 1) { - step = static_cast(quant_max - quant_min) / (input.size(-1) - 1); - } - - for (int i = 0; i < input.size(-1); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } - } - } - - // Reshape back to original dimensions - input = reshaped_input.reshape(input_sizes_int64); - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor reference_out = dequantize_per_token_reference_impl( - input, - scale_tensor, - zero_point_tensor, - quant_min, - quant_max, - dtype, - out_dtype); - - // Get implementation output - at::Tensor impl_out = torch::executor::native::dequantize_per_token_aten( - input, - scale_tensor, - zero_point_tensor, - quant_min, - quant_max, - dtype, - out_dtype); - - // Compare outputs - const bool output_correct = at::allclose(reference_out, impl_out); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "implementation:" << std::endl; - std::cout << impl_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -void test_vulkan_dequantize_per_token_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - int num_tokens = 1; - for (int i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - ASSERT_EQ(num_tokens, scales.size()); - ASSERT_EQ(num_tokens, zero_points.size()); - - // Create input tensor with quantized values - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - for (int token_idx = 0; token_idx < num_tokens; token_idx++) { - float step = 1.0f; - if (input.size(-1) > 1) { - step = static_cast(quant_max - quant_min) / (input.size(-1) - 1); - } - - for (int i = 0; i < input.size(-1); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } - } - } - - // Reshape back to original dimensions - input = reshaped_input.reshape(input_sizes_int64); - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor reference_out = torch::executor::native::dequantize_per_token_aten( - input, - scale_tensor, - zero_point_tensor, - quant_min, - quant_max, - dtype, - out_dtype); - - // Build Vulkan dequantize_per_token graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(dtype), in_storage); - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(out_dtype)); - - VK_GET_OP_FN("quantized_decomposed.dequantize_per_token.default") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_dtype, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - - graph.prepack(); - - // Copy input data to GPU - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - // Execute the graph - graph.execute(); - - // Copy output data back to CPU - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs with appropriate tolerance for half precision - bool output_correct; - if (out_dtype == at::kHalf) { - // Use higher tolerance for half precision due to limited precision - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); - } else { - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); - } - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_uint8_to_float) { - std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; - std::vector zero_points = {5, 10, 15, 20, 25, 30}; - - test_reference_dequantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_int8_to_float) { - std::vector scales = {0.05, 0.1, 0.15, 0.2}; - std::vector zero_points = {0, -5, 5, 10}; - - test_reference_dequantize_per_token( - {2, 2, 5}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_int32_to_float) { - std::vector scales = {0.05, 0.1, 0.15, 0.2}; - std::vector zero_points = {0, -5, 5, 10}; - - test_reference_dequantize_per_token( - {2, 2, 10}, // input sizes (2*2=4 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_int8_to_half) { - std::vector scales = {0.05, 0.1, 0.15, 0.2}; - std::vector zero_points = {0, -5, 5, 10}; - - test_reference_dequantize_per_token( - {4, 1, 5}, // input sizes (4*1=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype (int8) - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_int32_to_half) { - std::vector scales = {0.05, 0.1}; - std::vector zero_points = {0, -5}; - - test_reference_dequantize_per_token( - {2, 2}, // input sizes (2 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_uint8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; - std::vector zero_points = {5, 10, 15, 20, 25, 30}; - - test_vulkan_dequantize_per_token( - {2, 3, 6}, // input sizes (2*3=6 tokens) - scales, - zero_points, - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.05, 0.0}; - std::vector zero_points = {10, -5}; - - test_vulkan_dequantize_per_token( - {2, 2}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int32_to_float) { - std::vector scales = { - 0.0001, 0.0002, 0.0003, 0.0, 0.0011, 0.0102, 0.1003, 0.0}; - std::vector zero_points = {100, -100, 50, -50, 12, -6, 4, -24}; - - test_vulkan_dequantize_per_token( - {2, 2, 2, 12}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -2147483648, // quant_min - 2147483647, // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int8_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.05, 0.2}; - std::vector zero_points = {2, -5}; - - test_vulkan_dequantize_per_token( - {2, 2}, // input sizes (2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int32_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - // Use much smaller scales to avoid overflow to infinity in half precision - // Half precision max value is ~65504, so with int32 values around 2e9, - // we need scales smaller than 65504/2e9 ≈ 3e-5 to avoid overflow - std::vector scales = {1e-5, 2e-5, 1.5e-5}; - std::vector zero_points = {20, -15, 1}; - - test_vulkan_dequantize_per_token( - {3, 6}, // input sizes (3 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int8_to_double) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.05, 0.001}; - std::vector zero_points = {10, -5}; - - test_vulkan_dequantize_per_token( - {2, 2}, // input sizes (2 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kDouble); // output dtype -} - -void test_reference_dequantize_per_channel( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis); - - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create input tensor with quantized values - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - float step = 1.0f; - if (input.numel() > 1) { - step = static_cast(quant_max - quant_min) / (input.numel() - 1); - } - - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - flat_input[i] = static_cast(qvalue); - } - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor my_ref = dequantize_per_channel_reference_impl( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype, - out_dtype); - - // Get implementation output - at::Tensor cpu_ref = torch::executor::native::dequantize_per_channel_aten( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype, - out_dtype); - - // Compare outputs - const bool output_correct = at::allclose(my_ref, cpu_ref); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " axis: " << axis << std::endl; - std::cout << " input sizes:"; - for (size_t i = 0; i < input_sizes.size(); i++) { - std::cout << " " << input_sizes[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "cpu_ref:" << std::endl; - std::cout << cpu_ref << std::endl; - std::cout << "my_ref:" << std::endl; - std::cout << my_ref << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -void test_vulkan_dequantize_per_channel_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis); - - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create random float tensor - at::Tensor float_x = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kInt)); - - // Map the dtype to the corresponding quantized type and quantize the float - // tensor - c10::ScalarType qtype; - at::Tensor adjusted_zero_points = zero_point_tensor; - - if (dtype == at::kByte) { - qtype = c10::kQUInt8; - // ATEN ONLY: Adjust zero points for unsigned types (must be non-negative) - adjusted_zero_points = at::clamp_min(zero_point_tensor, 0); - } else if (dtype == at::kChar) { - qtype = c10::kQInt8; - } else if (dtype == at::kInt) { - qtype = c10::kQInt32; - } else { - std::cout << "invalid dtype for ATEN: " << dtype << std::endl; - std::cout << " --> Delegating to c10::kQInt32" << std::endl; - qtype = c10::kQInt32; - } - - // Normalize axis for ATen (ATen doesn't handle negative axes in - // quantize_per_channel) - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input_sizes_int64.size(); - } - - // Quantize using ATen - at::Tensor quantized_aten = at::quantize_per_channel( - float_x, scale_tensor, adjusted_zero_points, normalized_axis, qtype); - - // Get ATen dequantized output - at::Tensor aten_out = at::dequantize(quantized_aten).to(out_dtype); - - // Extract the quantized values (int_repr) to use with our implementations - at::Tensor quantized_input = quantized_aten.int_repr().to(dtype); - - // Get reference output using - // torch::executor::native::dequantize_per_channel_aten - at::Tensor reference_out = - torch::executor::native::dequantize_per_channel_aten( - quantized_input, - scale_tensor.to(at::kDouble), - zero_point_tensor.to(at::kLong), - axis, - quant_min, - quant_max, - dtype, - out_dtype); - - // Build Vulkan dequantize_per_channel graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - // Add tensors to graph - IOValueRef r_input = graph.add_input_tensor( - quantized_input.sizes().vec(), - from_at_scalartype(quantized_input.scalar_type()), - in_storage); - - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - - IOValueRef r_zero_point = graph.add_input_tensor( - adjusted_zero_points.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - ValueRef r_out = graph.add_tensor( - quantized_input.sizes().vec(), - from_at_scalartype(out_dtype), - out_storage); - - const ValueRef r_axis = graph.add_scalar(axis); - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - const ValueRef r_output_dtype = - graph.add_scalar(static_cast(out_dtype)); - - VK_GET_OP_FN("quantized_decomposed.dequantize_per_channel.default") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_axis, - r_quant_min, - r_quant_max, - r_dtype, - r_output_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Copy input data to GPU - graph.copy_into_staging( - r_input.staging, - quantized_input.const_data_ptr(), - quantized_input.numel()); - - // copy scale tensor to GPU - graph.copy_into_staging( - r_scale.staging, scale_tensor.const_data_ptr(), scale_tensor.numel()); - - // copy zero_point tensor to GPU - graph.copy_into_staging( - r_zero_point.staging, - zero_point_tensor.const_data_ptr(), - zero_point_tensor.numel()); - - // Execute the graph - graph.execute(); - - // Copy output data back to CPU - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs with appropriate tolerance for half precision - bool output_correct; - if (out_dtype == at::kHalf) { - // Use higher tolerance for half precision due to limited precision - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); - } else { - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); - } - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " axis: " << axis << std::endl; - std::cout << " input sizes:"; - for (size_t i = 0; i < input_sizes.size(); i++) { - std::cout << " " << input_sizes[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - std::cout << " storage: " << in_storage << std::endl; - std::cout << std::endl; - - std::cout << "\033[91m quantized_input: \033[0m" << std::endl; - std::cout << quantized_input << std::endl; - std::cout << "\033[91m aten: \033[0m" << std::endl; - std::cout << aten_out << std::endl; - std::cout << "\033[91m reference: \033[0m" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "\033[91m vulkan: \033[0m" << std::endl; - std::cout << vk_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanDequantizePerChannelTest, - test_reference_dequantize_per_channel_uint8_to_float_3D_axis0) { - std::vector scales = {0.1, 0.2, 0.3}; - std::vector zero_points = {0, 5, -2}; - - test_reference_dequantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - 0, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_reference_dequantize_per_channel_int8_to_float_3D_axis2) { - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_reference_dequantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_reference_dequantize_per_channel_int8_to_float_3D_axisn1) { - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_reference_dequantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - -1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_reference_dequantize_per_channel_int32_to_float_4D_axis0) { - std::vector scales = {0.1, 0.2, 0.00002}; - std::vector zero_points = {0, 5, -4}; - - test_reference_dequantize_per_channel( - {3, 4, 2, 5}, // input sizes - scales, - zero_points, - 0, // axis - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, - at::kFloat); -} - -// END OF REFERENCE TESTS - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_int8_to_float_axis0) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(9, 0.1f); - std::vector zero_points(9, 2); - - // 1D Tensor - test_vulkan_dequantize_per_channel( - {9}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 2D Tensor - test_vulkan_dequantize_per_channel( - {9, 14}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 3D Tensor - test_vulkan_dequantize_per_channel( - {9, 7, 11}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 17, 5, 5}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {5, 17, 5, 9}, // input sizes - scales, - zero_points, - -1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_int8_to_float_axis1) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(14, 0.001f); - std::vector zero_points(14, -5); - - // 2D Tensor - test_vulkan_dequantize_per_channel( - {9, 14}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 3D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 5, 5}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {9, 7, 14, 5}, // input sizes - scales, - zero_points, - -2, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_int8_to_float_axis2) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(11, 0.5f); - std::vector zero_points(11, 12); - - // 3D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {9, 11, 14, 5}, // input sizes - scales, - zero_points, - -3, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_int8_to_float_axis3) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(7, 0.5f); - std::vector zero_points(7, 12); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 7}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {7, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_uint8_to_float_comprehensive) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.0001, 0.5, 0.02}; - std::vector zero_points = {0, 5, -5, 1, 12}; - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_8bit_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; - std::vector zero_points = {0, 5, 5, 1, 12}; - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kHalf); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kHalf); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kHalf); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kHalf); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kHalf); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_8bit_to_double) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; - std::vector zero_points = {0, 5, 5, 1, 12}; - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kDouble); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kDouble); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kDouble); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kDouble); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kDouble); -} - -void test_vulkan_dequantize_per_tensor_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create a quantized input tensor with values from quant_min to quant_max - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - float step = 1.0f; - if (input.numel() > 1) { - step = static_cast(quant_max - quant_min) / (input.numel() - 1); - } - - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - flat_input[i] = static_cast(qvalue); - } - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Create scale and zero_point as tensors (single element tensors) - at::Tensor scale_tensor = - at::tensor({scale}, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor({zero_point}, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output using tensor variant - at::Tensor reference_out = - torch::executor::native::dequantize_per_tensor_tensor_args_aten( - input, - scale_tensor, - zero_point_tensor, - quant_min, - quant_max, - dtype, - out_dtype); - - // Build Vulkan dequantize_per_tensor.tensor graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(dtype), in_storage); - - // Add scale and zero_point as tensor inputs (buffer storage, width packed) - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - const ValueRef r_out_dtype = - graph.add_scalar(static_cast(out_dtype)); - - VK_GET_OP_FN("quantized_decomposed.dequantize_per_tensor.tensor") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_dtype, - r_out_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Run Vulkan dequantize_per_tensor.tensor - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - graph.execute(); - - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs with appropriate tolerance for half precision - bool output_correct; - if (out_dtype == at::kHalf) { - // Use higher tolerance for half precision due to limited precision - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); - } else { - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); - } - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_int8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {2, 3, 4}, // input sizes - 0.01, // scale - 1, // zero_point - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_uint8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {2, 3, 4, 12}, // input sizes - 0.1, // scale - 5, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_int32_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {2, 3}, // input sizes - 0.01, // scale - 12, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_uint8_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {3, 4}, // input sizes - 0.3, // scale - 2, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_int8_to_double) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {2, 3, 4}, // input sizes - 0.03, // scale - -2, // zero_point - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kDouble); // output dtype -} diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp deleted file mode 100644 index 86eebcf9b14..00000000000 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ /dev/null @@ -1,2188 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -#include -#include -#include - -#include -#include - -#include "test_utils.h" - -#include -#include -#include - -float eps = 1e-7; - -namespace torch { -namespace executor { -namespace native { - -// Forward declarations of the functions we're testing -Tensor& quantize_per_tensor_out( - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out); - -Tensor& quantize_per_token_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out); - -Tensor& quantize_per_channel_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out); - -Tensor& quantize_per_tensor_tensor_args_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out); - -// Wrapper function for quantize_per_tensor_out without context -Tensor& quantize_per_tensor_out_no_context( - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - return torch::executor::native::quantize_per_tensor_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); -} - -// Wrapper function for quantize_per_token_out without context -Tensor& quantize_per_token_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - return torch::executor::native::quantize_per_token_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); -} - -// Wrapper function for quantize_per_channel_out without context -Tensor& quantize_per_channel_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - return torch::executor::native::quantize_per_channel_out( - input, scale, zero_point, axis, quant_min, quant_max, dtype, out); -} - -// Wrapper function for quantize_per_tensor_tensor_args_out without context -Tensor& quantize_per_tensor_tensor_args_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - return torch::executor::native::quantize_per_tensor_tensor_args_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); -} - -// ATen wrapper for quantize_per_tensor -at::Tensor quantize_per_tensor_aten( - const at::Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto out = at::empty_like(input, dtype); - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - WRAP_TO_ATEN(quantize_per_tensor_out_no_context, 6) - (input, scale, zero_point, quant_min, quant_max, et_dtype, out); - return out; -} - -// ATen wrapper for quantize_per_token -at::Tensor quantize_per_token_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto out = at::empty_like(input, dtype); - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - WRAP_TO_ATEN(quantize_per_token_out_no_context, 6) - (input, scale, zero_point, quant_min, quant_max, et_dtype, out); - return out; -} - -// ATen wrapper for quantize_per_channel -at::Tensor quantize_per_channel_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto out = at::empty_like(input, dtype); - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - WRAP_TO_ATEN(quantize_per_channel_out_no_context, 7) - (input, scale, zero_point, axis, quant_min, quant_max, et_dtype, out); - return out; -} - -// ATen wrapper for quantize_per_tensor with tensor args -at::Tensor quantize_per_tensor_tensor_args_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto out = at::empty_like(input, dtype); - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - WRAP_TO_ATEN(quantize_per_tensor_tensor_args_out_no_context, 6) - (input, scale, zero_point, quant_min, quant_max, et_dtype, out); - return out; -} - -} // namespace native -} // namespace executor -} // namespace torch - -void check_quantize_args( - int64_t quant_min, - int64_t quant_max, - c10::ScalarType out_dtype) { - using namespace vkcompute; - int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; - switch (out_dtype) { - case c10::kByte: - quant_min_lower_bound = - static_cast(std::numeric_limits::min()); - quant_max_upper_bound = - static_cast(std::numeric_limits::max()); - break; - case c10::kChar: - quant_min_lower_bound = - static_cast(std::numeric_limits::min()); - quant_max_upper_bound = - static_cast(std::numeric_limits::max()); - break; - case c10::kBits16: - case c10::kUInt16: - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - break; - case c10::kShort: - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - break; - case c10::kInt: - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - break; - default: - VK_CHECK_COND(false, "Unsupported dtype: ", scalar_type_name(out_dtype)); - } - VK_CHECK_COND( - quant_min >= quant_min_lower_bound, - "quant_min out of bound for dtype, expected quant_min_lower_bound: ", - quant_min_lower_bound, - " actual quant_min: ", - quant_min); - - VK_CHECK_COND( - quant_max <= quant_max_upper_bound, - "quant_max out of bound for dtype, expected quant_max_upper_bound: ", - quant_max_upper_bound, - " actual quant_max: ", - quant_max); -} - -/** - * Helper function to validate quantize_per_channel arguments - * Similar to the validation in op_quantize.cpp - */ -void check_quantize_per_channel_args( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis) { - // Normalize axis - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input_sizes.size(); - } - - ASSERT_GE(normalized_axis, 0) - << "axis " << axis << " is not legal, normalized axis " << normalized_axis - << " should be >= 0"; - - ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) - << "axis " << axis << " is not legal, normalized axis " << normalized_axis - << " should be < input.dim() " << input_sizes.size(); - - int64_t num_channels = input_sizes[normalized_axis]; - - ASSERT_EQ(num_channels, static_cast(scales.size())) - << "Expected scales.size() to match input.size(axis) (" << num_channels - << "), but got " << scales.size(); - - ASSERT_EQ(num_channels, static_cast(zero_points.size())) - << "Expected zero_points.size() to match input.size(axis) (" - << num_channels << "), but got " << zero_points.size(); -} - -// -// Reference Implementation -// - -/* - * Reference implementation of quantize_per_tensor - */ -at::Tensor quantize_per_tensor_reference_impl( - const at::Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - // Create output tensor with the target dtype - at::Tensor out = at::empty_like(input, dtype); - - // Quantize the input tensor - float inv_scale = 1.0 / scale; - - // Iterate through the tensor and quantize each element - at::Tensor float_input = input.to(at::kFloat); - at::Tensor float_values = float_input.flatten(); - - auto out_flat = out.flatten(); - - for (int i = 0; i < float_values.numel(); i++) { - float value = float_values[i].item(); - int64_t qvalue = zero_point + std::nearbyint(inv_scale * value); - - qvalue = std::max(qvalue, quant_min); - qvalue = std::min(qvalue, quant_max); - - if (dtype == at::kByte) { - out_flat[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - out_flat[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - out_flat[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - out_flat[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - out_flat[i] = static_cast(qvalue); - } - } - - return out.reshape(input.sizes()); -} - -/* - * Reference implementation of quantize_per_token - */ -at::Tensor quantize_per_token_reference_impl( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - // Create output tensor with the target dtype - at::Tensor out = at::empty_like(input, dtype); - - // Calculate number of tokens - int num_tokens = 1; - for (int i = 0; i < input.dim() - 1; i++) { - num_tokens *= input.size(i); - } - - // Verify that the number of tokens matches the size of scale and zero_point - // tensors - assert(num_tokens == scale.numel()); - assert(num_tokens == zero_point.numel()); - - // Reshape input to [num_tokens, last_dim] - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); - - // Quantize each token separately - for (int token_idx = 0; token_idx < num_tokens; token_idx++) { - // Use float for scale since Vulkan doesn't support double - float token_scale = scale[token_idx].item(); - // Use int for zero_point since Vulkan doesn't support int64_t - int token_zero_point = zero_point[token_idx].item(); - - float inv_scale = 1.0 / token_scale; - - // Quantize the token - for (int i = 0; i < input.size(-1); i++) { - float value = reshaped_input[token_idx][i].item(); - int qvalue = token_zero_point + std::nearbyint(inv_scale * value); - - qvalue = std::max(qvalue, quant_min); - qvalue = std::min(qvalue, quant_max); - - if (dtype == at::kByte) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } - } - } - - return out; -} - -/* - * Reference implementation of quantize_per_channel - */ -at::Tensor quantize_per_channel_reference_impl( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - // Normalize axis to handle negative values - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input.dim(); - } - - // Create output tensor with the same shape as input but with target dtype - at::Tensor output = at::empty_like(input, dtype); - - // Get the number of channels along the quantization axis - int64_t num_channels = input.size(normalized_axis); - - // Calculate strides for efficient indexing - std::vector input_strides; - std::vector input_sizes; - for (int64_t i = 0; i < input.dim(); i++) { - input_sizes.push_back(input.size(i)); - input_strides.push_back(input.stride(i)); - } - - // Get data pointers - const float* input_data = input.const_data_ptr(); - const double* scale_data = scale.const_data_ptr(); - const int64_t* zero_point_data = zero_point.const_data_ptr(); - - // Iterate through all elements in the tensor - int64_t total_elements = input.numel(); - - // Helper lambda to convert flat index to multi-dimensional coordinates - auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { - int64_t remaining = flat_idx; - for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { - coords[dim] = remaining % input_sizes[dim]; - remaining /= input_sizes[dim]; - } - }; - - // Process each element - std::vector coords(input.dim()); - for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { - // Convert flat index to coordinates - flat_to_coords(flat_idx, coords); - - // Get the channel index for this element - int64_t channel_idx = coords[normalized_axis]; - - // Get the quantization parameters for this channel - double channel_scale = scale_data[channel_idx]; - int64_t channel_zero_point = zero_point_data[channel_idx]; - - // Get the input value - float input_value = input_data[flat_idx]; - - // Apply quantization formula: round(input / scale) + zero_point - float inv_scale = 1.0f / static_cast(channel_scale); - int64_t quantized_value = static_cast( - static_cast(channel_zero_point) + - std::nearbyint(static_cast(inv_scale * input_value))); - - // Clamp to quantization bounds - quantized_value = std::max(quantized_value, quant_min); - quantized_value = std::min(quantized_value, quant_max); - - // Store the result based on output dtype - switch (dtype) { - case at::kByte: { - uint8_t* output_data = output.mutable_data_ptr(); - output_data[flat_idx] = static_cast(quantized_value); - break; - } - case at::kChar: { - int8_t* output_data = output.mutable_data_ptr(); - output_data[flat_idx] = static_cast(quantized_value); - break; - } - case at::kShort: { - int16_t* output_data = output.mutable_data_ptr(); - output_data[flat_idx] = static_cast(quantized_value); - break; - } - case at::kInt: { - int32_t* output_data = output.mutable_data_ptr(); - output_data[flat_idx] = static_cast(quantized_value); - break; - } - default: - assert(false && "Unsupported output dtype"); - } - } - - return output; -} - -// Forward declaration of implementation functions -void test_vulkan_quantize_per_token_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_quantize_per_channel_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_quantize_per_tensor_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_quantize_per_token( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - // Test with buffer storage - test_vulkan_quantize_per_token_impl( - input_sizes, - scales, - zero_points, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // If the in_dtype is a double, convert to float for texture implementation - // since they don't support 64bit as inputs - if (in_dtype == at::kDouble) { - in_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_quantize_per_token_impl( - input_sizes, - scales, - zero_points, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_quantize_per_channel( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - // Test with buffer storage - test_vulkan_quantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // If the in_dtype is a double, convert to float for texture implementation - // since they don't support 64bit as inputs - if (in_dtype == at::kDouble) { - in_dtype = at::kFloat; - } - - test_vulkan_quantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_quantize_per_tensor_tensor( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - // Test with buffer storage - test_vulkan_quantize_per_tensor_tensor_impl( - input_sizes, - scale, - zero_point, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // If the in_dtype is a double, convert to float for texture implementation - // since they don't support 64bit as inputs - if (in_dtype == at::kDouble) { - in_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_quantize_per_tensor_tensor_impl( - input_sizes, - scale, - zero_point, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -void test_reference_quantize_per_tensor( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - check_quantize_args(quant_min, quant_max, dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - - // Fill with a simple pattern: values from 0 to 1 in steps - float step = 1.0f / (input.numel() - 1); - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - flat_input[i] = i * step; - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - scale = scale < eps ? eps : scale; - - // Get reference output - at::Tensor reference_out = quantize_per_tensor_reference_impl( - input, scale, zero_point, quant_min, quant_max, dtype); - - // Get implementation output - at::Tensor impl_out = torch::executor::native::quantize_per_tensor_aten( - input, scale, zero_point, quant_min, quant_max, dtype); - - // Convert to int for consistent display regardless of underlying type - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor impl_int = impl_out.to(at::kInt); - - const bool output_correct = at::equal(reference_int, impl_int); - if (!output_correct) { - at::Tensor diffs = at::abs(reference_int - impl_int); - - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "my_reference:" << std::endl; - std::cout << impl_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanQuantizePerTensorTest, - test_reference_quantize_per_tensor_float_to_int8) { - test_reference_quantize_per_tensor( - {2, 3, 4}, // input sizes - 0.1, // scale - 0, // zero_point - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerTensorTest, - test_reference_quantize_per_tensor_float_to_int32) { - test_reference_quantize_per_tensor( - {2, 3, 4}, // input sizes - 0.04, // scale - 5, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kFloat, - at::kInt); -} - -TEST( - VulkanQuantizePerTensorTest, - test_reference_quantize_per_tensor_half_to_uint8) { - test_reference_quantize_per_tensor( - {2, 3, 4}, // input sizes - 0.2, // scale - 2, // zero_point - 0, // quant_min - 255, // quant_max - at::kHalf, - at::kByte); -} - -TEST( - VulkanQuantizePerTensorTest, - test_reference_quantize_per_tensor_half_to_int32) { - test_reference_quantize_per_tensor( - {2, 3, 4}, // input sizes - 0.01, // scale - 1, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kHalf, - at::kInt); -} - -// No Vulkan tests for quantized_decomposed.quantize_per_tensor.default -// because it is not going to be implemented in Vulkan since we will -// be handling any future calls to this op via the export stage - -void test_reference_quantize_per_token( - const std::vector& input_sizes, - const std::vector& pre_scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - check_quantize_args(quant_min, quant_max, dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - - // Fill with a simple pattern: values from 0 to 1 in steps - float step = 1.0 / (input.numel() - 1); - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - flat_input[i] = i * step; - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Calculate number of tokens - int num_tokens = 1; - for (int i = 0; i < input.dim() - 1; i++) { - num_tokens *= input.size(i); - } - - // Verify that the number of tokens matches the size of scales and zero_points - ASSERT_EQ(num_tokens, pre_scales.size()); - ASSERT_EQ(num_tokens, zero_points.size()); - - std::vector scales = pre_scales; - for (auto& s : scales) { - s = s < eps ? eps : s; - } - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor reference_out = quantize_per_token_reference_impl( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); - - // Get implementation output - at::Tensor impl_out = torch::executor::native::quantize_per_token_aten( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); - - // Convert to int for consistent display regardless of underlying type - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor impl_int = impl_out.to(at::kInt); - - const bool output_correct = at::equal(reference_int, impl_out); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "my_reference:" << std::endl; - std::cout << impl_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -void test_vulkan_quantize_per_token_impl( - const std::vector& input_sizes, - const std::vector& pre_scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt, - const vkcompute::utils::StorageType in_storage = - vkcompute::utils::kTexture3D, - const vkcompute::utils::StorageType out_storage = - vkcompute::utils::kTexture3D) { - check_quantize_args(quant_min, quant_max, dtype); - int num_tokens = 1; - for (int i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - ASSERT_EQ(num_tokens, pre_scales.size()); - ASSERT_EQ(num_tokens, zero_points.size()); - - std::vector scales = pre_scales; - for (auto& s : scales) { - s = s < eps ? eps : s; - } - - // Create input tensor with random values - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output to show what we would compare against - at::Tensor reference_out = torch::executor::native::quantize_per_token_aten( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); - - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN("quantized_decomposed.quantize_per_token.default") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - - graph.prepack(); - - // Copy input data to GPU - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - // Execute the graph - graph.execute(); - - // Copy output data back to CPU - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor vk_int = vk_out.to(at::kInt); - - // Tolerance is 1 to address rounding errors and fp math differences between - // CPU/GPU - const bool output_correct = - at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); - if (!output_correct) { - at::Tensor diffs = at::abs(reference_int - vk_int); - - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanQuantizePerTokenTest, - test_reference_quantize_per_token_float_to_int8) { - std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; - std::vector zero_points = {1, 2, 3, 0, -1, -2}; - - test_reference_quantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerTokenTest, - test_reference_quantize_per_token_float_to_int32) { - std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; - std::vector zero_points = {1, 2, 3, 0, -1, -2}; - - test_reference_quantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kFloat, - at::kInt); -} - -TEST( - VulkanQuantizePerTokenTest, - test_reference_quantize_per_token_half_to_int32) { - std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; - std::vector zero_points = {1, 2, 3, 0, -1, -2}; - - test_reference_quantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kHalf, - at::kInt); -} - -TEST( - VulkanQuantizePerTokenTest, - test_reference_quantize_per_token_half_to_uint8) { - std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; - std::vector zero_points = {1, 2, 3, 0, -1, -2}; - - test_reference_quantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - 0, // quant_min - 255, // quant_max - at::kHalf, - at::kByte); -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_uint8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = { - -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; - std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; - - test_vulkan_quantize_per_token( - {5, 2, 4}, // input sizes (5*2=10 tokens) - scales, - zero_points, - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); -} - -TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = { - -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; - std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; - - test_vulkan_quantize_per_token( - {5, 2, 4}, // input sizes (5 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_int32) { - std::vector scales = { - -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; - std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; - - test_vulkan_quantize_per_token( - {5, 2, 4}, // input sizes (5*2=10 tokens) - scales, - zero_points, - -2147483648, // quant_min - 2147483647, // quant_max - at::kFloat, - at::kInt); -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_int32_small_scales) { - std::vector scales = { - 0, - 2.9387358770557188e-39f, - 1.40129846e-45f, - 1.17549435e-38f, - 0.0000000000001}; - std::vector zero_points = {20, -10, 15, 200, 50}; - - test_vulkan_quantize_per_token( - {5, 2}, // input sizes (3 tokens) - scales, - zero_points, - -2147483648, // quant_min - 2147483647, // quant_max - at::kFloat, - at::kInt); -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_uint8_many_tokens) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(18, 0.1); - std::vector zero_points(18, 5); - - // Alternate scale values - for (size_t i = 0; i < scales.size(); i++) { - scales[i] = (i % 2 == 0) ? 0.3 : -0.5; - } - - test_vulkan_quantize_per_token( - {3, 3, 2, 3}, // input sizes (3*3*2=18 tokens) - scales, - zero_points, - 0, // quant_min - 125, // quant_max - at::kFloat, - at::kByte); -} - -TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_half_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_vulkan_quantize_per_token( - {2, 2}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kHalf, // input dtype - at::kChar); // output dtype -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_double_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_vulkan_quantize_per_token( - {2, 2}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kDouble, // input dtype - at::kChar); // output dtype -} - -void test_reference_quantize_per_channel( - const std::vector& input_sizes, - const std::vector& pre_scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - check_quantize_args(quant_min, quant_max, dtype); - check_quantize_per_channel_args(input_sizes, pre_scales, zero_points, axis); - - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - - // Fill with a simple pattern: values from 0 to 1 in steps - float step = 1.0f / (input.numel() - 1); - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - flat_input[i] = i * step; - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - std::vector scales = pre_scales; - for (auto& s : scales) { - s = s < eps ? eps : s; - } - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor my_ref = quantize_per_channel_reference_impl( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype); - - // Get implementation output - at::Tensor cpu_ref = torch::executor::native::quantize_per_channel_aten( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype); - - // Get direct ATen implementation output - c10::ScalarType aten_dtype = dtype; - if (dtype == at::kChar) { - aten_dtype = c10::kQInt8; - } else if (dtype == at::kByte) { - aten_dtype = c10::kQUInt8; - } - - // Normalize axis for ATen (it doesn't handle negative values) - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input.dim(); - } - - at::Tensor aten_ref = at::quantize_per_channel( - input, scale_tensor, zero_point_tensor, normalized_axis, aten_dtype); - - // Convert to int for consistent display regardless of underlying type - at::Tensor my_ref_int = my_ref.to(at::kInt); - at::Tensor cpu_ref_int = cpu_ref.to(at::kInt); - // For quantized tensors, we need to use int_repr() to get the underlying - // integer values - at::Tensor aten_ref_int = aten_ref.int_repr().to(at::kInt); - - const bool output_correct = at::equal(my_ref_int, cpu_ref_int); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " axis: " << axis << std::endl; - std::cout << " input sizes:"; - for (size_t i = 0; i < input_sizes.size(); i++) { - std::cout << " " << input_sizes[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "aten_ref:" << std::endl; - std::cout << aten_ref_int << std::endl; - std::cout << "cpu_ref:" << std::endl; - std::cout << cpu_ref_int << std::endl; - std::cout << "my_ref:" << std::endl; - std::cout << my_ref_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -void test_vulkan_quantize_per_channel_impl( - const std::vector& input_sizes, - const std::vector& pre_scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt, - const vkcompute::utils::StorageType in_storage = - vkcompute::utils::kTexture3D, - const vkcompute::utils::StorageType out_storage = - vkcompute::utils::kTexture3D) { - check_quantize_args(quant_min, quant_max, dtype); - check_quantize_per_channel_args(input_sizes, pre_scales, zero_points, axis); - - std::vector scales = pre_scales; - for (auto& s : scales) { - s = s < eps ? eps : s; - } - - // Create input tensor with random values - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor reference_out = torch::executor::native::quantize_per_channel_aten( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype); - - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_axis = graph.add_scalar(axis); - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN("quantized_decomposed.quantize_per_channel.default") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_axis, - r_quant_min, - r_quant_max, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Copy input data to GPU - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - // Execute the graph - graph.execute(); - - // Copy output data back to CPU - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor vk_int = vk_out.to(at::kInt); - - // Tolerance is 1 to address rounding errors and fp math differences between - // CPU/GPU - const bool output_correct = - at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); - if (!output_correct) { - at::Tensor diffs = at::abs(reference_int - vk_int); - - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " axis: " << axis << std::endl; - std::cout << " input sizes:"; - for (size_t i = 0; i < input_sizes.size(); i++) { - std::cout << " " << input_sizes[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanQuantizePerChannelTest, - test_reference_quantize_per_channel_float_to_int8_3D_axis0) { - std::vector scales = {0.1, 0.2, 0.3}; - std::vector zero_points = {0, 5, -2}; - - test_reference_quantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_reference_quantize_per_channel_float_to_int8_3D_axis2) { - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_reference_quantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_reference_quantize_per_channel_float_to_int8_3D_axisn1) { - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_reference_quantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - -1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_reference_quantize_per_channel_float_to_int8_4D_axis0) { - std::vector scales = {0.1, 0.2, 0.00002}; - std::vector zero_points = {0, 5, -4}; - - test_reference_quantize_per_channel( - {3, 4, 2, 5}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -// END OF REFERENCE TESTS - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_int8_axis0) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(9, 0.1f); - std::vector zero_points(9, 2); - - // 1D Tensor - test_vulkan_quantize_per_channel( - {9}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 2D Tensor - test_vulkan_quantize_per_channel( - {9, 14}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 3D Tensor - test_vulkan_quantize_per_channel( - {9, 7, 11}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 17, 5, 5}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {5, 17, 5, 9}, // input sizes - scales, - zero_points, - -1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_int8_axis1) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(14, 0.001f); - std::vector zero_points(14, -5); - - // 2D Tensor - test_vulkan_quantize_per_channel( - {9, 14}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 3D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 5, 5}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {9, 7, 14, 5}, // input sizes - scales, - zero_points, - -2, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_int8_axis2) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(11, 0.5f); - std::vector zero_points(11, 12); - - // 3D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {9, 11, 14, 5}, // input sizes - scales, - zero_points, - -3, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_int8_axis3) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(7, 0.5f); - std::vector zero_points(7, 12); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 7}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {7, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_uint8_comprehensive) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.0001, 0.5, 0.02}; - std::vector zero_points = {0, 5, -5, 1, 12}; - - // 4D Tensor - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_half_to_8bit) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; - std::vector zero_points = {0, 5, 5, 1, 12}; - - // 4D Tensor - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kHalf, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kHalf, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kHalf, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kHalf, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kHalf, - at::kByte); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_double_to_8bit) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; - std::vector zero_points = {0, 5, 5, 1, 12}; - - // 4D Tensor - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kDouble, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kDouble, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kDouble, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kDouble, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kDouble, - at::kByte); -} - -void test_vulkan_quantize_per_tensor_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt, - const vkcompute::utils::StorageType in_storage = - vkcompute::utils::kTexture3D, - const vkcompute::utils::StorageType out_storage = - vkcompute::utils::kTexture3D) { - check_quantize_args(quant_min, quant_max, dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - - scale = scale < eps ? eps : scale; - - // Create scale and zero_point as tensors (single element tensors) - at::Tensor scale_tensor = - at::tensor({scale}, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor({zero_point}, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output using tensor variant - at::Tensor reference_out = - torch::executor::native::quantize_per_tensor_tensor_args_aten( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); - - // Build Vulkan quantize_per_tensor.tensor graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - - // Add scale and zero_point as tensor inputs (buffer storage, width packed) - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN("quantized_decomposed.quantize_per_tensor.tensor") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Run Vulkan quantize_per_tensor.tensor - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - graph.execute(); - - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs - // For quantized types, we need to compare the actual integer values - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor vk_int = vk_out.to(at::kInt); - - // Tolerance is 1 to address rounding errors and fp math differences between - // CPU/GPU - const bool output_correct = - at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); - if (!output_correct) { - at::Tensor diffs = at::abs(reference_int - vk_int); - - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_float_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {2, 3, 4}, // input sizes - 0.01, // scale - 1, // zero_point - -128, // quant_min - 127, // quant_max - at::kFloat, // input dtype - at::kChar); // output dtype -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_float_to_uint8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {2, 3, 4, 12}, // input sizes - 0.1, // scale - 5, // zero_point - 0, // quant_min - 255, // quant_max - at::kFloat, // input dtype - at::kByte); // output dtype -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_float_to_int32) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {2, 3}, // input sizes - 0.01, // scale - 12, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kFloat, // input dtype - at::kInt); // output dtype -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_half_to_uint8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {3, 4}, // input sizes - 0.3, // scale - 2, // zero_point - 0, // quant_min - 255, // quant_max - at::kHalf, // input dtype - at::kByte); // output dtype -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_double_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {2, 3, 4}, // input sizes - 0.03, // scale - -2, // zero_point - -128, // quant_min - 127, // quant_max - at::kDouble, // input dtype - at::kChar); // output dtype -} diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index a94e68a53af..c3347b339a7 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -23,6 +23,24 @@ #include #include +// +// SDPA Mode Enum +// + +enum class SDPAMode { DECOMPOSED, FUSED, ATTN_WEIGHT_ONLY }; + +std::ostream& operator<<(std::ostream& os, const SDPAMode& mode) { + switch (mode) { + case SDPAMode::DECOMPOSED: + return os << "DECOMPOSED"; + case SDPAMode::FUSED: + return os << "FUSED"; + case SDPAMode::ATTN_WEIGHT_ONLY: + return os << "ATTN_WEIGHT_ONLY"; + } + return os; +} + namespace torch { namespace executor { namespace native { @@ -74,7 +92,7 @@ at::Tensor sdpa_with_kv_cache_aten( const int64_t seq_len, // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const std::optional attn_mask, + const std::optional& attn_mask, const double dropout_p, const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy @@ -161,10 +179,11 @@ at::Tensor sdpa_reference_impl( at::Tensor& value_cache, const int64_t start_pos, const int64_t seq_len, - const std::optional __attn_mask_ignored, + const std::optional& __attn_mask_ignored, const double dropout_p, const bool is_causal, - const std::optional scale) { + const std::optional scale, + SDPAMode mode = SDPAMode::DECOMPOSED) { at::Tensor attn_mask = construct_attention_mask(q_projected, key_cache, start_pos); @@ -202,6 +221,10 @@ at::Tensor sdpa_reference_impl( float scale_factor = 1.0 / sqrt(q_transposed.size(-1)); at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask; + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + return attn_weight; + } + at::Tensor attn_weight_softmax = at::softmax(attn_weight, -1); at::Tensor out = at::matmul(attn_weight_softmax, v_transposed); @@ -268,7 +291,8 @@ void test_vulkan_sdpa( const int num_kv_heads, const int batch_size, vkcompute::utils::StorageType storage_type, - at::ScalarType dtype = at::kFloat) { + at::ScalarType dtype = at::kFloat, + SDPAMode mode = SDPAMode::DECOMPOSED) { // compute the max sequence length int max_seq_len = start_input_pos; for (int i = 0; i < sequence_lens.size(); ++i) { @@ -296,6 +320,9 @@ void test_vulkan_sdpa( // Get reference output at::Tensor out = at::empty_like(q); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + out = at::empty({batch_size, num_heads, init_seq_len, init_seq_len}); + } // Build Vulkan SDPA graph using namespace vkcompute; @@ -330,22 +357,87 @@ void test_vulkan_sdpa( const ValueRef r_out = graph.add_tensor( out.sizes().vec(), from_at_scalartype(out.scalar_type()), storage_type); - VK_GET_OP_FN("sdpa_with_kv_cache.default") - (graph, - { - r_q.value, - r_k.value, - r_v.value, - r_k_cache_data, - r_v_cache_data, - r_input_pos_symint, - kDummyValueRef, // sequence_len - kDummyValueRef, // attn_mask - kDummyValueRef, // dropout_p - kDummyValueRef, // is_causal - kDummyValueRef, // scale - r_out, - }); + switch (mode) { + case SDPAMode::DECOMPOSED: { + const ValueRef r_k_cache = graph.add_tensor( + k_cache_data.sizes().vec(), + from_at_scalartype(k_cache_data.scalar_type()), + storage_type); + const ValueRef r_v_cache = graph.add_tensor( + v_cache_data.sizes().vec(), + from_at_scalartype(v_cache_data.scalar_type()), + storage_type); + const ValueRef r_dummy_out = graph.add_tensor( + {1}, from_at_scalartype(out.scalar_type()), utils::kBuffer); + VK_GET_OP_FN("update_cache.default") + (graph, + { + r_k.value, + r_k_cache, + r_input_pos_symint, + r_dummy_out, + }); + VK_GET_OP_FN("update_cache.default") + (graph, + { + r_v.value, + r_v_cache, + r_input_pos_symint, + r_dummy_out, + }); + VK_GET_OP_FN("llama.custom_sdpa.default") + (graph, + { + r_q.value, + r_k_cache, + r_v_cache, + r_input_pos_symint, + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + } break; + case SDPAMode::FUSED: + VK_GET_OP_FN("sdpa_with_kv_cache.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_k_cache_data, + r_v_cache_data, + r_input_pos_symint, + kDummyValueRef, // sequence_len + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + break; + case SDPAMode::ATTN_WEIGHT_ONLY: + VK_GET_OP_FN("testing.compute_attn_weight_with_kv_cache.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_k_cache_data, + r_v_cache_data, + r_input_pos_symint, + kDummyValueRef, // sequence_len + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + break; + default: + VK_THROW("Unsupported SDPA mode"); + } ValueRef staging_out = graph.set_output_tensor(r_out); @@ -378,7 +470,7 @@ void test_vulkan_sdpa( v = at::rand_like(k); at::Tensor reference_out = sdpa_reference_impl( - q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {}); + q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {}, mode); graph.set_symint(r_input_pos_symint, input_pos); graph.resize_input(0, q.sizes().vec()); @@ -393,15 +485,38 @@ void test_vulkan_sdpa( graph.execute(); - out = at::empty_like(q); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + const int context_len = input_pos + seq_len; + const int context_len_align_up4 = (context_len + 3) & ~3; + const int seq_len_align_up4 = (seq_len + 3) & ~3; + + out = at::empty( + {batch_size, num_heads, seq_len_align_up4, context_len_align_up4}, + q.options()); + } else { + out = at::empty_like(q); + } EXTRACT_TENSOR(out); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + // Index vk_out to only include the relevant seq_len and context_len + // dimensions + int context_len = input_pos + seq_len; + vk_out = vk_out.index( + {at::indexing::Slice(), + at::indexing::Slice(), + at::indexing::Slice(0, seq_len), + at::indexing::Slice(0, context_len)}); + } + const bool output_correct = at::allclose(reference_out, vk_out); if (!output_correct) { // Print only differing tensor elements side by side for easier comparison auto ref_flat = reference_out.flatten(); auto vk_flat = vk_out.flatten(); auto numel = ref_flat.numel(); + std::cout << "While testing " << mode << " mode with " << storage_type + << " storage" << std::endl; std::cout << "reference_out\tvk_out\tindex" << std::endl; int first_diff_idx = -1; auto sizes = reference_out.sizes(); @@ -466,27 +581,32 @@ void test_vulkan_sdpa( const int num_kv_heads, const int batch_size, at::ScalarType dtype = at::kFloat) { - // Test texture - test_vulkan_sdpa( - start_input_pos, - sequence_lens, - head_dim, - num_heads, - num_kv_heads, - batch_size, - vkcompute::utils::kTexture3D, - dtype); - - // Test buffer - test_vulkan_sdpa( - start_input_pos, - sequence_lens, - head_dim, - num_heads, - num_kv_heads, - batch_size, - vkcompute::utils::kBuffer, - dtype); + for (SDPAMode mode : + {SDPAMode::ATTN_WEIGHT_ONLY, SDPAMode::DECOMPOSED, SDPAMode::FUSED}) { + // Test texture + test_vulkan_sdpa( + start_input_pos, + sequence_lens, + head_dim, + num_heads, + num_kv_heads, + batch_size, + vkcompute::utils::kTexture3D, + dtype, + mode); + + // Test buffer + test_vulkan_sdpa( + start_input_pos, + sequence_lens, + head_dim, + num_heads, + num_kv_heads, + batch_size, + vkcompute::utils::kBuffer, + dtype, + mode); + } } TEST(VulkanSDPATest, test_sdpa_op_small_params) { diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index b9386f92772..dae2eddf8b2 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -177,33 +177,6 @@ def define_common_targets(is_fbcode = False): "//executorch/extension/tensor:tensor", ] ) - define_test_targets( - "quantize_test", - extra_deps = [ - ":test_utils", - "//executorch/kernels/quantized/cpu:op_quantize", - "//executorch/extension/tensor:tensor", - "//executorch/extension/aten_util:aten_bridge", - ] - ) - define_test_targets( - "dequantize_test", - extra_deps = [ - ":test_utils", - "//executorch/kernels/quantized/cpu:op_dequantize", - "//executorch/extension/tensor:tensor", - "//executorch/extension/aten_util:aten_bridge", - ] - ) - define_test_targets( - "choose_qparams_test", - extra_deps = [ - ":test_utils", - "//executorch/kernels/quantized/cpu:op_choose_qparams", - "//executorch/extension/tensor:tensor", - "//executorch/extension/aten_util:aten_bridge", - ] - ) define_test_targets( "quantized_linear_test", extra_deps = [ diff --git a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py index 76eb9dbe838..cd27915225b 100644 --- a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py @@ -241,7 +241,7 @@ def generate_benchmark_fixture(self) -> str: return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone(); }} -at::Tensor make_index_tensor_1d(std::vector indices) {{ +at::Tensor make_index_tensor_1d(std::vector indices) {{ at::ScalarType dtype = at::kInt; std::vector sizes = {{static_cast(indices.size())}}; @@ -249,7 +249,7 @@ def generate_benchmark_fixture(self) -> str: return at::from_blob(indices.data(), sizes, dtype).detach().clone(); }} -at::Tensor make_index_tensor_2d(std::vector> indices) {{ +at::Tensor make_index_tensor_2d(std::vector> indices) {{ at::ScalarType dtype = at::kInt; std::vector sizes = {{ static_cast(indices.size()), @@ -265,7 +265,7 @@ def generate_benchmark_fixture(self) -> str: return at::from_blob(acc.data(), sizes, dtype).detach().clone(); }} -at::Tensor make_index_tensor_3d(std::vector>> indices) {{ +at::Tensor make_index_tensor_3d(std::vector>> indices) {{ at::ScalarType dtype = at::kInt; std::vector sizes = {{ static_cast(indices.size()), diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py index 80b4d5dead9..49419a50399 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py @@ -348,7 +348,7 @@ def generate_suite_cpp(self) -> str: return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone(); }} -at::Tensor make_index_tensor_1d(std::vector indices) {{ +at::Tensor make_index_tensor_1d(std::vector indices) {{ at::ScalarType dtype = at::kInt; std::vector sizes = {{static_cast(indices.size())}}; @@ -356,14 +356,14 @@ def generate_suite_cpp(self) -> str: return at::from_blob(indices.data(), sizes, dtype).detach().clone(); }} -at::Tensor make_index_tensor_2d(std::vector> indices) {{ +at::Tensor make_index_tensor_2d(std::vector> indices) {{ at::ScalarType dtype = at::kInt; std::vector sizes = {{ static_cast(indices.size()), static_cast(indices[0].size())}}; // Flatten indices as from_blob reads garbage otherwise. - std::vector acc; + std::vector acc; for (auto& vec: indices) {{ acc.insert(acc.end(), vec.begin(), vec.end()); }} @@ -372,7 +372,7 @@ def generate_suite_cpp(self) -> str: return at::from_blob(acc.data(), sizes, dtype).detach().clone(); }} -at::Tensor make_index_tensor_3d(std::vector>> indices) {{ +at::Tensor make_index_tensor_3d(std::vector>> indices) {{ at::ScalarType dtype = at::kInt; std::vector sizes = {{ static_cast(indices.size()), @@ -380,7 +380,7 @@ def generate_suite_cpp(self) -> str: static_cast(indices[0][0].size())}}; // Flatten indices as from_blob reads garbage otherwise. - std::vector acc; + std::vector acc; for (auto& v: indices) {{ for (auto& vv: v) {{ acc.insert(acc.end(), vv.begin(), vv.end()); diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index c368c23c539..08bc502f964 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -34,6 +34,7 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam(); config.set_storage_type_override(default_storage_type); config.set_memory_layout_override(default_memory_layout); + config.force_resize = true; graph = new ComputeGraph(config); if (test_dtype == at::kHalf) {{ diff --git a/backends/vulkan/test/scripts/test_model.sh b/backends/vulkan/test/scripts/test_model.sh index 5f06d2c039b..40ec88bae70 100755 --- a/backends/vulkan/test/scripts/test_model.sh +++ b/backends/vulkan/test/scripts/test_model.sh @@ -111,6 +111,7 @@ build_core_libraries_and_devtools() { -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ diff --git a/backends/vulkan/test/scripts/test_op.sh b/backends/vulkan/test/scripts/test_op.sh index 1ec07b7f75f..797089e54dc 100755 --- a/backends/vulkan/test/scripts/test_op.sh +++ b/backends/vulkan/test/scripts/test_op.sh @@ -138,6 +138,7 @@ build_core_libraries() { -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index f8194f0b32c..03a3263c293 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -11,20 +11,14 @@ from typing import Tuple import executorch.backends.vulkan.test.utils as test_utils - import torch - from executorch.backends.transforms.convert_dtype_pass import I64toI32 - from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner - from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend - from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) - from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, @@ -36,11 +30,8 @@ ) from executorch.extension.pytree import tree_flatten from torch.export import Dim, export, ExportedProgram - from torchao.quantization.granularity import PerGroup - from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e - from torchao.quantization.pt2e.quantizer import Quantizer from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ from torchao.utils import unwrap_tensor_subclass @@ -69,9 +60,6 @@ def lower_module( edge_program = to_edge_transform_and_lower( program, compile_config=edge_compile_config, - transform_passes=[ - I64toI32(edge_compile_config._skip_dim_order), - ], partitioner=[VulkanPartitioner(compile_options)], ) @@ -1969,102 +1957,6 @@ def forward(self, x): sample_inputs, ) - def test_vulkan_backend_full_quantization_workflow(self): - class FullQuantizationWorkflowModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - # Step 1: Choose quantization parameters per tensor - scale, zero_point = ( - torch.ops.quantized_decomposed.choose_qparams.tensor( - x, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - eps=1e-5, - dtype=torch.int32, - ) - ) - - # Step 2: Quantize using the calculated parameters - quantized = torch.ops.quantized_decomposed.quantize_per_tensor.tensor( - x, - scale, - zero_point, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - dtype=torch.int32, - ) - - # Step 3: Dequantize back to float - dequantized = ( - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor( - quantized, - scale, - zero_point, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - dtype=torch.int32, - ) - ) - - return dequantized - - full_workflow_module = FullQuantizationWorkflowModule() - sample_inputs = (torch.rand(size=(2, 3, 4), dtype=torch.float32),) - - # Use higher tolerance since quantization introduces some error - self.lower_module_and_test_output( - full_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3 - ) - - def test_vulkan_backend_full_per_token_quantization_workflow(self): - class FullPerTokenQuantizationWorkflowModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - # Step 1: Choose quantization parameters per token - scale, zero_point = ( - torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( - x, - dtype=torch.int32, - ) - ) - - # Step 2: Quantize using the calculated parameters per token - quantized = torch.ops.quantized_decomposed.quantize_per_token.default( - x, - scale, - zero_point, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - dtype=torch.int32, - ) - - # Step 3: Dequantize back to float per token - dequantized = ( - torch.ops.quantized_decomposed.dequantize_per_token.default( - quantized, - scale, - zero_point, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - dtype=torch.int32, - output_dtype=torch.float32, - ) - ) - - return dequantized - - full_per_token_workflow_module = FullPerTokenQuantizationWorkflowModule() - sample_inputs = (torch.rand(size=(6, 4), dtype=torch.float32),) - - # Use higher tolerance since quantization introduces some error - self.lower_module_and_test_output( - full_per_token_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3 - ) - def test_vulkan_backend_different_required_reprs(self): class ComplexModule(torch.nn.Module): """ @@ -2482,6 +2374,7 @@ def forward(self, x): rtol=1e-1, ) + @unittest.skip("Cannot run on swiftshader due to no integer dot product support") def test_vulkan_backend_xnnpack_pt2e_quantized_conv_sequence(self): """ Test a sequence of convolution layers quantized with PT2E quantization. @@ -2572,6 +2465,7 @@ def forward(self, x): rtol=1e-1, ) + @unittest.skip("Cannot run on swiftshader due to no integer dot product support") def test_vulkan_backend_xnnpack_pt2e_quantized_conv_sequence_all_reduced(self): """ Test a sequence of convolution layers quantized with PT2E quantization. diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 4a30ab6c2de..438126a179f 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -3,15 +3,8 @@ import torch -from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform -from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass -from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( - get_symmetric_quantization_config, - VulkanQuantizer, -) - from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge from executorch.exir.backend.canonical_partitioners.config_partitioner import ( @@ -94,66 +87,6 @@ def op_node_count(graph_module: torch.fx.GraphModule, canonical_op_name: str) -> class TestVulkanPasses(unittest.TestCase): - def test_fuse_int8pack_mm(self): - K = 256 - N = 256 - model = SingleLinearModule(K, N) - sample_inputs = model.get_sample_inputs() - - quantizer = VulkanQuantizer() - quantizer.set_global( - get_symmetric_quantization_config(is_dynamic=False, weight_bits=8) - ) - - edge_manager = quantize_and_lower_module( - model, - sample_inputs, - quantizer, - ) - - ep = edge_manager._edge_programs["forward"] - edge_manager.transform( - [ - AddmmToLinearTransform(), - FuseQuantizedOpsTransform(ep), - ] - ) - - gm = ep.graph_module - - self.assertEqual(op_node_count(gm, "_weight_int8pack_mm.default"), 1) - self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) - - def test_fuse_linear_qcs4w(self): - K = 256 - N = 256 - model = SingleLinearModule(K, N) - sample_inputs = model.get_sample_inputs() - - quantizer = VulkanQuantizer() - quantizer.set_global( - get_symmetric_quantization_config(is_dynamic=False, weight_bits=4) - ) - - edge_manager = quantize_and_lower_module( - model, - sample_inputs, - quantizer, - ) - - ep = edge_manager._edge_programs["forward"] - edge_manager.transform( - [ - AddmmToLinearTransform(), - FuseQuantizedOpsTransform(ep), - ] - ) - - gm = ep.graph_module - - self.assertEqual(op_node_count(gm, "linear_qcs4w.default"), 1) - self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) - def test_fuse_rotary_emb(self): """Test conversion of rotary embedding pattern to et_vk.apply_rotary_emb custom op.""" @@ -238,7 +171,8 @@ def _reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor): # Apply the rotary embedding pass ep = edge_manager._edge_programs["forward"] - rotary_pass = FusePatternsPass(ep) + rotary_pass = FusePatternsPass() + rotary_pass._exported_program = ep result = rotary_pass.call(ep.graph_module) # Verify that the pass was successful diff --git a/backends/vulkan/test/tester.py b/backends/vulkan/test/tester.py index b2066a06ec0..0707c09158f 100644 --- a/backends/vulkan/test/tester.py +++ b/backends/vulkan/test/tester.py @@ -44,8 +44,9 @@ def __init__( class Partition(BaseStages.Partition): def __init__(self, partitioner: Optional[Partitioner] = None): + vk_compile_spec = {"skip_bool_tensors": True} super().__init__( - partitioner=partitioner or VulkanPartitioner(), + partitioner=partitioner or VulkanPartitioner(vk_compile_spec), ) @@ -55,6 +56,10 @@ def __init__( partitioners: Optional[List[Partitioner]] = None, edge_compile_config: Optional[EdgeCompileConfig] = None, ): + if partitioners is None: + vk_compile_spec = {"skip_bool_tensors": True} + partitioners = [VulkanPartitioner(vk_compile_spec)] + super().__init__( default_partitioner_cls=VulkanPartitioner, partitioners=partitioners, diff --git a/backends/vulkan/test/utils.py b/backends/vulkan/test/utils.py index 41c1d92bd00..ab0a15ce4cf 100644 --- a/backends/vulkan/test/utils.py +++ b/backends/vulkan/test/utils.py @@ -8,18 +8,14 @@ import logging from collections import OrderedDict from copy import deepcopy - from enum import auto, Enum from typing import Any, List, Optional, Tuple import executorch.backends.vulkan.utils as utils - import torch - from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner - from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, @@ -31,15 +27,165 @@ serialize_from_bundled_program_to_flatbuffer, ) from executorch.exir import ExecutorchProgramManager, to_edge_transform_and_lower + +from executorch.exir.backend.backend_api import _get_node_list_with_same_tag + +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) + +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer + +from executorch.exir.lowered_backend_module import ( + create_exported_program_from_submodule, + create_submodule_from_nodes, +) from executorch.extension.pybindings.portable_lib import ( # @manual _load_for_executorch_from_buffer, ) from executorch.extension.pytree import tree_flatten from torch.export import export +from torch.export.exported_program import ExportedProgram +from torch.export.graph_signature import InputKind +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupportBase from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +class NodeFlagIsSetChecker(OperatorSupportBase): + """ + Check if a node is marked with a given field in node.meta["custom"] + """ + + def __init__(self, field: str) -> None: + super().__init__() + self.field = field + + def check_field(self, node: torch.fx.Node) -> bool: + if "custom" not in node.meta: + return False + + custom_map = node.meta["custom"] + if self.field not in custom_map: + return False + + return custom_map[self.field] + + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + if node.op == "placeholder" or node.op == "output": + return False + + # Check if the node itself is tagged + if self.check_field(node): + return True + + # Check if any direct user of this node is tagged + for user in node.users: + if self.check_field(user): + return True + + return False + + +class FlagBasedPartitioner(Partitioner): + """ + Partitioner that partitions based on whether node.meta["custom"][field] is set to + True. + """ + + def __init__(self, field: str) -> None: + super().__init__() + self.field = field + self.delegation_spec = DelegationSpec("custom_partition", []) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + capability_partitioner = CapabilityBasedPartitioner( + exported_program.graph_module, + NodeFlagIsSetChecker(self.field), + allows_single_node_partition=True, + ) + partition_list = capability_partitioner.propose_partitions() + + partition_tags = {} + for partition in partition_list: + for node in partition.nodes: + tag = f"tag{partition.id}" + node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + tag_mutated_buffer(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) + + +def mark_node_range( + graph_module: torch.fx.GraphModule, + end_idx: int = (2**31 - 1), + start_idx: int = 0, + field: str = "_in_target_subgraph", +): + call_fn_count = 0 + for node in graph_module.graph.nodes: + if "custom" not in node.meta: + node.meta["custom"] = {} + + node.meta["custom"][field] = False + + if node.op != "call_function": + continue + + call_fn_count += 1 + if call_fn_count >= start_idx and call_fn_count < end_idx: + node.meta["custom"][field] = True + + +def extract_submodule_program( + tagged_graph_module: torch.fx.GraphModule, + owning_program: ExportedProgram, + field: str = "_in_target_subgraph", +) -> ExportedProgram: + tagged_graph_module_output_node = tagged_graph_module.graph.output_node() + + partitioner = FlagBasedPartitioner(field) + partition_result = partitioner.partition(owning_program) + + tag, delegation_spec = next(iter(partition_result.partition_tags.items())) + node_list = _get_node_list_with_same_tag(tagged_graph_module, tag, owning_program) + + replace_ctx = tagged_graph_module._set_replace_hook( + owning_program.graph_signature.get_replace_hook() + ) + with replace_ctx: + submodule, call_module_node = create_submodule_from_nodes( + tagged_graph_module, node_list, tag + ) + + submodule_output_node = submodule.graph.output_node() + # Copy the output node meta from the original output node, because + # create_submodule_from_nodes doesn't cover the meta field + submodule_output_node.meta = tagged_graph_module_output_node.meta + + ( + submodule_program, + _, + _, + ) = create_exported_program_from_submodule( + submodule, + owning_program, + tag, + call_module_node, + False, + ) + + return submodule_program + + class QuantizationMode(Enum): """Enum to describe how a model should be quantized.""" @@ -50,11 +196,16 @@ class QuantizationMode(Enum): def get_exported_graph( model, sample_inputs, + sample_kwargs=None, dynamic_shapes=None, qmode=QuantizationMode.NONE, ) -> torch.fx.GraphModule: export_training_graph = export( - model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + model, + sample_inputs, + kwargs=sample_kwargs, + dynamic_shapes=dynamic_shapes, + strict=True, ).module() if qmode == QuantizationMode.NONE: @@ -76,12 +227,101 @@ def random_uniform_tensor(shape, low=0.0, high=1.0, device=None, dtype=None): if dtype is None: dtype = torch.float32 + # Handle integer types using randint + if dtype in ( + torch.int, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.long, + torch.short, + ): + low_int = int(low) + high_int = int(high) + # randint requires high > low, so ensure at least a range of 1 + if high_int <= low_int: + high_int = low_int + 1 + return torch.randint(low_int, high_int, shape, device=device, dtype=dtype) + + # Handle unsigned integer types + if dtype in (torch.uint8,): + low_int = max(0, int(low)) + high_int = int(high) + if high_int <= low_int: + high_int = low_int + 1 + return torch.randint(low_int, high_int, shape, device=device, dtype=dtype) + + # Handle boolean type + if dtype == torch.bool: + return torch.randint(0, 2, shape, device=device, dtype=torch.int8).bool() + + # Handle floating-point types (float16, float32, float64, bfloat16) return torch.empty(shape, device=device, dtype=dtype).uniform_(low, high) +def generate_sample_inputs( + exported_program: ExportedProgram, + low: float = -1.0, + high: float = 1.0, +) -> Tuple[torch.Tensor, ...]: + """ + Analyze the exported program graph to determine input shapes and dtypes, + then generate random sample inputs. + + Uses the graph signature to identify only user inputs (excluding parameters, + buffers, and other non-input placeholders). + + Args: + exported_program: The exported program to analyze + low: Lower bound for random uniform values (default: -1.0) + high: Upper bound for random uniform values (default: 1.0) + + Returns: + Tuple of randomly generated tensors matching the input specs + """ + sample_inputs = [] + + # Get the set of user input names by filtering input_specs for USER_INPUT kind + user_input_names = set() + for spec in exported_program.graph_signature.input_specs: + if spec.kind == InputKind.USER_INPUT: + if hasattr(spec.arg, "name"): + user_input_names.add(spec.arg.name) + + for node in exported_program.graph.nodes: + if node.op != "placeholder": + continue + + # Only process nodes that are user inputs (not parameters, buffers, etc.) + if node.name not in user_input_names: + continue + + if "val" in node.meta: + val = node.meta["val"] + shape = None + dtype = None + + if isinstance(val, torch.Tensor): + shape = tuple(val.shape) + dtype = val.dtype + elif hasattr(val, "shape") and hasattr(val, "dtype"): + # Handle FakeTensor or similar + shape = tuple(val.shape) + dtype = val.dtype + + if shape is not None and dtype is not None: + tensor = random_uniform_tensor(shape, low=low, high=high, dtype=dtype) + sample_inputs.append(tensor) + + inputs_flattened, _ = tree_flatten(sample_inputs) + return inputs_flattened + + def export_model_to_vulkan( model, sample_inputs, + sample_kwargs=None, dynamic_shapes=None, operator_blocklist=None, operator_allowlist=None, @@ -90,10 +330,17 @@ def export_model_to_vulkan( qmode=QuantizationMode.NONE, ): compile_options = {} - exported_graph = get_exported_graph(model, sample_inputs, qmode=qmode) + exported_graph = get_exported_graph( + model, + sample_inputs, + sample_kwargs=sample_kwargs, + dynamic_shapes=dynamic_shapes, + qmode=qmode, + ) program = export( exported_graph, sample_inputs, + kwargs=sample_kwargs, dynamic_shapes=dynamic_shapes, strict=True, ) @@ -262,16 +509,25 @@ def check_outputs_equal( ) return result else: + result = True for i in range(len(ref_output)): - if not torch.allclose( - model_output[i], ref_output[i], atol=atol, rtol=rtol - ): - print(f"\n=== Output {i} comparison failed ===") - print_tensor_comparison_errors( - model_output[i], ref_output[i], atol, rtol - ) - return False - return True + if isinstance(ref_output[i], torch.Tensor): + if not torch.allclose( + model_output[i], ref_output[i], atol=atol, rtol=rtol + ): + print(f"\n=== Output {i} comparison failed ===") + print_tensor_comparison_errors( + model_output[i], ref_output[i], atol, rtol + ) + result = False + elif isinstance(ref_output[i], int): + if not model_output[i] == ref_output[i]: + print(f"\n=== Output {i} comparison failed ===") + print(f"{model_output[i]} vs {ref_output[[i]]}") + result = False + else: + print(f"WARNING: Output {i} has type {type(ref_output[i])}") + return result else: # If one output, eager returns tensor while executor tuple of size 1 result = torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol) @@ -303,17 +559,17 @@ def run_and_check_output( Returns: bool: True if outputs match within tolerance, False otherwise """ - # Load the ExecutorTorch program + # Load the ExecuTorch program executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer) # Flatten inputs for execution inputs_flattened, _ = tree_flatten(sample_inputs) - # Run the ExecutorTorch program + # Run the ExecuTorch program model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) # Generate reference outputs using the reference model - ref_output = reference_model(*sample_inputs) + ref_output, _ = tree_flatten(reference_model(*sample_inputs)) # Check if outputs are equal return check_outputs_equal( @@ -415,11 +671,55 @@ def lower_module_and_test_output( return True +def create_bundled_program( + executorch_program: ExecutorchProgramManager, + sample_inputs: Tuple[torch.Tensor, ...], + expected_outputs: List[Any], + method_name: str = "forward", +) -> bytes: + """ + Create a bundled program containing the model and test cases for correctness testing. + + Args: + executorch_program: The ExecutorchProgramManager to bundle + sample_inputs: Sample inputs for the model + expected_outputs: Expected outputs from running the model with sample_inputs + method_name: Name of the method to test (default: "forward") + + Returns: + Serialized bundled program as bytes + """ + # Flatten sample inputs to match expected format + inputs_flattened, _ = tree_flatten(sample_inputs) + + # Create test suite with the sample inputs and expected outputs + test_suites = [ + MethodTestSuite( + method_name=method_name, + test_cases=[ + MethodTestCase( + inputs=inputs_flattened, + expected_outputs=expected_outputs, + ) + ], + ) + ] + + # Create bundled program + bundled_program = BundledProgram(executorch_program, test_suites) + + # Serialize to flatbuffer + bundled_buffer = serialize_from_bundled_program_to_flatbuffer(bundled_program) + + return bundled_buffer + + def save_bundled_program( model: torch.nn.Module, sample_inputs: Tuple[torch.Tensor], output_path: str, method_name: str = "forward", + sample_kwargs=None, et_program: Optional[ExecutorchProgramManager] = None, dynamic_shapes=None, ) -> str: @@ -439,32 +739,29 @@ def save_bundled_program( """ # If no ExecutorchProgramManager provided, export to Vulkan if et_program is None: - et_program = export_model_to_vulkan(model, sample_inputs, dynamic_shapes) + et_program = export_model_to_vulkan( + model, + sample_inputs, + sample_kwargs=sample_kwargs, + dynamic_shapes=dynamic_shapes, + ) + + if sample_kwargs is None: + sample_kwargs = {} # Generate expected outputs by running the model - expected_outputs = [getattr(model, method_name)(*sample_inputs)] + expected_outputs = [getattr(model, method_name)(*sample_inputs, **sample_kwargs)] - # Flatten sample inputs to match expected format - inputs_flattened, _ = tree_flatten(sample_inputs) - - # Create test suite with the sample inputs and expected outputs - test_suites = [ - MethodTestSuite( - method_name=method_name, - test_cases=[ - MethodTestCase( - inputs=inputs_flattened, - expected_outputs=expected_outputs, - ) - ], - ) - ] + # Flatten sample inputs with kwargs to match expected format + inputs_flattened, _ = tree_flatten((sample_inputs, sample_kwargs)) # Create bundled program - bp = BundledProgram(et_program, test_suites) - - # Serialize to flatbuffer - bp_buffer = serialize_from_bundled_program_to_flatbuffer(bp) + bp_buffer = create_bundled_program( + et_program, + tuple(inputs_flattened), + expected_outputs, + method_name, + ) # Ensure output path has correct extension if not output_path.endswith(".bpte"): @@ -783,3 +1080,26 @@ def find_bad_operators( "all_operators": all_operators, "test_count": test_count, } + + +def make_indent(indent_level): + indent_str = "" + for _ in range(indent_level): + indent_str += " " + return indent_str + + +def print_output(outputs, n: int = 0, indent_level: int = 0): + if isinstance(outputs, (list, tuple)): + print(f"{make_indent(indent_level)}output_{n} = {type(outputs)}") + new_indent_level = indent_level + 2 + for n, test_out in enumerate(outputs): + print_output(test_out, n, new_indent_level) + elif isinstance(outputs, torch.Tensor): + print( + f"{make_indent(indent_level)}output_{n} = test_utils.random_uniform_tensor({outputs.shape}, low={outputs.min().item()}, high={outputs.max().item()}, dtype={outputs.dtype})" + ) + elif isinstance(outputs, int): + print(f"{make_indent(indent_level)}output_{n} = {outputs}") + else: + print(f"{make_indent(indent_level)}output_{n} = {type(outputs)}") diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index 07d28229221..038a838484d 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -44,6 +44,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( if (v_dst.storage_type() == utils::kBuffer) { kernel_name = "nchw_to_buffer"; add_dtype_suffix(kernel_name, v_dst.dtype()); + add_dtype_suffix(kernel_name, v_dst.dtype()); return VK_KERNEL_FROM_STR(kernel_name); } @@ -53,6 +54,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( } add_storage_type_suffix(kernel_name, v_dst.storage_type()); add_dtype_suffix(kernel_name, v_dst.dtype()); + add_dtype_suffix(kernel_name, v_dst.dtype()); return VK_KERNEL_FROM_STR(kernel_name); } @@ -78,6 +80,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( if (v_src.storage_type() == utils::kBuffer) { kernel_name = "buffer_to_nchw"; add_dtype_suffix(kernel_name, v_src.dtype()); + add_dtype_suffix(kernel_name, v_src.dtype()); return VK_KERNEL_FROM_STR(kernel_name); } @@ -87,6 +90,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( } add_storage_type_suffix(kernel_name, v_src.storage_type()); add_dtype_suffix(kernel_name, v_src.dtype()); + add_dtype_suffix(kernel_name, v_src.dtype()); return VK_KERNEL_FROM_STR(kernel_name); } @@ -395,7 +399,11 @@ void record_matmul_texture3d( _(int8_t, QInt8) void fill_vtensor(api::vTensor& vten, std::vector& data) { - api::StagingBuffer staging_buffer(api::context(), vten.dtype(), data.size()); + api::StagingBuffer staging_buffer( + api::context(), + vten.dtype(), + data.size(), + vkapi::CopyDirection::HOST_TO_DEVICE); #define CASE(ctype, name) \ case vkapi::ScalarType::name: { \ @@ -482,7 +490,10 @@ void fill_vtensor( void extract_vtensor(api::vTensor& vten, std::vector& data) { api::StagingBuffer staging_buffer( - api::context(), vten.dtype(), vten.staging_buffer_numel()); + api::context(), + vten.dtype(), + vten.staging_buffer_numel(), + vkapi::CopyDirection::DEVICE_TO_HOST); if (vten.storage_type() == utils::StorageType::BUFFER) { record_buffer_to_nchw_op(api::context(), vten, staging_buffer.buffer()); diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h index 1fd40b6f815..2af445bc800 100644 --- a/backends/vulkan/test/utils/test_utils.h +++ b/backends/vulkan/test/utils/test_utils.h @@ -42,7 +42,8 @@ vkcompute::api::StagingBuffer staging_buffer_##tensor( \ vkcompute::api::context(), \ vkapi::kFloat, \ - tensor.staging_buffer_numel()); \ + tensor.staging_buffer_numel(), \ + vkapi::CopyDirection::HOST_TO_DEVICE); \ record_nchw_to_image_op( \ vkcompute::api::context(), staging_buffer_##tensor.buffer(), tensor); @@ -50,7 +51,8 @@ vkcompute::api::StagingBuffer staging_buffer_##tensor( \ vkcompute::api::context(), \ vkapi::kFloat, \ - tensor.staging_buffer_numel()); \ + tensor.staging_buffer_numel(), \ + vkapi::CopyDirection::DEVICE_TO_HOST); \ record_image_to_nchw_op( \ vkcompute::api::context(), tensor, staging_buffer_##tensor.buffer()); diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index a193d02da88..15d5cf9254e 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -102,14 +102,15 @@ TEST_F(VulkanComputeAPITest, print_adapter) { std::cout << *(context()->adapter_ptr()) << std::endl; } -#if defined(VULKAN_DEBUG) && defined(VK_KHR_pipeline_executable_properties) +#if defined(VK_KHR_pipeline_executable_properties) && \ + defined(ETVK_INSPECT_PIPELINES) TEST_F(VulkanComputeAPITest, print_shader_executable_properties) { context()->print_shader_executable_properties( VK_KERNEL(binary_add_nobroadcast__test_half), {0}); } -#endif // VULKAN_DEBUG && VK_KHR_pipeline_executable_properties +#endif // VK_KHR_pipeline_executable_properties && ETVK_INSPECT_PIPELINES std::vector get_reference_strides( const std::vector& sizes, @@ -187,6 +188,8 @@ std::vector get_reference_strides( default: return {}; } + default: + VK_THROW("Unsupported memory layout: ", layout); } return {}; } @@ -527,7 +530,8 @@ TEST_F(VulkanComputeAPITest, spec_var_classes_test) { TEST_F(VulkanComputeAPITest, spec_var_shader_test) { size_t len = 16; - StagingBuffer buffer(context(), vkapi::kFloat, len); + StagingBuffer buffer( + context(), vkapi::kFloat, len, vkapi::CopyDirection::DEVICE_TO_HOST); float scale = 3.0f; float offset = 1.5f; @@ -599,7 +603,10 @@ TEST_F(VulkanComputeAPITest, update_params_between_submit) { } StagingBuffer staging_buffer( - context(), vkapi::kFloat, a.staging_buffer_numel()); + context(), + vkapi::kFloat, + a.staging_buffer_numel(), + vkapi::CopyDirection::DEVICE_TO_HOST); record_image_to_nchw_op(context(), a, staging_buffer.buffer()); submit_to_gpu(); @@ -619,7 +626,8 @@ TEST_F(VulkanComputeAPITest, update_params_between_submit) { template void test_storage_buffer_type(const size_t len) { - StagingBuffer buffer(context(), dtype, len); + StagingBuffer buffer( + context(), dtype, len, vkapi::CopyDirection::DEVICE_TO_HOST); std::string kernel_name("idx_fill_buffer"); switch (dtype) { @@ -1908,413 +1916,6 @@ TEST(VulkanComputeGraphTest, test_clone) { } } -TEST(VulkanComputeGraphTest, test_etvk_copy_offset_node) { - GraphConfig config; - ComputeGraph graph(config); - - int64_t n = 6; - int64_t c = 12; - int64_t h = 4; - int64_t w = 8; - utils::GPUMemoryLayout memory_layout = - utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; - - std::vector size = {n, c, h, w}; - - IOValueRef a = graph.add_input_tensor(size, vkapi::kFloat, memory_layout); - - IOValueRef out = {}; - out.value = graph.add_tensor(size, vkapi::kFloat, memory_layout); - - // Notice that copy_node operates on in texture's x, y, z dimension. In the - // comment, we provide the cooresponding coordinate in nchw. - - // src_offset is (n=0, c=4, h=1, w=1) - ValueRef src_offset_ref = graph.add_scalar_list({1, 1, 1}); - - // dst_offset is (n=1, c=8, h=2, w=0) in nchw coordinate - // Argument is {x, y, z}. - // x = 0 since w = 0 - // y = 2 since h = 2 - // z = c / 4 + 2 since - // 1. there c/4 planes per batch, n=1 means we are on the first batch; - // 2. +2 because c = 8, with channel packing it means two texels. - ValueRef dst_offset_ref = graph.add_scalar_list({0, 2, c / 4 + 2}); - - // range is (n=1, c=8, h=2, w=4) - // Argument is {x, y, z}. - // x = 4 since w = 4 - // y = 2 since h = 2 - // z = 2 since we are only copying 8 channels, hence 2 texel. n = 1 can be a - // bit misleading here, since it gives the impression that we are copying the - // entire channel. However, remember when we copy, we are trying to - // dst[dst_offset:dst_offset + range] = src[src_offset:src_offset + range], - // range must be non zero. - ValueRef range_ref = graph.add_scalar_list({4, 2, 2}); - - auto copyFn = VK_GET_OP_FN("etvk.copy_offset"); - copyFn( - graph, {a.value, range_ref, src_offset_ref, dst_offset_ref, out.value}); - - out.staging = graph.set_output_tensor(out.value); - - graph.prepare(); - graph.prepack(); - - fill_vtensor(graph, a, 0.0f, /*iota = */ true); - - graph.execute(); - - EXTRACT_TENSOR(out); - EXTRACT_TENSOR(a); - - // We will examine the results in the dst_range - // The value in the cooresponding coordinate should match between the source - // and destination tensor. We loop thru the range, calculate both the src and - // dst index using the offsets, and compare the values in the extracted - // vector. They should match. - int n_idx = 0; - // at each nested loop, index range from dst_offset to dst_offset + range - - for (int c_idx = 0; c_idx < 8; c_idx++) { - for (int h_idx = 0; h_idx < 2; h_idx++) { - for (int w_idx = 0; w_idx < 4; w_idx++) { - auto dst_idx = - get_buf_idx(graph, out, {n_idx + 1, c_idx + 8, h_idx + 2, w_idx}); - auto src_idx = - get_buf_idx(graph, a, {n_idx, c_idx + 4, h_idx + 1, w_idx + 1}); - - EXPECT_TRUE(data_out[dst_idx] == data_a[src_idx]); - } - } - } -} - -TEST(VulkanComputeGraphTest, DISABLED_test_etvk_copy_channel_offset_node) { - GraphConfig config; - ComputeGraph graph(config); - - int64_t n = 2; - int64_t c = 12; - int64_t h = 4; - int64_t w = 8; - utils::GPUMemoryLayout memory_layout = - utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; - - std::vector size = {n, c, h, w}; - - IOValueRef a = graph.add_input_tensor(size, vkapi::kFloat, memory_layout); - - IOValueRef out = {}; - out.value = graph.add_tensor(size, vkapi::kFloat, memory_layout); - - int64_t src_offset = 2; - int64_t dst_offset = 3; - int64_t range = 7; - - ValueRef src_offset_ref = graph.add_scalar(src_offset); - ValueRef dst_offset_ref = graph.add_scalar(dst_offset); - ValueRef range_ref = graph.add_scalar(range); - - auto copyFn = VK_GET_OP_FN("etvk.copy_channel_offset"); - copyFn( - graph, {a.value, range_ref, src_offset_ref, dst_offset_ref, out.value}); - - out.staging = graph.set_output_tensor(out.value); - - graph.prepare(); - graph.prepack(); - - fill_vtensor(graph, a, 0.0f, true); - - graph.execute(); - - EXTRACT_TENSOR(out); - EXTRACT_TENSOR(a); - - for (int n_idx = 0; n_idx < n; n_idx++) { - for (int c_idx = 0; c_idx < range; c_idx++) { - for (int h_idx = 0; h_idx < h; h_idx++) { - for (int w_idx = 0; w_idx < w; w_idx++) { - auto src_idx = - get_buf_idx(graph, a, {n_idx, c_idx + src_offset, h_idx, w_idx}); - auto dst_idx = get_buf_idx( - graph, out, {n_idx, c_idx + dst_offset, h_idx, w_idx}); - EXPECT_TRUE(data_out[dst_idx] == data_a[src_idx]); - } - } - } - } -} - -TEST( - VulkanComputeGraphTest, - DISABLED_test_etvk_copy_channel_offset_node_clean_boundary) { - // Tricky part for channel copy is handling the boundary across multiple copy. - // For example, when we concat two [3, 1, 1] nchw-tensors along the channel - // dimension, due to channel packing, elements from different source texel - // will be packed into same destination texel at the boundaries. - GraphConfig config; - ComputeGraph graph(config); - - int64_t n = 2; - int64_t c = 12; - int64_t h = 4; - int64_t w = 8; - utils::GPUMemoryLayout memory_layout = - utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; - - std::vector size = {n, c, h, w}; - - IOValueRef zero = graph.add_input_tensor(size, vkapi::kFloat, memory_layout); - IOValueRef a = graph.add_input_tensor(size, vkapi::kFloat, memory_layout); - IOValueRef b = graph.add_input_tensor(size, vkapi::kFloat, memory_layout); - - IOValueRef out = {}; - out.value = graph.add_tensor(size, vkapi::kFloat, memory_layout); - - auto copyFn = VK_GET_OP_FN("etvk.copy_channel_offset"); - - // Make sure entire out tensor is zeroed. The zero tensor will be filled with - // zero later. - copyFn( - graph, - {zero.value, - graph.add_scalar(c), - graph.add_scalar(0), - graph.add_scalar(0), - out.value}); - - int64_t a_src_offset = 0; - int64_t a_dst_offset = 2; - int64_t a_range = 5; - // a will write to channge [2, 7) - copyFn( - graph, - {a.value, - graph.add_scalar(a_range), - graph.add_scalar(a_src_offset), - graph.add_scalar(a_dst_offset), - out.value}); - - // b will write to channel [6, 11) - // Intentional for b to override channel=6 - int64_t b_src_offset = 0; - int64_t b_dst_offset = 6; - int64_t b_range = 5; - - copyFn( - graph, - {b.value, - graph.add_scalar(b_range), - graph.add_scalar(b_src_offset), - graph.add_scalar(b_dst_offset), - out.value}); - - out.staging = graph.set_output_tensor(out.value); - - graph.prepare(); - graph.prepack(); - - float a_value = 1.0f; - float b_value = 2.0f; - float zero_value = 0.0f; - fill_vtensor(graph, a, a_value); - fill_vtensor(graph, b, b_value); - fill_vtensor(graph, zero, zero_value); - - graph.execute(); - - EXTRACT_TENSOR(out); - - for (int n_idx = 0; n_idx < n; n_idx++) { - // c_idx only up to a_range-1 because the expected overwrite by b - for (int c_idx = a_dst_offset; c_idx < a_dst_offset + a_range - 1; - c_idx++) { - for (int h_idx = 0; h_idx < h; h_idx++) { - for (int w_idx = 0; w_idx < w; w_idx++) { - auto dst_idx = get_buf_idx(graph, out, {n_idx, c_idx, h_idx, w_idx}); - EXPECT_TRUE(data_out[dst_idx] == a_value); - } - } - } - } - - for (int n_idx = 0; n_idx < n; n_idx++) { - for (int c_idx = b_dst_offset; c_idx < b_dst_offset + b_range; c_idx++) { - for (int h_idx = 0; h_idx < h; h_idx++) { - for (int w_idx = 0; w_idx < w; w_idx++) { - auto dst_idx = get_buf_idx(graph, out, {n_idx, c_idx, h_idx, w_idx}); - EXPECT_TRUE(data_out[dst_idx] == b_value); - } - } - } - } - - // Also verify that data before a_dst_offset and after b_dst_offset + b_range - // are untouched. - for (int n_idx = 0; n_idx < n; n_idx++) { - for (int c_idx = 0; c_idx < a_dst_offset; c_idx++) { - for (int h_idx = 0; h_idx < h; h_idx++) { - for (int w_idx = 0; w_idx < w; w_idx++) { - auto dst_idx = get_buf_idx(graph, out, {n_idx, c_idx, h_idx, w_idx}); - EXPECT_TRUE(data_out[dst_idx] == zero_value); - } - } - } - } - - for (int n_idx = 0; n_idx < n; n_idx++) { - for (int c_idx = b_dst_offset + b_range; c_idx < c; c_idx++) { - for (int h_idx = 0; h_idx < h; h_idx++) { - for (int w_idx = 0; w_idx < w; w_idx++) { - auto dst_idx = get_buf_idx(graph, out, {n_idx, c_idx, h_idx, w_idx}); - EXPECT_TRUE(data_out[dst_idx] == zero_value); - } - } - } - } -} - -TEST(VulkanComputeGraphTest, test_etvk_copy_offset_int_node) { - GraphConfig config; - ComputeGraph graph(config); - - int64_t n = 6; - int64_t c = 12; - int64_t h = 4; - int64_t w = 8; - utils::GPUMemoryLayout memory_layout = - utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; - - std::vector size = {n, c, h, w}; - - IOValueRef a = graph.add_input_tensor(size, vkapi::kInt, memory_layout); - - IOValueRef out = {}; - out.value = graph.add_tensor(size, vkapi::kInt, memory_layout); - - // Notice that copy_node operates on in texture's x, y, z dimension. In the - // comment, we provide the cooresponding coordinate in nchw. - - // src_offset is (n=0, c=4, h=1, w=1) - ValueRef src_offset_ref = graph.add_scalar_list({1, 1, 1}); - - // dst_offset is (n=1, c=8, h=2, w=0) in nchw coordinate - // Argument is {x, y, z}. - // x = 0 since w = 0 - // y = 2 since h = 2 - // z = c / 4 + 2 since - // 1. there c/4 planes per batch, n=1 means we are on the first batch; - // 2. +2 because c = 8, with channel packing it means two texels. - ValueRef dst_offset_ref = graph.add_scalar_list({0, 2, c / 4 + 2}); - - // range is (n=1, c=8, h=2, w=4) - // Argument is {x, y, z}. - // x = 4 since w = 4 - // y = 2 since h = 2 - // z = 2 since we are only copying 8 channels, hence 2 texel. n = 1 can be a - // bit misleading here, since it gives the impression that we are copying the - // entire channel. However, remember when we copy, we are trying to - // dst[dst_offset:dst_offset + range] = src[src_offset:src_offset + range], - // range must be non zero. - ValueRef range_ref = graph.add_scalar_list({4, 2, 2}); - - auto copyFn = VK_GET_OP_FN("etvk.copy_offset"); - copyFn( - graph, {a.value, range_ref, src_offset_ref, dst_offset_ref, out.value}); - - out.staging = graph.set_output_tensor(out.value); - - graph.prepare(); - graph.prepack(); - - fill_vtensor(graph, a, 0, /*iota = */ true); - - graph.execute(); - - EXTRACT_TENSOR(out); - EXTRACT_TENSOR(a); - - // We will examine the results in the dst_range - // The value in the cooresponding coordinate should match between the source - // and destination tensor. We loop thru the range, calculate both the src and - // dst index using the offsets, and compare the values in the extracted - // vector. They should match. - int n_idx = 0; - // at each nested loop, index range from dst_offset to dst_offset + range - - for (int c_idx = 0; c_idx < 8; c_idx++) { - for (int h_idx = 0; h_idx < 2; h_idx++) { - for (int w_idx = 0; w_idx < 4; w_idx++) { - auto dst_idx = - get_buf_idx(graph, out, {n_idx + 1, c_idx + 8, h_idx + 2, w_idx}); - auto src_idx = - get_buf_idx(graph, a, {n_idx, c_idx + 4, h_idx + 1, w_idx + 1}); - - EXPECT_TRUE(data_out[dst_idx] == data_a[src_idx]); - } - } - } -} - -TEST(VulkanComputeGraphTest, DISABLED_test_etvk_copy_channel_offset_int_node) { - GraphConfig config; - ComputeGraph graph(config); - - int64_t n = 2; - int64_t c = 12; - int64_t h = 4; - int64_t w = 8; - utils::GPUMemoryLayout memory_layout = - utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; - - std::vector size = {n, c, h, w}; - - IOValueRef a = graph.add_input_tensor(size, vkapi::kFloat, memory_layout); - - IOValueRef out = {}; - out.value = graph.add_tensor(size, vkapi::kFloat, memory_layout); - - int64_t src_offset = 2; - int64_t dst_offset = 3; - int64_t range = 7; - - ValueRef src_offset_ref = graph.add_scalar(src_offset); - ValueRef dst_offset_ref = graph.add_scalar(dst_offset); - ValueRef range_ref = graph.add_scalar(range); - - auto copyFn = VK_GET_OP_FN("etvk.copy_channel_offset"); - copyFn( - graph, {a.value, range_ref, src_offset_ref, dst_offset_ref, out.value}); - - out.staging = graph.set_output_tensor(out.value); - - graph.prepare(); - graph.prepack(); - - fill_vtensor(graph, a, 0.0f, true); - - graph.execute(); - - EXTRACT_TENSOR(out); - EXTRACT_TENSOR(a); - - for (int n_idx = 0; n_idx < n; n_idx++) { - for (int c_idx = 0; c_idx < range; c_idx++) { - for (int h_idx = 0; h_idx < h; h_idx++) { - for (int w_idx = 0; w_idx < w; w_idx++) { - auto src_idx = - get_buf_idx(graph, a, {n_idx, c_idx + src_offset, h_idx, w_idx}); - auto dst_idx = get_buf_idx( - graph, out, {n_idx, c_idx + dst_offset, h_idx, w_idx}); - EXPECT_TRUE(data_out[dst_idx] == data_a[src_idx]); - } - } - } - } -} - TEST(VulkanComputeGraphTest, test_view_change_packing) { std::vector> layout_pairs = { @@ -2417,7 +2018,11 @@ void run_from_gpu_test( vten.sizes_ubo()); } - StagingBuffer staging_buffer(context(), dtype, vten.staging_buffer_numel()); + StagingBuffer staging_buffer( + context(), + dtype, + vten.staging_buffer_numel(), + vkapi::CopyDirection::DEVICE_TO_HOST); if (dtype == vkapi::kChar && !context()->adapter_ptr()->has_full_int8_buffers_support()) { @@ -2453,7 +2058,10 @@ void round_trip_test( // Create and fill input staging buffer StagingBuffer staging_buffer_in( - context(), dtype, vten.staging_buffer_numel()); + context(), + dtype, + vten.staging_buffer_numel(), + vkapi::CopyDirection::HOST_TO_DEVICE); std::vector data_in(staging_buffer_in.numel()); for (int i = 0; i < staging_buffer_in.numel(); i++) { @@ -2463,7 +2071,10 @@ void round_trip_test( // Output staging buffer StagingBuffer staging_buffer_out( - context(), dtype, vten.staging_buffer_numel()); + context(), + dtype, + vten.staging_buffer_numel(), + vkapi::CopyDirection::DEVICE_TO_HOST); record_nchw_to_image_op(context(), staging_buffer_in.buffer(), vten); diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 96f200eecbc..2ca2ddf19b7 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -8,26 +8,18 @@ from typing import Any, List, Optional, Set, Tuple, Union import torch - from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, VkStorageType, ) - from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) - from executorch.exir.dialects.edge._ops import EdgeOpOverload - from executorch.exir.tensor import TensorSpec - from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param - from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter - from torch.export import ExportedProgram - from torch.export.exported_program import InputKind from torch.export.graph_signature import TensorArgument @@ -128,7 +120,7 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: is_get_attr_node(node) or is_param(program, node) or is_buffer(program, node) - or is_constant(program, node) + or is_lifted_tensor_constant(program, node) ) @@ -206,6 +198,8 @@ def is_tensor_arg_node(node: Any) -> bool: if isinstance(node, torch.fx.Node): return is_tensor_node(node) elif isinstance(node, (list, tuple)): + if len(node) == 0: + return False return all(is_tensor_node(n) for n in node) return False @@ -257,6 +251,47 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool: return False +def ndim_of(node: Any) -> Optional[int]: + """ + Returns the number of dimensions of the tensor produced by the given node + """ + if not is_single_tensor_node(node): + return None + + return node.meta["val"].ndim + + +def is_unsqueezed_vector(node: torch.fx.Node) -> bool: + """ + Returns True if the node's tensor has all dimensions equal to 1 except for the last dimension. + """ + if not is_single_tensor_node(node): + return False + + tensor = node.meta["val"] + assert isinstance(tensor, FakeTensor) + + if len(tensor.shape) < 1: + return False + # All dims except last are 1, last can be any size + return all(dim == 1 for dim in tensor.shape[:-1]) + + +def op_contains_bool_tensor(node: torch.fx.Node) -> bool: + """ + Returns true if the operator used to compute the given node contains a bool tensor + """ + if is_tensor_node(node) and tensor_node_is_bool(node): + return True + + for arg_node in node.args: + # pyre-ignore[6] + if is_tensor_node(arg_node) and tensor_node_is_bool(arg_node): + return True + + return False + + def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]: primary_arg_idx: Optional[int] = None for i, arg_node in enumerate(node.args): @@ -330,6 +365,18 @@ def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]: return None +def node_has_target(node: Any, target: str): + if not hasattr(node, "target"): + return False + + if isinstance(node.target, str): + return node.target == target + elif hasattr(node.target, "name"): + return node.target.name() == target + + return False + + ## ## Memory Layout, Storage Type Determination ## @@ -344,12 +391,27 @@ def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]: VkStorageType.TEXTURE_3D, } +# Memory layouts available to non-quantized tensors all_memory_layouts: Set[VkMemoryLayout] = { VkMemoryLayout.TENSOR_WIDTH_PACKED, VkMemoryLayout.TENSOR_HEIGHT_PACKED, VkMemoryLayout.TENSOR_CHANNELS_PACKED, } +# Memory layouts available to quantized tensors +all_quantized_memory_layouts: Set[VkMemoryLayout] = { + VkMemoryLayout.PACKED_INT8_4W4C, + VkMemoryLayout.PACKED_INT8_4H4W, +} + +universal_memory_layout_set: Set[VkMemoryLayout] = { + VkMemoryLayout.TENSOR_WIDTH_PACKED, + VkMemoryLayout.TENSOR_HEIGHT_PACKED, + VkMemoryLayout.TENSOR_CHANNELS_PACKED, + VkMemoryLayout.PACKED_INT8_4W4C, + VkMemoryLayout.PACKED_INT8_4H4W, +} + MemoryLayoutSet = Set[VkMemoryLayout] MemoryLayoutSetList = Union[MemoryLayoutSet, List[MemoryLayoutSet]] @@ -400,6 +462,12 @@ def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageEx height = (height + 3) // 4 elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED: channels = (channels + 3) // 4 + elif layout == VkMemoryLayout.PACKED_INT8_4W4C: + width = (width + 3) // 4 + channels = (channels + 3) // 4 + elif layout == VkMemoryLayout.PACKED_INT8_4H4W: + height = (height + 3) // 4 + width = (width + 3) // 4 else: raise RuntimeError(f"Unsupported memory layout {layout}") @@ -558,6 +626,16 @@ def make_intersect(self, other: "TensorRepSet") -> "TensorRepSet": self.valid_texture_layouts & other.valid_texture_layouts, ) + def make_union(self, other: "TensorRepSet") -> "TensorRepSet": + """ + Merge this TensorRepSet with another TensorRepSet, returning a new TensorRepSet + with the union of the two. + """ + return TensorRepSet( + self.valid_buffer_layouts | other.valid_buffer_layouts, + self.valid_texture_layouts | other.valid_texture_layouts, + ) + def is_compatible(self, storage: TensorRepr) -> bool: """ Check if this TensorRepr is compatible with the given TensorRepSet. @@ -683,28 +761,44 @@ def make_filtered_tensor_repset( if len(tensor_val.shape) > 4: return TensorRepSet(tensor_repset.valid_buffer_layouts, set()) - # Bool tensors are currently not supported - if tensor_val.dtype == torch.bool: - return NO_STORAGE - return TensorRepSet(tensor_repset.valid_buffer_layouts, valid_texture_layouts) ## Convenience TensorRepSet definitions +# Only includes memory layouts that can be used by non-quantized tensors + CONTIGUOUS_ANY = TensorRepSet( {VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_WIDTH_PACKED} ) CONTIGUOUS_BUFFER = TensorRepSet({VkMemoryLayout.TENSOR_WIDTH_PACKED}, set()) WIDTH_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_WIDTH_PACKED}) +HEIGHT_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_HEIGHT_PACKED}) CHANNELS_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED}) +CHANNELS_PACKED_ANY = TensorRepSet( + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, {VkMemoryLayout.TENSOR_CHANNELS_PACKED} +) + +CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_CHANNELS_PACKED} +) + ANY_TEXTURE = TensorRepSet(set(), all_memory_layouts) ANY_BUFFER = TensorRepSet(all_memory_layouts, set()) - ANY_STORAGE = TensorRepSet(all_memory_layouts, all_memory_layouts) + +# Only includes memory layouts that can be used by quantized tensors + +PACKED_INT8_4W4C_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4W4C}, set()) + +# Special use RepSets + NO_STORAGE = TensorRepSet(set(), set()) +ALL_STORAGES_REPSET = TensorRepSet( + universal_memory_layout_set, universal_memory_layout_set +) class TensorRepSetList: @@ -828,19 +922,19 @@ def __init__( # noqa: C901 # Now, go through the arguments of the operator and create a filtered repset # for each based on the actual tensor value. args_repset_list = TensorRepSetList([]) - common_arg_repset = ANY_STORAGE + common_arg_repset = ALL_STORAGES_REPSET for i, arg_node in enumerate(op_node.args): arg_repset = inputs_repsets[i] - # Use ANY_STORAGE for non-tensor nodes so they don't cause the op repsets to - # appear empty + # Use ALL_STORAGES_REPSET for non-tensor nodes so they don't cause the op + # repsets to appear empty if not is_tensor_arg_node(arg_node): - args_repset_list.append(ANY_STORAGE) + args_repset_list.append(ALL_STORAGES_REPSET) # NO_STORAGE is used to denote that an input is either a non tensor arg or # a weight tensor that is not prepacked. Similar to the above, use - # ANY_STORAGE in this case. + # ALL_STORAGES_REPSET in this case. elif arg_repset.is_empty(): - args_repset_list.append(ANY_STORAGE) + args_repset_list.append(ALL_STORAGES_REPSET) else: assert not arg_repset.is_empty() @@ -853,7 +947,7 @@ def __init__( # noqa: C901 # Repeat for output tensors. outs_repset_list = TensorRepSetList([]) - common_out_repset = ANY_STORAGE + common_out_repset = ALL_STORAGES_REPSET if num_tensors_in_node(op_node) == 1: common_out_repset = make_filtered_tensor_repset( op_node.meta["val"], outputs_repsets[0], texture_limits @@ -1026,6 +1120,25 @@ def try_constrain_with_arg_repset( self.assert_sync_contraints() return True + def try_constrain_with_out_repset(self, repset: TensorRepSet): + # Skip for operators that must synchronize the input and output representations + # or operators that have more than one output repset + if self.sync_primary_io_repr or len(self.outs_repset_list) > 1: + return False + + out_current_repset = self.outs_repset_list[0] + + if out_current_repset == repset: + return False + + if not out_current_repset.any_in_common(repset): + return False + + self.outs_repset_list[0] = out_current_repset.make_intersect(repset) + + self.assert_sync_contraints() + return True + def pick_representations(self) -> Tuple[TensorReprList, TensorReprList]: """ For each tensor participating in the op, pick a representation for it among the @@ -1218,6 +1331,36 @@ def is_in_8bit_range(tensor: torch.Tensor) -> bool: ## +def normalize_dims(dims: Union[int, List[int]], ndim: int) -> Union[int, List[int]]: + """ + Normalize dimension indices to be non-negative and within [0, ndim). + Accepts a single int or a list of ints. + """ + if isinstance(dims, int): + if dims < 0: + dims += ndim + + return dims + + normalized = [] + for d in dims: + if d < 0: + d += ndim + normalized.append(d) + + return normalized + + +def nchw_dim_to_whcn_dim(nchw_dim: int, ndim: int) -> int: + # Handle negative indices for nchw_dim + if nchw_dim < 0: + nchw_dim += ndim + + assert nchw_dim >= 0 and nchw_dim < ndim + whcn_dim = (ndim - 1) - nchw_dim + return whcn_dim + + def get_tensor_val_str(tensor_val: FakeTensor) -> str: return f"{tensor_val.dtype}: {tensor_val.shape}" @@ -1269,6 +1412,7 @@ def update_program_state_dict( updated_tensor: torch.Tensor, ) -> None: target_name = None + kind = None # Iterate over all the tensors in the graph signature, and find # the one corresponding to the parameter/buffer name for input_ in program.graph_signature.input_specs: @@ -1277,6 +1421,7 @@ def update_program_state_dict( and isinstance(input_.arg, TensorArgument) and input_.arg.name == buffer_name ): + kind = input_.kind target_name = input_.target break @@ -1286,6 +1431,9 @@ def update_program_state_dict( ), f"could not find {buffer_name} in source program signature" assert target_name in program.state_dict, f"could not find {target_name}" + if kind == InputKind.PARAMETER: + updated_tensor = torch.nn.Parameter(updated_tensor, requires_grad=False) + # Finally, overwrite the current tensor with updated tensor program.state_dict[target_name] = updated_tensor diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 69d3cdef75d..3ccbdc8ab85 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -6,12 +6,11 @@ # pyre-strict +import copy from functools import partial - -from typing import Any, Dict, final, List +from typing import Any, Callable, Dict, final, List import executorch.backends.vulkan.utils as utils - from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform @@ -22,14 +21,12 @@ FoldQDQPass, FuseQuantizedOpsTransform, insert_prepack_nodes, - RemoveLocalScalarDenseOpsTransform, RemoveRedundantOpsTransform, SqueezeUnsqueezeInputs, TagMemoryMetaPass, ) from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan._passes.remove_asserts import RemoveAssertsTransform - from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, @@ -39,7 +36,6 @@ serialize_vulkan_graph, ) from executorch.backends.xnnpack._passes import FuseBatchNormPass - from executorch.exir.backend.backend_details import ( BackendDetails, CompileSpec, @@ -47,16 +43,12 @@ PreprocessResult, ) from executorch.exir.backend.utils import DelegateMappingBuilder - from executorch.exir.memory_planning import greedy, MemoryPlanningAlgorithmSuite from executorch.exir.pass_base import ExportPass, PassBase - from executorch.exir.passes import MemoryPlanningPass, SpecPropPass - from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass - -from executorch.exir.program._program import _copy_module - +from executorch.exir.program._program import _transform +from torch._export.verifier import Verifier from torch.export._remove_auto_functionalized_pass import ( unsafe_remove_auto_functionalized_pass, ) @@ -64,28 +56,34 @@ DEFAULT_DEBUG_HANDLE = 65535 +class _any_op(Verifier): + # Set training dialect to skip functional check in base verifier + dialect = "TRAINING" + + def allowed_op_types(self): + return (Callable,) + + # pyre-ignore def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: for p in passes: - if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): - new_gm = program.graph_module - # This is a workaround to allow the memory planning pass to work without - # having to first apply ToOutVarPass(). See the `greedy()` function in - # `exir.memory_planning`; if this attribute isn't set, assertions in - # `collect_spec_from_nodes()` will fail. - if isinstance(p, MemoryPlanningPass): - new_gm.encounter_to_out_var_failure = True - - new_gm_res = p(new_gm) - assert new_gm_res is not None - new_gm = new_gm_res.graph_module - + if isinstance(p, MemoryPlanningPass) and hasattr(p, "run"): + p.run(program.graph_module) + + elif issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): + # Some passes require the ep to be provided. However, since the ep may be + # updated with each pass applied, the ep must be set right before calling + # the pass. _exported_program is the attribute used by XNNPACK and Vulkan + # passes to store the exported program. + if hasattr(p, "_exported_program"): + p._exported_program = program + + program = _transform(program, p, override_verifiers=[_any_op]) # See the application of this function in exir/program/_program.py for more # details on why this step is necessary. if isinstance(p, SpecPropPass): - p.update_placeholder_tensor_specs(program, new_gm) + p.update_placeholder_tensor_specs(program, program.graph_module) - _copy_module(program.graph_module, new_gm) else: program = p(program) @@ -112,6 +110,9 @@ def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: if spec.key == "downcast_64_bit": options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + if spec.key == "force_fp16": + options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + # Unhandled options are ignored return options @@ -127,15 +128,21 @@ def preprocess( # noqa: C901 module_compile_spec: List[CompileSpec], ) -> PreprocessResult: compile_options = parse_compile_spec(module_compile_spec) - limits_x = compile_options.get( - "texture_limits_x", utils.DEFAULT_TEXTURE_LIMITS[0] - ) - limits_y = compile_options.get( - "texture_limits_y", utils.DEFAULT_TEXTURE_LIMITS[1] - ) - limits_z = compile_options.get( - "texture_limits_z", utils.DEFAULT_TEXTURE_LIMITS[2] - ) + + default_texture_limits = copy.deepcopy(utils.DEFAULT_TEXTURE_LIMITS) + # 2048 is the typical limit value for 3D textures, but mobile GPUs often support + # 16384. Since the Vulkan delegate primarily targets mobile GPUs at the moment, + # 16394 is the default texture limit used. This option is provided as a + # convenient way to switch to using a limit of 2048 for image textures which + # will be compatible with most GPUs. + if compile_options.get("small_texture_limits", False): + default_texture_limits[0] = 2048 + default_texture_limits[1] = 2048 + default_texture_limits[2] = 2048 + + limits_x = compile_options.get("texture_limits_x", default_texture_limits[0]) + limits_y = compile_options.get("texture_limits_y", default_texture_limits[1]) + limits_z = compile_options.get("texture_limits_z", default_texture_limits[2]) texture_limits = (limits_x, limits_y, limits_z) default_storage_type = compile_options.get( @@ -145,6 +152,7 @@ def preprocess( # noqa: C901 "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED ) downcast_64_bit = compile_options.get("downcast_64_bit", True) + force_fp16 = compile_options.get("force_fp16", False) program = unsafe_remove_auto_functionalized_pass(program) @@ -154,16 +162,16 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ - FusePatternsPass(program), - RemoveRedundantOpsTransform(), + FuseBatchNormPass(program), + FusePatternsPass(), + FuseClampPass(), AddmmToLinearTransform(), - FuseQuantizedOpsTransform(program), - FoldQDQPass(program), + RemoveRedundantOpsTransform(), + FuseQuantizedOpsTransform(), + FoldQDQPass(), SqueezeUnsqueezeInputs(), FuseViewCopyTransform(), ViewCopyToSqueezeUnsqueezePass(), - FuseBatchNormPass(program), - FuseClampPass(), ], ) @@ -179,9 +187,6 @@ def preprocess( # noqa: C901 program, [ RemoveAssertsTransform(), - # Since this pass may replace a scalar argument with a tensor argument, - # this pass may result in a non ATen compliant graph structure. - RemoveLocalScalarDenseOpsTransform(), insert_prepack_nodes, ], ) @@ -199,28 +204,39 @@ def preprocess( # noqa: C901 texture_limits, default_storage_type=default_storage_type, default_memory_layout=default_memory_layout, + force_fp16=force_fp16, ), ], ) # Finally, apply dynamic shape passes and memory planning pass. These passes # must be applied only when the graph structure is finalized. - greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False) - mem_planning_suite = MemoryPlanningAlgorithmSuite( - algo_list=[greedy_memory_planning] - ) - program = apply_passes( - program, - [ - ConstraintBasedSymShapeEvalPass(), - MemoryPlanningPass(memory_planning_algo=mem_planning_suite), - ], - ) + final_passes = [ + ConstraintBasedSymShapeEvalPass(), + ] + if not compile_options.get("skip_memory_planning", False): + greedy_memory_planning = partial( + greedy, allow_overlapping_allocations=False + ) + mem_planning_suite = MemoryPlanningAlgorithmSuite( + algo_list=[greedy_memory_planning] + ) + # This is a workaround to allow the memory planning pass to work without having + # to first apply ToOutVarPass(). See the `greedy()` function in + # `exir.memory_planning`; if this attribute isn't set, assertions in + # `collect_spec_from_nodes()` will fail. + program.graph_module.encounter_to_out_var_failure = True + final_passes.append( + MemoryPlanningPass(memory_planning_algo=mem_planning_suite) + ) + + program = apply_passes(program, final_passes) graph_builder = VkGraphBuilder( program, DelegateMappingBuilder(generated_identifiers=True), downcast_64_bit=downcast_64_bit, + force_fp16=force_fp16, ) vk_graph = graph_builder.build_graph() diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt index 33bf84b9066..625e3d2523f 100644 --- a/backends/xnnpack/CMakeLists.txt +++ b/backends/xnnpack/CMakeLists.txt @@ -35,7 +35,10 @@ if(EXECUTORCH_XNNPACK_ENABLE_KLEIDI) add_definitions(-DENABLE_XNNPACK_KLEIDI) endif() -set(_common_compile_options -Wno-deprecated-declarations -fPIC) +set(_common_compile_options + $<$:/wd4996> + $<$>:-Wno-deprecated-declarations -fPIC> +) set(_xnnpack_schema__include_dir "${CMAKE_BINARY_DIR}/schema/include") # Paths to headers generated from the .fbs files. diff --git a/backends/xnnpack/README.md b/backends/xnnpack/README.md index 6e6be7ddb4c..7c6a7ccbc33 100644 --- a/backends/xnnpack/README.md +++ b/backends/xnnpack/README.md @@ -134,4 +134,4 @@ create an issue on [github](https://www.github.com/pytorch/executorch/issues). ## See Also For more information about the XNNPACK Backend, please check out the following resources: - [XNNPACK Backend](https://pytorch.org/executorch/main/backends-xnnpack) -- [XNNPACK Backend Internals](https://pytorch.org/executorch/main/backend-delegates-xnnpack-reference) +- [XNNPACK Backend Internals](https://pytorch.org/executorch/main/backends/xnnpack/backend-delegates-xnnpack-reference) diff --git a/backends/xnnpack/_passes/TARGETS b/backends/xnnpack/_passes/TARGETS index 6f7b13d8026..4977ad08936 100644 --- a/backends/xnnpack/_passes/TARGETS +++ b/backends/xnnpack/_passes/TARGETS @@ -8,6 +8,7 @@ runtime.python_library( deps = [ "//caffe2:torch", "//executorch/backends/transforms:addmm_mm_to_linear", + "//executorch/backends/transforms:remove_clone_ops", "//executorch/backends/transforms:lib", "//executorch/backends/xnnpack/partition:partitioner_graphs", "//executorch/backends/xnnpack/serialization:xnnpack_schema", diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index 141718bde6f..4992d7a4abd 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -4,8 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import List, Optional, Type +from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform + from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import ( @@ -23,6 +27,9 @@ from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass +from executorch.backends.xnnpack._passes.propagate_custom_meta_pass import ( + PropagateCustomMetaPass, +) from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import ( RemoveRedundantCopyPass, ) @@ -39,6 +46,11 @@ from torch.export import ExportedProgram +class XNNPACKRemoveCloneOpsTransform(RemoveCloneOpsTransform): + def __init__(self): + super().__init__(preserve_input_output_copies=True) + + class XNNPACKPassManager: def __init__( self, @@ -55,10 +67,12 @@ def __init__( if not passes: # All the XNNPACK passes self.passes = [ + XNNPACKRemoveCloneOpsTransform, # TODO - remove this pass once we have a better support for dim_order ops lowering DimOrderOpsRevertPass, ConvertToUpsampleBilinear2d, ConvertToLinearPass, + PropagateCustomMetaPass, ConvertToSDPAPass, ConstPropPass, FuseBatchNormPass, diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index 85e9889ca36..179006bc1b6 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from enum import Enum from typing import Optional, Tuple @@ -110,7 +112,9 @@ def is_nhwc_node(node: torch.fx.Node) -> bool: if len(quantize_node.all_input_nodes) > 0: actual_node = quantize_node.args[0] if actual_node.op == "placeholder": - return not actual_node.meta["val"][0].is_contiguous() + return ChannelsLastTaggedReshapePass._is_nhwc_tensor( + actual_node.meta["val"][0] + ) else: return actual_node.meta.get( ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False @@ -125,7 +129,9 @@ def is_nchw_node(node: torch.fx.Node) -> bool: if len(quantize_node.all_input_nodes) > 0: actual_node = quantize_node.args[0] if actual_node.op == "placeholder": - return actual_node.meta["val"][0].is_contiguous() + return not ChannelsLastTaggedReshapePass._is_nhwc_tensor( + actual_node.meta["val"][0] + ) else: return not actual_node.meta.get( ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False @@ -133,10 +139,33 @@ def is_nchw_node(node: torch.fx.Node) -> bool: return not ChannelsLastTaggedReshapePass.is_nhwc_node(node) + @staticmethod + def _is_nhwc_tensor(tensor: torch.Tensor) -> bool: + nhwc = tensor.is_contiguous(memory_format=torch.channels_last) + nchw = tensor.is_contiguous() + # if both are true false + # if both nchw and nhwc are true + # then we want to see this is nchw hence return false + # if either of nchw or nhwc is false, then just rely on hwc + # if both are false, mayb channels_last_3d, then return nhwc + # however this should not happen here + # return (not (nchw and nhwc)) and nhwc + # Readable version + if nchw and nhwc: + return False + else: + return nhwc + + def _is_nhwc(self, tensor: torch.Tensor) -> bool: + return ChannelsLastTaggedReshapePass._is_nhwc_tensor(tensor) + def requires_nhwc_input(self, node: torch.fx.Node) -> bool: return node.target in self.memory_sensitive_ops_nhwc def requires_nchw_inputs(self, node: torch.fx.Node) -> bool: + if node.target == exir_ops.edge.aten.view_copy.default: + return True + return node.target in self.memory_sensitive_ops_nchw def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool: @@ -315,11 +344,8 @@ def input_dim_order( self, input_node: torch.fx.Node, input_order: InputDimOrder ) -> bool: if input_node.op == "placeholder": - return ( - input_node.meta["val"].is_contiguous() - if input_order == InputDimOrder.NCHW - else not input_node.meta["val"].is_contiguous() - ) + is_nhwc = self._is_nhwc(input_node.meta["val"]) + return not is_nhwc if input_order == InputDimOrder.NCHW else is_nhwc else: return ( ChannelsLastTaggedReshapePass.is_nchw_node(input_node) @@ -348,7 +374,7 @@ def input_to_nhwc( self.mark_as_nhwc_node(input_node) if input_node.op == "placeholder": - if not input_node.meta["val"][0].is_contiguous(): + if self._is_nhwc(input_node.meta["val"][0]): return elif ChannelsLastTaggedReshapePass.is_nhwc_node(input_node): return @@ -420,7 +446,7 @@ def input_to_nchw( self.mark_as_nchw_node(input_node) if input_node.op == "placeholder": - if input_node.meta["val"].is_contiguous(): + if not self._is_nhwc(input_node.meta["val"]): return elif ChannelsLastTaggedReshapePass.is_nchw_node(input_node): return @@ -462,17 +488,17 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901 and isinstance(node.meta["val"], torch.Tensor) and len(node.meta["val"].shape) == 4 ): - if node.meta["val"].is_contiguous(): - self.mark_as_nchw_node(node) - else: + if self._is_nhwc(node.meta["val"]): self.mark_as_nhwc_node(node) + else: + self.mark_as_nchw_node(node) continue # Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes if node.op == "output": out_tuple = node.args[0] for out_node in out_tuple: - if out_node.meta["val"].is_contiguous(): + if not self._is_nhwc(out_node.meta["val"]): self.input_to_nchw(graph_module, out_node, node) else: self.input_to_nhwc(graph_module, out_node, node) diff --git a/backends/xnnpack/_passes/propagate_custom_meta_pass.py b/backends/xnnpack/_passes/propagate_custom_meta_pass.py new file mode 100644 index 00000000000..b1a03514446 --- /dev/null +++ b/backends/xnnpack/_passes/propagate_custom_meta_pass.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass +from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant +from executorch.exir.pass_base import PassResult + + +class PropagateCustomMetaPass(XNNPACKPass): + """ + Pass to propagate node.meta['custom'] from parent nodes to their q/dq child nodes. + For all quantize/dequantize nodes in the graph, if the parent node has a + node.meta['custom'] entry, this pass will copy that value to the q/dq node's meta. + """ + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + + for node in graph.nodes: + if not (is_quant(node) or is_dequant(node)): + continue + + # Get the parent node (first input argument) + if len(node.all_input_nodes) == 0: + continue + + parent_node = node.args[0] + if not isinstance(parent_node, torch.fx.Node): + continue + + if "custom" in parent_node.meta: + node.meta["custom"] = parent_node.meta["custom"] + + graph_module.recompile() + + # Since we are overriding "call", we need to call the parent's "call" + # to retrace the graph and regenerate metadata + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index d17b7abd6a1..e06f337f9ee 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from . import ( # noqa node_visitor, op_abs, @@ -14,7 +16,9 @@ op_cat, op_ceiling, op_clamp, + op_clone, op_conv2d, + op_cos, op_div, op_dynamic_dequantize_ops, op_dynamic_quantize_ops, @@ -41,6 +45,7 @@ op_relu, op_rsqrt, op_sigmoid, + op_sin, op_skip_ops, op_slice_copy, op_softmax, @@ -52,4 +57,5 @@ op_sub, op_tanh, op_to_copy, + op_view_copy, ) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 6a055c9413f..4643ada9336 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -232,7 +232,7 @@ def get_per_channel_dtype( if quant_params.dtype == torch.int32: return XNNDatatype.xnn_datatype_qcint32 elif quant_params.dtype == torch.int8: - if quant_params.is_per_channel_group: + if quant_params.per_channel_group: # 4-bit per channel group quantized weights # No 8-bit support yet assert ( @@ -275,14 +275,14 @@ def get_per_channel_dtype( return dtype def get_quant_params( - self, quant_params: QuantParams, xnn_graph: XNNGraph + self, quant_params: QuantParams, xnn_graph: XNNGraph, external_tag: str = None ) -> XNNQuantParams: if quant_params.per_channel: scale = cast(torch.Tensor, quant_params.scale) buffer_idx = len(xnn_graph.constant_data) num_scales = scale.numel() - if quant_params.is_per_channel_group: + if quant_params.per_channel_group: scale = scale.to(torch.bfloat16) num_bytes = scale.untyped_storage().nbytes() @@ -291,16 +291,21 @@ def get_quant_params( ctypes.POINTER(ctypes.c_char * num_bytes), ).contents scale_name = hashlib.sha256(bytes(scale_array)).hexdigest() + scale_name = "scale_" + scale_name xnn_graph.constant_data.append( ConstantDataOffset( offset=UINT64_MAX, size=num_bytes, named_key=scale_name ) ) + if external_tag is not None: + logging.info( + f"Adding constant data with name, key {scale_name} and external_tag {external_tag} to named_data_store" + ) self._named_data_store.add_named_data( - scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT + scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT, external_tag ) - if quant_params.is_per_channel_group: + if quant_params.per_channel_group: return PerChannelGroupQuant( scale=[], channel_dim=quant_params.axis, @@ -335,7 +340,7 @@ def _check_per_channel_group_params( ) -> None: # Make sure things are lining up for per_channel_group quantization case # Has to be done this late because we don't have clean access to the actual tensor - assert quant_params.is_per_channel_group, "Not per_channel_group quantization" + assert quant_params.per_channel_group, "Not per_channel_group quantization" # linear weights will be in [oc, ic]. And per_channel quantization must be on axis 0 num_groups = cast(torch.Tensor, quant_params.scale).shape[1] assert ( @@ -470,13 +475,19 @@ def define_tensor( # noqa: C901 assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : {quant_params.axis}, expecting 0 / 1." # Serialize tensor value + custom_meta = tensor.meta.get("custom", None) + external_tag = ( + custom_meta.get("delegate_constant_tag", None) if custom_meta else None + ) ser_val = ( XValue(xvalue_union=tvalue) if quant_params is None else XValue( xvalue_union=XNNQuantizedTensorValue( tensor_value=tvalue, - quant_params=self.get_quant_params(quant_params, xnn_graph), + quant_params=self.get_quant_params( + quant_params, xnn_graph, external_tag + ), ) ) ) @@ -614,7 +625,7 @@ def get_serialized_buffer_index( f"Serializing constant data node {tensor} but tensor value has no bytes", ) sha256_hash = hashlib.sha256(bytes(array)) - named_key = sha256_hash.hexdigest() + named_key = tensor.name + "_" + sha256_hash.hexdigest() size = const_val.untyped_storage().nbytes() xnn_graph.constant_data.append( @@ -626,7 +637,6 @@ def get_serialized_buffer_index( custom_meta.get("delegate_constant_tag", None) if custom_meta else None ) if external_tag is not None: - external_tag = custom_meta.get("delegate_constant_tag", None) logging.info( f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store" ) diff --git a/backends/xnnpack/operators/op_clone.py b/backends/xnnpack/operators/op_clone.py new file mode 100644 index 00000000000..e4ddf187ecc --- /dev/null +++ b/backends/xnnpack/operators/op_clone.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNCopy, + XNNGraph, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + + +@register_node_visitor +class CloneVisitor(NodeVisitor): + target = "aten.clone.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + # Sanity check that the input and output dim order are the same. We don't + # handle dim order conversions yet. + dim_order = node.kwargs.get("dim_order", None) + input_meta = node.args[0].meta["val"] + assert dim_order is None or list(input_meta.dim_order()) == dim_order + + # input + input_id = vals_to_ids[get_input_node(node, 0)] + + # output + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNCopy( + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/operators/op_cos.py b/backends/xnnpack/operators/op_cos.py new file mode 100644 index 00000000000..aa3166c96dd --- /dev/null +++ b/backends/xnnpack/operators/op_cos.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNCos, + XNNGraph, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + + +@register_node_visitor +class CosVisitor(NodeVisitor): + target = "aten.cos.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + # input + input_id = vals_to_ids[get_input_node(node, 0)] + + # output + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNCos( + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/operators/op_sin.py b/backends/xnnpack/operators/op_sin.py new file mode 100644 index 00000000000..56fe9396103 --- /dev/null +++ b/backends/xnnpack/operators/op_sin.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNGraph, + XNNSin, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + + +@register_node_visitor +class SinVisitor(NodeVisitor): + target = "aten.sin.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + # input + input_id = vals_to_ids[get_input_node(node, 0)] + + # output + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNSin( + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/operators/op_skip_ops.py b/backends/xnnpack/operators/op_skip_ops.py index 19df74e77ac..04be2b274b2 100644 --- a/backends/xnnpack/operators/op_skip_ops.py +++ b/backends/xnnpack/operators/op_skip_ops.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import Dict import torch @@ -59,16 +61,6 @@ class OpTCopyDefault(OpSkipOps): target = "aten.t_copy.default" -@register_node_visitor -class OpViewCopyDefault(OpSkipOps): - """ - currently, do nothing if node is view_copy.default - need to handle this later on, currently view it as one of skip ops - """ - - target = "aten.view_copy.default" - - @register_node_visitor class OpSymSizeInt(OpSkipOps): """ diff --git a/backends/xnnpack/operators/op_view_copy.py b/backends/xnnpack/operators/op_view_copy.py new file mode 100644 index 00000000000..5a8bf342eab --- /dev/null +++ b/backends/xnnpack/operators/op_view_copy.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNGraph, + XNNStaticReshape, + XNode, +) +from executorch.backends.xnnpack.utils.utils import ( + check_or_raise, + get_input_node, + PERM_NCHW_TO_NHWC, +) + + +@register_node_visitor +class ViewCopyVisitor(NodeVisitor): + target = "aten.view_copy.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + input_node = get_input_node(node, 0) + + # input + input_id = vals_to_ids[input_node] + + # output + output_id = vals_to_ids[node] + + # input shape + check_or_raise( + "val" in input_node.meta, + "Missing val in tensor metadata for input when serializing XNNStaticReshape", + ) + + # output shape + check_or_raise( + "val" in node.meta, + "Missing val in tensor metadata for input when serializing XNNStaticReshape", + ) + + new_shape = node.args[1] + check_or_raise( + all(isinstance(d, int) for d in new_shape), + "Symbolic reshape parameter is not supported in XNNStaticReshape", + ) + + # PyTorch uses -1 for inferred dims, whereas XNNPACK expects 0. + new_shape = tuple(d if d != -1 else 0 for d in new_shape) + + # Handle NCHW dim order - if this op is in NCHW order, we need to permute the + # view shape correspondingly. + if "XNN_NHWC_NODE" in node.meta: + check_or_raise(len(new_shape) == 4, "Invalid NCHW shape") + new_shape = [new_shape[PERM_NCHW_TO_NHWC[n]] for n in range(4)] + + num_dynamic_dims = sum(1 for d in new_shape if d == 0) + + check_or_raise( + num_dynamic_dims <= 1, + "XNNPACK reshape only supports 1 dynamic dimension.", + ) + + ser_node = XNode( + xnode_union=XNNStaticReshape( + num_dims=len(new_shape), + new_shape=new_shape, + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/operators/quant_params.py b/backends/xnnpack/operators/quant_params.py index 88a1f660f0e..f1c87c0b8b6 100644 --- a/backends/xnnpack/operators/quant_params.py +++ b/backends/xnnpack/operators/quant_params.py @@ -89,6 +89,9 @@ def __init__( # Groupwise quantization for weight self.per_channel_group = False self.group_size = group_size + + tensor = q_input.meta["val"] + if self.group_size > 0: assert ( self.per_channel is True @@ -96,12 +99,29 @@ def __init__( assert ( cast(torch.Tensor, scale).ndim == 2 ), "Scale must be 2D for per channel groupwise quant" - self.per_channel_group = True - assert group_size > 0, "Group size must be greater than 0" - self.is_per_channel_group = self.per_channel and self.group_size > 0 + # Assumed scale shape - [out_channels, in_channels/group_size] + input_channels = cast(torch.Tensor, scale).shape[1] * self.group_size + # 2d weight tensor shape - [out_channels, in_channels] + assert ( + tensor.shape[1] == input_channels + ), "Invalid input channels for groupwise quant" + # Prefer per_channel over per_channel_group when group_size == input_channels for non int4 cases only + # int4 case need more fixes to map qb4w to qc4w. Incorrect scales being passed down to xnnpack. + self.per_channel_group = ( + self.group_size <= input_channels + if self.is_qc4w + else self.group_size < input_channels + ) + + if not self.per_channel_group: + if cast(torch.Tensor, scale).ndim == 2: + # TODO: don't reshape scale for per_channel cases + assert ( + cast(torch.Tensor, scale).shape[1] == 1 + ), "Invalid scale shape for per channel quantization" + scale = cast(torch.Tensor, scale).squeeze(1) - if per_channel and not self.is_per_channel_group: - tensor = q_input.meta["val"] + if per_channel and not self.per_channel_group: assert ( tensor.shape[self.axis] == cast(torch.Tensor, self.scale).shape[0] ), f"Invalid size of per channel quantization scales, axis: {self.axis}, scale size: {self.scale.shape}, tensor shape: {tensor.shape}" @@ -110,6 +130,39 @@ def __init__( tensor.shape[self.axis] == cast(torch.Tensor, self.zp).shape[0] ), f"Invalid size of per channel quantization zero-points, axis: {self.axis}, zp size: {self.zp.shape}, tensor shape: {tensor.shape}" + def __str__(self) -> str: + """String representation of QuantParams for debugging and logging.""" + assert isinstance(self.scale, float) or isinstance(self.scale, torch.Tensor) + scale_str = ( + f"{self.scale}" + if isinstance(self.scale, float) + else f"tensor{tuple(self.scale.shape)}" + ) + assert isinstance(self.zp, float) or isinstance(self.zp, torch.Tensor) + zp_str = ( + f"{self.zp}" + if isinstance(self.zp, float) + else f"tensor{tuple(self.zp.shape)}" + ) + + return ( + f"QuantParams(" + f"per_channel={self.per_channel}, " + f"per_channel_group={self.per_channel_group}, " + f"scale={scale_str}, " + f"zp={zp_str}, " + f"axis={self.axis}, " + f"dtype={self.dtype}, " + f"qmin={self.qmin}, " + f"qmax={self.qmax}, " + f"is_dynamic={self.is_dynamic}, " + f"is_input={self.is_input}, " + f"is_output={self.is_output}, " + f"group_size={self.group_size}, " + f"is_qc4w={self.is_qc4w}" + f")" + ) + def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor: # Do nothing if already quantized by the Quantizer if tensor.dtype == self.dtype: diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index e393f1c9ac8..26ac6275ef1 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import List, Type @@ -22,7 +23,9 @@ CatConfig, CeilConfig, ClampConfig, + CloneDimOrderConfig, ConstantPadConfig, + CosConfig, DeQuantizedPerTensorConfig, DivConfig, # EluConfig, @@ -45,6 +48,7 @@ ReciprocalSquareRootConfig, ReLUConfig, SigmoidConfig, + SinConfig, SliceCopyConfig, SoftmaxConfig, SquareRootConfig, @@ -76,6 +80,7 @@ BMMConfig, CatConfig, CeilConfig, + CloneDimOrderConfig, ConstantPadConfig, ConvolutionConfig, ClampConfig, @@ -105,6 +110,8 @@ TanhConfig, ToDimOrderCopyConfig, SigmoidConfig, + SinConfig, + CosConfig, SliceCopyConfig, SoftmaxConfig, SquareRootConfig, diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index f65f9cb3398..d025c8e6029 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -458,9 +458,7 @@ def get_deps( a bool indicating if the deps are valid and a list of all the dep nodes. This handles the src partition for """ - if self.src_partitions is None: - # Cache src partitions so we don't have to recompute them every time - self.src_partitions = get_source_partitions(ep.graph, self.linear_modules) + self.src_partitions = get_source_partitions(ep.graph, self.linear_modules) # src_partition is None if node is not in source partition, # otherwise gives us the linear source partition it belongs to diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 559d1522275..d36072d1991 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -380,6 +380,43 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] +class ViewCopyConfig(GenericNodePartitionerConfig): + target_name = "view_copy.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + XNNPACK's static_reshape only supports 1 dynamic dimension. + """ + if not self.check_common_constraints(node, ep): + return False + + new_shape = node.args[1] + + # Check for symbolic dims. They aren't lowerable to XNNPACK currently. + symbolic_dim_indices = [ + i for i, d in enumerate(new_shape) if not isinstance(d, int) + ] + if not all(isinstance(n, int) for n in new_shape): + why( + node, + reason=f"Symbolic reshape is not supported. Output shape is {new_shape} and dims at {symbolic_dim_indices} are symbolic.", + ) + return False + + dynamic_dim_indices = [i for i, d in enumerate(new_shape) if d == -1] + if len(dynamic_dim_indices) > 1: + why( + node, + reason=f"Only a single inferred dimension is supported. Output shape is {new_shape} and dims {dynamic_dim_indices} are inferred.", + ) + return False + + return True + + class FloorConfig(GenericNodePartitionerConfig): target_name = "floor.default" @@ -636,3 +673,39 @@ class BMMConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] + + +class SinConfig(GenericNodePartitionerConfig): + target_name = "sin.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + +class CloneDimOrderConfig(GenericNodePartitionerConfig): + target_name = "_clone_dim_order.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + if not self.check_common_constraints(node, ep): + return False + + # Only partition no-op _clone_dim_order nodes (output dim order = input). + # We can relax this in the future. + # This is also a conservative check and doesn't consider ambiguity. + dim_order = node.kwargs.get("dim_order", None) + input_meta = node.args[0].meta["val"] + if dim_order is not None and list(input_meta.dim_order()) != dim_order: + why(node, reason="Only dim-order preserving clones are supported.") + return False + + return True + + +class CosConfig(GenericNodePartitionerConfig): + target_name = "cos.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] diff --git a/backends/xnnpack/recipes/TARGETS b/backends/xnnpack/recipes/TARGETS index 6b6c1ddfe82..5d452b4a4b7 100644 --- a/backends/xnnpack/recipes/TARGETS +++ b/backends/xnnpack/recipes/TARGETS @@ -30,6 +30,7 @@ runtime.python_library( deps = [ "//caffe2:torch", "//executorch/export:lib", + "//executorch/runtime:runtime", # @manual "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", ":xnnpack_recipe_types", diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 78eaaf6d039..4881844ac6d 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -174,13 +174,12 @@ payload (deprecated) or via offsets to the constant_data_ptr. If no constant data associated with the tensor value, then returns nullptr. */ const uint8_t* getConstantDataPtr( - const fb_xnnpack::XNNTensorValue* tensor_value, + uint32_t buffer_idx, GraphPtr flatbuffer_graph, const uint8_t* constant_data_ptr, const NamedDataMap* named_data_map, std::vector& freeable_buffers, XNNWeightsCache* weights_cache) { - auto buffer_idx = tensor_value->constant_buffer_idx(); if (buffer_idx) { if (!constant_data_ptr) { // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC @@ -230,6 +229,22 @@ const uint8_t* getConstantDataPtr( return nullptr; } +const uint8_t* getConstantDataPtr( + const fb_xnnpack::XNNTensorValue* tensor_value, + GraphPtr flatbuffer_graph, + const uint8_t* constant_data_ptr, + const NamedDataMap* named_data_map, + std::vector& freeable_buffers, + XNNWeightsCache* weights_cache) { + return getConstantDataPtr( + tensor_value->constant_buffer_idx(), + flatbuffer_graph, + constant_data_ptr, + named_data_map, + freeable_buffers, + weights_cache); +} + /** Define serialized tensor value into the subgraph. While also keeping track of the remapped ids from @@ -434,22 +449,15 @@ Error defineTensor( const float* scale = qparams->scale()->data(); if (qparams->scale_buffer_idx() != 0) { - // if scales are stored in named data, then retrieve it - ConstantDataOffsetPtr scale_buffer_offset = - flatbuffer_graph->constant_data()->Get( - qparams->scale_buffer_idx()); - const std::string& data_name = - scale_buffer_offset->named_key()->str(); - Result scale_buffer = - named_data_map->get_data(data_name.c_str()); + scale = reinterpret_cast(getConstantDataPtr( + qparams->scale_buffer_idx(), + flatbuffer_graph, + constant_data_ptr, + named_data_map, + freeable_buffers, + weights_cache)); ET_CHECK_OR_RETURN_ERROR( - scale_buffer.ok(), - Internal, - "Failed to get constant data for key %s from named_data_map. Error code: %u", - data_name.c_str(), - static_cast(scale_buffer.error())); - scale = reinterpret_cast(scale_buffer.get().data()); - freeable_buffers.push_back(std::move(scale_buffer.get())); + scale != nullptr, Internal, "Failed to load scale data."); } status = xnn_define_channelwise_quantized_tensor_value_v2( /*subgraph=*/subgraph_ptr, @@ -483,22 +491,15 @@ Error defineTensor( // Block scales are preferably serialized as bf16 but can also be // serialized as fp32 for backwards compatability. if (qparams->scale_buffer_idx() != 0) { - ConstantDataOffsetPtr scale_buffer_offset = - flatbuffer_graph->constant_data()->Get( - qparams->scale_buffer_idx()); - const std::string& data_name = - scale_buffer_offset->named_key()->str(); - Result scale_buffer = - named_data_map->get_data(data_name.c_str()); + scale_data = reinterpret_cast(getConstantDataPtr( + qparams->scale_buffer_idx(), + flatbuffer_graph, + constant_data_ptr, + named_data_map, + freeable_buffers, + weights_cache)); ET_CHECK_OR_RETURN_ERROR( - scale_buffer.ok(), - Internal, - "Failed to get constant data for key %s from named_data_map. Error code: %u", - data_name.c_str(), - static_cast(scale_buffer.error())); - scale_data = - reinterpret_cast(scale_buffer.get().data()); - freeable_buffers.push_back(std::move(scale_buffer.get())); + scale_data != nullptr, Internal, "Failed to load scale data."); scale_numel = qparams->num_scales(); } else { // Read fp32 scales, convert to bf16. @@ -1458,6 +1459,34 @@ Error defineBatchMatrixMultiplyNode( return Error::Ok; } +/* + * Defines a copy node in the XNN subgraph. + */ +Error defineCopyNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const NodePtr node, + const fb_xnnpack::XNNGraph* graph) noexcept { + MAYBE_UNUSED(graph); + + auto graph_node = node->xnode_union_as_XNNCopy(); + + xnn_status status = xnn_define_copy( + subgraph_ptr, + remapped_ids.at(graph_node->input_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create copy node %i with code: %s", + node->debug_handle(), + xnn_status_to_string(status)); + + return Error::Ok; +} + /* Returns not Implemented Error code. This function is meant to be called when the compiler encountes a XNodeType from the flatbuffer @@ -1538,8 +1567,9 @@ Error defineGenericUnaryNode( MAYBE_UNUSED(graph); \ auto graph_node = node->xnode_union_as_XNN##name(); \ std::pair min_max = getOutputMinMax(node); \ - union xnn_unary_params params = { \ - .clamp = {.min = min_max.first, .max = min_max.second}}; \ + union xnn_unary_params params; \ + params.clamp.min = min_max.first; \ + params.clamp.max = min_max.second; \ return defineGenericUnaryNode( \ subgraph_ptr, \ remapped_ids, \ @@ -1553,48 +1583,49 @@ Error defineGenericUnaryNode( } // Macro for unary operations with leaky_relu parameters -#define _DEFINE_UNARY_NODE_WITH_LEAKY_RELU(name) \ - Error define##name##Node( \ - xnn_subgraph_t subgraph_ptr, \ - const std::unordered_map& remapped_ids, \ - const NodePtr node, \ - const fb_xnnpack::XNNGraph* graph) noexcept { \ - MAYBE_UNUSED(graph); \ - auto graph_node = node->xnode_union_as_XNNLeakyReLU(); \ - union xnn_unary_params params = { \ - .leaky_relu = {.negative_slope = graph_node->negative_slope()}}; \ - return defineGenericUnaryNode( \ - subgraph_ptr, \ - remapped_ids, \ - graph_node->input_id(), \ - graph_node->output_id(), \ - graph_node->flags(), \ - xnn_unary_leaky_relu, \ - ¶ms, \ - node->xnode_union_type(), \ - node->debug_handle()); \ +#define _DEFINE_UNARY_NODE_WITH_LEAKY_RELU(name) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNNLeakyReLU(); \ + union xnn_unary_params params; \ + params.leaky_relu.negative_slope = graph_node->negative_slope(); \ + return defineGenericUnaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node->input_id(), \ + graph_node->output_id(), \ + graph_node->flags(), \ + xnn_unary_leaky_relu, \ + ¶ms, \ + node->xnode_union_type(), \ + node->debug_handle()); \ } // Macro for unary operations with elu parameters -#define _DEFINE_UNARY_NODE_WITH_ELU(name) \ - Error define##name##Node( \ - xnn_subgraph_t subgraph_ptr, \ - const std::unordered_map& remapped_ids, \ - const NodePtr node, \ - const fb_xnnpack::XNNGraph* graph) noexcept { \ - MAYBE_UNUSED(graph); \ - auto graph_node = node->xnode_union_as_XNNELU(); \ - union xnn_unary_params params = {.elu = {.alpha = graph_node->alpha()}}; \ - return defineGenericUnaryNode( \ - subgraph_ptr, \ - remapped_ids, \ - graph_node->input_id(), \ - graph_node->output_id(), \ - graph_node->flags(), \ - xnn_unary_elu, \ - ¶ms, \ - node->xnode_union_type(), \ - node->debug_handle()); \ +#define _DEFINE_UNARY_NODE_WITH_ELU(name) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNNELU(); \ + union xnn_unary_params params; \ + params.elu.alpha = graph_node->alpha(); \ + return defineGenericUnaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node->input_id(), \ + graph_node->output_id(), \ + graph_node->flags(), \ + xnn_unary_elu, \ + ¶ms, \ + node->xnode_union_type(), \ + node->debug_handle()); \ } // Generic helper function for binary operations @@ -1627,25 +1658,26 @@ Error defineGenericBinaryNode( } // Macro for binary operations with min/max parameters -#define _DEFINE_BINARY_NODE_WITH_MINMAX(name, op_type) \ - Error define##name##Node( \ - xnn_subgraph_t subgraph_ptr, \ - const std::unordered_map& remapped_ids, \ - const NodePtr node, \ - const fb_xnnpack::XNNGraph* graph) noexcept { \ - MAYBE_UNUSED(graph); \ - auto graph_node = node->xnode_union_as_XNN##name(); \ - std::pair min_max = getOutputMinMax(node); \ - struct xnn_binary_params params = { \ - .output_min = min_max.first, .output_max = min_max.second}; \ - return defineGenericBinaryNode( \ - subgraph_ptr, \ - remapped_ids, \ - graph_node, \ - op_type, \ - ¶ms, \ - node->xnode_union_type(), \ - node->debug_handle()); \ +#define _DEFINE_BINARY_NODE_WITH_MINMAX(name, op_type) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNN##name(); \ + std::pair min_max = getOutputMinMax(node); \ + struct xnn_binary_params params; \ + params.output_min = min_max.first; \ + params.output_max = min_max.second; \ + return defineGenericBinaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node, \ + op_type, \ + ¶ms, \ + node->xnode_union_type(), \ + node->debug_handle()); \ } // Macro for binary operations without parameters @@ -1689,6 +1721,8 @@ _DEFINE_UNARY_NODE_NO_PARAMS(Log, xnn_unary_log) _DEFINE_UNARY_NODE_NO_PARAMS(Negate, xnn_unary_negate) _DEFINE_UNARY_NODE_NO_PARAMS(Square, xnn_unary_square) _DEFINE_UNARY_NODE_NO_PARAMS(Abs, xnn_unary_abs) +_DEFINE_UNARY_NODE_NO_PARAMS(Sin, xnn_unary_sine) +_DEFINE_UNARY_NODE_NO_PARAMS(Cos, xnn_unary_cosine) // Unary Ops with min/max params _DEFINE_UNARY_NODE_WITH_MINMAX(Clamp, xnn_unary_clamp) @@ -1736,6 +1770,8 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Floor) _DEFINE(PReLU) _DEFINE(Sigmoid) + _DEFINE(Sin) + _DEFINE(Cos) // Others _DEFINE(FullyConnected) @@ -1757,6 +1793,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Concatenate5) _DEFINE(StaticSlice) _DEFINE(BatchMatrixMultiply) + _DEFINE(Copy) case fb_xnnpack::XNodeUnion::NONE: default: // Adding here as a catch all, just in case return &defineNotImplementedNode; @@ -1895,9 +1932,8 @@ ET_NODISCARD Error XNNCompiler::compileModel( xnn_weights_cache_t weights_cache_ptr = nullptr; #endif -#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE - ET_CHECK_OR_RETURN_ERROR( - workspace != nullptr, Internal, "Failed to initialize XNNPACK workspace"); + // NOLINTBEGIN(facebook-hte-NullableDereference) - weights cache is allowed to + // be null status = xnn_create_runtime_v4( subgraph.get(), weights_cache_ptr, @@ -1905,14 +1941,7 @@ ET_NODISCARD Error XNNCompiler::compileModel( ::executorch::extension::threadpool::get_pthreadpool(), runtime_flags, &runtime_ptr); -#else - status = xnn_create_runtime_v3( - subgraph.get(), - weights_cache_ptr, - ::executorch::extension::threadpool::get_pthreadpool(), - runtime_flags, - &runtime_ptr); -#endif + // NOLINTEND(facebook-hte-NullableDereference) ET_CHECK_OR_RETURN_ERROR( xnn_status_success == status, diff --git a/backends/xnnpack/runtime/XNNExecutor.h b/backends/xnnpack/runtime/XNNExecutor.h index f7084a5dd88..c7926744dd6 100644 --- a/backends/xnnpack/runtime/XNNExecutor.h +++ b/backends/xnnpack/runtime/XNNExecutor.h @@ -9,13 +9,13 @@ #pragma once #include +#include #include #include #include #include #include -#include #include #include @@ -35,9 +35,11 @@ class XNNExecutor { std::vector output_ids_; std::vector externals_; std::vector packed_data_names_; + std::shared_ptr workspace_; public: - XNNExecutor() = default; + XNNExecutor(std::shared_ptr workspace) + : workspace_(workspace) {} inline size_t getNumInputs() { return input_ids_.size(); @@ -51,6 +53,10 @@ class XNNExecutor { return packed_data_names_; } + inline std::shared_ptr get_workspace() { + return workspace_; + } + /** * Initialize the XNNExecutor with a given runtime and input/output ids. * The input/output ids are expected to be sorted in order of their diff --git a/backends/xnnpack/runtime/XNNPACKBackend.cpp b/backends/xnnpack/runtime/XNNPACKBackend.cpp index b05919ecf2b..76e83d4b57b 100644 --- a/backends/xnnpack/runtime/XNNPACKBackend.cpp +++ b/backends/xnnpack/runtime/XNNPACKBackend.cpp @@ -7,7 +7,10 @@ */ #include +#include #include +#include +#include #include #include #include @@ -21,14 +24,18 @@ namespace executorch { namespace backends { +using executorch::backends::xnnpack::WorkspaceSharingMode; +using executorch::backends::xnnpack::XNNWorkspace; using executorch::backends::xnnpack::delegate::XNNWeightsCache; using executorch::ET_RUNTIME_NAMESPACE::Backend; using executorch::ET_RUNTIME_NAMESPACE::BackendExecutionContext; using executorch::ET_RUNTIME_NAMESPACE::BackendInitContext; +using executorch::ET_RUNTIME_NAMESPACE::BackendOptionContext; using executorch::ET_RUNTIME_NAMESPACE::CompileSpec; using executorch::ET_RUNTIME_NAMESPACE::DelegateHandle; using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap; using executorch::runtime::ArrayRef; +using executorch::runtime::BackendOption; using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::FreeableBuffer; @@ -51,23 +58,8 @@ class XnnpackBackend final return; } -#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE - // Create a workspace for the XNNExecutor to use. This workspace will be - // shared across all delegate instances. - ET_LOG(Debug, "Creating XNN workspace"); - xnn_workspace_t workspace = nullptr; - status = xnn_create_workspace(&workspace); - if (status != xnn_status_success) { - ET_LOG( - Error, - "Failed to create XNN workspace, XNNPACK status: 0x%x", - (unsigned int)status); - workspace = nullptr; - return; - } - workspace_.reset(workspace); - ET_LOG(Debug, "Created XNN workspace: %p", workspace_.get()); -#endif // ENABLE_XNNPACK_SHARED_WORKSPACE + // Workspace manager is initialized with the appropriate default mode in its + // constructor } bool is_available() const override { @@ -85,11 +77,16 @@ class XnnpackBackend final } const NamedDataMap* named_data_map = context.get_named_data_map(); - // thread safe. This can heppen when multiple threads call init() on + // thread safe. This can happen when multiple threads call init() on // the same backend instance. -#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE - const std::lock_guard lock(workspace_mutex_); -#endif + + auto program_id = + reinterpret_cast(context.get_runtime_allocator()); + auto workspace_result = get_or_create_workspace(program_id); + if (!workspace_result.ok()) { + return workspace_result.error(); + } + auto workspace = workspace_result.get(); #ifdef ENABLE_XNNPACK_WEIGHTS_CACHE const std::lock_guard lock_weight_cache(weights_cache_mutex_); @@ -97,17 +94,19 @@ class XnnpackBackend final context.get_runtime_allocator(), named_data_map); #endif + auto [workspace_lock, workspace_ptr] = workspace->acquire(); + // Executor has been allocated but not constructed, ensure that runtime_ is // nullptr by constructing it in place here. NOTE: Since we use placement // new and since this type is not trivially destructible, we must call the // destructor manually in destroy(). - new (executor) xnnpack::delegate::XNNExecutor; + new (executor) xnnpack::delegate::XNNExecutor(workspace); Error err = xnnpack::delegate::XNNCompiler::compileModel( processed->data(), processed->size(), executor, weights_cache_.get(), - workspace_.get(), + workspace_ptr, named_data_map); // This backend does not need its processed data after compiling the model. processed->Free(); @@ -130,14 +129,12 @@ class XnnpackBackend final Span args) const override { auto executor = static_cast(handle); -#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE - const std::lock_guard lock(workspace_mutex_); -#endif - #ifdef ENABLE_XNNPACK_WEIGHTS_CACHE const std::lock_guard lock_weights_cache(weights_cache_mutex_); #endif + auto [raii_lock, _] = executor->get_workspace()->acquire(); + // Prepare Inputs/Outputs and Propagate Input Shapes Error err = executor->prepare_args(args); if (err != Error::Ok) { @@ -158,13 +155,6 @@ class XnnpackBackend final void destroy(DelegateHandle* handle) const override { if (handle != nullptr) { - // This is needed to serialize access to xnn_delete_runtime which is not - // thread safe. This can heppen when multiple threads call destroy() on - // the same backend instance. -#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE - const std::lock_guard lock(workspace_mutex_); -#endif - auto executor = static_cast(handle); #ifdef ENABLE_XNNPACK_PROFILING @@ -176,18 +166,87 @@ class XnnpackBackend final weights_cache_mutex_); weights_cache_->delete_packed_data(executor->get_packed_data_names()); #endif + + // This is needed to serialize access to xnn_delete_runtime which is not + // thread safe. This can heppen when multiple threads call destroy() on + // the same backend instance. Make sure to hold onto the workspace + // shared_ptr, as the pointer in the executor is freed, which includes + // the mutex referenced by raii_lock. + auto workspace = executor->get_workspace(); + auto [raii_lock, _] = workspace->acquire(); + // XNNExecutor is not trivially destructible. Since this was constructed // manually in init(), we must destroy it manually here. executor->~XNNExecutor(); } } + Error get_option_internal( + BackendOptionContext& context, + executorch::runtime::Span& + backend_options) const { + // Intentionally not locking here as it is not required. + + // Verify that the expected option key is present and modify the value + for (size_t i = 0; i < backend_options.size(); ++i) { + if (strcmp( + backend_options[i].key, + xnnpack::workspace_sharing_mode_option_key) == 0) { + // Set the value to what was stored by set_option + backend_options[i].value = + static_cast(workspace_manager_.get_sharing_mode()); + } + } + + return Error::Ok; + } + + Error get_option( + BackendOptionContext& context, + executorch::runtime::Span& + backend_options) override { + return get_option_internal(context, backend_options); + } + + Error set_option( + BackendOptionContext& context, + const executorch::runtime::Span& + backend_options) override { + if (backend_options.size() > 0) { + for (const auto& option : backend_options) { + if (strcmp(option.key, xnnpack::workspace_sharing_mode_option_key) == + 0) { + if (auto* val = std::get_if(&option.value)) { + if (*val < 0 || + *val > static_cast(WorkspaceSharingMode::Count)) { + ET_LOG( + Error, + "XNNPACK workspace sharing mode must be between 0 and %d, inclusive, but was %d.", + static_cast(WorkspaceSharingMode::Count), + *val); + return Error::InvalidArgument; + } + + ET_LOG( + Debug, "Setting XNNPACK workspace sharing mode to %d.", *val); + auto status = workspace_manager_.set_sharing_mode( + static_cast(*val)); + if (status != Error::Ok) { + return status; + } + } else { + ET_LOG(Error, "XNNPACK workspace sharing mode must be an integer."); + return Error::InvalidArgument; + } + } + } + } + return Error::Ok; + } + private: - // This is a global workspace for all delegate instances. - mutable std::mutex workspace_mutex_; - std::unique_ptr workspace_{ - nullptr, - &xnn_release_workspace}; + // Workspace manager for handling workspace sharing modes + mutable xnnpack::XNNWorkspaceManager workspace_manager_; // Weights cache is global to all delegate instances. mutable std::mutex weights_cache_mutex_; @@ -195,13 +254,21 @@ class XnnpackBackend final std::make_unique(); // Lock Hiearchy for Mutexes: - // workspace_mutex_ // weights_cache_mutex_ + // workspace_meta_mutex_ + // workspace_mutex_ (owned by executor) + + // Retrieve a workspace for the given method ID, depending on the sharing + // mode. + Result> get_or_create_workspace( + uintptr_t program_id) const { + return workspace_manager_.get_or_create_workspace(program_id); + } }; namespace { -auto cls = XnnpackBackend(); -Backend backend{"XnnpackBackend", &cls}; +auto backend_instance = XnnpackBackend(); +Backend backend{xnnpack::xnnpack_backend_key, &backend_instance}; static auto success_with_compiler = register_backend(backend); } // namespace diff --git a/backends/xnnpack/runtime/XNNPACKBackend.h b/backends/xnnpack/runtime/XNNPACKBackend.h new file mode 100644 index 00000000000..aca72f8652b --- /dev/null +++ b/backends/xnnpack/runtime/XNNPACKBackend.h @@ -0,0 +1,42 @@ +#pragma once + +namespace executorch::backends::xnnpack { +/// The key for the backend. This is used to register the backend, check +/// availability, and get/set options. +const char xnnpack_backend_key[] = "XnnpackBackend"; + +/// The key for the workspace sharing option. See the WorkspaceSharingMode enum +/// for a description of the associated functionality. +const char workspace_sharing_mode_option_key[] = "workspace_sharing_mode"; + +/// Workspace sharing mode. This is a backend option that can be set via the +/// set_option API to control memory sharing between CALL_DELEGATE instances. +/// This is useful for reducing memory consumption. +enum class WorkspaceSharingMode { + /// No workspace sharing. Each CALL_DELEGATE instance will have its own + /// workspace (memory arena). + Disabled = 0, + + /// All CALL_DELEGATE instances in a given program will share a workspace. + /// This reduces memory consumption + /// for methods with multiple delegate calls, at the cost of only allowing one + /// method to execute at a time. + PerModel = 1, + + /// All CALL_DELEGATE instances accross all loaded methods will share a + /// workspace. This reduces memory + /// consumption by overlapping activation memory between methods but enforces + /// synchronization between + /// methods. If multiple methods are run concurrently, it may block as only + /// one delegate call occur + /// at a time. Additionally, the workspace does not shrink when a method is + /// unloaded, so memory will + /// only be reclaimed when all XNNPACK-delegated methods are unloaded. + Global = 2, + + /// The number of workspace sharing modes. This is not a valid mode and is + /// only used for tracking the + // maximum enum value. + Count, +}; +} // namespace executorch::backends::xnnpack diff --git a/backends/xnnpack/runtime/XNNWeightsCache.cpp b/backends/xnnpack/runtime/XNNWeightsCache.cpp index 1a230c19976..54191b72825 100644 --- a/backends/xnnpack/runtime/XNNWeightsCache.cpp +++ b/backends/xnnpack/runtime/XNNWeightsCache.cpp @@ -11,6 +11,9 @@ #include #include #include +#include +#include +#include #include #include @@ -155,21 +158,45 @@ size_t XNNWeightsCache::look_up( return packed_weight_entry->second.offset; } +/** + * Reserve space in the weight cache for n bytes of weight data, aligned to + * context->kPackedAllocationAlignment. This function will return nullptr if + * the allocation fails. + */ void* XNNWeightsCache::reserve_space(XNNWeightsCache* context, size_t n) { // MemoryAllocator* allocator = context->runtime_allocator_; // void* reserved_pointer = allocator->allocate(n, // context->kPackedAllocationAlignment); // return reserved_pointer; - std::string data_container; - data_container.resize(n + context->kPackedAllocationAlignment); - void* maybe_aligned_space = data_container.data(); - void* aligned_space = (void*)((intptr_t)maybe_aligned_space + 64 - - (intptr_t)maybe_aligned_space % 64); - - context->packed_pointer_to_container_[aligned_space] = - std::move(data_container); - return aligned_space; + try { + std::string data_container; + size_t raw_allocation_size = n + context->kPackedAllocationAlignment - 1; + data_container.resize(raw_allocation_size); + + void* maybe_aligned_space = data_container.data(); + void* aligned_space = std::align( + context->kPackedAllocationAlignment, + n, + maybe_aligned_space, + raw_allocation_size // Note that std::align mutates this value. + ); + ET_CHECK_MSG(aligned_space != nullptr, "Memory alignment failed."); + + context->packed_pointer_to_container_[aligned_space] = + std::move(data_container); + return aligned_space; + } catch (std::bad_alloc& e) { + // XNNPACK can gracefully handle allocation failures, so return nullptr. + // We want to be able to recover from a failed attempt to load a large + // model without a crash. + ET_LOG( + Error, + "XNN weight cache failed to allocate %zu bytes: %s.", + n, + e.what()); + return nullptr; + } } size_t XNNWeightsCache::look_up_or_insert( @@ -201,11 +228,11 @@ size_t XNNWeightsCache::look_up_or_insert( weight_bias_name.append(bias_entry->second); } } - PackedDataMeta packed_data_metadata = { - .offset = next_offset, - .ref_count = - 0, // ref_count is only incremented after finalizing for runtime - .in_current_runtime = true}; + PackedDataMeta packed_data_metadata; + packed_data_metadata.offset = next_offset; + packed_data_metadata.ref_count = + 0; // ref_count is only incremented after finalizing for runtime + packed_data_metadata.in_current_runtime = true; context->name_to_packed_data_metadata_[weight_bias_name] = packed_data_metadata; } else { diff --git a/backends/xnnpack/runtime/XNNWorkspace.h b/backends/xnnpack/runtime/XNNWorkspace.h new file mode 100644 index 00000000000..36596b05089 --- /dev/null +++ b/backends/xnnpack/runtime/XNNWorkspace.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include + +namespace executorch::backends::xnnpack { + +using WorkspacePtr = + std::unique_ptr; + +/// A lightweight wrapper around an underlying xnn_workspace_t instance, bundled +/// with appropriate synchronization. +class XNNWorkspace { + public: + XNNWorkspace(WorkspacePtr workspace) : workspace_(std::move(workspace)){}; + XNNWorkspace(const XNNWorkspace&) = delete; + XNNWorkspace& operator=(const XNNWorkspace&) = delete; + // Not moveable due to std::mutex. + XNNWorkspace(XNNWorkspace&&) = delete; + XNNWorkspace& operator=(XNNWorkspace&&) = delete; + + std::pair, xnn_workspace_t> acquire() { + auto lock = std::unique_lock(mutex_); + return {std::move(lock), workspace_.get()}; + } + + // Return the workspace pointer withot acquiring the lock. This should be used + // carefully, as it can lead to crashes or data corruption if the workspace is + // used concurrently.s + xnn_workspace_t unsafe_get_workspace() { + return workspace_.get(); + } + + static runtime::Result> create() { + // Because this class can't be moved, we need to construct it in-place. + xnn_workspace_t workspace = nullptr; + auto status = xnn_create_workspace(&workspace); + if (status != xnn_status_success) { + ET_LOG( + Error, + "Failed to create XNN workspace, XNNPACK status: 0x%x", + (unsigned int)status); + return runtime::Error::Internal; + } + + return std::make_shared( + WorkspacePtr(workspace, &xnn_release_workspace)); + } + + private: + std::mutex mutex_; + WorkspacePtr workspace_; +}; + +} // namespace executorch::backends::xnnpack diff --git a/backends/xnnpack/runtime/XNNWorkspaceManager.cpp b/backends/xnnpack/runtime/XNNWorkspaceManager.cpp new file mode 100644 index 00000000000..d8c6dae4d6d --- /dev/null +++ b/backends/xnnpack/runtime/XNNWorkspaceManager.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include // For PRIuPTR + +namespace executorch::backends::xnnpack { + +using executorch::runtime::Error; +using executorch::runtime::Result; + +XNNWorkspaceManager::XNNWorkspaceManager() { +#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE + sharing_mode_ = WorkspaceSharingMode::Global; +#else + sharing_mode_ = WorkspaceSharingMode::Disabled; +#endif // ENABLE_XNNPACK_SHARED_WORKSPACE +} + +runtime::Error XNNWorkspaceManager::set_sharing_mode( + WorkspaceSharingMode mode) { + // Validate that the mode is valid + if (static_cast(mode) < 0 || + static_cast(mode) >= static_cast(WorkspaceSharingMode::Count)) { + ET_LOG( + Error, + "XNNPACK workspace sharing mode must be between 0 and %d, inclusive, but was %d.", + static_cast(WorkspaceSharingMode::Count) - 1, + static_cast(mode)); + return runtime::Error::InvalidArgument; + } + + sharing_mode_ = mode; + return runtime::Error::Ok; +} + +WorkspaceSharingMode XNNWorkspaceManager::get_sharing_mode() const { + return sharing_mode_.load(); +} + +Result> +XNNWorkspaceManager::get_or_create_workspace(uintptr_t program_id) const { + auto mode = sharing_mode_.load(); + + // Get or create the workspace according to the current sharing mode. + if (mode == WorkspaceSharingMode::Disabled) { + ET_LOG(Debug, "Instantiating workspace."); + auto create_result = XNNWorkspace::create(); + if (!create_result.ok()) { + return create_result.error(); + } + + return create_result.get(); + } else if (mode == WorkspaceSharingMode::PerModel) { + return get_or_create_model_workspace(program_id); + } else if (mode == WorkspaceSharingMode::Global) { + return get_or_create_global_workspace(); + } else { + ET_LOG( + Error, "Invalid workspace sharing mode: %d.", static_cast(mode)); + return Error::Internal; + } +} + +Result> +XNNWorkspaceManager::get_or_create_global_workspace() const { + std::scoped_lock lock(workspace_meta_mutex_); + + // Check for an existing (live) global workspace. + std::shared_ptr workspace = {}; + if (auto live_workspace = global_workspace_.lock()) { + workspace = live_workspace; + } + + // Allocate a new workspace if needed. + if (!workspace) { + auto create_result = XNNWorkspace::create(); + if (!create_result.ok()) { + return create_result.error(); + } + workspace = create_result.get(); + ET_LOG( + Debug, + "Created global workspace %p.", + workspace->unsafe_get_workspace()); + global_workspace_ = workspace; + } + + return workspace; +} + +Result> +XNNWorkspaceManager::get_or_create_model_workspace(uintptr_t program_id) const { + std::scoped_lock lock(workspace_meta_mutex_); + + // Check for an existing (live) workspace for this program. + auto match = model_workspaces_.find(program_id); + std::shared_ptr workspace = {}; + if (match != model_workspaces_.end()) { + if (auto live_workspace = match->second.lock()) { + workspace = live_workspace; + } + } + + // Allocate a new workspace if needed. + if (!workspace) { + auto create_result = XNNWorkspace::create(); + if (!create_result.ok()) { + return create_result.error(); + } + workspace = create_result.get(); + ET_LOG( + Debug, + "Created workspace %p for program %" PRIuPTR ".", + workspace->unsafe_get_workspace(), + program_id); + model_workspaces_.insert( + {program_id, std::weak_ptr(workspace)}); + } + + return workspace; +} + +} // namespace executorch::backends::xnnpack diff --git a/backends/xnnpack/runtime/XNNWorkspaceManager.h b/backends/xnnpack/runtime/XNNWorkspaceManager.h new file mode 100644 index 00000000000..52db1184bbd --- /dev/null +++ b/backends/xnnpack/runtime/XNNWorkspaceManager.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace executorch::backends::xnnpack { + +/** + * XNNWorkspaceManager manages XNNPACK workspaces based on the configured + * workspace sharing mode. + * + * It supports three modes: + * - Disabled: Each delegate instance gets its own workspace + * - PerModel: All delegate instances in a model share a workspace + * - Global: All delegate instances across all models share a workspace + */ +class XNNWorkspaceManager { + public: + XNNWorkspaceManager(); + ~XNNWorkspaceManager() = default; + + /** + * Set the workspace sharing mode. + * + * @param mode The workspace sharing mode to set. + * @return Error::Ok if the mode was set successfully. + */ + runtime::Error set_sharing_mode(WorkspaceSharingMode mode); + + /** + * Get the current workspace sharing mode. + * + * @return The current workspace sharing mode. + */ + WorkspaceSharingMode get_sharing_mode() const; + + /** + * Retrieve a workspace for the given program ID, depending on the sharing + * mode. A workspace will be created if needed. + * + * @param program_id The ID of the program requesting a workspace. + * @return A Result containing a shared_ptr to the workspace, or an error. + */ + runtime::Result> get_or_create_workspace( + uintptr_t program_id) const; + + private: + // The active sharing mode. Changes to this affect only models loaded after + // the change. + std::atomic sharing_mode_; + + // A mutex guarding global_workspace_ and model_workspaces_. Note that this + // mutex only guards the top-level definitions, not the contents of the + // workspace. The contents of the workspace are guarded by the workspace's own + // mutex in the XNNWorkspace class. + mutable std::mutex workspace_meta_mutex_; + + // A global workspace for all delegate instances, if global sharing is + // enabled. Lazy initialized. Stored as a weak pointer to allow automatic + // cleanup when all references are released. + mutable std::weak_ptr global_workspace_; + + // A map from program id to workspace for delegate instances, if per model + // sharing is enabled. Workspaces are owned by the executor instances via + // shared_ptr. They are tracked here via weak pointers to allow automatic + // cleanup when the executors are destroyed while being retrievable when + // instantiating new executors. + mutable std::unordered_map> + model_workspaces_; + + // Retrieve the global workspace, lazy initializing it if needed. + runtime::Result> + get_or_create_global_workspace() const; + + // Get or create a workspace for the given program ID. + runtime::Result> get_or_create_model_workspace( + uintptr_t program_id) const; +}; + +} // namespace executorch::backends::xnnpack diff --git a/backends/xnnpack/runtime/utils/utils.cpp b/backends/xnnpack/runtime/utils/utils.cpp index bbcb8bc071c..0e017df978b 100644 --- a/backends/xnnpack/runtime/utils/utils.cpp +++ b/backends/xnnpack/runtime/utils/utils.cpp @@ -206,8 +206,8 @@ void vst1(int8_t* out, int8x8_t vout) { template <> void quantize_tensor_arm64_q8_wrapper( - const float* __restrict__ in, - uint8_t* __restrict__ out, + const float* ET_RESTRICT in, + uint8_t* ET_RESTRICT out, const int64_t N, const float scale, const int32_t zero_point) { @@ -216,8 +216,8 @@ void quantize_tensor_arm64_q8_wrapper( template <> void quantize_tensor_arm64_q8_wrapper( - const float* __restrict__ in, - int8_t* __restrict__ out, + const float* ET_RESTRICT in, + int8_t* ET_RESTRICT out, const int64_t N, const float scale, const int32_t zero_point) { diff --git a/backends/xnnpack/runtime/utils/utils.h b/backends/xnnpack/runtime/utils/utils.h index 2eb079f0b0c..de8ee7970dd 100644 --- a/backends/xnnpack/runtime/utils/utils.h +++ b/backends/xnnpack/runtime/utils/utils.h @@ -82,8 +82,8 @@ void vst1(T* out, Tx8 vout); template void quantize_tensor_arm64_q8( - const float* __restrict__ in, - underlying_t* __restrict__ out, + const float* ET_RESTRICT in, + underlying_t* ET_RESTRICT out, const int64_t N, const float scale, const int32_t zero_point) { @@ -117,8 +117,8 @@ void quantize_tensor_arm64_q8( template void quantize_tensor_arm64_q8_wrapper( - const float* __restrict__ in, - T* __restrict__ out, + const float* ET_RESTRICT in, + T* ET_RESTRICT out, const int64_t N, const float scale, const int32_t zero_point); diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index 950318f18dc..fb2c9b1598c 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -156,6 +156,9 @@ union XNodeUnion { XNNGelu: _XNNNode1x1, XNNTanh: _XNNNode1x1, XNNExp: _XNNNode1x1, + XNNSin: _XNNNode1x1, + XNNCopy: _XNNNode1x1, + XNNCos: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index a4efc627cbb..203469421d1 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -152,6 +152,9 @@ union XNodeUnion { XNNGelu: _XNNNode1x1, XNNTanh: _XNNNode1x1, XNNExp: _XNNNode1x1, + XNNSin: _XNNNode1x1, + XNNCopy: _XNNNode1x1, + XNNCos: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 99b64708f86..e95a55e1c01 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -347,6 +347,21 @@ class XNNPReLU(XNNNode2x1): pass +@dataclass +class XNNSin(XNNNode1x1): + pass + + +@dataclass +class XNNCos(XNNNode1x1): + pass + + +@dataclass +class XNNCopy(XNNNode1x1): + pass + + @dataclass class XNNScaledDotProductAttention: query_id: int @@ -402,6 +417,10 @@ class XNNScaledDotProductAttention: XNNLog, XNNGelu, XNNTanh, + XNNExp, + XNNSin, + XNNCopy, + XNNCos, ] diff --git a/backends/xnnpack/targets.bzl b/backends/xnnpack/targets.bzl index 0eab89a00f9..796fd887e33 100644 --- a/backends/xnnpack/targets.bzl +++ b/backends/xnnpack/targets.bzl @@ -59,6 +59,9 @@ def define_common_targets(): exported_deps = [ "//executorch/runtime/backend:interface" + aten_suffix, ], + exported_headers = [ + "runtime/XNNPACKBackend.h", + ], deps = [ third_party_dep("XNNPACK"), "//executorch/backends/xnnpack/serialization:xnnpack_flatbuffer_header", @@ -70,3 +73,13 @@ def define_common_targets(): # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) link_whole = True, ) + + runtime.cxx_library( + name = "xnnpack_interface", + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + exported_headers = [ + "runtime/XNNPACKBackend.h", + ], + ) diff --git a/backends/xnnpack/test/TARGETS b/backends/xnnpack/test/TARGETS index 5f3581b6aeb..1729a893ff4 100644 --- a/backends/xnnpack/test/TARGETS +++ b/backends/xnnpack/test/TARGETS @@ -107,9 +107,21 @@ runtime.python_test( deps = [ "//executorch/backends/xnnpack/recipes:xnnpack_recipes", "//executorch/export:lib", + "//executorch/runtime:runtime", # @manual "//pytorch/vision:torchvision", # @manual "//executorch/backends/xnnpack/test/tester:tester", "//executorch/examples/models:models", # @manual "//executorch/examples/xnnpack:models", # @manual ], ) + +runtime.python_test( + name = "test_xnnpack_partitioner", + srcs = ["test_xnnpack_partitioner.py"], + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/exir:lib", + "//executorch/extension/pybindings:portable_lib", + ], +) diff --git a/backends/xnnpack/test/ops/test_clone.py b/backends/xnnpack/test/ops/test_clone.py new file mode 100644 index 00000000000..0396b9b2bea --- /dev/null +++ b/backends/xnnpack/test/ops/test_clone.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester + + +class TestClone(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class Clone(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + z = torch.clone(x) + return z + + class CloneWithMemoryFormat(torch.nn.Module): + def __init__(self, memory_format): + super().__init__() + self.memory_format = memory_format + + def forward(self, x): + z = torch.clone(x, memory_format=self.memory_format) + return z + + def _test_clone_partitioned(self, inputs): + """Test that dim-order preserving clones are partitioned (removed)""" + ( + Tester(self.Clone(), inputs) + .export() + .check_count({"torch.ops.aten.clone.default": 1}) + .dump_artifact() + .to_edge_transform_and_lower() + .dump_artifact() + .check_not( + [ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp16_clone(self): + """Test FP16 clone - should be partitioned""" + inputs = (torch.randn(2, 3, 4, 5).to(torch.float16),) + self._test_clone_partitioned(inputs) + + def test_fp32_clone(self): + """Test FP32 clone - should be partitioned""" + inputs = (torch.randn(2, 3, 4, 5),) + self._test_clone_partitioned(inputs) + + def test_fp32_clone_2d(self): + """Test FP32 clone with 2D tensor - should be partitioned""" + inputs = (torch.randn(10, 20),) + self._test_clone_partitioned(inputs) + + def test_fp32_clone_3d(self): + """Test FP32 clone with 3D tensor - should be partitioned""" + inputs = (torch.randn(2, 3, 4),) + self._test_clone_partitioned(inputs) + + def test_fp32_clone_with_contiguous_format(self): + """Test FP32 clone with contiguous memory format - should be partitioned""" + inputs = (torch.randn(1, 3, 4, 4),) + ( + Tester(self.CloneWithMemoryFormat(torch.contiguous_format), inputs) + .export() + .to_edge_transform_and_lower() + .dump_artifact() + .check_not( + [ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_clone_with_channels_last_not_partitioned(self): + """Test FP32 clone with channels_last memory format - should NOT be partitioned""" + inputs = (torch.randn(1, 3, 4, 4),) + ( + Tester(self.CloneWithMemoryFormat(torch.channels_last), inputs) + .export() + .to_edge_transform_and_lower() + # Clone with channels_last changes dim order, so should NOT be delegated + .check( + [ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_clone_channels_last_to_contiguous_not_partitioned(self): + """Test clone from channels_last to contiguous - should NOT be partitioned""" + + class CloneChannelsLastToContiguous(torch.nn.Module): + def forward(self, x): + # Start with channels_last input + y = x.to(memory_format=torch.channels_last) + # Clone back to contiguous (changes dim order) + z = torch.clone(y, memory_format=torch.contiguous_format) + return z + + inputs = (torch.randn(1, 3, 4, 4),) + ( + Tester(CloneChannelsLastToContiguous(), inputs) + .export() + .to_edge_transform_and_lower() + .dump_artifact() + # Clone that changes dim order should NOT be delegated + .check( + [ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) diff --git a/backends/xnnpack/test/ops/test_conv1d.py b/backends/xnnpack/test/ops/test_conv1d.py index 036500b29d5..35d9bced512 100644 --- a/backends/xnnpack/test/ops/test_conv1d.py +++ b/backends/xnnpack/test/ops/test_conv1d.py @@ -126,7 +126,9 @@ def _test_conv1d( # quantized operators to be loaded and we don't want to do that in the test. if not skip_to_executorch: tester.to_executorch().serialize().run_method_and_compare_outputs( - num_runs=10, atol=0.02, rtol=0.02 + num_runs=10, + atol=0.04 if quantized else 1e-03, + rtol=0.02 if quantized else 1e-03, ) def test_fp16_conv1d(self): diff --git a/backends/xnnpack/test/ops/test_cos.py b/backends/xnnpack/test/ops/test_cos.py new file mode 100644 index 00000000000..e40d8a812f3 --- /dev/null +++ b/backends/xnnpack/test/ops/test_cos.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester + + +class TestCos(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class Cos(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + z = torch.cos(x) + return z + + def _test_cos(self, inputs, legacy_mode: bool = False, atol: float = 1e-4): + tester = ( + Tester(self.Cos(), inputs) + .export() + .check_count({"torch.ops.aten.cos.default": 1}) + ) + + if legacy_mode: + tester = tester.to_edge().partition() + else: + tester = tester.to_edge_transform_and_lower() + + ( + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_cos_default"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(atol=atol) + ) + + def test_fp16_cos(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ).to(torch.float16), + ) + self._test_cos(inputs, legacy_mode=False, atol=2e-3) + + def test_fp16_cos_legacy_mode(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ).to(torch.float16), + ) + self._test_cos(inputs, legacy_mode=True, atol=2e-3) + + def test_fp32_cos(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ), + ) + self._test_cos(inputs, legacy_mode=False) + + def test_fp32_cos_legacy_mode(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ), + ) + self._test_cos(inputs, legacy_mode=True) diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index ac6fec25732..dc92a9542a9 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -395,7 +395,9 @@ def _test_groupwise_dq_linear( quantize_( mod, Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int4, weight_granularity=PerGroup(group_size) + # pyre-ignore[16] + weight_dtype=torch.int4, + weight_granularity=PerGroup(group_size), ), ) unwrap_tensor_subclass(mod) diff --git a/backends/xnnpack/test/ops/test_sin.py b/backends/xnnpack/test/ops/test_sin.py new file mode 100644 index 00000000000..6a1b323e14c --- /dev/null +++ b/backends/xnnpack/test/ops/test_sin.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester + + +class TestSin(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class Sin(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + z = torch.sin(x) + return z + + def _test_sin(self, inputs, legacy_mode: bool = False): + tester = ( + Tester(self.Sin(), inputs) + .export() + .check_count({"torch.ops.aten.sin.default": 1}) + ) + + if legacy_mode: + tester = tester.to_edge().partition() + else: + tester = tester.to_edge_transform_and_lower() + + ( + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_sin_default"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp16_sin(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ).to(torch.float16), + ) + self._test_sin(inputs, legacy_mode=False) + + def test_fp16_sin_legacy_mode(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ).to(torch.float16), + ) + self._test_sin(inputs, legacy_mode=True) + + def test_fp32_sin(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ), + ) + self._test_sin(inputs, legacy_mode=False) + + def test_fp32_sin_legacy_mode(self): + inputs = ( + torch.Tensor( + [ + [0.0, 0.1, 0.5, 0.785398], + [-0.5, -0.785398, 1.5708, -1.5708], + ], + ), + ) + self._test_sin(inputs, legacy_mode=True) diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index a73a0eb0ad1..d823af9735e 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import unittest import torch @@ -23,6 +25,7 @@ is_quant, is_tagged_as_implicit_q_dq, ) +from executorch.exir.dialects._ops import ops as exir_ops class TestChannelsLastTaggedReshapePass(unittest.TestCase): @@ -480,3 +483,153 @@ def test_q_dq_nodes_around_copy_are_tagged(self): # Compare outputs tester.run_method_and_compare_outputs() + + def test_fp32_channels_last_tagged_reshape_pass_nhwc_view(self): + # Views are always run in NCHW for now. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view((1, 3, 3, -1)) + return self.conv2(y) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + # 4 dim order conversions - a pair at the start and end and a pair + # around the view. + exir_ops.edge.aten._to_copy.default: 4, + } + ) + ) + + def test_fp32_channels_last_tagged_reshape_pass_nchw_view_channel_modified(self): + # View cannot run in NHWC because channel and/or batch are modified. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(6, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view((1, 6, 6, -1)) + return self.conv2(y) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 4, + } + ) + ) + + def test_fp32_channels_last_tagged_reshape_pass_nchw_view_batch_modified(self): + # View cannot run in NHWC because channel and/or batch are modified. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view((2, 3, 6, -1)) + return self.conv2(y) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 4, + } + ) + ) + + def test_fp32_channels_last_tagged_reshape_pass_flatten_view(self): + # View cannot run in NHWC because tensor rank changes. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.linear1 = torch.nn.Linear(36 * 3, 1) + + def forward(self, x): + y = self.conv1(x) + y = y.view((x.shape[0], -1)) + return self.linear1(y) + + inputs = (torch.randn(1, 3, 8, 8),) + tester = ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 1, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 1, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 2, + } + ) + ) + + # Verify view is not tagged. + graph = tester.get_artifact().exported_program().module().graph + view_nodes = [ + n for n in graph.nodes if n.target == exir_ops.edge.aten.view_copy.default + ] + self.assertEqual(1, len(view_nodes)) + self.assertTrue(ChannelsLastTaggedReshapePass(None).is_nchw_node(view_nodes[0])) diff --git a/backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py b/backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py new file mode 100644 index 00000000000..2d876b372cb --- /dev/null +++ b/backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple, Union + +import executorch.backends.test.harness.stages as BaseStages + +import torch +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( + ConfigPrecisionType, +) +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, +) +from executorch.backends.xnnpack.test.tester import Quantize as XNNPackQuantize, Tester +from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower + +from executorch.exir import ExecutorchProgramManager +from executorch.exir._serialize import _deserialize_pte_binary +from executorch.exir.passes.external_constants_pass import ( + delegate_external_constants_pass_unlifted, +) +from executorch.extension.flat_tensor.serialize.serialize import ( + _deserialize_to_flat_tensor, +) + +from torchao.quantization.granularity import PerGroup +from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig + +try: + import executorch.extension.pybindings.portable_lib # noqa[F401] + import executorch.kernels.quantized # noqa[F401] + + has_quantized_ops = True +except: + has_quantized_ops = False + print("Missing quantized ops") + + +class TestPropagateCustomMetaPass(unittest.TestCase): + class ModuleLinear(torch.nn.Module): + def __init__( + self, + in_size: int = 2, + input_channels: int = 4, + output_channels: int = 4, + dtype: torch.dtype = torch.float, + use_bias: bool = False, + ): + super().__init__() + self.linear = torch.nn.Linear( + input_channels, output_channels, bias=use_bias + ).to(dtype=dtype) + + self.ic = input_channels + self.oc = output_channels + assert dtype in [torch.float, torch.half], "Unsupported op dtype" + self.op_dtype = dtype + self.in_size = in_size + + def forward(self, x: torch.Tensor): + return self.linear(x) + + def get_random_inputs(self): + inp = torch.randn(self.in_size, self.ic).to(self.op_dtype) + return (inp,) + + class Export(BaseStages.Export): + def run( + self, + artifact: torch.nn.Module, + inputs: Tuple[torch.Tensor], + ) -> None: + + tagged_module = torch.export.export( + artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True + ).module() + delegate_external_constants_pass_unlifted( + module=tagged_module, + gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd" + ) + self.exported_program = torch.export.export( + tagged_module, inputs, dynamic_shapes=self.dynamic_shapes, strict=True + ) + + def _test_linear( + self, + partitioner: XnnpackPartitioner, + quantization_stage: Union[BaseStages.Quantize, BaseStages.Quantize_], + ) -> ExecutorchProgramManager: + eager_model = self.ModuleLinear( + in_size=1, + input_channels=32, + output_channels=2, + ) + test_inputs = eager_model.get_random_inputs() + + tester = Tester(eager_model, test_inputs) + tester.quantize(quantization_stage) + tester.export(self.Export()) + tester.to_edge_transform_and_lower( + ToEdgeTransformAndLower([partitioner]) + ).to_executorch() + tester.run_method_and_compare_outputs() + + exec = tester.get_artifact() + program_buffer = exec.buffer + self.assertEqual(len(exec._tensor_data), 1) + data_buffer = bytes(exec._tensor_data["model"]) + self.assertTrue(len(data_buffer) > 200) + from executorch.extension.pybindings import portable_lib as runtime + + module = runtime._load_for_executorch_from_buffer(program_buffer, data_buffer) + output = module.forward(test_inputs) + reference_output = exec.exported_program().module()( + test_inputs[0], + ) + self.assertTrue(torch.allclose(output[0], reference_output, 1e-2)) + + # with self.assertRaises(RuntimeError): + # runtime._load_for_executorch_from_buffer(program_buffer).forward( + # test_inputs + # ) + + return exec + + def test_quantize_(self): + # Quantize with torchao quantize_ API. + DynamicallyQuantizedPartitioner = XnnpackPartitioner( + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, + per_op_mode=False, + ) + linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), + ) + exec = self._test_linear( + DynamicallyQuantizedPartitioner, BaseStages.Quantize_(config=linear_config) + ) + # PTE file has no named data. + pte_file = _deserialize_pte_binary(exec.buffer) + self.assertEqual(pte_file.named_data, None) + + # PTD file contains quantized weight and scale. + ptd_file = _deserialize_to_flat_tensor(bytes(exec._tensor_data["model"])) + self.assertEqual(len(ptd_file.named_data), 2) + + def test_pt2e_quantize(self): + # Quantize with pt2e quantize. + quant_configs = [ + # per_tensor + get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False), + # per_channel + get_symmetric_quantization_config(is_per_channel=True, is_dynamic=False), + # per_channel_dynamic + get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True), + ] + for quant_config in quant_configs: + precision = ( + ConfigPrecisionType.DYNAMIC_QUANT + if quant_config.input_activation.is_dynamic + else ConfigPrecisionType.STATIC_QUANT + ) + for per_op_mode in [True, False]: + partitioner = XnnpackPartitioner( + config_precisions=precision, per_op_mode=per_op_mode + ) + exec = self._test_linear( + partitioner, XNNPackQuantize(quantization_config=quant_config) + ) + # PTE file has no named data. + pte_file = _deserialize_pte_binary(exec.buffer) + self.assertEqual(pte_file.named_data, None) + + # PTD file contains quantized weight, and potentially scale. + ptd_file = _deserialize_to_flat_tensor( + bytes(exec._tensor_data["model"]) + ) + self.assertTrue(len(ptd_file.named_data) >= 1) diff --git a/backends/xnnpack/test/runtime/test_workspace_manager.cpp b/backends/xnnpack/test/runtime/test_workspace_manager.cpp new file mode 100644 index 00000000000..ddb7074a1ce --- /dev/null +++ b/backends/xnnpack/test/runtime/test_workspace_manager.cpp @@ -0,0 +1,280 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +#include + +using namespace ::testing; + +using executorch::backends::xnnpack::WorkspaceSharingMode; +using executorch::backends::xnnpack::XNNWorkspace; +using executorch::backends::xnnpack::XNNWorkspaceManager; +using executorch::runtime::Error; +using executorch::runtime::Result; + +class XNNWorkspaceManagerTest : public ::testing::Test { + protected: + void SetUp() override { + // Log calls will abort if PAL is not initialized. + executorch::runtime::runtime_init(); + + // Initialize a new workspace manager for each test. + workspace_manager_ = std::make_unique(); + } + + std::unique_ptr workspace_manager_; +}; + +TEST_F(XNNWorkspaceManagerTest, SetAndGetSharingMode) { + // Test setting and getting the sharing mode + EXPECT_EQ( + workspace_manager_->set_sharing_mode(WorkspaceSharingMode::Disabled), + Error::Ok); + EXPECT_EQ( + workspace_manager_->get_sharing_mode(), WorkspaceSharingMode::Disabled); + + EXPECT_EQ( + workspace_manager_->set_sharing_mode(WorkspaceSharingMode::PerModel), + Error::Ok); + EXPECT_EQ( + workspace_manager_->get_sharing_mode(), WorkspaceSharingMode::PerModel); + + EXPECT_EQ( + workspace_manager_->set_sharing_mode(WorkspaceSharingMode::Global), + Error::Ok); + EXPECT_EQ( + workspace_manager_->get_sharing_mode(), WorkspaceSharingMode::Global); +} + +TEST_F(XNNWorkspaceManagerTest, SetInvalidSharingMode) { + // First set a valid mode to ensure we're starting from a known state. + EXPECT_EQ( + workspace_manager_->set_sharing_mode(WorkspaceSharingMode::Disabled), + Error::Ok); + EXPECT_EQ( + workspace_manager_->get_sharing_mode(), WorkspaceSharingMode::Disabled); + + // Try to set an invalid mode. + WorkspaceSharingMode invalid_mode = static_cast(70); + EXPECT_EQ( + workspace_manager_->set_sharing_mode(invalid_mode), + Error::InvalidArgument); + + // The mode should not have changed. + EXPECT_EQ( + workspace_manager_->get_sharing_mode(), WorkspaceSharingMode::Disabled); +} + +TEST_F(XNNWorkspaceManagerTest, DisabledMode) { + // Verify that each call retrieves a new workspace when sharing is disabled. + workspace_manager_->set_sharing_mode(WorkspaceSharingMode::Disabled); + + uintptr_t program_id = 12345; + auto workspace1_result = + workspace_manager_->get_or_create_workspace(program_id); + ASSERT_TRUE(workspace1_result.ok()); + auto workspace1 = workspace1_result.get(); + + auto workspace2_result = + workspace_manager_->get_or_create_workspace(program_id); + ASSERT_TRUE(workspace2_result.ok()); + auto workspace2 = workspace2_result.get(); + + auto workspace3_result = + workspace_manager_->get_or_create_workspace(program_id + 1); + ASSERT_TRUE(workspace3_result.ok()); + auto workspace3 = workspace3_result.get(); + + EXPECT_NE(workspace1, workspace2); + EXPECT_NE(workspace1, workspace3); + EXPECT_NE(workspace2, workspace3); + EXPECT_NE( + workspace1->unsafe_get_workspace(), workspace2->unsafe_get_workspace()); + EXPECT_NE( + workspace1->unsafe_get_workspace(), workspace3->unsafe_get_workspace()); + EXPECT_NE( + workspace2->unsafe_get_workspace(), workspace3->unsafe_get_workspace()); +} + +TEST_F(XNNWorkspaceManagerTest, PerModelMode) { + // In PerModel mode, calls with the same program_id should return the same + // workspace. + workspace_manager_->set_sharing_mode(WorkspaceSharingMode::PerModel); + + // Get two workspaces with the same program ID and one different. + uintptr_t program_id = 12345; + auto workspace1_result = + workspace_manager_->get_or_create_workspace(program_id); + ASSERT_TRUE(workspace1_result.ok()); + auto workspace1 = workspace1_result.get(); + + auto workspace2_result = + workspace_manager_->get_or_create_workspace(program_id); + ASSERT_TRUE(workspace2_result.ok()); + auto workspace2 = workspace2_result.get(); + + auto workspace3_result = + workspace_manager_->get_or_create_workspace(program_id + 1); + ASSERT_TRUE(workspace3_result.ok()); + auto workspace3 = workspace3_result.get(); + + // Workspace 1 and 2 should be the same, but different from workspace 3. + EXPECT_EQ(workspace1, workspace2); + EXPECT_EQ( + workspace1->unsafe_get_workspace(), workspace2->unsafe_get_workspace()); + + EXPECT_NE(workspace1, workspace3); + EXPECT_NE( + workspace1->unsafe_get_workspace(), workspace3->unsafe_get_workspace()); +} + +TEST_F(XNNWorkspaceManagerTest, GlobalMode) { + // In Global mode, all calls should return the same workspace. + workspace_manager_->set_sharing_mode(WorkspaceSharingMode::Global); + + // Get workspaces with different program IDs + uintptr_t program_id1 = 12345; + auto workspace1_result = + workspace_manager_->get_or_create_workspace(program_id1); + ASSERT_TRUE(workspace1_result.ok()); + auto workspace1 = workspace1_result.get(); + + uintptr_t program_id2 = 67890; + auto workspace2_result = + workspace_manager_->get_or_create_workspace(program_id2); + ASSERT_TRUE(workspace2_result.ok()); + auto workspace2 = workspace2_result.get(); + + EXPECT_EQ(workspace1, workspace2); + EXPECT_EQ( + workspace1->unsafe_get_workspace(), workspace2->unsafe_get_workspace()); +} + +TEST_F(XNNWorkspaceManagerTest, PerModelModeCleanup) { + // Test that workspaces are properly cleaned up when shared_ptr is destroyed + workspace_manager_->set_sharing_mode(WorkspaceSharingMode::PerModel); + + uintptr_t program_id = 12345; + xnn_workspace_t raw_workspace1 = nullptr; + + // Create a scope to control the lifetime of workspace1 + { + auto workspace1_result = + workspace_manager_->get_or_create_workspace(program_id); + ASSERT_TRUE(workspace1_result.ok()); + auto workspace1 = workspace1_result.get(); + + // Store the raw pointer for later comparison + raw_workspace1 = workspace1->unsafe_get_workspace(); + + // Let workspace1 go out of scope and be destroyed + } + + // Get a new workspace with the same program ID + auto workspace2_result = + workspace_manager_->get_or_create_workspace(program_id); + ASSERT_TRUE(workspace2_result.ok()); + auto workspace2 = workspace2_result.get(); + + // Since the previous workspace was destroyed, we should get a new one. + EXPECT_NE(workspace2->unsafe_get_workspace(), raw_workspace1); +} + +TEST_F(XNNWorkspaceManagerTest, GlobalModeCleanup) { + // Test that global workspaces are properly cleaned up when all users + // are destroyed. + workspace_manager_->set_sharing_mode(WorkspaceSharingMode::Global); + + uintptr_t program_id = 12345; + xnn_workspace_t raw_workspace1 = nullptr; + + // Create a scope to control the lifetime of workspace1 + { + auto workspace1_result = + workspace_manager_->get_or_create_workspace(program_id); + ASSERT_TRUE(workspace1_result.ok()); + auto workspace1 = workspace1_result.get(); + + // Store the raw pointer for later comparison + raw_workspace1 = workspace1->unsafe_get_workspace(); + + // Let workspace1 go out of scope and be destroyed + } + + // Get a new workspace (program ID doesn't matter in Global mode) + auto workspace2_result = + workspace_manager_->get_or_create_workspace(program_id); + ASSERT_TRUE(workspace2_result.ok()); + auto workspace2 = workspace2_result.get(); + + // Since the previous workspace was destroyed, we should get a new one. + EXPECT_NE(workspace2->unsafe_get_workspace(), raw_workspace1); +} + +TEST_F(XNNWorkspaceManagerTest, SwitchingModes) { + // Test switching between different sharing modes + + // Start with Disabled mode + workspace_manager_->set_sharing_mode(WorkspaceSharingMode::Disabled); + + // Get a workspace + uintptr_t program_id = 12345; + auto workspace1_result = + workspace_manager_->get_or_create_workspace(program_id); + ASSERT_TRUE(workspace1_result.ok()); + auto workspace1 = workspace1_result.get(); + + // Switch to PerModel mode + workspace_manager_->set_sharing_mode(WorkspaceSharingMode::PerModel); + + // Get another workspace with the same program ID + auto workspace2_result = + workspace_manager_->get_or_create_workspace(program_id); + ASSERT_TRUE(workspace2_result.ok()); + auto workspace2 = workspace2_result.get(); + + // Should be a different workspace + EXPECT_NE(workspace1, workspace2); + + // Get another workspace with the same program ID in PerModel mode + auto workspace3_result = + workspace_manager_->get_or_create_workspace(program_id); + ASSERT_TRUE(workspace3_result.ok()); + auto workspace3 = workspace3_result.get(); + + // Should be the same workspace as workspace2 + EXPECT_EQ(workspace2, workspace3); + + // Switch to Global mode + workspace_manager_->set_sharing_mode(WorkspaceSharingMode::Global); + + // Get another workspace + auto workspace4_result = + workspace_manager_->get_or_create_workspace(program_id); + ASSERT_TRUE(workspace4_result.ok()); + auto workspace4 = workspace4_result.get(); + + // Should be a different workspace since we switched modes + EXPECT_NE(workspace3, workspace4); + + // Get a workspace with a different program ID in Global mode + uintptr_t different_program_id = 67890; + auto workspace5_result = + workspace_manager_->get_or_create_workspace(different_program_id); + ASSERT_TRUE(workspace5_result.ok()); + auto workspace5 = workspace5_result.get(); + + // Should be the same workspace as workspace4 + EXPECT_EQ(workspace4, workspace5); +} diff --git a/backends/xnnpack/test/runtime/test_workspace_sharing.cpp b/backends/xnnpack/test/runtime/test_workspace_sharing.cpp new file mode 100644 index 00000000000..66f0d012acd --- /dev/null +++ b/backends/xnnpack/test/runtime/test_workspace_sharing.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; + +using executorch::backends::xnnpack::workspace_sharing_mode_option_key; +using executorch::backends::xnnpack::WorkspaceSharingMode; +using executorch::backends::xnnpack::xnnpack_backend_key; +using executorch::extension::Module; +using executorch::extension::TensorPtr; +using executorch::runtime::BackendOption; +using executorch::runtime::BackendOptions; +using executorch::runtime::Error; + +TensorPtr create_input_tensor(float val); +void run_and_validate_two_models( + std::optional mode1 = std::nullopt, + std::optional mode2 = std::nullopt); +void set_and_check_workspace_sharing_mode(WorkspaceSharingMode mode); + +TEST(WorkspaceSharing, SetMode) { + // Try setting and reading back the mode a few times. + set_and_check_workspace_sharing_mode(WorkspaceSharingMode::Disabled); + set_and_check_workspace_sharing_mode(WorkspaceSharingMode::PerModel); + set_and_check_workspace_sharing_mode(WorkspaceSharingMode::Global); +} + +TEST(WorkspaceSharing, SetInvalidMode) { + // Make sure we can't set an invalid mode. + + // Set to an initial known value. + set_and_check_workspace_sharing_mode(WorkspaceSharingMode::PerModel); + + // Set to a bad value. + BackendOptions<1> backend_options; + backend_options.set_option(workspace_sharing_mode_option_key, 70); + + auto status = executorch::runtime::set_option( + xnnpack_backend_key, backend_options.view()); + ASSERT_EQ(status, Error::InvalidArgument); + + // Make sure the option is still set to a valid value. + BackendOption read_option; + strcpy(read_option.key, workspace_sharing_mode_option_key); + read_option.value = -1; + status = get_option(xnnpack_backend_key, read_option); + + ASSERT_TRUE( + std::get(read_option.value) == + static_cast(WorkspaceSharingMode::PerModel)); +} + +TEST(WorkspaceSharing, RunWithDisabledMode) { + // Load and run some PTEs with workspace sharing disabled. + run_and_validate_two_models(WorkspaceSharingMode::Disabled); +} + +TEST(WorkspaceSharing, RunWithPerModelMode) { + // Load and run some PTEs with per-model workspace sharing. + run_and_validate_two_models(WorkspaceSharingMode::PerModel); +} + +TEST(WorkspaceSharing, RunWithGlobalMode) { + // Load and run some PTEs with global workspace sharing. + run_and_validate_two_models(WorkspaceSharingMode::Global); +} + +TEST(WorkspaceSharing, RunWithModeSwitch) { + // Check each pair of modes, loading one model in one mode and the other in + // the other mode. + + std::array modes = { + WorkspaceSharingMode::Disabled, + WorkspaceSharingMode::PerModel, + WorkspaceSharingMode::Global}; + + for (auto i = 0; i < modes.size(); ++i) { + for (auto j = i + 1; j < modes.size(); ++j) { + run_and_validate_two_models(modes[i], modes[j]); + } + } +} + +TensorPtr create_input_tensor(float val) { + // Create an f32 tensor with shape [10, 10, 10], matching the input of the + // test models. + std::vector data(1000, val); + + // Note that the tensor pointer takes ownership of the data vector. + return executorch::extension::make_tensor_ptr({10, 10, 10}, std::move(data)); +} + +void run_and_validate_two_models( + std::optional mode1, + std::optional mode2) { + // Load and run two models, verifying that the output tensors are correct, + // optionally setting sharing mode. + + if (mode1) { + set_and_check_workspace_sharing_mode(*mode1); + } + + Module mod1(std::getenv("ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH")); + + auto a = create_input_tensor(1.0); + auto b = create_input_tensor(2.0); + auto c = create_input_tensor(3.0); + + auto result = mod1.forward({a, b, c}); + EXPECT_TRUE(result.ok()); + + // Expected output is 2a + 2b + c. + auto output_val = 1.0 * 2 + 2.0 * 2 + 3.0; + auto& output_tensor = result.get()[0].toTensor(); + for (auto i = 0; i < output_tensor.numel(); ++i) { + ASSERT_EQ(output_tensor.const_data_ptr()[i], output_val); + } + + if (mode2) { + set_and_check_workspace_sharing_mode(*mode2); + } + + Module mod2(std::getenv("ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH")); + + auto result2 = mod2.forward({a, b, c}); + EXPECT_TRUE(result2.ok()); + + // Expected output is zero (the subtract operations cancel out). + auto& output_tensor2 = result2.get()[0].toTensor(); + for (auto i = 0; i < output_tensor2.numel(); ++i) { + ASSERT_EQ(output_tensor2.const_data_ptr()[i], 0); + } + + // Run mod1 again to validate that it gives correct results in the second mode + auto result3 = mod1.forward({a, b, c}); + EXPECT_TRUE(result3.ok()); + + // Expected output is still 2a + 2b + c + auto& output_tensor3 = result3.get()[0].toTensor(); + for (auto i = 0; i < output_tensor3.numel(); ++i) { + ASSERT_EQ(output_tensor3.const_data_ptr()[i], output_val); + } +} + +void set_and_check_workspace_sharing_mode(WorkspaceSharingMode mode) { + executorch::runtime::runtime_init(); + + BackendOptions<1> backend_options; + backend_options.set_option( + workspace_sharing_mode_option_key, static_cast(mode)); + + auto status = executorch::runtime::set_option( + xnnpack_backend_key, backend_options.view()); + ASSERT_EQ(status, Error::Ok); + + // Read the option back to sanity check. + BackendOption read_option; + strcpy(read_option.key, workspace_sharing_mode_option_key); + read_option.value = -1; + status = get_option(xnnpack_backend_key, read_option); + + ASSERT_TRUE(std::get(read_option.value) == static_cast(mode)); +} diff --git a/backends/xnnpack/test/runtime/test_xnnexecutor.cpp b/backends/xnnpack/test/runtime/test_xnnexecutor.cpp index b2a56f6283d..568c3c4ec35 100644 --- a/backends/xnnpack/test/runtime/test_xnnexecutor.cpp +++ b/backends/xnnpack/test/runtime/test_xnnexecutor.cpp @@ -18,7 +18,7 @@ using executorch::runtime::Span; using executorch::runtime::testing::TensorFactory; TEST(XNNExecutorTest, ArgumentWithTooManyDimensions) { - XNNExecutor executor; + XNNExecutor executor({}); xnn_subgraph_t subgraph = nullptr; xnn_runtime_t rt = nullptr; et_pal_init(); diff --git a/backends/xnnpack/test/targets.bzl b/backends/xnnpack/test/targets.bzl index f175e9655ea..04517c035fe 100644 --- a/backends/xnnpack/test/targets.bzl +++ b/backends/xnnpack/test/targets.bzl @@ -63,3 +63,26 @@ def define_common_targets(): "ET_MODULE_LINEAR_XNN_DATA_PATH": "$(location fbcode//executorch/test/models:exported_xnnpack_program_and_data[ModuleLinear.ptd])", }, ) + + runtime.cxx_test( + name = "test_workspace_sharing", + srcs = ["runtime/test_workspace_sharing.cpp"], + deps = [ + "//executorch/extension/module:module", + "//executorch/extension/tensor:tensor", + "//executorch/backends/xnnpack:xnnpack_backend", + ], + env = { + "ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH": "$(location fbcode//executorch/test/models:exported_xnnp_delegated_programs[ModuleAddLarge.pte])", + "ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH": "$(location fbcode//executorch/test/models:exported_xnnp_delegated_programs[ModuleSubLarge.pte])", + }, + ) + + runtime.cxx_test( + name = "test_workspace_manager", + srcs = ["runtime/test_workspace_manager.cpp"], + deps = [ + third_party_dep("XNNPACK"), + "//executorch/backends/xnnpack:xnnpack_backend", + ], + ) diff --git a/backends/xnnpack/test/test_xnnpack_partitioner.py b/backends/xnnpack/test/test_xnnpack_partitioner.py index 8cd9eb92d56..894fab4098f 100644 --- a/backends/xnnpack/test/test_xnnpack_partitioner.py +++ b/backends/xnnpack/test/test_xnnpack_partitioner.py @@ -9,8 +9,13 @@ import unittest import torch +import torch.nn.functional as F + from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.exir import to_edge, to_edge_transform_and_lower +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, +) from torch.export import export @@ -82,3 +87,77 @@ def test_no_warning_for_to_edge_transform_and_lower_workflow(self): log_contents = log_capture_string.getvalue() self.assertNotIn("DEPRECATION WARNING", log_contents) + + def test_multi_method_partitioning_with_shared_weights(self): + """ + Test that multi-method models with shared weights are correctly partitioned. + Verify that: + 1. Both methods are fully lowered to XNNPACK. + 2. Constants are not duplicated between named data and constant buffers. + 3. Program executes correctly. + """ + + class MultiMethodModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(8, 16) + self.linear2 = torch.nn.Linear(16, 8) + + def forward(self, x): + return self.linear2(F.sigmoid(self.linear(x))) + + def forward_2(self, x): + return self.linear2(F.relu(self.linear(x))) + + def example_inputs(self): + return (torch.randn(1, 8),) + + model = MultiMethodModel() + + # Get eager reference output. + example_inputs = model.example_inputs() + with torch.no_grad(): + fwd1_eager = model.forward(*example_inputs) + fwd2_eager = model.forward_2(*example_inputs) + + # Export both methods + ep_fwd = export(model, model.example_inputs(), strict=True) + # Patch the forward, as export only traces the 'forward' method. + model.forward = model.forward_2 + ep_fwd_2 = export(model, model.example_inputs(), strict=True) + + # Convert to edge and lower to executorch + edge = to_edge({"forward": ep_fwd, "forward_2": ep_fwd_2}) + lowered = edge.to_backend(XnnpackPartitioner(force_fp32_dynamic_linear=True)) + executorch = lowered.to_executorch() + + # Check that graph is fully delegated. + nodes_1 = list(lowered._edge_programs["forward"].graph.nodes) + nodes_2 = list(lowered._edge_programs["forward_2"].graph.nodes) + self.assertEqual(len(nodes_1), 5) + self.assertEqual(len(nodes_2), 5) + expected_node_names = [ + "x", + "lowered_module_0", + "executorch_call_delegate", + "getitem", + "output_1", + ] + for n in expected_node_names: + self.assertTrue(any(node.name == n for node in nodes_1)) + self.assertTrue(any(node.name == n for node in nodes_2)) + + # Check that weights are not duplicated. + self.assertEqual(len(executorch._named_data.pte_data), 4) + self.assertEqual(len(executorch._named_data.buffers), 4) + self.assertEqual(len(executorch._named_data.external_data), 0) + + # Check that there are no constant buffers (besides the placeholder). + self.assertEqual(len(executorch._emitter_output.program.constant_buffer), 1) + + # Check for model correctness. + executorch_module = _load_for_executorch_from_buffer(executorch.buffer) + fwd1_et = executorch_module.run_method("forward", example_inputs) + fwd2_et = executorch_module.run_method("forward_2", example_inputs) + self.assertTrue(torch.allclose(fwd1_eager, fwd1_et[0], 1e-3)) + self.assertTrue(torch.allclose(fwd2_eager, fwd2_et[0], 1e-3)) diff --git a/backends/xnnpack/xnnpack_preprocess.py b/backends/xnnpack/xnnpack_preprocess.py index 05fb53a837d..cdceb8a90a1 100644 --- a/backends/xnnpack/xnnpack_preprocess.py +++ b/backends/xnnpack/xnnpack_preprocess.py @@ -71,6 +71,11 @@ def generate_node_to_external_map( if node.op == "output": for output_nodes in node.args: for output_node in output_nodes: + if output_node in node_to_external_map: + raise RuntimeError( + f"Output node '{output_node}' is already in the inputs. " + "This is likely due to pass through arguments, which are not supported in XNNPACK Delegate." + ) node_to_external_map[output_node] = ExternalMeta( external_id=len(node_to_external_map), io_type=XNN_VALUE_FLAG_EXTERNAL_OUTPUT, diff --git a/codegen/tools/CMakeLists.txt b/codegen/tools/CMakeLists.txt index 489a96aafb6..2d61a4d68c1 100644 --- a/codegen/tools/CMakeLists.txt +++ b/codegen/tools/CMakeLists.txt @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -24,10 +25,23 @@ target_include_directories( # Compile options target_compile_options( - selective_build PUBLIC -Wno-deprecated-declarations -fPIC -frtti -fexceptions + selective_build + PUBLIC -Wno-deprecated-declarations + -fPIC + -frtti + -fexceptions + -Werror + -Wunused-variable + -Wno-unknown-argument ) +# We suppress -Wno-unknown-argument because our build system passes -fPIC for +# Unix builds, but we also build on Windows where it's ignored # Link against required libraries +if(TARGET bundled_program) + target_compile_definitions(selective_build PRIVATE -DET_BUNDLE_IO) + target_link_libraries(selective_build PRIVATE bundled_program) +endif() target_link_libraries(selective_build PRIVATE executorch_core program_schema) # Install the module diff --git a/codegen/tools/combine_prim_ops_headers.py b/codegen/tools/combine_prim_ops_headers.py new file mode 100644 index 00000000000..b579de2047d --- /dev/null +++ b/codegen/tools/combine_prim_ops_headers.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Script to combine multiple selected_prim_ops.h header files into a single header. +This is used by selected_prim_operators_genrule to merge prim ops headers from dependencies. +""" + +import argparse +import os +import sys +from pathlib import Path +from typing import List, Set + + +def read_header_file(file_path: Path) -> Set[str]: + """ + Read a selected_prim_ops.h file and extract the macros and comments. + + Args: + file_path: Path to the header file + + Returns: + macros_set where macros_set contains unique macro defines + """ + macros = set() + + try: + with open(file_path, "r") as f: + for line in f: + line = line.strip() + + # Extract #define statements for prim ops + if line.startswith("#define INCLUDE_") and not line.startswith( + "#define EXECUTORCH_ENABLE" + ): + macros.add(line) + except FileNotFoundError: + print(f"Warning: Header file not found: {file_path}", file=sys.stderr) + except Exception as e: + print(f"Error reading {file_path}: {e}", file=sys.stderr) + + return macros + + +def combine_prim_ops_headers(header_file_paths: List[str], output_path: str) -> None: + """ + Combine multiple selected_prim_ops.h files into a single header. + + Args: + header_files: List of paths to header files to combine + output_path: Path to output the combined header + """ + all_macros = set() + has_selective_build = False + + # Read all header files and collect unique macros + for header_file_path in header_file_paths: + header_file = Path(header_file_path) / "selected_prim_ops.h" + if os.path.exists(header_file): + macros = read_header_file(header_file) + all_macros.update(macros) + if len(all_macros) > 0: + has_selective_build = True + else: + print( + f"Warning: Header file does not exist: {header_file}", file=sys.stderr + ) + + # Generate combined header + header_content = [ + "// Combined header for selective prim ops build", + "// This file is auto-generated by combining multiple selected_prim_ops.h files", + "// Do not edit manually.", + "", + "#pragma once", + "", + ] + + if all_macros and has_selective_build: + header_content.extend( + [ + "// Enable selective build for prim ops", + "#define EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD", + "", + "// Combined prim ops macros from all dependencies", + ] + ) + + # Sort macros for deterministic output + sorted_macros = sorted(all_macros) + header_content.extend(sorted_macros) + else: + header_content.extend( + [ + "// No prim ops found in dependencies - all prim ops will be included", + "// Selective build is disabled", + ] + ) + + header_content.append("") + + # Write the combined header + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as f: + f.write("\n".join(header_content)) + + +def _get_header_file_paths_from_query_output(query_output_file: str) -> List[str]: + """ + Parse the output of a Buck query command to extract header file paths. + + Args: + query_output_file: Path to the file containing the query output + + Returns: + List of header file paths + """ + header_file_paths = [] + assert ( + query_output_file[0] == "@" + ), "query_output_file is not a valid file path, or it doesn't start with '@'." + query_output_file = query_output_file[1:] + + with open(query_output_file, "r") as f: + for line in f: + # Extract the header file path from the query output + header_file_paths += line.split() + return header_file_paths + + +def main(): + parser = argparse.ArgumentParser( + description="Combine multiple selected_prim_ops.h header files" + ) + parser.add_argument( + "--header_files", + required=True, + help="Comma-separated list of header file paths", + ) + parser.add_argument( + "--output_dir", required=True, help="Output directory for combined header" + ) + + args = parser.parse_args() + import os + + header_file_paths = _get_header_file_paths_from_query_output(args.header_files) + + if not header_file_paths: + print("Error: No header files provided", file=sys.stderr) + sys.exit(1) + + # Generate output path + output_path = os.path.join(args.output_dir, "selected_prim_ops.h") + + combine_prim_ops_headers(header_file_paths, output_path) + + +if __name__ == "__main__": + main() diff --git a/codegen/tools/gen_all_oplist.py b/codegen/tools/gen_all_oplist.py index 5cb93bb9153..f33c3dc935d 100644 --- a/codegen/tools/gen_all_oplist.py +++ b/codegen/tools/gen_all_oplist.py @@ -10,7 +10,7 @@ import sys from functools import reduce from pathlib import Path -from typing import Any, List +from typing import Any, Dict, List import yaml from torchgen.selective_build.selector import ( @@ -72,6 +72,19 @@ def _raise_if_check_prim_ops_fail(options): raise Exception(error) +def _selected_ops_model_dict_is_empty(model_dict: Dict[str, Any]) -> bool: + return ( + not model_dict.get("build_features", []) + and not model_dict.get("custom_classes", []) + and not model_dict.get("et_kernel_metadata", None) + and not model_dict.get("include_all_non_op_selectives", False) + and not model_dict.get("include_all_operators", False) + and not model_dict.get("kernel_metadata", {}) + and not model_dict.get("operators", {}) + ) + + +# flake8: noqa: C901 def main(argv: List[Any]) -> None: """This binary generates 3 files: @@ -171,6 +184,11 @@ def main(argv: List[Any]) -> None: ), f"{model_file_name} is not a valid file path. This is likely a BUCK issue." with open(model_file_name, "rb") as model_file: model_dict = yaml.safe_load(model_file) + # It is possible that we created an empty yaml file. + # This is because et_operator_library may only contain prim ops. + # In that case selected_operators.yaml will be empty. + if _selected_ops_model_dict_is_empty(model_dict): + continue resolved = resolve_model_file_path_to_buck_target(model_file_name) for op in model_dict["operators"]: model_dict["operators"][op]["debug_info"] = [resolved] diff --git a/codegen/tools/gen_oplist.py b/codegen/tools/gen_oplist.py index cca5bf1b1d2..28506050a8e 100644 --- a/codegen/tools/gen_oplist.py +++ b/codegen/tools/gen_oplist.py @@ -9,6 +9,7 @@ import os import sys from enum import IntEnum +from pathlib import Path from typing import Any, Dict, List, Optional, Set import yaml @@ -158,7 +159,7 @@ def _get_et_kernel_metadata_from_ops_yaml(ops_yaml_path: str) -> Dict[str, List[ def _dump_yaml( op_list: List[str], - output_path: str, + output_path: Path, model_name: Optional[str] = None, et_kernel_metadata: Optional[Dict[str, List[str]]] = None, include_all_operators: bool = False, @@ -212,20 +213,23 @@ def create_kernel_key(maybe_kernel_key: str) -> str: def gen_oplist( - output_path: str, + output_path: Path, model_file_path: Optional[str] = None, ops_schema_yaml_path: Optional[str] = None, root_ops: Optional[str] = None, ops_dict: Optional[str] = None, include_all_operators: bool = False, ): - assert ( + if not ( model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators - ), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators." + ): + # dump empty yaml file + _dump_yaml([], output_path) + return assert output_path, "Need to provide output_path for dumped yaml file." op_set = set() @@ -326,9 +330,15 @@ def main(args: List[Any]) -> None: ) options = parser.parse_args(args) + # check if the output_path is a directory, then generate operators + # under selected_operators.yaml + if Path(options.output_path).is_dir(): + output_path = Path(options.output_path) / "selected_operators.yaml" + else: + output_path = Path(options.output_path) try: gen_oplist( - output_path=options.output_path, + output_path=output_path, model_file_path=options.model_file_path, ops_schema_yaml_path=options.ops_schema_yaml_path, root_ops=options.root_ops, diff --git a/codegen/tools/gen_ops_def.py b/codegen/tools/gen_ops_def.py index aba3f9242ac..98fdab73fd1 100644 --- a/codegen/tools/gen_ops_def.py +++ b/codegen/tools/gen_ops_def.py @@ -23,7 +23,7 @@ def get_operators(model_file: str) -> List[Operator]: print("Processing model file: ", model_file) with open(model_file, "rb") as f: flatbuffer = f.read() - program = _deserialize_pte_binary(flatbuffer) + program = _deserialize_pte_binary(flatbuffer).program print(f"Program loaded from model file: {model_file}") operators = program.execution_plan[0].operators return operators diff --git a/codegen/tools/gen_selected_prim_ops.py b/codegen/tools/gen_selected_prim_ops.py new file mode 100644 index 00000000000..4535ffaa57a --- /dev/null +++ b/codegen/tools/gen_selected_prim_ops.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import argparse +import os +import sys +from typing import Any, List + +from torchgen.code_template import CodeTemplate # type: ignore[import-not-found] + + +selected_prim_ops_h_template_str = """#pragma once +/** + * Generated by executorch/codegen/tools/gen_selected_prim_ops.py + */ + +$defines +""" +selected_prim_ops_h_template = CodeTemplate(selected_prim_ops_h_template_str) + + +def normalize_op_name(op_name: str) -> str: + """ + Normalize an operator name to a macro-safe format. + Convert op names like "executorch_prim::et_view.default" to "EXECUTORCH_PRIM_ET_VIEW_DEFAULT" + or "aten::sym_size.int" to "ATEN_SYM_SIZE_INT" + """ + # Remove namespace separator and replace with underscore + normalized = op_name.replace("::", "_") + # Replace dots with underscores + normalized = normalized.replace(".", "_") + # Convert to uppercase + normalized = normalized.upper() + # Add INCLUDE_ prefix + normalized = f"INCLUDE_{normalized}" + return normalized + + +def write_selected_prim_ops(prim_op_names: List[str], output_dir: str) -> None: + """ + Generate selected_prim_ops.h from a list of prim op names. + + Args: + prim_op_names: List of prim op names like ["executorch_prim::et_view.default", "aten::sym_size.int"] + output_dir: Directory where to write selected_prim_ops.h + """ + # Generate #define statements for each op + defines = [] + for op_name in prim_op_names: + macro_name = normalize_op_name(op_name) + defines.append(f"#define {macro_name}") + + # Join all defines with newlines + defines_str = "\n".join(defines) + + # Generate header content + header_contents = selected_prim_ops_h_template.substitute(defines=defines_str) + + # Write to file + selected_prim_ops_path = os.path.join(output_dir, "selected_prim_ops.h") + with open(selected_prim_ops_path, "wb") as out_file: + out_file.write(header_contents.encode("utf-8")) + + +def main(argv: List[Any]) -> None: + parser = argparse.ArgumentParser(description="Generate selected prim ops header") + parser.add_argument( + "--prim-op-names", + "--prim_op_names", + help="Comma-separated list of prim op names to include", + required=True, + ) + parser.add_argument( + "--output-dir", + "--output_dir", + help="The directory to store the output header file (selected_prim_ops.h)", + required=True, + ) + + options = parser.parse_args(argv) + + # Parse comma-separated prim op names + prim_op_names = [ + name.strip() for name in options.prim_op_names.split(",") if name.strip() + ] + + write_selected_prim_ops(prim_op_names, options.output_dir) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/codegen/tools/selective_build.cpp b/codegen/tools/selective_build.cpp index d33ff12ec9f..a34789e129d 100644 --- a/codegen/tools/selective_build.cpp +++ b/codegen/tools/selective_build.cpp @@ -1,16 +1,21 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ +#include +#include #include #include -#include -#include +#ifdef ET_BUNDLE_IO +#include +#include +#endif namespace py = pybind11; @@ -186,8 +191,39 @@ get_kernel_tensor_metadatas_from_execution_plan( const executorch_flatbuffer::Program* _get_program_from_buffer( const py::bytes& buffer) { + // Access the Python bytes without copying and get raw pointer/size. + const std::string_view sv = buffer.cast(); +#ifdef ET_BUNDLE_IO + void* buf_ptr = const_cast(static_cast(sv.data())); + const size_t buf_len = sv.size(); + + // If this is a bundled program, extract the inner ExecuTorch program bytes. + if (executorch::bundled_program::is_bundled_program(buf_ptr, buf_len)) { + const void* program_data = nullptr; + size_t program_size = 0; + + const auto status = executorch::bundled_program::get_program_data( + buf_ptr, // serialized BundledProgram start + buf_len, // total size of the BundledProgram blob + &program_data, // [out] pointer to inner .pte bytes + &program_size // [out] size of inner .pte bytes + ); + + if (status != ::executorch::runtime::Error::Ok || program_data == nullptr || + program_size == 0) { + throw std::runtime_error( + "bundled_program::get_program_data() failed or returned empty data"); + } + + // program_data points directly at the flatbuffer-encoded Program region. + return executorch_flatbuffer::GetProgram( + reinterpret_cast(program_data)); + } +#endif + // Otherwise treat the buffer as a raw .pte (flatbuffer Program with optional + // extended header). return executorch_flatbuffer::GetProgram( - buffer.cast().data()); + reinterpret_cast(sv.data())); } py::list _get_program_operators(const executorch_flatbuffer::Program* program) { diff --git a/codegen/tools/targets.bzl b/codegen/tools/targets.bzl index 39de8fcb482..c11982409f0 100644 --- a/codegen/tools/targets.bzl +++ b/codegen/tools/targets.bzl @@ -17,10 +17,8 @@ def define_common_targets(is_fbcode = False): ], deps = [ "//executorch/codegen:gen_lib", - ] + select({ - "DEFAULT": [], - "ovr_config//os:linux": [] if runtime.is_oss else ["//executorch/codegen/tools:selective_build"], # TODO(larryliu0820) :selective_build doesn't build in OSS yet - }), + "//executorch/codegen/tools:selective_build", + ], ) runtime.python_binary( @@ -29,7 +27,7 @@ def define_common_targets(is_fbcode = False): deps = [ ":gen_oplist_lib", ], - preload_deps = [] if runtime.is_oss else ["//executorch/codegen/tools:selective_build"], # TODO(larryliu0820) :selective_build doesn't build in OSS yet + preload_deps = ["//executorch/codegen/tools:selective_build"], package_style = "inplace", visibility = [ "//executorch/...", @@ -103,6 +101,26 @@ def define_common_targets(is_fbcode = False): _is_external_target = True, ) + runtime.python_library( + name = "combine_prim_ops_headers_lib", + srcs = ["combine_prim_ops_headers.py"], + base_module = "executorch.codegen.tools", + visibility = ["//executorch/..."], + ) + + runtime.python_binary( + name = "combine_prim_ops_headers", + main_module = "executorch.codegen.tools.combine_prim_ops_headers", + package_style = "inplace", + visibility = [ + "PUBLIC", + ], + deps = [ + ":combine_prim_ops_headers_lib", + ], + _is_external_target = True, + ) + runtime.python_test( name = "test_gen_all_oplist", srcs = [ @@ -155,27 +173,48 @@ def define_common_targets(is_fbcode = False): _is_external_target = True, ) - if not runtime.is_oss: - runtime.cxx_python_extension( - name = "selective_build", - srcs = [ - "selective_build.cpp", - ], - base_module = "executorch.codegen.tools", - types = ["selective_build.pyi"], - preprocessor_flags = [ - "-DEXECUTORCH_PYTHON_MODULE_NAME=selective_build", - ], - deps = [ - "//executorch/runtime/core:core", - "//executorch/schema:program", - ], - external_deps = [ - "pybind11", - ], - use_static_deps = True, - visibility = ["//executorch/codegen/..."], - ) + runtime.python_library( + name = "gen_selected_prim_ops_lib", + srcs = ["gen_selected_prim_ops.py"], + base_module = "executorch.codegen.tools", + visibility = ["//executorch/..."], + external_deps = ["torchgen"], + ) + + runtime.python_binary( + name = "gen_selected_prim_ops", + main_module = "executorch.codegen.tools.gen_selected_prim_ops", + package_style = "inplace", + visibility = [ + "PUBLIC", + ], + deps = [ + ":gen_selected_prim_ops_lib", + ], + _is_external_target = True, + ) + + + runtime.cxx_python_extension( + name = "selective_build", + srcs = [ + "selective_build.cpp", + ], + base_module = "executorch.codegen.tools", + types = ["selective_build.pyi"], + preprocessor_flags = [ + "-DEXECUTORCH_PYTHON_MODULE_NAME=selective_build", + ], + deps = [ + "//executorch/runtime/core:core", + "//executorch/schema:program", + ], + external_deps = [ + "pybind11", + ], + use_static_deps = True, + visibility = ["//executorch/codegen/..."], + ) # TODO(larryliu0820): This is a hack to only run these two on fbcode. These targets depends on exir which is only available in fbcode. @@ -214,10 +253,12 @@ def define_common_targets(is_fbcode = False): ], ) + if runtime.is_oss or is_fbcode: + # Doesn't work on xplat. But works on fbcode and OSS. runtime.python_test( - name = "test_selective_build", + name = "test_tools_selective_build", srcs = [ - "test/test_selective_build.py", + "test/test_tools_selective_build.py", ], package_style = "inplace", visibility = [ diff --git a/codegen/tools/test/test_gen_oplist.py b/codegen/tools/test/test_gen_oplist.py index f5c6829d6a0..18689cd2505 100644 --- a/codegen/tools/test/test_gen_oplist.py +++ b/codegen/tools/test/test_gen_oplist.py @@ -8,6 +8,7 @@ import os import tempfile import unittest +from pathlib import Path from typing import Dict, List from unittest.mock import NonCallableMock, patch @@ -77,7 +78,7 @@ def test_gen_op_list_with_valid_root_ops( gen_oplist.main(args) mock_dump_yaml.assert_called_once_with( ["aten::add", "aten::mul"], - output_path, + Path(output_path), None, {"aten::add": ["default"], "aten::mul": ["default"]}, False, @@ -100,7 +101,7 @@ def test_gen_op_list_with_root_ops_and_dtypes( gen_oplist.main(args) mock_dump_yaml.assert_called_once_with( ["aten::add", "aten::mul"], - output_path, + Path(output_path), None, { "aten::add": [ @@ -129,7 +130,7 @@ def test_gen_op_list_with_both_op_list_and_ops_schema_yaml_merges( gen_oplist.main(args) mock_dump_yaml.assert_called_once_with( ["aten::add.out", "aten::mul.out", "aten::relu.out"], - output_path, + Path(output_path), test_path, { "aten::relu.out": ["default"], @@ -153,7 +154,7 @@ def test_gen_op_list_with_include_all_operators( gen_oplist.main(args) mock_dump_yaml.assert_called_once_with( ["aten::add", "aten::mul"], - output_path, + Path(output_path), None, {"aten::add": ["default"], "aten::mul": ["default"]}, True, @@ -164,7 +165,7 @@ def test_get_custom_build_selector_with_both_allowlist_and_yaml( ) -> None: op_list = ["aten::add", "aten::mul"] filename = os.path.join(self.temp_dir.name, "selected_operators.yaml") - gen_oplist._dump_yaml(op_list, filename, "model.pte") + gen_oplist._dump_yaml(op_list, Path(filename), "model.pte") self.assertTrue(os.path.isfile(filename)) with open(filename) as f: es = yaml.safe_load(f) diff --git a/codegen/tools/test/test_selective_build.py b/codegen/tools/test/test_tools_selective_build.py similarity index 100% rename from codegen/tools/test/test_selective_build.py rename to codegen/tools/test/test_tools_selective_build.py diff --git a/configurations/CMakeLists.txt b/configurations/CMakeLists.txt index fa5412ac476..fb154ff88bc 100644 --- a/configurations/CMakeLists.txt +++ b/configurations/CMakeLists.txt @@ -63,6 +63,6 @@ if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED) install( TARGETS optimized_native_cpu_ops_lib EXPORT ExecuTorchTargets - DESTINATION lib + DESTINATION ${CMAKE_INSTALL_LIBDIR} ) endif() diff --git a/desktop/README.md b/desktop/README.md index c774cec9c0c..2c00be632e7 100644 --- a/desktop/README.md +++ b/desktop/README.md @@ -1,18 +1,32 @@ -# Experimental: PyTorch Unified Python-less Solution +# ExecuTorch: Inference on consumer Desktops/Laptops with GPUs -This folder contains the experimental PyTorch Unified Python-less Solution, for both compiler and runtime. Proceed with caution. +## Overview +ExecuTorch is a lightweight, flexible runtime designed for efficient AI inference, historically focused on mobile and embedded devices. With the growing demand for local inference on personal desktops and laptops—especially those equipped with consumer GPUs (e.g., gaming PCs with NVIDIA hardware)—ExecuTorch is experimenting on expanding its capabilities to support these platforms. -## torch dependency -We use the pinned pytorch version from `install_requirements.py` and CI should be using `.ci/docker/ci_commit_pins/pytorch.txt` which should be consistent with `install_requirements.py`. +## Historical Context +- **Mobile and Embedded Focus**: ExecuTorch’s initial target market was mobile and embedded devices. +- **Desktop/Laptop Support**: Previously, desktop and laptop ("AI PC") inference was enabled through backends such as XNNPACK, OpenVino, and Qualcomm NPUs. +- **No CUDA Support**: For a long time, ExecuTorch did not offer a CUDA backend, limiting GPU acceleration on NVIDIA hardware. +## Recent Developments +With increased demand for local inference on consumer desktops and laptops, exemplified by popular runtimes like llama.cpp and MLX, ExecuTorch is now experimenting with CUDA and Metal support. This is achieved by leveraging Inductor compiler technology from PyTorch, specifically using Ahead-of-Time Inductor [AOTI](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) to avoid reinventing the wheel. -## Compiler -All code should live in `compiler/` folder. Code uses `torch` nightly as mentioned in torch dependency section. +## Key Benefits +- **Model Agnostic**: Validated on models such as [Voxtral](../examples/models/voxtral), [Gemma3-4b](../examples/models/gemma3), ResNet, and [Whisper](../examples/models/whisper/README.md). Theoretically, any model exportable via torch.export is supported. +- **PyTorch Ecosystem Integration**: Enables workflows for fine-tuning, quantization, and compilation within the PyTorch ecosystem. +- **No Python Runtime During Inference**: Ideal for native applications (e.g., written in C++) embedding AI capabilities. +- **No libtorch Dependency**: Reduces binary size, making deployment easier for resource-constrained applications. +- **Efficient GPU Support**: Uses AOTI-powered CUDA backend for efficient inference on NVIDIA GPUs. -## Runtime -All code should live in `runtime/` folder. CMake build system should leverage `libtorch` in the pip install of `torch` nightly. To build runtime, we need to point `CMAKE_PREFIX_PATH` to the pip install location of `torch` nightly. This way we can do: +## Backends -```cmake -find_package(torch REQUIRED) -``` +Backends leveraging AoTi +- [CUDA backend](../backends/cuda) +- [Metal backend](../backends/apple/metal) + +## Roadmap & Limitations +- **Experimental Status**: CUDA and Metal backends via AoTi are currently experimental. Contributions and feedback are welcome! +- **Model Compatibility**: While most models exportable via torch.export should work, validation is ongoing for broader model support. +- **Portability**: Figuring out the balance and trade-off between performance, portability and model filesize. +- **Windows-native WIP**: On windows we only supports WSL right now. Native Windows support is WIP. diff --git a/devtools/bundled_program/test/test_bundle_data.py b/devtools/bundled_program/test/test_bundle_data.py index a587a8672e9..9fdeb4a776d 100644 --- a/devtools/bundled_program/test/test_bundle_data.py +++ b/devtools/bundled_program/test/test_bundle_data.py @@ -18,7 +18,7 @@ from executorch.devtools.bundled_program.util.test_util import ( get_common_executorch_program, ) -from executorch.exir._serialize import _serialize_pte_binary +from executorch.exir._serialize import _PTEFile, _serialize_pte_binary class TestBundle(unittest.TestCase): @@ -72,7 +72,11 @@ def test_bundled_program(self) -> None: self.assertEqual( bundled_program.serialize_to_schema().program, - bytes(_serialize_pte_binary(executorch_program.executorch_program)), + bytes( + _serialize_pte_binary( + pte_file=_PTEFile(program=executorch_program.executorch_program) + ) + ), ) def test_bundled_program_from_pte(self) -> None: diff --git a/devtools/etdump/tests/etdump_test.cpp b/devtools/etdump/tests/etdump_test.cpp index d095844986f..fd35caca557 100644 --- a/devtools/etdump/tests/etdump_test.cpp +++ b/devtools/etdump/tests/etdump_test.cpp @@ -345,7 +345,7 @@ TEST_F(ProfilerETDumpTest, DebugEventTensorList) { EValue* values_p[2] = {&evalue_1, &evalue_2}; BoxedEvalueList a_box(values_p, storage, 2); - EValue evalue(a_box); + EValue evalue(&a_box); evalue.tag = Tag::ListTensor; etdump_gen[i]->create_event_block("test_block"); diff --git a/devtools/etrecord/tests/TARGETS b/devtools/etrecord/tests/TARGETS index f25f0464c9e..59f73bdb406 100644 --- a/devtools/etrecord/tests/TARGETS +++ b/devtools/etrecord/tests/TARGETS @@ -23,5 +23,6 @@ runtime.python_library( "//executorch/exir/tests:models", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", "//executorch/export:lib", + "//executorch/runtime:runtime", # @manual ], ) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 6d046d8f2e8..6b6b4f583a6 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -73,6 +73,7 @@ from executorch.devtools.inspector.numerical_comparator import ( L1Comparator, MSEComparator, + NumericalComparatorBase, SNRComparator, ) from executorch.exir import ExportedProgram @@ -1036,10 +1037,8 @@ def __init__( source_time_scale: The time scale of the performance data retrieved from the runtime. The default time hook implentation in the runtime returns NS. target_time_scale: The target time scale to which the users want their performance data converted to. Defaults to MS. debug_buffer_path: Debug buffer file path that contains the debug data referenced by ETDump for intermediate and program outputs. - delegate_metadata_parser: Optional function to parse delegate metadata from an Profiling Event. Expected signature of the function is: - (delegate_metadata_list: List[bytes]) -> Union[List[str], Dict[str, Any]] - delegate_time_scale_converter: Optional function to convert the time scale of delegate profiling data. If not given, use the conversion ratio of - target_time_scale/source_time_scale. + delegate_metadata_parser: Optional function to parse delegate metadata from an Profiling Event. Expected signature of the function is (delegate_metadata_list: List[bytes]) -> Union[List[str], Dict[str, Any]]. + delegate_time_scale_converter: Optional function to convert the time scale of delegate profiling data. If not given, use the conversion ratio of target_time_scale/source_time_scale. enable_module_hierarchy: Enable submodules in the operator graph. Defaults to False. Returns: @@ -1169,6 +1168,7 @@ def _consume_etrecord(self) -> None: def _get_aot_intermediate_outputs_and_op_names( self, + disable_debug_handle_valdiation: bool = False, ) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]: """ Capture intermediate outputs only if _representative_inputs are provided @@ -1184,6 +1184,7 @@ def _get_aot_intermediate_outputs_and_op_names( self._etrecord.exported_program, self._etrecord.export_graph_id, self._etrecord.edge_dialect_program, + disable_debug_handle_valdiation, ): export_program = self._etrecord.exported_program else: @@ -1304,10 +1305,9 @@ def print_data_tabular( Displays the underlying EventBlocks in a structured tabular format, with each row representing an Event. Args: - file: Which IO stream to print to. Defaults to stdout. - Not used if this is in an IPython environment such as a Jupyter notebook. - include_units: Whether headers should include units (default true) - include_delegate_debug_data: Whether to include delegate debug metadata (default false) + file: Which IO stream to print to. Defaults to stdout. Not used if this is in an IPython environment such as a Jupyter notebook. + include_units: Whether headers should include units (default true). + include_delegate_debug_data: Whether to include delegate debug metadata (default false). Returns: None @@ -1404,7 +1404,11 @@ def get_exported_program( else self._etrecord.graph_map.get(graph) ) - def calculate_numeric_gap(self, distance: str = "MSE"): + def calculate_numeric_gap( + self, + distance: Union[str, NumericalComparatorBase], + disable_debug_handle_valdiation: bool = False, + ): """ Compares logged intermediate outputs from the exported graph (in ETRecord) with runtime outputs (in ETDump) using a user-specific numerical comparator. @@ -1415,13 +1419,23 @@ def calculate_numeric_gap(self, distance: str = "MSE"): compare the intermediate outputs from the AOT and the runtime. Args: - distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR". + distance: The metrics the inspector will use for gap calculation. Can be either: + - A string: one of "MSE", "L1", or "SNR" for built-in comparators. + - A custom NumericalComparatorBase instance: allows you to define custom comparison logic + by subclassing NumericalComparatorBase and implementing the compare() method. + disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc., + during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose + connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR + node has corresponding node in aten IR, and when such validation fails numeric debugger falls back to edge + IR as reference graph. This flag allows one to override such behavior and make best effort comparison. Returns: pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps. """ aot_intermediate_outputs, aot_debug_handle_to_op_names = ( - self._get_aot_intermediate_outputs_and_op_names() + self._get_aot_intermediate_outputs_and_op_names( + disable_debug_handle_valdiation + ) ) if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0: raise ValueError( @@ -1434,15 +1448,18 @@ def calculate_numeric_gap(self, distance: str = "MSE"): mapping = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) - metric = distance.strip().upper() - if metric == "MSE": - comparator = MSEComparator() - elif metric == "L1": - comparator = L1Comparator() - elif metric == "SNR": - comparator = SNRComparator() + if isinstance(distance, NumericalComparatorBase): + comparator = distance else: - raise ValueError(f"Unsupported distance metric {distance!r}") + metric = distance.strip().upper() + if metric == "MSE": + comparator = MSEComparator() + elif metric == "L1": + comparator = L1Comparator() + elif metric == "SNR": + comparator = SNRComparator() + else: + raise ValueError(f"Unsupported distance metric {distance!r}") rows = [] for (aot_debug_handle, aot_intermediate_output), ( @@ -1451,6 +1468,15 @@ def calculate_numeric_gap(self, distance: str = "MSE"): ) in mapping.items(): if aot_intermediate_output is None or runtime_intermediate_output is None: continue + # If aot outputs length is > 1 then comparison fails since we dont really have + # any instances where runtime intermediate output is a tuple or list + # This does not happen when edge dialect program is reference for comparison + # but happens in aten graph where ops like unbind remain undecomposed + if ( + isinstance(aot_intermediate_output, Sequence) + and len(aot_intermediate_output) > 1 + ): + continue rows.append( { "aot_ops": find_op_names( diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index ee7ebb2f5ea..a3933ffb993 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -657,13 +657,21 @@ def _combine_aot_overlapped_intermediate_outputs( # Combine all AOT debug_handles into a list aot_combined_debug_handle = [t[0] for t in aot_map.keys()] - if set(aot_combined_debug_handle) != set(runtime_debug_handle): - # AOT combined debug_handle and runtime debug_handle do not match. + # Reason we dont check for exact match: + # in some experiments where we want to rewrite the aten graph that was + # lowered, so as to use custom ops like int4_matmul, we lose some nodes + # on the graph and thus lose some debug handles. And we dont find + # exact match within connected components. + if not set(aot_combined_debug_handle).issubset(set(runtime_debug_handle)): + # AOT combined debug_handle is not a subset of runtime debug_handle. return (-1,), None # Pick the last intermediate output last_int = runtime_debug_handle[negative_index] key = (last_int,) + if key not in aot_map: + # If the last intermediate output is not in the AOT map, return None + return (-1,), None return runtime_debug_handle, aot_map[key] @@ -965,7 +973,7 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: # Ensure both sequences have the same length if len(a) != len(b): raise ValueError( - f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison." + f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}." ) # Compare each element in the sequences and return the list of results @@ -990,6 +998,9 @@ def get_ancestor_node_identifiers(node: Node) -> List[str]: Returns: the identifiers of all its ancestor nodes """ + if FROM_NODE_KEY not in node.meta: + return [] + node_source = node.meta[FROM_NODE_KEY] node_source = node_source[-1] ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"] @@ -1056,11 +1067,16 @@ def _find_n_match_node(node: Node) -> None: if node.op in ("output", "placeholder"): return node_id = f"{node.name}.{exported_program_graph_id}" - parent_node_id = get_parent_node_identifier(node) + parent_node_ids = get_ancestor_node_identifiers(node) if node_id in ancestors_node_id_to_debug_handle: matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id]) - elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle: - matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id]) + elif parent_node_ids: + for parent_node_id in parent_node_ids: + if parent_node_id in ancestors_node_id_to_debug_handle: + matched_debug_handles.add( + ancestors_node_id_to_debug_handle[parent_node_id] + ) + break bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) return matched_debug_handles @@ -1094,15 +1110,17 @@ def _equip_debug_handle(node: Node) -> None: if node.op in ("output", "placeholder"): return node_id = f"{node.name}.{exported_program_graph_id}" - parent_node_id = get_parent_node_identifier(node) + parent_node_ids = get_ancestor_node_identifiers(node) + node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE if node_id in ancestors_node_id_to_debug_handle: node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id] - elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle: - node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[ - parent_node_id - ] - else: - node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE + elif parent_node_ids: + for parent_node_id in parent_node_ids: + if parent_node_id in ancestors_node_id_to_debug_handle: + node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[ + parent_node_id + ] + break bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle) @@ -1111,6 +1129,7 @@ def propagate_back_debug_handle( exported_program: ExportedProgram, exported_program_graph_id: int, edge_dialect_program: ExportedProgram, + disable_debug_handle_valdiation: bool = False, ) -> bool: """ Propagate debug handle from edge dialect program back to the exported program while maintain the correctness @@ -1124,6 +1143,10 @@ def propagate_back_debug_handle( Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1. The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping. + disable_debug_handle_validation is used to avoid _verify_graph_match() in case of debug handle mismatch. + This can happen when we are comparing against aten graph in which case not all debug handles are matched + in aten graph. Example of this is when symbolic shape nodes are re-exported. + Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False. """ # 1. Extract mapping from ancestor node identifiers to debug handles @@ -1137,7 +1160,9 @@ def propagate_back_debug_handle( ) # 3. Verify if every debug handle in edge dialect program has a corresponding node - if not _verify_graph_match(edge_dialect_program, matched_debug_handles): + if not disable_debug_handle_valdiation and not _verify_graph_match( + edge_dialect_program, matched_debug_handles + ): return False # 4. Apply debug handles to the exported program diff --git a/devtools/inspector/numerical_comparator/__init__.py b/devtools/inspector/numerical_comparator/__init__.py index daacb5496ae..0090c50025f 100644 --- a/devtools/inspector/numerical_comparator/__init__.py +++ b/devtools/inspector/numerical_comparator/__init__.py @@ -13,9 +13,13 @@ MSEComparator, ) +from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import ( + NumericalComparatorBase, +) + from executorch.devtools.inspector.numerical_comparator.snr_numerical_comparator import ( SNRComparator, ) -__all__ = ["L1Comparator", "MSEComparator", "SNRComparator"] +__all__ = ["L1Comparator", "MSEComparator", "SNRComparator", "NumericalComparatorBase"] diff --git a/devtools/inspector/numerical_comparator/inspector_numerical_comparator_base.py b/devtools/inspector/numerical_comparator/inspector_numerical_comparator_base.py deleted file mode 100644 index b6dac7e1970..00000000000 --- a/devtools/inspector/numerical_comparator/inspector_numerical_comparator_base.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - - -from abc import ABC, abstractmethod -from typing import Any - - -class InspectorNumericalComparatorBase(ABC): - @abstractmethod - def compare(self, a: Any, b: Any) -> float: - """Compare two intermediate output and return a result. - - This method should be overridden by subclasses to provide custom comparison logic. - - Args: - a: The first intermediate output to compare. - b: The second intermediate output to compare. - - Returns: - A numerical result indicating the comparison outcome. - """ - pass diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index a3afed07ed8..93a74915e84 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -10,6 +10,7 @@ import os import random import statistics +import sys import tempfile import unittest from contextlib import redirect_stdout @@ -681,7 +682,7 @@ def test_calculate_numeric_gap(self): aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda: ( + inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( aot_intermediate_outputs, aot_debug_handle_to_op_name, ) @@ -721,6 +722,79 @@ def test_calculate_numeric_gap(self): # gap should equal 3.0 self.assertEqual(row["gap"][0], 3.0) + def test_calculate_numeric_gap_with_custom_comparator(self): + """Test calculate_numeric_gap with a custom NumericalComparatorBase implementation.""" + from executorch.devtools.inspector.numerical_comparator import ( + NumericalComparatorBase, + ) + + # Create a custom comparator that returns the max absolute difference + class MaxAbsDiffComparator(NumericalComparatorBase): + def compare(self, a, b): + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + return torch.max(torch.abs(a - b)).item() + return abs(a - b) + + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + # Call the constructor of Inspector + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + (1,): torch.tensor([4.0, 5.0, 6.0]), + } + + runtime_intermediate_outputs = { + (0,): ([torch.tensor([2.0, 1.0, 5.0])], 1), + (1,): ([torch.tensor([3.0, 6.0, 5.0])], 1), + } + + aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + + inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Create custom comparator instance + custom_comparator = MaxAbsDiffComparator() + + # Test with custom comparator + df = inspector_instance.calculate_numeric_gap(distance=custom_comparator) + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 2) + cols = set(df.columns) + expected_cols = { + "aot_ops", + "aot_intermediate_output", + "runtime_ops", + "runtime_intermediate_output", + "gap", + } + self.assertEqual(cols, expected_cols) + + # Verify the custom comparator logic + # For (0,): max(|[1.0, 2.0, 3.0] - [2.0, 1.0, 5.0]|) = max([1.0, 1.0, 2.0]) = 2.0 + self.assertEqual(df.iloc[0]["gap"][0], 2.0) + # For (1,): max(|[4.0, 5.0, 6.0] - [3.0, 6.0, 5.0]|) = max([1.0, 1.0, 1.0]) = 1.0 + self.assertEqual(df.iloc[1]["gap"][0], 1.0) + @unittest.skip("ci config values are not propagated") def test_intermediate_tensor_comparison_with_torch_export(self): """Test intermediate tensor comparison using torch.export.export and to_edge_transform_and_lower.""" @@ -838,6 +912,123 @@ def _gen_random_runtime_output( ) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]: return [torch.randn(RAW_DATA_SIZE)] + @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") + def test_disable_debug_handle_validation_with_symbolic_shapes(self): + """ + Test that demonstrates the issue with symbolic shape related nodes losing from_node info + during dynamic shape based export, and shows how disable_debug_handle_valdiation parameter + in propagate_back_debug_handle allows validation to be bypassed. + """ + from executorch.devtools.inspector._inspector_utils import ( + propagate_back_debug_handle, + ) + + class SymbolicShapeModel(torch.nn.Module): + """Model that will have symbolic shape related operations after export.""" + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # This will create symbolic shape nodes during dynamic export + batch_size = x.shape[0] + x = x + torch.rand((batch_size, 1)) + # Masking operation that creates gt/lt nodes + valid_mask = mask > 0.5 + x = torch.where(valid_mask, x, torch.zeros_like(x)) + return x + + # Create model and dynamic inputs + model = SymbolicShapeModel() + batch_size = 2 + seq_len = 4 + x = torch.randn(batch_size, seq_len) + mask = torch.rand(batch_size, seq_len) + example_inputs = (x, mask) + + # Export with dynamic shapes to create symbolic shape related nodes + dynamic_shapes = { + "x": {0: torch.export.Dim("batch_size", min=1, max=10)}, + "mask": {0: torch.export.Dim("batch_size", min=1, max=10)}, + } + + exported_program = torch.export.export( + model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True + ) + + """ + In this case origina aten graph has sym_size_int_2 node but when we look at + nodes metadata in edge_program_manager, its sym_size node's from_node says + sym_size_int_3 which is not in the original aten graph. + """ + # Create edge program - this is where from_node info can be lost for symbolic shape nodes + edge_program_manager: EdgeProgramManager = to_edge(exported_program) + edge_program_manager_copy = copy.deepcopy(edge_program_manager) + et_program_manager: ExecutorchProgramManager = ( + edge_program_manager.to_executorch() + ) + + with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file: + etrecord_path = tmp_file.name + + # Generate ETRecord with the exported program (aten graph) + generate_etrecord( + etrecord_path, + edge_program_manager_copy, + et_program_manager, + exported_program=exported_program, + ) + + # Create Inspector and get etrecord + with patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object(EventBlock, "_gen_from_etdump"): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=etrecord_path, + ) + + # Extract the necessary values from the inspector's etrecord + exported_program_from_etrecord = ( + inspector_instance._etrecord.exported_program + ) + export_graph_id = inspector_instance._etrecord.export_graph_id + edge_dialect_program = inspector_instance._etrecord.edge_dialect_program + + # Ensure we have all the necessary components + self.assertIsNotNone(exported_program_from_etrecord) + self.assertIsNotNone(export_graph_id) + self.assertIsNotNone(edge_dialect_program) + + # Test propagate_back_debug_handle with validation enabled (should fail or return False) + # This demonstrates the issue with symbolic shape nodes losing from_node info + validation_enabled_result = propagate_back_debug_handle( + exported_program_from_etrecord, + export_graph_id, + edge_dialect_program, + disable_debug_handle_valdiation=False, + ) + + # With validation enabled, it should return False when from_node info is lost + self.assertFalse( + validation_enabled_result, + "propagate_back_debug_handle should return False when validation is enabled " + "and symbolic shape nodes lose from_node info", + ) + + # Test propagate_back_debug_handle with validation disabled (should succeed) + # This shows how the disable_debug_handle_valdiation flag allows the function to work + validation_disabled_result = propagate_back_debug_handle( + exported_program_from_etrecord, + export_graph_id, + edge_dialect_program, + disable_debug_handle_valdiation=True, + ) + + # With validation disabled, it should return True even when from_node info is lost + self.assertTrue( + validation_disabled_result, + "propagate_back_debug_handle should return True when validation is disabled, " + "allowing best effort comparison even when from_node info is lost", + ) + def _gen_random_events(self) -> List[Event]: events = [] for i in range(2): diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 26fe38acfac..8c4bb4b38b9 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -334,13 +334,15 @@ def test_map_runtime_aot_intermediate_outputs_no_overlaps(self): self.assertEqual(actual, expected) def test_map_runtime_aot_intermediate_outputs_partial_match(self): - # Partial match between aot and runtime debug_handles will return empty + # Partial match between aot and runtime debug_handles will return + # matching debug handles from runtime aot_intermediate_outputs = {(2,): 100, (9,): 300} runtime_intermediate_outputs = {(2, 3): (200, 1), (8, 9): (300, 1)} actual = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) - expected = {} + # Since the runtime output debug handle of 9 is there in aot debug handle + expected = {((8, 9), 300): ((8, 9), 300)} self.assertEqual(actual, expected) def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self): diff --git a/devtools/scripts/profile_model.sh b/devtools/scripts/profile_model.sh index 8697c97cd02..a4d50f6c6fc 100755 --- a/devtools/scripts/profile_model.sh +++ b/devtools/scripts/profile_model.sh @@ -7,7 +7,7 @@ #!/bin/bash -# ExecutorTorch Model Profiling Script +# ExecuTorch Model Profiling Script # # This script automates the process of building executor_runner with profiling enabled, # running model inference with ETDump collection, and generating CSV profiling reports. diff --git a/devtools/visualization/model_explorer_styles/cluster_highlight_style.json b/devtools/visualization/model_explorer_styles/cluster_highlight_style.json new file mode 100644 index 00000000000..cced07d6a55 --- /dev/null +++ b/devtools/visualization/model_explorer_styles/cluster_highlight_style.json @@ -0,0 +1,236 @@ + [ + { + "queries": [ + { + "type": "node_type", + "nodeType": "op_nodes" + }, + { + "type": "regex", + "queryRegex": "quantize", + "matchTypes": [ + "title" + ] + } + ], + "nodeType": "op_nodes", + "styles": { + "node_bg_color": { + "id": "node_bg_color", + "value": "#dce9e9" + } + }, + "version": "v2" + }, + { + "queries": [ + { + "type": "node_type", + "nodeType": "op_nodes" + }, + { + "type": "regex", + "queryRegex": "aten.", + "matchTypes": [ + "title" + ] + } + ], + "nodeType": "op_nodes", + "styles": { + "node_bg_color": { + "id": "node_bg_color", + "value": "#b4e3f5" + } + }, + "version": "v2" + }, + { + "queries": [ + { + "type": "node_type", + "nodeType": "layer_nodes" + }, + { + "type": "regex", + "queryRegex": "cluster", + "matchTypes": [ + "title" + ] + } + ], + "nodeType": "op_nodes", + "styles": { + "node_bg_color": { + "id": "node_bg_color", + "value": "#d0eae9" + }, + "node_border_color": { + "id": "node_border_color", + "value": "#ffffff" + } + }, + "version": "v2" + }, + { + "queries": [ + { + "type": "node_type", + "nodeType": "layer_nodes" + }, + { + "type": "regex", + "queryRegex": "partition 0", + "matchTypes": [ + "title" + ] + } + ], + "nodeType": "op_nodes", + "styles": { + "node_bg_color": { + "id": "node_bg_color", + "value": "#fff1d5" + } + }, + "version": "v2" + }, + { + "queries": [ + { + "type": "node_type", + "nodeType": "layer_nodes" + }, + { + "type": "regex", + "queryRegex": "partition 1", + "matchTypes": [ + "title" + ] + } + ], + "nodeType": "op_nodes", + "styles": { + "node_bg_color": { + "id": "node_bg_color", + "value": "#fdffcc" + } + }, + "version": "v2" + }, + { + "queries": [ + { + "type": "node_type", + "nodeType": "layer_nodes" + }, + { + "type": "regex", + "queryRegex": "partition 2", + "matchTypes": [ + "title" + ] + } + ], + "nodeType": "op_nodes", + "styles": { + "node_bg_color": { + "id": "node_bg_color", + "value": "#ccffcc" + } + }, + "version": "v2" + }, + { + "queries": [ + { + "type": "node_type", + "nodeType": "layer_nodes" + }, + { + "type": "regex", + "queryRegex": "partition 3", + "matchTypes": [ + "title" + ] + } + ], + "nodeType": "op_nodes", + "styles": { + "node_bg_color": { + "id": "node_bg_color", + "value": "#ccffff" + } + }, + "version": "v2" + }, + { + "queries": [ + { + "type": "node_type", + "nodeType": "layer_nodes" + }, + { + "type": "regex", + "queryRegex": "partition 4", + "matchTypes": [ + "title" + ] + } + ], + "nodeType": "op_nodes", + "styles": { + "node_bg_color": { + "id": "node_bg_color", + "value": "#ffc6e2" + } + }, + "version": "v2" + }, + { + "queries": [ + { + "type": "node_type", + "nodeType": "layer_nodes" + }, + { + "type": "regex", + "queryRegex": "partition 5", + "matchTypes": [ + "title" + ] + } + ], + "nodeType": "op_nodes", + "styles": { + "node_bg_color": { + "id": "node_bg_color", + "value": "#ffcaff" + } + }, + "version": "v2" + }, + { + "queries": [ + { + "type": "node_type", + "nodeType": "layer_nodes" + }, + { + "type": "regex", + "queryRegex": "partition 6", + "matchTypes": [ + "title" + ] + } + ], + "nodeType": "op_nodes", + "styles": { + "node_bg_color": { + "id": "node_bg_color", + "value": "#d7d7ff" + } + }, + "version": "v2" + } +] \ No newline at end of file diff --git a/devtools/visualization/visualization_utils.py b/devtools/visualization/visualization_utils.py index b21a953f4d2..b76d164b61b 100644 --- a/devtools/visualization/visualization_utils.py +++ b/devtools/visualization/visualization_utils.py @@ -1,24 +1,32 @@ # Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - +import json import subprocess import time from typing import Any, Callable, Type -from executorch.exir import EdgeProgramManager, ExecutorchProgramManager -from executorch.exir.program._program import _update_exported_program_graph_module +from executorch.exir import EdgeProgramManager, ExecutorchProgramManager # type: ignore +from executorch.exir.program._program import ( # type: ignore + _update_exported_program_graph_module, +) + from torch._export.verifier import Verifier -from torch.export.exported_program import ExportedProgram -from torch.fx import GraphModule +from torch.export.exported_program import ExportedProgram # type: ignore +from torch.fx import GraphModule, Node # type: ignore try: from model_explorer import config, consts, visualize_from_config # type: ignore + from model_explorer.config import ModelExplorerConfig # type: ignore + from model_explorer.pytorch_exported_program_adater_impl import ( # type: ignore + PytorchExportedProgramAdapterImpl, + ) except ImportError: print( - "Error: 'model_explorer' is not installed. Install using devtools/install_requirement.sh" + "Error: 'model_explorer' is not installed. Install using devtools/install_requirements.sh" ) raise @@ -139,6 +147,134 @@ def visualize_model_explorer( ) +def _save_model_as_json(cur_config: ModelExplorerConfig, file_name: str): + """Save the graphs stored in the `cur_config` in JSON format, which can be loaded by the Model Explorer GUI. + + :param cur_config: ModelExplorerConfig containing the graph for visualization. + :param file_name: Name of the JSON file for storage. + """ + # Extract the graphs from the config file. + graphs_list = json.loads(cur_config.get_transferrable_data()["graphs_list"]) + graphs = json.loads(graphs_list[0])["graphs"] + + # The returned dictionary is missing the `collectionLabel` entry. Add it manually. + for graph in graphs: + graph["collectionLabel"] = "Executorch" + + # Create the JSON according to the structure required by the Model Explorer GUI. + json_data = { + "label": "Executorch", + "graphs": graphs, + "graphsWithLevel": [ + {"graph": graph, "level": level} for level, graph in enumerate(graphs) + ], + } + + # Store the JSON. + with open(file_name, "w") as f: + json.dump(json_data, f) + + +def visualize_with_clusters( + exported_program: ExportedProgram, + json_file_name: str | None = None, + reuse_server: bool = False, + get_node_partition_name: Callable[[Node], str | None] = lambda node: node.meta.get( + "delegation_tag", None + ), + get_node_qdq_cluster_name: Callable[ + [Node], str | None + ] = lambda node: node.meta.get("cluster", None), + **kwargs, +): + """Visualize exported programs using the Model Explorer. The QDQ clusters and individual partitions are highlighted. + + To install the Model Explorer, run `devtools/install_requirements.sh`. + To display a stored json file, first launch the Model Explorer server by running `model-explorer`, and then + use the GUI to open the json. + + NOTE: FireFox seems to have issues rendering the graphs. Other browsers seem to work well. + + :param exported_program: Program to visualize. + :param json_file_name: If not None, a JSON of the visualization will be stored in the provided file. The JSON can + then be opened in the Model Explorer GUI later. + If None, a Model Explorer instance will be launched with the model visualization. + :param reuse_server: If True, an existing instance of the Model Explorer server will be used (if one exists). + Otherwise, a new instance at a separate port will start. + :param get_node_partition_name: Function which takes a `Node` and returns a string with the name of the partition + the `Node` belongs to, or `None` if it has no partition. + :param get_node_qdq_cluster_name: Function which takes a `Node` and returns a string with the name of the QDQ + cluster the `Node` belongs to, or `None` if it has no cluster. + :param kwargs: Additional kwargs for the `visualize_from_config()` function. + """ + + cur_config = config() + + # Create a Model Explorer graph from the `exported_program`. + adapter = PytorchExportedProgramAdapterImpl( + exported_program, consts.DEFAULT_SETTINGS + ) + graphs = adapter.convert() + + nodes = list(exported_program.graph.nodes) + explorer_nodes = graphs["graphs"][0].nodes + + # Highlight QDQ clusters and individual partitions. + known_partition_names = [] + for explorer_node, node in zip(explorer_nodes, nodes, strict=True): + # Generate the `namespace` for the node, which will determine node grouping in the visualizer. + # The character "/" is used as a divider when a node has multiple namespaces. + namespace = "" + + if (partition_name := get_node_partition_name(node)) is not None: + # If the nodes are tagged by the partitioner, highlight the tagged groups. + + # Create a custom naming for the partitions ("partition " where `i` = 0, 1, 2, ...). + if partition_name not in known_partition_names: + known_partition_names.append(partition_name) + partition_id = known_partition_names.index(partition_name) + + safe_partition_name = partition_name.replace( + "/", ":" + ) # Avoid using unwanted "/". + namespace += f"partition {partition_id} ({safe_partition_name})" + + if (cluster_name := get_node_qdq_cluster_name(node)) is not None: + # Highlight the QDQ cluster. + + # Add a separator, in case the namespace already contains the `partition`. + if len(namespace) != 0: + namespace += "/" + + # Create a custom naming for the clusters ("cluster ()"). + safe_cluster_name = cluster_name.replace( + "/", ":" + ) # Avoid using unwanted "/". + namespace += f"cluster ({safe_cluster_name})" + + explorer_node.namespace = namespace + + # Store the modified graph in the config. + graphs_index = len(cur_config.graphs_list) + cur_config.graphs_list.append(graphs) + name = "Executorch" + model_source: config.ModelSource = {"url": f"graphs://{name}/{graphs_index}"} + cur_config.model_sources.append(model_source) + + if json_file_name is not None: + # Just save the visualization. + _save_model_as_json(cur_config, json_file_name) + + else: + # Start the ModelExplorer server, and visualize the graph in the browser. + if reuse_server: + cur_config.set_reuse_server() + visualize_from_config( + cur_config, + **kwargs, + ) + + def visualize_graph( graph_module: GraphModule, exported_program: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager, diff --git a/devtools/visualization/visualization_utils_test.py b/devtools/visualization/visualization_utils_test.py index 4f44241518f..0d470a7f359 100644 --- a/devtools/visualization/visualization_utils_test.py +++ b/devtools/visualization/visualization_utils_test.py @@ -24,7 +24,7 @@ from model_explorer.config import ModelExplorerConfig # type: ignore except ImportError: print( - "Error: 'model_explorer' is not installed. Install using devtools/install_requirement.sh" + "Error: 'model_explorer' is not installed. Install using devtools/install_requirements.sh" ) raise diff --git a/docs/.gitignore b/docs/.gitignore index 980fbad8320..b9b2a3753e5 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -3,3 +3,4 @@ /sphinxbuild_py /sphinxbuild_cpp /src +source/sg_execution_times.rst diff --git a/docs/Makefile b/docs/Makefile index 219998d4b4d..c4f5e571ff8 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -10,6 +10,15 @@ BUILDDIR = _build # Put it first so that "make" without argument is like "make help". +html-noplot: + $(SPHINXBUILD) -D plot_gallery=0 -b html $(SPHINXOPTS) "$(SOURCEDIR)" "$(BUILDDIR)/html" + +html-stable: + # Stable differs from 'make html' in that it shows the release version + # instead of "main (version)" in the docs and version switcher. + # See conf.py for more details. + RELEASE=true $(MAKE) html + help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/README.md b/docs/README.md index e30decb9362..845267b32f6 100644 --- a/docs/README.md +++ b/docs/README.md @@ -43,7 +43,7 @@ To build the documentation locally: git clone -b viable/strict https://github.com/pytorch/executorch.git && cd executorch ``` -1. If you don't have it already, start either a Python virtual envitonment: +1. If you don't have it already, start either a Python virtual environment: ```bash python3 -m venv .venv && source .venv/bin/activate && pip install --upgrade pip @@ -111,7 +111,7 @@ You can use the variables in both regular text and code blocks. ## Including READMEs to the Documentation Build You might want to include some of the `README.md` files from various directories -in this repositories in your documentation build. To do that, create an `.md` +in this repository in your documentation build. To do that, create an `.md` file and use the `{include}` directive to insert your `.md` files. Example: ```` @@ -177,7 +177,7 @@ file: ```` In the `index.md` file, I would add `tutorials/selective-build-tutorial` in -both the `toctree` and the `cusotmcarditem` sections. +both the `toctree` and the `customcarditem` sections. # Auto-generated API documentation diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css deleted file mode 100644 index 3ae9585701e..00000000000 --- a/docs/source/_static/css/custom.css +++ /dev/null @@ -1,194 +0,0 @@ -/** - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -/* sphinx-design styles for cards/tabs -*/ -:root { - --sd-color-info: #ee4c2c; - --sd-color-primary: #6c6c6d; - --sd-color-primary-highlight: #f3f4f7; - --sd-color-card-border-hover: #ee4c2c; - --sd-color-card-border: #f3f4f7; - --sd-color-card-background: #fff; - --sd-color-card-text: inherit; - --sd-color-card-header: transparent; - --sd-color-card-footer: transparent; - --sd-color-tabs-label-active: #ee4c2c; - --sd-color-tabs-label-hover: #ee4c2c; - --sd-color-tabs-label-inactive: #6c6c6d; - --sd-color-tabs-underline-active: #ee4c2c; - --sd-color-tabs-underline-hover: #fabdbd; - --sd-color-tabs-underline-inactive: transparent; - --sd-color-tabs-overline: rgb(222, 222, 222); - --sd-color-tabs-underline: rgb(222, 222, 222); -} - -.sd-text-info { - color: #ee4c2c; -} - -.sd-card-img-top { - background: #ee4c2c; - height: 5px !important; -} - -.sd-card { - position: relative; - background-color: #fff; - opacity: 1.0; - border-radius: 0px; - width: 30%; - border: none; - padding-bottom: 0px; -} - - -.sd-card-img:hover { - opacity: 1.0; - background-color: #f3f4f7; -} - - -.sd-card:after { - display: block; - opacity: 1; - content: ''; - border-bottom: solid 1px #ee4c2c; - background-color: #fff; - transform: scaleX(0); - transition: transform .250s ease-in-out; - transform-origin: 0% 50%; -} - -.sd-card:hover { - background-color: #fff; - opacity: 1; - border-top: 1px solid #f3f4f7; - border-left: 1px solid #f3f4f7; - border-right: 1px solid #f3f4f7; -} - -.sd-card:hover:after { - transform: scaleX(1); -} - -.card-prerequisites:hover { - transition: none; - border: none; -} - -.card-prerequisites:hover:after { - transition: none; - transform: none; -} - -.card-prerequisites:after { - display: block; - content: ''; - border-bottom: none; - background-color: #fff; - transform: none; - transition: none; - transform-origin: none; -} - - -details.sd-dropdown { - font-weight: 300; - width: auto; -} - -details.sd-dropdown:after { - border: none; - transition: none; -} - -details.sd-dropdown:hover { - border: none; - transition: none; -} - -details.sd-dropdown .sd-summary-content { - font-weight: 300; -} - -details.sd-dropdown .highlight .n { - font-weight: normal; -} - -.et-page-column1 { - float: left; - width: 70%; - font-size: 1rem; -} - -.et-page-column2 { - float: right; - padding-top: 40px; - padding-left: 60px; - padding-right: 60px; - padding-bottom: 60px; - width: 30%; -} - -.et-page-column-row:after { - content: ""; - display: table; - clear: both; -} - -/* For screens smaller than 768px (typical mobile devices) */ -@media screen and (max-width: 768px) { - .et-page-column1, .et-page-column2 { - float: none; /* Remove floats */ - width: 100%; /* Full width for both columns */ - padding: 0; - font-size: 1rem; - } - - .et-page-column2 img { - display: none; - } - .et-page-column-row:after { - content: ""; - display: table; - clear: both; - } -} - -article.pytorch-article .class .method dt { - border-top: none; -} - -article.pytorch-article .class .simple dt { - border-top: none; -} - -article.pytorch-article .function dt.sig { - border-top: none; -} - -/* styles needed for 3rd level left nav */ - -.pytorch-left-menu ul, .pytorch-right-menu ul { - margin-left: 1.2em; -} - -.pytorch-left-menu li.toctree-l2.current > a { - color: #e44c2c; -} - -/* The next two styles enable normal hihglighting in the third level nav -in right side bar.*/ -#pytorch-right-menu .side-scroll-highlight { - color: #6c6c6d; -} - -#pytorch-right-menu a.reference.internal.side-scroll-highlight-local { - color: #ee4c2c; -} diff --git a/docs/source/_static/css/progress-bar.css b/docs/source/_static/css/progress-bar.css deleted file mode 100644 index 9b3aeb9d301..00000000000 --- a/docs/source/_static/css/progress-bar.css +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -.progress-bar-wrapper { - margin-top: auto; - display: flex; - justify-content: space-between; - margin-bottom: 20px; - position: sticky; - top: 0; - background: white; - padding-top: 20px; - padding-bottom: 20px; - z-index: 2; -} - -.progress-bar-item { - position: relative; - display: flex; - flex-direction: column; - align-items: center; - flex: 1; - - @media (max-width: 768px) { - font-size: 12px; - } -} - -.progress-bar-item::before { - position: absolute; - content: ""; - border-bottom: 2px solid #ccc; - width: 100%; - top: 20px; - left: -50%; - z-index: 2; -} - -.progress-bar-item::after { - position: absolute; - content: ""; - border-bottom: 2px solid #ccc; - width: 100%; - top: 20px; - left: 50%; - z-index: 2; -} - -.progress-bar-item .step-number { - position: relative; - z-index: 5; - display: flex; - justify-content: center; - align-items: center; - width: 40px; - height: 40px; - border-radius: 50%; - border-color: #812CE5; - border-style: solid; - border-width: 1px; - color: #812CE5; - background: #fff; - margin-bottom: 6px; -} - -.progress-bar-item.active { - font-weight: bold; -} - -.progress-bar-item.completed .step-number { - background-color: #812CE5; - color: white; -} - -.progress-bar-item.completed::after { - position: absolute; - content: ""; - border-bottom: 2px solid #812CE5; - width: 100%; - top: 20px; - left: 50%; - z-index: 3; -} - -.progress-bar-item:first-child::before { - content: none; -} - -.progress-bar-item:last-child::after { - content: none; -} - -.progress-bar-item a:link { - color: #262626 !important; -} - -.step-caption:first-child { - margin-left: 10px; -} - -.step-caption { - text-align: center; -} - -.step-caption a:link { - color: #262626 !important; -} - -.step-caption a:hover { - color: #ee4c2c; - text-decoration: underline; -} diff --git a/docs/source/_static/img/ExecuTorch-Logo-cropped.svg b/docs/source/_static/img/ExecuTorch-Logo-cropped.svg deleted file mode 100644 index 9e0ef52fbd8..00000000000 --- a/docs/source/_static/img/ExecuTorch-Logo-cropped.svg +++ /dev/null @@ -1,57 +0,0 @@ - - - - - - - - - - - diff --git a/docs/source/_static/img/executorch-chip-logo-circle-16.png b/docs/source/_static/img/executorch-chip-logo-circle-16.png new file mode 100644 index 00000000000..a3966ae27db Binary files /dev/null and b/docs/source/_static/img/executorch-chip-logo-circle-16.png differ diff --git a/docs/source/_static/img/executorch-chip-logo-circle-32.png b/docs/source/_static/img/executorch-chip-logo-circle-32.png new file mode 100644 index 00000000000..83f1018a76c Binary files /dev/null and b/docs/source/_static/img/executorch-chip-logo-circle-32.png differ diff --git a/docs/source/_static/img/executorch-chip-logo.svg b/docs/source/_static/img/executorch-chip-logo.svg new file mode 100644 index 00000000000..11e5ed60956 --- /dev/null +++ b/docs/source/_static/img/executorch-chip-logo.svg @@ -0,0 +1,205 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/img/swiftpm_xcode1.png b/docs/source/_static/img/swiftpm_xcode1.png index 4e624ed43df..b9acb23847b 100644 Binary files a/docs/source/_static/img/swiftpm_xcode1.png and b/docs/source/_static/img/swiftpm_xcode1.png differ diff --git a/docs/source/_static/img/visualization/1.png b/docs/source/_static/img/visualization/1.png new file mode 100644 index 00000000000..9d76c793492 Binary files /dev/null and b/docs/source/_static/img/visualization/1.png differ diff --git a/docs/source/_static/img/visualization/2.png b/docs/source/_static/img/visualization/2.png new file mode 100644 index 00000000000..0efe1fe8555 Binary files /dev/null and b/docs/source/_static/img/visualization/2.png differ diff --git a/docs/source/_static/img/visualization/3.png b/docs/source/_static/img/visualization/3.png new file mode 100644 index 00000000000..18d45bc4412 Binary files /dev/null and b/docs/source/_static/img/visualization/3.png differ diff --git a/docs/source/_static/img/visualization/4.png b/docs/source/_static/img/visualization/4.png new file mode 100644 index 00000000000..9e20a92d962 Binary files /dev/null and b/docs/source/_static/img/visualization/4.png differ diff --git a/docs/source/_static/img/visualization/5.png b/docs/source/_static/img/visualization/5.png new file mode 100644 index 00000000000..08becaee177 Binary files /dev/null and b/docs/source/_static/img/visualization/5.png differ diff --git a/docs/source/_static/img/visualization/6.png b/docs/source/_static/img/visualization/6.png new file mode 100644 index 00000000000..342b47bc415 Binary files /dev/null and b/docs/source/_static/img/visualization/6.png differ diff --git a/docs/source/_static/img/visualization/visualize_with_clusters_example.png b/docs/source/_static/img/visualization/visualize_with_clusters_example.png new file mode 100644 index 00000000000..938ae24ae48 Binary files /dev/null and b/docs/source/_static/img/visualization/visualize_with_clusters_example.png differ diff --git a/docs/source/_static/js/progress-bar.js b/docs/source/_static/js/progress-bar.js deleted file mode 100644 index 878251cfc60..00000000000 --- a/docs/source/_static/js/progress-bar.js +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -document.addEventListener("DOMContentLoaded", function() { - const steps = Array.from(document.querySelectorAll('.progress-bar-item')); - const h2s = Array.from(document.querySelectorAll('h2')); - - // Populate captions from h2s - h2s.forEach((h2, index) => { - const captionElem = document.getElementById(`caption-${index + 1}`); - if (captionElem) { - captionElem.innerText = h2.innerText; - } - }); - - // Throttle function to optimize performance - function throttle(func, delay) { - let lastCall = 0; - return function() { - const now = Date.now(); - if (now - lastCall < delay) return; - lastCall = now; - func.apply(this, arguments); - } - } - - document.addEventListener("scroll", throttle(function() { - let activeIndex = 0; - let closestDistance = Number.MAX_VALUE; - const totalHeight = document.documentElement.scrollHeight; - const viewportHeight = window.innerHeight; - const scrollBottom = window.scrollY + viewportHeight; - const isAtBottom = totalHeight === scrollBottom; - - h2s.forEach((h2, index) => { - const rect = h2.getBoundingClientRect(); - const distanceToTop = Math.abs(rect.top); - if (distanceToTop < closestDistance) { - closestDistance = distanceToTop; - activeIndex = index; - } - }); - - steps.forEach((step, index) => { - if (isAtBottom) { - step.classList.remove('active'); - step.classList.add('completed'); - } else { - if (index < activeIndex) { - step.classList.remove('active'); - step.classList.add('completed'); - } else if (index === activeIndex) { - step.classList.add('active'); - step.classList.remove('completed'); - } else { - step.classList.remove('active', 'completed'); - } - } - }); - }, 100)); -}); diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html deleted file mode 100644 index 55f91103b35..00000000000 --- a/docs/source/_templates/layout.html +++ /dev/null @@ -1,145 +0,0 @@ -{% extends "!layout.html" %} - -{% block extrahead %} -{% if 'getting-started-setup' in pagename%} - - -{% elif 'compiler-delegate-and-partitioner' in pagename%} - - -{% elif 'xtensa' in pagename%} - - -{% elif 'qualcomm-ai-engine-direct-backend' in pagename%} - - -{% elif 'coreml' in pagename%} - - -{% elif 'mps' in pagename%} - - -{% endif %} -{{ super() }} -{% endblock %} - - -{% block sidebartitle %} - - {% include "searchbox.html" %} -{% endblock %} - -{%- block content %} -{% if 'tutorials' in pagename %} - - - -{% endif %} -{{ super() }} - -{% endblock %} - - - -{% block menu %} - {% if 'singlehtml' not in builder %} - {% set global_toc = toctree(collapse=theme_collapse_navigation|tobool, - includehidden=theme_includehidden|tobool, - titles_only=theme_titles_only|tobool) %} - {% endif %} - {% if global_toc %} - {{ global_toc }} - {% else %} - -
{{ toc }}
- {% endif %} -{% endblock %} - - -{% block footer %} -{{ super() }} - - -{{ super() }} - - -{{ super() }} - -{% endblock %} diff --git a/docs/source/advanced-topics-section.md b/docs/source/advanced-topics-section.md new file mode 100644 index 00000000000..e7b7f5490c6 --- /dev/null +++ b/docs/source/advanced-topics-section.md @@ -0,0 +1,112 @@ +(advanced-topics-section)= + +# Advanced + +Deep dive into ExecuTorch's advanced features for optimization, customization, and integration. + +This section covers advanced concepts for developers who need to customize ExecuTorch for specific use cases, optimize performance, or integrate with custom hardware backends. + +## Quantization & Optimization + +Techniques for model compression and performance optimization. + +**→ {doc}`quantization-optimization` — Quantization strategies and performance optimization** + +Key topics: + +- Quantization strategies and techniques +- Performance profiling and optimization + +## Model Export + +Learn the core ExecuTorch workflow, exporting PyTorch models to the `.pte` format for edge deployment. + +**→ {doc}`using-executorch-export`** - Model Export & Lowering + +Key topics: + +- Export and Lowering Workflow +- Hardware Backend Selection & Optimization +- Dynamic Shapes & Advanced Model Features + + +## Kernel Library + +Deep dive into ExecuTorch's kernel implementation and customization. + +**→ {doc}`kernel-library-advanced` — Kernel library deep dive and customization** + +Key topics: + +- Kernel library architecture +- Custom kernel implementation +- Selective build and optimization + +## Backend & Delegates + +**→ {doc}`backend-delegate-advanced` — Backend delegate integration** + +Key topics: + +- Learn how to integrate Backend Delegate into ExecuTorch and more +- XNNPACK Delegate Internals +- Debugging Delegation + + +## Runtime & Integration + +Advanced runtime features and backend integration. + +**→ {doc}`runtime-integration-advanced` — Runtime customization and backend integration** + +Key topics: + +- Backend delegate implementation +- Platform abstraction layer +- Custom runtime integration + +## Compiler & IR + +Advanced compiler features and intermediate representation details. + +**→ {doc}`compiler-ir-advanced` — Compiler passes and IR specification** + +Key topics: + +- Custom compiler passes +- Memory planning strategies +- Backend dialect and EXIR +- Ops set definition + + +## File Formats + +ExecuTorch file format specifications and internals. + +**→ {doc}`file-formats-advanced` — PTE and PTD file format specifications** + +Key topics: + +- PTE file format internals +- PTD file format specification +- Custom file format handling + +## Next Steps + +After exploring advanced topics: + +- **{doc}`tools-sdk-section`** - Developer tools for debugging and profiling +- **{doc}`api-section`** - Complete API reference documentation + +```{toctree} +:hidden: +:maxdepth: 2 +:caption: Advanced Topics + +quantization-optimization +using-executorch-export +kernel-library-advanced +backend-delegate-advanced +runtime-integration-advanced +compiler-ir-advanced +file-formats-advanced diff --git a/docs/source/android-arm-vgf.md b/docs/source/android-arm-vgf.md new file mode 100644 index 00000000000..51111900add --- /dev/null +++ b/docs/source/android-arm-vgf.md @@ -0,0 +1 @@ +```{include} backends/arm-vgf/arm-vgf-overview.md diff --git a/docs/source/android-backends.md b/docs/source/android-backends.md new file mode 100644 index 00000000000..d4da0966ed9 --- /dev/null +++ b/docs/source/android-backends.md @@ -0,0 +1,28 @@ +(android-backends)= +# Backends + +Available hardware acceleration backends for Android deployment. + +## CPU Acceleration + +- {doc}`android-xnnpack` — XNNPACK CPU acceleration + +## GPU Acceleration + +- {doc}`android-vulkan` — Vulkan GPU acceleration + +## NPU/Accelerator Backends + +- {doc}`android-qualcomm` — Qualcomm AI Engine (NPU) +- {doc}`android-mediatek` — MediaTek NPU acceleration +- {doc}`android-arm-vgf` — ARM VGF Backend +- {doc}`backends/samsung/samsung-overview` — Samsung Exynos NPU + +```{toctree} +:hidden: +android-xnnpack +android-vulkan +android-qualcomm +android-mediatek +android-arm-vgf +backends/samsung/samsung-overview diff --git a/docs/source/android-examples.md b/docs/source/android-examples.md new file mode 100644 index 00000000000..057fd48bc55 --- /dev/null +++ b/docs/source/android-examples.md @@ -0,0 +1,9 @@ +# Examples & Demos + +- [Working with LLMs - Android Examples](https://github.com/meta-pytorch/executorch-examples/blob/main/llm/android/LlamaDemo/README.md) - ExecuTorch Llama Android Demo App +- [Demo Apps](https://github.com/meta-pytorch/executorch-examples/tree/main/dl3/android/DeepLabV3Demo#executorch-android-demo-app) - DeepLab v3 model for image segmentation +- {doc}`tutorial-arm-vgf` — Export a simple PyTorch model for the ExecuTorch VGF backend + +```{toctree} +:hidden: +tutorial-arm-vgf diff --git a/docs/source/android-mediatek.md b/docs/source/android-mediatek.md new file mode 100644 index 00000000000..7034fe439dd --- /dev/null +++ b/docs/source/android-mediatek.md @@ -0,0 +1 @@ +```{include} backends-mediatek.md diff --git a/docs/source/android-qualcomm.md b/docs/source/android-qualcomm.md new file mode 100644 index 00000000000..f484d771a8b --- /dev/null +++ b/docs/source/android-qualcomm.md @@ -0,0 +1 @@ +```{include} backends-qualcomm.md diff --git a/docs/source/android-samsung-exynos.md b/docs/source/android-samsung-exynos.md new file mode 100644 index 00000000000..4c5a470edca --- /dev/null +++ b/docs/source/android-samsung-exynos.md @@ -0,0 +1 @@ +```{include} backends-samsung-exynos.md diff --git a/docs/source/android-section.md b/docs/source/android-section.md new file mode 100644 index 00000000000..a5774352bc1 --- /dev/null +++ b/docs/source/android-section.md @@ -0,0 +1,23 @@ +(android-section)= + +# Android + +Deploy ExecuTorch on Android devices with hardware acceleration support. + +## Quick Start & Integration + +- {doc}`using-executorch-android` — Complete Android integration guide + +## Backends + +- {doc}`android-backends` — Available Android backends and acceleration options + +## Examples & Demos + +- {doc}`android-examples` — Explore Android Examples & Demos + +```{toctree} +:hidden: +using-executorch-android +android-backends +android-examples diff --git a/docs/source/android-vulkan.md b/docs/source/android-vulkan.md new file mode 100644 index 00000000000..aa987835989 --- /dev/null +++ b/docs/source/android-vulkan.md @@ -0,0 +1 @@ +```{include} backends/vulkan/vulkan-overview.md diff --git a/docs/source/android-xnnpack.md b/docs/source/android-xnnpack.md new file mode 100644 index 00000000000..4a85dec946b --- /dev/null +++ b/docs/source/android-xnnpack.md @@ -0,0 +1 @@ +```{include} backends/xnnpack/xnnpack-overview.md diff --git a/docs/source/api-section.md b/docs/source/api-section.md new file mode 100644 index 00000000000..ab2573aefa9 --- /dev/null +++ b/docs/source/api-section.md @@ -0,0 +1,26 @@ +(api-section)= +# API + +In this section, find complete API documentation for ExecuTorch's export, runtime, and extension interfaces. Includes comprehensive references for Python, C++, and Java APIs across all supported platforms. + +- {doc}`export-to-executorch-api-reference` — Export to ExecuTorch API Reference +- {doc}`executorch-runtime-api-reference` — ExecuTorch Runtime API Reference +- {doc}`runtime-python-api-reference` — Runtime Python API Reference +- {doc}`api-life-cycle` — API Life Cycle +- [Android doc →](https://pytorch.org/executorch/main/javadoc/) — Android API Documentation +- {doc}`extension-module` — Extension Module +- {doc}`extension-tensor` — Extension Tensor +- {doc}`running-a-model-cpp-tutorial` — Detailed C++ Runtime APIs Tutorial + +```{toctree} +:hidden: +:maxdepth: 1 +:caption: API Reference + +export-to-executorch-api-reference +executorch-runtime-api-reference +runtime-python-api-reference +api-life-cycle +extension-module +extension-tensor +running-a-model-cpp-tutorial diff --git a/docs/source/archive/backends-cadence-legacy.md b/docs/source/archive/backends-cadence-legacy.md new file mode 100644 index 00000000000..21f60477c63 --- /dev/null +++ b/docs/source/archive/backends-cadence-legacy.md @@ -0,0 +1,238 @@ +# Cadence Xtensa Backend (Legacy / Outdated) + +```{warning} +**⚠️ THIS DOCUMENTATION IS OUTDATED AND NO LONGER MAINTAINED** + +**For current Cadence backend documentation and support:** +- Please refer to the up-to-date documentation in [backends-cadence.md](../backends-cadence.md) +``` + +--- +# Cadence Xtensa Backend + + +In this tutorial we will walk you through the process of getting setup to build ExecuTorch for an Xtensa HiFi4 DSP and running a simple model on it. + +[Cadence](https://www.cadence.com/en_US/home.html) is both a hardware and software vendor, providing solutions for many computational workloads, including to run on power-limited embedded devices. The [Xtensa HiFi4 DSP](https://www.cadence.com/en_US/home/tools/ip/tensilica-ip/hifi-dsps/hifi-4.html) is a Digital Signal Processor (DSP) that is optimized for running audio based neural networks such as wake word detection, Automatic Speech Recognition (ASR), etc. + +In addition to the chip, the HiFi4 Neural Network Library ([nnlib](https://github.com/foss-xtensa/nnlib-hifi4)) offers an optimized set of library functions commonly used in NN processing that we utilize in this example to demonstrate how common operations can be accelerated. + +On top of being able to run on the Xtensa HiFi4 DSP, another goal of this tutorial is to demonstrate how portable ExecuTorch is and its ability to run on a low-power embedded device such as the Xtensa HiFi4 DSP. This workflow does not require any delegates, it uses custom operators and compiler passes to enhance the model and make it more suitable to running on Xtensa HiFi4 DSPs. A custom [quantizer](https://pytorch.org/tutorials/prototype/quantization_in_pytorch_2_0_export_tutorial.html) is used to represent activations and weights as `uint8` instead of `float`, and call appropriate operators. Finally, custom kernels optimized with Xtensa intrinsics provide runtime acceleration. + +::::{grid} 2 +:::{grid-item-card} What you will learn in this tutorial: +:class-card: card-prerequisites +* In this tutorial you will learn how to export a quantized model with a linear operation targeted for the Xtensa HiFi4 DSP. +* You will also learn how to compile and deploy the ExecuTorch runtime with the kernels required for running the quantized model generated in the previous step on the Xtensa HiFi4 DSP. +::: +:::{grid-item-card} Tutorials we recommend you complete before this: +:class-card: card-prerequisites +* [Introduction to ExecuTorch](intro-how-it-works.md) +* [Getting Started](getting-started.md) +* [Building ExecuTorch with CMake](using-executorch-building-from-source.md) +::: +:::: + +```{note} +The linux part of this tutorial has been designed and tested on Ubuntu 22.04 LTS, and requires glibc 2.34. Workarounds are available for other distributions, but will not be covered in this tutorial. +``` + +## Prerequisites (Hardware and Software) + +In order to be able to succesfully build and run ExecuTorch on a Xtensa HiFi4 DSP you'll need the following hardware and software components. + +### Hardware + - [i.MX RT600 Evaluation Kit](https://www.nxp.com/design/development-boards/i-mx-evaluation-and-development-boards/i-mx-rt600-evaluation-kit:MIMXRT685-EVK) + +### Software + - x86-64 Linux system (For compiling the DSP binaries) + - [MCUXpresso IDE](https://www.nxp.com/design/software/development-software/mcuxpresso-software-and-tools-/mcuxpresso-integrated-development-environment-ide:MCUXpresso-IDE) + - This IDE is supported on multiple platforms including MacOS. You can use it on any of the supported platforms as you'll only be using this to flash the board with the DSP images that you'll be building later on in this tutorial. +- [J-Link](https://www.segger.com/downloads/jlink/) + - Needed to flash the board with the firmware images. You can install this on the same platform that you installed the MCUXpresso IDE on. + - Note: depending on the version of the NXP board, another probe than JLink might be installed. In any case, flashing is done using the MCUXpresso IDE in a similar way. + - [MCUXpresso SDK](https://mcuxpresso.nxp.com/en/select?device=EVK-MIMXRT685) + - Download this SDK to your Linux machine, extract it and take a note of the path where you store it. You'll need this later. +- [Xtensa compiler](https://tensilicatools.com/platform/i-mx-rt600/) + - Download this to your Linux machine. This is needed to build ExecuTorch for the HiFi4 DSP. +- For cases with optimized kernels, the [nnlib repo](https://github.com/foss-xtensa/nnlib-hifi4). + +## Setting up Developer Environment + +Step 1. In order to be able to successfully install all the software components specified above users will need to go through the NXP tutorial linked below. Although the tutorial itself walks through a Windows setup, most of the steps translate over to a Linux installation too. + +[NXP tutorial on setting up the board and dev environment](https://www.nxp.com/document/guide/getting-started-with-i-mx-rt600-evaluation-kit:GS-MIMXRT685-EVK?section=plug-it-in) + +```{note} +Before proceeding forward to the next section users should be able to succesfullly flash the **dsp_mu_polling_cm33** sample application from the tutorial above and notice output on the UART console indicating that the Cortex-M33 and HiFi4 DSP are talking to each other. +``` + +Step 2. Make sure you have completed the ExecuTorch setup tutorials linked to at the top of this page. + +## Working Tree Description + +The working tree is: + +``` +executorch +├── backends +│ └── cadence +│ ├── aot +│ ├── ops_registration +│ ├── tests +│ ├── utils +│ ├── hifi +│ │ ├── kernels +│ │ ├── operators +│ │ └── third-party +│ │ └── hifi4-nnlib +│ └── [other cadence DSP families] +│ ├── kernels +│ ├── operators +│ └── third-party +│ └── [any required lib] +└── examples + └── cadence + ├── models + └── operators +``` + +***AoT (Ahead-of-Time) Components***: + +The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/backends/cadence/aot/export_example.py) is an API that takes a model (nn.Module) and representative inputs and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/backends/cadence/aot/quantizer/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/backends/cadence/aot/quantizer/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [ops_registrations.py](https://github.com/pytorch/executorch/blob/main/backends/cadence/aot/ops_registrations.py) and have corresponding implemetations in the other folders. + +***Operators***: + +The operators folder contains two kinds of operators: existing operators from the [ExecuTorch portable library](https://github.com/pytorch/executorch/tree/main/kernels/portable/cpu) and new operators that define custom computations. The former is simply dispatching the operator to the relevant ExecuTorch implementation, while the latter acts as an interface, setting up everything needed for the custom kernels to compute the outputs. + +***Kernels***: + +The kernels folder contains the optimized kernels that will run on the HiFi4 chip. They use Xtensa intrinsics to deliver high performance at low-power. + +## Build + +In this step, you will generate the ExecuTorch program from different models. You'll then use this Program (the `.pte` file) during the runtime build step to bake this Program into the DSP image. + +***Simple Model***: + +The first, simple model is meant to test that all components of this tutorial are working properly, and simply does an add operation. The generated file is called `add.pte`. + +```bash +cd executorch +python3 -m examples.portable.scripts.export --model_name="add" +``` + +***Quantized Operators***: + +The other, more complex model are custom operators, including: + - a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/cadence/operators/test_quantized_linear_op.py#L30). Linear is the backbone of most Automatic Speech Recognition (ASR) models. + - a quantized [conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/cadence/operators/test_quantized_conv1d_op.py#L40). Convolutions are important in wake word and many denoising models. + +In both cases the generated file is called `CadenceDemoModel.pte`. + +```bash +cd executorch +python3 -m examples.cadence.operators.quantized__op +``` + +***Small Model: RNNT predictor***: + +The torchaudio [RNNT-emformer](https://pytorch.org/audio/stable/tutorials/online_asr_tutorial.html) model is an Automatic Speech Recognition (ASR) model, comprised of three different submodels: an encoder, a predictor and a joiner. +The [predictor](https://github.com/pytorch/executorch/blob/main/examples/cadence/models/rnnt_predictor.py) is a sequence of basic ops (embedding, ReLU, linear, layer norm) and can be exported using: + +```bash +cd executorch +python3 -m examples.cadence.models.rnnt_predictor +``` + +The generated file is called `CadenceDemoModel.pte`. + +### Runtime + +**Building the DSP firmware image** +In this step, you'll be building the DSP firmware image that consists of the sample ExecuTorch runner along with the Program generated from the previous step. This image when loaded onto the DSP will run through the model that this Program consists of. + +***Step 1***. Configure the environment variables needed to point to the Xtensa toolchain that you have installed in the previous step. The three environment variables that need to be set include: +```bash +# Directory in which the Xtensa toolchain was installed +export XTENSA_TOOLCHAIN=/home/user_name/cadence/XtDevTools/install/tools +# The version of the toolchain that was installed. This is essentially the name of the directory +# that is present in the XTENSA_TOOLCHAIN directory from above. +export TOOLCHAIN_VER=RI-2021.8-linux +# The Xtensa core that you're targeting. +export XTENSA_CORE=nxp_rt600_RI2021_8_newlib +``` + +***Step 2***. Clone the [nnlib repo](https://github.com/foss-xtensa/nnlib-hifi4), which contains optimized kernels and primitives for HiFi4 DSPs, with `git clone git@github.com:foss-xtensa/nnlib-hifi4.git`. + +***Step 3***. Run the CMake build. +In order to run the CMake build, you need the path to the following: +- The Program generated in the previous step +- Path to the NXP SDK root. This should have been installed already in the [Setting up Developer Environment](#setting-up-developer-environment) section. This is the directory that contains the folders such as boards, components, devices, and other. + +```bash +cd executorch +./install_executorch.sh --clean +mkdir cmake-out +# prebuild and install executorch library +cmake -DCMAKE_TOOLCHAIN_FILE=/backends/cadence/cadence.cmake \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Debug \ + -DPYTHON_EXECUTABLE=python3 \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=OFF \ + -DEXECUTORCH_BUILD_PTHREADPOOL=OFF \ + -DEXECUTORCH_BUILD_CPUINFO=OFF \ + -Bcmake-out . + +cmake --build cmake-out -j --target install --config Debug +# build cadence runner +cmake -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_TOOLCHAIN_FILE=/examples/backends/cadence.cmake \ + -DCMAKE_PREFIX_PATH=/cmake-out \ + -DMODEL_PATH= \ + -DNXP_SDK_ROOT_DIR= \ + -DNN_LIB_BASE_DIR= \ + -Bcmake-out/examples/cadence \ + examples/cadence + +cmake --build cmake-out/examples/cadence -j8 -t cadence_executorch_example +``` + +After having succesfully run the above step you should see two binary files in their CMake output directory. +```bash +> ls cmake-xt/*.bin +cmake-xt/dsp_data_release.bin cmake-xt/dsp_text_release.bin +``` + +## Deploying and Running on Device + +***Step 1***. You now take the DSP binary images generated from the previous step and copy them over into your NXP workspace created in the [Setting up Developer Environment](#setting-up-developer-environment) section. Copy the DSP images into the `dsp_binary` section highlighted in the image below. + +![MCUXpresso IDE](../_static/img/dsp_binary.png) + +```{note} +As long as binaries have been built using the Xtensa toolchain on Linux, flashing the board and running on the chip can be done only with the MCUXpresso IDE, which is available on all platforms (Linux, MacOS, Windows). +``` + +***Step 2***. Clean your work space + +***Step 3***. Click **Debug your Project** which will flash the board with your binaries. + +On the UART console connected to your board (at a default baud rate of 115200), you should see an output similar to this: + +```bash +> screen /dev/tty.usbmodem0007288234991 115200 +Executed model +Model executed successfully. +First 20 elements of output 0 +0.165528 0.331055 ... +``` + +## Conclusion and Future Work + +In this tutorial, you have learned how to export a quantized operation, build the ExecuTorch runtime and run this model on the Xtensa HiFi4 DSP chip. + +The (quantized linear) model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model as a new test and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/backends/cadence/hifi/operators) and [kernels](https://github.com/pytorch/executorch/blob/main/backends/cadence/hifi/kernels). + +Other models can be created following the same structure, always assuming that operators and kernels are available. diff --git a/docs/source/backend-delegate-advanced.md b/docs/source/backend-delegate-advanced.md new file mode 100644 index 00000000000..e82e5ee035d --- /dev/null +++ b/docs/source/backend-delegate-advanced.md @@ -0,0 +1,28 @@ +(backend-delegate-advanced)= + +# Backend & Delegates + +## Integration + +- {doc}`backend-delegates-integration` — Learn how to integrate a backend delegate into ExecuTorch + +## Dependency Management + +- {doc}`backend-delegates-dependencies` — Manage third-party dependencies for backend delegates + +## Overview + +- {doc}`compiler-delegate-and-partitioner` — Understanding backends, delegates, and the partitioner system + +## Debugging + +- {doc}`debug-backend-delegate` — Tools and techniques for debugging delegation issues + +```{toctree} +:hidden: +:maxdepth: 1 + +backend-delegates-integration +backend-delegates-dependencies +compiler-delegate-and-partitioner +debug-backend-delegate diff --git a/docs/source/backend-delegates-dependencies.md b/docs/source/backend-delegates-dependencies.md index f2068989bd2..06f23ca36bc 100644 --- a/docs/source/backend-delegates-dependencies.md +++ b/docs/source/backend-delegates-dependencies.md @@ -49,7 +49,7 @@ for these third-party dependencies. `executorch/third-party` then try to use that if possible. This helps with reducing the binary size when the delegate is enabled. * The rest of the ExecuTorch code, outside of the delegate, should not depend on - this. And it should should build and run correctly without this dependency + this. And it should build and run correctly without this dependency when the delegate is disabled at build time. More details in the section [below](#runtime-dependencies). diff --git a/docs/source/backend-delegates-integration.md b/docs/source/backend-delegates-integration.md index 0179ceff872..130da0d3225 100644 --- a/docs/source/backend-delegates-integration.md +++ b/docs/source/backend-delegates-integration.md @@ -23,12 +23,13 @@ the top level ExecuTorch package. For third-party dependencies, please refer to At a minimum, a delegate must provide CMake support for building its C++ sources. -For the CMake setup, the delegate dir should be included by the -top level `CMakeLists.txt` file using `add_subdirectory` CMake command, and -should be built conditionally with an ExecuTorch build flag like -`EXECUTORCH_BUILD_`, see `EXECUTORCH_BUILD_XNNPACK` for example. -For third-party dependencies, please refer to -[this](backend-delegates-dependencies.md). +For the CMake setup: + +- The delegate directory should be included by the top-level `CMakeLists.txt` file using the `add_subdirectory` command. +- It should be built conditionally using an ExecuTorch build flag like `EXECUTORCH_BUILD_`. +(See `EXECUTORCH_BUILD_XNNPACK` for an example.) + +For third-party dependencies, please refer to [this](backend-delegates-dependencies.md). +::::{grid} 2 + +:::{grid-item-card} Tutorials we recommend you complete before this: +:class-card: card-prerequisites +* [Introduction to ExecuTorch](intro-how-it-works.md) +* [Getting Started](getting-started.md) +* [Building ExecuTorch with CMake](using-executorch-building-from-source.md) +::: + +:::{grid-item-card} What you will learn in this tutorial: +:class-card: card-prerequisites +In this tutorial you will learn how to export a simple PyTorch model for the ExecuTorch Ethos-U backend. +::: + +:::: + +```{tip} +If you are already familiar with this delegate, you may want to jump directly to the examples: +* [Examples in the ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm) +* [A commandline compiler for example models](https://github.com/pytorch/executorch/blob/main/examples/arm/aot_arm_compiler.py) +``` + +This tutorial serves as an introduction to using ExecuTorch to deploy PyTorch models on Arm® Ethos™-U targets. It is based on `ethos_u_minimal_example.ipynb`, provided in Arm’s examples folder. + +## Prerequisites + +### Hardware + +To successfully complete this tutorial, you will need a Linux machine with aarch64 or x86_64 processor architecture, or a macOS™ machine with Apple® Silicon. + +To enable development without a specific development board, we will be using a [Fixed Virtual Platform (FVP)](https://www.arm.com/products/development-tools/simulation/fixed-virtual-platforms), simulating [Arm® Corstone™-300](https://developer.arm.com/Processors/Corstone-300)(cs300) and [Arm® Corstone™-320](https://developer.arm.com/Processors/Corstone-320)(cs320)systems. Think of it as virtual hardware. + +### Software + +First, you will need to install ExecuTorch. Please follow the recommended tutorials to set up a working ExecuTorch development environment. + +In addition to this, you need to install a number of SDK dependencies for generating Ethos-U command streams. Scripts to automate this are available in the main [ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm/). +To install Ethos-U dependencies, run +```bash +./examples/arm/setup.sh --i-agree-to-the-contained-eula +``` +This will install: +- [TOSA Serialization Library](https://www.mlplatform.org/tosa/software.html) for serializing the Exir IR graph into TOSA IR. +- [Ethos-U Vela graph compiler](https://pypi.org/project/ethos-u-vela/) for compiling TOSA flatbuffers into a Ethos-U command stream. +- [Arm GNU Toolchain](https://developer.arm.com/Tools%20and%20Software/GNU%20Toolchain) for cross compilation. +- [Corstone SSE-300 FVP](https://developer.arm.com/documentation/100966/1128/Arm--Corstone-SSE-300-FVP) for testing on Ethos-U55 reference design. +- [Corstone SSE-320 FVP](https://developer.arm.com/documentation/109760/0000/SSE-320-FVP) for testing on Ethos-U85 reference design. + +## Set Up the Developer Environment + +The setup.sh script generates a setup_path.sh script that you need to source whenever you restart your shell. Run: + +```{bash} +source examples/arm/arm-scratch/setup_path.sh +``` + +As a simple check that your environment is set up correctly, run `which FVP_Corstone_SSE-320` and make sure that the executable is located where you expect, in the `examples/arm` tree. + +## Build + +### Ahead-of-Time (AOT) components + +The ExecuTorch Ahead-of-Time (AOT) pipeline takes a PyTorch Model (a `torch.nn.Module`) and produces a `.pte` binary file, which is then consumed by the ExecuTorch Runtime. This [document](getting-started-architecture.md) goes in much more depth about the ExecuTorch software stack for both AoT as well as Runtime. + +The example below shows how to quantize a model consisting of a single addition, and export it it through the AOT flow using the EthosU backend. For more details, see `examples/arm/ethos_u_minimal_example.ipynb`. +```python +import torch + +class Add(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + +example_inputs = (torch.ones(1,1,1,1),torch.ones(1,1,1,1)) + +model = Add() +model = model.eval() +exported_program = torch.export.export(model, example_inputs) +graph_module = exported_program.graph_module + + +from executorch.backends.arm.ethosu import EthosUCompileSpec +from executorch.backends.arm.quantizer import ( + EthosUQuantizer, + get_symmetric_quantization_config, +) +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + +# Create a compilation spec describing the target for configuring the quantizer +# Some args are used by the Arm Vela graph compiler later in the example. Refer to Arm Vela documentation for an +# explanation of its flags: https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/OPTIONS.md +compile_spec = EthosUCompileSpec( + target="ethos-u55-128", + system_config="Ethos_U55_High_End_Embedded", + memory_mode="Shared_Sram", + extra_flags=["--output-format=raw", "--debug-force-regor"] + ) + +# Create and configure quantizer to use a symmetric quantization config globally on all nodes +quantizer = EthosUQuantizer(compile_spec) +operator_config = get_symmetric_quantization_config() +quantizer.set_global(operator_config) + +# Post training quantization +quantized_graph_module = prepare_pt2e(graph_module, quantizer) +quantized_graph_module(*example_inputs) # Calibrate the graph module with the example input +quantized_graph_module = convert_pt2e(quantized_graph_module) + + +# Create a new exported program using the quantized_graph_module +quantized_exported_program = torch.export.export(quantized_graph_module, example_inputs) +from executorch.backends.arm.ethosu import EthosUPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.extension.export_util.utils import save_pte_program + +# Create partitioner from compile spec +partitioner = EthosUPartitioner(compile_spec) + +# Lower the exported program to the Ethos-U backend +edge_program_manager = to_edge_transform_and_lower( + quantized_exported_program, + partitioner=[partitioner], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + ), + ) + +# Convert edge program to executorch +executorch_program_manager = edge_program_manager.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=False) + ) + + +# Save pte file +save_pte_program(executorch_program_manager, "ethos_u_minimal_example.pte") +``` + + +```{tip} +For a quick start, you can use the script `examples/arm/aot_arm_compiler.py` to produce the pte. +To produce a pte file equivalent to the one above, run +`python -m examples.arm.aot_arm_compiler --model_name=add --delegate --quantize --output=ethos_u_minimal_example.pte` +``` + +### Runtime: + +After the AOT compilation flow is done, the runtime can be cross compiled and linked to the produced `.pte`-file using the Arm cross-compilation toolchain. This is done in two steps: + +First, build and install the ExecuTorch libraries and EthosUDelegate: +``` +# In ExecuTorch top-level, with sourced setup_path.sh +cmake -DCMAKE_BUILD_TYPE=Release --preset arm-baremetal -B cmake-out-arm . +cmake --build cmake-out-arm --target install -j$(nproc) +``` +Second, build and link the `arm_executor_runner` and generate kernel bindings for any non delegated ops. This is the actual program that will run on target. + +``` +# In ExecuTorch top-level, with sourced setup_path.sh +cmake -DCMAKE_TOOLCHAIN_FILE=`pwd`/examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DET_PTE_FILE_PATH=ethos_u_minimal_example.pte \ + -DTARGET_CPU=cortex-m55 \ + -DETHOSU_TARGET_NPU_CONFIG=ethos-u55-128 \ + -DMEMORY_MODE=Shared_Sram \ + -DSYSTEM_CONFIG=Ethos_U55_High_End_Embedded \ + -Bethos_u_minimal_example \ + examples/arm/executor_runner +cmake --build ethos_u_minimal_example -j$(nproc) -- arm_executor_runner +``` + +```{tip} +For a quick start, you can use the script `backends/arm/scripts/build_executor_runner.sh` to build the runner. +To build a runner equivalent to the one above, run +`./backends/arm/scripts/build_executor_runner.sh --pte=ethos_u_minimal_example.pte` +```` + +The block diagram below shows, at the high level, how the various build artifacts are generated and are linked together to generate the final bare-metal executable. + +![](arm-delegate-runtime-build.svg) + + +## Running on Corstone FVP Platforms + +Finally, use the `backends/arm/scripts/run_fvp.sh` utility script to run the .elf-file on simulated Arm hardware. +``` +backends/arm/scripts/run_fvp.sh --elf=$(find ethos_u_minimal_example -name arm_executor_runner) --target=ethos-u55-128 +``` +The example application is by default built with an input of ones, so the expected result of the quantized addition should be close to 2. + + +## Takeaways + +In this tutorial you have learned how to use ExecuTorch to export a PyTorch model to an executable that can run on an embedded target, and then run that executable on simulated hardware. +To learn more, check out these learning paths: + +https://learn.arm.com/learning-paths/embedded-and-microcontrollers/rpi-llama3/ +https://learn.arm.com/learning-paths/embedded-and-microcontrollers/visualizing-ethos-u-performance/ + +## FAQs + +If you encountered any bugs or issues following this tutorial please file a bug/issue here on [Github](https://github.com/pytorch/executorch/issues/new). + + +``` +Arm is a registered trademark of Arm Limited (or its subsidiaries or affiliates). +``` diff --git a/docs/source/backends/arm-vgf/arm-vgf-overview.md b/docs/source/backends/arm-vgf/arm-vgf-overview.md new file mode 100644 index 00000000000..4d693354dbc --- /dev/null +++ b/docs/source/backends/arm-vgf/arm-vgf-overview.md @@ -0,0 +1,114 @@ +# Arm VGF Backend + +The Arm® VGF backend is the ExecuTorch solution for lowering PyTorch models to VGF compatible hardware. +It leverages the TOSA operator set and the [ML SDK for Vulkan®](https://github.com/arm/ai-ml-sdk-for-vulkan?tab=readme-ov-file) to produce a .PTE file. +The VGF backend also supports execution from a .PTE file and provides functionality to extract the corresponding VGF file for integration into various applications. + +## Features + +- Wide operator support for delegating large parts of models to the VGF target. +- A quantizer that optimizes quantization for the VGF target. + +## Target Requirements + +The target system must include ML SDK for Vulkan and a Vulkan driver with Vulkan API >= 1.3. + +## Development Requirements + +```{tip} +All requirements can be downloaded using `examples/arm/setup.sh --enable-mlsdk-deps --disable-ethos-u-deps` and added to the path using +`source examples/arm/arm-scratch/setup_path.sh` +``` + +For the AOT flow, compilation of a model to `.pte` format using the VGF backend, the requirements are: +- [TOSA Serialization Library](https://www.mlplatform.org/tosa/software.html) for serializing the Exir IR graph into TOSA IR. +- [ML SDK Model Converter](https://github.com/arm/ai-ml-sdk-model-converter) for converting TOSA flatbuffers to VGF files. + +And for building and running your application using the generic executor_runner: +- [Vulkan API](https://www.vulkan.org) should be set up locally for GPU execution support. +- [ML Emulation Layer for Vulkan](https://github.com/arm/ai-ml-emulation-layer-for-vulkan) for testing on Vulkan API. + +## Using the Arm VGF Backend + +The [VGF Minimal Example](https://github.com/pytorch/executorch/blob/main/examples/arm/vgf_minimal_example.ipynb) demonstrates how to lower a module using the VGF backend. + +The main configuration point for the lowering is the `VgfCompileSpec` consumed by the partitioner and quantizer. +The full user-facing API is documented below. + +```python +class VgfCompileSpec(tosa_spec: executorch.backends.arm.tosa.specification.TosaSpecification | str | None = None, compiler_flags: list[str] | None = None) +``` +Compile spec for VGF compatible targets. + +Args: +- **tosa_spec**: TOSA specification that should be targeted. +- **compiler_flags**: Extra compiler flags for converter_backend. + +```python +def VgfCompileSpec.dump_debug_info(self, debug_mode: executorch.backends.arm.common.arm_compile_spec.ArmCompileSpec.DebugMode | None): +``` +Dump debugging information into the intermediates path. + +Args: +- **debug_mode**: The debug mode to use for dumping debug information. + +```python +def VgfCompileSpec.dump_intermediate_artifacts_to(self, output_path: str | None): +``` +Sets a path for dumping intermediate results during such as tosa and pte. + +Args: +- **output_path**: Path to dump intermediate results to. + +```python +def VgfCompileSpec.get_intermediate_path(self) -> str | None: +``` +Gets the path used for dumping intermediate results such as tosa and pte. + +Returns: + Path where intermediate results are saved. + +```python +def VgfCompileSpec.get_output_format() -> str: +``` +Returns a constant string that is the output format of the class. + + + +### Partitioner API + +See [Partitioner API](arm-vgf-partitioner.md) for more information of the Partitioner API. + +## Quantization + +The VGF quantizer supports [Post Training Quantization (PT2E)](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) +and [Quantization-Aware Training (QAT)](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_qat.html). + +For more information on quantization, see [Quantization](arm-vgf-quantization.md). + +## Runtime Integration + +The VGF backend can use the default ExecuTorch runner. The steps required for building and running it are explained in the [VGF Backend Tutorial](tutorials/vgf-getting-started.md). +The example application is recommended to use for testing basic functionality of your lowered models, as well as a starting point for developing runtime integrations for your own targets. + +## Reference + +**→{doc}`/backends/arm-vgf/arm-vgf-partitioner` — Partitioner options.** + +**→{doc}`/backends/arm-vgf/arm-vgf-quantization` — Supported quantization schemes.** + +**→{doc}`/backends/arm-vgf/arm-vgf-troubleshooting` — Debug common issues.** + +**→{doc}`/backends/arm-vgf/tutorials/arm-vgf-tutorials` — Tutorials.** + + +```{toctree} +:maxdepth: 2 +:hidden: +:caption: Arm VGF Backend + +arm-vgf-partitioner +arm-vgf-quantization +arm-vgf-troubleshooting +tutorials/arm-vgf-tutorials +``` diff --git a/docs/source/backends/arm-vgf/arm-vgf-partitioner.md b/docs/source/backends/arm-vgf/arm-vgf-partitioner.md new file mode 100644 index 00000000000..e3cbd2f9d22 --- /dev/null +++ b/docs/source/backends/arm-vgf/arm-vgf-partitioner.md @@ -0,0 +1,47 @@ +# Partitioner API + +The `VgfPartitioner` controls what parts of a model is delegated to the Arm VGF backend. Below is a reference of the various functions the partitioner provides: + +```python +class VgfPartitioner(compile_spec: executorch.backends.arm.vgf.compile_spec.VgfCompileSpec, additional_checks: Optional[Sequence[torch.fx.passes.operator_support.OperatorSupportBase]] = None) -> None +``` +Partitions subgraphs supported by the Arm Vgf backend. + +Args: +- **compile_spec**: The Vgf compilation specification. +- **additional_checks**: Optional sequence of additional operator support checks. + +```python +def VgfPartitioner.ops_to_not_decompose(self, ep: torch.export.exported_program.ExportedProgram) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.node.Node], bool]]]: +``` +Return operators and a filter that should not be decomposed. + +Provide a base set of ops to preserve as-is and a predicate that keeps +certain activations whole when surrounded by quantize/dequantize ops in +a quantized graph. This helps downstream TOSA lowering and delegation. + +Args: +- **ep (ExportedProgram)**: Program used to infer target-specific policy. + +Returns: +- **Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]**: + A list of op overloads to keep intact, and an optional filter + function that returns True when an op should not be decomposed. + +```python +def VgfPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult: +``` +Partition the program and tag TOSA-compatible subgraphs. + +Run the FX capability-based partitioner to propose subgraphs, then +refine tags by removing boundary-only quantize/dequantize nodes and by +rejecting partitions that would lower to no-ops. Emit a detailed report +of rejected nodes and their reasons. + +Args: +- **exported_program (ExportedProgram)**: Program to analyze and + partition. + +Returns: +- **PartitionResult**: The input program with nodes tagged for delegation + and a mapping of partition tags to delegation specs. diff --git a/docs/source/backends/arm-vgf/arm-vgf-quantization.md b/docs/source/backends/arm-vgf/arm-vgf-quantization.md new file mode 100644 index 00000000000..23f3246eb6b --- /dev/null +++ b/docs/source/backends/arm-vgf/arm-vgf-quantization.md @@ -0,0 +1,99 @@ +# Quantization + +The Arm VGF delegate can be used to execute quantized models. To quantize a model so that is supported by this delegate, the `VgfQuantizer` should be used. + +Currently the symmetric `int8` config defined by `executorch.backends.arm.quantizer.arm_quantizer.get_symmetric_quantization_config` is the main config available to use with the VGF quantizer. + +### Supported Quantization Schemes + +The quantization schemes supported by the VGF Backend are: +- 8-bit symmetric weights with 8-bit asymmetric activations (via the PT2E quantization flow). + - Supports both static and dynamic activations + - Supports per-channel and per-tensor schemes + +Weight-only quantization is not currently supported on the VGF backend. + +### Quantization API + +```python +class VgfQuantizer(compile_spec: 'VgfCompileSpec') -> 'None' +``` +Quantizer supported by the Arm Vgf backend. + +Args: +- **compile_spec (VgfCompileSpec)**: Backend compile specification for Vgf + targets. + +```python +def VgfQuantizer.quantize_with_submodules(self, model: 'GraphModule', calibration_samples: 'list[tuple]', is_qat: 'bool' = False): +``` +Quantizes a GraphModule in a way such that conditional submodules are handled properly. + +Args: +- **model (GraphModule)**: The model to quantize. +- **calibration_samples (list[tuple])**: A list of inputs to used to + calibrate the model during quantization. To properly calibrate a + model with submodules, at least one sample per code path is + needed. +- **is_qat (bool)**: Whether to do quantization aware training or not. + +Returns: +- **GraphModule**: The quantized model. + +```python +def VgfQuantizer.set_global(self, quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': +``` +Set quantization_config for submodules not matched by other filters. + +Args: +- **quantization_config (QuantizationConfig)**: Configuration to apply to + modules that are not captured by name or type filters. + +```python +def VgfQuantizer.set_io(self, quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': +``` +Set quantization_config for input and output nodes. + +Args: +- **quantization_config (QuantizationConfig)**: Configuration describing + activation quantization for model inputs and outputs. + +```python +def VgfQuantizer.set_module_name(self, module_name: 'str', quantization_config: 'Optional[QuantizationConfig]') -> 'TOSAQuantizer': +``` +Set quantization_config for submodules with a given module name. + +For example, calling set_module_name("blocks.sub") quantizes supported +patterns for that submodule with the provided quantization_config. + +Args: +- **module_name (str)**: Fully qualified module name to configure. +- **quantization_config (QuantizationConfig)**: Configuration to apply to + the named submodule. + +```python +def VgfQuantizer.set_module_type(self, module_type: 'Callable', quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': +``` +Set quantization_config for submodules with a given module type. + +For example, calling set_module_type(Sub) quantizes supported patterns +in each Sub instance with the provided quantization_config. + +Args: +- **module_type (Callable)**: Type whose submodules should use the + provided quantization configuration. +- **quantization_config (QuantizationConfig)**: Configuration to apply to + submodules of the given type. + +```python +def VgfQuantizer.transform_for_annotation(self, model: 'GraphModule') -> 'GraphModule': +``` +Transform the graph to prepare it for quantization annotation. + +Currently transforms scalar values to tensor attributes. + +Args: +- **model (GraphModule)**: Model whose graph will be transformed. + +Returns: +- **GraphModule**: Transformed model prepared for annotation. diff --git a/docs/source/backends/arm-vgf/arm-vgf-troubleshooting.md b/docs/source/backends/arm-vgf/arm-vgf-troubleshooting.md new file mode 100644 index 00000000000..6100bc94b0c --- /dev/null +++ b/docs/source/backends/arm-vgf/arm-vgf-troubleshooting.md @@ -0,0 +1,7 @@ +# Arm VGF Troubleshooting + +This page describes common issues that you may encounter when using the Arm VGF backend and how to debug and resolve them. + +## How do you visualize VGF files + +The [VGF Adapter for Model Explorer](https://github.com/arm/vgf-adapter-model-explorer) enables visualization of VGF files and can be useful for debugging. diff --git a/docs/source/backends/arm-vgf/tutorials/arm-vgf-tutorials.md b/docs/source/backends/arm-vgf/tutorials/arm-vgf-tutorials.md new file mode 100644 index 00000000000..ceb4304a814 --- /dev/null +++ b/docs/source/backends/arm-vgf/tutorials/arm-vgf-tutorials.md @@ -0,0 +1,10 @@ +# Arm VGF Backend Tutorials + +**→{doc}`vgf-getting-started`** + +```{toctree} +:maxdepth: 2 +:hidden: +:caption: Tutorials + +vgf-getting-started diff --git a/docs/source/backends/arm-vgf/tutorials/vgf-getting-started.md b/docs/source/backends/arm-vgf/tutorials/vgf-getting-started.md new file mode 100644 index 00000000000..fe4a019528d --- /dev/null +++ b/docs/source/backends/arm-vgf/tutorials/vgf-getting-started.md @@ -0,0 +1,213 @@ +# Getting Started Tutorial + + +::::{grid} 2 + +:::{grid-item-card} Tutorials we recommend you complete before this: +:class-card: card-prerequisites +* [Introduction to ExecuTorch](intro-how-it-works.md) +* [Getting Started](getting-started.md) +* [Building ExecuTorch with CMake](using-executorch-building-from-source.md) +::: + +:::{grid-item-card} What you will learn in this tutorial: +:class-card: card-prerequisites +In this tutorial you will learn how to export a simple PyTorch model for the ExecuTorch VGF backend. +::: + +:::: + +```{warning} +This delegate is under active development, to get best results please use a recent version. +The VGF backend support is in early development and you may encounter issues. +You may encounter some rough edges and features which may be documented or planned but not implemented, please refer to the in-tree documentation for the latest status of features. +``` + +```{tip} +If you are already familiar with this delegate, you may want to jump directly to the examples: +* [Examples in the ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm) +* [A commandline compiler for example models](https://github.com/pytorch/executorch/blob/main/examples/arm/aot_arm_compiler.py) +``` + +This tutorial serves as an introduction to using ExecuTorch to deploy PyTorch models on VGF targets. The tutorial is based on `vgf_minimal_example.ipyb`, provided in Arm's example folder. + +## Prerequisites + +### Hardware + +To successfully complete this tutorial, you will need a Linux machine with aarch64 or x86_64 processor architecture, or a macOS™ machine with Apple® Silicon. + +To enable development without a specific development board, we will be using the [ML SDK for Vulkan®](https://github.com/arm/ai-ml-sdk-for-vulkan/) to emulate the program consumer. + +### Software + +First, you will need to install ExecuTorch. Please follow the recommended tutorials if you haven't already, to set up a working ExecuTorch development environment. For the VGF backend it's recommended you [install from source](https://docs.pytorch.org/executorch/stable/using-executorch-building-from-source.html), or from a [nightly](https://download.pytorch.org/whl/nightly/executorch/). + +In addition to this, you need to install a number of SDK dependencies for generating VGF files. Scripts to automate this are available in the main [ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm/). To install VGF dependencies, run +```bash +./examples/arm/setup.sh --i-agree-to-the-contained-eula --disable-ethos-u-deps --enable-mlsdk-deps +``` +This will install: +- [TOSA Serialization Library](https://www.mlplatform.org/tosa/software.html) for serializing the Exir IR graph into TOSA IR. +- [ML SDK Model Converter](https://github.com/arm/ai-ml-sdk-model-converter) for converting TOSA flatbuffers to VGF files. +- [Vulkan API](https://www.vulkan.org) should be set up locally for GPU execution support. +- [ML Emulation Layer for Vulkan](https://github.com/arm/ai-ml-emulation-layer-for-vulkan) for testing on Vulkan API. + + +## Set Up the Developer Environment + +The `setup.sh` script has generated a `setup_path.sh` script that you need to source whenever you restart your shell. Do this by running + +`source examples/arm/arm-scratch/setup_path.sh` + +As a simple check that your environment is set up correctly, run + +```bash +which model-converter +``` +Make sure the executable is located where you expect, in the `examples/arm` tree. + +## Build + +### Ahead-of-Time (AOT) components + +The ExecuTorch Ahead-of-Time (AOT) pipeline takes a PyTorch Model (a `torch.nn.Module`) and produces a `.pte` binary file, which is then typically consumed by the ExecuTorch Runtime. This [document](https://github.com/pytorch/executorch/blob/main/docs/source/getting-started-architecture.md) goes in much more depth about the ExecuTorch software stack for both AoT as well as Runtime. + +The example below shows how to quantize a model consisting of a single addition, and export it it through the AOT flow using the VGF backend. For more details, se `examples/arm/vgf_minimal_example.ipynb`. + +```python +import torch + +class Add(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + +example_inputs = (torch.ones(1,1,1,1),torch.ones(1,1,1,1)) + +model = Add() +model = model.eval() +exported_program = torch.export.export(model, example_inputs) +graph_module = exported_program.graph_module + + +from executorch.backends.arm.quantizer import ( + VgfQuantizer, + get_symmetric_quantization_config, +) +from executorch.backends.arm.vgf import VgfCompileSpec +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + +# Create a compilation spec describing the target for configuring the quantizer +compile_spec = VgfCompileSpec("TOSA-1.0+INT") + +# Create and configure quantizer to use a symmetric quantization config globally on all nodes +quantizer = VgfQuantizer(compile_spec) +operator_config = get_symmetric_quantization_config(is_per_channel=False) +quantizer.set_global(operator_config) + +# Post training quantization +quantized_graph_module = prepare_pt2e(graph_module, quantizer) +quantized_graph_module(*example_inputs) # Calibrate the graph module with the example input +quantized_graph_module = convert_pt2e(quantized_graph_module) + + +# Create a new exported program using the quantized_graph_module +quantized_exported_program = torch.export.export(quantized_graph_module, example_inputs) +import os +from executorch.backends.arm.vgf import VgfPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.extension.export_util.utils import save_pte_program + +# Create partitioner from compile spec +partitioner = VgfPartitioner(compile_spec) + +# Lower the exported program to the VGF backend +edge_program_manager = to_edge_transform_and_lower( + quantized_exported_program, + partitioner=[partitioner], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + ), +) + +# Convert edge program to executorch +executorch_program_manager = edge_program_manager.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=False) +) + + +# Save pte file +cwd_dir = os.getcwd() +pte_base_name = "simple_example" +pte_name = pte_base_name + ".pte" +pte_path = os.path.join(cwd_dir, pte_name) +save_pte_program(executorch_program_manager, pte_name) +assert os.path.exists(pte_path), "Build failed; no .pte-file found" +``` + + +```{tip} +For a quick start, you can use the script `examples/arm/aot_arm_compiler.py` to produce the pte. +To produce a pte file equivalent to the one above, run +`python -m examples.arm.aot_arm_compiler --model_name=add --delegate --quantize --output=simple_example.pte --target=vgf` +``` + +### Runtime: + +## Build executor runtime + +After the AOT compilation flow is done, we can build the executor runner target. For this tutorial, the default runner can be used. Build it with the following configuration: + +```bash +# In ExecuTorch top-level, with sourced setup_path.sh +cmake \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Debug \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ + -DEXECUTORCH_BUILD_XNNPACK=OFF \ + -DEXECUTORCH_BUILD_VULKAN=ON \ + -DEXECUTORCH_BUILD_VGF=ON \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DPYTHON_EXECUTABLE=python \ + -Bcmake-out . + +cmake --build cmake-out --target executor_runner` +``` + + +The block diagram below demonstrates, at the high level, how the various build artifacts are generated and are linked together to generate the final bare-metal executable. + +![](arm-delegate-runtime-build.svg) + + +## Deploying and running on device + +Since we are using the Vulkan emulation layer, we can run the executor runner with the VGF delegate on the host machine: + +```bash +./cmake-out/executor_runner -model_path simple_example.pte +``` + +The example application is by default built with an input of ones, so the expected result of the quantized addition should be close to 2. + +## Takeaways + +In this tutorial you have learned how to use ExecuTorch to export a PyTorch model to an executable that can run on an embedded target, and then run that executable on simulated hardware. + + +## FAQs + +Issue: glslc is not found when configuring the executor runner. +Solution: The Vulkan sdk is likely not in your path, check whether setup_path.sh contains something like +`export PATH=$(pwd)/examples/arm/arm-scratch/vulkan_sdk/1.4.321.1/x86_64/bin:$PATH`. +If not, add it and source the file. + +If you encountered any bugs or issues following this tutorial please file a bug/issue here on [Github](https://github.com/pytorch/executorch/issues/new). diff --git a/docs/source/backends/coreml/coreml-op-support.md b/docs/source/backends/coreml/coreml-op-support.md new file mode 100644 index 00000000000..107de9f6a80 --- /dev/null +++ b/docs/source/backends/coreml/coreml-op-support.md @@ -0,0 +1,5 @@ +# Op support + +The Core ML backend supports almost all PyTorch operators. + +If an operator in your model is not supported by Core ML, you will see a warning about this during lowering. If you want to guarantee that your model fully delegates to Core ML, you can set [`lower_full_graph=True`](coreml-partitioner.md) in the `CoreMLPartitioner`. When set, lowering will fail if an unsupported operator is encountered. diff --git a/docs/source/backends/coreml/coreml-overview.md b/docs/source/backends/coreml/coreml-overview.md new file mode 100644 index 00000000000..bff0cb8994e --- /dev/null +++ b/docs/source/backends/coreml/coreml-overview.md @@ -0,0 +1,112 @@ +# Core ML Backend + +Core ML delegate is the ExecuTorch solution to take advantage of Apple's [Core ML framework](https://developer.apple.com/documentation/coreml) for on-device ML. With Core ML, a model can run on CPU, GPU, and the Apple Neural Engine (ANE). + +## Features + +- Dynamic dispatch to the CPU, GPU, and ANE. +- Supports fp32 and fp16 computation. + +## Target Requirements + +Below are the minimum OS requirements on various hardware for running a Core ML-delegated ExecuTorch model: + +- [macOS](https://developer.apple.com/macos) >= 13.0 +- [iOS](https://developer.apple.com/ios/) >= 16.0 +- [iPadOS](https://developer.apple.com/ipados/) >= 16.0 +- [tvOS](https://developer.apple.com/tvos/) >= 16.0 + +## Development Requirements + +To develop you need: + +- [macOS](https://developer.apple.com/macos) >= 13.0 +- [Xcode](https://developer.apple.com/documentation/xcode) >= 14.1 + + +Before starting, make sure you install the Xcode Command Line Tools: + +```bash +xcode-select --install +``` + +---- + +## Using the Core ML Backend + +To target the Core ML backend during the export and lowering process, pass an instance of the `CoreMLPartitioner` to `to_edge_transform_and_lower`. The example below demonstrates this process using the MobileNet V2 model from torchvision. + +```python +import torch +import torchvision.models as models +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.exir import to_edge_transform_and_lower + +mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() +sample_inputs = (torch.randn(1, 3, 224, 224), ) + +et_program = to_edge_transform_and_lower( + torch.export.export(mobilenet_v2, sample_inputs), + partitioner=[CoreMLPartitioner()], +).to_executorch() + +with open("mv2_coreml.pte", "wb") as file: + et_program.write_to_file(file) +``` + +See [Partitioner API](coreml-partitioner.md) for a reference on available partitioner options. + +---- + +## Quantization + +The Core ML delegate can also be used as a backend to execute quantized models. See [Core ML Quantization](coreml-quantization.md) for more information on available quantization schemes and APIs. + +## Backward compatibility + +Core ML supports backward compatibility via the [`minimum_deployment_target`](coreml-partitioner.md#coreml-compilespec) option. A model exported with a specific deployment target is guaranteed to work on all deployment targets >= the specified deployment target. For example, a model exported with `coremltools.target.iOS17` will work on iOS 17 or higher. + +---- + +## Runtime integration + +To run the model on device, use the standard ExecuTorch runtime APIs. See [Running on Device](getting-started.md#running-on-device) for more information, including building the iOS frameworks. + +When building from source, pass `-DEXECUTORCH_BUILD_COREML=ON` when configuring the CMake build to compile the Core ML backend. + +Due to the use of static initializers for registration, it may be necessary to use whole-archive to link against the `coremldelegate` target. This can typically be done by passing `"$"` to `target_link_libraries`. + +``` +# CMakeLists.txt +add_subdirectory("executorch") +... +target_link_libraries( + my_target + PRIVATE executorch + extension_module_static + extension_tensor + optimized_native_cpu_ops_lib + $) +``` + +No additional steps are necessary to use the backend beyond linking the target. A Core ML-delegated .pte file will automatically run on the registered backend. + +## Reference + +**→{doc}`/backends/coreml/coreml-troubleshooting` — Debug common issues.** + +**→{doc}`/backends/coreml/coreml-partitioner` — Partitioner options.** + +**→{doc}`/backends/coreml/coreml-quantization` — Supported quantization schemes.** + +**→{doc}`/backends/coreml/coreml-op-support` — Supported operators.** + +```{toctree} +:maxdepth: 2 +:hidden: +:caption: Core ML Backend +coreml-troubleshooting +coreml-partitioner +coreml-quantization +coreml-op-support diff --git a/docs/source/backends/coreml/coreml-partitioner.md b/docs/source/backends/coreml/coreml-partitioner.md new file mode 100644 index 00000000000..36cb0a74363 --- /dev/null +++ b/docs/source/backends/coreml/coreml-partitioner.md @@ -0,0 +1,114 @@ +# Partitioner API + +The Core ML partitioner API allows for configuration of the model delegation to Core ML. Passing a `CoreMLPartitioner` instance with no additional parameters will run as much of the model as possible on the Core ML backend with default settings. This is the most common use case. For advanced use cases, the partitioner exposes the following options via the [constructor](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/partition/coreml_partitioner.py#L60): + + + - `skip_ops_for_coreml_delegation`: Allows you to skip ops for delegation by Core ML. By default, all ops that Core ML supports will be delegated. See [here](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/test/test_coreml_partitioner.py#L42) for an example of skipping an op for delegation. +- `compile_specs`: A list of `CompileSpec`s for the Core ML backend. These control low-level details of Core ML delegation, such as the compute unit (CPU, GPU, ANE), the iOS deployment target, and the compute precision (FP16, FP32). These are discussed more below. +- `take_over_mutable_buffer`: A boolean that indicates whether PyTorch mutable buffers in stateful models should be converted to [Core ML `MLState`](https://developer.apple.com/documentation/coreml/mlstate). If set to `False`, mutable buffers in the PyTorch graph are converted to graph inputs and outputs to the Core ML lowered module under the hood. Generally, setting `take_over_mutable_buffer` to true will result in better performance, but using `MLState` requires iOS >= 18.0, macOS >= 15.0, and Xcode >= 16.0. +- `take_over_constant_data`: A boolean that indicates whether PyTorch constant data like model weights should be consumed by the Core ML delegate. If set to False, constant data is passed to the Core ML delegate as inputs. By default, take_over_constant_data=True. +- `lower_full_graph`: A boolean that indicates whether the entire graph must be lowered to Core ML. If set to True and Core ML does not support an op, an error is raised during lowering. If set to False and Core ML does not support an op, the op is executed on the CPU by ExecuTorch. Although setting `lower_full_graph`=False can allow a model to lower where it would otherwise fail, it can introduce performance overhead in the model when there are unsupported ops. You will see warnings about unsupported ops during lowering if there are any. By default, `lower_full_graph`=False. + + +#### Core ML CompileSpec + +A list of `CompileSpec`s is constructed with [`CoreMLBackend.generate_compile_specs`](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L210). Below are the available options: +- `compute_unit`: this controls the compute units (CPU, GPU, ANE) that are used by Core ML. The default value is `coremltools.ComputeUnit.ALL`. The available options from coremltools are: + - `coremltools.ComputeUnit.ALL` (uses the CPU, GPU, and ANE) + - `coremltools.ComputeUnit.CPU_ONLY` (uses the CPU only) + - `coremltools.ComputeUnit.CPU_AND_GPU` (uses both the CPU and GPU, but not the ANE) + - `coremltools.ComputeUnit.CPU_AND_NE` (uses both the CPU and ANE, but not the GPU) +- `minimum_deployment_target`: The minimum iOS deployment target (e.g., `coremltools.target.iOS18`). By default, the smallest deployment target needed to deploy the model is selected. During export, you will see a warning about the "Core ML specification version" that was used for the model, which maps onto a deployment target as discussed [here](https://apple.github.io/coremltools/mlmodel/Format/Model.html#model). If you need to control the deployment target, please specify it explicitly. +- `compute_precision`: The compute precision used by Core ML (`coremltools.precision.FLOAT16` or `coremltools.precision.FLOAT32`). The default value is `coremltools.precision.FLOAT16`. Note that the compute precision is applied no matter what dtype is specified in the exported PyTorch model. For example, an FP32 PyTorch model will be converted to FP16 when delegating to the Core ML backend by default. Also note that the ANE only supports FP16 precision. +- `model_type`: Whether the model should be compiled to the Core ML [mlmodelc format](https://developer.apple.com/documentation/coreml/downloading-and-compiling-a-model-on-the-user-s-device) during .pte creation ([`CoreMLBackend.MODEL_TYPE.COMPILED_MODEL`](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L71)), or whether it should be compiled to mlmodelc on device ([`CoreMLBackend.MODEL_TYPE.MODEL`](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L70)). Using `CoreMLBackend.MODEL_TYPE.COMPILED_MODEL` and doing compilation ahead of time should improve the first time on-device model load time. + +### Dynamic and Enumerated Shapes in Core ML Export + +When exporting an `ExportedProgram` to Core ML, **dynamic shapes** are mapped to [`RangeDim`](https://apple.github.io/coremltools/docs-guides/source/flexible-inputs.html#set-the-range-for-each-dimension). +This enables Core ML `.pte` files to accept inputs with varying dimensions at runtime. + +⚠️ **Note:** The Apple Neural Engine (ANE) does not support true dynamic shapes. If a model relies on `RangeDim`, Core ML will fall back to scheduling the model on the CPU or GPU instead of the ANE. + +--- + +#### Enumerated Shapes + +To enable limited flexibility on the ANE—and often achieve better performance overall—you can export models using **[enumerated shapes](https://apple.github.io/coremltools/docs-guides/source/flexible-inputs.html#select-from-predetermined-shapes)**. + +- Enumerated shapes are *not fully dynamic*. +- Instead, they define a **finite set of valid input shapes** that Core ML can select from at runtime. +- This approach allows some adaptability while still preserving ANE compatibility. + +--- + +#### Specifying Enumerated Shapes + +Unlike `RangeDim`, **enumerated shapes are not part of the `ExportedProgram` itself.** +They must be provided through a compile spec. + +For reference on how to do this, see: +- The annotated code snippet below, and +- The [end-to-end test in ExecuTorch](https://github.com/pytorch/executorch/blob/main/backends/apple/coreml/test/test_enumerated_shapes.py), which demonstrates how to specify enumerated shapes during export. + + +```python +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 5) + self.linear2 = torch.nn.Linear(11, 5) + + def forward(self, x, y): + return self.linear1(x).sum() + self.linear2(y) + +model = Model() +example_inputs = ( + torch.randn((4, 6, 10)), + torch.randn((5, 11)), +) + +# Specify the enumerated shapes. Below we specify that: +# +# * x can take shape [1, 5, 10] and y can take shape [3, 11], or +# * x can take shape [4, 6, 10] and y can take shape [5, 11] +# +# Any other input shapes will result in a runtime error. +# +# Note that we must export x and y with dynamic shapes in the ExportedProgram +# because some of their dimensions are dynamic +enumerated_shapes = {"x": [[1, 5, 10], [4, 6, 10]], "y": [[3, 11], [5, 11]]} +dynamic_shapes = [ + { + 0: torch.export.Dim.AUTO(min=1, max=4), + 1: torch.export.Dim.AUTO(min=5, max=6), + }, + {0: torch.export.Dim.AUTO(min=3, max=5)}, +] +ep = torch.export.export( + model.eval(), example_inputs, dynamic_shapes=dynamic_shapes +) + +# If enumerated shapes are specified for multiple inputs, we must export +# for iOS18+ +compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18 +) +compile_specs.append( + CoreMLBackend.generate_enumerated_shapes_compile_spec( + ep, + enumerated_shapes, + ) +) + +# When using an enumerated shape compile spec, you must specify lower_full_graph=True +# in the CoreMLPartitioner. We do not support using enumerated shapes +# for partially exported models +partitioner = CoreMLPartitioner( + compile_specs=compile_specs, lower_full_graph=True +) +delegated_program = executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[partitioner], +) +et_prog = delegated_program.to_executorch() +``` diff --git a/docs/source/backends/coreml/coreml-quantization.md b/docs/source/backends/coreml/coreml-quantization.md new file mode 100644 index 00000000000..151aabd8144 --- /dev/null +++ b/docs/source/backends/coreml/coreml-quantization.md @@ -0,0 +1,162 @@ +# Quantization + +To quantize a PyTorch model for the Core ML backend, use the `CoreMLQuantizer`. `Quantizers` are backend specific, which means the `CoreMLQuantizer` is configured to quantize models to leverage the quantized operators offered by the Core ML backend. + +### Supported Quantization Schemes + +The CoreML delegate supports the following quantization schemes: + +- 8-bit static and weight-only quantization via the PT2E flow; dynamic quantization is not supported by CoreML. +- 4-bit weight-only affine quantization (per-group or per-channel) via the quantize_ flow +- 1-8 bit weight-only LUT quantization (per grouped-channel) via the quantize_ flow + +### 8-bit Quantization using the PT2E Flow + +Quantization with the Core ML backend requires exporting the model for iOS 17 or later. +To perform 8-bit quantization with the PT2E flow, follow these steps: + +1) Create a [`coremltools.optimize.torch.quantization.LinearQuantizerConfig`](https://apple.github.io/coremltools/source/coremltools.optimize.torch.quantization.html#coremltools.optimize.torch.quantization.LinearQuantizerConfig) and use it to create an instance of a `CoreMLQuantizer`. +2) Use `torch.export.export` to export a graph module that will be prepared for quantization. +3) Call `prepare_pt2e` to prepare the model for quantization. +4) Run the prepared model with representative samples to calibrate the quantizated tensor activation ranges. +5) Call `convert_pt2e` to quantize the model. +6) Export and lower the model using the standard flow. + +The output of `convert_pt2e` is a PyTorch model which can be exported and lowered using the normal flow. As it is a regular PyTorch model, it can also be used to evaluate the accuracy of the quantized model using standard PyTorch techniques. + +```python +import torch +import coremltools as ct +import torchvision.models as models +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights +from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from executorch.exir import to_edge_transform_and_lower +from executorch.backends.apple.coreml.compiler import CoreMLBackend + +mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() +sample_inputs = (torch.randn(1, 3, 224, 224), ) + +# Step 1: Define a LinearQuantizerConfig and create an instance of a CoreMLQuantizer +# Note that "linear" here does not mean only linear layers are quantized, but that linear (aka affine) quantization +# is being performed +static_8bit_config = ct.optimize.torch.quantization.LinearQuantizerConfig( + global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig( + quantization_scheme="symmetric", + activation_dtype=torch.quint8, + weight_dtype=torch.qint8, + weight_per_channel=True, + ) +) +quantizer = CoreMLQuantizer(static_8bit_config) + +# Step 2: Export the model for training +training_gm = torch.export.export(mobilenet_v2, sample_inputs).module() + +# Step 3: Prepare the model for quantization +prepared_model = prepare_pt2e(training_gm, quantizer) + +# Step 4: Calibrate the model on representative data +# Replace with your own calibration data +for calibration_sample in [torch.randn(1, 3, 224, 224)]: + prepared_model(calibration_sample) + +# Step 5: Convert the calibrated model to a quantized model +quantized_model = convert_pt2e(prepared_model) + +# Step 6: Export the quantized model to Core ML +et_program = to_edge_transform_and_lower( + torch.export.export(quantized_model, sample_inputs), + partitioner=[ + CoreMLPartitioner( + # iOS17 is required for the quantized ops in this example + compile_specs=CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS17 + ) + ) + ], +).to_executorch() +``` + +The above does static quantization (activations and weights are quantized). + +You can see a full description of available quantization configs in the [coremltools documentation](https://apple.github.io/coremltools/source/coremltools.optimize.torch.quantization.html#coremltools.optimize.torch.quantization.LinearQuantizerConfig). For example, the config below will perform weight-only quantization: + +``` +weight_only_8bit_config = ct.optimize.torch.quantization.LinearQuantizerConfig( + global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig( + quantization_scheme="symmetric", + activation_dtype=torch.float32, + weight_dtype=torch.qint8, + weight_per_channel=True, + ) +) +quantizer = CoreMLQuantizer(weight_only_8bit_config) +``` + +Quantizing activations requires calibrating the model on representative data. Also note that PT2E currently requires passing at least 1 calibration sample before calling `convert_pt2e`, even for data-free weight-only quantization. + +See [PyTorch 2 Export Post Training Quantization](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) for more information. + +### LLM quantization with quantize_ + +The Core ML backend also supports quantizing models with the [torchao](https://github.com/pytorch/ao) quantize_ API. This is most commonly used for LLMs, requiring more advanced quantization. Since quantize_ is not backend aware, it is important to use a config that is compatible with Core ML: + +* Quantize embedding/linear layers with IntxWeightOnlyConfig (with weight_dtype torch.int4 or torch.int8, using PerGroup or PerAxis granularity). Using 4-bit or PerGroup quantization requires exporting with minimum_deployment_target >= ct.target.iOS18. Using 8-bit quantization with per-axis granularity is supported on ct.target.IOS16+. See [Core ML `CompileSpec`](coreml-partitioner.md#coreml-compilespec) for more information on setting the deployment target. +* Quantize embedding/linear layers with CodebookWeightOnlyConfig (with dtype torch.uint1 through torch.uint8, using various block sizes). Quantizing with CodebookWeightOnlyConfig requires exporting with minimum_deployment_target >= ct.target.iOS18, see [Core ML `CompileSpec`](coreml-partitioner.md#coreml-compilespec) for more information on setting the deployment target. + +Below is an example that quantizes embeddings to 8-bits per-axis and linear layers to 4-bits using group_size=32 with affine quantization: + +```python +from torchao.quantization.granularity import PerGroup, PerAxis +from torchao.quantization.quant_api import ( + IntxWeightOnlyConfig, + quantize_, +) + +# Quantize embeddings with 8-bits, per channel +embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), +) +quantize_( + eager_model, + embedding_config, + lambda m, fqn: isinstance(m, torch.nn.Embedding), +) + +# Quantize linear layers with 4-bits, per-group +linear_config = IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), +) +quantize_( + eager_model, + linear_config, +) +``` + +Below is another example that uses codebook quantization to quantize both embeddings and linear layers to 3-bits. +In the coremltools documentation, this is called [palettization](https://apple.github.io/coremltools/docs-guides/source/opt-palettization-overview.html): + +``` +from torchao.quantization.quant_api import ( + quantize_, +) +from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig + +quant_config = CodebookWeightOnlyConfig( + dtype=torch.uint3, + # There is one LUT per 16 rows + block_size=[16, -1], +) + +quantize_( + eager_model, + quant_config, + lambda m, fqn: isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear), +) +``` + +Both of the above examples will export and lower to Core ML with the to_edge_transform_and_lower API. diff --git a/docs/source/backends/coreml/coreml-troubleshooting.md b/docs/source/backends/coreml/coreml-troubleshooting.md new file mode 100644 index 00000000000..0c764b9d51b --- /dev/null +++ b/docs/source/backends/coreml/coreml-troubleshooting.md @@ -0,0 +1,22 @@ +# Troubleshooting + +This page describes common issues that you may encounter when using the Core ML backend and how to debug and resolve them. + +### Issues during lowering +1. "ValueError: In op, of type [X], named [Y], the named input [Z] must have the same data type as the named input x. However, [Z] has dtype fp32 whereas x has dtype fp16." + +This happens because the model is in FP16, but Core ML interprets some of the arguments as FP32, which leads to a type mismatch. The solution is to keep the PyTorch model in FP32. Note that the model will be still be converted to FP16 during lowering to Core ML unless specified otherwise in the compute_precision [Core ML `CompileSpec`](coreml-partitioner.md#coreml-compilespec). Also see the [related issue in coremltools](https://github.com/apple/coremltools/issues/2480). + +### Issues during runtime +1. [ETCoreMLModelCompiler.mm:55] [Core ML] Failed to compile model, error = Error Domain=com.apple.mlassetio Code=1 "Failed to parse the model specification. Error: Unable to parse ML Program: at unknown location: Unknown opset 'CoreML7'." UserInfo={NSLocalizedDescription=Failed to par$ + +This means the model requires the Core ML opset 'CoreML7', which requires running the model on iOS >= 17 or macOS >= 14. + +## Extracting the mlpackage for profiling and debugging + +[Core ML *.mlpackage files](https://apple.github.io/coremltools/docs-guides/source/convert-to-ml-program.html#save-ml-programs-as-model-packages) can be extracted from a Core ML-delegated *.pte file. This can help with debugging and profiling for users who are more familiar with *.mlpackage files: +```bash +python examples/apple/coreml/scripts/extract_coreml_models.py -m /path/to/model.pte +``` + +Note that if the ExecuTorch model has graph breaks, there may be multiple extracted *.mlpackage files. diff --git a/docs/source/backends/mps/mps-overview.md b/docs/source/backends/mps/mps-overview.md new file mode 100644 index 00000000000..a2280defad5 --- /dev/null +++ b/docs/source/backends/mps/mps-overview.md @@ -0,0 +1,120 @@ +# MPS Backend + +MPS delegate is the ExecuTorch solution to take advantage of Apple's GPU for on-device ML using the [MPS Graph](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph?language=objc) framework and tuned kernels provided by [MPS](https://developer.apple.com/documentation/metalperformanceshaders?language=objc). + +## Target Requirements + +Below are the minimum OS requirements on various hardware for running a MPS-delegated ExecuTorch model: +- [macOS](https://developer.apple.com/macos) >= 12.4 +- [iOS](https://www.apple.com/ios) >= 15.4 + +## Development Requirements +To develop you need: + +- [Xcode](https://developer.apple.com/xcode/) >= 14.1 + +Before starting, make sure you install the Xcode Command Line Tools: + +```bash +xcode-select --install +``` + +## Using the MPS Backend + +In this step, you will generate a simple ExecuTorch program that lowers MobileNetV3 model to the MPS delegate. You'll then pass this Program (the `.pte` file) during the runtime to run it using the MPS backend. + +```bash +cd executorch +# Note: `mps_example` script uses by default the MPSPartitioner for ops that are not yet supported by the MPS delegate. To turn it off, pass `--no-use_partitioner`. +python3 -m examples.apple.mps.scripts.mps_example --model_name="mv3" --bundled --use_fp16 + +# To see all options, run following command: +python3 -m examples.apple.mps.scripts.mps_example --help +``` + +### Runtime + +**Building the MPS executor runner:** +```bash +# In this step, you'll be building the `mps_executor_runner` that is able to run MPS lowered modules: +cd executorch +./examples/apple/mps/scripts/build_mps_executor_runner.sh +``` + +## Run the mv3 generated model using the mps_executor_runner + +```bash +./cmake-out/examples/apple/mps/mps_executor_runner --model_path mv3_mps_float16_bundled.pte --bundled_program +``` + +- You should see the following results. Note that no output file will be generated in this example: +``` +I 00:00:00.003290 executorch:mps_executor_runner.mm:286] Model file mv3_mps_float16_bundled.pte is loaded. +I 00:00:00.003306 executorch:mps_executor_runner.mm:292] Program methods: 1 +I 00:00:00.003308 executorch:mps_executor_runner.mm:294] Running method forward +I 00:00:00.003311 executorch:mps_executor_runner.mm:349] Setting up non-const buffer 1, size 606112. +I 00:00:00.003374 executorch:mps_executor_runner.mm:376] Setting up memory manager +I 00:00:00.003376 executorch:mps_executor_runner.mm:392] Loading method name from plan +I 00:00:00.018942 executorch:mps_executor_runner.mm:399] Method loaded. +I 00:00:00.018944 executorch:mps_executor_runner.mm:404] Loading bundled program... +I 00:00:00.018980 executorch:mps_executor_runner.mm:421] Inputs prepared. +I 00:00:00.118731 executorch:mps_executor_runner.mm:438] Model executed successfully. +I 00:00:00.122615 executorch:mps_executor_runner.mm:501] Model verified successfully. +``` + +### [Optional] Run the generated model directly using pybind +1. Make sure `pybind` MPS support was installed: +```bash +CMAKE_ARGS="-DEXECUTORCH_BUILD_MPS=ON" ./install_executorch.sh +``` +2. Run the `mps_example` script to trace the model and run it directly from python: +```bash +cd executorch +# Check correctness between PyTorch eager forward pass and ExecuTorch MPS delegate forward pass +python3 -m examples.apple.mps.scripts.mps_example --model_name="mv3" --no-use_fp16 --check_correctness +# You should see following output: `Results between ExecuTorch forward pass with MPS backend and PyTorch forward pass for mv3_mps are matching!` + +# Check performance between PyTorch MPS forward pass and ExecuTorch MPS forward pass +python3 -m examples.apple.mps.scripts.mps_example --model_name="mv3" --no-use_fp16 --bench_pytorch +``` + +### Profiling: +1. [Optional] Generate an [ETRecord](etrecord.rst) while you're exporting your model. +```bash +cd executorch +python3 -m examples.apple.mps.scripts.mps_example --model_name="mv3" --generate_etrecord -b +``` +2. Run your Program on the ExecuTorch runtime and generate an [ETDump](etdump.md). +``` +./cmake-out/examples/apple/mps/mps_executor_runner --model_path mv3_mps_float16_bundled.pte --bundled_program --dump-outputs +``` +3. Create an instance of the Inspector API by passing in the ETDump you have sourced from the runtime along with the optionally generated ETRecord from step 1. +```bash +python3 -m devtools.inspector.inspector_cli --etdump_path etdump.etdp --etrecord_path etrecord.bin +``` + +## Runtime integration + +***Step 1***. Create the ExecuTorch core and MPS delegate frameworks to link on iOS +```bash +cd executorch +./scripts/build_apple_frameworks.sh --mps +``` + +`mps_delegate.xcframework` will be in `cmake-out` folder, along with `executorch.xcframework` and `portable_delegate.xcframework`: +```bash +cd cmake-out && ls +``` + +***Step 2***. Link the frameworks into your XCode project: +Go to project Target’s `Build Phases` - `Link Binaries With Libraries`, click the **+** sign and add the frameworks: files located in `Release` folder. +- `executorch.xcframework` +- `portable_delegate.xcframework` +- `mps_delegate.xcframework` + +From the same page, include the needed libraries for the MPS delegate: +- `MetalPerformanceShaders.framework` +- `MetalPerformanceShadersGraph.framework` +- `Metal.framework` + +In this tutorial, you have learned how to lower a model to the MPS delegate, build the mps_executor_runner and run a lowered model through the MPS delegate, or directly on device using the MPS delegate static library. diff --git a/docs/source/backends/nxp/nxp-overview.md b/docs/source/backends/nxp/nxp-overview.md new file mode 100644 index 00000000000..973bffe6f19 --- /dev/null +++ b/docs/source/backends/nxp/nxp-overview.md @@ -0,0 +1,71 @@ +# NXP eIQ Neutron Backend + +This manual page is dedicated to introduction NXP eIQ Neutron backend. +NXP offers accelerated machine learning models inference on edge devices. +To learn more about NXP's machine learning acceleration platform, please refer to [the official NXP website](https://www.nxp.com/applications/technologies/ai-and-machine-learning:MACHINE-LEARNING). + +
+For up-to-date status about running ExecuTorch on Neutron backend please visit the manual page. +
+ +## Features + + +ExecuTorch v1.0 supports running machine learning models on selected NXP chips (for now only i.MXRT700). +Among currently supported machine learning models are: +- Convolution-based neutral networks +- Full support for MobileNetV2 and CifarNet + +## Target Requirements + +- Hardware with NXP's [i.MXRT700](https://www.nxp.com/products/i.MX-RT700) chip or a evaluation board like MIMXRT700-EVK. + +## Development Requirements + +- [MCUXpresso IDE](https://www.nxp.com/design/design-center/software/development-software/mcuxpresso-software-and-tools-/mcuxpresso-integrated-development-environment-ide:MCUXpresso-IDE) or [MCUXpresso Visual Studio Code extension](https://www.nxp.com/design/design-center/software/development-software/mcuxpresso-software-and-tools-/mcuxpresso-for-visual-studio-code:MCUXPRESSO-VSC) +- [MCUXpresso SDK 25.06](https://mcuxpresso.nxp.com/mcuxsdk/25.06.00/html/index.html) +- eIQ Neutron Converter for MCUXPresso SDK 25.06, what you can download from eIQ PyPI: + +```commandline +$ pip install --index-url https://eiq.nxp.com/repository neutron_converter_SDK_25_06 +``` + +Instead of manually installing requirements, except MCUXpresso IDE and SDK, you can use the setup script: +```commandline +$ ./examples/nxp/setup.sh +``` + +## Using NXP eIQ Backend + +To test converting a neural network model for inference on NXP eIQ Neutron backend, you can use our example script: + +```shell +# cd to the root of executorch repository +./examples/nxp/aot_neutron_compile.sh [model (cifar10 or mobilenetv2)] +``` + +For a quick overview how to convert a custom PyTorch model, take a look at our [example python script](https://github.com/pytorch/executorch/tree/release/1.0/examples/nxp/aot_neutron_compile.py). + + +## Runtime Integration + +To learn how to run the converted model on the NXP hardware, use one of our example projects on using ExecuTorch runtime from MCUXpresso IDE example projects list. +For more finegrained tutorial, visit [this manual page](https://mcuxpresso.nxp.com/mcuxsdk/latest/html/middleware/eiq/executorch/docs/nxp/topics/example_applications.html). + +## Reference + +**→{doc}`nxp-partitioner` — Partitioner options.** + +**→{doc}`nxp-quantization` — Supported quantization schemes.** + +**→{doc}`tutorials/nxp-tutorials` — Tutorials.** + +```{toctree} +:maxdepth: 2 +:hidden: +:caption: NXP Backend + +nxp-partitioner +nxp-quantization +tutorials/nxp-tutorials +``` diff --git a/docs/source/backends/nxp/nxp-partitioner.rst b/docs/source/backends/nxp/nxp-partitioner.rst new file mode 100644 index 00000000000..d6ef1c216fd --- /dev/null +++ b/docs/source/backends/nxp/nxp-partitioner.rst @@ -0,0 +1,43 @@ +=============== +Partitioner API +=============== + +The Neutron partitioner API allows for configuration of the model delegation to Neutron. Passing an ``NeutronPartitioner`` instance with no additional parameters will run as much of the model as possible on the Neutron backend. This is the most common use-case. + +It has the following arguments: + +* `compile_spec` - list of key-value pairs defining compilation: +* `custom_delegation_options` - custom options for specifying node delegation. + +-------------------- +Compile Spec Options +-------------------- +To generate the Compile Spec for Neutron backend, you can use the `generate_neutron_compile_spec` function or directly the `NeutronCompileSpecBuilder().neutron_compile_spec()` +Following fields can be set: + +* `config` - NXP platform defining the Neutron NPU configuration, e.g. "imxrt700". +* `neutron_converter_flavor` - Flavor of the neutron-converter module to use. Neutron-converter module named neutron_converter_SDK_25_06' has flavor 'SDK_25_06'. You shall set the flavour to the MCUXpresso SDK version you will use. +* `extra_flags` - Extra flags for the Neutron compiler. +* `operators_not_to_delegate` - List of operators that will not be delegated. + +------------------------- +Custom Delegation Options +------------------------- +By default the Neutron backend is defensive, what means it does not delegate operators which cannot be decided statically during partitioning. But as the model author you typically have insight into the model and so you can allow opportunistic delegation for some cases. For list of options, see +`CustomDelegationOptions `_ + +================ +Operator Support +================ + +Operators are the building blocks of the ML model. See `IRs `_ for more information on the PyTorch operator set. + +This section lists the Edge operators supported by the Neutron backend. +For detailed constraints of the operators see the conditions in the ``is_supported_*`` functions in the `Node converters `_ + + +.. csv-table:: Operator Support + :file: op-support.csv + :header-rows: 1 + :widths: 20 15 30 30 + :align: center \ No newline at end of file diff --git a/docs/source/backends/nxp/nxp-quantization.md b/docs/source/backends/nxp/nxp-quantization.md new file mode 100644 index 00000000000..da7bc94f821 --- /dev/null +++ b/docs/source/backends/nxp/nxp-quantization.md @@ -0,0 +1,106 @@ +# NXP eIQ Neutron Quantization + +The eIQ Neutron NPU requires the operators delegated to be quantized. To quantize the PyTorch model for the Neutron backend, use the `NeutronQuantizer` from `backends/nxp/quantizer/neutron_quantizer.py`. +The `NeutronQuantizer` is configured to quantize the model with quantization scheme supported by the eIQ Neutron NPU. + +### Supported Quantization Schemes + +The Neutron delegate supports the following quantization schemes: + +- Static quantization with 8-bit symmetric weights and 8-bit asymmetric activations (via the PT2E quantization flow), per-tensor granularity. + - Following operators are supported at this moment: + - `aten.abs.default` + - `aten.adaptive_avg_pool2d.default` + - `aten.addmm.default` + - `aten.add.Tensor` + - `aten.avg_pool2d.default` + - `aten.cat.default` + - `aten.conv1d.default` + - `aten.conv2d.default` + - `aten.dropout.default` + - `aten.flatten.using_ints` + - `aten.hardtanh.default` + - `aten.hardtanh_.default` + - `aten.linear.default` + - `aten.max_pool2d.default` + - `aten.mean.dim` + - `aten.mul.Tensor` + - `aten.pad.default` + - `aten.permute.default` + - `aten.relu.default` and `aten.relu_.default` + - `aten.reshape.default` + - `aten.view.default` + - `aten.softmax.int` + - `aten.tanh.default`, `aten.tanh_.default` + - `aten.sigmoid.default` + - `aten.slice_copy.Tensor` + +### Static 8-bit Quantization Using the PT2E Flow + +To perform 8-bit quantization with the PT2E flow, perform the following steps prior to exporting the model to edge: + +1) Create an instance of the `NeutronQuantizer` class. +2) Use `torch.export.export` to export the model to ATen Dialect. +3) Call `prepare_pt2e` with the instance of the `NeutronQuantizer` to annotate the model with observers for quantization. +4) As static quantization is required, run the prepared model with representative samples to calibrate the quantized tensor activation ranges. +5) Call `convert_pt2e` to quantize the model. +6) Export and lower the model using the standard flow. + +The output of `convert_pt2e` is a PyTorch model which can be exported and lowered using the normal flow. As it is a regular PyTorch model, it can also be used to evaluate the accuracy of the quantized model using standard PyTorch techniques. + +To quantize the model, you can use the PT2E workflow: + +```python +import torch +import torchvision.models as models +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner +from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec +from executorch.exir import to_edge_transform_and_lower +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + +model = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() +sample_inputs = (torch.randn(1, 3, 224, 224), ) + +target_spec = NeutronTargetSpec(target="imxrt700", converter_flavor="SDK_25_09") +quantizer = NeutronQuantizer(neutron_target_spec) # (1) + +training_ep = torch.export.export(model, sample_inputs).module() # (2) +prepared_model = prepare_pt2e(training_ep, quantizer) # (3) + +for cal_sample in [torch.randn(1, 3, 224, 224)]: # Replace with representative model inputs + prepared_model(cal_sample) # (4) Calibrate + +quantized_model = convert_pt2e(prepared_model) # (5) + +compile_spec = generate_neutron_compile_spec( + "imxrt700", + operators_not_to_delegate=None, + neutron_converter_flavor="SDK_25_06", +) + +et_program = to_edge_transform_and_lower( # (6) + torch.export.export(quantized_model, sample_inputs), + partitioner=[NeutronPartitioner(compile_spec=compile_spec)], +).to_executorch() +``` + +Or you can use the predefined function for post training quantization from NXP Backend implementation: +```python +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.quantizer.utils import calibrate_and_quantize + +... + +target_spec = NeutronTargetSpec(target="imxrt700", converter_flavor="SDK_25_09") +quantized_graph_module = calibrate_and_quantize( + aten_model, + calibration_inputs, + NeutronQuantizer(neutron_target_spec=target_spec), +) +``` + +See [PyTorch 2 Export Post Training Quantization](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) for more information. diff --git a/docs/source/backends/nxp/op-support.csv b/docs/source/backends/nxp/op-support.csv new file mode 100644 index 00000000000..581ec3ffb94 --- /dev/null +++ b/docs/source/backends/nxp/op-support.csv @@ -0,0 +1,21 @@ +Operator,Compute DType,Quantization,Constraints +aten.abs.default,int8,static int8, +aten._adaptive_avg_pool2d.default,int8,static int8,"ceil_mode=False, count_include_pad=False, divisor_override=False" +aten.addmm.default,int8,static int8,2D tensor only +aten.add.Tensor,int8,static int8,"alpha = 1, input tensor of rame rank" +aten.avg_pool2d.default,int8,static int8,"ceil_mode=False, count_include_pad=False, divisor_override=False" +aten.cat.default,int8,static int8,"input_channels % 8 = 0, output_channels %8 = 0" +aten.clone.default,int8,static int8, +aten.constant_pad_nd.default,int8,static int8,"H or W padding only" +aten.convolution.default,int8,static int8,"1D or 2D convolution, constant weights, groups=1 or groups=channels_count (depthwise)" +aten.hardtanh.default,int8,static int8,"supported ranges: <0,6>, <-1, 1>, <0,1>, <0,inf>" +aten.max_pool2d.default,int8,static int8,"dilation=1, ceil_mode=False" +aten.max_pool2d_with_indices.default,int8,static int8,"dilation=1, ceil_mode=False" +aten.mean.dim,int8,static int8,"4D tensor only, dims = [-1,-2] or [-2,-1]" +aten.mul.Tensor, int8, static int8, "tensor-size % 8 = 0" +aten.mm.default,int8,static int8,"2D tensor only" +aten.relu.default,int8,static int8, +aten.tanh.default,int8,static int8, +aten.view_copy.default,int8,static int8, +aten.sigmoid.default,int8,static int8, +aten.slice_copy.Tensor, int8, static int8 diff --git a/docs/source/backends/nxp/tutorials/nxp-basic-tutorial.md b/docs/source/backends/nxp/tutorials/nxp-basic-tutorial.md new file mode 100644 index 00000000000..3f183a44f29 --- /dev/null +++ b/docs/source/backends/nxp/tutorials/nxp-basic-tutorial.md @@ -0,0 +1,24 @@ +# Preparing a Model for NXP eIQ Neutron Backend + +This guide demonstrating the use of ExecuTorch AoT flow to convert a PyTorch model to ExecuTorch +format and delegate the model computation to eIQ Neutron NPU using the eIQ Neutron Backend. + +## Step 1: Environment Setup + +This tutorial is intended to be run from a Linux and uses Conda or Virtual Env for Python environment management. For full setup details and system requirements, see [Getting Started with ExecuTorch](/getting-started). + +Create a Conda environment and install the ExecuTorch Python package. +```bash +conda create -y --name executorch python=3.12 +conda activate executorch +conda install executorch +``` + +Run the setup.sh script to install the neutron-converter: +```commandline +$ ./examples/nxp/setup.sh +``` + +## Step 2: Model Preparation and Running the Model on Target + +See the example `aot_neutron_compile.py` and its [README](https://github.com/pytorch/executorch/blob/release/1.0/examples/nxp/README.md) file. diff --git a/docs/source/backends/nxp/tutorials/nxp-tutorials.md b/docs/source/backends/nxp/tutorials/nxp-tutorials.md new file mode 100644 index 00000000000..eb5b164d668 --- /dev/null +++ b/docs/source/backends/nxp/tutorials/nxp-tutorials.md @@ -0,0 +1,10 @@ +# NXP Tutorials + +**→{doc}`nxp-basic-tutorial` — Lower and run a model on the NXP eIQ Neutron backend.** + +```{toctree} +:hidden: +:maxdepth: 1 + +nxp-basic-tutorial +``` diff --git a/docs/source/backends/samsung/samsung-op-support-table.csv b/docs/source/backends/samsung/samsung-op-support-table.csv new file mode 100644 index 00000000000..7d925c43400 --- /dev/null +++ b/docs/source/backends/samsung/samsung-op-support-table.csv @@ -0,0 +1,45 @@ +Operator,Quantization,Constraints +add,static int8, +avg_pool2d,static int8,"ceil_mode=False, divisor_override=pooling_region" +batch_norm,static int8, +bmm,static int8, +cat,static int8,at most 1 constant tensor +clamp,static int8, +constant_pad_nd,static int8,padding_value=0.0 only +conv2d,static int8,constant weights +dequantize_per_channel,, +dequantize_per_tensor,, +div,static int8, +embedding,static int8, +expand_copy,,"expanding at most one axis, new dimensions must be size 1" +gelu,static int8, +getitem,, +hardsigmoid,static int8, +hardswish,static int8, +hardtanh,static int8, +layer_norm,static int8,norm at last axis only +leaky_relu,static int8, +linear,static int8,constant weights +log_softmax,static int8, +max_pool2d,static int8,"ceil_mode=False, indices not supported" +maximum,, +mean_dim,static int8, +minimum,, +mul,static int8, +permute,static int8, +pixel_shuffle,, +quantize_per_channel,, +quantize_per_tensor,, +relu,static int8, +reshape,static int8, +rsqrt,static int8, +select,static int8, +slice_copy,static int8, +softmax,static int8, +sqrt,static int8, +squeeze,static int8, +sub,static int8, +to_copy,,memory_format=contiguous only +unsqueeze,static int8, +upsample_bilinear2d,static int8, +upsample_nearest2d,static int8, diff --git a/docs/source/backends/samsung/samsung-op-support.rst b/docs/source/backends/samsung/samsung-op-support.rst new file mode 100644 index 00000000000..ecccd565021 --- /dev/null +++ b/docs/source/backends/samsung/samsung-op-support.rst @@ -0,0 +1,11 @@ +================ +Operator Support +================ + +This page lists the PyTorch operators currently supported by the Samsung Exynos backend. + +.. csv-table:: Operator Support + :file: samsung-op-support-table.csv + :header-rows: 1 + :widths: 25 15 55 + :align: center diff --git a/docs/source/backends/samsung/samsung-overview.md b/docs/source/backends/samsung/samsung-overview.md new file mode 100644 index 00000000000..8b0dea0c696 --- /dev/null +++ b/docs/source/backends/samsung/samsung-overview.md @@ -0,0 +1,117 @@ +# Samsung Exynos Backend + +ExecuTorch's Samsung Exynos backend enables the execution of ExecuTorch models on +Samsung SoCs via the NPU/DSP. The delegate is built on top of the +[Samsung Exynos AI Litecore SDK]((https://soc-developer.semiconductor.samsung.com/global/development/ai-litecore)). + +## Features + +- Wide range of operator support +- Supported inference precisions: + - FP16 + - 8-bit statically quantized (int8/uint8) + - 16-bit statically quantized (int16/uint16) + +## Target Requirements + +Currently, the Samsung Exynos backend is supported only for devices with the +following chipsets: + +- Exynos 2500 (E9955) + +## Development Requirements + +The [Samsung Exynos AI Litecore SDK](https://soc-developer.semiconductor.samsung.com/global/development/ai-litecore) +is required to build the Exynos backend from source, and is also required to +export models to the Exynos delegate. + +---- + +## Using the Samsung Exynos Backend + +To target the Exynos backend during the export and lowering process, pass an instance of +the `EnnPartitioner` to `to_edge_transform_and_lower`. The example below +demonstrates this process using the MobileNet V2 model from torchvision. + +```python +import torch +import torchvision.models as models +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) +from executorch.exir import to_edge_transform_and_lower + +mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() +sample_inputs = (torch.randn(1, 3, 224, 224), ) + +chipset = "E9955" +compile_specs = [gen_samsung_backend_compile_spec(chipset)] + +et_program = to_edge_transform_and_lower( + torch.export.export(mobilenet_v2, sample_inputs), + partitioner=[EnnPartitioner(compile_specs)], +).to_executorch() + +with open("mv2_xnnpack.pte", "wb") as file: + et_program.write_to_file(file) +``` + +See [Partitioner API](samsung-partitioner.md) for a reference on available partitioner options. + +---- + +## Quantization + +The Samsung Exynos backend support statically quantized models with 8-bit and 16-bit +integral types. + +See [Samsung Exynos Quantization](samsung-quantization.md) for more +information on available quantization schemes and APIs. + +---- + +## Runtime Integration + +To run the model on-device, use the standard ExecuTorch runtime APIs. + +The Exynos backend is currently not available in any of ExecuTorch's published packages. +To access it, build ExecuTorch from source. When building from source, pass +`-DEXECUTORCH_BUILD_EXYNOS=ON` when configuring the CMake build. See [Running on Device](/getting-started.md#running-on-device) +for more information. + +Then, to link against the backend, add the `executorch_backends` CMake target as a build +dependency. + +``` +# CMakeLists.txt +add_subdirectory("executorch") +... +target_link_libraries( + my_target + PRIVATE executorch + executorch_backends + ... +) +``` + +No additional steps are necessary to use the backend beyond linking the target. Any +Exynos delegated .pte file will automatically run on the registered backend. + +## Reference + +**→{doc}`samsung-partitioner` — Partitioner options.** + +**→{doc}`samsung-quantization` — Supported quantization schemes.** + +**→{doc}`samsung-op-support` — Supported operators.** + +```{toctree} +:maxdepth: 2 +:hidden: +:caption: Exynos Backend + +samsung-partitioner +samsung-quantization +samsung-op-support diff --git a/docs/source/backends/samsung/samsung-partitioner.md b/docs/source/backends/samsung/samsung-partitioner.md new file mode 100644 index 00000000000..eb84a795551 --- /dev/null +++ b/docs/source/backends/samsung/samsung-partitioner.md @@ -0,0 +1,29 @@ +# Partitioner API + +The `EnnPartitioner` API is the primary entrypoint when exporting a model to the Samsung +Exynos backend. The partitioner is responsible for determining which parts of the model +should be lowered to the backend and also provides an interface for configuring the +behaviour of the backend. + +Currently, the configuration options for `EnnPartitioner` can be generated automatically +using the `gen_samsung_backend_compile_spec` API. For instance, + +```python +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.serialization.compile_options import ( + gen_samsung_backend_compile_spec, +) + +from executorch.exir import to_edge_transform_and_lower + +chipset = "E9955" +compile_specs = [gen_samsung_backend_compile_spec(chipset)] + +et_program = to_edge_transform_and_lower( + exported_program, + partitioner=[EnnPartitioner(compile_specs)], +).to_executorch() +``` + +At the moment, only `"E9955"` is supported as a valid chipset name, which corresponds to +the Exynose 2500 SoC. Support for additional chipsets will be added in the future. diff --git a/docs/source/backends/samsung/samsung-quantization.md b/docs/source/backends/samsung/samsung-quantization.md new file mode 100644 index 00000000000..ad4b50cb93d --- /dev/null +++ b/docs/source/backends/samsung/samsung-quantization.md @@ -0,0 +1,60 @@ +# Quantization + +The Exynos backend currently supports executing statically quantized 8-bit models. + +### 8-bit quantization with the PT2E quantization flow + +To perform 8-bit quantization with the PT2E flow, perform the following steps prior to exporting the model: + +1) Create an instance of the `EnnQuantizer` class and set the desired quantization behaviour. +2) Use `torch.export.export` to obtain a graph module representation of the source model. +3) Use `prepare_pt2e` to prepare the model for quantization. +4) Execute the prepared model with representative samples to calibrate the quantizated tensor activation ranges. +5) Use `convert_pt2e` to quantize the model. +6) Export and lower the model using the standard export flow. + +The output of `convert_pt2e` is a PyTorch model which can be exported and lowered using +the same export flow as non-quantized models. As it is a regular PyTorch model, it can +also be used to evaluate the accuracy of the quantized model using standard PyTorch +techniques. + +The below example shows how to quantize a MobileNetV2 model using the PT2E quantization flow. + +```python +import torch +import torchvision.models as models +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights + +from executorch.backends.samsung.partition.enn_partitioner import EnnPartitioner +from executorch.backends.samsung.quantizer.quantizer import EnnQuantizer, Precision + +from executorch.exir import to_edge_transform_and_lower +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + +model = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() +sample_inputs = (torch.randn(1, 3, 224, 224), ) + +# Currently, "A8W8" is the only supported precision mode +precision = "A8W8" +is_per_channel = True +is_qat = False + +quantizer = EnnQuantizer() +quantizer.set_quant_params(precision, is_per_channel, is_qat) # (1) + +training_ep = torch.export.export(model, sample_inputs).module() # (2) +prepared_model = prepare_pt2e(training_ep, quantizer) # (3) + +for cal_sample in [torch.randn(1, 3, 224, 224)]: # Replace with representative model inputs + prepared_model(cal_sample) # (4) Calibrate + +quantized_model = convert_pt2e(prepared_model) # (5) + +et_program = to_edge_transform_and_lower( # (6) + torch.export.export(quantized_model, sample_inputs), + partitioner=[EnnPartitioner()], +).to_executorch() +``` + +See [PyTorch 2 Export Post Training Quantization](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) +for more information. diff --git a/docs/source/backends/template/README.md b/docs/source/backends/template/README.md new file mode 100644 index 00000000000..e7cb037bd6c --- /dev/null +++ b/docs/source/backends/template/README.md @@ -0,0 +1,53 @@ +# Backend Documentation Template + +This template provides a standardized structure and starting point for backend documentation. It is intended to provide a uniform experience for users while allowing for backends to customize their documentation as needed. + +## Template Structure + +The template includes the following files: + +### Required Pages + +- `backend-overview.md` - Main backend overview and introduction + +### Recommended Pages + +- `backend-quantization.md` - Quantization support and API documentation +- `backend-partitioner.md` - Partitioner API reference +- `op-support.csv` - Operator support data in CSV format + +### Optional Pages (and Subsections) + +- `backend-troubleshooting.md` - Common issues and troubleshooting guide +- `backend-op-support.rst` - Operator support documentation (RST format) +- `backend-arch-internals.md` - Architecture and internals documentation +- `tutorials/backend-tutorials.md` - Tutorial sub-section + - Use this sub-section to provide tutorials for your backend. + - Tutorials should explain how a user can accomplish a task, in a step by step manner. + - Some examples might include: + - An end to end example of lowering and running a model on a specific platform. +- `tutorials/backend-guides.md` - Guides sub-section + - Use this sub-section to provide guides or how-tos for backend-specific functionality. + - Guides should focus on providing information and building conceptual understanding, rather than giving step by step directions. + - Some examples might include: + - LLM attention management / static attention + - Performance optimization guide + +## Using the Template + +To use this template for a new backend: + +1. Copy the entire `template` directory contents to your backend's documentation directory +2. Rename files to match your backend name (e.g., `backend-overview.md` → `mybackend-overview.md`) +3. Populate the content for your backend. + +### Additional Customization + +You may need to: +- Add backend-specific sections to any file +- Remove sections that don't apply to your backend +- Update the operator support CSV with your backend's supported operators +- Add backend-specific images or diagrams +- Update cross-references and links + +Try to keep the landing page (`backend-overview.md`) simple and straigtforward. Use the child pages and sections to provide more detailed information. diff --git a/docs/source/backends/template/backend-arch-internals.md b/docs/source/backends/template/backend-arch-internals.md new file mode 100644 index 00000000000..66c4a27eb4e --- /dev/null +++ b/docs/source/backends/template/backend-arch-internals.md @@ -0,0 +1,8 @@ +# {BACKEND_NAME} Architecture and Internals + +This page covers internal implementation details of the backend, and is mainly aimed at contributors and heavy power users. This is an optional page for each backend and has no set structure. + +Some topics to consider: + * High-level design of the backend + * Details on the lowering flow + * Internal debugging tools and techniques diff --git a/docs/source/backends/template/backend-overview.md b/docs/source/backends/template/backend-overview.md new file mode 100644 index 00000000000..666b70e1584 --- /dev/null +++ b/docs/source/backends/template/backend-overview.md @@ -0,0 +1,54 @@ +# Backend Template + +Provide a brief overview/description of the backend. At a high-level, what does it do? Consider linking to top-level vendor documentation for the target hardware family and/or framework (Core ML, XNNPACK, etc.). + +## Features + +List high-level features of backend, such as operator and hardware support. + +## Target Requirements + +What hardware and software is required to run the backend on a specific device? For example, does it require specific iOS or Android OS versions? If it's an NPU, what hardware models are supported? + +## Development Requirements + +What software and hardware is needed to create a .PTE file targeting this backend? Are there any additional dependencies that need to be installed that are not included with the ExecuTorch pip package? How does the user install them? + +## Using *Backend Name* + +This section describes the steps users need to take in order to generate a .PTE targeting this backend. Include a full code sample for exporting and lowering a model to this backend. Make sure relevant imports for the backend partitioner are included. + +## Runtime Integration + +This section is intended to tell the user all of the steps they'll need to take to be able to run a .PTE file on-device that is targeting the given backend. +- What CMake targets should they link to? +- How is this backend compiled from source? +- Is the backend bundled by default in iOS and/or Android pre-built libraries? + +## Reference + +**→{doc}`backend-partitioner` — Partitioner options.** + +**→{doc}`backend-quantization` — Supported quantization schemes.** + +**→{doc}`backend-troubleshooting` — Debug common issues.** + +**→{doc}`backend-arch-internals` — Backend internals.** + +**→{doc}`tutorials/backend-tutorials` — Tutorials.** + +**→{doc}`guides/backend-guides` — Tutorials.** + +```{toctree} +:maxdepth: 2 +:hidden: +:caption: {BACKEND} Backend + +backend-troubleshooting +backend-partitioner +backend-quantization +backend-op-support +backend-arch-internals +tutorials/backend-tutorials +guides/backend-guides +``` diff --git a/docs/source/backends/template/backend-partitioner.rst b/docs/source/backends/template/backend-partitioner.rst new file mode 100644 index 00000000000..981e5744aed --- /dev/null +++ b/docs/source/backends/template/backend-partitioner.rst @@ -0,0 +1,25 @@ +========================== +{BACKEND_NAME} Partitioner API +========================== + +Document the partitioner API for the backend, including configuration options and compile specs. + +- ``option1``: Description of the option and values. +- ``option2``: Description of the second option. +- ``option3``: Description of the third option. + +{ADDITIONAL_PARTITIONER_DETAILS} + +================ +Operator Support +================ + +This page lists the operators supported by the {BACKEND_NAME} backend. Operators are the building blocks of the ML model. See `IRs `_ for more information on the PyTorch operator set. + +{OPERATOR_SUPPORT_NOTES} + +.. csv-table:: Operator Support + :file: op-support.csv + :header-rows: 1 + :widths: 20 15 30 30 + :align: center diff --git a/docs/source/backends/template/backend-quantization.md b/docs/source/backends/template/backend-quantization.md new file mode 100644 index 00000000000..4997a56e248 --- /dev/null +++ b/docs/source/backends/template/backend-quantization.md @@ -0,0 +1,31 @@ +# {BACKEND_NAME} Quantization + +Document quantization schemes and flows for the backend. This should include a description of each scheme and a code example to perform quantization. Example sections for PT2E and quantize_ are included below, to be replaced with details for the target backend. + +For each supported quantization scheme, include the following: + * What is the quantization scheme? + * How are weights quantized? + * How are activations quantized? Static or dynamic? + * How many bits? + * What is the granularity? Per-tensor, per-channel, group/block-wise? + * What are the steps to quantize a model with this scheme? + * Include a code sample. + * If the quantization flow only supports a small set of operators - for example, linear only - note this. + +### Supported Quantization Schemes +The {BACKEND_NAME} delegate supports the following quantization schemes: + +- {QUANTIZATION_SCHEME_1} +- {QUANTIZATION_SCHEME_2} + +### {QUANTIZATION_METHOD_1} using the PT2E Flow + +[Description] + +[Code Sample] + +### LLM Quantization with quantize_ + +[Description] + +[Code Sample] diff --git a/docs/source/backends/template/backend-troubleshooting.md b/docs/source/backends/template/backend-troubleshooting.md new file mode 100644 index 00000000000..851c04f34ea --- /dev/null +++ b/docs/source/backends/template/backend-troubleshooting.md @@ -0,0 +1,15 @@ +# {BACKEND_NAME} Troubleshooting + +This page describes common issues that you may encounter when using the {BACKEND_NAME} backend and how to debug and resolve them. + +## {COMMON_ISSUE_1} + +{ISSUE_DESCRIPTION_1} + +{SOLUTION_STEPS_1} + +## {COMMON_ISSUE_2} + +{ISSUE_DESCRIPTION_2} + +{SOLUTION_STEPS_2} diff --git a/docs/source/backends/template/guides/backend-basic-guide.md b/docs/source/backends/template/guides/backend-basic-guide.md new file mode 100644 index 00000000000..44f86d8bd4d --- /dev/null +++ b/docs/source/backends/template/guides/backend-basic-guide.md @@ -0,0 +1,3 @@ +# Using {FEATURE} on {BACKEND_NAME} + +This is a placeholder guide. diff --git a/docs/source/backends/template/guides/backend-guides.md b/docs/source/backends/template/guides/backend-guides.md new file mode 100644 index 00000000000..dbeaf25742a --- /dev/null +++ b/docs/source/backends/template/guides/backend-guides.md @@ -0,0 +1,10 @@ +# {BACKEND_NAME} Guides + +**→{doc}`{backend_name}-basic-guide` — Guide description.** + +```{toctree} +:hidden: +:maxdepth: 1 + +{backend_name}-basic-guides +``` diff --git a/docs/source/backends/template/op-support.csv b/docs/source/backends/template/op-support.csv new file mode 100644 index 00000000000..66af56d6a44 --- /dev/null +++ b/docs/source/backends/template/op-support.csv @@ -0,0 +1,6 @@ +Operator,Compute DType,Quantization,Constraints +{OPERATOR_1},{DTYPE_SUPPORT_1},{QUANTIZATION_SUPPORT_1},{CONSTRAINTS_1} +{OPERATOR_2},{DTYPE_SUPPORT_2},{QUANTIZATION_SUPPORT_2},{CONSTRAINTS_2} +{OPERATOR_3},{DTYPE_SUPPORT_3},{QUANTIZATION_SUPPORT_3},{CONSTRAINTS_3} +{OPERATOR_4},{DTYPE_SUPPORT_4},{QUANTIZATION_SUPPORT_4},{CONSTRAINTS_4} +{OPERATOR_5},{DTYPE_SUPPORT_5},{QUANTIZATION_SUPPORT_5},{CONSTRAINTS_5} diff --git a/docs/source/backends/template/tutorials/backend-basic-tutorial.md b/docs/source/backends/template/tutorials/backend-basic-tutorial.md new file mode 100644 index 00000000000..23d76857116 --- /dev/null +++ b/docs/source/backends/template/tutorials/backend-basic-tutorial.md @@ -0,0 +1,91 @@ +# Preparing a Model for {BACKEND_NAME} + +This is a placeholder tutorial. + +## Step 1: Environment Setup + +This tutorial is intended to be run from a {SUPPORTED_HOST_OS} and uses Conda for Python environment management. For full setup details and system requirements, see [Getting Started with ExecuTorch](/getting-started). + +Create a Conda environment and install the ExecuTorch Python package. +```bash +conda create -y --name executorch python=3.12 +conda activate executorch +conda install executorch +``` + +{ADDITIONAL_SETUP_STEPS} + +## Step 2: Model Preparation + +Create a python file named `export_{model_filename}.py`. This script will be responsible for loading the {EXAMPLE_MODEL} model from {MODEL_SOURCE} and create a {BACKEND_NAME}-targeted .pte file. + +```py +# export_{model_filename}.py +from executorch.backends.{backend_name}.partition.{backend_name}_partitioner import {BackendName}Partitioner +from executorch.exir import to_edge_transform_and_lower +import torch +import {MODEL_IMPORT} +``` + +### Model Instantiation and Example Inputs + +Instantiate the {EXAMPLE_MODEL} model from [{MODEL_SOURCE}]({MODEL_SOURCE_URL}). The export process also needs an example model input to trace the model. The model takes {MODEL_INPUT_DESCRIPTION}, so we'll create {INPUT_TUPLE_DESCRIPTION}. +```py +model = {MODEL_INSTANTIATION_CODE} +example_inputs = ({EXAMPLE_INPUTS},) +``` + +### Lower the Model + +Next, export and lower the model to ExecuTorch. Note that the `{BackendName}Partitioner` passed to the `partitioner` parameter tells ExecuTorch to target the {BACKEND_NAME} backend. +```py +exported_program = torch.export.export(model, example_inputs) + +executorch_program = to_edge_transform_and_lower( + exported_program, + partitioner=[{BackendName}Partitioner()], +).to_executorch() + +executorch_program.save("{model_filename}_{backend_name}.pte") +``` + +### Run the Script + +Save the above script to export_{model_filename}.py and run the script. You should see a file named `{model_filename}_{backend_name}.pte` in the current directory. +```bash +python export_{model_filename}.py +``` + +## Step 3: Running the Model + +The .pte file created in the previous step can be run on a variety of devices, including {SUPPORTED_PLATFORMS}. ExecuTorch provides runtime APIs and language bindings for a variety of platforms. This tutorial will demonstrate running the model on a desktop using the Python runtime. + +### Smoke Test + +First, we'll verify that the model loads and runs correctly by running the model with {TEST_INPUT_DESCRIPTION}. Create a new script, named `run_{model_filename}.py`, and add the following code. +```py +# run_{model_filename}.py + +from executorch.runtime import Runtime +import torch + +runtime = Runtime.get() + +input_tensor = {TEST_INPUT_TENSOR} +program = runtime.load_program("{model_filename}_{backend_name}.pte") +method = program.load_method("forward") +outputs = method.execute([input_tensor])[0] + +print(outputs) +``` + +When running the script with `python run_{model_filename}.py`, you should see {EXPECTED_OUTPUT_DESCRIPTION} printed to the console. +``` +{EXPECTED_OUTPUT_EXAMPLE} +``` + +# Next Steps + + - See [Edge Platforms](/edge-platforms-section) to deploy the .pte file on {SUPPORTED_PLATFORMS}. + - See [Model Export and Lowering](/using-executorch-export) for more information on model preparation. + - See [{BACKEND_NAME} Overview](/backends/{backend_name}/{backend_name}-overview) for more information about the {BACKEND_NAME} backend. diff --git a/docs/source/backends/template/tutorials/backend-tutorials.md b/docs/source/backends/template/tutorials/backend-tutorials.md new file mode 100644 index 00000000000..15e226dd5c5 --- /dev/null +++ b/docs/source/backends/template/tutorials/backend-tutorials.md @@ -0,0 +1,10 @@ +# {BACKEND_NAME} Tutorials + +**→{doc}`{backend_name}-basic-tutorial` — Lower and run a model on the {BACKEND_NAME} backend.** + +```{toctree} +:hidden: +:maxdepth: 1 + +{backend_name}-basic-tutorial +``` diff --git a/docs/source/backends/vulkan/tutorials/etvk-llama-tutorial.md b/docs/source/backends/vulkan/tutorials/etvk-llama-tutorial.md new file mode 100644 index 00000000000..cb14c72331e --- /dev/null +++ b/docs/source/backends/vulkan/tutorials/etvk-llama-tutorial.md @@ -0,0 +1,159 @@ +# Exporting Llama 3.2 1B/3B Instruct to ExecuTorch Vulkan and running on device + +This tutorial assumes that you have a working local copy of the ExecuTorch repo, +and have gone through the steps to install the executorch pip package or have +installed it by building from source. + +This tutorial also assumes that you have the Android SDK tools installed and +that you are able to connect to an Android device via `adb`. + +Finally, the Android NDK should also be installed, and your environment should +have a variable `ANDROID_NDK` that points to the root directory of the NDK. + +```shell +export ANDROID_NDK= +``` + +## Download the Llama 3.2 1B/3B Instruct model checkpoint and tokenizer + +The model checkpoint and tokenizer can be downloaded from the +[Meta Llama website](https://www.llama.com/llama-downloads/). + +The model files should be downloaded to `~/.llama/checkpoints/Llama3.2-1B-Instruct`. + +## Export the Llama 3.2 1B/3B model + +First, navigate to the root of the ExecuTorch repo. + +```shell +# Navigate to executorch root +cd ~/executorch +``` + +Then, set some environment variables to describe how the model should be +exported. Feel free to tune the values to your preferences. + +```shell +export LLM_NAME=Llama3.2 && \ +export LLM_SIZE=1B && \ +export LLM_SUFFIX="-Instruct" && \ +export QUANT=8da4w && \ +export BACKEND=vulkan && \ +export GROUP_SIZE=64 && \ +export CONTEXT_LENGTH=2048 +``` + +Then, export the Llama 3.2 1B/3B Instruct model to ExecuTorch Vulkan. Note that +that `--vulkan-force-fp16` flag is set, which will improve model inference +latency at the cost of model accuracy. Feel free to remove this flag. + +```shell +python -m examples.models.llama.export_llama \ + -c $HOME/.llama/checkpoints/${LLM_NAME}-${LLM_SIZE}${LLM_SUFFIX}/consolidated.00.pth \ + -p $HOME/.llama/checkpoints/${LLM_NAME}-${LLM_SIZE}${LLM_SUFFIX}/params.json \ + -d fp32 --${BACKEND} \ + -qmode ${QUANT} -G ${GROUP_SIZE} \ + --max_seq_length ${CONTEXT_LENGTH} \ + --max_context_length ${CONTEXT_LENGTH} \ + -kv --use_sdpa_with_kv_cache \ + --metadata '{"append_eos_to_prompt": 0, "get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \ + --model "llama3_2" \ + --output_name $HOME/.llama/checkpoints/${LLM_NAME}-${LLM_SIZE}${LLM_SUFFIX}/${LLM_NAME}-${LLM_SIZE}${LLM_SUFFIX}_${BACKEND}_${QUANT}_g${GROUP_SIZE}_c${CONTEXT_LENGTH}.pte + +``` + +After exporting the model, push the exported `.pte` file and the tokenizer to +your device. + +```shell +adb shell mkdir -p /data/local/tmp/llama && \ +adb push ~/.llama/checkpoints/${LLM_NAME}-${LLM_SIZE}${LLM_SUFFIX}/tokenizer.model \ + /data/local/tmp/llama/${LLM_NAME}-${LLM_SIZE}${LLM_SUFFIX}_tokenizer.model && \ +adb push ~/.llama/checkpoints/${LLM_NAME}-${LLM_SIZE}${LLM_SUFFIX}/${LLM_NAME}-${LLM_SIZE}${LLM_SUFFIX}_${BACKEND}_${QUANT}_g${GROUP_SIZE}_c${CONTEXT_LENGTH}.pte \ + /data/local/tmp/llama/${LLM_NAME}-${LLM_SIZE}${LLM_SUFFIX}_${BACKEND}_${QUANT}_g${GROUP_SIZE}_c${CONTEXT_LENGTH}.pte +``` + +## Build Core Executorch Components + +To be able to run the `.pte` file on device, first the core libraries, +including the Vulkan backend, must be compiled for Android. + +```shell +cmake . \ + -DCMAKE_INSTALL_PREFIX=cmake-out-android-so \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DANDROID_SUPPORT_FLEXIBLE_PAGE_SIZES=ON \ + --preset "android-arm64-v8a" \ + -DANDROID_PLATFORM=android-28 \ + -DPYTHON_EXECUTABLE=python \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_PAL_DEFAULT=posix \ + -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ + -DEXECUTORCH_BUILD_VULKAN=ON \ + -DEXECUTORCH_BUILD_TESTS=OFF \ + -Bcmake-out-android-so && \ +cmake --build cmake-out-android-so -j16 --target install --config Release +``` + +## Build and push the llama runner binary to Android + +Then, build a binary that can be used to run the `.pte` file. + +```shell +cmake examples/models/llama \ + -DCMAKE_INSTALL_PREFIX=cmake-out-android-so \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DANDROID_SUPPORT_FLEXIBLE_PAGE_SIZES=ON \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DCMAKE_BUILD_TYPE=Release \ + -DPYTHON_EXECUTABLE=python \ + -Bcmake-out-android-so/examples/models/llama && \ +cmake --build cmake-out-android-so/examples/models/llama -j16 --config Release +``` + +Once the binary is built, it can be pushed to your Android device. + +```shell +adb shell mkdir /data/local/tmp/etvk/ && \ +adb push cmake-out-android-so/examples/models/llama/llama_main /data/local/tmp/etvk/ +``` + +## Execute the llama runner binary + +Finally, we can execute the lowered `.pte` file on your device. + +```shell +adb shell /data/local/tmp/etvk/llama_main \ + --model_path=/data/local/tmp/llama/${LLM_NAME}-${LLM_SIZE}${LLM_SUFFIX}_${BACKEND}_${QUANT}_g${GROUP_SIZE}_c${CONTEXT_LENGTH}.pte \ + --tokenizer_path=/data/local/tmp/llama/${LLM_NAME}-${LLM_SIZE}${LLM_SUFFIX}_tokenizer.model \ + --temperature=0 --seq_len=400 --warmup \ + --prompt=\"\<\|begin_of_text\|\>\<\|start_header_id\|\>system\<\|end_header_id\|\>Write me a short poem.\<\|eot_id\|\>\<\|start_header_id\|\>assistant\<\|end_header_id\|\>\" +``` + +Here is some sample output captured from a Galaxy S24: + +```shell +E tokenizers:hf_tokenizer.cpp:60] Error parsing json file: [json.exception.parse_error.101] parse error at line 1, column 1: syntax error while parsing value - invalid literal; last read: 'I' +<|begin_of_text|><|start_header_id|>system<|end_header_id|>Write me a short poem.<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +Here is a short poem I came up with: + +"Moonlight whispers secrets to the night +A gentle breeze that rustles the light +The stars up high, a twinkling show +A peaceful world, where dreams grow slow" + +I hope you enjoy it!<|eot_id|> + +PyTorchObserver {"prompt_tokens":14,"generated_tokens":54,"model_load_start_ms":1760077800721,"model_load_end_ms":1760077802998,"inference_start_ms":1760077802998,"inference_end_ms":1760077804187,"prompt_eval_end_ms":1760077803162,"first_token_ms":1760077803162,"aggregate_sampling_time_ms":19,"SCALING_FACTOR_UNITS_PER_SECOND":1000} + Prompt Tokens: 14 Generated Tokens: 54 + Model Load Time: 2.277000 (seconds) + Total inference time: 1.189000 (seconds) Rate: 45.416316 (tokens/second) + Prompt evaluation: 0.164000 (seconds) Rate: 85.365854 (tokens/second) + Generated 54 tokens: 1.025000 (seconds) Rate: 52.682927 (tokens/second) + Time to first generated token: 0.164000 (seconds) + Sampling time over 68 tokens: 0.019000 (seconds) +``` diff --git a/docs/source/backends/vulkan/tutorials/etvk-profiling-tutorial.md b/docs/source/backends/vulkan/tutorials/etvk-profiling-tutorial.md new file mode 100644 index 00000000000..07982d81c1c --- /dev/null +++ b/docs/source/backends/vulkan/tutorials/etvk-profiling-tutorial.md @@ -0,0 +1,144 @@ +# Executing and profiling an ExecuTorch Vulkan model on device + +This tutorial assumes that you have a working local copy of the ExecuTorch repo, +and have gone through the steps to install the executorch pip package or have +installed it by building from source. + +This tutorial also assumes that you have the Android SDK tools installed and +that you are able to connect to an Android device via `adb`. + +Finally, the Android NDK should also be installed, and your environment should +have a variable `ANDROID_NDK` that points to the root directory of the NDK. + +```shell +export ANDROID_NDK= +``` + +## Lower a model to ExecuTorch Vulkan and obtain the `.pte` file + + +The commands in this tutorial are assumed to be executed from ExecuTorch's root +directory. + +```shell +cd ~/executorch +``` + +For this tutorial, we will use the export script in +[`executorch/examples/vulkan/export.py`](https://github.com/pytorch/executorch/tree/main/examples/vulkan), +however any method of generating a `.pte` file will suffice. In this tutorial, +the InceptionV3 model is exported. + +```shell +python -m examples.vulkan.export --model_name=ic3 -o . -fp16 +``` + +After exporting, there should be a file called `ic3_vulkan.pte` in the root +directory of ExecuTorch. Feel free to modify the `-o` argument of the script to +control where the `.pte` file will be stored. + +Then, push the `.pte` file to device. + +```shell +adb shell mkdir -p /data/local/tmp/etvk/models/ && \ +adb push ic3_vulkan.pte /data/local/tmp/etvk/models/ic3_vulkan.pte +``` + +## Build the `executor_runner` binary and push to device + +To be able to run the `.pte` file on device, first the core libraries, +including the Vulkan backend, must be compiled for Android. Note that +`-DEXECUTORCH_ENABLE_EVENT_TRACER=ON` is used to turn on profiling, and +`-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON` is used to build the runner binary that +will be used to execute and profile the `.pte` file. + + +```shell +cmake . \ + -DCMAKE_INSTALL_PREFIX=cmake-out-android-so \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DANDROID_SUPPORT_FLEXIBLE_PAGE_SIZES=ON \ + --preset "android-arm64-v8a" \ + -DANDROID_PLATFORM=android-28 \ + -DPYTHON_EXECUTABLE=python \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_PAL_DEFAULT=posix \ + -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ + -DEXECUTORCH_BUILD_VULKAN=ON \ + -DEXECUTORCH_BUILD_TESTS=OFF \ + -DEXECUTORCH_BUILD_EXTENSION_EVALUE_UTIL=ON \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ + -Bcmake-out-android-so && \ +cmake --build cmake-out-android-so -j16 --target install --config Release +``` + +Once the build completes, we can push the runner binary to device. + +```shell +adb push cmake-out-android-so/executor_runner /data/local/tmp/etvk/executor_runner +``` + +## Execute the `.pte` file + +Finally, we can execute the lowered `.pte` file on your device. To test run the +model file without profiling: + +```shell +adb shell /data/local/tmp/etvk/executor_runner \ + --model_path /data/local/tmp/etvk/models/ic3_vulkan.pte +``` + +Now, with profiling: + +```shell +MODEL_NAME=ic3 && \ +BACKEND=vulkan && \ +NUM_ITERS=3 && \ +adb shell mkdir -p /data/local/tmp/etvk/etdumps/ && \ +adb shell /data/local/tmp/etvk/executor_runner \ + --model_path /data/local/tmp/etvk/models/${MODEL_NAME}_${BACKEND}.pte \ + --num_executions=${NUM_ITERS} \ + --etdump_path /data/local/tmp/etvk/etdumps/${MODEL_NAME}_${BACKEND}.etdp && \ +adb pull /data/local/tmp/etvk/etdumps/${MODEL_NAME}_${BACKEND}.etdp ${MODEL_NAME}_${BACKEND}.etdp && \ +adb shell rm /data/local/tmp/etvk/etdumps/${MODEL_NAME}_${BACKEND}.etdp && \ +python devtools/inspector/inspector_cli.py \ + --etdump_path ${MODEL_NAME}_${BACKEND}.etdp +``` + +Here is some sample (tailed) output from a Samsung Galaxy S24: + +```shell +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 165 │ Execute │ conv2d_clamp_half_163 │ 0.345082 │ 0.346164 │ 0.346247 │ 0.345748 │ 0.344812 │ 0.346268 │ [] │ True │ │ [2081488974948084, 2081488995911052, 2081489016763676] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 166 │ Execute │ conv2d_clamp_half_164 │ 0.306124 │ 0.30654 │ 0.306998 │ 0.306557 │ 0.30602 │ 0.307112 │ [] │ True │ │ [2081488975294716, 2081488996256228, 2081489017110204] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 167 │ Execute │ set_zero_int32_165 │ 0.00240245 │ 0.00244403 │ 0.00248561 │ 0.00244403 │ 0.00239205 │ 0.002496 │ [] │ True │ │ [2081488975601100, 2081488996563132, 2081489017417680] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 168 │ Execute │ concat_2_texture3d_half_166 │ 0.0122305 │ 0.01248 │ 0.0125634 │ 0.0124108 │ 0.0121682 │ 0.0125842 │ [] │ True │ │ [2081488975603960, 2081488996565940, 2081489017420436] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 169 │ Execute │ set_zero_int32_167 │ 0.00157056 │ 0.00161195 │ 0.00161214 │ 0.00159478 │ 0.00156021 │ 0.00161219 │ [] │ True │ │ [2081488975616804, 2081488996578888, 2081489017432968] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 170 │ Execute │ concat_3_texture3d_half_168 │ 0.0420369 │ 0.0423281 │ 0.0427857 │ 0.0423974 │ 0.0419641 │ 0.0429001 │ [] │ True │ │ [2081488975618728, 2081488996580864, 2081489017434944] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 171 │ Execute │ update_concat_offset_3_int32_169 │ 0.00261035 │ 0.00265193 │ 0.00265212 │ 0.00263468 │ 0.00259995 │ 0.00265217 │ [] │ True │ │ [2081488975661992, 2081488996623556, 2081489017477272] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 172 │ Execute │ concat_1_texture3d_half_170 │ 0.00758157 │ 0.00774789 │ 0.00803914 │ 0.00779994 │ 0.00753999 │ 0.00811195 │ [] │ True │ │ [2081488975664956, 2081488996626572, 2081489017480288] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 173 │ Execute │ mean2d_half_171 │ 0.0147889 │ 0.0148721 │ 0.0150384 │ 0.0149067 │ 0.0147681 │ 0.01508 │ [] │ True │ │ [2081488975673432, 2081488996634476, 2081489017488400] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 174 │ Execute │ view_half_172 │ 0.00644803 │ 0.00644803 │ 0.00653119 │ 0.00648268 │ 0.00644803 │ 0.00655198 │ [] │ True │ │ [2081488975688876, 2081488996649712, 2081489017503532] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 175 │ Execute │ view_half_173 │ 0.00488806 │ 0.00488806 │ 0.00488806 │ 0.00488806 │ 0.00488806 │ 0.00488806 │ [] │ True │ │ [2081488975695688, 2081488996656524, 2081489017510448] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 176 │ Execute │ linear_naive_texture3d_half_174 │ 0.586726 │ 0.590096 │ 0.595338 │ 0.590876 │ 0.585884 │ 0.596648 │ [] │ True │ │ [2081488975700940, 2081488996661776, 2081489017515700] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 177 │ Execute │ image_to_nchw_texture3d_half_float_175 │ 0.00270395 │ 0.00270414 │ 0.00274572 │ 0.00272139 │ 0.00270391 │ 0.00275612 │ [] │ True │ │ [2081488976297952, 2081488997248024, 2081489018106160] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 178 │ Execute │ DELEGATE_CALL │ 20.8864 │ 20.9461 │ 21.5925 │ 21.1906 │ 20.8715 │ 21.7541 │ [] │ False │ │ [358395625, 380178646, 401147657] │ +├─────┼────────────────────┼────────────────────────────────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼──────────────┼────────────┼───────────────────┼─────────────────────────┼────────────────────────────────────────────────────────┤ +│ 179 │ Execute │ Method::execute │ 20.8867 │ 20.9464 │ 21.593 │ 21.191 │ 20.8718 │ 21.7547 │ [] │ False │ │ [358395521, 380178542, 401147552] │ +╘═════╧════════════════════╧════════════════════════════════════════╧══════════════╧══════════════╧══════════════╧══════════════╧══════════════╧══════════════╧════════════╧═══════════════════╧═════════════════════════╧════════════════════════════════════════════════════════╛ +``` diff --git a/docs/source/backends/vulkan/tutorials/vulkan-tutorials.md b/docs/source/backends/vulkan/tutorials/vulkan-tutorials.md new file mode 100644 index 00000000000..953c93a9c12 --- /dev/null +++ b/docs/source/backends/vulkan/tutorials/vulkan-tutorials.md @@ -0,0 +1,13 @@ +# Vulkan Backend Tutorials + +**→{doc}`etvk-profiling-tutorial`** + +**→{doc}`etvk-llama-tutorial`** + +```{toctree} +:maxdepth: 2 +:hidden: +:caption: Tutorials + +etvk-profiling-tutorial +etvk-llama-tutorial diff --git a/docs/source/backends/vulkan/vulkan-op-support-table.csv b/docs/source/backends/vulkan/vulkan-op-support-table.csv new file mode 100644 index 00000000000..34d2ece924a --- /dev/null +++ b/docs/source/backends/vulkan/vulkan-op-support-table.csv @@ -0,0 +1,113 @@ +Namespace,Operator,Notes +aten,_log_softmax, +aten,_native_batch_norm_legit_no_training, +aten,_softmax, +aten,_to_copy,dtype conversion between float types only +aten,_weight_int8pack_mm, +aten,abs, +aten,add, +aten,addmm, +aten,amax,keepdim=True required; max 2D reductions +aten,amin,keepdim=True required; max 2D reductions +aten,arange, +aten,avg_pool2d, +aten,bmm, +aten,cat, +aten,clamp, +aten,clone, +aten,constant_pad_nd, +aten,convolution,batch=1 for 2D conv; no transposed 1D conv; no 3D conv +aten,cos, +aten,div, +aten,div.Tensor_mode, +aten,embedding, +aten,eq, +aten,exp, +aten,expand_copy,no resize support +aten,flip, +aten,full, +aten,full_like, +aten,ge, +aten,gelu, +aten,gt, +aten,hardshrink, +aten,hardtanh, +aten,index_select, +aten,le, +aten,leaky_relu, +aten,linear, +aten,lt, +aten,max_pool2d, +aten,max_pool2d_with_indices, +aten,mean,keepdim=True required; max 2D reductions +aten,minimum, +aten,mm, +aten,native_group_norm, +aten,native_layer_norm,resize supported +aten,neg, +aten,ones, +aten,ones_like, +aten,permute, +aten,permute_copy, +aten,pow, +aten,relu, +aten,repeat, +aten,round, +aten,rsqrt, +aten,scalar_tensor, +aten,select_copy, +aten,sigmoid, +aten,sin, +aten,slice_copy, +aten,split, +aten,split_with_sizes_copy, +aten,sqrt, +aten,squeeze_copy, +aten,sub, +aten,sum,keepdim=True required; max 2D reductions +aten,t_copy, +aten,tanh, +aten,unsqueeze_copy, +aten,upsample_bilinear2d, +aten,upsample_nearest2d, +aten,view_copy, +aten,zeros, +aten,zeros_like, +aten,_assert_scalar,removed via graph pass +aten,sym_constrain_range_for_size,removed via graph pass +aten,sym_size, +dim_order_ops,_clone_dim_order,no dtype conversion; removable if no dtype change +dim_order_ops,_to_dim_order_copy,no dtype conversion; removable if no dtype change +llama,custom_sdpa, +llama,sdpa_with_kv_cache, +llama,update_cache, +operator,add, +operator,eq, +operator,ge, +operator,getitem, +operator,gt, +operator,le, +operator,lt, +quantized_decomposed,choose_qparams, +quantized_decomposed,choose_qparams_per_token_asymmetric, +quantized_decomposed,dequantize_per_channel, +quantized_decomposed,dequantize_per_tensor, +quantized_decomposed,dequantize_per_token, +quantized_decomposed,quantize_per_channel, +quantized_decomposed,quantize_per_tensor, +quantized_decomposed,quantize_per_token, +torchao,choose_qparams_affine, +torchao,dequantize_affine, +torchao,quantize_affine, +et_vk,add_q8ta_q8ta_q8to,no resize support +et_vk,apply_rotary_emb, +et_vk,conv2d_q8ta_q8csw_q8to,no resize support +et_vk,conv2d_q8ta_q8csw_q8to_dw,no resize support +et_vk,conv_with_clamp,batch=1 for 2D conv; no transposed 1D conv +et_vk,dequantize_q8to_from_conv2d,no resize support +et_vk,grid_priors, +et_vk,linear_dq8ca_q4gsw, +et_vk,linear_q4gsw, +et_vk,linear_q8ta_q8csw, +et_vk,linear_qcs4w, +et_vk,quantize_q8ta_for_conv2d,no resize support diff --git a/docs/source/backends/vulkan/vulkan-op-support.rst b/docs/source/backends/vulkan/vulkan-op-support.rst new file mode 100644 index 00000000000..547f7f9dc6c --- /dev/null +++ b/docs/source/backends/vulkan/vulkan-op-support.rst @@ -0,0 +1,46 @@ +================ +Operator Support +================ + +This page lists the operators currently supported by the Vulkan backend. The +source of truth for this information is `op_registry.py `_, +which is used by the Vulkan Partitioner to determine which operators should be +lowered to the Vulkan backend and additionally describes the capabilities of +each operator implementation. + +If an operator used in your model is not in this list, feel free to create a +feature request on Github and we will do our best to add an implementation for +the operator. + +The namespace of an operator describes where it originates from: + +* **aten** - operators in this namespace correspond 1:1 to operators in PyTorch's + `ATen library `_. + They all support fp16 and fp32 dtypes at a minimum. +* **dim_order_op** - these operators are inserted when lowering to ExecuTorch in + order to manage optimal tensor memory layouts. They are typically removed, + since the Vulkan backend manages optimal tensor representations internally. +* **llama** - custom ops targeted for LLM inference. These are typically inserted + by model source transformations applied to a `nn.Module` and are not invoked + directly by a PyTorch model. +* **operator** - these operators work with symbolic integers, which are also + supported by the Vulkan backend. +* **quantized_decomposed** / **torchao** - these ops are introduced by quantization + workflows (either torchao's `quantize_` API or the PT2E quantization flow). + They typically represent quantizing/dequantizing a tensor, or choosing the + quantization parameters for a tensor. In practice, most instances of these + operators will be fused into a custom op in the **et_vk** namespace. +* **et_vk** - these are custom operators implemented only in the Vulkan backend. + They typically represent quantized variants of **aten** operators, or fusions + of common operator patterns. They are inserted by operator fusion graph passes + when lowering to the Vulkan backend. + +All operators support dynamic input shapes unless otherwise noted (i.e. "no +resize support"). The expectation is that over time, all operators will be able +to support dynamic shapes. + +.. csv-table:: Vulkan Backend Operator Support + :file: vulkan-op-support-table.csv + :header-rows: 1 + :widths: 25 25 75 + :align: left diff --git a/docs/source/backends/vulkan/vulkan-overview.md b/docs/source/backends/vulkan/vulkan-overview.md new file mode 100644 index 00000000000..ede7d330e4b --- /dev/null +++ b/docs/source/backends/vulkan/vulkan-overview.md @@ -0,0 +1,163 @@ +# Vulkan Backend + +The ExecuTorch Vulkan (ET-VK) backend enables ExecuTorch models to execute on +GPUs via the cross-platform [Vulkan API](https://www.vulkan.org/). Although the +Vulkan API support is almost ubiquitous among modern GPUs, the ExecuTorch Vulkan +backend is currently developed with a specific focus for **Android GPUs**. + +## Features + +- Wide operator support via an in-tree [GLSL compute shader library](https://github.com/pytorch/executorch/tree/main/backends/vulkan/runtime/graph/ops/glsl) +- Support for models that require dynamic shapes +- Support for FP32 and FP16 inference modes +- Support for quantized linear layers with 8-bit/4-bit weights and 8-bit dynamically quantized activations +- Support for quantized linear layers with 8-bit/4-bit weights and FP32/FP16 activations + +Note that the Vulkan backend is under active development, and its GLSL compute +shader library is being consistently expanded over time. Additional support for +quantized operators (i.e. quantized convolution) and additional quantization +modes is on the way. + +## Target Requirements + +- Supports Vulkan 1.1 + +## Development Requirements + +To contribute to the Vulkan delegate, the [Vulkan SDK](https://vulkan.lunarg.com/sdk/home#android) +must be installed on the development system. After installation, the `glslc` binary must +be found in your `PATH` in order to compile Vulkan shaders. This can be checked by +running + +```sh +glslc --version +``` + +If this is not the case after completing the Vulkan SDK installation, you may have to +go into `~/VulkanSDK//` and run + +```sh +source setup-env.sh +``` + +or alternatively, + +```sh +python install_vulkan.py +``` + +The [Android NDK](https://developer.android.com/ndk/downloads) must also be installed. +Any NDK version past NDK r17c should suffice. + +---- + +## Using the Vulkan Backend + +To lower a model to the Vulkan backend during the export and lowering process, +pass an instance of `VulkanPartitioner` to `to_edge_transform_and_lower`. The +example below demonstrates this process using the MobileNet V2 model from +torchvision. + +```python +import torch +import torchvision.models as models + +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner +from executorch.exir import to_edge_transform_and_lower + +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights + +mobilenet_v2 = models.mobilenetv2.mobilenet_v2( + weights=MobileNet_V2_Weights.DEFAULT +).eval() + +sample_inputs = (torch.randn(1, 3, 224, 224),) + +exported_program = torch.export.export(mobilenet_v2, sample_inputs) + +etvk_program = to_edge_transform_and_lower( + exported_program, + partitioner=[VulkanPartitioner()], +).to_executorch() + +with open("mv2_vulkan.pte", "wb") as file: + etvk_program.write_to_file(file) +``` + +See [Partitioner API](vulkan-partitioner.md) +for a reference on available partitioner options. + +---- + +## Quantization + +The Vulkan delegate currently supports execution of quantized linear layers. +See [Vulkan Quantization](vulkan-quantization.md) +for more information on available quantization schemes and APIs. + +---- + +## Runtime Integration + +To run the model on-device, use the standard ExecuTorch runtime APIs. + +For integration in Android applications, the Vulkan backend is included in the +[executorch-android-vulkan](https://mvnrepository.com/artifact/org.pytorch/executorch-android-vulkan) +package. + +When building from source, pass `-DEXECUTORCH_BUILD_VULKAN=ON` when configuring +the CMake build to compile the Vulkan backend. See [Running on Device](/getting-started.md#running-on-device) +for more information. + +To link against the backend, add the `executorch_backends` CMake target as a +build dependency, or link directly against `libvulkan_backend`. Due to the use +of static initialization to register available compute shaders and operators, +it is required to ensure that the library is linked with `--whole-archive`. + +```cmake +# CMakeLists.txt +find_package(executorch CONFIG REQUIRED COMPONENTS vulkan_backend executorch_backends) + +... +target_link_libraries( + my_target + PRIVATE + executorch + executorch_backends + ... +) + +# Ensure that unused code is not discarded. The required linker options may be +# different depending on the target platform. Typically, the +# executorch_target_link_options_shared_lib function from +# executorch/tools/cmake/Utils.cmake can be used to set the required linker +# options. +target_link_options( + executorch_backends INTERFACE "SHELL:LINKER:--whole-archive \ + $ \ + LINKER:--no-whole-archive" +) +``` + +No additional steps are necessary to use the backend beyond linking the target. +Any Vulkan-delegated .pte file will automatically run on the registered backend. + +## Additional Resources + +**→{doc}`/backends/vulkan/vulkan-partitioner`** + +**→{doc}`/backends/vulkan/vulkan-quantization`** + +**→{doc}`/backends/vulkan/vulkan-troubleshooting`** + +```{toctree} +:maxdepth: 2 +:hidden: +:caption: Vulkan Backend + +vulkan-partitioner +vulkan-quantization +vulkan-op-support +vulkan-troubleshooting + +tutorials/vulkan-tutorials diff --git a/docs/source/backends/vulkan/vulkan-partitioner.md b/docs/source/backends/vulkan/vulkan-partitioner.md new file mode 100644 index 00000000000..566ec491b47 --- /dev/null +++ b/docs/source/backends/vulkan/vulkan-partitioner.md @@ -0,0 +1,55 @@ +# Partitioner API + +[VulkanPartitioner](https://github.com/pytorch/executorch/blob/main/backends/vulkan/partitioner/vulkan_partitioner.py) +is a Python class that controls what operators in a model can or should be +delegated to the Vulkan backend. It is the primary entrypoint to the Vulkan +backend and is also used to configure the behaviour of the Vulkan backend. + +## Usage + +For most use-cases, constructing `VulkanPartitioner()` with no arguments is +sufficient. In this case, the partitioner will lower as much of the model to +the Vulkan backend as possible. + +```python +etvk_program = to_edge_transform_and_lower( + exported_program, + partitioner=[VulkanPartitioner()], +).to_executorch() +``` + +## Common Config Options + +Generally, the Vulkan backend is configured by passing a `compile_options` +dictionary to `VulkanPartitioner()`, i.e. + +```python +compile_options = { + "require_dynamic_shapes": True, + "force_fp16": True, +} + +etvk_program = to_edge_transform_and_lower( + exported_program, + partitioner=[VulkanPartitioner(compile_options)], +).to_executorch() +``` + +### `require_dynamic_shapes` + +If a model is expected to use dynamic shapes, then it is recommended to set the +`"required_dynamic_shapes"` key in `compile_options`. + +Not all operators in Vulkan support dynamic shapes at the moment, although the +majority do. This flag will prevent operators that don't support dynamic shapes +from being lowered to Vulkan. + +### `force_fp16` + +This option causes the Vulkan backend to internally convert all FP32 tensors to +FP16. This can improve inference latency and memory footprint at the cost of +model accuracy. + +FP32 input tensors will be automatically converted to FP16 upon entering the +Vulkan backend, and FP16 outputs will be automatically be converted to FP32 as +they are returned. diff --git a/docs/source/backends/vulkan/vulkan-quantization.md b/docs/source/backends/vulkan/vulkan-quantization.md new file mode 100644 index 00000000000..89c9f7514b0 --- /dev/null +++ b/docs/source/backends/vulkan/vulkan-quantization.md @@ -0,0 +1,163 @@ +# Quantization + +The Vulkan backend currently supports execution of quantized linear layers, +where weights are symmetrically quantized to 8-bit or 4-bit with per output +channel or per group quantization scales. + +Support for additional quantized operators and quantization schemes (i.e. static ++ dynamic quantized convolution, support for statically quantized linear) is +under active development and will be added soon. + +### 4-bit quantization with torchao `quantize_` + +The `quantize_` API from [torchao](https://github.com/pytorch/ao) allows for +more advanced quantization schemes, and is the quantization workflow needed to +access 4-bit quantization. 4-bit quantization is commonly used for LLMs. + +Two options are available to execute linear layers with 4-bit quantization: + +1. Dynamically quantized activations via `Int8DynamicActivationIntxWeightConfig` +2. Weight only quantization via `IntxWeightOnlyConfig` + +Dynamically quantized activations can provide a significant boost in latency +compared to weight only quantization, since it allows GPUs to leverage +accelerated integer dot product instructions when computing matrix +multiplication. + +Below is a simple example of quantizing a simple sequence of linear layers using +the `quantize_` API. + +```python +import torch + +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner + +from executorch.exir import to_edge_transform_and_lower +from torchao.quantization.granularity import PerGroup +from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + quantize_, +) +from torchao.utils import unwrap_tensor_subclass + + +class LinearSequenceModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 64, bias=False) + self.linear2 = torch.nn.Linear(64, 32, bias=False) + self.linear3 = torch.nn.Linear(32, 16, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + + +linear_sequence_module = LinearSequenceModule() + +M = 32 +sample_inputs = (torch.randn(M, 128),) + +group_size = 32 + +q_config_8da4w = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, weight_granularity=PerGroup(group_size) +) + +q_config_4w = IntxWeightOnlyConfig( + weight_dtype=torch.int4, granularity=PerGroup(group_size) +) + +quantize_(linear_sequence_module, q_config_8da4w) +unwrap_tensor_subclass(linear_sequence_module) + +# Regular export path from here +exported_program = torch.export.export(linear_sequence_module, sample_inputs) + +etvk_program = to_edge_transform_and_lower( + exported_program, + partitioner=[VulkanPartitioner()], +).to_executorch() +``` + +### 8-bit quantization with PT2E quantization + +For 8-bit quantized linear layers, currently the only quantization scheme +supported is weight only quantization, with weights that are symmetrically +quantized to 8 bits with per output channel quantization scales. + +To access this quantization mode, the PT2E quantization flow must be used. At a +high level, the steps to quantize a model are: + +1) Create an instance of the `VulkanQuantizer` class and specify desired quantization behaviour +2) Use `torch.export.export` to prepare for quantization. +3) Call `prepare_pt2e` to prepare the exported graph for quantization. +4) Execute the prepared model with representative samples to calibrate the quantizated tensor activation ranges. +5) Call `convert_pt2e` to quantize the model. +6) Export and lower the model using the standard flow. + +For example: + +```python +import torch + +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner + +from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( + get_symmetric_quantization_config, + VulkanQuantizer, +) + +from executorch.exir import to_edge_transform_and_lower + +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + +from torchao.utils import unwrap_tensor_subclass + + +class LinearSequenceModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 64, bias=False) + self.linear2 = torch.nn.Linear(64, 32, bias=False) + self.linear3 = torch.nn.Linear(32, 16, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + + +linear_sequence_module = LinearSequenceModule() + +M = 32 +# Create sample inputs +sample_inputs = (torch.randn(M, 128),) + +# Setup quantizer +quantizer = VulkanQuantizer() +quantizer.set_global(get_symmetric_quantization_config(is_dynamic=False, weight_bits=8)) + +# Export the model +exported_program = torch.export.export(linear_sequence_module, sample_inputs) +graph_module = exported_program.module() + +# Quantize the exported program with PT2E quantization flow +quantized_module = prepare_pt2e(graph_module, quantizer) +# Calibrate. In practice, this would be done by iterating over a real dataset +quantized_module(*sample_inputs) +quantized_module = convert_pt2e(quantized_module) + +# Export once more +exported_program = torch.export.export(quantized_module, sample_inputs) + +# Lower to vulkan +etvk_program = to_edge_transform_and_lower( + exported_program, + partitioner=[VulkanPartitioner()], +).to_executorch() +``` diff --git a/docs/source/backends/vulkan/vulkan-troubleshooting.md b/docs/source/backends/vulkan/vulkan-troubleshooting.md new file mode 100644 index 00000000000..9845f588004 --- /dev/null +++ b/docs/source/backends/vulkan/vulkan-troubleshooting.md @@ -0,0 +1,57 @@ +# Troubleshooting + +This page describes common issues that you may encounter when using the Vulkan +backend and how to debug and resolve them. + +## Vulkan Backend Not Found + +If you try to execute a .pte file that has been lowered to the Vulkan backend +and you see an error like: + +```shell +E 00:00:00.366934 executorch:method.cpp:74] Backend VulkanBackend is not registered. +``` + +This error indicates the Vulkan backend is not registered with the runtime. This +can happen because the backend was not compiled or linked, or because the +registration code was optimized out. + +First, make sure that when building ExecuTorch, cmake is configured with +`-DEXECUTORCH_BUILD_VULKAN=ON`. + +Next, make sure that your application is linking the `vulkan_backend` target, +or the `executorch_backends` target. + +Finally, ensure that `vulkan_backend` or `executorch_backends` is being linked +with the equivalent of `--whole-archive`. + +## Slow Performance + +Performance issues can be caused by a variety of factors: + +* A key compute shader (most often convolution or linear) is not performing well + on your target GPU +* Unsupported operators are causing too many graph breaks +* An existing operator is lacking support for some memory layout or storage type + resulting in a high number of copies being inserted to ensure tensors are in + a required representation for the next operator + +If you experience poor on-device performance for a particular model, please +obtain some profiling data while running your model. The +[profiling tutorial](./tutorials/etvk-profiling-tutorial.md) can +be a good reference for how to do this. + +Then, please file an issue on Github with the following details: + +* The device(s) you have tested with, and which devices exhibit poor performance + running the model +* The profiling data collected from executing the model +* The release version of ExecuTorch you are using, or the commit hash you built + from if you built from source +* If available, an export script that can be used to export your model to aid + in reproducing the issue +* If available, the `.pte` file you are testing with to aid in reproducing the + issue. + +We will do our best to patch performance problems in the Vulkan backend and +help you resolve your issue. diff --git a/docs/source/backends/xnnpack/op-support.csv b/docs/source/backends/xnnpack/op-support.csv new file mode 100644 index 00000000000..5350fed8d12 --- /dev/null +++ b/docs/source/backends/xnnpack/op-support.csv @@ -0,0 +1,47 @@ +Operator,Compute DType,Quantization,Constraints +_to_dim_order_copy,"fp16, fp32",,no dtype conversion +abs,"fp16, fp32",, +add,"fp16, fp32",PT2E: static int8,alpha=1 +avg_pool2d,"fp16, fp32",PT2E: static int8,"ceil_mode=False, count_include_pad=False, divisor_override=pooling_region" +bmm,"fp16, fp32",, +cat,"fp16, fp32",PT2E: static int8, +ceil,"fp16, fp32",, +clamp,"fp16, fp32",, +constant_pad_nd,"fp16, fp32",,no negative padding values +conv1d,"fp16, fp32","PT2E: static or dynamic int8 activations +8-bit weights, symmetric per-tensor or per-channel",constant weights +conv2d,"fp16, fp32","PT2E: static or dynamic int8 activations +8-bit weights, symmetric per-tensor or per-channel",constant weights +dequantize_per_tensor,"fp16, fp32",, +div,"fp16, fp32",, +elu,"fp16, fp32",, +exp,"fp16, fp32",, +floor,"fp16, fp32",, +gelu,"fp16, fp32",, +hardswish,"fp16, fp32",, +hardtanh,"fp16, fp32",, +leaky_relu,"fp16, fp32",, +linear,"fp16, fp32","PT2E: static or dynamic int8 activations +8-bit weights, symmetric per-tensor or per-channel + +quantize\_: 8-bit dynamic activations +4-bit groupwise weights",constant weights +log,"fp16, fp32",, +max_pool2d,"fp16, fp32",,"stride ≤ kernel_size, ceil_mode only for static shapes" +maximum,"fp16, fp32",, +mean,"fp16, fp32",,"4D tensors only; dims=[-2,-1] or [-1,-2]" +minimum,"fp16, fp32",, +mul,"fp16, fp32",PT2E: static int8, +neg,"fp16, fp32",, +permute_copy,"fp16, fp32",, +pow,"fp16, fp32",,power=2 only +quantize_per_tensor,"fp16, fp32",, +relu,"fp16, fp32",, +rsqrt,"fp16, fp32",, +sigmoid,"fp16, fp32",, +slice_copy,"fp16, fp32",,"no zero-dim tensors, no dynamic shapes" +softmax,"fp16, fp32",,dim must be last dimension +sqrt,"fp16, fp32",, +sub,"fp16, fp32",,alpha=1 +tanh,"fp16, fp32",, +upsample_bilinear2d,"fp16, fp32",,no dynamic output sizes diff --git a/docs/source/backends/xnnpack/xnnpack-arch-internals.md b/docs/source/backends/xnnpack/xnnpack-arch-internals.md new file mode 100644 index 00000000000..52bcd3704cb --- /dev/null +++ b/docs/source/backends/xnnpack/xnnpack-arch-internals.md @@ -0,0 +1,146 @@ +# Architecture and Internals + +This is a high-level overview of the ExecuTorch XNNPACK backend delegate. This high performance delegate is aimed to reduce CPU inference latency for ExecuTorch models. We will provide a brief introduction to the XNNPACK library and explore the delegate’s overall architecture and intended use cases. + +## What is XNNPACK? +XNNPACK is a library of highly-optimized neural network operators for ARM, x86, and WebAssembly architectures in Android, iOS, Windows, Linux, and macOS environments. It is an open source project, you can find more information about it on [github](https://github.com/google/XNNPACK). + +## What are ExecuTorch delegates? +A delegate is an entry point for backends to process and execute parts of the ExecuTorch program. Delegated portions of ExecuTorch models hand off execution to backends. The XNNPACK backend delegate is one of many available in ExecuTorch. It leverages the XNNPACK third-party library to accelerate ExecuTorch programs efficiently across a variety of CPUs. More detailed information on the delegates and developing your own delegates is available [here](/compiler-delegate-and-partitioner.md). It is recommended that you get familiar with that content before continuing on to the Architecture section. + +## Architecture +![High Level XNNPACK delegate Architecture](/backends/xnnpack/xnnpack-delegate-architecture.png) + +### Ahead-of-time +In the ExecuTorch export flow, lowering to the XNNPACK delegate happens at the `to_backend()` stage. In this stage, the model is partitioned by the `XnnpackPartitioner`. Partitioned sections of the graph are converted to a XNNPACK specific graph represenationed and then serialized via flatbuffer. The serialized flatbuffer is then ready to be deserialized and executed by the XNNPACK backend at runtime. + +![ExecuTorch XNNPACK delegate Export Flow](/backends/xnnpack/xnnpack-et-flow-diagram.png) + +#### Partitioner +The partitioner is implemented by backend delegates to mark nodes suitable for lowering. The `XnnpackPartitioner` lowers using node targets and module metadata. Some more references for partitioners can be found [here](/compiler-delegate-and-partitioner.md) + +##### Module-based partitioning + +`source_fn_stack` is embedded in the node’s metadata and gives information on where these nodes come from. For example, modules like `torch.nn.Linear` when captured and exported `to_edge` generate groups of nodes for their computation. The group of nodes associated with computing the linear module then has a `source_fn_stack` of `torch.nn.Linear. Partitioning based on `source_fn_stack` allows us to identify groups of nodes which are lowerable via XNNPACK. + +For example after capturing `torch.nn.Linear` you would find the following key in the metadata for the addmm node associated with linear: +```python +>>> print(linear_node.meta["source_fn_stack"]) +'source_fn_stack': ('fn', ) +``` + + +##### Op-based partitioning + +The `XnnpackPartitioner` also partitions using op targets. It traverses the graph and identifies individual nodes which are lowerable to XNNPACK. A drawback to module-based partitioning is that operators which come from [decompositions](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py) may be skipped. For example, an operator like `torch.nn.Hardsigmoid` is decomposed into add, muls, divs, and clamps. While hardsigmoid is not lowerable, we can lower the decomposed ops. Relying on `source_fn_stack` metadata would skip these lowerables because they belong to a non-lowerable module, so in order to improve model performance, we greedily lower operators based on the op targets as well as the `source_fn_stack`. + +##### Passes + +Before any serialization, we apply passes on the subgraphs to prepare the graph. These passes are essentially graph transformations that help improve the performance of the delegate. We give an overview of the most significant passes and their function below. For a description of all passes see [here](https://github.com/pytorch/executorch/tree/main/backends/xnnpack/_passes): + +* Channels Last Reshape + * ExecuTorch tensors tend to be contiguous before passing them into delegates, while XNNPACK only accepts channels-last memory layout. This pass minimizes the number of permutation operators inserted to pass in channels-last memory format. +* Conv1d to Conv2d + * Allows us to delegate Conv1d nodes by transforming them to Conv2d +* Conv and BN Fusion + * Fuses batch norm operations with the previous convolution node + +#### Serialiazation +After partitioning the lowerable subgraphs from the model, The XNNPACK delegate pre-processes these subgraphs and serializes them via flatbuffer for the XNNPACK backend. + + +##### Serialization Schema + +The XNNPACK delegate uses flatbuffer for serialization. In order to improve runtime performance, the XNNPACK delegate’s flatbuffer [schema](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/serialization/schema.fbs) mirrors the XNNPACK Library’s graph level API calls. The serialized data are arguments to XNNPACK’s APIs, so that at runtime, the XNNPACK execution graph can efficiently be created with successive calls to XNNPACK’s APIs. + +### Runtime +The XNNPACK backend’s runtime interfaces with the ExecuTorch runtime through the custom `init` and `execute` function. Each delegated subgraph is contained in an individually serialized XNNPACK blob. When the model is initialized, ExecuTorch calls `init` on all XNNPACK Blobs to load the subgraph from serialized flatbuffer. After, when the model is executed, each subgraph is executed via the backend through the custom `execute` function. To read more about how delegate runtimes interface with ExecuTorch, refer to this [resource](/compiler-delegate-and-partitioner.md). + + +#### **XNNPACK Library** +XNNPACK delegate supports CPU's on multiple platforms; more information on the supported hardware architectures can be found on the XNNPACK Library’s [README](https://github.com/google/XNNPACK). + +#### **Init** +When calling XNNPACK delegate’s `init`, we deserialize the preprocessed blobs via flatbuffer. We define the nodes (operators) and edges (intermediate tensors) to build the XNNPACK execution graph using the information we serialized ahead-of-time. As we mentioned earlier, the majority of processing has been done ahead-of-time, so that at runtime we can just call the XNNPACK APIs with the serialized arguments in succession. As we define static data into the execution graph, XNNPACK performs weight packing at runtime to prepare static data like weights and biases for efficient execution. After creating the execution graph, we create the runtime object and pass it on to `execute`. + +Since weight packing creates an extra copy of the weights inside XNNPACK, We free the original copy of the weights inside the preprocessed XNNPACK Blob, this allows us to remove some of the memory overhead. + + +#### **Execute** +When executing the XNNPACK subgraphs, we prepare the tensor inputs and outputs and feed them to the XNNPACK runtime graph. After executing the runtime graph, the output pointers are filled with the computed tensors. + +#### **Profiling** +We have enabled basic profiling for the XNNPACK delegate that can be enabled with the compiler flag `-DEXECUTORCH_ENABLE_EVENT_TRACER` (add `-DENABLE_XNNPACK_PROFILING` for additional details). With ExecuTorch's Developer Tools integration, you can also now use the Developer Tools to profile the model. You can follow the steps in [Using the ExecuTorch Developer Tools to Profile a Model](/tutorials/devtools-integration-tutorial) on how to profile ExecuTorch models and use Developer Tools' Inspector API to view XNNPACK's internal profiling information. An example implementation is available in the `executor_runner` (see [tutorial here](/tutorial-xnnpack-delegate-lowering.md#profiling)). + + +[comment]: <> (TODO: Refactor quantizer to a more official quantization doc) +## Quantization +The XNNPACK delegate can also be used as a backend to execute symmetrically quantized models. For quantized model delegation, we quantize models using the `XNNPACKQuantizer`. `Quantizers` are backend specific, which means the `XNNPACKQuantizer` is configured to quantize models to leverage the quantized operators offered by the XNNPACK Library. We will not go over the details of how to implement your custom quantizer, you can follow the docs [here](https://pytorch.org/tutorials/prototype/pt2e_quantizer.html) to do so. However, we will provide a brief overview of how to quantize the model to leverage quantized execution of the XNNPACK delegate. + +### Configuring the XNNPACKQuantizer + +```python +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, +) +quantizer = XNNPACKQuantizer() +quantizer.set_global(get_symmetric_quantization_config()) +``` +Here we initialize the `XNNPACKQuantizer` and set the quantization config to be symmetrically quantized. Symmetric quantization is when weights are symmetrically quantized with `qmin = -127` and `qmax = 127`, which forces the quantization zeropoints to be zero. `get_symmetric_quantization_config()` can be configured with the following arguments: +* `is_per_channel` + * Weights are quantized across channels +* `is_qat` + * Quantize aware training +* `is_dynamic` + * Dynamic quantization + +We can then configure the `XNNPACKQuantizer` as we wish. We set the following configs below as an example: +```python +quantizer.set_global(quantization_config) + .set_object_type(torch.nn.Conv2d, quantization_config) # can configure by module type + .set_object_type(torch.nn.functional.linear, quantization_config) # or torch functional op typea + .set_module_name("foo.bar", quantization_config) # or by module fully qualified name +``` + +### Quantizing your model with the XNNPACKQuantizer +After configuring our quantizer, we are now ready to quantize our model +```python +from torch.export import export + +exported_model = export(model_to_quantize, example_inputs).module() +prepared_model = prepare_pt2e(exported_model, quantizer) +print(prepared_model.graph) +``` +Prepare performs some Conv2d-BN fusion, and inserts quantization observers in the appropriate places. For Post-Training Quantization, we generally calibrate our model after this step. We run sample examples through the `prepared_model` to observe the statistics of the Tensors to calculate the quantization parameters. + +Finally, we convert our model here: +```python +quantized_model = convert_pt2e(prepared_model) +print(quantized_model) +``` +You will now see the Q/DQ representation of the model, which means `torch.ops.quantized_decomposed.dequantize_per_tensor` are inserted at quantized operator inputs and `torch.ops.quantized_decomposed.quantize_per_tensor` are inserted at operator outputs. Example: + +```python +def _qdq_quantized_linear( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, + bias_fp32, + out_scale, out_zero_point, out_quant_min, out_quant_max +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8) + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8) + return out_i8 +``` + + +You can read more indepth explanations on PyTorch 2 quantization [here](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html). + +## See Also +- [Integrating XNNPACK Delegate in Android AAR](/using-executorch-android.md) +- [Complete the Lowering to XNNPACK Tutorial](/tutorial-xnnpack-delegate-lowering.md) diff --git a/docs/source/xnnpack-delegate-architecture.png b/docs/source/backends/xnnpack/xnnpack-delegate-architecture.png similarity index 100% rename from docs/source/xnnpack-delegate-architecture.png rename to docs/source/backends/xnnpack/xnnpack-delegate-architecture.png diff --git a/docs/source/xnnpack-et-flow-diagram.png b/docs/source/backends/xnnpack/xnnpack-et-flow-diagram.png similarity index 100% rename from docs/source/xnnpack-et-flow-diagram.png rename to docs/source/backends/xnnpack/xnnpack-et-flow-diagram.png diff --git a/docs/source/backends/xnnpack/xnnpack-overview.md b/docs/source/backends/xnnpack/xnnpack-overview.md new file mode 100644 index 00000000000..5ef92c81126 --- /dev/null +++ b/docs/source/backends/xnnpack/xnnpack-overview.md @@ -0,0 +1,100 @@ +# XNNPACK Backend + +The XNNPACK delegate is the ExecuTorch solution for CPU execution on mobile CPUs. [XNNPACK](https://github.com/google/XNNPACK/tree/master) is a library that provides optimized kernels for machine learning operators on Arm and x86 CPUs. + +## Features + +- Wide operator support on Arm and x86 CPUs, available on any modern mobile phone. +- Support for a wide variety of quantization schemes and quantized operators. +- Supports fp32 and fp16 activations. +- Supports 8-bit quantization. + +## Target Requirements + +- ARM64 on Android, iOS, macOS, Linux, and Windows. +- ARMv7 (with NEON) on Android. +- ARMv6 (with VFPv2) on Linux. +- x86 and x86-64 (up to AVX512) on Windows, Linux, Android. + +## Development Requirements + +The XNNPACK delegate does not introduce any development system requirements beyond those required by +the core ExecuTorch runtime. + +---- + +## Using the XNNPACK Backend + +To target the XNNPACK backend during the export and lowering process, pass an instance of the `XnnpackPartitioner` to `to_edge_transform_and_lower`. The example below demonstrates this process using the MobileNet V2 model from torchvision. + +```python +import torch +import torchvision.models as models +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import to_edge_transform_and_lower + +mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() +sample_inputs = (torch.randn(1, 3, 224, 224), ) + +et_program = to_edge_transform_and_lower( + torch.export.export(mobilenet_v2, sample_inputs), + partitioner=[XnnpackPartitioner()], +).to_executorch() + +with open("mv2_xnnpack.pte", "wb") as file: + et_program.write_to_file(file) +``` + +See [Partitioner API](/backends/xnnpack/xnnpack-partitioner) for a reference on available partitioner options. + +---- + +## Quantization + +The XNNPACK delegate can also be used as a backend to execute symmetrically quantized models. See [XNNPACK Quantization](/backends/xnnpack/xnnpack-quantization) for more information on available quantization schemes and APIs. + +---- + +## Runtime Integration + +To run the model on-device, use the standard ExecuTorch runtime APIs. + +The XNNPACK delegate is included by default in the published Android, iOS, and pip packages. When building from source, pass `-DEXECUTORCH_BUILD_XNNPACK=ON` when configuring the CMake build to compile the XNNPACK backend. See [Running on Device](/getting-started.md#running-on-device) for more information. + +To link against the backend, add the `executorch_backends` CMake target as a build dependency, or link directly against `libxnnpack_backend`. Due to the use of static registration, it may be necessary to link with whole-archive. This can typically be done by passing `"$"` to `target_link_libraries`. + +``` +# CMakeLists.txt +add_subdirectory("executorch") +... +target_link_libraries( + my_target + PRIVATE executorch + executorch_backends + ... +) +``` + +No additional steps are necessary to use the backend beyond linking the target. Any XNNPACK-delegated .pte file will automatically run on the registered backend. + +## Reference + +**→{doc}`/backends/xnnpack/xnnpack-troubleshooting` — Debug common issues.** + +**→{doc}`/backends/xnnpack/xnnpack-partitioner` — Partitioner options and supported operators.** + +**→{doc}`/backends/xnnpack/xnnpack-quantization` — Supported quantization schemes.** + +**→{doc}`/backends/xnnpack/xnnpack-arch-internals` — XNNPACK backend internals.** + +```{toctree} +:maxdepth: 2 +:hidden: +:caption: XNNPACK Backend + +xnnpack-partitioner +xnnpack-quantization +xnnpack-troubleshooting +xnnpack-arch-internals +``` diff --git a/docs/source/backends/xnnpack/xnnpack-partitioner.rst b/docs/source/backends/xnnpack/xnnpack-partitioner.rst new file mode 100644 index 00000000000..a0881aa3a6a --- /dev/null +++ b/docs/source/backends/xnnpack/xnnpack-partitioner.rst @@ -0,0 +1,24 @@ +=============== +Partitioner API +=============== + +The XNNPACK partitioner API allows for configuration of the model delegation to XNNPACK. Passing an ``XnnpackPartitioner`` instance with no additional parameters will run as much of the model as possible on the XNNPACK backend. This is the most common use-case. For advanced use cases, the partitioner exposes the following options via the `constructor `_: + +- ``configs``: Control which operators are delegated to XNNPACK. By default, all available operators all delegated. See `../config/__init__.py `_ for an exhaustive list of available operator configs. +- ``config_precisions``: Filter operators by data type. By default, delegate all precisions. One or more of ``ConfigPrecisionType.FP32``, ``ConfigPrecisionType.STATIC_QUANT``, or ``ConfigPrecisionType.DYNAMIC_QUANT``. See `ConfigPrecisionType `_. +- ``per_op_mode``: If true, emit individual delegate calls for every operator. This is an advanced option intended to reduce memory overhead in some contexts at the cost of a small amount of runtime overhead. Defaults to false. +- ``verbose``: If true, print additional information during lowering. + +================ +Operator Support +================ + +This section lists the operators supported by the XNNPACK backend. Operators are the building blocks of the ML model. See `IRs `_ for more information on the PyTorch operator set. + +All operators support dynamic input shapes unless otherwise noted. + +.. csv-table:: Operator Support + :file: op-support.csv + :header-rows: 1 + :widths: 20 15 30 30 + :align: center diff --git a/docs/source/backends/xnnpack/xnnpack-quantization.md b/docs/source/backends/xnnpack/xnnpack-quantization.md new file mode 100644 index 00000000000..e0180393f9e --- /dev/null +++ b/docs/source/backends/xnnpack/xnnpack-quantization.md @@ -0,0 +1,94 @@ +# Quantization + +The XNNPACK delegate can also be used as a backend to execute symmetrically quantized models. To quantize a PyTorch model for the XNNPACK backend, use the `XNNPACKQuantizer`. `Quantizers` are backend specific, which means the `XNNPACKQuantizer` is configured to quantize models to leverage the quantized operators offered by the XNNPACK Library. + +### Supported Quantization Schemes +The XNNPACK delegate supports the following quantization schemes: + +- 8-bit symmetric weights with 8-bit asymmetric activations (via the PT2E quantization flow). + - Supports both static and dynamic activations. + - Supports per-channel and per-tensor schemes. + - Supports linear, convolution, add, mul, cat, and adaptive avg pool 2d operators. + +Weight-only quantization is not currently supported on XNNPACK. + +### 8-bit Quantization using the PT2E Flow + +To perform 8-bit quantization with the PT2E flow, perform the following steps prior to exporting the model: + +1) Create an instance of the `XnnpackQuantizer` class. Set quantization parameters. +2) Use `torch.export.export` to prepare for quantization. +3) Call `prepare_pt2e` to prepare the model for quantization. +4) For static quantization, run the prepared model with representative samples to calibrate the quantizated tensor activation ranges. +5) Call `convert_pt2e` to quantize the model. +6) Export and lower the model using the standard flow. + +The output of `convert_pt2e` is a PyTorch model which can be exported and lowered using the normal flow. As it is a regular PyTorch model, it can also be used to evaluate the accuracy of the quantized model using standard PyTorch techniques. + +```python +import torch +import torchvision.models as models +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import to_edge_transform_and_lower +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + +model = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() +sample_inputs = (torch.randn(1, 3, 224, 224), ) + +qparams = get_symmetric_quantization_config(is_per_channel=True) # (1) +quantizer = XNNPACKQuantizer() +quantizer.set_global(qparams) + +training_ep = torch.export.export(model, sample_inputs).module() # (2) +prepared_model = prepare_pt2e(training_ep, quantizer) # (3) + +for cal_sample in [torch.randn(1, 3, 224, 224)]: # Replace with representative model inputs + prepared_model(cal_sample) # (4) Calibrate + +quantized_model = convert_pt2e(prepared_model) # (5) + +et_program = to_edge_transform_and_lower( # (6) + torch.export.export(quantized_model, sample_inputs), + partitioner=[XnnpackPartitioner()], +).to_executorch() +``` + +See [PyTorch 2 Export Post Training Quantization](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) for more information. + +### LLM quantization with quantize_ + +The XNNPACK backend also supports quantizing models with the [torchao](https://github.com/pytorch/ao) quantize_ API. This is most commonly used for LLMs, requiring more advanced quantization. Since quantize_ is not backend aware, it is important to use a config that is compatible with CPU/XNNPACK: + +* Quantize embeedings with `IntxWeightOnlyConfig` (with weight_dtype torch.int2, torch.int4, or torch.int8, using PerGroup or PerAxis granularity) +* Quantize linear layers with 4 bit weight and 8bit dynamic activation, use `Int8DynamicActivationIntxWeightConfig` (with weight_dtype=torch.int4, using PerGroup or PerAxis granularity) + +Below is a simple example, but a more detailed tutorial including accuracy evaluation on popular LLM benchmarks can be found in the [torchao documentation](https://docs.pytorch.org/ao/main/serving.html#mobile-deployment-with-executorch). + +```python +from torchao.quantization.granularity import PerGroup, PerAxis +from torchao.quantization.quant_api import ( + IntxWeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + quantize_, +) + +# Quantize embeddings with 8-bits, per channel +embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), +) +qunatize_( + eager_model, + lambda m, fqn: isinstance(m, torch.nn.Embedding), +) + + +# Quatize linear layers with 8-bit dynamic activations and 4-bit weights +linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), +) +quantize_(eager_model, linear_config) +``` diff --git a/docs/source/backends/xnnpack/xnnpack-troubleshooting.md b/docs/source/backends/xnnpack/xnnpack-troubleshooting.md new file mode 100644 index 00000000000..2716937e893 --- /dev/null +++ b/docs/source/backends/xnnpack/xnnpack-troubleshooting.md @@ -0,0 +1,49 @@ +# Troubleshooting + +This page describes common issues that you may encounter when using the XNNPACK backend and how to debug and resolve them. + +## XNNPACK Backend Not Found + +This error indicates the XNNPACK backend is not registered with the runtime. This can happen because the backend was not compiled or linked, or because the registration code was optimized out. + +The XNNPACK backend is built by default for Python, Android, iOS, and in most CMake presets. + +* Set the `EXECUTORCH_BUILD_XNNPACK=ON` CMake option option when building from source. + * Either by passing the option during CMake configuration or setting it inside the user CMake logic before including ExecuTorch. + * See [Building from Source](/using-executorch-building-from-source). +* On iOS, link the `backend_xnnpack` [framework](/using-executorch-ios). +* If the backend is still not found, link with `WHOLE_ARCHIVE`. + * Pass `"LINK_LIBRARY:WHOLE_ARCHIVE,xnnpack_backend>"` to `target_link_libraries` in CMake. + +## Slow Performance + + * Try reducing the thread count using [_unsafe_reset_threadpool](/using-executorch-faqs.md#inference-is-slow-performance-troubleshooting). + * Small models may benefit from using fewer threads than default. + * Try values between 1 and 4 threads and measure performance on your model. + * Use [op-level profiling](/tutorials/devtools-integration-tutorial) to understand which operators are taking the most time. + * The XNNPACK backend provides operator-level timing for delegated operators. + * See general performance troubleshooting tips in [Performance Troubleshooting](/using-executorch-faqs.md#inference-is-slow-performance-troubleshooting). + + +## Debugging Why Nodes Are Not Partitioned + +* To debug cases where operators are not delegated to XNNPACK, +you can enable internal debug logs before **to_edge_transform_and_lower** from the partitioner. This will print diagnostic messages explaining why specific nodes fail +to partition. + +``` python + +# caption: Enable internal partition debug logging + +import logging + +logger = logging.getLogger("executorch.backends.xnnpack.partition") +logger.setLevel(logging.DEBUG) + +if not logger.handlers: + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + formatter = logging.Formatter("[%(levelname)s] %(name)s: %(message)s") + ch.setFormatter(formatter) + logger.addHandler(ch) +``` \ No newline at end of file diff --git a/docs/source/build-run-openvino.md b/docs/source/build-run-openvino.md index dc6f098850f..9b4c48fee5a 100644 --- a/docs/source/build-run-openvino.md +++ b/docs/source/build-run-openvino.md @@ -61,7 +61,7 @@ For more information about OpenVINO build, refer to the [OpenVINO Build Instruct Follow the steps below to setup your build environment: -1. **Setup ExecuTorch Environment**: Refer to the [Environment Setup](getting-started-setup.md#environment-setup) guide for detailed instructions on setting up the ExecuTorch environment. +1. **Setup ExecuTorch Environment**: Refer to the [Environment Setup](using-executorch-building-from-source.md#environment-setup) guide for detailed instructions on setting up the ExecuTorch environment. 2. **Setup OpenVINO Backend Environment** - Install the dependent libs. Ensure that you are inside `executorch/backends/openvino/` directory @@ -92,7 +92,7 @@ The exported model will be saved as 'resnet50.pte' in the current directory. ### Build C++ OpenVINO Examples -After building the OpenVINO backend following the [instructions](#setup) above, the executable will be saved in `/cmake-out/backends/openvino/`. +After building the OpenVINO backend following the [instructions](#setup) above, the executable will be saved in `/cmake-out/`. The executable requires a model file (`.pte` file generated in the aot step) and the number of inference executions. @@ -101,7 +101,7 @@ The executable requires a model file (`.pte` file generated in the aot step) and Run inference with a given model for 10 executions: ``` -./openvino_executor_runner \ +./executor_runner \ --model_path=model.pte \ --num_executions=10 ``` diff --git a/docs/source/bundled-io.md b/docs/source/bundled-io.md index 79897737268..d901710bfb7 100644 --- a/docs/source/bundled-io.md +++ b/docs/source/bundled-io.md @@ -17,7 +17,7 @@ This stage mainly focuses on the creation of a `BundledProgram` and dumping it o ### Step 1: Create a Model and Emit its ExecuTorch Program. -ExecuTorch Program can be emitted from user's model by using ExecuTorch APIs. Follow the [Generate and emit sample ExecuTorch program](getting-started.md#exporting) or [Exporting to ExecuTorch tutorial](https://pytorch.org/executorch/main/tutorials/export-to-executorch-tutorial). +ExecuTorch Program can be emitted from user's model by using ExecuTorch APIs. Follow the [Generate and emit sample ExecuTorch program](getting-started.md#exporting) or [Exporting to ExecuTorch tutorial](tutorials/export-to-executorch-tutorial) . ### Step 2: Construct `List[MethodTestSuite]` to hold test info @@ -194,18 +194,18 @@ regenerate_bundled_program = deserialize_from_flatbuffer_to_bundled_program(seri ``` ## Runtime Stage -This stage mainly focuses on executing the model with the bundled inputs and and comparing the model's output with the bundled expected output. We provide multiple APIs to handle the key parts of it. +This stage mainly focuses on executing the model with the bundled inputs and comparing the model's output with the bundled expected output. We provide multiple APIs to handle the key parts of it. ### Get ExecuTorch Program Pointer from `BundledProgram` Buffer We need the pointer to ExecuTorch program to do the execution. To unify the process of loading and executing `BundledProgram` and Program flatbuffer, we create an API for this -`executorch::bundled_program::get_program_data`. Check out an [example usage](https://github.com/pytorch/executorch/blob/release/0.6/examples/devtools/example_runner/example_runner.cpp#L128-L137) of this API. +`executorch::bundled_program::get_program_data`. Check out an [example usage](https://github.com/pytorch/executorch/blob/release/1.0/examples/devtools/example_runner/example_runner.cpp#L128-L137) of this API. ### Load Bundled Input to Method -To execute the program on the bundled input, we need to load the bundled input into the method. Here we provided an API called `executorch::bundled_program::load_bundled_input`. Check out an [example usage](https://github.com/pytorch/executorch/blob/release/0.6/examples/devtools/example_runner/example_runner.cpp#L253-L259) of this API. +To execute the program on the bundled input, we need to load the bundled input into the method. Here we provided an API called `executorch::bundled_program::load_bundled_input`. Check out an [example usage](https://github.com/pytorch/executorch/blob/release/1.0/examples/devtools/example_runner/example_runner.cpp#L253-L259) of this API. ### Verify the Method's Output. -We call `executorch::bundled_program::verify_method_outputs` to verify the method's output with bundled expected outputs. Check out an [example usage](https://github.com/pytorch/executorch/blob/release/0.6/examples/devtools/example_runner/example_runner.cpp#L300-L311) of this API. +We call `executorch::bundled_program::verify_method_outputs` to verify the method's output with bundled expected outputs. Check out an [example usage](https://github.com/pytorch/executorch/blob/release/1.0/examples/devtools/example_runner/example_runner.cpp#L301-L307) of this API. ### Runtime Example @@ -213,13 +213,29 @@ Please checkout our [example runner](https://github.com/pytorch/executorch/blob/ ```bash cd executorch - ./examples/devtools/build_example_runner.sh - ./cmake-out/examples/devtools/example_runner --bundled_program_path {your-bpte-file} --output_verification +./examples/devtools/build_example_runner.sh +./cmake-out/examples/devtools/example_runner --bundled_program_path {your-bpte-file} --output_verification ``` It is expected to see no output from running the above mentioned snippet. +For a detailed example of how the runner should be like, please refer to our [example runner](https://github.com/pytorch/executorch/blob/release/1.0/examples/devtools/example_runner/example_runner.cpp). -For a detailed example of how the runner should be like, please refer to our [example runner](https://github.com/pytorch/executorch/blob/release/0.6/examples/devtools/example_runner/example_runner.cpp). + +### Try the Complete Workflow + +To test the entire end-to-end workflow including building the example runner, exporting a model, and verifying the bundled program execution, you can use the test script: + +```bash +cd executorch +./examples/devtools/test_example_runner.sh +``` + +This script will: +1. Build the example runner using `build_example_runner.sh` +2. Export a MobileNetV2 model as a bundled program +3. Run the example runner with the bundled program to verify correctness + +This is a great way to ensure your environment is set up correctly and to see the complete BundledProgram workflow in action. ## Common Errors diff --git a/docs/source/compiler-delegate-and-partitioner.md b/docs/source/compiler-delegate-and-partitioner.md index c633bb1fd12..c0449e7366b 100644 --- a/docs/source/compiler-delegate-and-partitioner.md +++ b/docs/source/compiler-delegate-and-partitioner.md @@ -1,4 +1,4 @@ -# Backends and Delegates +# Understanding Backends and Delegates Audience: Vendors, Backend Delegate developers, who are interested in integrating their own compilers and hardware as part of ExecuTorch @@ -37,7 +37,7 @@ The diagram looks like following There are mainly two Ahead-of-Time entry point for backend to implement: `partition` and `preprocess`. `partitioner` is an algorithm implemented by the backend to tag the nodes to be lowered to the backend. `to_backend` API will apply the partition algorithm and lower each subgraph, which consists of connected tagged nodes, to the targeted backend. Every subgraph -will be sent to the `preprocess` part provided by the backend to compiled as a binary blob. +will be sent to the `preprocess` part provided by the backend to be compiled as a binary blob. During partition, the `exported_program` is not allowed to mutate the program, and it's supposed to apply tag to each node. The `PartitionResult` includes both tagged exported program and the partition tags dictionary for `to_backend` to look up the tag and @@ -131,7 +131,7 @@ static auto success_with_compiler = register_backend(backend); Providing consistent debugging experience, be it for runtime failures or performance profiling, is important. ExecuTorch employs native Developer Tools for this purpose, which enables correlating program instructions to original PyTorch code, via debug handles. You can read more about it [here](etrecord.rst). -Delegated program or subgraphs are opaque to ExecuTorch runtime and appear as a special `call_delegate` instruction, which asks corresponding backend to handle the execution of the subgraph or program. Due to the opaque nature of backend delgates, native Developer Tools does not have visibility into delegated program. Thus the debugging, functional or performance, experiences of delegated execution suffers significantly as compared to it's non-delegated counterpart. +Delegated program or subgraphs are opaque to ExecuTorch runtime and appear as a special `call_delegate` instruction, which asks corresponding backend to handle the execution of the subgraph or program. Due to the opaque nature of backend delegates, native Developer Tools does not have visibility into delegated program. Thus the debugging, functional or performance, experiences of delegated execution suffers significantly as compared to it's non-delegated counterpart. In order to provide consistent debugging experience to users, regardless of the use of delegation for a model, Developer Tools provide an interface to correlate delegated (sub)graph to original (sub)graph. The Developer Tools do so via debug handles map which allows delegates to generate internal handles that can be associated with the original (sub)graph consumed by the delegate. Then at runtime, backend developer can report error or profiling information using the internal handle, which will be mapped to original (sub)graph using the debug handle map. For more information, please refer to [Delegate Debugging](delegate-debugging.md). @@ -194,8 +194,8 @@ qnnpack is one backend and xnnpack is another backend. We haven't open-sourced these two backends delegates yet, and this example won't run out of box. It can be used as a reference to see how it can be done. -This option is easy to try becuase usually all backends will implement their own -parititioner. However this option may get different results if we change the +This option is easy to try because usually all backends will implement their own +partitioner. However this option may get different results if we change the order of to_backend call. If we want to have a better control on the nodes, like which backend they should go, option 2 is better. diff --git a/docs/source/compiler-entry-points.md b/docs/source/compiler-entry-points.md new file mode 100644 index 00000000000..ac5623c6769 --- /dev/null +++ b/docs/source/compiler-entry-points.md @@ -0,0 +1,9 @@ +# Compiler Entry Points + +```{toctree} +:maxdepth: 1 + +compiler-backend-dialect +compiler-custom-compiler-passes +compiler-memory-planning +``` diff --git a/docs/source/compiler-ir-advanced.md b/docs/source/compiler-ir-advanced.md new file mode 100644 index 00000000000..b6d24026d5a --- /dev/null +++ b/docs/source/compiler-ir-advanced.md @@ -0,0 +1,31 @@ +(compiler-ir-advanced)= +# Compiler & IR + +Advanced compiler features and intermediate representation specifications. + +## Compiler Passes + +- {doc}`compiler-custom-compiler-passes` — Custom compiler passes and optimization + +## Memory Management + +- {doc}`compiler-memory-planning` — Advanced memory planning strategies + +## Intermediate Representation + +- {doc}`ir-exir` — EXIR (Export Intermediate Representation) specification +- {doc}`ir-ops-set-definition` — Ops set definition and operator standardization + +## Backend dialect + +- {doc}`compiler-backend-dialect` — Backend dialect and compiler integration + +```{toctree} +:hidden: +:maxdepth: 1 + +compiler-custom-compiler-passes +compiler-memory-planning +ir-exir +ir-ops-set-definition +compiler-backend-dialect diff --git a/docs/source/conf.py b/docs/source/conf.py index 65845c03868..f69fc243255 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -24,7 +24,7 @@ import sys from typing import Any -import pytorch_sphinx_theme +import pytorch_sphinx_theme2 # type: ignore[import-not-found] # To let us import ./custom_directives.py sys.path.insert(0, os.path.abspath(".")) @@ -63,13 +63,10 @@ "sphinx_design", "sphinx_gallery.gen_gallery", "sphinx_reredirects", + "sphinx_sitemap", + "sphinxcontrib.mermaid", ] -if not FBCODE: - extensions += [ - "executorch_custom_versions", - ] - this_file_dir = os.path.abspath(os.path.dirname(__file__)) doxygen_xml_dir = os.path.join( os.path.dirname(this_file_dir), # {repo_root}/docs/ @@ -77,36 +74,60 @@ "xml", # {repo_root}/docs/cpp/build/xml ) -html_favicon = "_static/img/ExecuTorch-Logo-cropped.svg" - -# Get ET_VERSION_DOCS during the build. -et_version_docs = os.environ.get("ET_VERSION_DOCS", None) -print(f"et_version_docs: {et_version_docs}") - -# The code below will cut version displayed in the dropdown like this: -# By default, set to "main". -# If it's a tag like refs/tags/v1.2.3-rc4 or refs/tags/v1.2.3, then -# cut to 1.2 -# the version varible is used in layout.html: https://github.com/pytorch/executorch/blob/main/docs/source/_templates/layout.html#L29 -version = release = "main" -if et_version_docs: - if et_version_docs.startswith("refs/tags/v"): - version = ".".join( - et_version_docs.split("/")[-1].split("-")[0].lstrip("v").split(".")[:2] - ) - elif et_version_docs.startswith("refs/heads/release/"): - version = et_version_docs.split("/")[-1] -print(f"Version: {version}") -html_title = " ".join((project, version, "documentation")) +html_favicon = "_static/img/executorch-chip-logo.svg" + +# Import executorch version +# Adopted from PyTorch docs pattern +from executorch import version as et_version # type: ignore[attr-defined] + +executorch_version = str(et_version.__version__) + +# Check if this is a release build from environment variable +# The workflow sets RELEASE=true for tagged releases, RELEASE=false otherwise +# We need to properly parse the string as a boolean (any non-empty string is truthy in Python) +RELEASE = os.environ.get("RELEASE", "false").lower() == "true" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = "main" +# The full version, including alpha/beta/rc tags. +release = "main" + +# Customized html_title here. +# Default is " ".join(project, release, "documentation") if not set +if RELEASE: + # Turn 0.8.0a0+a90e907 into 0.8 + # Note: the release candidates should no longer have the aHASH suffix, but in any + # case we wish to leave only major.minor, even for rc builds. + version = ".".join(executorch_version.split("+")[0].split(".")[:2]) + html_title = " ".join((project, version, "documentation")) + release = version + +switcher_version = "main" if not RELEASE else version + +print(f"executorch_version: {executorch_version}") +print(f"Version: {version}, RELEASE: {RELEASE}") + +html_baseurl = "https://docs.pytorch.org/executorch/" # needed for sphinx-sitemap +sitemap_locales = [None] +sitemap_excludes = [ + "search.html", + "genindex.html", +] +sitemap_url_scheme = "{link}" breathe_projects = {"ExecuTorch": "../build/xml/"} breathe_default_project = "ExecuTorch" -templates_path = ["_templates"] autodoc_typehints = "description" myst_enable_extensions = [ "colon_fence", + "deflist", + "html_image", ] myst_heading_anchors = 4 @@ -162,38 +183,93 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = "pytorch_sphinx_theme" -html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] +html_theme = "pytorch_sphinx_theme2" +html_theme_path = [pytorch_sphinx_theme2.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # + html_theme_options = { + "logo": { + "image_light": "_static/img/et-logo.png", + "image_dark": "_static/img/et-logo.png", + }, + "navigation_with_keys": False, + "canonical_url": "https://docs.pytorch.org/executorch/stable/", + "switcher": { + "json_url": "https://docs.pytorch.org/executorch/executorch-versions.json", # for testing only, will need to replace to the correct json file on the executorch website when it's added in the repo. + "version_match": switcher_version, + }, + "show_toc_level": 2, + "analytics_id": "GTM-T8XT4PS", + "icon_links": [ + { + "name": "X", + "url": "https://x.com/PyTorch", + "icon": "fa-brands fa-x-twitter", + }, + { + "name": "GitHub", + "url": "https://github.com/pytorch/executorch", + "icon": "fa-brands fa-github", + }, + { + "name": "Discourse", + "url": "https://discuss.pytorch.org/", + "icon": "fa-brands fa-discourse", + }, + { + "name": "PyPi", + "url": "https://pypi.org/project/executorch", + "icon": "fa-brands fa-python", + }, + ], + "show_version_warning_banner": True, + "use_edit_page_button": True, + "header_links_before_dropdown": 8, + "navbar_align": "left", + "navbar_start": ["navbar-logo", "version-switcher"], + "navbar_center": ["navbar-nav"], + "navbar_end": ["search-field-custom", "theme-switcher", "navbar-icon-links"], + "navbar_persistent": [], +} + +theme_variables = pytorch_sphinx_theme2.get_theme_variables() +templates_path = [ + "_templates", + os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"), +] + +html_context = { + "theme_variables": theme_variables, + "display_github": True, + "github_url": "https://github.com", + "github_user": "pytorch", + "github_repo": "executorch", + "feedback_url": "https://github.com/pytorch/executorch", + "github_version": "main", + "doc_path": "docs/source", "pytorch_project": "executorch", "display_version": True, - "logo_only": True, - "collapse_navigation": True, # changed to True to enable 3rd level nav. - "sticky_navigation": False, - "navigation_depth": 4, - "includehidden": True, - "titles_only": False, - "analytics_id": "GTM-T8XT4PS", } + # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] -html_css_files = ["css/custom.css", "progress-bar.css"] -html_js_files = ["js/progress-bar.js"] +# Add custom 404 page for GitHub Pages +html_additional_pages = {"404": "404.html"} + # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { "python": ("https://docs.python.org/", None), "numpy": ("https://numpy.org/doc/stable/", None), - "torch": ("https://pytorch.org/docs/stable/", None), + "torch": ("https://docs.pytorch.org/docs/stable/", None), } # Redirects for moved pages @@ -202,7 +278,8 @@ "export-overview": "using-executorch-export.html", "runtime-build-and-cross-compilation": "using-executorch-building-from-source.html", "tutorials/export-to-executorch-tutorial": "../using-executorch-export.html", - "build-run-vulkan": "backends-vulkan.html", + "build-run-vulkan": "backends/vulkan/vulkan-overview.html", + "backends-vulkan": "backends/vulkan/vulkan-overview.html", "executorch-arm-delegate-tutorial": "backends-arm-ethos-u.html", "build-run-coreml": "backends-coreml.html", "build-run-mediatek-backend": "backends-mediatek.html", diff --git a/docs/source/debug-backend-delegate.md b/docs/source/debug-backend-delegate.md index 86dddd75868..efb4653a994 100644 --- a/docs/source/debug-backend-delegate.md +++ b/docs/source/debug-backend-delegate.md @@ -6,60 +6,607 @@ We provide a list of util functions to give users insights on what happened to t The `get_delegation_info()` method provides a summary of what happened to the model after the `to_backend()` call: ```python +import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import to_edge_transform_and_lower +from torch.export import Dim, export +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights +import torchvision.models as models + +# Dependency needed for debugging delegates from executorch.devtools.backend_debug import get_delegation_info from tabulate import tabulate -# ... After call to to_backend(), but before to_executorch() -graph_module = edge_manager.exported_program().graph_module + +model = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() +sample_inputs = (torch.randn(1, 3, 224, 224), ) + +et_program = to_edge_transform_and_lower( + torch.export.export(model, sample_inputs), + partitioner=[XnnpackPartitioner()] +) +graph_module = et_program.exported_program().graph_module delegation_info = get_delegation_info(graph_module) +# print the summary like the number of delegated nodes, non-delegated nodes, etc print(delegation_info.get_summary()) df = delegation_info.get_operator_delegation_dataframe() +# print the table including op_type, occurrences_in_delegated_graphs, occurrences_in_non_delegated_graphs print(tabulate(df, headers="keys", tablefmt="fancy_grid")) ``` Example printout: ``` -Total delegated subgraphs: 86 -Number of delegated nodes: 473 -Number of non-delegated nodes: 430 +Total delegated subgraphs: 2 +Number of delegated nodes: 203 +Number of non-delegated nodes: 4 ``` +| | op_type | occurrences_in_delegated_graphs | occurrences_in_non_delegated_graphs | +|----|---------------------------------------------------|---------------------------------|-------------------------------------| +| 0 | aten__native_batch_norm_legit_no_training_default | 52 | 0 | +| 1 | aten_add_tensor | 10 | 0 | +| 2 | aten_convolution_default | 52 | 0 | +| 3 | aten_hardtanh_default | 35 | 0 | +| 4 | aten_linear_default | 1 | 0 | +| 5 | aten_mean_dim | 1 | 0 | +| 6 | aten_view_copy_default | 0 | 1 | +| 7 | dim_order_ops__clone_dim_order_default | 0 | 1 | +| 8 | getitem | 52 | 2 | +| 9 | **Total** | **203** | **4** | -| | op_type | occurrences_in_delegated_graphs | occurrences_in_non_delegated_graphs | -|----|---------------------------------|------- |-----| -| 0 | aten__softmax_default | 12 | 0 | -| 1 | aten_add_tensor | 37 | 0 | -| 2 | aten_addmm_default | 48 | 0 | -| 3 | aten_arange_start_step | 0 | 25 | -| | ... | | | -| 23 | aten_view_copy_default | 170 | 48 | -| | ... | | | -| 26 | Total | 473 | 430 | -From the table, the operator `aten_view_copy_default` appears 170 times in delegate graphs and 48 times in non-delegated graphs. Users can use information like this to debug. +From the table, the operator `aten_view_copy_default` appears 0 times in delegate graphs and 1 times in non-delegated graphs. Users can use information like this to debug. `get_item node` is a special case, it means getting the output from the delegate subgraph. ## Visualize delegated graph -To see a more detailed view, use the `format_delegated_graph()` method to get a str of printout of the whole graph or use `print_delegated_graph()` to print directly: +To see a more detailed view, use the `format_delegated_graph()` method to get a string representation of the entire graph or use `print_delegated_graph()` to print directly: ```python from executorch.exir.backend.utils import format_delegated_graph -graph_module = edge_manager.exported_program().graph_module +graph_module = et_program.exported_program().graph_module print(format_delegated_graph(graph_module)) # or call print_delegated_graph(graph_module) ``` -It will print the whole model as well as the subgraph consumed by the backend. The generic debug function provided by fx like `print_tabular()` or `print_readable()` will only show `call_delegate` but hide the the subgraph consumes by the backend, while this function exposes the contents inside the subgraph. +It will print the whole model as well as the subgraph consumed by the backend. The generic debug function provided by fx like `print_tabular()` or `print_readable()` will only show `call_delegate` and hide the subgraph consumed by the backend, while this function exposes the contents inside the subgraph. -In the example printout below, observe that `embedding` and `add` operators are delegated to `XNNPACK` while the `sub` operator is not. +In the example printout below, observe that there are two subgraphs, `aten_view_copy_default` is not delegated, while most of the others ops are delegated. +
``` -%aten_unsqueeze_copy_default_22 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.unsqueeze_copy.default](args = (%aten_arange_start_step_23, -2), kwargs = {}) - %aten_unsqueeze_copy_default_23 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.unsqueeze_copy.default](args = (%aten_arange_start_step_24, -1), kwargs = {}) +graph(): + %b_features_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_0_1_num_batches_tracked] + %b_features_1_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_1_conv_0_1_num_batches_tracked] + %b_features_1_conv_2_num_batches_tracked : [num_users=0] = placeholder[target=b_features_1_conv_2_num_batches_tracked] + %b_features_2_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_2_conv_0_1_num_batches_tracked] + %b_features_2_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_2_conv_1_1_num_batches_tracked] + %b_features_2_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_2_conv_3_num_batches_tracked] + %b_features_3_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_3_conv_0_1_num_batches_tracked] + %b_features_3_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_3_conv_1_1_num_batches_tracked] + %b_features_3_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_3_conv_3_num_batches_tracked] + %b_features_4_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_4_conv_0_1_num_batches_tracked] + %b_features_4_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_4_conv_1_1_num_batches_tracked] + %b_features_4_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_4_conv_3_num_batches_tracked] + %b_features_5_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_5_conv_0_1_num_batches_tracked] + %b_features_5_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_5_conv_1_1_num_batches_tracked] + %b_features_5_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_5_conv_3_num_batches_tracked] + %b_features_6_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_6_conv_0_1_num_batches_tracked] + %b_features_6_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_6_conv_1_1_num_batches_tracked] + %b_features_6_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_6_conv_3_num_batches_tracked] + %b_features_7_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_7_conv_0_1_num_batches_tracked] + %b_features_7_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_7_conv_1_1_num_batches_tracked] + %b_features_7_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_7_conv_3_num_batches_tracked] + %b_features_8_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_8_conv_0_1_num_batches_tracked] + %b_features_8_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_8_conv_1_1_num_batches_tracked] + %b_features_8_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_8_conv_3_num_batches_tracked] + %b_features_9_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_9_conv_0_1_num_batches_tracked] + %b_features_9_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_9_conv_1_1_num_batches_tracked] + %b_features_9_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_9_conv_3_num_batches_tracked] + %b_features_10_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_10_conv_0_1_num_batches_tracked] + %b_features_10_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_10_conv_1_1_num_batches_tracked] + %b_features_10_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_10_conv_3_num_batches_tracked] + %b_features_11_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_11_conv_0_1_num_batches_tracked] + %b_features_11_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_11_conv_1_1_num_batches_tracked] + %b_features_11_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_11_conv_3_num_batches_tracked] + %b_features_12_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_12_conv_0_1_num_batches_tracked] + %b_features_12_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_12_conv_1_1_num_batches_tracked] + %b_features_12_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_12_conv_3_num_batches_tracked] + %b_features_13_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_13_conv_0_1_num_batches_tracked] + %b_features_13_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_13_conv_1_1_num_batches_tracked] + %b_features_13_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_13_conv_3_num_batches_tracked] + %b_features_14_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_14_conv_0_1_num_batches_tracked] + %b_features_14_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_14_conv_1_1_num_batches_tracked] + %b_features_14_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_14_conv_3_num_batches_tracked] + %b_features_15_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_15_conv_0_1_num_batches_tracked] + %b_features_15_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_15_conv_1_1_num_batches_tracked] + %b_features_15_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_15_conv_3_num_batches_tracked] + %b_features_16_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_16_conv_0_1_num_batches_tracked] + %b_features_16_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_16_conv_1_1_num_batches_tracked] + %b_features_16_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_16_conv_3_num_batches_tracked] + %b_features_17_conv_0_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_17_conv_0_1_num_batches_tracked] + %b_features_17_conv_1_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_17_conv_1_1_num_batches_tracked] + %b_features_17_conv_3_num_batches_tracked : [num_users=0] = placeholder[target=b_features_17_conv_3_num_batches_tracked] + %b_features_18_1_num_batches_tracked : [num_users=0] = placeholder[target=b_features_18_1_num_batches_tracked] + %x : [num_users=1] = placeholder[target=x] %lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0] backend_id: XnnpackBackend lowered graph(): - %aten_embedding_default : [num_users=1] = placeholder[target=aten_embedding_default] - %aten_embedding_default_1 : [num_users=1] = placeholder[target=aten_embedding_default_1] - %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_embedding_default, %aten_embedding_default_1), kwargs = {}) - return (aten_add_tensor,) - %executorch_call_delegate : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %aten_embedding_default, %aten_embedding_default_1), kwargs = {}) - %aten_sub_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.sub.Tensor](args = (%aten_unsqueeze_copy_default, %aten_unsqueeze_copy_default_1), kwargs = {}) + %p_features_0_0_weight : [num_users=1] = placeholder[target=p_features_0_0_weight] + %p_features_0_1_weight : [num_users=1] = placeholder[target=p_features_0_1_weight] + %p_features_0_1_bias : [num_users=1] = placeholder[target=p_features_0_1_bias] + %p_features_1_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_1_conv_0_0_weight] + %p_features_1_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_1_conv_0_1_weight] + %p_features_1_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_1_conv_0_1_bias] + %p_features_1_conv_1_weight : [num_users=1] = placeholder[target=p_features_1_conv_1_weight] + %p_features_1_conv_2_weight : [num_users=1] = placeholder[target=p_features_1_conv_2_weight] + %p_features_1_conv_2_bias : [num_users=1] = placeholder[target=p_features_1_conv_2_bias] + %p_features_2_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_2_conv_0_0_weight] + %p_features_2_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_2_conv_0_1_weight] + %p_features_2_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_2_conv_0_1_bias] + %p_features_2_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_2_conv_1_0_weight] + %p_features_2_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_2_conv_1_1_weight] + %p_features_2_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_2_conv_1_1_bias] + %p_features_2_conv_2_weight : [num_users=1] = placeholder[target=p_features_2_conv_2_weight] + %p_features_2_conv_3_weight : [num_users=1] = placeholder[target=p_features_2_conv_3_weight] + %p_features_2_conv_3_bias : [num_users=1] = placeholder[target=p_features_2_conv_3_bias] + %p_features_3_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_3_conv_0_0_weight] + %p_features_3_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_3_conv_0_1_weight] + %p_features_3_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_3_conv_0_1_bias] + %p_features_3_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_3_conv_1_0_weight] + %p_features_3_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_3_conv_1_1_weight] + %p_features_3_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_3_conv_1_1_bias] + %p_features_3_conv_2_weight : [num_users=1] = placeholder[target=p_features_3_conv_2_weight] + %p_features_3_conv_3_weight : [num_users=1] = placeholder[target=p_features_3_conv_3_weight] + %p_features_3_conv_3_bias : [num_users=1] = placeholder[target=p_features_3_conv_3_bias] + %p_features_4_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_4_conv_0_0_weight] + %p_features_4_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_4_conv_0_1_weight] + %p_features_4_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_4_conv_0_1_bias] + %p_features_4_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_4_conv_1_0_weight] + %p_features_4_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_4_conv_1_1_weight] + %p_features_4_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_4_conv_1_1_bias] + %p_features_4_conv_2_weight : [num_users=1] = placeholder[target=p_features_4_conv_2_weight] + %p_features_4_conv_3_weight : [num_users=1] = placeholder[target=p_features_4_conv_3_weight] + %p_features_4_conv_3_bias : [num_users=1] = placeholder[target=p_features_4_conv_3_bias] + %p_features_5_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_5_conv_0_0_weight] + %p_features_5_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_5_conv_0_1_weight] + %p_features_5_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_5_conv_0_1_bias] + %p_features_5_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_5_conv_1_0_weight] + %p_features_5_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_5_conv_1_1_weight] + %p_features_5_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_5_conv_1_1_bias] + %p_features_5_conv_2_weight : [num_users=1] = placeholder[target=p_features_5_conv_2_weight] + %p_features_5_conv_3_weight : [num_users=1] = placeholder[target=p_features_5_conv_3_weight] + %p_features_5_conv_3_bias : [num_users=1] = placeholder[target=p_features_5_conv_3_bias] + %p_features_6_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_6_conv_0_0_weight] + %p_features_6_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_6_conv_0_1_weight] + %p_features_6_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_6_conv_0_1_bias] + %p_features_6_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_6_conv_1_0_weight] + %p_features_6_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_6_conv_1_1_weight] + %p_features_6_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_6_conv_1_1_bias] + %p_features_6_conv_2_weight : [num_users=1] = placeholder[target=p_features_6_conv_2_weight] + %p_features_6_conv_3_weight : [num_users=1] = placeholder[target=p_features_6_conv_3_weight] + %p_features_6_conv_3_bias : [num_users=1] = placeholder[target=p_features_6_conv_3_bias] + %p_features_7_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_7_conv_0_0_weight] + %p_features_7_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_7_conv_0_1_weight] + %p_features_7_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_7_conv_0_1_bias] + %p_features_7_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_7_conv_1_0_weight] + %p_features_7_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_7_conv_1_1_weight] + %p_features_7_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_7_conv_1_1_bias] + %p_features_7_conv_2_weight : [num_users=1] = placeholder[target=p_features_7_conv_2_weight] + %p_features_7_conv_3_weight : [num_users=1] = placeholder[target=p_features_7_conv_3_weight] + %p_features_7_conv_3_bias : [num_users=1] = placeholder[target=p_features_7_conv_3_bias] + %p_features_8_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_8_conv_0_0_weight] + %p_features_8_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_8_conv_0_1_weight] + %p_features_8_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_8_conv_0_1_bias] + %p_features_8_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_8_conv_1_0_weight] + %p_features_8_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_8_conv_1_1_weight] + %p_features_8_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_8_conv_1_1_bias] + %p_features_8_conv_2_weight : [num_users=1] = placeholder[target=p_features_8_conv_2_weight] + %p_features_8_conv_3_weight : [num_users=1] = placeholder[target=p_features_8_conv_3_weight] + %p_features_8_conv_3_bias : [num_users=1] = placeholder[target=p_features_8_conv_3_bias] + %p_features_9_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_9_conv_0_0_weight] + %p_features_9_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_9_conv_0_1_weight] + %p_features_9_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_9_conv_0_1_bias] + %p_features_9_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_9_conv_1_0_weight] + %p_features_9_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_9_conv_1_1_weight] + %p_features_9_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_9_conv_1_1_bias] + %p_features_9_conv_2_weight : [num_users=1] = placeholder[target=p_features_9_conv_2_weight] + %p_features_9_conv_3_weight : [num_users=1] = placeholder[target=p_features_9_conv_3_weight] + %p_features_9_conv_3_bias : [num_users=1] = placeholder[target=p_features_9_conv_3_bias] + %p_features_10_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_10_conv_0_0_weight] + %p_features_10_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_10_conv_0_1_weight] + %p_features_10_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_10_conv_0_1_bias] + %p_features_10_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_10_conv_1_0_weight] + %p_features_10_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_10_conv_1_1_weight] + %p_features_10_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_10_conv_1_1_bias] + %p_features_10_conv_2_weight : [num_users=1] = placeholder[target=p_features_10_conv_2_weight] + %p_features_10_conv_3_weight : [num_users=1] = placeholder[target=p_features_10_conv_3_weight] + %p_features_10_conv_3_bias : [num_users=1] = placeholder[target=p_features_10_conv_3_bias] + %p_features_11_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_11_conv_0_0_weight] + %p_features_11_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_11_conv_0_1_weight] + %p_features_11_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_11_conv_0_1_bias] + %p_features_11_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_11_conv_1_0_weight] + %p_features_11_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_11_conv_1_1_weight] + %p_features_11_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_11_conv_1_1_bias] + %p_features_11_conv_2_weight : [num_users=1] = placeholder[target=p_features_11_conv_2_weight] + %p_features_11_conv_3_weight : [num_users=1] = placeholder[target=p_features_11_conv_3_weight] + %p_features_11_conv_3_bias : [num_users=1] = placeholder[target=p_features_11_conv_3_bias] + %p_features_12_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_12_conv_0_0_weight] + %p_features_12_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_12_conv_0_1_weight] + %p_features_12_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_12_conv_0_1_bias] + %p_features_12_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_12_conv_1_0_weight] + %p_features_12_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_12_conv_1_1_weight] + %p_features_12_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_12_conv_1_1_bias] + %p_features_12_conv_2_weight : [num_users=1] = placeholder[target=p_features_12_conv_2_weight] + %p_features_12_conv_3_weight : [num_users=1] = placeholder[target=p_features_12_conv_3_weight] + %p_features_12_conv_3_bias : [num_users=1] = placeholder[target=p_features_12_conv_3_bias] + %p_features_13_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_13_conv_0_0_weight] + %p_features_13_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_13_conv_0_1_weight] + %p_features_13_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_13_conv_0_1_bias] + %p_features_13_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_13_conv_1_0_weight] + %p_features_13_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_13_conv_1_1_weight] + %p_features_13_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_13_conv_1_1_bias] + %p_features_13_conv_2_weight : [num_users=1] = placeholder[target=p_features_13_conv_2_weight] + %p_features_13_conv_3_weight : [num_users=1] = placeholder[target=p_features_13_conv_3_weight] + %p_features_13_conv_3_bias : [num_users=1] = placeholder[target=p_features_13_conv_3_bias] + %p_features_14_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_14_conv_0_0_weight] + %p_features_14_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_14_conv_0_1_weight] + %p_features_14_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_14_conv_0_1_bias] + %p_features_14_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_14_conv_1_0_weight] + %p_features_14_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_14_conv_1_1_weight] + %p_features_14_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_14_conv_1_1_bias] + %p_features_14_conv_2_weight : [num_users=1] = placeholder[target=p_features_14_conv_2_weight] + %p_features_14_conv_3_weight : [num_users=1] = placeholder[target=p_features_14_conv_3_weight] + %p_features_14_conv_3_bias : [num_users=1] = placeholder[target=p_features_14_conv_3_bias] + %p_features_15_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_15_conv_0_0_weight] + %p_features_15_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_15_conv_0_1_weight] + %p_features_15_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_15_conv_0_1_bias] + %p_features_15_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_15_conv_1_0_weight] + %p_features_15_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_15_conv_1_1_weight] + %p_features_15_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_15_conv_1_1_bias] + %p_features_15_conv_2_weight : [num_users=1] = placeholder[target=p_features_15_conv_2_weight] + %p_features_15_conv_3_weight : [num_users=1] = placeholder[target=p_features_15_conv_3_weight] + %p_features_15_conv_3_bias : [num_users=1] = placeholder[target=p_features_15_conv_3_bias] + %p_features_16_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_16_conv_0_0_weight] + %p_features_16_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_16_conv_0_1_weight] + %p_features_16_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_16_conv_0_1_bias] + %p_features_16_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_16_conv_1_0_weight] + %p_features_16_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_16_conv_1_1_weight] + %p_features_16_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_16_conv_1_1_bias] + %p_features_16_conv_2_weight : [num_users=1] = placeholder[target=p_features_16_conv_2_weight] + %p_features_16_conv_3_weight : [num_users=1] = placeholder[target=p_features_16_conv_3_weight] + %p_features_16_conv_3_bias : [num_users=1] = placeholder[target=p_features_16_conv_3_bias] + %p_features_17_conv_0_0_weight : [num_users=1] = placeholder[target=p_features_17_conv_0_0_weight] + %p_features_17_conv_0_1_weight : [num_users=1] = placeholder[target=p_features_17_conv_0_1_weight] + %p_features_17_conv_0_1_bias : [num_users=1] = placeholder[target=p_features_17_conv_0_1_bias] + %p_features_17_conv_1_0_weight : [num_users=1] = placeholder[target=p_features_17_conv_1_0_weight] + %p_features_17_conv_1_1_weight : [num_users=1] = placeholder[target=p_features_17_conv_1_1_weight] + %p_features_17_conv_1_1_bias : [num_users=1] = placeholder[target=p_features_17_conv_1_1_bias] + %p_features_17_conv_2_weight : [num_users=1] = placeholder[target=p_features_17_conv_2_weight] + %p_features_17_conv_3_weight : [num_users=1] = placeholder[target=p_features_17_conv_3_weight] + %p_features_17_conv_3_bias : [num_users=1] = placeholder[target=p_features_17_conv_3_bias] + %p_features_18_0_weight : [num_users=1] = placeholder[target=p_features_18_0_weight] + %p_features_18_1_weight : [num_users=1] = placeholder[target=p_features_18_1_weight] + %p_features_18_1_bias : [num_users=1] = placeholder[target=p_features_18_1_bias] + %b_features_0_1_running_mean : [num_users=1] = placeholder[target=b_features_0_1_running_mean] + %b_features_0_1_running_var : [num_users=1] = placeholder[target=b_features_0_1_running_var] + %b_features_1_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_1_conv_0_1_running_mean] + %b_features_1_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_1_conv_0_1_running_var] + %b_features_1_conv_2_running_mean : [num_users=1] = placeholder[target=b_features_1_conv_2_running_mean] + %b_features_1_conv_2_running_var : [num_users=1] = placeholder[target=b_features_1_conv_2_running_var] + %b_features_2_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_2_conv_0_1_running_mean] + %b_features_2_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_2_conv_0_1_running_var] + %b_features_2_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_2_conv_1_1_running_mean] + %b_features_2_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_2_conv_1_1_running_var] + %b_features_2_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_2_conv_3_running_mean] + %b_features_2_conv_3_running_var : [num_users=1] = placeholder[target=b_features_2_conv_3_running_var] + %b_features_3_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_3_conv_0_1_running_mean] + %b_features_3_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_3_conv_0_1_running_var] + %b_features_3_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_3_conv_1_1_running_mean] + %b_features_3_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_3_conv_1_1_running_var] + %b_features_3_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_3_conv_3_running_mean] + %b_features_3_conv_3_running_var : [num_users=1] = placeholder[target=b_features_3_conv_3_running_var] + %b_features_4_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_4_conv_0_1_running_mean] + %b_features_4_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_4_conv_0_1_running_var] + %b_features_4_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_4_conv_1_1_running_mean] + %b_features_4_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_4_conv_1_1_running_var] + %b_features_4_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_4_conv_3_running_mean] + %b_features_4_conv_3_running_var : [num_users=1] = placeholder[target=b_features_4_conv_3_running_var] + %b_features_5_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_5_conv_0_1_running_mean] + %b_features_5_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_5_conv_0_1_running_var] + %b_features_5_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_5_conv_1_1_running_mean] + %b_features_5_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_5_conv_1_1_running_var] + %b_features_5_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_5_conv_3_running_mean] + %b_features_5_conv_3_running_var : [num_users=1] = placeholder[target=b_features_5_conv_3_running_var] + %b_features_6_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_6_conv_0_1_running_mean] + %b_features_6_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_6_conv_0_1_running_var] + %b_features_6_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_6_conv_1_1_running_mean] + %b_features_6_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_6_conv_1_1_running_var] + %b_features_6_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_6_conv_3_running_mean] + %b_features_6_conv_3_running_var : [num_users=1] = placeholder[target=b_features_6_conv_3_running_var] + %b_features_7_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_7_conv_0_1_running_mean] + %b_features_7_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_7_conv_0_1_running_var] + %b_features_7_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_7_conv_1_1_running_mean] + %b_features_7_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_7_conv_1_1_running_var] + %b_features_7_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_7_conv_3_running_mean] + %b_features_7_conv_3_running_var : [num_users=1] = placeholder[target=b_features_7_conv_3_running_var] + %b_features_8_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_8_conv_0_1_running_mean] + %b_features_8_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_8_conv_0_1_running_var] + %b_features_8_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_8_conv_1_1_running_mean] + %b_features_8_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_8_conv_1_1_running_var] + %b_features_8_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_8_conv_3_running_mean] + %b_features_8_conv_3_running_var : [num_users=1] = placeholder[target=b_features_8_conv_3_running_var] + %b_features_9_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_9_conv_0_1_running_mean] + %b_features_9_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_9_conv_0_1_running_var] + %b_features_9_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_9_conv_1_1_running_mean] + %b_features_9_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_9_conv_1_1_running_var] + %b_features_9_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_9_conv_3_running_mean] + %b_features_9_conv_3_running_var : [num_users=1] = placeholder[target=b_features_9_conv_3_running_var] + %b_features_10_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_10_conv_0_1_running_mean] + %b_features_10_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_10_conv_0_1_running_var] + %b_features_10_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_10_conv_1_1_running_mean] + %b_features_10_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_10_conv_1_1_running_var] + %b_features_10_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_10_conv_3_running_mean] + %b_features_10_conv_3_running_var : [num_users=1] = placeholder[target=b_features_10_conv_3_running_var] + %b_features_11_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_11_conv_0_1_running_mean] + %b_features_11_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_11_conv_0_1_running_var] + %b_features_11_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_11_conv_1_1_running_mean] + %b_features_11_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_11_conv_1_1_running_var] + %b_features_11_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_11_conv_3_running_mean] + %b_features_11_conv_3_running_var : [num_users=1] = placeholder[target=b_features_11_conv_3_running_var] + %b_features_12_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_12_conv_0_1_running_mean] + %b_features_12_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_12_conv_0_1_running_var] + %b_features_12_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_12_conv_1_1_running_mean] + %b_features_12_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_12_conv_1_1_running_var] + %b_features_12_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_12_conv_3_running_mean] + %b_features_12_conv_3_running_var : [num_users=1] = placeholder[target=b_features_12_conv_3_running_var] + %b_features_13_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_13_conv_0_1_running_mean] + %b_features_13_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_13_conv_0_1_running_var] + %b_features_13_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_13_conv_1_1_running_mean] + %b_features_13_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_13_conv_1_1_running_var] + %b_features_13_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_13_conv_3_running_mean] + %b_features_13_conv_3_running_var : [num_users=1] = placeholder[target=b_features_13_conv_3_running_var] + %b_features_14_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_14_conv_0_1_running_mean] + %b_features_14_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_14_conv_0_1_running_var] + %b_features_14_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_14_conv_1_1_running_mean] + %b_features_14_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_14_conv_1_1_running_var] + %b_features_14_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_14_conv_3_running_mean] + %b_features_14_conv_3_running_var : [num_users=1] = placeholder[target=b_features_14_conv_3_running_var] + %b_features_15_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_15_conv_0_1_running_mean] + %b_features_15_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_15_conv_0_1_running_var] + %b_features_15_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_15_conv_1_1_running_mean] + %b_features_15_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_15_conv_1_1_running_var] + %b_features_15_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_15_conv_3_running_mean] + %b_features_15_conv_3_running_var : [num_users=1] = placeholder[target=b_features_15_conv_3_running_var] + %b_features_16_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_16_conv_0_1_running_mean] + %b_features_16_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_16_conv_0_1_running_var] + %b_features_16_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_16_conv_1_1_running_mean] + %b_features_16_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_16_conv_1_1_running_var] + %b_features_16_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_16_conv_3_running_mean] + %b_features_16_conv_3_running_var : [num_users=1] = placeholder[target=b_features_16_conv_3_running_var] + %b_features_17_conv_0_1_running_mean : [num_users=1] = placeholder[target=b_features_17_conv_0_1_running_mean] + %b_features_17_conv_0_1_running_var : [num_users=1] = placeholder[target=b_features_17_conv_0_1_running_var] + %b_features_17_conv_1_1_running_mean : [num_users=1] = placeholder[target=b_features_17_conv_1_1_running_mean] + %b_features_17_conv_1_1_running_var : [num_users=1] = placeholder[target=b_features_17_conv_1_1_running_var] + %b_features_17_conv_3_running_mean : [num_users=1] = placeholder[target=b_features_17_conv_3_running_mean] + %b_features_17_conv_3_running_var : [num_users=1] = placeholder[target=b_features_17_conv_3_running_var] + %b_features_18_1_running_mean : [num_users=1] = placeholder[target=b_features_18_1_running_mean] + %b_features_18_1_running_var : [num_users=1] = placeholder[target=b_features_18_1_running_var] + %x : [num_users=1] = placeholder[target=x] + %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%x, %p_features_0_0_weight, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default, %p_features_0_1_weight, %p_features_0_1_bias, %b_features_0_1_running_mean, %b_features_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default, 0), kwargs = {}) + %aten_hardtanh_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default, %p_features_1_conv_0_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 32), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_1, %p_features_1_conv_0_1_weight, %p_features_1_conv_0_1_bias, %b_features_1_conv_0_1_running_mean, %b_features_1_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_1, 0), kwargs = {}) + %aten_hardtanh_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_1, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_1, %p_features_1_conv_1_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_2, %p_features_1_conv_2_weight, %p_features_1_conv_2_bias, %b_features_1_conv_2_running_mean, %b_features_1_conv_2_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_2, 0), kwargs = {}) + %aten_convolution_default_3 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%getitem_2, %p_features_2_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_3 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_3, %p_features_2_conv_0_1_weight, %p_features_2_conv_0_1_bias, %b_features_2_conv_0_1_running_mean, %b_features_2_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_3, 0), kwargs = {}) + %aten_hardtanh_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_3, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_4 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_2, %p_features_2_conv_1_0_weight, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 96), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_4 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_4, %p_features_2_conv_1_1_weight, %p_features_2_conv_1_1_bias, %b_features_2_conv_1_1_running_mean, %b_features_2_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_4, 0), kwargs = {}) + %aten_hardtanh_default_3 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_4, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_5 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_3, %p_features_2_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_5 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_5, %p_features_2_conv_3_weight, %p_features_2_conv_3_bias, %b_features_2_conv_3_running_mean, %b_features_2_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_5 : [num_users=2] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_5, 0), kwargs = {}) + %aten_convolution_default_6 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%getitem_5, %p_features_3_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_6 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_6, %p_features_3_conv_0_1_weight, %p_features_3_conv_0_1_bias, %b_features_3_conv_0_1_running_mean, %b_features_3_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_6, 0), kwargs = {}) + %aten_hardtanh_default_4 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_6, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_7 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_4, %p_features_3_conv_1_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 144), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_7 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_7, %p_features_3_conv_1_1_weight, %p_features_3_conv_1_1_bias, %b_features_3_conv_1_1_running_mean, %b_features_3_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_7, 0), kwargs = {}) + %aten_hardtanh_default_5 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_7, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_8 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_5, %p_features_3_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_8 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_8, %p_features_3_conv_3_weight, %p_features_3_conv_3_bias, %b_features_3_conv_3_running_mean, %b_features_3_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_8, 0), kwargs = {}) + %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%getitem_5, %getitem_8), kwargs = {}) + %aten_convolution_default_9 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor, %p_features_4_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_9 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_9, %p_features_4_conv_0_1_weight, %p_features_4_conv_0_1_bias, %b_features_4_conv_0_1_running_mean, %b_features_4_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_9, 0), kwargs = {}) + %aten_hardtanh_default_6 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_9, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_10 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_6, %p_features_4_conv_1_0_weight, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 144), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_10 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_10, %p_features_4_conv_1_1_weight, %p_features_4_conv_1_1_bias, %b_features_4_conv_1_1_running_mean, %b_features_4_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_10, 0), kwargs = {}) + %aten_hardtanh_default_7 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_10, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_11 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_7, %p_features_4_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_11 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_11, %p_features_4_conv_3_weight, %p_features_4_conv_3_bias, %b_features_4_conv_3_running_mean, %b_features_4_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_11 : [num_users=2] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_11, 0), kwargs = {}) + %aten_convolution_default_12 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%getitem_11, %p_features_5_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_12 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_12, %p_features_5_conv_0_1_weight, %p_features_5_conv_0_1_bias, %b_features_5_conv_0_1_running_mean, %b_features_5_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_12 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_12, 0), kwargs = {}) + %aten_hardtanh_default_8 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_12, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_13 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_8, %p_features_5_conv_1_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 192), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_13 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_13, %p_features_5_conv_1_1_weight, %p_features_5_conv_1_1_bias, %b_features_5_conv_1_1_running_mean, %b_features_5_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_13 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_13, 0), kwargs = {}) + %aten_hardtanh_default_9 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_13, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_14 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_9, %p_features_5_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_14 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_14, %p_features_5_conv_3_weight, %p_features_5_conv_3_bias, %b_features_5_conv_3_running_mean, %b_features_5_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_14 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_14, 0), kwargs = {}) + %aten_add_tensor_1 : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%getitem_11, %getitem_14), kwargs = {}) + %aten_convolution_default_15 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor_1, %p_features_6_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_15 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_15, %p_features_6_conv_0_1_weight, %p_features_6_conv_0_1_bias, %b_features_6_conv_0_1_running_mean, %b_features_6_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_15 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_15, 0), kwargs = {}) + %aten_hardtanh_default_10 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_15, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_16 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_10, %p_features_6_conv_1_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 192), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_16 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_16, %p_features_6_conv_1_1_weight, %p_features_6_conv_1_1_bias, %b_features_6_conv_1_1_running_mean, %b_features_6_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_16 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_16, 0), kwargs = {}) + %aten_hardtanh_default_11 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_16, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_17 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_11, %p_features_6_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_17 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_17, %p_features_6_conv_3_weight, %p_features_6_conv_3_bias, %b_features_6_conv_3_running_mean, %b_features_6_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_17 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_17, 0), kwargs = {}) + %aten_add_tensor_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_add_tensor_1, %getitem_17), kwargs = {}) + %aten_convolution_default_18 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor_2, %p_features_7_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_18 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_18, %p_features_7_conv_0_1_weight, %p_features_7_conv_0_1_bias, %b_features_7_conv_0_1_running_mean, %b_features_7_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_18 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_18, 0), kwargs = {}) + %aten_hardtanh_default_12 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_18, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_19 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_12, %p_features_7_conv_1_0_weight, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 192), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_19 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_19, %p_features_7_conv_1_1_weight, %p_features_7_conv_1_1_bias, %b_features_7_conv_1_1_running_mean, %b_features_7_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_19 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_19, 0), kwargs = {}) + %aten_hardtanh_default_13 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_19, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_20 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_13, %p_features_7_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_20 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_20, %p_features_7_conv_3_weight, %p_features_7_conv_3_bias, %b_features_7_conv_3_running_mean, %b_features_7_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_20 : [num_users=2] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_20, 0), kwargs = {}) + %aten_convolution_default_21 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%getitem_20, %p_features_8_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_21 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_21, %p_features_8_conv_0_1_weight, %p_features_8_conv_0_1_bias, %b_features_8_conv_0_1_running_mean, %b_features_8_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_21 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_21, 0), kwargs = {}) + %aten_hardtanh_default_14 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_21, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_22 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_14, %p_features_8_conv_1_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 384), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_22 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_22, %p_features_8_conv_1_1_weight, %p_features_8_conv_1_1_bias, %b_features_8_conv_1_1_running_mean, %b_features_8_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_22 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_22, 0), kwargs = {}) + %aten_hardtanh_default_15 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_22, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_23 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_15, %p_features_8_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_23 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_23, %p_features_8_conv_3_weight, %p_features_8_conv_3_bias, %b_features_8_conv_3_running_mean, %b_features_8_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_23 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_23, 0), kwargs = {}) + %aten_add_tensor_3 : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%getitem_20, %getitem_23), kwargs = {}) + %aten_convolution_default_24 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor_3, %p_features_9_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_24 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_24, %p_features_9_conv_0_1_weight, %p_features_9_conv_0_1_bias, %b_features_9_conv_0_1_running_mean, %b_features_9_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_24 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_24, 0), kwargs = {}) + %aten_hardtanh_default_16 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_24, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_25 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_16, %p_features_9_conv_1_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 384), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_25 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_25, %p_features_9_conv_1_1_weight, %p_features_9_conv_1_1_bias, %b_features_9_conv_1_1_running_mean, %b_features_9_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_25 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_25, 0), kwargs = {}) + %aten_hardtanh_default_17 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_25, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_26 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_17, %p_features_9_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_26 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_26, %p_features_9_conv_3_weight, %p_features_9_conv_3_bias, %b_features_9_conv_3_running_mean, %b_features_9_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_26 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_26, 0), kwargs = {}) + %aten_add_tensor_4 : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_add_tensor_3, %getitem_26), kwargs = {}) + %aten_convolution_default_27 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor_4, %p_features_10_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_27 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_27, %p_features_10_conv_0_1_weight, %p_features_10_conv_0_1_bias, %b_features_10_conv_0_1_running_mean, %b_features_10_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_27 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_27, 0), kwargs = {}) + %aten_hardtanh_default_18 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_27, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_28 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_18, %p_features_10_conv_1_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 384), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_28 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_28, %p_features_10_conv_1_1_weight, %p_features_10_conv_1_1_bias, %b_features_10_conv_1_1_running_mean, %b_features_10_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_28 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_28, 0), kwargs = {}) + %aten_hardtanh_default_19 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_28, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_29 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_19, %p_features_10_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_29 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_29, %p_features_10_conv_3_weight, %p_features_10_conv_3_bias, %b_features_10_conv_3_running_mean, %b_features_10_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_29 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_29, 0), kwargs = {}) + %aten_add_tensor_5 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_add_tensor_4, %getitem_29), kwargs = {}) + %aten_convolution_default_30 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor_5, %p_features_11_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_30 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_30, %p_features_11_conv_0_1_weight, %p_features_11_conv_0_1_bias, %b_features_11_conv_0_1_running_mean, %b_features_11_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_30 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_30, 0), kwargs = {}) + %aten_hardtanh_default_20 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_30, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_31 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_20, %p_features_11_conv_1_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 384), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_31 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_31, %p_features_11_conv_1_1_weight, %p_features_11_conv_1_1_bias, %b_features_11_conv_1_1_running_mean, %b_features_11_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_31 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_31, 0), kwargs = {}) + %aten_hardtanh_default_21 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_31, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_32 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_21, %p_features_11_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_32 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_32, %p_features_11_conv_3_weight, %p_features_11_conv_3_bias, %b_features_11_conv_3_running_mean, %b_features_11_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_32 : [num_users=2] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_32, 0), kwargs = {}) + %aten_convolution_default_33 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%getitem_32, %p_features_12_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_33 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_33, %p_features_12_conv_0_1_weight, %p_features_12_conv_0_1_bias, %b_features_12_conv_0_1_running_mean, %b_features_12_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_33 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_33, 0), kwargs = {}) + %aten_hardtanh_default_22 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_33, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_34 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_22, %p_features_12_conv_1_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 576), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_34 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_34, %p_features_12_conv_1_1_weight, %p_features_12_conv_1_1_bias, %b_features_12_conv_1_1_running_mean, %b_features_12_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_34 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_34, 0), kwargs = {}) + %aten_hardtanh_default_23 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_34, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_35 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_23, %p_features_12_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_35 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_35, %p_features_12_conv_3_weight, %p_features_12_conv_3_bias, %b_features_12_conv_3_running_mean, %b_features_12_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_35 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_35, 0), kwargs = {}) + %aten_add_tensor_6 : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%getitem_32, %getitem_35), kwargs = {}) + %aten_convolution_default_36 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor_6, %p_features_13_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_36 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_36, %p_features_13_conv_0_1_weight, %p_features_13_conv_0_1_bias, %b_features_13_conv_0_1_running_mean, %b_features_13_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_36 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_36, 0), kwargs = {}) + %aten_hardtanh_default_24 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_36, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_37 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_24, %p_features_13_conv_1_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 576), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_37 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_37, %p_features_13_conv_1_1_weight, %p_features_13_conv_1_1_bias, %b_features_13_conv_1_1_running_mean, %b_features_13_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_37 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_37, 0), kwargs = {}) + %aten_hardtanh_default_25 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_37, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_38 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_25, %p_features_13_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_38 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_38, %p_features_13_conv_3_weight, %p_features_13_conv_3_bias, %b_features_13_conv_3_running_mean, %b_features_13_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_38 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_38, 0), kwargs = {}) + %aten_add_tensor_7 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_add_tensor_6, %getitem_38), kwargs = {}) + %aten_convolution_default_39 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor_7, %p_features_14_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_39 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_39, %p_features_14_conv_0_1_weight, %p_features_14_conv_0_1_bias, %b_features_14_conv_0_1_running_mean, %b_features_14_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_39 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_39, 0), kwargs = {}) + %aten_hardtanh_default_26 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_39, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_40 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_26, %p_features_14_conv_1_0_weight, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 576), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_40 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_40, %p_features_14_conv_1_1_weight, %p_features_14_conv_1_1_bias, %b_features_14_conv_1_1_running_mean, %b_features_14_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_40 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_40, 0), kwargs = {}) + %aten_hardtanh_default_27 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_40, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_41 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_27, %p_features_14_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_41 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_41, %p_features_14_conv_3_weight, %p_features_14_conv_3_bias, %b_features_14_conv_3_running_mean, %b_features_14_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_41 : [num_users=2] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_41, 0), kwargs = {}) + %aten_convolution_default_42 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%getitem_41, %p_features_15_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_42 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_42, %p_features_15_conv_0_1_weight, %p_features_15_conv_0_1_bias, %b_features_15_conv_0_1_running_mean, %b_features_15_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_42 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_42, 0), kwargs = {}) + %aten_hardtanh_default_28 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_42, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_43 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_28, %p_features_15_conv_1_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 960), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_43 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_43, %p_features_15_conv_1_1_weight, %p_features_15_conv_1_1_bias, %b_features_15_conv_1_1_running_mean, %b_features_15_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_43 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_43, 0), kwargs = {}) + %aten_hardtanh_default_29 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_43, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_44 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_29, %p_features_15_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_44 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_44, %p_features_15_conv_3_weight, %p_features_15_conv_3_bias, %b_features_15_conv_3_running_mean, %b_features_15_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_44 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_44, 0), kwargs = {}) + %aten_add_tensor_8 : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%getitem_41, %getitem_44), kwargs = {}) + %aten_convolution_default_45 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor_8, %p_features_16_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_45 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_45, %p_features_16_conv_0_1_weight, %p_features_16_conv_0_1_bias, %b_features_16_conv_0_1_running_mean, %b_features_16_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_45 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_45, 0), kwargs = {}) + %aten_hardtanh_default_30 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_45, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_46 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_30, %p_features_16_conv_1_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 960), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_46 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_46, %p_features_16_conv_1_1_weight, %p_features_16_conv_1_1_bias, %b_features_16_conv_1_1_running_mean, %b_features_16_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_46 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_46, 0), kwargs = {}) + %aten_hardtanh_default_31 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_46, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_47 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_31, %p_features_16_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_47 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_47, %p_features_16_conv_3_weight, %p_features_16_conv_3_bias, %b_features_16_conv_3_running_mean, %b_features_16_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_47 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_47, 0), kwargs = {}) + %aten_add_tensor_9 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_add_tensor_8, %getitem_47), kwargs = {}) + %aten_convolution_default_48 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor_9, %p_features_17_conv_0_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_48 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_48, %p_features_17_conv_0_1_weight, %p_features_17_conv_0_1_bias, %b_features_17_conv_0_1_running_mean, %b_features_17_conv_0_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_48 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_48, 0), kwargs = {}) + %aten_hardtanh_default_32 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_48, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_49 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_32, %p_features_17_conv_1_0_weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 960), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_49 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_49, %p_features_17_conv_1_1_weight, %p_features_17_conv_1_1_bias, %b_features_17_conv_1_1_running_mean, %b_features_17_conv_1_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_49 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_49, 0), kwargs = {}) + %aten_hardtanh_default_33 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_49, 0.0, 6.0), kwargs = {}) + %aten_convolution_default_50 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_hardtanh_default_33, %p_features_17_conv_2_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_50 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_50, %p_features_17_conv_3_weight, %p_features_17_conv_3_bias, %b_features_17_conv_3_running_mean, %b_features_17_conv_3_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_50 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_50, 0), kwargs = {}) + %aten_convolution_default_51 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%getitem_50, %p_features_18_0_weight, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + %aten__native_batch_norm_legit_no_training_default_51 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten_convolution_default_51, %p_features_18_1_weight, %p_features_18_1_bias, %b_features_18_1_running_mean, %b_features_18_1_running_var, 0.1, 1e-05), kwargs = {}) + %getitem_51 : [num_users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default_51, 0), kwargs = {}) + %aten_hardtanh_default_34 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem_51, 0.0, 6.0), kwargs = {}) + %aten_mean_dim : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mean.dim](args = (%aten_hardtanh_default_34, [-1, -2], True), kwargs = {}) + return (aten_mean_dim,) + %executorch_call_delegate : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %x), kwargs = {}) + %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {}) + %aten_view_copy_default : [num_users=1] = call_function[target=executorch.exir.memory.view](args = (%getitem, [1, 1280]), kwargs = {}) + %alloc : [num_users=1] = call_function[target=executorch.exir.memory.alloc](args = (((1, 1280), torch.float32),), kwargs = {}) + %dim_order_ops__clone_dim_order_default : [num_users=1] = call_function[target=torch.ops.dim_order_ops._clone_dim_order.out](args = (%aten_view_copy_default,), kwargs = {dim_order: [0, 1], out: %alloc}) + %lowered_module_1 : [num_users=1] = get_attr[target=lowered_module_1] + backend_id: XnnpackBackend + lowered graph(): + %p_classifier_1_weight : [num_users=1] = placeholder[target=p_classifier_1_weight] + %p_classifier_1_bias : [num_users=1] = placeholder[target=p_classifier_1_bias] + %dim_order_ops__clone_dim_order_default : [num_users=1] = placeholder[target=dim_order_ops__clone_dim_order_default] + %aten_linear_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.linear.default](args = (%dim_order_ops__clone_dim_order_default, %p_classifier_1_weight, %p_classifier_1_bias), kwargs = {}) + return (aten_linear_default,) + %executorch_call_delegate_1 : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_1, %dim_order_ops__clone_dim_order_default), kwargs = {}) + %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_1, 0), kwargs = {}) + return (getitem_1,) ``` +
diff --git a/docs/source/desktop-backends.md b/docs/source/desktop-backends.md new file mode 100644 index 00000000000..e4220edb47f --- /dev/null +++ b/docs/source/desktop-backends.md @@ -0,0 +1,27 @@ +(desktop-backends)= +# Backends + +Available hardware acceleration backends for desktop platforms. + +## Linux Backends + +- {doc}`desktop-xnnpack` — XNNPACK (CPU acceleration) +- {doc}`desktop-openvino` — OpenVINO (Intel hardware optimization) + +## macOS Backends + +- {doc}`desktop-coreml` — CoreML (recommended for Apple Silicon) +- {doc}`desktop-mps` — Metal Performance Shaders (Apple Silicon GPU) +- {doc}`desktop-xnnpack` — XNNPACK (CPU acceleration) + +## Windows Backends + +- {doc}`desktop-xnnpack` — XNNPACK (CPU acceleration) +- {doc}`desktop-openvino` — OpenVINO (Intel hardware optimization) + +```{toctree} +:hidden: +desktop-xnnpack +desktop-openvino +desktop-coreml +desktop-mps diff --git a/docs/source/desktop-coreml.md b/docs/source/desktop-coreml.md new file mode 100644 index 00000000000..48271326d87 --- /dev/null +++ b/docs/source/desktop-coreml.md @@ -0,0 +1 @@ +```{include} backends-coreml.md diff --git a/docs/source/desktop-mps.md b/docs/source/desktop-mps.md new file mode 100644 index 00000000000..d6f305d33aa --- /dev/null +++ b/docs/source/desktop-mps.md @@ -0,0 +1 @@ +```{include} backends-mps.md diff --git a/docs/source/desktop-openvino.md b/docs/source/desktop-openvino.md new file mode 100644 index 00000000000..a0fd5774c73 --- /dev/null +++ b/docs/source/desktop-openvino.md @@ -0,0 +1 @@ +```{include} build-run-openvino.md diff --git a/docs/source/desktop-section.md b/docs/source/desktop-section.md new file mode 100644 index 00000000000..bf306e7c43b --- /dev/null +++ b/docs/source/desktop-section.md @@ -0,0 +1,24 @@ +(desktop-section)= +# Desktop & Laptop Platforms + +Deploy ExecuTorch on Linux, macOS, and Windows with optimized backends for each platform. + +## Platform Overview & Runtime + +- {doc}`using-executorch-cpp` — C++ runtime integration guide +- {doc}`using-executorch-building-from-source` — Building ExecuTorch from source + +## Backends + +- {doc}`desktop-backends` — Available desktop backends and platform-specific optimization + +## Tutorials + +- {doc}`raspberry_pi_llama_tutorial` — Cross compiling ExecuTorch for the Raspberry Pi on Linux Host + +```{toctree} +:hidden: +using-executorch-cpp +using-executorch-building-from-source +desktop-backends +raspberry_pi_llama_tutorial diff --git a/docs/source/desktop-xnnpack.md b/docs/source/desktop-xnnpack.md new file mode 100644 index 00000000000..4a85dec946b --- /dev/null +++ b/docs/source/desktop-xnnpack.md @@ -0,0 +1 @@ +```{include} backends/xnnpack/xnnpack-overview.md diff --git a/docs/source/developer-tools.md b/docs/source/developer-tools.md new file mode 100644 index 00000000000..d3b90b7adc8 --- /dev/null +++ b/docs/source/developer-tools.md @@ -0,0 +1,16 @@ +# Tools + +```{toctree} +:maxdepth: 1 + +devtools-overview +bundled-io +etrecord +etdump +runtime-profiling +model-debugging +model-inspector +memory-planning-inspection +delegate-debugging +devtools-tutorial +``` diff --git a/docs/source/devtools-overview.md b/docs/source/devtools-overview.md index 449dd1485dc..ac797252daf 100644 --- a/docs/source/devtools-overview.md +++ b/docs/source/devtools-overview.md @@ -17,7 +17,7 @@ The ExecuTorch Developer Tools support the following features: - **Debugging** - Intermediate outputs and output quality analysis - **Numerical Discrepancy Detection** - Operator-level numerical discrepancy detection between AOT and runtime intermediate outputs to streamline numerical debugging and validation. - **Memory Allocation Insights** - Visualize how memory is planned, where all the live tensors are at any point in time -- **Visualization** - Coming soon +- **Visualization** - Visualize the model as a computational graph (see more [here](visualize.md)) ## Fundamental components of the Developer Tools @@ -41,6 +41,6 @@ More details are available in the [ETDump documentation](etdump.md) on how to ge ### Inspector APIs -The Inspector Python APIs are the main user enrty point into the Developer Tools. They join the data sourced from ETDump and ETRecord to give users access to all the performance and debug data sourced from the runtime along with linkage back to eager model source code and module hierarchy in an easy to use API. +The Inspector Python APIs are the main user entry point into the Developer Tools. They join the data sourced from ETDump and ETRecord to give users access to all the performance and debug data sourced from the runtime along with linkage back to eager model source code and module hierarchy in an easy to use API. More details are available in the [Inspector API documentation](model-inspector.rst) on how to use the Inspector APIs. diff --git a/docs/source/devtools-tutorial.md b/docs/source/devtools-tutorial.md index 7c6cedc311b..6d540dc7f35 100644 --- a/docs/source/devtools-tutorial.md +++ b/docs/source/devtools-tutorial.md @@ -1,3 +1,3 @@ ## Developer Tools Usage Tutorial -Please refer to the [Developer Tools tutorial](https://pytorch.org/executorch/main/tutorials/devtools-integration-tutorial) for a walkthrough on how to profile a model in ExecuTorch using the Developer Tools. +Please refer to the [Developer Tools tutorial](tutorials/devtools-integration-tutorial) for a walkthrough on how to profile a model in ExecuTorch using the Developer Tools. diff --git a/docs/source/edge-platforms-section.md b/docs/source/edge-platforms-section.md new file mode 100644 index 00000000000..209986507fa --- /dev/null +++ b/docs/source/edge-platforms-section.md @@ -0,0 +1,75 @@ +(edge-platforms-section)= +# Edge + +Deploy ExecuTorch on mobile, desktop, and embedded platforms with optimized backends for each. + +ExecuTorch supports deployment across a wide variety of edge computing platforms, from high-end mobile devices to constrained embedded systems and microcontrollers. + +## Android + +Deploy ExecuTorch on Android devices with hardware acceleration support. + +**→ {doc}`android-section` — Complete Android deployment guide** + +Key features: + +- Hardware acceleration support (CPU, GPU, NPU) +- Multiple backend options (XNNPACK, Vulkan, Qualcomm, MediaTek, ARM, Samsung) +- Comprehensive examples and demos + +## iOS + +Deploy ExecuTorch on iOS devices with Apple hardware acceleration. + +**→ {doc}`ios-section` — Complete iOS deployment guide** + +Key features: +- Apple hardware optimization (CoreML, MPS, XNNPACK) +- Swift and Objective-C integration +- LLM and computer vision examples + +## Desktop & Laptop Platforms + +Deploy ExecuTorch on Linux, macOS, and Windows with optimized backends. + +**→ {doc}`desktop-section` — Complete desktop deployment guide** + +Key features: +- Cross-platform C++ runtime +- Platform-specific optimization (OpenVINO, CoreML, MPS) +- CPU and GPU acceleration options + +## Embedded Systems + +Deploy ExecuTorch on constrained embedded systems and microcontrollers. + +**→ {doc}`embedded-section` — Complete embedded deployment guide** + +Key features: + +- Resource-constrained deployment +- DSP and NPU acceleration (Cadence, ARM Ethos-U, NXP) +- Custom backend development support +- LLM and computer vision examples + +## Troubleshooting & Support + +- **{doc}`using-executorch-troubleshooting`** - Common issues and solutions across all platforms + +## Next Steps + +After choosing your platform: + +- **{doc}`backends-section`** - Deep dive into backend selection and optimization +- **{doc}`llm/working-with-llms`** - Working with Large Language Models on edge devices + +```{toctree} +:hidden: +:maxdepth: 3 +:caption: Edge Platforms + +android-section +ios-section +desktop-section +embedded-section +using-executorch-troubleshooting diff --git a/docs/source/embedded-arm-ethos-u.md b/docs/source/embedded-arm-ethos-u.md new file mode 100644 index 00000000000..cdc544a6553 --- /dev/null +++ b/docs/source/embedded-arm-ethos-u.md @@ -0,0 +1 @@ +```{include} backends-arm-ethos-u.md diff --git a/docs/source/embedded-backends.md b/docs/source/embedded-backends.md new file mode 100644 index 00000000000..4ed7962ef42 --- /dev/null +++ b/docs/source/embedded-backends.md @@ -0,0 +1,20 @@ +(embedded-backends)= +# Backends + +Available hardware acceleration backends for embedded systems. + +## DSP Acceleration + +- {doc}`embedded-cadence` — Cadence Xtensa DSP processors + +## NPU Acceleration + +- {doc}`embedded-arm-ethos-u` — ARM Ethos-U NPU acceleration +- {doc}`embedded-nxp` — NXP eIQ Neutron Backend + + +```{toctree} +:hidden: +embedded-cadence +embedded-arm-ethos-u +embedded-nxp diff --git a/docs/source/embedded-cadence.md b/docs/source/embedded-cadence.md new file mode 100644 index 00000000000..d2f7ea78259 --- /dev/null +++ b/docs/source/embedded-cadence.md @@ -0,0 +1 @@ +```{include} backends-cadence.md diff --git a/docs/source/embedded-nxp.md b/docs/source/embedded-nxp.md new file mode 100644 index 00000000000..65ae8daff43 --- /dev/null +++ b/docs/source/embedded-nxp.md @@ -0,0 +1 @@ +```{include} backends/nxp/nxp-overview.md diff --git a/docs/source/embedded-section.md b/docs/source/embedded-section.md new file mode 100644 index 00000000000..aac64190030 --- /dev/null +++ b/docs/source/embedded-section.md @@ -0,0 +1,43 @@ +(embedded-section)= + +# Embedded Systems + +Deploy ExecuTorch on constrained embedded systems and microcontrollers. + +## API Reference & Development + +Start here for C++ development with ExecuTorch runtime APIs and essential tutorials. + +- {doc}`executorch-runtime-api-reference` — **Start here**: Complete runtime API reference for embedded development +- {doc}`running-a-model-cpp-tutorial` — Step-by-step C++ API tutorial with practical examples +- {doc}`extension-module` — Custom module extensions for specialized functionality +- {doc}`extension-tensor` — Tensor operations and memory management extensions + +## Build & Integration Guide + +- {doc}`using-executorch-cpp` — Complete setup guide for C++ runtime integration +- {doc}`using-executorch-building-from-source` — Building from Source + +## Choose Backend for acceleration + +- {doc}`embedded-backends` — Available embedded backends and acceleration options + +## Tutorials + +- {doc}`tutorial-arm-ethos-u` — Export a simple PyTorch model for the ExecuTorch Ethos-U backend +- {doc}`raspberry_pi_llama_tutorial` — Deploy a LLaMA model on a Raspberry Pi +- {doc}`pico2_tutorial` — Deploy a demo MNIST model on the Raspberry Pi Pico 2 + + +```{toctree} +:hidden: +executorch-runtime-api-reference +running-a-model-cpp-tutorial +extension-module +extension-tensor +using-executorch-cpp +using-executorch-building-from-source +embedded-backends +tutorial-arm-ethos-u +raspberry_pi_llama_tutorial +pico2_tutorial diff --git a/docs/source/etrecord.rst b/docs/source/etrecord.rst index 1ab84a6ee10..39bc45cab5a 100644 --- a/docs/source/etrecord.rst +++ b/docs/source/etrecord.rst @@ -23,13 +23,120 @@ It should be provided to the `Inspector API `__ to link ba Generating an ``ETRecord`` -------------------------- -The user should use the following API to generate an ``ETRecord`` file. They -will be expected to provide the Edge Dialect program (returned by the call to ``to_edge()``), -the ExecuTorch program (returned by the call to ``to_executorch()``), and optional models that -they are interested in working with via our tooling. +There are multiple ways to generate an ``ETRecord`` for debugging purposes: + +Method 1: Using the ``generate_etrecord`` Parameter (Recommended) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The recommended approach is to enable ``ETRecord`` generation by passing ``generate_etrecord=True`` +to your export API calls. This can be used with: + +* ``executorch.export()`` - High-level export API +* ``to_edge()`` - Edge dialect conversion +* ``to_edge_transform_and_lower()`` - Edge conversion with transformations and lowering + +After export completes, retrieve the ``ETRecord`` using the ``get_etrecord()`` method, and save it using the ``save()`` method: + +**Example with** ``executorch.export()``: + +.. code-block:: python + + import executorch + from executorch.export import ExportRecipe + + # Export with ETRecord generation enabled + session = executorch.export( + model=model, + example_inputs=[example_inputs], + export_recipe=recipe, + generate_etrecord=True # Enable ETRecord generation + ) + + # Get and save the ETRecord + etrecord = session.get_etrecord() + etrecord.save("model_debug.etrecord") + +**Example with** ``to_edge()``: + +.. code-block:: python + + from executorch.exir.program import to_edge + from torch.export import export + + # Export model first + exported_program = export(model, example_inputs) + + # Convert to edge with ETRecord generation + edge_manager = to_edge( + exported_program, + generate_etrecord=True # Enable ETRecord generation + ) + + # Apply transformations + edge_manager = edge_manager.to_backend() + et_manager = edge_manager.to_executorch() + + # Get and save ETRecord + etrecord = et_manager.get_etrecord() + etrecord.save("edge_debug.etrecord") + +**Example with** ``to_edge_transform_and_lower()``: + +.. code-block:: python + + from executorch.exir.program import to_edge_transform_and_lower + from torch.export import export + + # Export model first + exported_program = export(model, example_inputs) + + # Transform and lower with ETRecord generation + edge_manager = to_edge_transform_and_lower( + exported_program, + partitioner=[MyPartitioner()], + generate_etrecord=True # Enable ETRecord generation + ) + + et_manager = edge_manager.to_executorch() + + # Get and save ETRecord + etrecord = et_manager.get_etrecord() + etrecord.save("debug.etrecord") + +Method 2: Using the ``generate_etrecord()`` Function +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You can also use the standalone ``generate_etrecord()`` function to generate an ``ETRecord``. +This method requires you to provide the Edge Dialect program (returned by ``to_edge()``), +the ExecuTorch program (returned by ``to_executorch()``), and optional models. .. warning:: - Users should do a deepcopy of the output of ``to_edge()`` and pass in the deepcopy to the ``generate_etrecord`` API. This is needed because the subsequent call, ``to_executorch()``, does an in-place mutation and will lose debug data in the process. + When using the standalone function, users should do a deepcopy of the output of ``to_edge()`` and pass in the deepcopy to the ``generate_etrecord`` API. This is needed because the subsequent call, ``to_executorch()``, does an in-place mutation and will lose debug data in the process. + +**Example:** + +.. code-block:: python + + import copy + from executorch.devtools import generate_etrecord + from torch.export import export + + # Export and convert to edge + aten_dialect = export(model, example_inputs, strict=True) + edge_program = to_edge(aten_dialect) + + # Create copy for ETRecord (needed because to_executorch modifies in-place) + edge_program_copy = copy.deepcopy(edge_program) + + # Convert to ExecutorchProgramManager + executorch_program = edge_program_copy.to_executorch() + + # Generate ETRecord separately + generate_etrecord( + "debug.etrecord", + edge_program, + executorch_program, + ) .. currentmodule:: executorch.devtools.etrecord._etrecord .. autofunction:: generate_etrecord diff --git a/docs/source/examples-end-to-end-to-lower-model-to-delegate.md b/docs/source/examples-end-to-end-to-lower-model-to-delegate.md index 4ef6bcd0d6e..fd14d718531 100644 --- a/docs/source/examples-end-to-end-to-lower-model-to-delegate.md +++ b/docs/source/examples-end-to-end-to-lower-model-to-delegate.md @@ -19,7 +19,7 @@ There are three flows for delegating a program to a backend: is good for reusing lowered modules exported from other flows. 1. Lower parts of a module according to a partitioner. This is good for lowering models that include both lowerable and non-lowerable nodes, and is - the most streamlined procecss. + the most streamlined process. ### Flow 1: Lowering the whole module diff --git a/docs/source/examples.md b/docs/source/examples.md new file mode 100644 index 00000000000..6a3a8ac29c9 --- /dev/null +++ b/docs/source/examples.md @@ -0,0 +1,9 @@ +# Examples + +```{toctree} +:maxdepth: 1 + +Building an ExecuTorch Android Demo App +Building an ExecuTorch iOS Demo App +tutorial-arm +``` diff --git a/docs/source/executorch_custom_versions.py b/docs/source/executorch_custom_versions.py deleted file mode 100644 index 590f21b10ec..00000000000 --- a/docs/source/executorch_custom_versions.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Sphinx extension to replace ${executorch_version:TAG} with version numbers. - -It also defines a special variable ${executorch_version} that is set to the value -of `EXECUTORCH_VERSION` defined in this file. - -This custom extension pulls third-party version strings from files in the -.ci/docker/ci_commit_pins directory, and uses them to expand specific strings in -markdown files. - -For example, `${executorch_version:pytorch}` will be replaced with the -appropriate pytorch version string used by CI. -""" - -import os - -from docutils import nodes - -version_file_names = [ - "buck2.txt", - "pytorch.txt", -] - -EXECUTORCH_VERSION = "0.7.0" - -variables: dict[str, str] = {} - - -def populate_version_variable(): - variables["${executorch_version}"] = EXECUTORCH_VERSION - cwd = os.getcwd() - version_file_path = os.path.join(cwd, "..", ".ci", "docker", "ci_commit_pins") - - for file_name in version_file_names: - file_path = os.path.join(version_file_path, file_name) - with open(file_path, "r") as f: - var_name = "${executorch_version:" + file_name.split(".")[0] + "}" - variables[var_name] = f.read().strip() - - -populate_version_variable() - - -def replace_variables(app, doctree, docname): - # Replace in regular text: - for node in doctree.traverse(nodes.Text): - new_text = node.astext() - for var, value in variables.items(): - new_text = new_text.replace(var, value) - node.parent.replace(node, nodes.Text(new_text)) - # Replace in code blocks: - for node in doctree.traverse(nodes.literal_block): - new_text = node.astext() - for var, value in variables.items(): - new_text = new_text.replace(var, value) - - classes = node.get("classes", []) - # check if the output is generated by sphinx-gallery and if yes, keep the original - # CSS classes. Otherwise, the sphinx-gallery generated outputs are - # formatted as regular code blocks with gray background instead of pink. - is_sphinx_gallery = any("sphx-glr" in class_ for class_ in classes) - - language = node.get("language") - - if is_sphinx_gallery: - new_literal_block = nodes.literal_block(new_text, new_text, classes=classes) - else: - new_literal_block = nodes.literal_block( - new_text, - new_text, - classes=["highlight-none", "notranslate"], - language=language, - ) - - node.parent.replace(node, new_literal_block) - - -def setup(app): - app.connect("doctree-resolved", replace_variables) diff --git a/docs/source/export-overview.md b/docs/source/export-overview.md index d07701d06cd..c96716a0949 100644 --- a/docs/source/export-overview.md +++ b/docs/source/export-overview.md @@ -11,5 +11,5 @@ program, making it easier for you to understand and implement the process. To learn more about exporting your model: -* Complete the [Exporting to ExecuTorch tutorial](https://pytorch.org/executorch/main/tutorials/export-to-executorch-tutorial). +* Complete the [Exporting to ExecuTorch tutorial](tutorials/export-to-executorch-tutorial) . * Read the [torch.export documentation](https://pytorch.org/docs/2.1/export.html). diff --git a/docs/source/extension-module.md b/docs/source/extension-module.md index 29aa6712d37..690256fecbb 100644 --- a/docs/source/extension-module.md +++ b/docs/source/extension-module.md @@ -6,7 +6,7 @@ In the [Detailed C++ Runtime APIs Tutorial](running-a-model-cpp-tutorial.md), we ## Example -Let's see how we can run the `SimpleConv` model generated from the [Exporting to ExecuTorch tutorial](https://pytorch.org/executorch/main/tutorials/export-to-executorch-tutorial) using the `Module` and [`TensorPtr`](extension-tensor.md) APIs: +Let's see how we can run the `SimpleConv` model generated from the [Exporting to ExecuTorch tutorial](tutorials/export-to-executorch-tutorial) using the `Module` and [`TensorPtr`](extension-tensor.md) APIs: ```cpp #include diff --git a/docs/source/file-formats-advanced.md b/docs/source/file-formats-advanced.md new file mode 100644 index 00000000000..c16ebccfd65 --- /dev/null +++ b/docs/source/file-formats-advanced.md @@ -0,0 +1,17 @@ +(file-formats-advanced)= + +# File Formats + +ExecuTorch file format specifications and internal structure. + +## Program File Formats + +- {doc}`pte-file-format` — PTE (PyTorch ExecuTorch) file format specification +- {doc}`ptd-file-format` — PTD file format specification + +```{toctree} +:hidden: +:maxdepth: 1 + +pte-file-format +ptd-file-format diff --git a/docs/source/getting-started-architecture.md b/docs/source/getting-started-architecture.md index ef4a12d1a7f..617d521b802 100644 --- a/docs/source/getting-started-architecture.md +++ b/docs/source/getting-started-architecture.md @@ -4,7 +4,7 @@ This page describes the technical architecture of ExecuTorch and its individual **Context** -In order to target on-device AI with diverse hardware, critical power requirements, and realtime processing needs, a single monolithic solution is not practical. Instead, a modular, layered, and extendable architecture is desired. ExecuTorch defines a streamlined workflow to prepare (export, transformation, and compilation) and execute a PyTorch program, with opinionated out-of-the-box default components and well-defined entry points for customizations. This architecture greatly improves portability, allowing engineers to use a performant lightweight, cross-platform runtime that easily integrates into different devices and platforms. +In order to target on-device AI with diverse hardware, critical power requirements, and real-time processing needs, a single monolithic solution is not practical. Instead, a modular, layered, and extensible architecture is desired. ExecuTorch defines a streamlined workflow to prepare (export, transformation, and compilation) and execute a PyTorch program, with opinionated out-of-the-box default components and well-defined entry points for customizations. This architecture greatly improves portability, allowing engineers to use a performant lightweight, cross-platform runtime that easily integrates into different devices and platforms. ## Overview @@ -89,6 +89,6 @@ _Executor_ is the entry point to load the program and execute it. The execution ## Developer Tools -It should be efficient for users to go from research to production using the flow above. Productivity is essentially important, for users to author, optimize and deploy their models. We provide [ExecuTorch Developer Tools](devtools-overview.md) to improve productivity. The Developer Tools are not in the diagram. Instead it's a tool set that covers the developer workflow in all three phases. +It should be efficient for users to go from research to production using the flow above. Productivity is especially important, for users to author, optimize and deploy their models. We provide [ExecuTorch Developer Tools](devtools-overview.md) to improve productivity. The Developer Tools are not in the diagram. Instead it's a tool set that covers the developer workflow in all three phases. During the program preparation and execution, users can use the ExecuTorch Developer Tools to profile, debug, or visualize the program. Since the end-to-end flow is within the PyTorch ecosystem, users can correlate and display performance data along with graph visualization as well as direct references to the program source code and model hierarchy. We consider this to be a critical component for quickly iterating and lowering PyTorch programs to edge devices and environments. diff --git a/docs/source/getting-started.md b/docs/source/getting-started.md index d3d9662f5c3..7e5d658a559 100644 --- a/docs/source/getting-started.md +++ b/docs/source/getting-started.md @@ -1,5 +1,5 @@ # Getting Started with ExecuTorch -This section is intended to describe the necessary steps to take PyTorch model and run it using ExecuTorch. To use the framework, you will typically need to take the following steps: +This section is intended to describe the necessary steps to take a PyTorch model and run it using ExecuTorch. To use the framework, you will typically need to take the following steps: - Install the ExecuTorch python package and runtime libraries. - Export the PyTorch model for the target hardware configuration. - Run the model using the ExecuTorch runtime APIs on your development platform. @@ -8,11 +8,11 @@ This section is intended to describe the necessary steps to take PyTorch model a ## System Requirements The following are required to install the ExecuTorch host libraries, needed to export models and run from Python. Requirements for target end-user devices are backend dependent. See the appropriate backend documentation for more information. -- Python 3.10 - 3.12 +- Python 3.10 - 3.13 - g++ version 7 or higher, clang++ version 5 or higher, or another C++17-compatible toolchain. -- Linux (x86_64 or ARM64) or macOS (ARM64). +- Linux (x86_64 or ARM64), macOS (ARM64), or Windows (x86_64). - Intel-based macOS systems require building PyTorch from source (see [Building From Source](using-executorch-building-from-source.md) for instructions). - - Windows is supported via WSL. +- On Windows, Visual Studio 2022 or later. ## Installation To use ExecuTorch, you will need to install both the Python package and the appropriate platform-specific runtime libraries. Pip is the recommended way to install the ExecuTorch python package. @@ -25,6 +25,7 @@ pip install executorch To build the framework from source, see [Building From Source](using-executorch-building-from-source.md). Backend delegates may require additional dependencies. See the appropriate backend documentation for more information. +> **_NOTE:_** On Windows, ExecuTorch requires a [Visual Studio Developer Powershell](https://learn.microsoft.com/en-us/visualstudio/ide/reference/command-prompt-powershell?view=vs-2022). Running from outside of a developer prompt will manifest as errors related to CL.exe.
@@ -44,7 +45,7 @@ ExecuTorch provides hardware acceleration for a wide variety of hardware. The mo For mobile use cases, consider using XNNPACK for Android and Core ML or XNNPACK for iOS as a first step. See [Hardware Backends](backends-overview.md) for more information. ### Exporting -Exporting is done using Python APIs. ExecuTorch provides a high degree of customization during the export process, but the typical flow is as follows. This example uses the MobileNet V2 image classification model implementation in torchvision, but the process supports any [export-compliant](https://pytorch.org/docs/stable/export.html) PyTorch model. For users working with Hugging Face models, +Exporting is done using Python APIs. ExecuTorch provides a high degree of customization during the export process, but the typical flow is as follows. This example uses the MobileNet V2 image classification model implementation in torchvision, but the process supports any [export-compliant](https://pytorch.org/docs/stable/export.html) PyTorch model. For Hugging Face models, you can find a list of supported models in the [*huggingface/optimum-executorch*](https://github.com/huggingface/optimum-executorch) repo. ```python @@ -68,7 +69,7 @@ with open("model.pte", "wb") as f: If the model requires varying input sizes, you will need to specify the varying dimensions and bounds as part of the `export` call. See [Model Export and Lowering](using-executorch-export.md) for more information. -The hardware backend to target is controlled by the partitioner parameter to to\_edge\_transform\_and\_lower. In this example, the XnnpackPartitioner is used to target mobile CPUs. See the [backend-specific documentation](backends-overview.md) for information on how to use each backend. +The hardware backend to target is controlled by the partitioner parameter to `to_edge_transform_and_lower`. In this example, the XnnpackPartitioner is used to target mobile CPUs. See the [backend-specific documentation](backends-overview.md) for information on how to use each backend. Quantization can also be done at this stage to reduce model size and runtime. Quantization is backend-specific. See the documentation for the target backend for a full description of supported quantization schemes. @@ -76,7 +77,7 @@ Quantization can also be done at this stage to reduce model size and runtime. Qu After successfully generating a .pte file, it is common to use the Python runtime APIs to validate the model on the development platform. This can be used to evaluate model accuracy before running on-device. -For the MobileNet V2 model from torchvision used in this example, image inputs are expected as a normalized, float32 tensor with a dimensions of (batch, channels, height, width). The output See [torchvision.models.mobilenet_v2](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html) for more information on the input and output tensor format for this model. +For the MobileNet V2 model from torchvision used in this example, image inputs are expected as a normalized, float32 tensor with a dimensions of (batch, channels, height, width). The output is a tensor containing class logits. See [torchvision.models.mobilenet_v2](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html) for more information on the input and output tensor format for this model. ```python import torch @@ -89,7 +90,7 @@ input_tensor: torch.Tensor = torch.randn(1, 3, 224, 224) program = runtime.load_program("model.pte") method = program.load_method("forward") output: List[torch.Tensor] = method.execute([input_tensor]) -print("Run succesfully via executorch") +print("Run successfully via executorch") from torchvision.models.mobilenetv2 import MobileNet_V2_Weights import torchvision.models as models @@ -103,7 +104,7 @@ print(torch.allclose(output[0], eager_reference_output, rtol=1e-3, atol=1e-5)) For complete examples of exporting and running the model, please refer to our [examples GitHub repository](https://github.com/meta-pytorch/executorch-examples/tree/main/mv2/python). -Additionally, if you work with Hugging Face models, the [*huggingface/optimum-executorch*](https://github.com/huggingface/optimum-executorch) library simplifies running these models end-to-end with ExecuTorch, using familiar Hugging Face APIs. Visit the repository for specific examples and supported models. +Additionally, for Hugging Face models, the [*huggingface/bptimum-executorch*](https://github.com/huggingface/optimum-executorch) library simplifies running these models end-to-end with ExecuTorch using familiar Hugging Face APIs. Visit the repository for specific examples and supported models.
@@ -131,7 +132,7 @@ dependencies { ``` #### Runtime APIs -Models can be loaded and run using the `Module` class: +Models can be loaded and run from Java or Kotlin using the `Module` class. ```java import org.pytorch.executorch.EValue; import org.pytorch.executorch.Module; @@ -147,8 +148,11 @@ EValue[] output = model.forward(input_evalue); float[] scores = output[0].toTensor().getDataAsFloatArray(); ``` +Note that the [C++](#c) APIs can be used when targeting Android native. + For a full example of running a model on Android, see the [DeepLabV3AndroidDemo](https://github.com/meta-pytorch/executorch-examples/tree/main/dl3/android/DeepLabV3Demo). For more information on Android development, including building from source, a full description of the Java APIs, and information on using ExecuTorch from Android native code, see [Using ExecuTorch on Android](using-executorch-android.md). + ### iOS #### Installation @@ -165,22 +169,27 @@ For more information on iOS integration, including an API reference, logging set ExecuTorch provides C++ APIs, which can be used to target embedded or mobile devices. The C++ APIs provide a greater level of control compared to other language bindings, allowing for advanced memory management, data loading, and platform integration. #### Installation -CMake is the preferred build system for the ExecuTorch C++ runtime. To use with CMake, clone the ExecuTorch repository as a subdirectory of your project, and use CMake's `add_subdirectory("executorch")` to include the dependency. The `executorch` target, as well as kernel and backend targets will be made available to link against. The runtime can also be built standalone to support diverse toolchains. See [Using ExecuTorch with C++](using-executorch-cpp.md) for a detailed description of build integration, targets, and cross compilation. +CMake is the preferred build system for the ExecuTorch C++ runtime. To use with CMake, clone the ExecuTorch repository as a subdirectory of your project, and use CMake's `add_subdirectory("executorch")` to include the dependency. The `executorch` target, as well as kernel and backend targets will be made available to link against. The runtime can also be built standalone to support diverse toolchains. See [Using ExecuTorch with C++](using-executorch-cpp.md) and [Building from Source](using-executorch-building-from-source.md) for a detailed description of build integration, targets, and cross compilation. ``` git clone -b viable/strict https://github.com/pytorch/executorch.git ``` -```python +```cmake +# Set CMAKE_CXX_STANDARD to 17 or above. +set(CMAKE_CXX_STANDARD 17) + # CMakeLists.txt +set(EXECUTORCH_BUILD_PRESET_FILE ${CMAKE_SOURCE_DIR}/executorch/tools/cmake/preset/llm.cmake) +# Set other ExecuTorch options here. + add_subdirectory("executorch") ... target_link_libraries( my_target PRIVATE executorch - extension_module_static - extension_tensor - optimized_native_cpu_ops_lib - xnnpack_backend) + executorch::backends + executorch::extensions + executorch::kernels) ``` @@ -226,5 +235,5 @@ ExecuTorch provides a high-degree of customizability to support diverse hardware - [Using ExecuTorch on Android](using-executorch-android.md) and [Using ExecuTorch on iOS](using-executorch-ios.md) for mobile runtime integration. - [Using ExecuTorch with C++](using-executorch-cpp.md) for embedded and mobile native development. - [Profiling and Debugging](using-executorch-troubleshooting.md) for developer tooling and debugging. -- [API Reference](export-to-executorch-api-reference.md) for a full description of available APIs. +- [API Reference](export-to-executorch-api-reference.rst) for a full description of available APIs. - [Examples](https://github.com/pytorch/executorch/tree/main/examples) for demo apps and example code. diff --git a/docs/source/index.md b/docs/source/index.md index d0c9142cf4a..b65139319a7 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -1,297 +1,195 @@ (home)= # Welcome to the ExecuTorch Documentation -**ExecuTorch** is PyTorch's solution to training and inference on the -Edge. +**ExecuTorch** is PyTorch's solution for efficient AI inference on edge devices — from mobile phones to embedded systems. ## Key Value Propositions -- **Portability:** Compatibility with a wide variety of computing - platforms, from high-end mobile phones to highly constrained - embedded systems and microcontrollers. -- **Productivity:** Enabling developers to use the same toolchains and - Developer Tools from PyTorch model authoring and conversion, to - debugging and deployment to a wide variety of platforms. -- **Performance:** Providing end users with a seamless and - high-performance experience due to a lightweight runtime and - utilizing full hardware capabilities such as CPUs, NPUs, and DSPs. - -ExecuTorch provides support for: - -* **Strong Model Support** LLMs (Large Language Models), - CV (Computer Vision), ASR (Automatic Speech Recognition), TTS (Text To Speech) -* **All Major Platforms** Android, Mac, Linux, Windows -* **Rich Acceleration Support** Apple, Arm, Cadence, MediaTek, NXP, OpenVino, Qualcomm, Vulkan, XNNPACK - -### Documentation Navigation -#### Introduction -- [Overview](intro-overview) -- [How it Works](intro-how-it-works) -- [Getting Started with Architecture](getting-started-architecture) -- [Concepts](concepts) -#### Usage -- [Getting Started](getting-started) -- [Using Executorch Export](using-executorch-export) -- [Using Executorch on Android](using-executorch-android) -- [Using Executorch on iOS](using-executorch-ios) -- [Using Executorch with C++](using-executorch-cpp) -- [Runtime Integration](using-executorch-runtime-integration) -- [Troubleshooting](using-executorch-troubleshooting) -- [Building from Source](using-executorch-building-from-source) -- [Quantization](quantization-overview) -- [FAQs](using-executorch-faqs) -#### Examples -- [Android Demo Apps](https://github.com/meta-pytorch/executorch-examples/tree/main/dl3/android/DeepLabV3Demo#executorch-android-demo-app) -- [iOS Demo Apps](https://github.com/meta-pytorch/executorch-examples/tree/main/mv3/apple/ExecuTorchDemo) -- [Hugging Face Models](https://github.com/huggingface/optimum-executorch/blob/main/README.md) -#### Backends -- [Overview](backends-overview) -- [XNNPACK](backends-xnnpack) -- [Core ML](backends-coreml) -- [MPS](backends-mps) -- [Vulkan](backends-vulkan) -- [ARM Ethos-U](backends-arm-ethos-u) -- [Qualcomm](backends-qualcomm) -- [MediaTek](backends-mediatek) -- [Cadence](backends-cadence) -- [OpenVINO](build-run-openvino) -- [NXP](backend-nxp) -#### Developer Tools -- [Overview](devtools-overview) -- [Bundled IO](bundled-io) -- [ETRecord](etrecord) -- [ETDump](etdump) -- [Runtime Profiling](runtime-profiling) -- [Model Debugging](model-debugging) -- [Model Inspector](model-inspector) -- [Memory Planning Inspection](memory-planning-inspection) -- [Delegate Debugging](delegate-debugging) -- [Tutorial](devtools-tutorial) -#### Runtime -- [Overview](runtime-overview) -- [Extension Module](extension-module) -- [Extension Tensor](extension-tensor) -- [Detailed C++ Runtime APIs Tutorial](running-a-model-cpp-tutorial) -- [Backend Delegate Implementation and Linking](runtime-backend-delegate-implementation-and-linking) -- [Platform Abstraction Layer](runtime-platform-abstraction-layer) -#### Portable C++ Programming -- [PTE File Format](pte-file-format) -- [PTD File Format](ptd-file-format) -#### API Reference -- [Export to Executorch API Reference](export-to-executorch-api-reference) -- [Executorch Runtime API Reference](executorch-runtime-api-reference) -- [Runtime Python API Reference](runtime-python-api-reference) -- [API Life Cycle](api-life-cycle) -- [Javadoc](https://pytorch.org/executorch/main/javadoc/) -#### Kernel Library -- [Overview](kernel-library-overview) -- [Custom ATen Kernel](kernel-library-custom-aten-kernel) -- [Selective Build](kernel-library-selective-build) -#### Working with LLMs -- [Getting Started](llm/getting-started.md) -- [Exporting LLMs](llm/export-llm.md) -- [Exporting custom LLMs](llm/export-custom-llm.md) -- [Running with C++](llm/run-with-c-plus-plus.md) -- [Running on Android (XNNPack)](llm/llama-demo-android.md) -- [Running on Android (QNN)](llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md) -- [Running on iOS](llm/run-on-ios.md) -#### Backend Development -- [Delegates Integration](backend-delegates-integration) -- [XNNPACK Reference](backend-delegates-xnnpack-reference) -- [Dependencies](backend-delegates-dependencies) -- [Compiler Delegate and Partitioner](compiler-delegate-and-partitioner) -- [Debug Backend Delegate](debug-backend-delegate) -#### IR Specification -- [EXIR](ir-exir) -- [Ops Set Definition](ir-ops-set-definition) -#### Compiler Entry Points -- [Backend Dialect](compiler-backend-dialect) -- [Custom Compiler Passes](compiler-custom-compiler-passes) -- [Memory Planning](compiler-memory-planning) -#### Contributing -- [Contributing](contributing) +- **Portability:** Run on diverse platforms, from high-end mobile to constrained microcontrollers +- **Performance:** Lightweight runtime with full hardware acceleration (CPU, GPU, NPU, DSP) +- **Productivity:** Use familiar PyTorch tools from authoring to deployment -```{toctree} -:glob: -:maxdepth: 1 -:caption: Introduction -:hidden: +--- -intro-overview -intro-how-it-works -getting-started-architecture -concepts -``` +## 🎯 Wins & Success Stories -```{toctree} -:glob: -:maxdepth: 1 -:caption: Usage -:hidden: +::::{grid} 1 +:class-container: success-showcase +:::{grid-item-card} +:class-header: bg-primary text-white +:class-body: text-center +[View All Success Stories →](success-stories) +::: +:::: -getting-started -using-executorch-export -using-executorch-android -using-executorch-ios -using-executorch-cpp -using-executorch-runtime-integration -using-executorch-troubleshooting -using-executorch-building-from-source -using-executorch-faqs -``` +--- -```{toctree} -:glob: -:maxdepth: 1 -:caption: Examples -:hidden: +## Quick Navigation -Building an ExecuTorch Android Demo App -Building an ExecuTorch iOS Demo App -tutorial-arm.md -``` +::::{grid} 2 -```{toctree} -:glob: -:maxdepth: 1 -:caption: Backends -:hidden: +:::{grid-item-card} **Get Started** +:link: quick-start-section +:link-type: doc -backends-overview -backends-xnnpack -backends-coreml -backends-mps -backends-vulkan -backends-arm-ethos-u -backends-qualcomm -backends-mediatek -backends-cadence -OpenVINO Backend -backends-nxp -``` +New to ExecuTorch? Start here for installation and your first model deployment. +::: -```{toctree} -:glob: -:maxdepth: 1 -:caption: Developer Tools -:hidden: +:::{grid-item-card} **Deploy on Edge Platforms** +:link: edge-platforms-section +:link-type: doc -devtools-overview -bundled-io -etrecord -etdump -runtime-profiling -model-debugging -model-inspector -memory-planning-inspection -delegate-debugging -devtools-tutorial -``` +Deploy on Android, iOS, Laptops / Desktops and embedded platforms with optimized backends. +::: -```{toctree} -:glob: -:maxdepth: 1 -:caption: Runtime -:hidden: +:::{grid-item-card} **Work with LLMs** +:link: llm/working-with-llms +:link-type: doc -runtime-overview -extension-module -extension-tensor -running-a-model-cpp-tutorial -runtime-backend-delegate-implementation-and-linking -runtime-platform-abstraction-layer -portable-cpp-programming -pte-file-format -ptd-file-format -``` +Export, optimize, and deploy Large Language Models on edge devices. +::: -```{toctree} -:glob: -:maxdepth: 1 -:caption: API Reference -:hidden: +:::{grid-item-card} 🔧 **Developer Tools** +:link: tools-section +:link-type: doc -export-to-executorch-api-reference -executorch-runtime-api-reference -runtime-python-api-reference -api-life-cycle -Javadoc -``` +Profile, debug, and inspect your models with comprehensive tooling. +::: -```{toctree} -:glob: -:maxdepth: 1 -:caption: Quantization -:hidden: +:::: -quantization-overview -``` +--- -```{toctree} -:glob: -:maxdepth: 1 -:caption: Kernel Library -:hidden: +## Explore Documentation -kernel-library-overview -kernel-library-custom-aten-kernel -kernel-library-selective-build -``` +::::{grid} 1 +:::{grid-item-card} **Intro** +:link: intro-section +:link-type: doc -```{toctree} -:glob: -:maxdepth: 2 -:caption: Working with LLMs -:hidden: +**Overview, architecture, and core concepts** — Understand how ExecuTorch works and its benefits +::: +:::: -Getting Started -Exporting LLMs with export_llm -Exporting custom LLMs -Running with C++ -Running on Android -Running on Android -Running on iOS -``` +::::{grid} 1 +:::{grid-item-card} **Quick Start** +:link: quick-start-section +:link-type: doc -```{toctree} -:glob: -:maxdepth: 1 -:caption: Backend Development -:hidden: +**Get started with ExecuTorch** — Install, export your first model, and run inference +::: +:::: -backend-delegates-integration -backend-delegates-xnnpack-reference -backend-delegates-dependencies -compiler-delegate-and-partitioner -debug-backend-delegate -``` +::::{grid} 1 +:::{grid-item-card} **Edge** +:link: edge-platforms-section +:link-type: doc -```{toctree} -:glob: -:maxdepth: 1 -:caption: IR Specification -:hidden: +**Android, iOS, Desktop, Embedded** — Platform-specific deployment guides and examples +::: +:::: -ir-exir -ir-ops-set-definition -``` +::::{grid} 1 +:::{grid-item-card} **Backends** +:link: backends-section +:link-type: doc -```{toctree} -:glob: -:maxdepth: 1 -:caption: Compiler Entry Points -:hidden: +**CPU, GPU, NPU/Accelerator backends** — Hardware acceleration and backend selection +::: +:::: + +::::{grid} 1 +:::{grid-item-card} **LLMs** +:link: llm/working-with-llms +:link-type: doc + +**LLM export, optimization, and deployment** — Complete LLM workflow for edge devices +::: +:::: + +::::{grid} 1 +:::{grid-item-card} **Advanced** +:link: advanced-topics-section +:link-type: doc + +**Quantization, memory planning, custom passes** — Deep customization and optimization +::: +:::: + +::::{grid} 1 +:::{grid-item-card} **Tools** +:link: tools-section +:link-type: doc + +**Developer tools, profiling, debugging** — Comprehensive development and debugging suite +::: +:::: -compiler-backend-dialect -compiler-custom-compiler-passes -compiler-memory-planning -``` +::::{grid} 1 +:::{grid-item-card} **API** +:link: api-section +:link-type: doc + +**API Reference Usages & Examples** — Detailed Python, C++, and Java API references +::: +:::: + +::::{grid} 1 +:::{grid-item-card} **💬 Support** +:link: support-section +:link-type: doc + +**FAQ, troubleshooting, contributing** — Get help and contribute to the project +::: +:::: + +--- + +## What's Supported + +::::{grid} 3 + +:::{grid-item} +**Model Types** + +- Large Language Models (LLMs) +- Computer Vision (CV) +- Speech Recognition (ASR) +- Text-to-Speech (TTS) +- More ... +::: + +:::{grid-item} +**Platforms** + +- Android & iOS +- Linux, macOS, Windows +- Embedded & MCUs +- Go **→ {doc}`edge-platforms-section`** +::: + +:::{grid-item} +**Rich Acceleration** + +- CPU +- GPU +- NPU +- DSP +- Go **→ {doc}`backends-section`** +::: + +:::: ```{toctree} -:glob: -:maxdepth: 1 -:caption: Contributing :hidden: +:maxdepth: 1 -contributing -``` +intro-section +quick-start-section +edge-platforms-section +backends-section +llm/working-with-llms +advanced-topics-section +tools-section +api-section +support-section diff --git a/docs/source/intro-how-it-works.md b/docs/source/intro-how-it-works.md index 3e6d384a62f..3ced602fed4 100644 --- a/docs/source/intro-how-it-works.md +++ b/docs/source/intro-how-it-works.md @@ -6,7 +6,7 @@ At a high-level, there are three steps for running a PyTorch model with ExecuTor 1. **Export the model.** The first step is to capture the PyTorch program as a graph, which is a new representation of the model that can be expressed in terms of a series of operators such as addition, multiplication, or convolution. This process safely preserves the semantics of the original PyTorch program. This representation is the first step to enable running the model on edge use cases that have low memory and/or low compute. 1. **Compile the exported model to an ExecuTorch program.** Given an exported model from step 1, convert it to an executable format called an ExecuTorch program that the runtime can use for inference. This step provides entry points for various optimizations such as compressing the model (e.g., quantization) to reduce size and further compiling subgraphs down to on-device specialized hardware accelerators to improve latency. It also provides an entry point for memory planning, i.e. to efficiently plan the location of intermediate tensors to reduce the runtime memory footprint. -1. **Run the ExecuTorch program on a target device.** Given an input--such as an image represented as an input activation tensor--the ExecuTorch runtime loads the ExecuTorch program, executes the instructions represented by the program, and computes an output. This step is efficient because (1) the runtime is lightweight and (2) an efficient execution plan has already been calculated in steps 1 and 2, making it possible to do performant inference. Furthermore, portability of the core runtime enabled performant execution even on highly-constrained devices. +1. **Run the ExecuTorch program on a target device.** Given an input--such as an image represented as an input activation tensor--the ExecuTorch runtime loads the ExecuTorch program, executes the instructions represented by the program, and computes an output. This step is efficient because (1) the runtime is lightweight and (2) an efficient execution plan has already been calculated in steps 1 and 2, making it possible to do performant inference. Furthermore, portability of the core runtime enables performant execution even on highly-constrained devices. This figure illustrates the three-step process of exporting a PyTorch program, compiling it into an ExecuTorch program that targets a specific hardware device, and finally executing the program on the device using the ExecuTorch runtime. ![name](_static/img/how-executorch-works-high-level.png) diff --git a/docs/source/intro-overview.md b/docs/source/intro-overview.md index 96c7982b8fe..be2fd468716 100644 --- a/docs/source/intro-overview.md +++ b/docs/source/intro-overview.md @@ -20,7 +20,7 @@ Key value propositions of ExecuTorch are: ## Why ExecuTorch? Supporting on-device AI presents unique challenges with diverse hardware, -critical power requirements, low/no internet connectivity, and realtime +critical power requirements, low/no internet connectivity, and real-time processing needs. These constraints have historically prevented or slowed down the creation of scalable and performant on-device AI solutions. We designed ExecuTorch, backed by our industry partners like Meta, Arm, Apple, and Qualcomm, diff --git a/docs/source/intro-section.md b/docs/source/intro-section.md new file mode 100644 index 00000000000..2f6f3c57c88 --- /dev/null +++ b/docs/source/intro-section.md @@ -0,0 +1,27 @@ +(intro-section)= + +# Intro + +Overview, architecture, and core concepts of ExecuTorch. + +ExecuTorch is PyTorch's solution for training and inference on the Edge, providing portability, productivity, and performance for edge computing platforms. + +## Getting Started with ExecuTorch + +New to ExecuTorch? Start with these foundational topics: + +- **{doc}`intro-overview`** - High-level overview of ExecuTorch capabilities +- **{doc}`intro-how-it-works`** - Technical overview of the ExecuTorch workflow +- **{doc}`getting-started-architecture`** - System architecture and components +- **{doc}`concepts`** - Core concepts and terminology + +```{toctree} +:hidden: +:maxdepth: 2 +:caption: Introduction Topics + +intro-overview +intro-how-it-works +getting-started-architecture +concepts +``` diff --git a/docs/source/ios-backends.md b/docs/source/ios-backends.md new file mode 100644 index 00000000000..cb186f53319 --- /dev/null +++ b/docs/source/ios-backends.md @@ -0,0 +1,19 @@ +(ios-backends)= +# Backends + +Available hardware acceleration backends for iOS deployment. + +## Apple Hardware Acceleration (Recommended) + +- {doc}`ios-coreml` — CoreML (NPU/GPU, recommended for iOS) +- {doc}`ios-mps` — Metal Performance Shaders (GPU) + +## CPU Acceleration + +- {doc}`ios-xnnpack` — XNNPACK (CPU acceleration) + +```{toctree} +:hidden: +ios-coreml +ios-mps +ios-xnnpack diff --git a/docs/source/ios-coreml.md b/docs/source/ios-coreml.md new file mode 100644 index 00000000000..ff6551aa0c2 --- /dev/null +++ b/docs/source/ios-coreml.md @@ -0,0 +1 @@ +```{include} backends/coreml/coreml-overview.md diff --git a/docs/source/ios-examples.md b/docs/source/ios-examples.md new file mode 100644 index 00000000000..86acf3273a6 --- /dev/null +++ b/docs/source/ios-examples.md @@ -0,0 +1,4 @@ +# Examples & Demos + +- [iOS LLM Examples Repository](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/apple) +- [MobileViT Demo App](https://github.com/meta-pytorch/executorch-examples/tree/main/mv3/apple/ExecuTorchDemo) diff --git a/docs/source/ios-mps.md b/docs/source/ios-mps.md new file mode 100644 index 00000000000..13717675ba5 --- /dev/null +++ b/docs/source/ios-mps.md @@ -0,0 +1 @@ +```{include} backends/mps/mps-overview.md diff --git a/docs/source/ios-section.md b/docs/source/ios-section.md new file mode 100644 index 00000000000..33c9a61ce1d --- /dev/null +++ b/docs/source/ios-section.md @@ -0,0 +1,23 @@ +(ios-section)= +# iOS + +Deploy ExecuTorch on iOS devices with Apple hardware acceleration. + +## Quick Start & Integration + +- {doc}`using-executorch-ios` — Complete iOS integration guide + +## Backends + +- {doc}`ios-backends` — Available iOS backends and acceleration options + +## Examples & Demos + +- {doc}`ios-examples` — Explore iOS Examples & Demos + + +```{toctree} +:hidden: +using-executorch-ios +ios-backends +ios-examples diff --git a/docs/source/ios-xnnpack.md b/docs/source/ios-xnnpack.md new file mode 100644 index 00000000000..4a85dec946b --- /dev/null +++ b/docs/source/ios-xnnpack.md @@ -0,0 +1 @@ +```{include} backends/xnnpack/xnnpack-overview.md diff --git a/docs/source/ir-specification.md b/docs/source/ir-specification.md new file mode 100644 index 00000000000..c58098ffc67 --- /dev/null +++ b/docs/source/ir-specification.md @@ -0,0 +1,8 @@ +# IR Specification + +```{toctree} +:maxdepth: 1 + +ir-exir +ir-ops-set-definition +``` diff --git a/docs/source/kernel-library-advanced.md b/docs/source/kernel-library-advanced.md new file mode 100644 index 00000000000..5f0215b87c1 --- /dev/null +++ b/docs/source/kernel-library-advanced.md @@ -0,0 +1,23 @@ +(kernel-library-advanced)= + +# Kernel Library Deep Dive + +Advanced kernel implementation and customization for ExecuTorch. + +## Kernel Library Overview + +- {doc}`kernel-library-overview` — Architecture and design of the kernel library + +- {doc}`kernel-library-custom-aten-kernel` — Kernel registration and customization + +## Build Optimization + +- {doc}`kernel-library-selective-build` — Selective build for reduced binary footprint + +```{toctree} +:hidden: +:maxdepth: 1 + +kernel-library-overview +kernel-library-custom-aten-kernel +kernel-library-selective-build diff --git a/docs/source/kernel-library-overview.md b/docs/source/kernel-library-overview.md index cfd46524097..a826b334ba4 100644 --- a/docs/source/kernel-library-overview.md +++ b/docs/source/kernel-library-overview.md @@ -1,7 +1,7 @@ -This page provides a description of the Portable Kernel Library and the Optimized Kernel Library, which are the default kernel libraries shipped with ExecuTorch. It is recommended reading for those who are interested in executing ExecuTorch programs with these kernel libraries, or for those who want to implement their own kernels and kernel libraries. - # Overview of ExecuTorch’s Kernel Libraries +This page provides a description of the Portable Kernel Library and the Optimized Kernel Library, which are the default kernel libraries shipped with ExecuTorch. It is recommended reading for those who are interested in executing ExecuTorch programs with these kernel libraries, or for those who want to implement their own kernels and kernel libraries. + An ExecuTorch program encodes instructions that describe the computation that should be performed by the program. Many of these instructions will correspond to calling a specific ATen operator, for example `aten.convolution`. However, one of the core design principles of ExecuTorch is that the signature of an operator should be separate from the implementation of the operator. This means that the ExecuTorch runtime does not ship with any standard implementation for ATen operators; users must make sure to link against kernel libraries that contain implementations of the operators required by their ExecuTorch program, and configure [operator registration](kernel-library-custom-aten-kernel.md) to map an operator signature to the desired implementation. This makes it easy to adjust the implementation of operators such as `aten.convolution` that will be called when executing an ExecuTorch program; it allows users to select the exact operator implementations that will meet the unique performance, memory usage, battery usage, etc. constraints of their use-case. **In essence, a kernel library is simply a collection of ATen operator implementations that follow a common theme or design principle**. Note that due to ExecuTorch’s selective build process (discussed in the following section), operator implementations are linked individually. This means that users can easily mix different kernel libraries in their build without sacrificing build size. diff --git a/docs/source/kernel-library-selective-build.md b/docs/source/kernel-library-selective-build.md index 7d6495656a2..edec9567b7b 100644 --- a/docs/source/kernel-library-selective-build.md +++ b/docs/source/kernel-library-selective-build.md @@ -61,11 +61,11 @@ gen_selected_ops( ROOT_OPS # comma separated operator names to be selected INCLUDE_ALL_OPS # boolean flag to include all operators OPS_FROM_MODEL # path to a pte file of model to select operators from - DTYPE_SELECTIVE_BUILD # boolean flag to enable dtye selection + DTYPE_SELECTIVE_BUILD # boolean flag to enable dtype selection ) ``` -The macro makes a call to gen_oplist.py, which requires a [distinct selection](https://github.com/BujSet/executorch/blob/main/codegen/tools/gen_oplist.py#L222-L228) of API choice. `OPS_SCHEMA_YAML`, `ROOT_OPS`, `INCLUDE_ALL_OPS`, and `OPS_FROM_MODEL` are mutually exclusive options, and should not be used in conjunction. +The macro makes a call to gen_oplist.py, which requires a [distinct selection](https://github.com/pytorch/executorch/blob/main/codegen/tools/gen_oplist.py#L222-L228) of API choice. `OPS_SCHEMA_YAML`, `ROOT_OPS`, `INCLUDE_ALL_OPS`, and `OPS_FROM_MODEL` are mutually exclusive options, and should not be used in conjunction. ### Select all ops @@ -83,7 +83,7 @@ This API lets users pass in a list of operator names. Note that this API can be ### Select ops from model -This API lets users pass in a pte file of an exported model. When used, the pte file will be parsed to generate a yaml file that enumerates the operators and dtypes used in the model. +This API lets users pass in a pte file of an exported model. When used, the pte file will be parsed to generate a yaml file that enumerates the operators and dtypes used in the model. ### Dtype Selective Build @@ -91,7 +91,7 @@ Beyond pruning the binary to remove unused operators, the binary size can furthe ## Example Walkthrough -In [examples/selective_build/CMakeLists.txt](https://github.com/BujSet/executorch/blob/main/examples/selective_build/CMakeLists.txt#L48-L72), we have the following cmake config options: +In [examples/selective_build/CMakeLists.txt](https://github.com/pytorch/executorch/blob/main/examples/selective_build/advanced/CMakeLists.txt), we have the following cmake config options: 1. `EXECUTORCH_SELECT_OPS_YAML` 2. `EXECUTORCH_SELECT_OPS_LIST` @@ -99,10 +99,10 @@ In [examples/selective_build/CMakeLists.txt](https://github.com/BujSet/executorc 4. `EXECUTORCH_SELECT_OPS_FROM_MODEL` 5. `EXECUTORCH_DTYPE_SELECTIVE_BUILD` -These options allow a user to tailor the cmake build process to utilize the different APIs, and results in different invocations on the `gen_selected_ops` [function](https://github.com/BujSet/executorch/blob/main/examples/selective_build/CMakeLists.txt#L110-L123). The following table describes some examples of how the invocation changes when these configs are set: +These options allow a user to tailor the cmake build process to utilize the different APIs, and results in different invocations on the `gen_selected_ops` [function](https://github.com/pytorch/executorch/blob/main/examples/selective_build/advanced/CMakeLists.txt). The following table describes some examples of how the invocation changes when these configs are set: | Example cmake Call | Resultant `gen_selected_ops` Invocation | -| :----: | :---:| +| :----: | :---:| |
cmake -D… -DSELECT_OPS_LIST="aten::add.out,aten::mm.out"
|
gen_selected_ops("" "${SELECT_OPS_LIST}" "" "" "")
| |
cmake -D… -DSELECT_OPS_YAML=ON
|
set(_custom_ops_yaml ${EXECUTORCH_ROOT}/examples/portable/custom_ops/custom_ops.yaml)
gen_selected_ops("${_custom_ops_yaml}" "" "")
| |
cmake -D… -DEXECUTORCH_SELECT_OPS_FROM_MODEL="model.pte.out"
|
gen_selected_ops("" "" "" "${_model_path}" "")
| diff --git a/docs/source/kernel-library.md b/docs/source/kernel-library.md new file mode 100644 index 00000000000..a995a20973b --- /dev/null +++ b/docs/source/kernel-library.md @@ -0,0 +1,9 @@ +# Kernel Library + +```{toctree} +:maxdepth: 1 + +kernel-library-overview +kernel-library-custom-aten-kernel +kernel-library-selective-build +``` diff --git a/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md b/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md index 4587589a51b..1168c4c04a3 100644 --- a/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md +++ b/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md @@ -1,6 +1,7 @@ -# Building and Running Llama 3 8B Instruct with Qualcomm AI Engine Direct Backend +# Run Llama 3 3B Instruct on Android (with Qualcomm AI Engine Direct Backend) -This tutorial demonstrates how to export Llama 3 8B Instruct for Qualcomm AI Engine Direct Backend and running the model on a Qualcomm device. +This tutorial demonstrates how to export and run the Llama 3 3B Instruct model on a Qualcomm device using the Qualcomm AI Engine Direct Backend via ExecuTorch. +We use a static Llama [implementation](https://github.com/pytorch/executorch/blob/main/examples/qualcomm/oss_scripts/llama/model/static_llama.py) to optimize performance and memory usage during on-device inference. ## Prerequisites @@ -13,10 +14,8 @@ This tutorial demonstrates how to export Llama 3 8B Instruct for Qualcomm AI Eng ## Instructions -### Step 1: Prepare the checkpoint of the model and optimized matrix from [Spin Quant](https://github.com/facebookresearch/SpinQuant) - -1. For Llama 3 tokenizer and checkpoint, please refer to https://github.com/meta-llama/llama-models/blob/main/README.md for further instructions on how to download `tokenizer.model`, `consolidated.00.pth` and `params.json`. -2. To get the optimized matrix, please refer to [SpinQuant on GitHub](https://github.com/facebookresearch/SpinQuant). You can download the optimized rotation matrices in the Quantized Models section. Please choose **LLaMA-3-8B/8B_W4A16KV16_lr_1.5_seed_0**. +### Step 1: Prepare the checkpoint and tokenizer of the model. +1. For Llama 3 tokenizer and checkpoint, please refer to [instructions](https://www.llama.com/models/llama-3) for further instructions on how to download `tokenizer.model`, `consolidated.00.pth` and `params.json`. ### Step 2: Export to ExecuTorch with Qualcomm AI Engine Direct Backend Deploying large language models like Llama 3 on-device presents the following challenges: @@ -25,122 +24,74 @@ Deploying large language models like Llama 3 on-device presents the following ch 2. High model loading and inference time. 3. Difficulty in quantization. -To address these challenges, we have implemented the following solutions: -1. Using `quantization.pt2e_quantize = "qnn_16a4w'` to quantize activations and weights, thereby reducing the on-disk model size and alleviating memory pressure during inference. -2. Using `backed.qnn.num_sharding = 8` to shard the model into sub-parts. -3. Performing graph transformations to convert or decompose operations into more accelerator-friendly operations. -4. Using `backend.qnn.optimized_rotation_path = ""` to apply R1 and R2 of [Spin Quant](https://github.com/facebookresearch/SpinQuant) to improve accuracy. -5. Using `quantization.calibration_data = "<|start_header_id|>system<|end_header_id|..."` to ensure that during quantization, the calibration includes special tokens in the prompt template. For more details on the prompt template, refer to [the model card](https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/). +To address these, we apply the following optimizations: + +1. Quantization: Apply the `quant_recipe` when setting the quantization config to reduce model size and memory usage. + +2. Mixed Precision Quantization: compresses KV cache tensors to 8-bit and applies `QuantDtype.use_16a8w` to the LM head. + +3. Model Sharding: Set `num_sharding` = 4 to shard the model into sub-parts. This helps reduce memory pressure and improve performance during on-device inference. The number of shards might be different depending on the model size. + +4. Graph Transformations: Convert operations into accelerator-friendly formats for better runtime performance. + +You can find the full optimization configuration in this [file](https://github.com/pytorch/executorch/blob/main/examples/qualcomm/oss_scripts/llama/__init__.py), as shown below: + +``` python +@register_llm_model("llama3_2-3b_instruct") +@dataclass(init=False, frozen=True) +class Llama3_2_3B_Instruct(LLMModelConfig): + repo_id = None + params_path = None + convert_weights = None + transform_weight = True + # The Llama3_2 enabled should be instruct, however, Llama's tokenizer does not provide utility to apply chat template. + instruct_model = False + + num_sharding = 4 + masked_softmax = False + + # SeqMSE Quantization: optimizes the parameter encodings of each layer of a model individually to minimize the difference between the layer’s original and quantized outputs. (Implementation details: ./backends/qualcomm/_passes/seq_mse.py) In this configuration, we set `seq_mse_candidates` = 0, which means SeqMSE quantization is not applied. + seq_mse_candidates = 0 + r1 = False + r2 = False + r3 = False + # quant recipe + quant_recipe = Llama3_3BQuantRecipe +``` + To export with the Qualcomm AI Engine Direct Backend, ensure the following: -1. The host machine has more than 100GB of memory (RAM + swap space). +1. The host machine has more than 64GB of memory (RAM + swap space). 2. The entire process takes a few hours. ```bash -# path/to/config.yaml -base: - model_class: llama3 - checkpoint: path/to/consolidated.00.pth - params: path/to/params.json - tokenizer_path: path/to/tokenizer.model - metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' -model: - use_kv_cache: True - enable_dynamic_shape: False -quantization: - pt2e_quantize: qnn_16a4w - # Please note that calibration_data must include the prompt template for special tokens. - calibration_data: "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" -backend: - qnn: - enabled: True - num_sharding: 8 - - -# export_llm -python -m extension.llm.export.export_llm \ - --config path/to/config.yaml +# export llama +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 --compile_only ``` +Note: end-to-end [instructions](https://github.com/pytorch/executorch/blob/main/examples/qualcomm/oss_scripts/llama/README.md) ### Step 3: Invoke the Runtime on an Android smartphone with Qualcomm SoCs -1. Build executorch with Qualcomm AI Engine Direct Backend for android - ```bash - cmake \ - -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake" \ - -DANDROID_ABI=arm64-v8a \ - -DCMAKE_INSTALL_PREFIX=cmake-android-out \ - -DCMAKE_BUILD_TYPE=Release \ - -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ - -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ - -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ - -DEXECUTORCH_BUILD_QNN=ON \ - -DQNN_SDK_ROOT=${QNN_SDK_ROOT} \ - -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ - -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ - -DEXECUTORCH_BUILD_KERNELS_LLM=ON \ - -Bcmake-android-out . - - cmake --build cmake-android-out -j16 --target install --config Release - ``` -2. Build llama runner for android -```bash - cmake \ - -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_ROOT}"/build/cmake/android.toolchain.cmake \ - -DANDROID_ABI=arm64-v8a \ - -DCMAKE_INSTALL_PREFIX=cmake-android-out \ - -DCMAKE_BUILD_TYPE=Release -DPYTHON_EXECUTABLE=python \ - -DEXECUTORCH_BUILD_QNN=ON \ - -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ - -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ - -DEXECUTORCH_BUILD_KERNELS_LLM=ON \ - -Bcmake-android-out/examples/models/llama examples/models/llama - - cmake --build cmake-android-out/examples/models/llama -j16 --config Release -``` -3. Run on Android via adb shell -*Pre-requisite*: Make sure you enable USB debugging via developer options on your phone - **3.1 Connect your android phone** -**3.2 We need to push required QNN libraries to the device.** -```bash -# make sure you have write-permission on below path. -DEVICE_DIR=/data/local/tmp/llama -adb shell mkdir -p ${DEVICE_DIR} -adb push ${QNN_SDK_ROOT}/lib/aarch64-android/libQnnHtp.so ${DEVICE_DIR} -adb push ${QNN_SDK_ROOT}/lib/aarch64-android/libQnnSystem.so ${DEVICE_DIR} -adb push ${QNN_SDK_ROOT}/lib/aarch64-android/libQnnHtpV69Stub.so ${DEVICE_DIR} -adb push ${QNN_SDK_ROOT}/lib/aarch64-android/libQnnHtpV73Stub.so ${DEVICE_DIR} -adb push ${QNN_SDK_ROOT}/lib/aarch64-android/libQnnHtpV75Stub.so ${DEVICE_DIR} -adb push ${QNN_SDK_ROOT}/lib/hexagon-v69/unsigned/libQnnHtpV69Skel.so ${DEVICE_DIR} -adb push ${QNN_SDK_ROOT}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so ${DEVICE_DIR} -adb push ${QNN_SDK_ROOT}/lib/hexagon-v75/unsigned/libQnnHtpV75Skel.so ${DEVICE_DIR} -``` - -**3.3 Upload model, tokenizer and llama runner binary to phone** -```bash -adb push ${DEVICE_DIR} -adb push ${DEVICE_DIR} -adb push cmake-android-out/lib/libqnn_executorch_backend.so ${DEVICE_DIR} -adb push cmake-out-android/examples/models/llama/llama_main ${DEVICE_DIR} -``` +**3.2 Make sure the following artifact is present before running the model.** +-- artifact/ + └── llama_qnn.pte -**3.4 Run model** +**3.3 Run model** ```bash -adb shell "cd ${DEVICE_DIR} && ./llama_main --model_path --tokenizer_path --prompt \"<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n\" --seq_len 128" -``` -You should see the message: -``` -<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHello! I'd be delighted to chat with you about Facebook. Facebook is a social media platform that was created in 2004 by Mark Zuckerberg and his colleagues while he was a student at Harvard University. It was initially called "Facemaker" but later changed to Facebook, which is a combination of the words "face" and "book". The platform was initially intended for people to share their thoughts and share information with their friends, but it quickly grew to become one of the +# Run llama +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 --pre_gen_pte ${PATH_TO_ARTIFACT} ``` ## What is coming? - Performance improvements - Reduce the memory pressure during inference to support 12GB Qualcomm devices -- Support more LLMs (Qwen, Phi-4-mini, etc.) +- Broader LLM Support via [Optimum ExecuTorch](https://github.com/huggingface/optimum-executorch?tab=readme-ov-file#llms-large-language-models) + + - Already supported models (e.g.): Llama2, Llama3, Gemma, Qwen, Phi-4, SmolLM. For usage examples, please refer to [README](https://github.com/pytorch/executorch/blob/main/examples/qualcomm/oss_scripts/llama/README.md) ## FAQ If you encounter any issues while reproducing the tutorial, please file a github -issue on ExecuTorch repo and tag use `#qcom_aisw` tag +[issue](https://github.com/pytorch/executorch/issues) on ExecuTorch repo and tag use `#qcom_aisw` tag \ No newline at end of file diff --git a/docs/source/llm/export-custom-llm.md b/docs/source/llm/export-custom-llm.md index 57537ba31d8..4797f773fa3 100644 --- a/docs/source/llm/export-custom-llm.md +++ b/docs/source/llm/export-custom-llm.md @@ -81,7 +81,7 @@ with open("nanogpt.pte", "wb") as file: To export, run the script with `python export_nanogpt.py` (or python3, as appropriate for your environment). It will generate a `nanogpt.pte` file in the current directory. -For more information, see [Exporting to ExecuTorch](https://pytorch.org/executorch/main/tutorials/export-to-executorch-tutorial) and +For more information, see [Exporting to ExecuTorch](../tutorials/export-to-executorch-tutorial) and [torch.export](https://pytorch.org/docs/stable/export.html). ## Backend delegation @@ -143,7 +143,7 @@ example_inputs = ( # long as they adhere to the rules specified in the dynamic shape configuration. # Here we set the range of 0th model input's 1st dimension as # [0, model.config.block_size]. -# See https://pytorch.org/executorch/main/concepts.html#dynamic-shapes +# See ../concepts.html#dynamic-shapes # for details about creating dynamic shapes. dynamic_shape = ( {1: torch.export.Dim("token_dim", max=model.config.block_size - 1)}, diff --git a/docs/source/llm/export-llm-optimum.md b/docs/source/llm/export-llm-optimum.md new file mode 100644 index 00000000000..1a104f77bc4 --- /dev/null +++ b/docs/source/llm/export-llm-optimum.md @@ -0,0 +1,171 @@ +# Exporting LLMs with HuggingFace's Optimum ExecuTorch + +[Optimum ExecuTorch](https://github.com/huggingface/optimum-executorch) provides a streamlined way to export Hugging Face transformer models to ExecuTorch format. It offers seamless integration with the Hugging Face ecosystem, making it easy to export models directly from the Hugging Face Hub. + +## Overview + +Optimum ExecuTorch supports a much wider variety of model architectures compared to ExecuTorch's native `export_llm` API. While `export_llm` focuses on a limited set of highly optimized models (Llama, Qwen, Phi, and SmolLM) with advanced features like SpinQuant and attention sink, Optimum ExecuTorch can export diverse architectures including Gemma, Mistral, GPT-2, BERT, T5, Whisper, Voxtral, and many others. + +### Use Optimum ExecuTorch when: +- You need to export models beyond the limited set supported by `export_llm` +- Exporting directly from Hugging Face Hub model IDs, including model variants such as finetunes +- You want a simpler interface with Hugging Face ecosystem integration + +### Use export_llm when: +- Working with one of the highly optimized supported models (Llama, Qwen, Phi, SmolLM) +- You need advanced optimizations like SpinQuant or attention sink +- You need pt2e quantization for QNN/CoreML/Vulkan backends +- Working with Llama models requiring custom checkpoints + +See [Exporting LLMs](export-llm.md) for details on using the native `export_llm` API. + +## Prerequisites + +### Installation + +First, clone and install Optimum ExecuTorch from source: + +```bash +git clone https://github.com/huggingface/optimum-executorch.git +cd optimum-executorch +pip install '.[dev]' +``` + +For access to the latest features and optimizations, install dependencies in dev mode: + +```bash +python install_dev.py +``` + +This installs `executorch`, `torch`, `torchao`, `transformers`, and other dependencies from nightly builds or source. + +## Supported Models + +Optimum ExecuTorch supports a wide range of model architectures including decoder-only LLMs (Llama, Qwen, Gemma, Mistral, etc.), multimodal models, vision models, audio models (Whisper), encoder models (BERT, RoBERTa), and seq2seq models (T5). + +For the complete list of supported models, see the [Optimum ExecuTorch documentation](https://github.com/huggingface/optimum-executorch#-supported-models). + +## Export Methods + +Optimum ExecuTorch offers two ways to export models: + +### Method 1: CLI Export + +The CLI is the simplest way to export models. It provides a single command to convert models from Hugging Face Hub to ExecuTorch format. + +#### Basic Export + +```bash +optimum-cli export executorch \ + --model "HuggingFaceTB/SmolLM2-135M-Instruct" \ + --task "text-generation" \ + --recipe "xnnpack" \ + --output_dir="./smollm2_exported" +``` + +#### With Optimizations + +Add custom SDPA, KV cache optimization, and quantization: + +```bash +optimum-cli export executorch \ + --model "HuggingFaceTB/SmolLM2-135M-Instruct" \ + --task "text-generation" \ + --recipe "xnnpack" \ + --use_custom_sdpa \ + --use_custom_kv_cache \ + --qlinear 8da4w \ + --qembedding 8w \ + --output_dir="./smollm2_exported" +``` + +#### Available CLI Arguments + +Key arguments for LLM export include `--model`, `--task`, `--recipe` (backend), `--use_custom_sdpa`, `--use_custom_kv_cache`, `--qlinear` (linear quantization), `--qembedding` (embedding quantization), and `--max_seq_len`. + +For the complete list of arguments, run: +```bash +optimum-cli export executorch --help +``` + +## Optimization Options + +### Custom Operators + +Optimum ExecuTorch includes custom SDPA (~3x speedup) and custom KV cache (~2.5x speedup) operators. Enable with `--use_custom_sdpa` and `--use_custom_kv_cache`. + +### Quantization + +Optimum ExecuTorch uses [TorchAO](https://github.com/pytorch/ao) for quantization. Common options: +- `--qlinear 8da4w`: int8 dynamic activation + int4 weight (recommended) +- `--qembedding 4w` or `--qembedding 8w`: int4/int8 embedding quantization + +Example: +```bash +optimum-cli export executorch \ + --model "meta-llama/Llama-3.2-1B" \ + --task "text-generation" \ + --recipe "xnnpack" \ + --use_custom_sdpa \ + --use_custom_kv_cache \ + --qlinear 8da4w \ + --qembedding 4w \ + --output_dir="./llama32_1b" +``` + +### Backend Support + +Supported backends: `xnnpack` (CPU), `coreml` (Apple GPU), `portable` (baseline), `cuda` (NVIDIA GPU). Specify with `--recipe`. + +## Exporting Different Model Types + +Optimum ExecuTorch supports various model architectures with different tasks: + +- **Decoder-only LLMs**: Use `--task text-generation` +- **Multimodal LLMs**: Use `--task multimodal-text-to-text` +- **Seq2Seq models** (T5): Use `--task text2text-generation` +- **ASR models** (Whisper): Use `--task automatic-speech-recognition` + +For detailed examples of exporting each model type, see the [Optimum ExecuTorch export guide](https://github.com/huggingface/optimum-executorch/blob/main/optimum/exporters/executorch/README.md). + +## Running Exported Models + +### Verifying Output with Python + +After exporting, you can verify the model output in Python before deploying to device using classes from `modeling.py`, such as the `ExecuTorchModelForCausalLM` class for LLMs: + +```python +from optimum.executorch import ExecuTorchModelForCausalLM +from transformers import AutoTokenizer + +# Load the exported model +model = ExecuTorchModelForCausalLM.from_pretrained("./smollm2_exported") +tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct") + +# Generate text +generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Once upon a time", + max_seq_len=128, +) +print(generated_text) +``` + +### Running on Device + +After verifying your model works correctly, deploy it to device: + +- [Running with C++](run-with-c-plus-plus.md) - Run exported models using ExecuTorch's C++ runtime +- [Running on Android](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/android) - Deploy to Android devices +- [Running on iOS](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/apple) - Deploy to iOS devices + +## Performance + +For performance benchmarks and on-device metrics, see the [Optimum ExecuTorch benchmarks](https://github.com/huggingface/optimum-executorch#-benchmarks-on-mobile-devices) and the [ExecuTorch Benchmark Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=pytorch%2Fexecutorch). + +## Additional Resources + +- [Optimum ExecuTorch GitHub](https://github.com/huggingface/optimum-executorch) - Full documentation and examples +- [Supported Models](https://github.com/huggingface/optimum-executorch#-supported-models) - Complete model list +- [Export Guide](https://github.com/huggingface/optimum-executorch/blob/main/optimum/exporters/executorch/README.md) - Detailed export examples +- [TorchAO Quantization](https://github.com/pytorch/ao) - Quantization library documentation diff --git a/docs/source/llm/export-llm.md b/docs/source/llm/export-llm.md index 462d9a51849..8156c31a97b 100644 --- a/docs/source/llm/export-llm.md +++ b/docs/source/llm/export-llm.md @@ -4,7 +4,7 @@ Instead of needing to manually write code to call torch.export(), use ExecuTorch ## Prerequisites -The LLM export functionality requires the `pytorch_tokenizers` package. If you encounter a `ModuleNotFoundError: No module named 'pytorch_tokenizers'` error, install it from the ExecutorTorch source code: +The LLM export functionality requires the `pytorch_tokenizers` package. If you encounter a `ModuleNotFoundError: No module named 'pytorch_tokenizers'` error, install it from the ExecuTorch source code: ```bash pip install -e ./extension/llm/tokenizers/ @@ -20,11 +20,13 @@ As of this doc, the list of supported LLMs include the following: The up-to-date list of supported LLMs can be found in the code [here](https://github.com/pytorch/executorch/blob/main/extension/llm/export/config/llm_config.py#L32). +**Note:** If you need to export models that are not on this list or other model architectures (such as Gemma, Mistral, BERT, T5, Whisper, etc.), see [Exporting LLMs with Optimum](export-llm-optimum.md) which supports a much wider variety of models from Hugging Face Hub. + ## The export_llm API `export_llm` is ExecuTorch's high-level export API for LLMs. In this tutorial, we will focus on exporting Llama 3.2 1B using this API. `export_llm`'s arguments are specified either through CLI args or through a yaml configuration whose fields are defined in [`LlmConfig`](https://github.com/pytorch/executorch/blob/main/extension/llm/export/config/llm_config.py). To call `export_llm`: ``` -python -m executorch.examples.extension.llm.export.export_llm +python -m executorch.extension.llm.export.export_llm --config +base. ``` @@ -78,7 +80,7 @@ python -m extension.llm.export.export_llm \ - `use_shared_embedding` can help for models with tied input/output embedding layers, given that you quantize using TorchAO low bit ops (`quantization.qmode: torchao:8da(\\d+)w` or `quantization.qmode: torchao:fpa(\d+)w`), see more [here](https://github.com/pytorch/executorch/blob/main/extension/llm/export/config/llm_config.py#L307). - `use_attention_sink` to extend generation by removing from the beginning of the KV cache when the max context length is reached. - `quantize_kv_cache` quantizes the KV cache in int8. -- `local_global_attention` impements [Local-Global Attention](https://arxiv.org/abs/2411.09604), making specific attention layers use a much smaller localized sliding window KV cache. +- `local_global_attention` implements [Local-Global Attention](https://arxiv.org/abs/2411.09604), making specific attention layers use a much smaller localized sliding window KV cache. ## Quantization Quantization options are defined by [`QuantizationConfig`](https://github.com/pytorch/executorch/blob/main/extension/llm/export/config/llm_config.py#L283). ExecuTorch does quantization in two ways: @@ -92,7 +94,7 @@ The quantization modes are defined [here](https://github.com/pytorch/executorch/ Common ones to use are: - `8da4w`: short for int8 dynamic activation + int4 weight quantization. -- `int8`: int8 weight-only quanziation. +- `int8`: int8 weight-only quantization. Group size is specified with: - `group_size`: 8, 32, 64, etc. @@ -112,7 +114,7 @@ base: metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' model: use_kv_cache: True - use_sdpa_withp_kv_cache: True + use_sdpa_with_kv_cache: True quantization: embedding_quantize: 4,32 qmode: 8da4w @@ -142,7 +144,7 @@ base: metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' model: use_kv_cache: True - use_sdpa_withp_kv_cache: True + use_sdpa_with_kv_cache: True quantization: embedding_quantize: 4,32 qmode: 8da4w diff --git a/docs/source/llm/getting-started.md b/docs/source/llm/getting-started.md index 849418342b6..95caae6ddd9 100644 --- a/docs/source/llm/getting-started.md +++ b/docs/source/llm/getting-started.md @@ -18,9 +18,13 @@ To follow this guide, you'll need to install ExecuTorch. Please see [Setting Up Deploying LLMs to ExecuTorch can be boiled down to a two-step process: (1) exporting the LLM to a `.pte` file and (2) running the `.pte` file using our C++ APIs or Swift/Java bindings. -- [Exporting LLMs](export-llm.md) +### Exporting +- [Exporting LLMs](export-llm.md) - Export using ExecuTorch's native `export_llm` API with advanced optimizations +- [Exporting LLMs with Optimum](export-llm-optimum.md) - Export Hugging Face models with broader architecture support - [Exporting custom LLMs](export-custom-llm.md) + +### Running - [Running with C++](run-with-c-plus-plus.md) -- [Running on Android (XNNPack)](llama-demo-android.md) +- [Running on Android (XNNPack)](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/android) - [Running on Android (Qualcomm)](build-run-llama3-qualcomm-ai-engine-direct-backend.md) - [Running on iOS](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/apple) diff --git a/docs/source/llm/llama-demo-android.md b/docs/source/llm/llama-demo-android.md deleted file mode 100644 index 023f82baf33..00000000000 --- a/docs/source/llm/llama-demo-android.md +++ /dev/null @@ -1,2 +0,0 @@ -```{include} ../../../examples/demo-apps/android/LlamaDemo/README.md -``` diff --git a/docs/source/llm/run-on-ios.md b/docs/source/llm/run-on-ios.md index 88ad94c38d3..f096995fca9 100644 --- a/docs/source/llm/run-on-ios.md +++ b/docs/source/llm/run-on-ios.md @@ -80,17 +80,22 @@ do { #### Generating -Generate up to a given number of tokens from an initial prompt. The callback block is invoked once per token as it’s produced. +Generate tokens from an initial prompt, configured with an `ExecuTorchLLMConfig` object. The callback block is invoked once per token as it’s produced. Objective-C: ```objectivec +ExecuTorchLLMConfig *config = [[ExecuTorchLLMConfig alloc] initWithBlock:^(ExecuTorchLLMConfig *c) { + c.temperature = 0.8; + c.sequenceLength = 2048; +}]; + NSError *error = nil; -BOOL success = [runner generate:@"Once upon a time" - sequenceLength:50 - withTokenCallback:^(NSString *token) { - NSLog(@"Generated token: %@", token); - } - error:&error]; +BOOL success = [runner generateWithPrompt:@"Once upon a time" + config:config + tokenCallback:^(NSString *token) { + NSLog(@"Generated token: %@", token); + } + error:&error]; if (!success) { NSLog(@"Generation failed: %@", error); } @@ -99,7 +104,10 @@ if (!success) { Swift: ```swift do { - try runner.generate("Once upon a time", sequenceLength: 50) { token in + try runner.generate("Once upon a time", Config { + $0.temperature = 0.8 + $0.sequenceLength = 2048 + }) { token in print("Generated token:", token) } } catch { @@ -121,6 +129,136 @@ Swift: runner.stop() ``` +#### Resetting + +To clear the prefilled tokens from the KV cache and reset generation stats, call: + +Objective-C: +```objectivec +[runner reset]; +``` + +Swift: +```swift +runner.reset() +``` + +### MultimodalRunner + +The `ExecuTorchLLMMultimodalRunner` class (bridged to Swift as `MultimodalRunner`) provides an interface for loading and running multimodal models that can accept a sequence of text, image, and audio inputs. + +#### Multimodal Inputs + +Inputs are provided as an array of `ExecuTorchLLMMultimodalInput` (or `MultimodalInput` in Swift). You can create inputs from String for text, `ExecuTorchLLMImage` for images (`Image` in Swift), and `ExecuTorchLLMAudio` for audio features (`Audio`) in Swift. + +Objective-C: +```objectivec +ExecuTorchLLMMultimodalInput *textInput = [ExecuTorchLLMMultimodalInput inputWithText:@"What's in this image?"]; + +NSData *imageData = ...; // Your raw image bytes +ExecuTorchLLMImage *image = [[ExecuTorchLLMImage alloc] initWithData:imageData width:336 height:336 channels:3]; +ExecuTorchLLMMultimodalInput *imageInput = [ExecuTorchLLMMultimodalInput inputWithImage:image]; +``` + +Swift: +```swift +let textInput = MultimodalInput("What's in this image?") + +let imageData: Data = ... // Your raw image bytes +let image = Image(data: imageData, width: 336, height: 336, channels: 3) +let imageInput = MultimodalInput(image) + +let audioFeatureData: Data = ... // Your raw audio feature bytes +let audio = Audio(float: audioFeatureData, batchSize: 1, bins: 128, frames: 3000) +let audioInput = MultimodalInput(audio) +``` + +#### Initialization + +Create a runner by specifying the paths to your multimodal model and its tokenizer. + +Objective-C: +```objectivec +NSString *modelPath = [[NSBundle mainBundle] pathForResource:@"llava" ofType:@"pte"]; +NSString *tokenizerPath = [[NSBundle mainBundle] pathForResource:@"llava_tokenizer" ofType:@"bin"]; + +ExecuTorchLLMMultimodalRunner *runner = [[ExecuTorchLLMMultimodalRunner alloc] initWithModelPath:modelPath + tokenizerPath:tokenizerPath]; +``` + +Swift: +```swift +let modelPath = Bundle.main.path(forResource: "llava", ofType: "pte")! +let tokenizerPath = Bundle.main.path(forResource: "llava_tokenizer", ofType: "bin")! + +let runner = MultimodalRunner(modelPath: modelPath, tokenizerPath: tokenizerPath) +``` + +#### Loading + +Explicitly load the model before generation. + +Objective-C: +```objectivec +NSError *error = nil; +BOOL success = [runner loadWithError:&error]; +if (!success) { + NSLog(@"Failed to load: %@", error); +} +``` + +Swift: +```swift +do { + try runner.load() +} catch { + print("Failed to load: \(error)") +} +``` + +#### Generating + +Generate tokens from an ordered array of multimodal inputs. + +Objective-C: +```objectivec +NSArray *inputs = @[textInput, imageInput]; + +ExecuTorchLLMConfig *config = [[ExecuTorchLLMConfig alloc] initWithBlock:^(ExecuTorchLLMConfig *c) { + c.sequenceLength = 768; +}]; + +NSError *error = nil; +BOOL success = [runner generateWithInputs:inputs + config:config + tokenCallback:^(NSString *token) { + NSLog(@"Generated token: %@", token); + } + error:&error]; +if (!success) { + NSLog(@"Generation failed: %@", error); +} +``` + +Swift: +```swift +let inputs = [textInput, imageInput] + +do { + try runner.generate(inputs, Config { + $0.sequenceLength = 768 + }) { token in + print("Generated token:", token) + } +} catch { + print("Generation failed:", error) +} +``` + +#### Stopping and Resetting + +The stop and reset methods for `MultimodalRunner` behave identically to those on `TextRunner`. + ## Demo Get hands-on with our [etLLM iOS Demo App](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/apple) to see the LLM runtime APIs in action. diff --git a/docs/source/llm/run-with-c-plus-plus.md b/docs/source/llm/run-with-c-plus-plus.md index f987fcab2a5..217afad847b 100644 --- a/docs/source/llm/run-with-c-plus-plus.md +++ b/docs/source/llm/run-with-c-plus-plus.md @@ -10,7 +10,7 @@ Before you begin, make sure you have: - Please also see [Model Metadata](#model-metadata) section for important metadata to be serialized into `.pte`. 2. A tokenizer file compatible with your model - For HuggingFace tokenizers, this is a JSON file `tokenizer.json` - - For SentencePiece tokenizers, this is is a `tokenizer.model` file and normally live alongside the weights file + - For SentencePiece tokenizers, this is a `tokenizer.model` file and normally lives alongside the weights file 3. CMake and a C++ compiler installed - CMake version 3.29 or higher - g++ or clang compiler diff --git a/docs/source/llm/working-with-llms.md b/docs/source/llm/working-with-llms.md new file mode 100644 index 00000000000..e4088efd12b --- /dev/null +++ b/docs/source/llm/working-with-llms.md @@ -0,0 +1,19 @@ +(working-with-llms)= + +# LLMs + +Learn how to export LLM models and deploy them across different platforms and runtime environments. This section covers the complete workflow from model export to running inference on mobile devices and edge hardware. + + +```{toctree} +:maxdepth: 1 +:caption: Working with LLMs + +getting-started +export-llm +export-llm-optimum +export-custom-llm +run-with-c-plus-plus +build-run-llama3-qualcomm-ai-engine-direct-backend +run-on-ios +``` diff --git a/docs/source/new-contributor-guide.md b/docs/source/new-contributor-guide.md index d2074a3379f..ec5e67afc87 100644 --- a/docs/source/new-contributor-guide.md +++ b/docs/source/new-contributor-guide.md @@ -103,13 +103,6 @@ Before you can start writing any code, you need to get a copy of ExecuTorch code * The `origin` entries show your forked GitHub repository. They tell you that when you run `git pull` or `git push`, your changes will go from/to your GitHub fork. * The `upstream` entries show the main ExecuTorch repository. If you want to sync the latest changes from there, you can run `git fetch upstream`. - - Let's sync from both your fork _and_ the main ExecuTorch branch, getting the latest changes from each of them. To do this, run: - - ```bash - git fetch --all --prune - ``` - 4. If you just cloned your fork, your GitHub repository will tell you your branch is up-to-date: ![](_static/img/new-contributor-guide/synced_fork.png) diff --git a/docs/source/pico2_tutorial.md b/docs/source/pico2_tutorial.md new file mode 100644 index 00000000000..6718e63d05a --- /dev/null +++ b/docs/source/pico2_tutorial.md @@ -0,0 +1,199 @@ +# Pico2: A simple MNIST Tutorial + +Deploy your PyTorch models directly to Raspberry Pi Pico2 microcontroller with ExecuTorch. + +## What You'll Build + +A 28×28 MNIST digit classifier running on memory constrained, low power microcontrollers: + +- Input: ASCII art digits (0, 1, 4, 7) +- Output: Real-time predictions via USB serial +- Memory: <400KB total footprint + +## Prerequisites + +- [Environment Setup section](https://docs.pytorch.org/executorch/1.0/using-executorch-building-from-source.html) + +- Refer to this link on how to accept 'EULA' agreement and setup toolchain [link](https://docs.pytorch.org/executorch/1.0/backends-arm-ethos-u.html#development-requirements) + +- Verify ARM toolchain + +```bash +which arm-none-eabi-gcc # --> arm/arm-scratch/arm-gnu-toolchain-13.3.rel1-x86_64-arm-none-eabi/bin/ +``` + +## Step 1: Generate pte from given example Model + +- Use the [provided example model](https://github.com/pytorch/executorch/blob/main/examples/raspberry_pi/pico2/export_mlp_mnist.py) + +```bash +python export_mlp_mnist.py # Creates balanced_tiny_mlp_mnist.pte +``` + +- **Note:** This is hand-crafted MNIST Classifier (proof-of-concept), and not production trained. This tiny MLP recognizes digits 0, 1, 4, and 7 using manually designed feature detectors. + +## Step 2: Build Firmware for Pico2 + +```bash +# Generate model (Creates balanced_tiny_mlp_mnist.pte) +cd ./examples/raspberry_pi/pico2 +python export_mlp_mnist.py +cd - + +# Build Pico2 firmware (one command!) + +./examples/raspberry_pi/pico2/build_firmware_pico.sh --model=balanced_tiny_mlp_mnist.pte # This creates executorch_pico.uf2, a firmware image for Pico2 +``` + +Output: **executorch_pico.uf2** firmware file (examples/raspberry_pi/pico2/build/) + +**Note:** '[build_firmware_pico.sh](https://github.com/pytorch/executorch/blob/main/examples/raspberry_pi/pico2/build_firmware_pico.sh)' script converts given model pte to hex array and generates C code for the same via this helper [script](https://github.com/pytorch/executorch/blob/main/examples/raspberry_pi/pico2/pte_to_array.py). This C code is then compiled to generate final .uf2 binary which is then flashed to Pico2. + +## Step 3: Flash to Pico2 + +Hold BOOTSEL button on Pico2 +Connect USB → Mounts as ^RPI-RP2^ drive +Drag & drop ^executorch_pico.uf2^ file +Release BOOTSEL → Pico2 reboots with your model + +## Step 4: Verify Deployment + +**Success indicators:** + +- LED blinks 10× at 500ms → Model running ✅ +- LED blinks 10× at 100ms → Error, check serial ❌ + +**View predictions:** + +```bash +# Connect serial terminal +screen /dev/tty.usbmodem1101 115200 +# Expected output: + +Something like: + +=== Digit 7 === +############################ +############################ + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### + #### +#### +### + +Input stats: 159 white pixels out of 784 total +Running neural network inference... +✅ Neural network results: + Digit 0: 370.000 + Digit 1: 0.000 + Digit 2: -3.000 + Digit 3: -3.000 + Digit 4: 860.000 + Digit 5: -3.000 + Digit 6: -3.000 + Digit 7: 1640.000 ← PREDICTED + Digit 8: -3.000 + Digit 9: -3.000 + +� PREDICTED: 7 (Expected: 7) ✅ CORRECT! +``` + +## Memory Optimization Tips + +### Pico2 Constraints + +- 520KB SRAM (runtime memory) +- 4MB Flash (model storage) +- Keep models small: + +### Common Issues + +- "Memory allocation failed" → Reduce model size and use quantization +- "Operator missing" → Use selective build: ^--operators=add,mul,relu^ +- "Import error" → Check ^arm-none-eabi-gcc^ toolchain setup. + +In order to resolve some of the issues above, refer to the following guides: + +- [ExecuTorch Quantization Optimization Guide](https://docs.pytorch.org/executorch/1.0/quantization-optimization.html) +- [Model Export & Lowering](https://docs.pytorch.org/executorch/1.0/using-executorch-export.html) and +- [Selective Build support](https://docs.pytorch.org/executorch/1.0/kernel-library-selective-build.html) + +### Firmware Size Analysis + +```bash +cd +ls -al examples/raspberry_pi/pico2/build/executorch_pico.elf +``` + +- **Overall section sizes** + +```bash +arm-none-eabi-size -A examples/raspberry_pi/pico2/build/executorch_pico.elf +``` + +- **Detailed section breakdown** + +```bash +arm-none-eabi-objdump -h examples/raspberry_pi/pico2/build/executorch_pico.elf +``` + +- **Symbol sizes (largest consumers)** + +```bash +arm-none-eabi-nm --print-size --size-sort --radix=d examples/raspberry_pi/pico2/build/executorch_pico.elf | tail -20 +``` + +### Model Memory Footprint + +- **Model data specifically** + +```bash +arm-none-eabi-nm --print-size --size-sort --radix=d examples/raspberry_pi/pico2/build/executorch_pico.elf | grep -i model +``` + +- **Check what's in .bss (uninitialized data)** + +```bash +arm-none-eabi-objdump -t examples/raspberry_pi/pico2/build/executorch_pico.elf | grep ".bss" | head -10 +``` + +- **Memory map overview** + +```bash +arm-none-eabi-readelf -l examples/raspberry_pi/pico2/build/executorch_pico.elf +``` + +## Next Steps + +### Scale up your deployment + +- Use real production trained model +- Optimize further → INT8 quantization, pruning + +### Happy Inference! + +**Result:** PyTorch model → Pico2 deployment in 4 simple steps 🚀 +Total tutorial time: ~15 minutes + +**Conclusion:** Real-time inference on memory constrained, low power microcontrollers, a complete PyTorch → ExecuTorch → Pico2 demo MNIST deployment diff --git a/docs/source/platforms-desktop.md b/docs/source/platforms-desktop.md new file mode 100644 index 00000000000..ba22786576f --- /dev/null +++ b/docs/source/platforms-desktop.md @@ -0,0 +1,23 @@ +# Desktop & Laptop + +ExecuTorch supports desktop and laptop deployment across Linux, macOS, and Windows. + +## Platform-Specific Guides +- [C++ Runtime Integration](using-executorch-cpp) - Complete setup guide +- [Building from Source](using-executorch-building-from-source) + +## Available Backends by Platform + +### Linux +- [XNNPACK (CPU)](backends/xnnpack/xnnpack-overview.md) +- [OpenVINO (Intel)](build-run-openvino) +- [ARM Ethos-U (ARM64)](backends-arm-ethos-u) + +### macOS +- [CoreML (recommended)](backends-coreml) +- [MPS (Apple Silicon)](backends-mps) +- [XNNPACK (CPU)](backends/xnnpack/xnnpack-overview.md) + +### Windows +- [XNNPACK (CPU)](backends/xnnpack/xnnpack-overview.md) +- [OpenVINO (Intel)](build-run-openvino) diff --git a/docs/source/platforms-embedded.md b/docs/source/platforms-embedded.md new file mode 100644 index 00000000000..5ea248fc0d9 --- /dev/null +++ b/docs/source/platforms-embedded.md @@ -0,0 +1,19 @@ +# Embedded Platforms + +ExecuTorch supports embedded devices from microcontrollers to edge devices. + +## Platform-Specific Guides +- [C++ Runtime Integration](using-executorch-cpp) - Complete setup guide +- [Building from Source](using-executorch-building-from-source) + +## Available Backends by Device Type + +### Microcontrollers +- [Cadence Xtensa Backend](backends-cadence) +- [ARM Ethos-U NPU Backend](backends-arm-ethos-u) +- [Custom Backend Development](backend-delegates-integration) + +### Edge Devices +- [ARM Ethos-U NPU Backend](backends-arm-ethos-u) +- [NXP eIQ Neutron Backend](backend-nxp) +- [Custom Hardware Integration](backend-delegates-integration) diff --git a/docs/source/ptd-file-format.md b/docs/source/ptd-file-format.md index 6381e8a071c..c7bad1f34c0 100644 --- a/docs/source/ptd-file-format.md +++ b/docs/source/ptd-file-format.md @@ -111,7 +111,7 @@ The flatbuffer-encoded metadata follows the headers and contains: ### Tensor Layout If a data segment contains a canonical tensor, it may have associated layout information: -- **Scalar type**: Data type (float32, int32, etc.) using ExecutorTorch scalar types. +- **Scalar type**: Data type (float32, int32, etc.) using ExecuTorch scalar types. - **Sizes**: Dimensions of the tensor. - **Dim order**: Memory layout order specifying how dimensions are arranged in memory. diff --git a/docs/source/quantization-optimization.md b/docs/source/quantization-optimization.md new file mode 100644 index 00000000000..d2005b3adac --- /dev/null +++ b/docs/source/quantization-optimization.md @@ -0,0 +1,20 @@ +(quantization-optimization)= + +# Quantization & Optimization + +Advanced techniques for model compression and performance optimization. + +## Quantization Strategies + +- {doc}`quantization-overview` — Comprehensive quantization strategies and techniques + +## Performance Optimization + +- {doc}`runtime-profiling` — Performance profiling and optimization techniques + +```{toctree} +:hidden: +:maxdepth: 1 + +quantization-overview +runtime-profiling diff --git a/docs/source/quantization-overview.md b/docs/source/quantization-overview.md index fdceee80e8e..81b15f6c8bb 100644 --- a/docs/source/quantization-overview.md +++ b/docs/source/quantization-overview.md @@ -14,7 +14,7 @@ Quantization in ExecuTorch is backend-specific. Each backend defines how models The PT2E quantization workflow has three main steps: 1. Configure a backend-specific quantizer. -2. Prepare, calibrate, convert, and evalute the quantized model in PyTorch +2. Prepare, calibrate, convert, and evaluate the quantized model in PyTorch 3. Lower the model to the target backend ## 1. Configure a Backend-Specific Quantizer @@ -28,8 +28,8 @@ These quantizers usually support configs that allow users to specify quantizatio Not all quantization options are supported by all backends. Consult backend-specific guides for supported quantization modes and configuration, and how to initialize the backend-specific PT2E quantizer: -* [XNNPACK quantization](backends-xnnpack.md#quantization) -* [CoreML quantization](backends-coreml.md#quantization) +* [XNNPACK quantization](backends/xnnpack/xnnpack-quantization.md) +* [CoreML quantization](backends/coreml/coreml-quantization.md) * [QNN quantization](backends-qualcomm.md#step-2-optional-quantize-your-model) diff --git a/docs/source/quantization.md b/docs/source/quantization.md new file mode 100644 index 00000000000..b5ee9f21897 --- /dev/null +++ b/docs/source/quantization.md @@ -0,0 +1,7 @@ +# Quantization + +```{toctree} +:maxdepth: 1 + +quantization-overview +``` diff --git a/docs/source/quick-start-section.md b/docs/source/quick-start-section.md new file mode 100644 index 00000000000..b6940d2acef --- /dev/null +++ b/docs/source/quick-start-section.md @@ -0,0 +1,38 @@ +(quick-start-section)= +# Quick Start + +Get started with ExecuTorch in just a few steps. + +This section walks you through the essential steps to get ExecuTorch up and running, from initial setup to exporting your first model for edge deployment. + +## What You'll Learn + +Follow these guides in order to get started with ExecuTorch: + +- **{doc}`getting-started`** - Initial Setup: Set up your development environment and run your first ExecuTorch example. + +- **{doc}`using-executorch-export`** - Exporting your model: Export for Edge deployment. + +- **{doc}`using-executorch-building-from-source`** - Building from Source: Build ExecuTorch from source for custom configurations and development. + +## Prerequisites + +- Python 3.10-3.13 +- PyTorch 2.9+ +- Basic familiarity with PyTorch model development + +## Next Steps + +After completing the quick start, explore: + +- **{doc}`edge-platforms-section`** - Deploy to specific platforms (Android, iOS, Desktop, Embedded) +- **{doc}`backends-section`** - Choose the right acceleration backend for your hardware + +```{toctree} +:hidden: +:maxdepth: 2 +:caption: Quick Start Guide + +getting-started +using-executorch-export +using-executorch-building-from-source diff --git a/docs/source/raspberry_pi_llama_tutorial.md b/docs/source/raspberry_pi_llama_tutorial.md new file mode 100644 index 00000000000..1e886db694a --- /dev/null +++ b/docs/source/raspberry_pi_llama_tutorial.md @@ -0,0 +1,394 @@ +# ExecuTorch on Raspberry Pi + +## TLDR + +This tutorial demonstrates how to deploy **Llama models on Raspberry Pi 4/5 devices** using ExecuTorch: + +- **Prerequisites**: Linux host machine, Python 3.10-3.13, conda environment, Raspberry Pi 4/5 +- **Setup**: Automated cross-compilation using `setup.sh` script for ARM toolchain installation +- **Export**: Convert Llama models to optimized `.pte` format with quantization options +- **Deploy**: Transfer binaries to Raspberry Pi and configure runtime libraries +- **Optimize**: Build optimization and performance tuning techniques +- **Result**: Efficient on-device Llama inference + +## Prerequisites and Hardware Requirements + +### Host Machine Requirements + +**Operating System**: Linux x86_64 (Ubuntu 20.04+ or CentOS Stream 9+) + +**Software Dependencies**: + +- **Python 3.10-3.13** (ExecuTorch requirement) +- **conda** or **venv** for environment management +- **CMake 3.29.6+** +- **Git** for repository cloning + +### Target Device Requirements + +**Supported Devices**: **Raspberry Pi 4** and **Raspberry Pi 5** with **64-bit OS** + +**Memory Requirements**: + +- **RAM & Storage** Varies by model size and optimization level +- **64-bit Raspberry Pi OS** (Bullseye or newer) + +### Verification Commands + +Verify your host machine compatibility: +```bash +# Check OS and architecture +uname -s # Should output: Linux +uname -m # Should output: x86_64 + +# Check Python version +python3 --version # Should be 3.10-3.13 + +# Check required tools +hash cmake git md5sum 2>/dev/null || echo "Missing required tools" + +cmake --version # Should be 3.29.6+ at minimum + +## Development Environment Setup + +### Clone ExecuTorch Repository + +First, clone the ExecuTorch repository with the Raspberry Pi support: + +```bash +# Create project directory +mkdir ~/executorch-rpi && cd ~/executorch-rpi && git clone -b release/1.0 https://github.com/pytorch/executorch.git && +cd executorch +``` + +### Create Conda Environment + +```bash +# Create conda environment +conda create -yn executorch python=3.10.0 +conda activate executorch + +# Upgrade pip +pip install --upgrade pip +``` + +Alternative: Virtual Environment +If you prefer Python's built-in virtual environment: + +```bash +python3 -m venv .venv +source .venv/bin/activate +pip install --upgrade pip +``` + +Refer to → {doc}`getting-started` for more details. + +## Cross-Compilation Toolchain Step + +Run the following script on your Linux host machine: + +```bash +# Run the Raspberry Pi setup script for Pi 5 +examples/raspberry_pi/setup.sh pi5 +``` + +On successful completion, you should see the following output: + +```bash +[100%] Linking CXX executable llama_main +[100%] Built target llama_main +[SUCCESS] LLaMA runner built successfully + +==== Verifying Build Outputs ==== +[SUCCESS] ✓ llama_main (6.1M) +[SUCCESS] ✓ libllama_runner.so (4.0M) +[SUCCESS] ✓ libextension_module.a (89K) - static library + +✓ ExecuTorch cross-compilation setup completed successfully! +``` + +## Model Preparation and Export + +### Download Llama Models + +Download the Llama model from Hugging Face or any other source, and make sure that following files exist. + +- consolidated.00.pth (model weights) +- params.json (model config) +- tokenizer.model (tokenizer) + +### Export Llama to ExecuTorch Format + +After downloading the Llama model, export it to ExecuTorch format using the provided script: + +```bash + +#### Set these paths to point to the exported files. Following is an example instruction to export a llama model + +LLAMA_QUANTIZED_CHECKPOINT=path/to/consolidated.00.pth +LLAMA_PARAMS=path/to/params.json + +python -m extension.llm.export.export_llm \ + --config examples/models/llama/config/llama_xnnpack_spinquant.yaml \ + +base.model_class="llama3_2" \ + +base.checkpoint="${LLAMA_QUANTIZED_CHECKPOINT:?}" \ + +base.params="${LLAMA_PARAMS:?}" +``` + +The file llama3_2.pte will be generated at the place where you run the command + +- For more details see [Option A: Download and Export Llama3.2 1B/3B Model](https://github.com/pytorch/executorch/blob/main/examples/models/llama/README.md#option-a-download-and-export-llama32-1b3b-model) +- Also refer to → {doc}`llm/export-llm` for more details. + +## Raspberry Pi Deployment + +### Transfer Binaries to Raspberry Pi + +After successful cross-compilation, transfer the required files: + +```bash +##### Set Raspberry Pi details +export RPI_UN="pi" # Your Raspberry Pi username +export RPI_IP="your-rpi-ip-address" + +##### Create deployment directory on Raspberry Pi +ssh $RPI_UN@$RPI_IP 'mkdir -p ~/executorch-deployment' +##### Copy main executable +scp cmake-out/examples/models/llama/llama_main $RPI_UN@$RPI_IP:~/executorch-deployment/ +##### Copy runtime library +scp cmake-out/examples/models/llama/runner/libllama_runner.so $RPI_UN@$RPI_IP:~/executorch-deployment/ +##### Copy model file +scp llama3_2.pte $RPI_UN@$RPI_IP:~/executorch-deployment/ +scp ./tokenizer.model $RPI_UN@$RPI_IP:~/executorch-deployment/ +``` + +### Configure Runtime Libraries on Raspberry Pi + +SSH into your Raspberry Pi and configure the runtime: + +#### Set up library environment + +```bash +cd ~/executorch-deployment +echo 'export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH' > setup_env.sh +chmod +x setup_env.sh + +#### Make executable + +chmod +x llama_main +``` + +## Dry Run + +```bash +source setup_env.sh +./llama_main --help +``` + +Make sure that the output does not have any GLIBC / other library mismatch errors in the output. If you see any, follow the troubleshooting steps below. + +## Troubleshooting + +### Issue 1: GLIBC Version Mismatch + +**Problem:** The binary was compiled with a newer GLIBC version (2.38) than what's available on your Raspberry Pi (2.36). + +**Error Symptoms:** + +```bash +./llama_main: /lib/aarch64-linux-gnu/libm.so.6: version `GLIBC_2.38' not found (required by ./llama_main) +./llama_main: /lib/aarch64-linux-gnu/libc.so.6: version `GLIBC_2.38' not found (required by ./llama_main) +./llama_main: /lib/aarch64-linux-gnu/libstdc++.so.6: version `CXXABI_1.3.15' not found (required by ./llama_main) +./llama_main: /lib/aarch64-linux-gnu/libc.so.6: version `GLIBC_2.38' not found (required by /lib/libllama_runner.so) +``` + +**There are two potential solutions:** + +- **Solution A**: Modify the Pi to match the binary (run on Pi) + +- **Solution B**: Modify the binary to match the Pi (run on host) + +#### Solution A: Upgrade GLIBC on Raspberry Pi (Recommended) + +1. **Check your current GLIBC version:** + +```bash +ldd --version +# Output: ldd (Debian GLIBC 2.36-9+rpt2+deb12u12) 2.36 +``` + +2. **⚠️ Compatibility Warning and Safety Check:** + +```bash +# Just check and warn - don't do the upgrade +current_glibc=$(ldd --version | head -n1 | grep -o '[0-9]\+\.[0-9]\+') +required_glibc="2.38" + +echo "Current GLIBC: $current_glibc" +echo "Required GLIBC: $required_glibc" + +if [[ $(echo "$current_glibc < $required_glibc" | bc -l) -eq 1 ]]; then + echo "" + echo "⚠️ WARNING: Your GLIBC version is too old" + echo " You need to upgrade to continue with the next steps" + echo " Consider using Solution B (rebuild binary) for better safety" + echo "" +else + echo "✅ Your GLIBC version is already compatible" +fi +``` + +**NOTE:** If the output shows "⚠️ WARNING: Your GLIBC version is too old", proceed with either Upgrade / Step #3 below (or) Solution B. Otherwise skip the next step as your device is __already compatible__ and directly go to Step#4. + +3. **Upgrade to newer GLIBC:** + +```bash +# Add Debian unstable repository +echo "deb http://deb.debian.org/debian sid main contrib non-free" | sudo tee -a /etc/apt/sources.list + +# Update package lists +sudo apt update + +# Install newer GLIBC packages +sudo apt-get -t sid install libc6 libstdc++6 + +# Reboot system +sudo reboot +``` + +4. **Verify compatibility after reboot:** + +```bash +cd ~/executorch-deployment +source setup_env.sh + +# Test that the binary works +if ./llama_main --help &>/dev/null; then + echo "✅ GLIBC upgrade successful - binary is compatible" +else + echo "❌ GLIBC upgrade failed - binary still incompatible" + echo "Consider rolling back or refer to documentation for troubleshooting" +fi +``` + +5. **Test the fix:** + +```bash +cd ~/executorch-deployment +source setup_env.sh +./llama_main --model_path ./llama3_2.pte --tokenizer_path ./tokenizer.model --seq_len 128 --prompt "Hello" +``` + +**Important Notes:** + +- Select "Yes" when prompted to restart services +- Press Enter to keep current version for configuration files +- Backup important data before upgrading + +#### Solution B: Rebuild with Raspberry Pi's GLIBC (Advanced) + +If you prefer not to upgrade your Raspberry Pi system: + +1. **Copy Pi's filesystem to host machine:** + +```bash +# On Raspberry Pi - install rsync +ssh pi@ +sudo apt update && sudo apt install rsync +exit + +# On host machine - copy Pi's filesystem +mkdir -p ~/rpi5-sysroot +rsync -aAXv --exclude={"/proc","/sys","/dev","/run","/tmp","/mnt","/media","/lost+found"} \ + pi@:/ ~/rpi5-sysroot +``` + +2. **Update CMake toolchain file:** +```bash +# Edit arm-toolchain-pi5.cmake +# Replace this line: +# set(CMAKE_SYSROOT "${TOOLCHAIN_PATH}/aarch64-none-linux-gnu/libc") + +# With this: +set(CMAKE_SYSROOT "/home/yourusername/rpi5-sysroot") +set(CMAKE_FIND_ROOT_PATH "${CMAKE_SYSROOT}") +``` + +3. **Rebuild binaries:** +```bash +# Clean and rebuild +rm -rf cmake-out +./examples/raspberry_pi/rpi_setup.sh pi5 --force-rebuild + +# Verify GLIBC version +strings ./cmake-out/examples/models/llama/llama_main | grep GLIBC_ +# Should show max GLIBC_2.36 (matching your Pi) +``` + +--- + +### Issue 2: Library Not Found + +**Problem:** Required libraries are not found at runtime. + +**Error Symptoms:** +```bash +./llama_main: error while loading shared libraries: libllama_runner.so: cannot open shared object file +``` + +**Solution:** +```bash +# Ensure you're in the correct directory and environment is set +cd ~/executorch-deployment +source setup_env.sh +./llama_main --help +``` + +**Root Cause:** Either `LD_LIBRARY_PATH` is not set or you're not in the deployment directory. + +--- + +### Issue 3: Tokenizer JSON Parsing Warnings + +**Problem:** Warning messages about JSON parsing errors after running the llama_main binary. + +**Error Symptoms:** + +```bash +E tokenizers:hf_tokenizer.cpp:60] Error parsing json file: [json.exception.parse_error.101] +``` + +**Solution:** These warnings can be safely ignored. They don't affect model inference. + +--- + + +## Quick Test Command + +After resolving issues, test with: + +```bash +cd ~/executorch-deployment +source setup_env.sh +./llama_main --model_path ./llama3_2.pte --tokenizer_path ./tokenizer.model --seq_len 128 --prompt "What is the meaning of life?" +``` + +## Debugging Tools + +Enable ExecuTorch logging: + +```bash +# Set log level for debugging +export ET_LOG_LEVEL=Info +./llama_main --model_path ./model.pte --verbose +``` + +## Final Run command + +```bash +cd ~/executorch-deployment +source setup_env.sh +./llama_main --model_path ./llama3_2.pte --tokenizer_path ./tokenizer.model --seq_len 128 --prompt "What is the meaning of life?" +``` + +Happy Inferencing! diff --git a/docs/source/running-a-model-cpp-tutorial.md b/docs/source/running-a-model-cpp-tutorial.md index a12ef122bc8..5ae4235995d 100644 --- a/docs/source/running-a-model-cpp-tutorial.md +++ b/docs/source/running-a-model-cpp-tutorial.md @@ -6,13 +6,13 @@ In this tutorial, we will cover how to run an ExecuTorch model in C++ using the For a high level overview of the ExecuTorch Runtime please see [Runtime Overview](runtime-overview.md), and for more in-depth documentation on each API please see the [Runtime API Reference](executorch-runtime-api-reference.rst). -[Here](https://github.com/pytorch/executorch/blob/main/examples/portable/executor_runner/executor_runner.cpp) is a fully functional version C++ model runner, and the [Setting up ExecuTorch](getting-started-setup.md) doc shows how to build and run it. +[Here](https://github.com/pytorch/executorch/blob/main/examples/portable/executor_runner/executor_runner.cpp) is a fully functional version C++ model runner, and the [Setting up ExecuTorch](getting-started-setup.rst) doc shows how to build and run it. ## Prerequisites You will need an ExecuTorch model to follow along. We will be using -the model `SimpleConv` generated from the [Exporting to ExecuTorch tutorial](https://pytorch.org/executorch/main/tutorials/export-to-executorch-tutorial). +the model `SimpleConv` generated from the [Exporting to ExecuTorch tutorial](tutorials/export-to-executorch-tutorial) . ## Model Loading @@ -96,7 +96,7 @@ MemoryManager memory_manager(&method_allocator, &planned_memory); ## Loading a Method -In ExecuTorch we load and initialize from the `Program` at a method granularity. Many programs will only have one method 'forward'. `load_method` is where initialization is done, from setting up tensor metadata, to intializing delegates, etc. +In ExecuTorch we load and initialize from the `Program` at a method granularity. Many programs will only have one method 'forward'. `load_method` is where initialization is done, from setting up tensor metadata, to initializing delegates, etc. ``` cpp Result method = program->load_method(method_name); diff --git a/docs/source/runtime-integration-advanced.md b/docs/source/runtime-integration-advanced.md new file mode 100644 index 00000000000..a76265c4093 --- /dev/null +++ b/docs/source/runtime-integration-advanced.md @@ -0,0 +1,20 @@ +(runtime-integration-advanced)= + +# Runtime & Integration + +Advanced runtime integration topics + +## Platform Integration + +- {doc}`runtime-platform-abstraction-layer` — Platform abstraction layer for cross-platform deployment + +## Portable C++ Programming + +- {doc}`portable-cpp-programming` — Portable C++ programming for cross-platform deployment + +```{toctree} +:hidden: +:maxdepth: 1 + +runtime-platform-abstraction-layer +portable-cpp-programming diff --git a/docs/source/runtime-overview.md b/docs/source/runtime-overview.md index 96a618a2a41..1df3da40478 100644 --- a/docs/source/runtime-overview.md +++ b/docs/source/runtime-overview.md @@ -11,7 +11,7 @@ Works](intro-how-it-works.md). At the highest level, the ExecuTorch runtime is responsible for: * Loading binary `.pte` program files that were generated by the - [`to_executorch()`](https://pytorch.org/executorch/main/tutorials/export-to-executorch-tutorial) step of the + [`to_executorch()`](tutorials/export-to-executorch-tutorial) step of the model-lowering process. * Executing the series of instructions that implement a lowered model. diff --git a/docs/source/runtime-profiling.md b/docs/source/runtime-profiling.md index 120d31954fd..56b62de599d 100644 --- a/docs/source/runtime-profiling.md +++ b/docs/source/runtime-profiling.md @@ -20,4 +20,4 @@ We provide access to all the profiling data via the Python [Inspector API](model - Through the Inspector API, users can do a wide range of analysis varying from printing out performance details to doing more finer granular calculation on module level. -Please refer to the [Developer Tools tutorial](https://pytorch.org/executorch/main/tutorials/devtools-integration-tutorial) for a step-by-step walkthrough of the above process on a sample model. +Please refer to the [Developer Tools tutorial](tutorials/devtools-integration-tutorial) for a step-by-step walkthrough of the above process on a sample model. diff --git a/docs/source/runtime.md b/docs/source/runtime.md new file mode 100644 index 00000000000..1d96cc53188 --- /dev/null +++ b/docs/source/runtime.md @@ -0,0 +1,15 @@ +# Runtime + +```{toctree} +:maxdepth: 1 + +runtime-overview +extension-module +extension-tensor +running-a-model-cpp-tutorial +runtime-backend-delegate-implementation-and-linking +runtime-platform-abstraction-layer +portable-cpp-programming +pte-file-format +ptd-file-format +``` diff --git a/docs/source/success-stories.md b/docs/source/success-stories.md new file mode 100644 index 00000000000..5b876437580 --- /dev/null +++ b/docs/source/success-stories.md @@ -0,0 +1,133 @@ +(success-stories)= + +# Success Stories + +Discover how organizations are leveraging ExecuTorch to deploy AI models at scale on edge devices. + +--- + +## Featured Success Stories + +::::{grid} 1 +:gutter: 3 + +:::{grid-item-card} **Meta's Family of Apps** +:class-header: bg-primary text-white + +**Industry:** Social Media & Messaging +**Hardware:** Android & iOS Devices +**Impact:** Billions of users, latency reduction + +Powers Instagram, WhatsApp, Facebook, and Messenger with real-time on-device AI for content ranking, recommendations, and privacy-preserving features at scale. + +[Read Blog →](https://engineering.fb.com/2025/07/28/android/executorch-on-device-ml-meta-family-of-apps/) +::: + +:::{grid-item-card} **Meta Quest & Ray-Ban Smart Glasses** +:class-header: bg-success text-white + +**Industry:** AR/VR & Wearables +**Hardware:** Quest 3, Ray-Ban Meta Smart Glasses, Meta Ray-Ban Display + +Enables real-time computer vision, hand tracking, voice commands, and translation on power-constrained wearable devices. +::: + +:::{grid-item-card} **Liquid AI: Efficient, Flexible On-Device Intelligence** +:class-header: bg-info text-white + +**Industry:** Artificial Intelligence / Edge Computing +**Hardware:** CPU via PyTorch ExecuTorch +**Impact:** 2× faster inference, lower latency, seamless multimodal deployment + +Liquid AI builds foundation models that make AI work where the cloud can't. In its LFM2 series, the team uses PyTorch ExecuTorch within the LEAP Edge SDK to deploy high-performance multimodal models efficiently across devices. ExecuTorch provides the flexibility to support custom architectures and processing pipelines while reducing inference latency through graph optimization and caching. Together, they enable faster, more efficient, privacy-preserving AI that runs entirely on the edge. + +[Read Blog →](https://www.liquid.ai/blog/how-liquid-ai-uses-executorch-to-power-efficient-flexible-on-device-intelligence) +::: + +:::{grid-item-card} **PrivateMind: Complete Privacy with On-Device AI** +:class-header: bg-warning text-white + +**Industry:** Privacy & Personal Computing +**Hardware:** iOS & Android Devices +**Impact:** 100% on-device processing + +PrivateMind delivers a fully private AI assistant using ExecuTorch's .pte format. Built with React Native ExecuTorch, it supports LLaMA, Qwen, Phi-4, and custom models with offline speech-to-text and PDF chat capabilities. + +[Visit →](https://privatemind.swmansion.com) +::: + +:::{grid-item-card} **NimbleEdge: On-Device Agentic AI Platform** +:class-header: bg-danger text-white + +**Industry:** AI Infrastructure +**Hardware:** iOS & Android Devices +**Impact:** 30% higher TPS on iOS, faster time-to-market with Qwen/Gemma models + +NimbleEdge successfully integrated ExecuTorch with its open-source DeliteAI platform to enable agentic workflows orchestrated in Python on mobile devices. The extensible ExecuTorch ecosystem allowed implementation of on-device optimization techniques leveraging contextual sparsity. ExecuTorch significantly accelerated the release of "NimbleEdge AI" for iOS, enabling models like Qwen 2.5 with tool calling support and achieving up to 30% higher transactions per second. + +[Visit →](https://nimbleedge.com) • [Blog →](https://www.nimbleedge.com/blog/meet-nimbleedge-ai-the-first-truly-private-on-device-assistant) • [iOS App →](https://apps.apple.com/in/app/nimbleedge-ai/id6746237456) +::: + +:::: + +--- + +## Featured Ecosystem Integrations and Interoperability + +::::{grid} 2 2 3 3 +:gutter: 2 + +:::{grid-item-card} **Hugging Face Transformers** +:class-header: bg-secondary text-white + +Popular models from Hugging Face easily export to ExecuTorch format for on-device deployment. + +[Learn More →](https://github.com/huggingface/optimum-executorch/) +::: + +:::{grid-item-card} **React Native ExecuTorch** +:class-header: bg-secondary text-white + +Declarative toolkit for running AI models and LLMs in React Native apps with privacy-first, on-device execution. + +[Explore →](https://docs.swmansion.com/react-native-executorch/) • [Blog →](https://expo.dev/blog/how-to-run-ai-models-with-react-native-executorch) +::: + +:::{grid-item-card} **torchao** +:class-header: bg-secondary text-white + +PyTorch-native quantization and optimization library for preparing efficient models for ExecuTorch deployment. + +[Blog →](https://pytorch.org/blog/torchao-quantized-models-and-quantization-recipes-now-available-on-huggingface-hub/) • [Qwen Example →](https://huggingface.co/pytorch/Qwen3-4B-INT8-INT4) • [Phi Example →](https://huggingface.co/pytorch/Phi-4-mini-instruct-INT8-INT4) +::: + +:::{grid-item-card} **Unsloth** +:class-header: bg-secondary text-white + +Optimize LLM fine-tuning with faster training and reduced VRAM usage, then deploy efficiently with ExecuTorch. + +[Example Model →](https://huggingface.co/metascroy/Qwen3-4B-int8-int4-unsloth) • [Blog →](https://docs.unsloth.ai/new/quantization-aware-training-qat) +::: + +:::{grid-item-card} **Ultralytics** +:class-header: bg-secondary text-white + +Deploy on-device inference for Ultralytics YOLO models using ExecuTorch. +[Explore →](https://docs.ultralytics.com/integrations/executorch/) +::: + +:::: + +--- + +## Featured Demos + +- **Text and Multimodal LLM demo mobile apps** - Text (Llama, Qwen3, Phi-4) and multimodal (Gemma3, Voxtral) mobile demo apps. [Try →](https://github.com/meta-pytorch/executorch-examples/tree/main/llm) + +- **Voxtral** - Deploy audio-text-input LLM on CPU (via XNNPACK) and on CUDA. [Try →](https://github.com/pytorch/executorch/blob/main/examples/models/voxtral/README.md) + +- **LoRA adapter** - Export two LoRA adapters that share a single foundation weight file, saving memory and disk space. [Try →](https://github.com/meta-pytorch/executorch-examples/tree/main/program-data-separation/cpp/lora_example) + +- **OpenVINO from Intel** - Deploy [Yolo12](https://github.com/pytorch/executorch/tree/main/examples/models/yolo12), [Llama](https://github.com/pytorch/executorch/tree/main/examples/openvino/llama), and [Stable Diffusion](https://github.com/pytorch/executorch/tree/main/examples/openvino/stable_diffusion) on [OpenVINO from Intel](https://www.intel.com/content/www/us/en/developer/articles/community/optimizing-executorch-on-ai-pcs.html). + +*Want to showcase your demo? [Submit here →](https://github.com/pytorch/executorch/issues)* diff --git a/docs/source/support-section.md b/docs/source/support-section.md new file mode 100644 index 00000000000..64c47a3e55b --- /dev/null +++ b/docs/source/support-section.md @@ -0,0 +1,17 @@ +(support-section)= +# Support + +In this section, find answers to common questions, troubleshooting guides, and information on how to contribute to the ExecuTorch project. Get help with issues and learn how to participate in the community. + +- {doc}`using-executorch-faqs` — FAQ +- {doc}`using-executorch-troubleshooting` — Common Issues +- {doc}`contributing` — Contributing + +```{toctree} +:hidden: +:maxdepth: 1 +:caption: Support + +using-executorch-faqs +using-executorch-troubleshooting +contributing diff --git a/docs/source/tools-section.md b/docs/source/tools-section.md new file mode 100644 index 00000000000..c54b4933c44 --- /dev/null +++ b/docs/source/tools-section.md @@ -0,0 +1,32 @@ +(tools-sdk-section)= + +# Tools + +In this section, explore ExecuTorch's comprehensive developer tools for profiling, debugging, and model inspection. These tools help optimize performance and troubleshoot issues during development and deployment. + +- {doc}`devtools-overview` — Developer Tools Overview +- {doc}`bundled-io` — Bundled I/O +- {doc}`etrecord` — ETRecord +- {doc}`etdump` — ETDump +- {doc}`runtime-profiling` — Profiling Suite +- {doc}`model-debugging` — Debugging Tools +- {doc}`model-inspector` — Model Inspector +- {doc}`memory-planning-inspection` — Memory Planning Inspection +- {doc}`devtools-tutorial` — Development Utilities +- {doc}`visualization` — Model Visualization + +```{toctree} +:hidden: +:maxdepth: 1 +:caption: Tools + +devtools-overview +bundled-io +etrecord +etdump +runtime-profiling +model-debugging +model-inspector +memory-planning-inspection +devtools-tutorial +visualization diff --git a/docs/source/tutorial-arm.md b/docs/source/tutorial-arm.md deleted file mode 100644 index 0692b631154..00000000000 --- a/docs/source/tutorial-arm.md +++ /dev/null @@ -1,467 +0,0 @@ -# Arm® Backend Tutorial - - -::::{grid} 2 - -:::{grid-item-card} Tutorials we recommend you complete before this: -:class-card: card-prerequisites -* [Introduction to ExecuTorch](intro-how-it-works.md) -* [Getting Started](getting-started.md) -* [Building ExecuTorch with CMake](using-executorch-building-from-source.md) -::: - -:::{grid-item-card} What you will learn in this tutorial: -:class-card: card-prerequisites -In this tutorial you will learn how to export a simple PyTorch model for ExecuTorch Arm backends. -::: - -:::: - -```{warning} -This delegate is under active development, to get best results please use a recent version. -The TOSA and Ethos(tm) backend support is reasonably mature and used in production by some users. -The VGF backend support is in early development and you may encounter issues. -You may encounter some rough edges and features which may be documented or planned but not implemented, please refer to the in-tree documentation for the latest status of features. -``` - -```{tip} -If you are already familiar with this delegate, you may want to jump directly to the examples: -* [Examples in the ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm) -* [Compilation for Ethos-U](https://github.com/pytorch/executorch/blob/main/examples/arm/ethos_u_minimal_example.ipynb) -* [A commandline compiler for example models](https://github.com/pytorch/executorch/blob/main/examples/arm/aot_arm_compiler.py) -``` - -## Prerequisites - -Let's make sure you have everything you need before you get started. - -### Hardware - -To successfully complete this tutorial, you will need a Linux or MacOS host machine with Arm aarch64 or x86_64 processor architecture. - -The target device will be an emulated platform to enable development without a specific development board. This tutorial has guidance for both Ethos-U targets and VGF via the ML SDK for Vulkan®. - -For Ethos-U and Cortex-M, We will be using a [Fixed Virtual Platform (FVP)](https://www.arm.com/products/development-tools/simulation/fixed-virtual-platforms), simulating [Corstone-300](https://developer.arm.com/Processors/Corstone-300)(cs300) and [Corstone-320](https://developer.arm.com/Processors/Corstone-320)(cs320)systems. Since we will be using the FVP (think of it as virtual hardware), we won't be requiring any real embedded hardware for this tutorial. - -For VGF we will be using the [ML SDK for Vulkan(R)](https://github.com/arm/ai-ml-sdk-for-vulkan/)) to emulate the program consumer. - -### Software - -First, you will need to install ExecuTorch. Please follow the recommended tutorials if you haven't already, to set up a working ExecuTorch development environment. For the VGF backend it's recommended you [install from source](https://docs.pytorch.org/executorch/stable/using-executorch-building-from-source.html), or from a [nightly](https://download.pytorch.org/whl/nightly/executorch/). - -In addition to this, you need to install a number of SDK dependencies for generating Ethos-U command streams or VGF files. There are scripts which automate this, which are found in the main [ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm/). - -## Set Up the Developer Environment - -In this section, we will do a one-time setup of the platform support files needed to run ExecuTorch programs in this tutorial. It is recommended to run the script in a conda or venv environment. - -With a checkout of the ExecuTorch repository, we will use the `examples/arm/setup.sh` script to pull each item in an automated fashion. - -For Ethos-U run: -```bash -./examples/arm/setup.sh --i-agree-to-the-contained-eula -``` - -For VGF run: -```bash -./examples/arm/setup.sh --i-agree-to-the-contained-eula --disable-ethos-u-deps --enable-mlsdk-deps -``` -It is possible to install both sets of dependencies if you omit the disable options. - - -### Notes: - -```{warning} -The `setup.sh` script has generated a `setup_path.sh` script that you need to source whenever you restart your shell. -``` - -i.e. run -`source executorch/examples/arm/ethos-u-scratch/setup_path.sh` - - -To confirm your environment is set up correctly and will enable you to generate .pte's for your target: - -For Ethos-U run: -```bash -# Check for Vela, which converts TOSA to Ethos-U command streams. -which vela -``` - -For VGF run: -```bash -# Check for model-converter, which converts TOSA to ML-SDK VGF format. -which model-converter -``` - -To ensure there's no environment pollution you should confirm these binaries reside within your executorch checkout, under the examples/arm tree. Other versions may present compatibility issues, so this should be corrected by modifying your environment variables such as ${PATH} appropriately. - - -## Convert the PyTorch Model to the `.pte` File - -`.pte` is a binary file produced by ExecuTorch Ahead-of-Time (AoT) pipeline by taking in a PyTorch Model (a torch.nn.Module), exporting it, running a variety of passes, and finally serializing it to a `.pte` file format. This binary file is typically consumed by the ExecuTorch Runtime. This [document](https://github.com/pytorch/executorch/blob/main/docs/source/getting-started-architecture.md) goes in much more depth about the ExecuTorch software stack for both AoT as well as Runtime. - -In this section, we will primarily focus on the AoT flow with the end goal of producing a `.pte` file. There are a set of export configurations to target different backends at runtime. For each, the AoT flow will produce a unique `.pte` file. We will explore a couple of different configurations producing different `.pte` files, particularly interesting for our Corstone-300 system and available processing elements. - -Before we get started, let's first talk about the PyTorch modules we will be using. - -### PyTorch Example Modules -We will use a couple of simple PyTorch Modules to explore the end-to-end flow. These modules will be used in various different ways throughout the tutorial, referring to them by their ``. - -#### SoftmaxModule -This is a very simple PyTorch module with just one [Softmax](https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html#torch.nn.Softmax) operator. - -```python -import torch - -class SoftmaxModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.softmax = torch.nn.Softmax() - - def forward(self, x): - z = self.softmax(x) - return z -``` - -Running it using the Python environment (on the same development Linux machine), you get the expected output. - -```python ->>> m = SoftmaxModule() ->>> m(torch.ones(2,2)) -tensor([[0.5000, 0.5000], - [0.5000, 0.5000]]) -``` - -#### AddModule -Let's write another simple PyTorch module with just one [Add](https://pytorch.org/docs/stable/generated/torch.add.html#torch.add) operator. - -```python -class AddModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x + x -``` - -Running it in python shows that 1 + 1 produces 2 as exepected: - -```python ->>> m = AddModule() ->>> m(torch.ones(5, dtype=torch.int32)) # integer types for non-quantized Ethos-U delegation -tensor([2, 2, 2, 2, 2], dtype=torch.int32) -``` -Keep the inputs and outputs to these modules in mind. When you will lower and run this through alternate means as opposed to running on this Linux machine, you will use the same inputs, and expect the outputs to match with the one shown here. - -```{tip} -you need to be aware of data types for running networks on the Ethos-U as it is an integer only co-processor. For this example you use integer types explicitly, for typical use of such a flow networks are built and trained in floating point, and then are quantized from floating point to integer for efficient inference. -``` - -#### MobileNetV2 Module -[MobileNetV2](https://arxiv.org/abs/1801.04381) is a commonly used network for edge and mobile devices. -It's also available as a default model in [torchvision](https://github.com/pytorch/vision), so you can load it with the sample code below. -``` -from torchvision.models import mobilenet_v2 # @manual -from torchvision.models.mobilenetv2 import MobileNet_V2_Weights - -mv2 = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT) -``` -For more details, refer to the code snippet [here](https://github.com/pytorch/executorch/blob/2354945d47f67f60d9a118ea1a08eef8ba2364b5/examples/models/mobilenet_v2/model.py#L18). - -### Non-delegated Workflow - -In the ExecuTorch AoT pipeline, one of the options is to select a backend. ExecuTorch offers a variety of different backends. Selecting backend is optional, it is typically done to target a particular mode of acceleration or hardware for a given model compute requirements. Without any backends, ExecuTorch runtime will fallback to using, available by default, a highly portable set of operators. - -It's expected that on platforms with dedicated acceleration like the Ethos-U55, that the non-delegated flow is used for two primary cases: -1. When the network is designed to be very small and best suited to run on the Cortex-M alone. -2. When the network has a mix of operations that can target the NPU and those that can't, e.g. the Ethos-U55 supports integer operations and so floating point softmax will fall back to execute on the CPU. - -In this flow, without any backend delegates, to illustrate the portability of the ExecuTorch runtime, as well as of the operator library you will skip specifying the backend during the `.pte` generation. - -Following script will serve as a helper utility to help generating the `.pte` file. This is available in the `examples/arm` directory. - -```bash -python3 -m examples.arm.aot_arm_compiler --model_name="softmax" -# This should produce ./softmax_arm_ethos-u55-128.pte -``` - -### Delegated Workflow - -Working with Arm, you introduced a new Arm backend delegate for ExecuTorch. This backend is under active development and has a limited set of features available as of writing this. - -By including a following step during the ExecuTorch AoT export pipeline to generate the `.pte` file, you can enable this backend delegate. - -```python -from executorch.backends.arm.arm_backend import generate_ethosu_compile_spec - -graph_module_edge.exported_program = to_backend( - model.exported_program, - ArmPartitioner(generate_ethosu_compile_spec("ethos-u55-128"))) -``` - -Similar to the non-delegate flow, the same script will server as a helper utility to help generate the `.pte` file. Notice the `--delegate` option to enable the `to_backend` call. - -For Ethos targets: -```bash -python3 -m examples.arm.aot_arm_compiler --model_name="add" --delegate -# This targets the default of ethos-u55-128, see --help for further targets -# should produce ./add_arm_delegate_ethos-u55-128.pte -``` - -For basic post-training quantization: -```bash -python3 -m examples.arm.aot_arm_compiler --model_name="mv2" --delegate --quantize -# This targets the default of ethos-u55-128, see --help for further targets -# should produce ./mv2_arm_delegate_ethos-u55-128.pte -``` - - -For VGF targets: -```bash -python3 -m examples.arm.aot_arm_compiler --model_name="add" --target=vgf --delegate -# should produce ./add_arm_delegate_vgf.pte -``` - -For basic post-training quantization: -```bash -python3 -m examples.arm.aot_arm_compiler --model_name="mv2" --target=vgf --delegate --quantize -# should produce ./mv2_arm_delegate_vgf.pte -``` - -To capture intermediates such as VGF for lower level integration, invoke with the "-i" option: -```bash -python3 -m examples.arm.aot_arm_compiler --model_name="mv2" --target=vgf --delegate --quantize -i ./mv2_output -# should produce ./mv2_arm_delegate_vgf.pte and intermediates in ./mv2_out/ -``` - -
- -At the end of this, you should have a number of different `.pte` files. - -- the SoftmaxModule, without any backend delegates. -- the AddModule, targeting the Arm Ethos-U backend. -- the Quantized MV2Model, targeting the Arm Ethos-U backend. -- the AddModule, targeting the VGF backend. -- the Quantized MV2Model, targeting the VGF backend. - -Now let's try to run these `.pte` files on a target. - -## Getting a Bare-Metal Executable - -In this section, you will go over steps that you need to go through to build the runtime application. This then run on the target device. In the executorch repository you have a functioning script which does the exact same steps. It is located at `executorch/examples/arm/run.sh`. You will use that to build necessary pieces and finally run the previously generated PTE file on an FVP. - -By default the `run.sh` will use `arm_test/` as an build and output folder and you will find the build artifacts under it. This can be controlled/overrided with the `--et_build_root` and the `--output` flags if needed. - -e.g. running `examples/arm/run.sh --model_name=add --target=ethos-u85-128` will produce a pte and elf file like this: - -```bash -arm_test/add/add_arm_delegate_ethos-u85-128.pte -arm_test/add/cmake-out/arm_executor_runner -``` -Also before you get started, make sure that you have completed ExecuTorch cmake build setup, and the instructions to setup the development environment described [earlier](#set-up-the-developer-environment). - -The block diagram below demonstrates, at the high level, how the various build artifacts are generated and are linked together to generate the final bare-metal executable. - -![](arm-delegate-runtime-build.svg) - -```{tip} -The `generate_pte_file` function in `run.sh` script produces the `.pte` files based on the models provided through `--model_name` input argument -``` - -### Generating ExecuTorch Libraries - -ExecuTorch's CMake build system produces a set of build pieces which are critical to building the ExecuTorch runtime with-in the bare-metal environment you have for Corstone FVPs from Ethos-U SDK. - -[This](using-executorch-building-from-source.md) document provides a detailed overview of each individual build piece. For running either variant of the `.pte` file, you will need a core set of libraries. Here is a list, - -- `libexecutorch.a` -- `libportable_kernels.a` -- `libportable_ops_lib.a` - -To run a `.pte` file with the Arm backend delegate call instructions, you will need the Arm backend delegate runtime library, that is, - -- `libexecutorch_delegate_ethos_u.a` - -These libraries are generated by the `backends/arm/scripts/build_executorch.sh` script called from the `run.sh` script. - -### Building the executor_runner Bare-Metal Application - -The SDK dir is the same one prepared [earlier](#setup-the-arm-ethos-u-software-development). And, you will be passing the `.pte` file (any one of them) generated above. - -Note, you have to generate a new `executor-runner` binary if you want to change the model or the `.pte` file. This constraint is from the constrained bare-metal runtime environment you have for Corstone-300/Corstone-320 platforms. The build also generates a kernel registration library for the relevant operators which could not be delegated to the EthosU, see the [Kernel Library Selective Build documentation](https://docs.pytorch.org/executorch/stable/kernel-library-selective-build.html). - -This step is executed by the build_executor_runner.sh script, which is invoked from the run.sh in the backends/arm/scripts folder. - -```{tip} -The `run.sh` script takes in `--target` option, which provides a way to provide a specific target, Corstone-300(ethos-u55-128) or Corstone-320(ethos-u85-128) -``` - -## Running on Corstone FVP Platforms - -Once the elf is prepared, regardless of the `.pte` file variant is used to generate the bare metal elf. `run.sh` will run the FVP for you via the `backends/arm/scripts/run_fvp.sh` script. - -#### Automatic FVP Selection - -- To run a specific test model with the compiler flag and target -```bash -./run.sh --model_name=mv2 --delegate --quantize --target=ethos-u85-128 -``` - -- To run a specific test model and target -```bash -./run.sh --model_name=mv2 --delegate --target=ethos-u85-128 -``` - -- To run all the test models iteratively in a loop , simply run -```bash -./run.sh -``` - -Note that you could use `build_executor_runner.sh` and `run_fvp.sh` scripts in tandem by passing the relevant --target argument (e.g., --target=ethos-u55-128), the correct FVP binary will be chosen automatically. For more details, see the [section on Runtime Integration](https://docs.pytorch.org/executorch/main/backends-arm-ethos-u.html#runtime-integration). - - -#### Manual FVP Binary Selection - -- If you build for the Ethos delegate U55/U65 target (e.g., using --target=ethos-u55-128 or --target=ethos-u65-256 with `build_executor_runner.sh` and `run_fvp.sh`), you should use the corresponding FVP binary: - - For U55: - ```bash - examples/arm/ethos-u-scratch/FVP-corstone300/models/Linux64_GCC-9.3/FVP_Corstone_SSE-300_Ethos-U55 - ``` - - For U65: - ```bash - examples/arm/ethos-u-scratch/FVP-corstone300/models/Linux64_GCC-9.3/FVP_Corstone_SSE-300_Ethos-U65 - ``` -- And say if you are not building for an Ethos target, use: - ```bash - examples/arm/ethos-u-scratch/FVP-corstone320/models/Linux64_GCC-9.3/FVP_Corstone_SSE-320 - ``` - -Following is an example usage: - -```bash -ethos_u_build_dir=examples/arm/executor_runner/ - -elf=$(find ${ethos_u_build_dir} -name "arm_executor_runner") - -FVP_Corstone_SSE-320 \ - -C mps4_board.subsystem.ethosu.num_macs=128 \ - -C mps4_board.visualisation.disable-visualisation=1 \ - -C vis_hdlcd.disable_visualisation=1 \ - -C mps4_board.telnetterminal0.start_telnet=0 \ - -C mps4_board.uart0.out_file='-' \ - -C mps4_board.uart0.shutdown_on_eot=1 \ - -a "${elf}" \ - --timelimit 120 || true # seconds- after which sim will kill itself -``` - -#### Verification of Successful FVP Execution -After running the FVP command, either automatically or manually, you should see output similar to the following on your shell if the execution is successful: - -```console -I [executorch:arm_executor_runner.cpp:364] Model in 0x70000000 $ -I [executorch:arm_executor_runner.cpp:366] Model PTE file loaded. Size: 4425968 bytes. -I [executorch:arm_executor_runner.cpp:376] Model buffer loaded, has 1 methods -I [executorch:arm_executor_runner.cpp:384] Running method forward -I [executorch:arm_executor_runner.cpp:395] Setup Method allocator pool. Size: 62914560 bytes. -I [executorch:arm_executor_runner.cpp:412] Setting up planned buffer 0, size 752640. -I [executorch:ArmBackendEthosU.cpp:79] ArmBackend::init 0x70000070 -I [executorch:arm_executor_runner.cpp:445] Method loaded. -I [executorch:arm_executor_runner.cpp:447] Preparing inputs... -I [executorch:arm_executor_runner.cpp:461] Input prepared. -I [executorch:arm_executor_runner.cpp:463] Starting the model execution... -I [executorch:ArmBackendEthosU.cpp:118] ArmBackend::execute 0x70000070 -I [executorch:ArmBackendEthosU.cpp:298] Tensor input/output 0 will be permuted -I [executorch:arm_perf_monitor.cpp:120] NPU Inferences : 1 -I [executorch:arm_perf_monitor.cpp:121] Profiler report, CPU cycles per operator: -I [executorch:arm_perf_monitor.cpp:125] ethos-u : cycle_cnt : 1498202 cycles -I [executorch:arm_perf_monitor.cpp:132] Operator(s) total: 1498202 CPU cycles -I [executorch:arm_perf_monitor.cpp:138] Inference runtime: 6925114 CPU cycles total -I [executorch:arm_perf_monitor.cpp:140] NOTE: CPU cycle values and ratio calculations require FPGA and identical CPU/NPU frequency -I [executorch:arm_perf_monitor.cpp:149] Inference CPU ratio: 99.99 % -I [executorch:arm_perf_monitor.cpp:153] Inference NPU ratio: 0.01 % -I [executorch:arm_perf_monitor.cpp:162] cpu_wait_for_npu_cntr : 729 CPU cycles -I [executorch:arm_perf_monitor.cpp:167] Ethos-U PMU report: -I [executorch:arm_perf_monitor.cpp:168] ethosu_pmu_cycle_cntr : 5920305 -I [executorch:arm_perf_monitor.cpp:171] ethosu_pmu_cntr0 : 359921 -I [executorch:arm_perf_monitor.cpp:171] ethosu_pmu_cntr1 : 0 -I [executorch:arm_perf_monitor.cpp:171] ethosu_pmu_cntr2 : 0 -I [executorch:arm_perf_monitor.cpp:171] ethosu_pmu_cntr3 : 503 -I [executorch:arm_perf_monitor.cpp:178] Ethos-U PMU Events:[ETHOSU_PMU_EXT0_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT1_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT0_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_NPU_IDLE] -I [executorch:arm_executor_runner.cpp:470] model_pte_loaded_size: 4425968 bytes. -I [executorch:arm_executor_runner.cpp:484] method_allocator_used: 1355722 / 62914560 free: 61558838 ( used: 2 % ) -I [executorch:arm_executor_runner.cpp:491] method_allocator_planned: 752640 bytes -I [executorch:arm_executor_runner.cpp:493] method_allocator_loaded: 966 bytes -I [executorch:arm_executor_runner.cpp:494] method_allocator_input: 602116 bytes -I [executorch:arm_executor_runner.cpp:495] method_allocator_executor: 0 bytes -I [executorch:arm_executor_runner.cpp:498] temp_allocator_used: 0 / 1048576 free: 1048576 ( used: 0 % ) -I [executorch:arm_executor_runner.cpp:152] Model executed successfully. -I [executorch:arm_executor_runner.cpp:156] 1 outputs: -Output[0][0]: -0.749744 -Output[0][1]: -0.019224 -Output[0][2]: 0.134570 -...(Skipped) -Output[0][996]: -0.230691 -Output[0][997]: -0.634399 -Output[0][998]: -0.115345 -Output[0][999]: 1.576386 -I [executorch:arm_executor_runner.cpp:177] Program complete, exiting. -I [executorch:arm_executor_runner.cpp:179] -``` - -```{note} -The `run.sh` script provides various options to select a particular FVP target, use desired models, select portable kernels and can be explored using the `--help` argument -``` - -## Running on the VGF backend with the standard executor_runner for Linux - -Follow typical [Building ExecuTorch with CMake](using-executorch-building-from-source.md) flow to build the linux target, ensuring that the VGF delegate is enabled. - -```bash --DEXECUTORCH_BUILD_VGF=ON -``` - -A full example buld line is: -``` -cmake bash \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE=Release \ - -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ - -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ - -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ - -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ - -DEXECUTORCH_BUILD_XNNPACK=OFF \ - -DEXECUTORCH_BUILD_VULKAN=ON \ - -DEXECUTORCH_BUILD_VGF=ON \ - -DEXECUTORCH_ENABLE_LOGGING=ON \ - -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ - -DPYTHON_EXECUTABLE=python \ - -Bcmake-out . -cmake --build cmake-out -j25 --target install --config Release -``` - -You can then invoke the executor runner on the host machine, which will use the VGF delegate, and requires the vulkan layer drivers we installed with setup.sh. - -```bash -./cmake-out/executor_runner -model_path add_arm_delegate_vgf.pte -``` - - -## Takeaways -In this tutorial you have learnt how to use the ExecuTorch software to both export a standard model from PyTorch and to run it on the compact and fully functioned ExecuTorch runtime, enabling a smooth path for offloading models from PyTorch to Arm based platforms. - -To recap, there are two major flows: - * A direct flow which offloads work onto the Cortex-M using libraries built into ExecuTorch. - * A delegated flow which partitions the graph into sections for Cortex-M and sections which can be offloaded and accelerated on the Ethos-U hardware. - -Both of these flows continue to evolve, enabling more use-cases and better performance. - -## FAQs - - -If you encountered any bugs or issues following this tutorial please file a bug/issue here on [Github](https://github.com/pytorch/executorch/issues/new). diff --git a/docs/source/tutorial-template.md b/docs/source/tutorial-template.md index b25731afa17..73b787c9e2c 100644 --- a/docs/source/tutorial-template.md +++ b/docs/source/tutorial-template.md @@ -9,12 +9,12 @@ :::{grid-item-card} Tutorials we recommend you complete before this: :class-card: card-prerequisites * [Introduction to ExecuTorch](intro-how-it-works.md) -* [Setting up ExecuTorch](getting-started-setup.md) -* [Building ExecuTorch with CMake](runtime-build-and-cross-compilation.md) +* [Setting up ExecuTorch](getting-started-setup.rst) +* [Building ExecuTorch with CMake](using-executorch-building-from-source.md) ::: :::: -## Prerequsites (Hardware and Software) +## Prerequisites (Hardware and Software) Provide instructions on what kind of hardware and software are pre-requisite for the tutorial. diff --git a/docs/source/tutorial-xnnpack-delegate-lowering.md b/docs/source/tutorial-xnnpack-delegate-lowering.md index bccd4e4add3..5c88246b0ba 100644 --- a/docs/source/tutorial-xnnpack-delegate-lowering.md +++ b/docs/source/tutorial-xnnpack-delegate-lowering.md @@ -11,8 +11,8 @@ In this tutorial, you will learn how to export an XNNPACK lowered Model and run :::{grid-item-card} Before you begin it is recommended you go through the following: :class-card: card-prerequisites * [Setting up ExecuTorch](getting-started-setup.rst) -* [Model Lowering Tutorial](https://pytorch.org/executorch/main/tutorials/export-to-executorch-tutorial) -* [ExecuTorch XNNPACK Delegate](backends-xnnpack.md) +* [Model Lowering Tutorial](tutorials/export-to-executorch-tutorial) +* [ExecuTorch XNNPACK Delegate](backends/xnnpack/xnnpack-overview.md) ::: :::: @@ -74,7 +74,7 @@ After lowering to the XNNPACK Program, we can then prepare it for executorch and ## Lowering a Quantized Model to XNNPACK -The XNNPACK delegate can also execute symmetrically quantized models. To understand the quantization flow and learn how to quantize models, refer to [Custom Quantization](quantization-custom-quantization.md) note. For the sake of this tutorial, we will leverage the `quantize()` python helper function conveniently added to the `executorch/executorch/examples` folder. +The XNNPACK delegate can also execute symmetrically quantized models. To understand the quantization flow and learn how to quantize models, refer to [Quantization Overview](quantization-overview.md). For the sake of this tutorial, we will leverage the `quantize()` python helper function conveniently added to the `executorch/executorch/examples` folder. ```python from torch.export import export diff --git a/docs/source/tutorials_source/bundled_program.bp b/docs/source/tutorials_source/bundled_program.bp deleted file mode 100644 index 8afe3cfee26..00000000000 Binary files a/docs/source/tutorials_source/bundled_program.bp and /dev/null differ diff --git a/docs/source/usage.md b/docs/source/usage.md new file mode 100644 index 00000000000..6ffc136093b --- /dev/null +++ b/docs/source/usage.md @@ -0,0 +1,19 @@ +# Usage + +This section describes how to use Executorch. It covers everything from +getting started to platform-specific implementations, runtime integration, +troubleshooting, and frequently asked questions. + +```{toctree} +:maxdepth: 1 + +getting-started +using-executorch-export +using-executorch-android +using-executorch-ios +using-executorch-cpp +using-executorch-runtime-integration +using-executorch-troubleshooting +using-executorch-building-from-source +using-executorch-faqs +``` diff --git a/docs/source/using-executorch-android.md b/docs/source/using-executorch-android.md index 23513302063..e097722b8e6 100644 --- a/docs/source/using-executorch-android.md +++ b/docs/source/using-executorch-android.md @@ -1,12 +1,20 @@ + # Using ExecuTorch on Android -To use from Android, ExecuTorch provides Java/Kotlin API bindings and Android platform integration, available as an AAR file. +🚀 Quick Start: __New to ExecuTorch__ ? Jump to [Using AAR from Maven Central](#using-aar-from-maven-central) for the fastest setup, then see the [Runtime Integration](#runtime-integration) example. -Note: This page covers Android app integration through the AAR library. The ExecuTorch C++ APIs can also be used from Android native, and the documentation can be found on [this page about cross compilation](using-executorch-building-from-source.md#cross-compilation). +To use from Android, ExecuTorch provides Java/Kotlin API bindings and Android platform integration, available as an AAR file. +Note: This page covers Android app integration through the AAR library. The ExecuTorch C++ APIs can also be used from Android native, and the documentation can be found on this page about cross compilation. ## Installation -All ExecuTorch Android libraries are packaged into an [Android library (AAR)](https://developer.android.com/studio/projects/android-library), `executorch.aar` for both generic (image/audio processing) and LLM (LLaMA) use case. In each release, prebuilt AAR artifacts are uploaded to [Maven](https://repo.maven.apache.org/maven2/org/pytorch/executorch-android/) and S3. Users can also build the AAR from source. +__Choose your installation method:__ + +- __[Maven Central](#using-aar-from-maven-central)__ (recommended): Easiest for most developers +- __[Direct AAR file](#using-aar-file-directly)__: For specific versions or offline development +- __[Build from source](#building-from-source)__: For custom backends or contributions + +All ExecuTorch Android libraries are packaged into an Android library (AAR), executorch.aar for both generic (image/audio processing) and LLM (LLaMA) use case. In each release, prebuilt AAR artifacts are uploaded to Maven and S3. Users can also build the AAR from source. ### Contents of library @@ -14,52 +22,63 @@ The AAR artifact contains the Java library for users to integrate with their Jav - [Java library](https://github.com/pytorch/executorch/tree/main/extension/android/executorch_android/src/main/java/org/pytorch/executorch) - JNI contains the JNI binding for the corresponding Java code, and ExecuTorch native library, including - - core ExecuTorch runtime libraries + - Core ExecuTorch runtime libraries - XNNPACK backend - Portable kernels - Optimized kernels - Quantized kernels - LLaMa-specific Custom ops library. -- Comes with two ABI variants, arm64-v8a and x86\_64. +- Comes with two ABI variants, arm64-v8a and x86_64. The AAR library can be used for generic Android device with arm64-v8a or x86_64 architecture. It can be used across form factors, including phones, tablets, tv boxes, etc, as it does not contain any UI components. ## Using AAR from Maven Central -ExecuTorch is available on [Maven Central](https://mvnrepository.com/artifact/org.pytorch/executorch-android). - -Simply add the target [`org.pytorch:executorch-android:${executorch_version}`](https://repo.maven.apache.org/maven2/org/pytorch/executorch-android/${executorch_version}/) to your Android app dependency (build.gradle), and build your app. +✅ Recommended for most developers +ExecuTorch is available on Maven Central. +Simply add the target org.pytorch:executorch-android:${executorch_version} to your Android app dependency (build.gradle), and build your app. For example: -For example: -``` -# app/build.gradle.kts +```kotlin +app/build.gradle.kts dependencies { - implementation("org.pytorch:executorch-android:${executorch_version}") +implementation("org.pytorch:executorch-android:${executorch_version}") } ``` -Note: If you want to use release v0.5.0, please use dependency `org.pytorch:executorch-android:0.5.1`. - -Click the screenshot below to watch the *demo video* on how to add the package and run a simple ExecuTorch model with Android Studio. +Note: If you want to use release v1.0.0, please use dependency org.pytorch:executorch-android:1.0.0. +Click the screenshot below to watch the demo video on how to add the package and run a simple ExecuTorch model with Android Studio. - Integrating and Running ExecuTorch on Android +Integrating and Running ExecuTorch on Android ## Using AAR file directly You can also directly specify an AAR file in the app. We upload pre-built AAR to S3 during each release, or as a snapshot. -### Released versions (recommended) +### Latest Released versions (Recommended) + +Starting from [v1.0.0](https://github.com/pytorch/executorch/releases/tag/v1.0.0), there are respective executorch.aar library available by backends + +| AAR | SHASUMS | Backend | +| ------- | --- | ------- | +| [executorch.aar](https://ossci-android.s3.amazonaws.com/executorch/release/1.0.0-xnnpack/executorch.aar) | [executorch.aar.sha256sums](https://ossci-android.s3.amazonaws.com/executorch/release/1.0.0-xnnpack/executorch.aar.sha256sums) | [XNNPACK](backends/xnnpack/xnnpack-overview.md) | +| [executorch.aar](https://ossci-android.s3.amazonaws.com/executorch/release/1.0.0-qnn/executorch.aar) | [executorch.aar.sha256sums](https://ossci-android.s3.amazonaws.com/executorch/release/1.0.0-qnn/executorch.aar.sha256sums) | [Qualcomm AI Engine](backends-qualcomm.md) | +| [executorch.aar](https://ossci-android.s3.amazonaws.com/executorch/release/1.0.0-vulkan/executorch.aar) | [executorch.aar.sha256sums](https://ossci-android.s3.amazonaws.com/executorch/release/1.0.0-vulkan/executorch.aar.sha256sums) | [Vulkan](backends/vulkan/vulkan-overview.md) | + +### Older Released versions + +Download the older released version | Version | AAR | SHASUMS | | ------- | --- | ------- | -| [${executorch_version}](https://github.com/pytorch/executorch/releases/tag/${executorch_version}) | [executorch.aar](https://ossci-android.s3.amazonaws.com/executorch/release/${executorch_version}/executorch.aar) | [executorch.aar.sha256sums](https://ossci-android.s3.amazonaws.com/executorch/release/${executorch_version}/executorch.aar.sha256sums) | +| [v0.7.0](https://github.com/pytorch/executorch/releases/tag/v0.7.0) | [executorch.aar](https://ossci-android.s3.amazonaws.com/executorch/release/v0.7.0/executorch.aar) | [executorch.aar.sha256sums](https://ossci-android.s3.amazonaws.com/executorch/release/v0.7.0/executorch.aar.sha256sums) | | [v0.6.0](https://github.com/pytorch/executorch/releases/tag/v0.6.0) | [executorch.aar](https://ossci-android.s3.amazonaws.com/executorch/release/v0.6.0/executorch.aar) | [executorch.aar.sha256sums](https://ossci-android.s3.amazonaws.com/executorch/release/v0.6.0/executorch.aar.sha256sums) | | [v0.5.0](https://github.com/pytorch/executorch/releases/tag/v0.5.0) | [executorch.aar](https://ossci-android.s3.amazonaws.com/executorch/release/v0.5.0-rc3/executorch.aar) | [executorch.aar.sha256sums](https://ossci-android.s3.amazonaws.com/executorch/release/v0.5.0-rc3/executorch.aar.sha256sums) | ### Snapshots from main branch Starting from 2025-04-12, you can download nightly `main` branch snapshots: + * `executorch.aar`: `https://ossci-android.s3.amazonaws.com/executorch/release/snapshot-{YYYYMMDD}/executorch.aar` * `executorch.aar.sha256sums`: `https://ossci-android.s3.amazonaws.com/executorch/release/snapshot-{YYYYMMDD}/executorch.aar.sha256sums` * Replace `YYYYMMDD` with the actual date you want to use. @@ -72,35 +91,37 @@ curl -O https://ossci-android.s3.amazonaws.com/executorch/release/snapshot-20250 curl -O https://ossci-android.s3.amazonaws.com/executorch/release/snapshot-20250412/executorch.aar.sha256sums ``` -We aim to make every daily snapshot available and useable. However, for best stability, please use releases, not snapshots. +We aim to make every daily snapshot available and usable. However, for best stability, please use releases, not snapshots. ## Using AAR file To add the AAR file to your app: -1. Download the AAR. -2. Add it to your gradle build rule as a file path. +Download the AAR. +Add it to your gradle build rule as a file path. +An AAR file itself does not contain dependency info, unlike the Maven one which bundled with pom.xml. The Java package requires fbjni and soloader, and currently requires users to explicitly declare the dependency. Therefore, two more dependencies in gradle rule is required: -An AAR file itself does not contain dependency info, unlike the Maven one which bundled with pom.xml. The Java package requires `fbjni` and `soloader`, and currently requires users to explicitly declare the dependency. Therefore, two more `dependencies` in gradle rule is required: -``` +```kotlin implementation("com.facebook.soloader:soloader:0.10.5") -implementation("com.facebook.fbjni:fbjni:0.5.1") +implementation("com.facebook.fbjni:fbjni:0.7.0") ``` ### Example usage -In your app working directory, such as executorch/examples/demo-apps/android/LlamaDemo, -``` +In your app working directory, such as executorch-examples/llm/android/LlamaDemo, + +```sh mkdir -p app/libs curl https://ossci-android.s3.amazonaws.com/executorch/release/${executorch_version}/executorch.aar -o app/libs/executorch.aar ``` And include it in gradle: -``` -# app/build.gradle.kts + +```kotlin +app/build.gradle.kts dependencies { - implementation(files("libs/executorch.aar")) - implementation("com.facebook.soloader:soloader:0.10.5") - implementation("com.facebook.fbjni:fbjni:0.5.1") +implementation(files("libs/executorch.aar")) +implementation("com.facebook.soloader:soloader:0.10.5") +implementation("com.facebook.fbjni:fbjni:0.7.0") } ``` @@ -108,52 +129,62 @@ Now you can compile your app with the ExecuTorch Android library. ## Building from Source -`scripts/build_android_library.sh` is a helper script to build the Java library (into .jar), native library (into .so), and the packaged AAR file. - -You need Android [SDK](https://developer.android.com/studio) and [NDK](https://developer.android.com/ndk/downloads) to use it. - -Current NDK version used in ExecuTorch CI: r27b. +```text +scripts/build_android_library.sh +``` -You need to set `ANDROID_HOME` to Android SDK home and `ANDROID_NDK` to the correct NDK root (containing NOTICE file). +is a helper script to build the Java library (into .jar), native library (into .so), and the packaged AAR file. +You need Android SDK and NDK to use it. +Current NDK version used in ExecuTorch CI: r28c. +You need to set ANDROID_HOME to Android SDK home and ANDROID_NDK to the correct NDK root (containing NOTICE file). -``` +```sh export ANDROID_HOME=/path/to/sdk export ANDROID_NDK=/path/to/ndk sh scripts/build_android_library.sh ``` -Currently, XNNPACK backend is always built with the script. +NOTE: Currently, XNNPACK backend is always built with the script. ### Optional environment variables -Optionally, set these environment variables before running `build_android_library.sh`. +Optionally, set these environment variables before running build_android_library.sh. -#### ANDROID_ABIS -Set environment variable `ANDROID_ABIS` to either `arm64-v8a` or `x86_64` if you only need to build the native library for one ABI only. -``` +- __ANDROID_ABIS__ + +Set environment variable ANDROID_ABIS to either arm64-v8a or x86_64 if you only need to build the native library for one ABI only. + +```sh export ANDROID_ABIS=arm64-v8a -# or -# export ANDROID_ABIS=x86_64 +``` + + (Or) + +```sh +export ANDROID_ABIS=x86_64 +``` + +And then run the script. + +```sh sh scripts/build_android_library.sh ``` -#### EXECUTORCH_CMAKE_BUILD_TYPE -Set environment variable `EXECUTORCH_CMAKE_BUILD_TYPE` to `Release` or `Debug` based on your needs. +- __EXECUTORCH_CMAKE_BUILD_TYPE__ + +Set environment variable EXECUTORCH_CMAKE_BUILD_TYPE to Release or Debug based on your needs. -#### Using MediaTek backend +- __Using MediaTek backend__ -To use [MediaTek backend](backends-mediatek.md), -after installing and setting up the SDK, set `NEURON_BUFFER_ALLOCATOR_LIB` and `NEURON_USDK_ADAPTER_LIB` to the corresponding path. +To use MediaTek backend, after installing and setting up the SDK, set NEURON_BUFFER_ALLOCATOR_LIB and NEURON_USDK_ADAPTER_LIB to the corresponding path. -#### Using Qualcomm AI Engine Backend +- __Using Qualcomm AI Engine Backend__ -To use [Qualcomm AI Engine Backend](backends-qualcomm.md#qualcomm-ai-engine-backend), -after installing and setting up the SDK, set `QNN_SDK_ROOT` to the corresponding path. +To use Qualcomm AI Engine Backend, after installing and setting up the SDK, set QNN_SDK_ROOT to the corresponding path. -#### Using Vulkan Backend +- __Using Vulkan Backend__ -To use [Vulkan Backend](backends-vulkan.md#vulkan-backend), -set `EXECUTORCH_BUILD_VULKAN` to `ON`. +To use Vulkan Backend, set EXECUTORCH_BUILD_VULKAN to ON. ## Android Backends @@ -161,11 +192,12 @@ The following backends are available for Android: | Backend | Type | Doc | | ------- | -------- | --- | -| [XNNPACK](https://github.com/google/XNNPACK) | CPU | [Doc](backends-xnnpack.md) | +| [XNNPACK](https://github.com/google/XNNPACK) | CPU | [Doc](backends/xnnpack/xnnpack-overview.md) | | [MediaTek NeuroPilot](https://neuropilot.mediatek.com/) | NPU | [Doc](backends-mediatek.md) | | [Qualcomm AI Engine](https://www.qualcomm.com/developer/software/qualcomm-ai-engine-direct-sdk) | NPU | [Doc](backends-qualcomm.md) | -| [Vulkan](https://www.vulkan.org/) | GPU | [Doc](backends-vulkan.md) | +| [Vulkan](https://www.vulkan.org/) | GPU | [Doc](backends/vulkan/vulkan-overview.md) | +Start with XNNPACK (CPU backend) for maximum compatibility, then add hardware-specific backends for optimization. ## Runtime Integration @@ -175,26 +207,27 @@ Here is an example code sample in Java that demonstrates how to integrate ExecuT import org.pytorch.executorch.EValue; import org.pytorch.executorch.Module; import org.pytorch.executorch.Tensor; - public class MainActivity extends Activity { - private Module module; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - // Load the ExecuTorch module - Module module = Module.load("/data/local/tmp/add.pte"); - Tensor tensor1 = Tensor.fromBlob(new float[] {1.0f}, new long[] {1}); - Tensor tensor2 = Tensor.fromBlob(new float[] {20.0f}, new long[] {1}); - - EValue eValue1 = EValue.from(tensor1); - EValue eValue2 = EValue.from(tensor2); - float result = module.forward(eValue1, eValue2)[0].toTensor().getDataAsFloatArray()[0]; - } +private Module module; + +@Override +protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + // Load the ExecuTorch module + Module module = Module.load("/data/local/tmp/add.pte"); + + Tensor tensor1 = Tensor.fromBlob(new float[] {1.0f}, new long[] {1}); + Tensor tensor2 = Tensor.fromBlob(new float[] {20.0f}, new long[] {1}); + + EValue eValue1 = EValue.from(tensor1); + EValue eValue2 = EValue.from(tensor2); + + float result = module.forward(eValue1, eValue2)[0].toTensor().getDataAsFloatArray()[0]; } ``` -Push the corresponding pte file to the phone: +Push the corresponding pte file to your Android device: + ```sh adb push extension/module/test/resources/add.pte /data/local/tmp/ ``` @@ -202,7 +235,7 @@ adb push extension/module/test/resources/add.pte /data/local/tmp/ This example loads an ExecuTorch module, prepares input data, runs inference, and processes the output data. Please use [DeepLabV3AndroidDemo](https://github.com/meta-pytorch/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) -and [LlamaDemo](https://github.com/pytorch/executorch/tree/main/examples/demo-apps/android/LlamaDemo) for the code examples +and [LlamaDemo](https://github.com/meta-pytorch/executorch-examples/tree/main/llm/android/LlamaDemo) for the code examples using ExecuTorch AAR package. ## Java API reference diff --git a/docs/source/using-executorch-building-from-source.md b/docs/source/using-executorch-building-from-source.md index d48f9d26db7..8e1772086de 100644 --- a/docs/source/using-executorch-building-from-source.md +++ b/docs/source/using-executorch-building-from-source.md @@ -5,6 +5,7 @@ Even if you don't use CMake directly, CMake can emit scripts for other format like Make, Ninja or Xcode. For information, see [cmake-generators(7)](https://cmake.org/cmake/help/latest/manual/cmake-generators.7.html). ## System Requirements + ### Operating System ExecuTorch is tested on the following systems, although it should also work in similar environments. @@ -16,27 +17,31 @@ ExecuTorch is tested on the following systems, although it should also work in s * macOS (x86_64/ARM64) * Big Sur (11.0)+ * Windows (x86_64) + * Windows 10+ with Visual Studio 2022+ and [Clang-CL](https://learn.microsoft.com/en-us/cpp/build/clang-support-msbuild?view=msvc-170) * Windows Subsystem for Linux (WSL) with any of the Linux options - * Windows 10+ with Visual Studio 2022+ (experimental) ### Software Requirements + * `conda` or another virtual environment manager - `conda` is recommended as it provides cross-language support and integrates smoothly with `pip` (Python's built-in package manager) - Otherwise, Python's built-in virtual environment manager `python venv` is a good alternative. * `g++` version 7 or higher, `clang++` version 5 or higher, or another C++17-compatible toolchain. -* `python` version 3.10-3.12 -* `Xcode Command Line Tools` (macOS only) +* `python` version 3.10-3.13 * `ccache` (optional) - A compiler cache that speeds up recompilation +* **macOS** + - `Xcode Command Line Tools` +* **Windows** + - `Visual Studio Clang Tools` - See [Clang/LLVM support in Visual Studio](https://learn.microsoft.com/en-us/cpp/build/clang-support-msbuild?view=msvc-170). -Additional dependencies will be installed automatically when running the [Python installation](#building-the-python-package). +Additional dependencies will be automatically installed when running the [Python installation](#building-the-python-package). Note that the cross-compilable core runtime code supports a wider range of -toolchains, down to C++17. See the [Runtime Overview](runtime-overview.md) for +toolchains, down to C++17. See [Runtime Overview](runtime-overview.md) for portability details. ## Environment Setup - Clone the ExecuTorch repository from GitHub and create a conda environment as follows. Venv can be used in place on conda. + Clone the ExecuTorch repository from GitHub and create a conda environment. Venv can be used in place of conda. ```bash git clone -b viable/strict https://github.com/pytorch/executorch.git cd executorch @@ -44,6 +49,13 @@ portability details. conda activate executorch ``` +> **_NOTE:_** Addition Windows Setup +> +> ExecuTorch requires symlinks to be enabled to build the Python components. To enable symlinks, run the following command before cloning the repository. Missing symlinks will manifest as an error related to `version.py` when running `pip install .`. See [src/README.md](https://github.com/pytorch/executorch/blob/main/src/README.md) for more information. +> ```bash +> git config --system core.symlinks true +> ``` +
## Building the Python package @@ -60,7 +72,7 @@ portability details. * `--clean`: Removes build artifacts. * `--editable`: Install the ExecuTorch python package in editable mode (see [Editable Install](#editable-install)). * `--minimal`: Install only the minimal set of dependencies required to run ExecuTorch. Do not install dependencies for examples. - * `--use-pt-pinned-commit`: Install the pinned PyTorch commit. When not specified, the latest PyTorch nightly build is installed. + * `--use-pt-pinned-commit`: Install the pinned PyTorch commit or release version. When not specified, the latest PyTorch nightly build is installed. For Intel-based macOS systems, use `--use-pt-pinned-commit --minimal`. As PyTorch does not provide pre-built binaries for Intel Mac, installation requires building PyTorch from source. Instructions can be found in [PyTorch Installation](https://github.com/pytorch/pytorch#installation). @@ -71,6 +83,13 @@ portability details. CMAKE_ARGS="-DEXECUTORCH_BUILD_MPS=ON" ./install_executorch.sh ``` + ### Verify the Build + +To verify that the Python components are installed correctly, run the following command. This will create a file named mv2_xnnpack_fp32.pte in the current directory for the MobileNet V2 model with the XNNPACK backend. If it completes without error, the ExecuTorch Python components are installed successfully. +```bash +python -m executorch.examples.xnnpack.aot_compiler --model_name="mv2" --delegate +``` + ### Editable Install For development, include the `--editable` flag, which allows for local changes to ExecuTorch Python code to be reflected without a re-install. Note that when C++ files are modified, you will need to re-run the full installation to reflect the changes. ```bash @@ -112,47 +131,39 @@ portability details. ## Building the C++ Runtime -The ExecuTorch C++ runtime is built using CMake. It can be compiled standalone to run examples, added as a CMake dependency, or cross-compiled for Android, iOS, or embedded platforms. +The ExecuTorch runtime uses CMake as the build system. When using ExecuTorch from C++ user code with CMake, adding ExecuTorch as a submodule and referencing via CMake `add_subdirectory` will build the runtime as part of the user build. -### Configuring +When user code is not using CMake, the runtime can be built standalone and linked. The CMake options described below apply in both cases. Scripts are also provided for [Android AAR](#cross-compiling-for-android) and [iOS framework](#cross-compiling-for-ios) builds. -Configuration should be done after cloning, pulling the upstream repo, or changing build options. Once this is done, you won't need to do it again until you pull from the upstream repo or modify any CMake-related files. +| Use Case | How to Build | +| :------------------------- | :--------------------------------------------------------------------------------- | +| C++ with user CMake | Use CMake `add_subdirectory`. | +| C++ without user CMake | Bulild ExecuTorch standalone with CMake. Link libraries with user build. | +| Android with Java/Kotlin | Use [scripts/build_android_libraries.sh](#cross-compiling-for-android). | +| Android with C++ | Follow C++ build steps, [cross-compile for Android](#cross-compiling-for-android). | +| iOS | Use [scripts/build_ios_frameworks.sh](#cross-compiling-for-ios). | -```bash -# cd to the root of the executorch repo -cd executorch - -# Clean and configure the CMake build system. It's good practice to do this -# whenever cloning or pulling the upstream repo. -./install_executorch.sh --clean -(mkdir cmake-out && cd cmake-out && cmake ..) -``` +### Configuring -### Building +Configuration should be done after cloning, pulling the upstream repo, or changing build options. Once this is done, you won't need to do it again until you pull from the upstream repo or modify any CMake-related files. -Build all targets with `cmake --build`. +When building as a submodule as part of a user CMake build, ExecuTorch CMake options can be specified either as part of the user CMake configuration or in user CMake code. +CMake configuration for standalone runtime build: ```bash -# cd to the root of the executorch repo -cd executorch - -# Build using the configuration that you previously generated under the -# `cmake-out` directory. -# -# NOTE: The `-j` argument specifies how many jobs/processes to use when -# building, and tends to speed up the build significantly. It's typical to use -# "core count + 1" as the `-j` value. -cmake --build cmake-out -j9 +mkdir cmake-out +cmake -B cmake-out --preset [preset] [options] +cmake --build cmake-out -j10 ``` -> **_TIP:_** For faster rebuilds, consider installing ccache (see [Compiler Cache section](#compiler-cache-ccache) above). On first builds, ccache populates its cache. Subsequent builds with the same compiler flags can be significantly faster. - -### Build Presets +#### Build Presets -ExecuTorch provides fine-grained control over what is built, as described in [Build Options](#build-options). These options are grouped into CMake presets to cover common scenarios, while providing the ability to override individual options. Presets can be specified when configuring CMake by specifying `--preset [name]` when configuring. +ExecuTorch provides fine-grained control over what is built, as described in [Build Options](#build-options). These options are grouped into CMake presets to cover common scenarios while preserving the ability to override individual options. Presets can be specified when configuring CMake by specifying `--preset [name]` when configuring. Preset values for common scenarios are listed below. Using a platform preset is recommended to avoid needing to specify many fine-grained build options. + * `android-arm64-v8a` - Build features and backends common for arm64-v8a Android targets. + * `android-x86_64` - Build features and backends common for x86_64 Android targets. * `arm-baremetal` - Build for bare-metal ARM targets. * `ios` - Build features and backends common for iOS targets. * `macos` - Build features and backends common for Mac targets. @@ -161,77 +172,34 @@ Preset values for common scenarios are listed below. Using a platform preset is * `profiling` - Build the ExecuTorch runtime with profiling enabled. * `zephyr` - Build for Zephyr RTOS. +User CMake: +```cmake +set(EXECUTORCH_BUILD_PRESET_FILE ${CMAKE_SOURCE_DIR}/executorch/tools/cmake/preset/llm.cmake) +``` + +Standalone build: ```bash # Configure the build with the ios preset. cmake .. --preset ios ``` -### CMake Targets and Libraries - -To link against the ExecuTorch framework from CMake, the following top-level targets are exposed: - - * `executorch::backends`: Contains all configured backends. - * `executorch::extensions`: Contains all configured extensions. - * `executorch::kernels`: Contains all configured kernel libraries. - -The backends, extensions, and kernels included in these targets are controlled by the various `EXECUTORCH_` CMake options specified by the build. Using these targets will automatically pull in the required dependencies to use the configured features. - -### Running an Example Model +#### Build Options -The example `executor_runner` binary can be used to run a model and sanity-check the build. Run the following commands to generate and run a simple model. -You should see the message "Model executed successfully" followed by the output values. +CMake options can be used to for fine-grained control of build type, control which features are built, and configure functionality, such as logging. Options are typically specified during CMake configuration. Default values of each option are set by the active preset, but can be overridden by specifying the option when configuring. -``` bash -python -m examples.portable.scripts.export --model_name="add" -./cmake-out/executor_runner --model_path add.pte -``` +Note that many build options require other options to be enabled. This may require enabling multiple options to enable a given feature. The CMake build output will provide an error message when a required option is not enabled. +User CMake: +```cmake +set(EXECUTORCH_BUILD_XNNPACK ON) ``` -I 00:00:00.000526 executorch:executor_runner.cpp:82] Model file add.pte is loaded. -I 00:00:00.000595 executorch:executor_runner.cpp:91] Using method forward -I 00:00:00.000612 executorch:executor_runner.cpp:138] Setting up planned buffer 0, size 48. -I 00:00:00.000669 executorch:executor_runner.cpp:161] Method loaded. -I 00:00:00.000685 executorch:executor_runner.cpp:171] Inputs prepared. -I 00:00:00.000764 executorch:executor_runner.cpp:180] Model executed successfully. -I 00:00:00.000770 executorch:executor_runner.cpp:184] 1 outputs: -Output 0: tensor(sizes=[1], [2.]) -``` - -### Compiler Cache (ccache) - -ExecuTorch automatically detects and enables [ccache](https://ccache.dev/) if it's installed. This significantly speeds up recompilation by caching previously compiled objects: - -- If ccache is detected, you'll see: `ccache found and enabled for faster builds` -- If ccache is not installed, you'll see: `ccache not found, builds will not be cached` - -To install ccache: +Standalone build: ```bash -# Ubuntu/Debian -sudo apt install ccache - -# macOS -brew install ccache - -# CentOS/RHEL -sudo yum install ccache -# or -sudo dnf install ccache +cmake -DEXECUTORCH_BUILD_XNNPACK=ON ``` -No additional configuration is needed - the build system will automatically use ccache when available. - -See [CMakeLists.txt](https://github.com/pytorch/executorch/blob/main/CMakeLists.txt) - -
- -## Build Options - -CMake options can be used to for fine-grained control of build type, control which features are built, and configure functionality, such as logging. Options are typically specified during CMake configuration. Default values of each option are set by the active preset, but can be overridden by specifying the option when configuring. - -Note that many build options require other options to be enabled. This may require enabling multiple options to enable a given feature. The CMake build output will provide an error message when a required option is not enabled. - -#### Build Type +##### Build Type The CMake build is typically set to `Debug` or `Release`. For production use or profiling, release mode should be used to improve performance and reduce binary size. It disables program verification and executorch logging and adds optimizations flags. The `EXECUTORCH_OPTIMIZE_SIZE` flag can be used to further optimize for size with a small performance tradeoff. @@ -240,7 +208,7 @@ The CMake build is typically set to `Debug` or `Release`. For production use or cmake .. -DCMAKE_BUILD_TYPE=Release ``` -#### Backends +##### Backends Typically, each hardware backend exposes a CMake option to control whether the backend is built. See backend-specific documentation for more details. @@ -260,7 +228,7 @@ Typically, each hardware backend exposes a CMake option to control whether the b cmake .. -DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_VULKAN=ON ``` -#### Extensions +##### Extensions ExecuTorch extensions provide optional functionality outside of the core runtime. As the core runtime is designed to run in constrained environments, these features are typically disabled by default. Extensions include higher-level APIs (Module and Tensor), multi-threading support (Threadpool), training, and more. @@ -281,7 +249,7 @@ ExecuTorch extensions provide optional functionality outside of the core runtime cmake .. -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON ``` -#### Logging +##### Logging Logging is enabled by default in debug builds and disabled in release. When enabled, the default log level is Info. Both log enable and level can be overriden with options. See [Logging](using-executorch-runtime-integration.md#logging). Disabling logging and decreasing log verbosity will reduce binary size by stripping unused strings from the build. @@ -293,7 +261,39 @@ Logging is enabled by default in debug builds and disabled in release. When enab cmake .. -DEXECUTORCH_ENABLE_LOGGING=ON -DEXECUTORCH_LOG_LEVEL=debug ``` -#### Output Libraries +### Building + +Build all targets with `cmake --build`. + +```bash +# cd to the root of the executorch repo +cd executorch + +# Build using the configuration that you previously generated under the +# `cmake-out` directory. +# +# NOTE: The `-j` argument specifies how many jobs/processes to use when +# building, and tends to speed up the build significantly. It's typical to use +# "core count + 1" as the `-j` value. +cmake --build cmake-out -j9 +``` + +> **_TIP:_** For faster rebuilds, consider installing ccache (see [Compiler Cache section](#compiler-cache-ccache) above). On first builds, ccache populates its cache. Subsequent builds with the same compiler flags can be significantly faster. + +
+ + +## CMake Targets and Output Libraries + +To link against the ExecuTorch framework from CMake, the following top-level targets are exposed: + + * `executorch::backends`: Contains all configured backends. + * `executorch::extensions`: Contains all configured extensions. + * `executorch::kernels`: Contains all configured kernel libraries. + +The backends, extensions, and kernels included in these targets are controlled by the various `EXECUTORCH_` CMake options specified by the build. Using these targets will automatically pull in the required dependencies to use the configured features. + +### Linking Without CMake To link against the runtime from outside of the CMake ecosystem, the runtime can be first built with CMake and then linked directly. A few of the relevant top-level targets are described below. Note that this is a more involved process than using CMake and is only recommended when using CMake is not viable. @@ -312,6 +312,26 @@ To link against the runtime from outside of the CMake ecosystem, the runtime can Backends typically introduce additional targets. See backend-specific documentation for more details. +### Verify the Build + +To verify the build, ExecuTorch optionally compiles a simple, stand-alone model runner to run PTE files with all-one input tensors. It is not enabled by default in most presets, but can be enabled by configuring with `-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON -DEXECUTORCH_BUILD_EXTENSION_EVALUE_UTIL=ON`. + +Once compiled, invoke the runner with a sample PTE (such as the one generated by [verifying the Python build](#verify-the-build)). +```bash +cmake-out/executor_runner --model_path=mv2_xnnpack_fp32.pte +``` + +If the runner runs successfully, you should see output similar to the following: +``` +I 00:00:00.043703 executorch:executor_runner.cpp:379] Model executed successfully 1 time(s) in 15.013292 ms. +I 00:00:00.043720 executorch:executor_runner.cpp:383] 1 outputs: +OutputX 0: tensor(sizes=[1, 1000], [ + -0.509859, 0.300644, 0.0953884, 0.147724, 0.231202, 0.338554, 0.206888, -0.0575762, -0.389273, -0.0606864, + ..., + 0.421219, 0.100447, -0.506771, -0.115824, -0.693017, -0.183262, 0.154781, -0.410684, 0.0119296, 0.449713, +]) +``` +
## Cross-Compiling for Android @@ -325,8 +345,7 @@ Backends typically introduce additional targets. See backend-specific documentat ### Building the AAR -With the NDK installed, the `build_android_library.sh` script will build the ExecuTorch Java AAR. This file contains the ExecuTorch Java bindings -and native code. See [Using the AAR File](using-executorch-android.md#using-aar-file) for usage. +With the NDK installed, the `build_android_library.sh` script will build the ExecuTorch Java AAR, which contains ExecuTorch Java bindings. See [Using the AAR File](using-executorch-android.md#using-aar-file) for usage. ```bash export ANDROID_ABIS=arm64-v8a @@ -335,36 +354,21 @@ mkdir -p $BUILD_AAR_DIR sh scripts/build_android_library.sh ``` -### Building the Example Runner +### Android Native -The native executor runner can be cross-compiled for android and deployed via ADB. This step is intended as -an example of CMake cross compilation and is not necessary for integration into an app. +To use the ExecuTorch runtime from native Android C++ code, the runtime can be cross-compiled for Android. The recommended approach is to add ExecuTorch as a submodule of the user project and use [CMake](https://developer.android.com/ndk/guides/cmake) for the native build. The above steps for C++ with CMake can be followed. +For direct cross-compilation, the ExecuTorch runtime can be configured to build with the NDK toolchain: ```bash -# Run the following lines from the `executorch/` folder -./install_executorch.sh --clean -mkdir cmake-android-out && cd cmake-android-out - # point -DCMAKE_TOOLCHAIN_FILE to the location where ndk is installed -cmake -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a .. - -cd .. -cmake --build cmake-android-out -j9 - -adb shell mkdir -p /data/local/tmp/executorch -# push the binary to an Android device -adb push cmake-android-out/executor_runner /data/local/tmp/executorch -# push the model file -adb push add.pte /data/local/tmp/executorch - -adb shell "/data/local/tmp/executorch/executor_runner --model_path /data/local/tmp/executorch/add.pte" +cmake -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a .. ```
## Cross-Compiling for iOS -For iOS, we'll build [frameworks](https://developer.apple.com/documentation/xcode/creating-a-multi-platform-binary-framework-bundle) instead of static libraries. The frameworks contain the compiled ExecuTorch runtime and public headers. +iOS binaries are built as [frameworks](https://developer.apple.com/documentation/xcode/creating-a-multi-platform-binary-framework-bundle) instead of static libraries. The frameworks contain the compiled ExecuTorch runtime and public headers. ### Pre-requisites @@ -385,119 +389,36 @@ xcode-select --install ``` Run the above command with `--help` flag to learn more on how to build additional backends -(like [Core ML](backends-coreml.md), [MPS](backends-mps.md) or XNNPACK), etc. +(like [Core ML](backends/coreml/coreml-overview.md), [MPS](backends/mps/mps-overview.md) or XNNPACK), etc. Note that some backends may require additional dependencies and certain versions of Xcode and iOS. See backend-specific documentation for more details. 2. Copy over the generated `.xcframework` bundles to your Xcode project, link them against your targets and don't forget to add an extra linker flag `-all_load`. -Check out the [iOS Demo App](https://github.com/meta-pytorch/executorch-examples/tree/main/mv3/apple/ExecuTorchDemo) tutorial for more info. - -
- -## Building on Windows - -ExecuTorch provides experimental support for native Windows builds. - -> **_NOTE:_** All commands should be executed on Windows powershell in administrator mode. - -### Environment Setup - -#### Pre-requisites +See the [iOS Demo App](https://github.com/meta-pytorch/executorch-examples/tree/main/mv3/apple/ExecuTorchDemo) tutorial for example usage of the ExecuTorch frameworks. -1. Install miniconda for Windows from the [official website](https://docs.conda.io/en/latest/miniconda.html). -2. Install Git for Windows from the [official website](https://git-scm.com/download/win). -3. Install ClangCL for Windows from the [official website](https://learn.microsoft.com/en-us/cpp/build/clang-support-msbuild?view=msvc-170) or through a [Visual Studio](https://learn.microsoft.com/en-us/cpp/build/clang-support-msbuild?view=msvc-170) or [Visual Studio Code](https://code.visualstudio.com/docs/cpp/config-clang-mac) installation. +## Compiler Cache (ccache) -#### Clone and Configure Environment - -```bash -git config --global core.symlinks true -git clone --recurse -submodules https://github.com/pytorch/executorch.git -cd executorch -conda create -yn et python=3.12 -conda activate et -``` - -If Conda is not available, run conda-hook.ps1, where `$miniconda_dir` is the directory where miniconda is installed. -This is `“C:\Users\\AppData\Local”` by default. - -```bash -$miniconda_dir\\shell\\condabin\\conda-hook.ps1 -``` - -### Build the Python Package - -Run `install_executorch.bat` to build and install the ExecuTorch Python package and runtime bindings. - -```bash -cd executorch -./install_executorch.bat -``` - -> **_NOTE_** Many components are not currently buildable on Windows. These instructions install a very minimal ExecuTorch which can be used as a sanity check. +ExecuTorch automatically detects and enables [ccache](https://ccache.dev/) if it's installed. This significantly speeds up recompilation by caching previously compiled objects: -### Build the C++ Runtime +- If ccache is detected, you'll see: `ccache found and enabled for faster builds` +- If ccache is not installed, you'll see: `ccache not found, builds will not be cached` +To install ccache: ```bash -del -Recurse -Force cmake-out; ` -cmake . ` - -DCMAKE_INSTALL_PREFIX=cmake-out ` - -DPYTHON_EXECUTABLE=$miniconda_dir\\envs\\et\\python.exe ` - -DCMAKE_PREFIX_PATH=$miniconda_dir\\envs\\et\\Lib\\site-packages ` - -DCMAKE_BUILD_TYPE=Release ` - -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON ` - -DEXECUTORCH_BUILD_FLATC=ON ` - -DEXECUTORCH_BUILD_PYBIND=OFF ` - -DEXECUTORCH_BUILD_XNNPACK=ON ` - -DEXECUTORCH_BUILD_KERNELS_LLM=ON ` - -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON ` - -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON ` - -DEXECUTORCH_ENABLE_LOGGING=ON ` - -T ClangCL ` - -Bcmake-out; ` -cmake --build cmake-out -j64 --target install --config Release -``` - -> **_NOTE_** `$miniconda_dir` is the directory where you installed miniconda. This is `“C:\Users\\AppData\Local”` by default. - -### Running an Example Model - -To validate the installation by running a model, create a file named export_mv2.py. Then, run the powershell commands to export and run the model. -The expected output is a tensor of size 1x1000, containing class scores. - -```py -# export_mv2.py -import torch -from executorch.exir import to_edge_transform_and_lower -from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner -from torchvision.models import mobilenet_v2 -from torchvision.models.mobilenetv2 import MobileNet_V2_Weights - -mv2 = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() -example_inputs = (torch.randn((1, 3, 224, 224)),) - -program = to_edge_transform_and_lower( - torch.export.export(model, example_inputs) -).to_executorch() - -with open("mv2_xnnpack.pte", "wb") as file: - executorch_program.write_to_file(file) -``` +# Ubuntu/Debian +sudo apt install ccache -```bash -python .\\export_mv2.py -.\\cmake-out\\backends\\xnnpack\\Release\\xnn_executor_runner.exe --model_path=.\\mv2_xnnpack.pte -``` +# macOS +brew install ccache -```bash -Output 0: tensor(sizes=[1, 1000], [ - -0.50986, 0.30064, 0.0953904, 0.147726, 0.231205, 0.338555, 0.206892, -0.0575775, … ]) +# CentOS/RHEL +sudo yum install ccache +# or +sudo dnf install ccache ``` -## Next Steps +No additional configuration is needed - the build system will automatically use ccache when available. -* [Selective Build](kernel-library-selective-build.md) to link only kernels used by the program. This can provide significant binary size savings. -* Tutorials on building [Android](https://github.com/meta-pytorch/executorch-examples/tree/main/dl3/android/DeepLabV3Demo#executorch-android-demo-app) and [iOS](https://github.com/meta-pytorch/executorch-examples/tree/main/mv3/apple/ExecuTorchDemo) demo apps. -* Tutorials on deploying applications to embedded devices such as [ARM Cortex-M/Ethos-U](backends-arm-ethos-u.md) and [XTensa HiFi DSP](backends-cadence.md). +See [CMakeLists.txt](https://github.com/pytorch/executorch/blob/main/CMakeLists.txt) diff --git a/docs/source/using-executorch-cpp.md b/docs/source/using-executorch-cpp.md index 3736226bc06..5505ade9573 100644 --- a/docs/source/using-executorch-cpp.md +++ b/docs/source/using-executorch-cpp.md @@ -69,7 +69,7 @@ The runner source code can be found in the ExecuTorch repo under [examples/porta ## Next Steps -- [Runtime API Reference](executorch-runtime-api-reference.md) for documentation on the available C++ runtime APIs. +- [Runtime API Reference](executorch-runtime-api-reference.rst) for documentation on the available C++ runtime APIs. - [Running an ExecuTorch Model Using the Module Extension in C++](extension-module.md) for information on the high-level Module API. - [Managing Tensor Memory in C++](extension-tensor.md) for information on high-level tensor APIs. - [Running an ExecuTorch Model in C++ Tutorial](running-a-model-cpp-tutorial.md) for information on the low-level runtime APIs. diff --git a/docs/source/using-executorch-export.md b/docs/source/using-executorch-export.md index 2a887bb346d..ae73cb5aeac 100644 --- a/docs/source/using-executorch-export.md +++ b/docs/source/using-executorch-export.md @@ -24,7 +24,7 @@ Quantization - the process of using reduced precision to reduce inference time a ExecuTorch backends provide hardware acceleration for a specific hardware target. In order to achieve maximum performance on target hardware, ExecuTorch optimizes the model for a specific backend during the export and lowering process. This means that the resulting .pte file is specialized for the specific hardware. In order to deploy to multiple backends, such as Core ML on iOS and Arm CPU on Android, it is common to generate a dedicated .pte file for each. -The choice of hardware backend is informed by the hardware that the model is intended to be deployed on. Each backend has specific hardware requires and level of model support. See the documentation for each hardware backend for more details. +The choice of hardware backend is informed by the hardware that the model is intended to be deployed on. Each backend has specific hardware requirements and level of model support. See the documentation for each hardware backend for more details. As part of the .pte file creation process, ExecuTorch identifies portions of the model (partitions) that are supported for the given backend. These sections are processed by the backend ahead of time to support efficient execution. Portions of the model that are not supported on the delegate, if any, are executed using the portable fallback implementation on CPU. This allows for partial model acceleration when not all model operators are supported on the backend, but may have negative performance implications. In addition, multiple partitioners can be specified in order of priority. This allows for operators not supported on GPU to run on CPU via XNNPACK, for example. @@ -32,10 +32,10 @@ As part of the .pte file creation process, ExecuTorch identifies portions of the Commonly used hardware backends are listed below. For mobile, consider using XNNPACK for Android and XNNPACK or Core ML for iOS. To create a .pte file for a specific backend, pass the appropriate partitioner class to `to_edge_transform_and_lower`. See the appropriate backend documentation and the [Export and Lowering](#export-and-lowering) section below for more information. -- [XNNPACK (Mobile CPU)](backends-xnnpack.md) -- [Core ML (iOS)](backends-coreml.md) -- [Metal Performance Shaders (iOS GPU)](backends-mps.md) -- [Vulkan (Android GPU)](backends-vulkan.md) +- [XNNPACK (CPU)](backends/xnnpack/xnnpack-overview.md) +- [Core ML (iOS)](backends/coreml/coreml-overview.md) +- [Metal Performance Shaders (iOS GPU)](backends/mps/mps-overview.md) +- [Vulkan (Android GPU)](backends/vulkan/vulkan-overview.md) - [Qualcomm NPU](backends-qualcomm.md) - [MediaTek NPU](backends-mediatek.md) - [Arm Ethos-U NPU](backends-arm-ethos-u.md) @@ -141,7 +141,6 @@ delegate_external_constants_pass_unlifted( exported_program = export(tagged_module, inputs, dynamic_shapes=dynamic_shapes) executorch_program = to_edge_transform_and_lower( exported_program, - transform_passes = [partial_function], partitioner = [XnnpackPartitioner()] ).to_executorch() ``` @@ -184,6 +183,7 @@ For more complex use cases, dynamic shape specification allows for mathematical Before integrating the runtime code, it is common to test the exported model from Python. This can be used to evaluate model accuracy and sanity check behavior before moving to the target device. Note that not all hardware backends are available from Python, as they may require specialized hardware to function. See the specific backend documentation for more information on hardware requirements and the availablilty of simulators. The XNNPACK delegate used in this example is always available on host machines. ```python +import torch from executorch.runtime import Runtime runtime = Runtime.get() @@ -194,9 +194,19 @@ method = program.load_method("forward") outputs = method.execute([input_tensor]) ``` -Pybindings currently does not support loading program and data. To run a model with PTE and PTD components, please use the [Extension Module](extension-module.md). There is also an E2E demo in [executorch-examples](https://github.com/meta-pytorch/executorch-examples/tree/main/program-data-separation). +To run a model with program and data separated, please use the [ExecuTorch Module pybindings](https://github.com/pytorch/executorch/blob/main/extension/pybindings/README.md). +```python +import torch +from executorch.extension.pybindings import portable_lib + +input_tensor = torch.randn(1, 3, 32, 32) +module = portable_lib._load_for_executorch("model.pte", "model.ptd") +outputs = module.forward([input_tensor]) +``` + +There is also an E2E demo in [executorch-examples](https://github.com/meta-pytorch/executorch-examples/tree/main/program-data-separation). -For more information, see [Runtime API Reference](executorch-runtime-api-reference.md). +For more information, see [Runtime API Reference](executorch-runtime-api-reference.rst). ## Advanced Topics @@ -270,7 +280,7 @@ decode_ep = torch.export.export(DecodeWrapper(model), ...) ## Next Steps -The PyTorch and ExecuTorch export and lowering APIs provide a high level of customizability to meet the needs of diverse hardware and models. See [torch.export](https://pytorch.org/docs/main/export.html) and [Export API Reference](export-to-executorch-api-reference.md) for more information. +The PyTorch and ExecuTorch export and lowering APIs provide a high level of customizability to meet the needs of diverse hardware and models. See [torch.export](https://pytorch.org/docs/main/export.html) and [Export API Reference](export-to-executorch-api-reference.rst) for more information. For advanced use cases, see the following: - [Quantization Overview](quantization-overview.md) for information on quantizing models to reduce inference time and memory footprint. diff --git a/docs/source/using-executorch-faqs.md b/docs/source/using-executorch-faqs.md index d1bd0390569..c147403c9e8 100644 --- a/docs/source/using-executorch-faqs.md +++ b/docs/source/using-executorch-faqs.md @@ -16,7 +16,7 @@ if you are using Ubuntu, or use an equivalent install command. ### ModuleNotFoundError: No module named 'pytorch_tokenizers' -The `pytorch_tokenizers` package is required for LLM export functionality. Install it from the ExecutorTorch source code: +The `pytorch_tokenizers` package is required for LLM export functionality. Install it from the ExecuTorch source code: ``` pip install -e ./extension/llm/tokenizers/ ``` @@ -48,7 +48,7 @@ Thread count can be set with the following function. Ensure this is done prior t ::executorch::extension::threadpool::get_threadpool()->_unsafe_reset_threadpool(num_threads); ``` -For a deeper investgiation into model performance, ExecuTorch supports operator-level performance profiling. See [Using the ExecuTorch Developer Tools to Profile a Model](devtools-integration-tutorial.md) for more information. +For a deeper investigation into model performance, ExecuTorch supports operator-level performance profiling. See [Using the ExecuTorch Developer Tools to Profile a Model](devtools-integration-tutorial.md) for more information. ### Missing Logs diff --git a/docs/source/using-executorch-ios.md b/docs/source/using-executorch-ios.md index 3e12f174177..78d22080d8d 100644 --- a/docs/source/using-executorch-ios.md +++ b/docs/source/using-executorch-ios.md @@ -18,7 +18,9 @@ The ExecuTorch Runtime for iOS and macOS (ARM64) is distributed as a collection Link your binary with the ExecuTorch runtime and any backends or kernels used by the exported ML model. It is recommended to link the core runtime to the components that use ExecuTorch directly, and link kernels and backends against the main app target. -**Note:** To access logs, link against the Debug build of the ExecuTorch runtime, i.e., the `executorch_debug` framework. For optimal performance, always link against the Release version of the deliverables (those without the `_debug` suffix), which have all logging overhead removed. +**Note:** You may need to add some extra linker flags for the build settings of the components that links against ExecuTorch backends or kernels to let them register properly at the app startup. See the [Linkage](#Linkage) section for more details. + +**Note:** To access logs, link against the Debug build of the ExecuTorch runtime, i.e., the `executorch_debug` framework. For optimal performance, always link against the Release version of the deliverables (those without the `_debug` suffix), which have all logging overhead removed. See the [Logging](#Logging) section for more details. ### Swift Package Manager @@ -26,7 +28,7 @@ The prebuilt ExecuTorch runtime, backend, and kernels are available as a [Swift #### Xcode -In Xcode, go to `File > Add Package Dependencies`. Paste the URL of the [ExecuTorch repo](https://github.com/pytorch/executorch) into the search bar and select it. Make sure to change the branch name to the desired ExecuTorch version in format "swiftpm-", (e.g. "swiftpm-0.7.0"), or a branch name in format "swiftpm-." (e.g. "swiftpm-0.8.0-20250801") for a [nightly build](https://ossci-ios.s3.amazonaws.com/list.html) on a specific date. +In Xcode, go to `File > Add Package Dependencies`. Paste the URL of the [ExecuTorch repo](https://github.com/pytorch/executorch) into the search bar and select it. Make sure to change the branch name to the desired ExecuTorch version in format "swiftpm-", (e.g. "swiftpm-1.0.0"), or a branch name in format "swiftpm-." (e.g. "swiftpm-1.1.0-20251101") for a [nightly build](https://ossci-ios.s3.amazonaws.com/list.html) on a specific date. ![](_static/img/swiftpm_xcode1.png) @@ -59,7 +61,7 @@ let package = Package( ], dependencies: [ // Use "swiftpm-." branch name for a nightly build. - .package(url: "https://github.com/pytorch/executorch.git", branch: "swiftpm-0.7.0") + .package(url: "https://github.com/pytorch/executorch.git", branch: "swiftpm-1.0.0") ], targets: [ .target( @@ -70,6 +72,10 @@ let package = Package( .product(name: "kernels_optimized", package: "executorch"), // Add other backends and kernels as needed. ]), + linkerSettings: [ + // Force load all symbols from static libraries to trigger backends and kernels registration + .unsafeFlags(["-Wl,-all_load"]) + ] ] ) ``` @@ -107,7 +113,7 @@ git clone -b viable/strict https://github.com/pytorch/executorch.git --depth 1 - python3 -m venv .venv && source .venv/bin/activate && pip install --upgrade pip ``` -4. Install the required dependencies, including those needed for the backends like [Core ML](backends-coreml.md) or [MPS](backends-mps.md), if you plan to build them later: +4. Install the required dependencies, including those needed for the backends like [Core ML](backends/coreml/coreml-overview.md) or [MPS](backends/mps/mps-overview.md), if you plan to build them later: ```bash ./install_requirements.sh diff --git a/docs/source/using-executorch-runtime-integration.md b/docs/source/using-executorch-runtime-integration.md index 550cb3eb71a..36bc4f6b2fe 100644 --- a/docs/source/using-executorch-runtime-integration.md +++ b/docs/source/using-executorch-runtime-integration.md @@ -64,7 +64,7 @@ namespace { ``` ### Weak Symbol Override -ExecuTorch also provides a link-time method to override the PAL using weak symbols. This method is primarily maintained for backwards compatability. +ExecuTorch also provides a link-time method to override the PAL using weak symbols. This method is primarily maintained for backwards compatibility. To override one or more PAL methods, take the following steps: diff --git a/docs/source/using-executorch-troubleshooting.md b/docs/source/using-executorch-troubleshooting.md index 56c2e1a0653..75648dc5b46 100644 --- a/docs/source/using-executorch-troubleshooting.md +++ b/docs/source/using-executorch-troubleshooting.md @@ -1,11 +1,11 @@ # Profiling and Debugging -To faciliate model and runtime integration, ExecuTorch provides tools to profile model resource utilization, numerics, and more. This section describes the available troubleshooting tools and steps to resolve issues when integrating ExecuTorch. +To facilitate model and runtime integration, ExecuTorch provides tools to profile model resource utilization, numerics, and more. This section describes the available troubleshooting tools and steps to resolve issues when integrating ExecuTorch. ## General Troubleshooting Steps - To troubleshoot failure of runtime API calls, such as loading or running a model, ensure that ExecuTorch framework logging is enabled. See [Logging](using-executorch-runtime-integration.md#logging) for more information. -- As a prelimatinary step to troubleshoot slow run times, ensure that performance testing is being done in a release build, and that the model is delegated. See [Inference is Slow](using-executorch-faqs.md#inference-is-slow--performance-troubleshooting) for more information. +- As a preliminary step to troubleshoot slow run times, ensure that performance testing is being done in a release build, and that the model is delegated. See [Inference is Slow](using-executorch-faqs.md#inference-is-slow--performance-troubleshooting) for more information. - Check [Frequently Asked Questions](using-executorch-faqs.md) for common issues and questions encountered during install, model export, and runtime integration. ## Developer Tools @@ -16,5 +16,5 @@ The ExecuTorch developer tools, or devtools, are a collection of tooling for tro - [Frequently Asked Questions](using-executorch-faqs.md) for solutions to commonly encountered questions and issues. - [Introduction to the ExecuTorch Developer Tools](runtime-profiling.md) for a high-level introduction to available developer tooling. -- [Using the ExecuTorch Developer Tools to Profile a Model](https://pytorch.org/executorch/main/tutorials/devtools-integration-tutorial) for information on runtime performance profiling. +- [Using the ExecuTorch Developer Tools to Profile a Model](tutorials/devtools-integration-tutorial) for information on runtime performance profiling. - [Inspector APIs](runtime-profiling.md) for reference material on trace inspector APIs. diff --git a/docs/source/visualize.md b/docs/source/visualize.md new file mode 100644 index 00000000000..fdd868df4f0 --- /dev/null +++ b/docs/source/visualize.md @@ -0,0 +1,144 @@ +# Visualize a Model using ModelExplorer + +The [visualization_utils.py](../../devtools/visualization/visualization_utils.py) contains functions for +visualizing ExecuTorch models as computational graphs using the `ModelExplorer` utility. + +## Installation + +To install the `ModelExplorer` and its dependencies, run: + +``` +./devtools/install_requirements.sh +``` + +## Visualize a model + +The function `visualize()` takes an `ExportedProgram` and launches a `ModelExplorer` server instance. A browser tab will +open, containing the visualization. + +The operations in the graph will be grouped together into collapsable nodes, based on which `nn.Module` instances they +originate from (see **Figure 1**). These nodes can be expanded by clicking the button in their top +left corner, as shown +in **Figure 2**. The model can contain an entire hierarchy of collapsable nodes, reflecting its +original _PyTorch_ +implementation (see **Figure 3**). + +
+ +
Figure 1: Model visualization collapsed into a single node representing the original module.
+
+ +
+ +
Figure 2: Button to expand a node.
+
+ +
+ +
Figure 3: Hierarchy of expandable nodes.
+
+ +The **Model Explorer GUI** provides a button in the top left corner of the screen (see **Figure 4 +**), +which expands all the nested expandable nodes. The result will display all the low-level operations, surrounded by +rectangles which indicate their membership to specific `nn.Module` instances. + +
+ +
Figure 4: Expand all nodes.
+
+ + +Sometimes, it is not ideal to view the model like this. Focusing on visualizing the origin of the final nodes can make +it harder to see the flow of data in the graph. For this purpose, a button in the top left corner can flatten all the +layers (expandable nodes), effectively hiding the original `nn.Module` instances and just displaying the model as a +computational graph (see **Figure 5**). + +
+ +
Figure 5: Flatten the model to a simple computational graph.
+
+ +--- + +# Visualize a Model with Highlighted QDQ Clusters and Partitions + +The [visualization_utils.py](../../devtools/visualization/visualization_utils.py) contains the function +`visualize_with_clusters()` which takes an `ExportedProgram` and visualizes it using the `ModelExplorer` utility. +It groups QDQ clusters and individual partitions together to improve readability. Example usage is available +in [examples/nxp/aot_neutron_compile.py](../../examples/nxp/aot_neutron_compile.py). + +An example of the visualization is shown in **Figure 6.** +
+ +
Figure 6: Example of the QDQ cluster and partition highlighting visualization.
+
+ +## Usage + +There are two main use cases for the visualization: + +### 1. Launching the `ModelExplorer` and Visualizing the Model Immediately + +Call: + +```python +visualize_with_clusters(exported_program) +``` + +This starts a `ModelExplorer` server and opens a browser tab with the visualization. + +By default, each call starts a new server instance and opens a new browser tab. +To reuse an existing server, set the `reuse_server` parameter to `True`. + +Starting the server is **blocking**, so the rest of your script will not run. + +### 2. Storing a Serialized Graph and Visualizing Later (Non-blocking) + +To save the visualization to a JSON file, call: + +```python +visualize_with_clusters(exported_program, "my_model.json") +``` + +This just saves the visualization in the file, and it does **not** start the `ModelExplorer` server. You can then open +the file in the `ModelExplorer` GUI at any point. To launch the server, run: + +```bash + model-explorer [model-file-json] +``` + +If the `model-file-json` is provided, the `ModelExplorer` will open the model visualization. Otherwise, the +`ModelBuilder` GUI home page will appear. In that case, click **Select from your computer**, choose the JSON file, +and then click **View selected models** to display the graph. + +--- + +## Styling the Graph + +`visualize_with_clusters()` supports custom grouping of nodes into QDQ clusters and partitions. + +You can pass the following optional parameters: + +- `get_node_partition_name` +- `get_node_qdq_cluster_name` + +These are functions that take a node and return a string identifying the partition or cluster it belongs to. +Nodes with the same partition/cluster string will be grouped together and labeled accordingly in the visualization. + +### Load a predefined style for QDQ cluster and partition highlighting. + +A color style for the QDQ cluster and partition highlighting is already provided +in [devtools/visualization/model_explorer_styles/cluster_highlight_style.json](../../devtools/visualization/model_explorer_styles/cluster_highlight_style.json). +To load it follow these steps: + +1. Click the **palette icon** in the top-right corner of the `ModelExplorer` interface. +2. Click **Import rules**. +3. Select + the [cluster_highlight_style.json](../../devtools/visualization/model_explorer_styles/cluster_highlight_style.json) + file to apply predefined styles that highlight each partition in a different color. + +
+ +
Figure 7: Add custom color styling to the graph.
+
diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index 48edc3c0669..af2fa3c74ee 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -23,7 +23,6 @@ from executorch.exir.backend.utils import format_delegated_graph from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes import MemoryPlanningPass -from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.extension.export_util.utils import save_pte_program @@ -211,9 +210,7 @@ def main() -> None: executorch_program = edge_manager.to_executorch( ExecutorchBackendConfig( extract_delegate_segments=True, - passes=[ - QuantFusionPass(), - ], + do_quant_fusion_and_const_prop=True, memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py new file mode 100644 index 00000000000..a3fd8201414 --- /dev/null +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -0,0 +1,504 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export script for static attention LLM models to CoreML via ExecuTorch. + +Usage: + python export_static_llm_coreml.py \ + --checkpoint /path/to/model.pth \ + --params /path/to/params.json \ + --output static_llm_coreml_model.pte \ + --max_context_len 1024 \ + --input_len 32 \ + --embedding_quantize 4,32 \ + --coreml_quantize c4w \ + --target_split_size 1048 +""" + +import argparse +import json + +import coremltools as ct +import torch +import torch.nn as nn +import torch.utils._pytree as pytree + +from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.examples.apple.coreml.llama.utils import ( + replace_linear_with_split_linear, +) +from executorch.examples.models.llama.llama_transformer import construct_transformer +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope +from executorch.examples.models.llama.static_attention import StaticAttentionIOManager +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower +from executorch.exir.backend.utils import format_delegated_graph +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes import MemoryPlanningPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass +from executorch.extension.export_util.utils import save_pte_program +from torch.library import impl, Library +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + +# Define custom graph break op +lib = Library("executorch_utils", "DEF") +lib.define("graph_break.Tensor(Tensor x) -> Tensor") + + +@impl(lib, "graph_break.Tensor", "CompositeExplicitAutograd") +def graph_break_impl(x): + return x + + +class ExecutorchGraphBreakModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return tuple( + ( + torch.ops.executorch_utils.graph_break.Tensor(a) + if isinstance(a, torch.Tensor) + else a + ) + for a in args + ) + + +class BlockWithGraphBreak(nn.Module): + def __init__(self, block: nn.Module, break_before: bool = True): + super().__init__() + self.graph_break = ExecutorchGraphBreakModule() + self.block = block + self.break_before = break_before + + def forward(self, *args, **kwargs): + if self.break_before: + new_args = self.graph_break(*args) + out = self.block(*new_args, **kwargs) + return out + else: + out = self.block(*args, **kwargs) + out = self.graph_break(*out) + return out + + +def remove_graph_break_(edge_manager): + from executorch.exir.dialects._ops import ops as exir_ops + + for n in edge_manager.exported_program().graph_module.graph.nodes: + if n.target == exir_ops.edge.executorch_utils.graph_break.Tensor: + n.replace_all_uses_with(n.args[0]) + edge_manager.exported_program().graph_module.graph.eliminate_dead_code() + + +def load_model(checkpoint_path: str, params_path: str, max_context_len: int): + """Load the model from checkpoint with static_mha attention type.""" + with open(params_path, "r") as f: + params = json.loads(f.read()) + + # TODO: to support lookahead decoding, the static model outputs + # full logits, but if we are not using lookahead decoding, we can have a + # more efficient model by setting generate_full_logits=False and supplying the last + # valid token + args = ModelArgs( + max_context_len=max_context_len, + generate_full_logits=True, + **params, + ) + args.attention_type = "static_mha" + args.attention_kwargs = {"decompose_sdpa_in_mha": True} + + with torch.device("meta"): + model = construct_transformer(args) + + checkpoint = torch.load( + checkpoint_path, map_location="cpu", mmap=True, weights_only=True + ) + if "model" in checkpoint: + checkpoint = checkpoint["model"] + + # Rename attention weight keys for static attention + for i in range(len(model.layers)): + if f"layers.{i}.attention.wq.weight" in checkpoint: + checkpoint[f"layers.{i}.attention.wqs.0.weight"] = checkpoint.pop( + f"layers.{i}.attention.wq.weight" + ) + if f"layers.{i}.attention.wk.weight" in checkpoint: + checkpoint[f"layers.{i}.attention.wks.0.weight"] = checkpoint.pop( + f"layers.{i}.attention.wk.weight" + ) + if f"layers.{i}.attention.wv.weight" in checkpoint: + checkpoint[f"layers.{i}.attention.wvs.0.weight"] = checkpoint.pop( + f"layers.{i}.attention.wv.weight" + ) + + missing, unexpected = model.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + if missing: + print(f"Missing keys: {missing}") + if unexpected: + print(f"Unexpected keys: {unexpected}") + + return model, args + + +def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype): + """ + Generate metadata methods for the C++ runner. + + The C++ runner needs these constant methods to understand the model structure: + - vocab_size: Vocabulary size + - head_dim: Head dimension + - n_heads_per_cache: Number of KV heads + - freqs_cos, freqs_sin: Pre-computed RoPE frequencies + - freqs_cos_input_index, freqs_sin_input_index: Input indices for RoPE + - kv_cache_specs: Tensor describing cache input/output indices and lengths + - mask_specs: Tensor describing mask input indices + - forward_input_len: Input length for forward method + """ + # Pre-compute RoPE frequencies for the full context + rope = Rope(model_args) + freqs_cos, freqs_sin = rope.get_freqs(None, model_args.max_context_len) + print(f"Pre-computed RoPE frequencies shape: {freqs_cos.shape}, {freqs_sin.shape}") + + # Flatten example inputs to get the pytree spec + flat_inputs, in_spec = pytree.tree_flatten(example_inputs) + + # Reconstruct input indices from the pytree spec + input_indices = pytree.tree_unflatten( + list(range(in_spec.num_leaves)), + in_spec, + ) + + # input_indices structure: + # (token_idx, { + # "masks": {cache_len: mask_idx}, + # "freqs_cos_override": freqs_cos_idx, + # "freqs_sin_override": freqs_sin_idx, + # "in_cache_state": ({k_cache_ids: k_cache_idx}, {v_cache_ids: v_cache_idx}) + # }) + + # Get the options dict indices + opts_indices = input_indices[1] + + # Build KV cache specs: [k_in_idx, k_out_idx, v_in_idx, v_out_idx, cache_len] + # For static_mha, output cache indices follow the same order as inputs + # Output structure: (logits, {"out_cache_state": ({k_ids: k_out}, {v_ids: v_out})}) + k_cache_in_indices = opts_indices["in_cache_state"][0] + v_cache_in_indices = opts_indices["in_cache_state"][1] + + # Sort by layer to ensure consistent ordering + sorted_k_cache_ids = sorted(k_cache_in_indices.keys()) + + # Output indices are in the same order (after logits) + # Logits is output 0, then k_caches, then v_caches + kv_cache_specs = [] + for i, cache_id in enumerate(sorted_k_cache_ids): + k_in_idx = k_cache_in_indices[cache_id] + v_in_idx = v_cache_in_indices[cache_id] + # Output indices: k_caches come after logits (idx 1 to n_layers), + # v_caches come after k_caches (idx n_layers+1 to 2*n_layers) + k_out_idx = 1 + i + v_out_idx = 1 + len(sorted_k_cache_ids) + i + kv_cache_specs.append([k_in_idx, k_out_idx, v_in_idx, v_out_idx, cache_len]) + + print(f"KV cache specs (k_in, k_out, v_in, v_out, cache_len): {kv_cache_specs}") + + # Build mask specs: [mask_idx, cache_len] + mask_specs = [ + [mask_idx, c_len] for c_len, mask_idx in opts_indices["masks"].items() + ] + print(f"Mask specs (mask_idx, cache_len): {mask_specs}") + + return { + "vocab_size": model_args.vocab_size, + "head_dim": model_args.head_dim, + "n_heads_per_cache": model_args.n_kv_heads, + "freqs_cos": freqs_cos.to(float_dtype), + "freqs_sin": freqs_sin.to(float_dtype), + "freqs_cos_input_index": torch.tensor( + [opts_indices["freqs_cos_override"]], dtype=torch.int64 + ), + "freqs_sin_input_index": torch.tensor( + [opts_indices["freqs_sin_override"]], dtype=torch.int64 + ), + "mask_specs": torch.tensor(mask_specs, dtype=torch.int64), + "kv_cache_specs": torch.tensor(kv_cache_specs, dtype=torch.int64), + "forward_input_len": input_len, + } + + +def main(): + parser = argparse.ArgumentParser( + description="Export static attention Llama model to CoreML" + ) + + # Model paths + parser.add_argument( + "-c", + "--checkpoint", + required=True, + help="Path to model checkpoint (.pth)", + ) + parser.add_argument( + "-p", + "--params", + required=True, + help="Path to params.json", + ) + parser.add_argument( + "-o", + "--output", + default="model.pte", + help="Output filename for the .pte model", + ) + + # Model configuration + parser.add_argument( + "--max_context_len", + type=int, + default=1024, + help="Maximum context length", + ) + parser.add_argument( + "--input_len", + type=int, + default=32, + help="Input sequence length per forward pass", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp16", "fp32"], + default="fp16", + help="Model dtype. The ANE requires fp16.", + ) + + # Quantization options + parser.add_argument( + "-E", + "--embedding_quantize", + default="8,0", + type=str, + help="Embedding quantization: ',', e.g., '4,32' or '8,0' for per-channel", + ) + parser.add_argument( + "--linear_quantize", + default="c4w", + choices=["b4w", "c4w"], + help="CoreML linear quantization: b4w (blockwise 4-bit) or c4w (channelwise 4-bit). The ANE requires channelwise.", + ) + + # Linear splitting options + parser.add_argument( + "--target_split_size", + type=int, + default=1024, + help="Split linear layers into chunks of this size (helps with ANE performance)", + ) + parser.add_argument( + "--max_splits", + type=int, + default=8, + help="Maximum number of splits for linear layers", + ) + + # Graph break options + parser.add_argument( + "--no_graph_breaks", + action="store_true", + help="Disable graph breaks between transformer blocks", + ) + + args = parser.parse_args() + + # Compute cache length + + print("Quantization and datatype:") + print(f"\tEmbedding quantize: {args.embedding_quantize}") + print(f"\tLinear quantize: {args.linear_quantize}") + print(f"\tDtype: {args.dtype}") + + cache_len = args.max_context_len - args.input_len + print("\nGeneration configuration:") + print(f"\tMax context length: {args.max_context_len}") + print(f"\tInput length: {args.input_len}") + print(f"\tCache length: {cache_len}") + + print("\nLinear splitting:") + print(f"\tTarget split size: {args.target_split_size}") + print(f"\tMax splits: {args.max_splits}") + + # Load model + print(f"\nLoading model from {args.checkpoint}...") + model, model_args = load_model( + args.checkpoint, + args.params, + args.max_context_len, + ) + print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim") + + # Set dtype + float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] + model = model.to(float_dtype).eval() + + # Apply linear splitting (before quantization) + if args.target_split_size is not None: + print(f"\nSplitting linear layers with target size {args.target_split_size}...") + replace_linear_with_split_linear( + model, + out_target_split_size=args.target_split_size, + out_max_splits=args.max_splits, + in_target_split_size=1, + in_max_splits=1, + ) + + # Apply embedding quantization + if args.embedding_quantize: + bitwidth, group_size = args.embedding_quantize.split(",") + bitwidth = int(bitwidth) + group_size = int(group_size) + assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" + + print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...") + if group_size == 0: + granularity = PerAxis(0) + else: + granularity = PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bitwidth}") + + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + # Apply linear quantization + if args.linear_quantize == "b4w": + print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...") + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + ), + ) + elif args.linear_quantize == "c4w": + print("\nQuantizing linear layers: 4-bit channelwise...") + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerAxis(0), + ), + ) + + # Add graph breaks between transformer blocks + # Keeping model pieces smaller helps with ANE performance + if not args.no_graph_breaks: + print("\nAdding graph breaks between before/after the transformer blocks...") + n_layers = len(model.layers) + model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True) + model.layers[n_layers - 1] = BlockWithGraphBreak( + model.layers[n_layers - 1], break_before=False + ) + + # Create IO manager and example inputs + mgr = StaticAttentionIOManager( + model_args, + input_len=args.input_len, + cache_lens=cache_len, + batch_size=1, + dtype=float_dtype, + style="smart_mask", # Use smart_mask to match C++ StaticTransformerRunner + mask_val=float("-inf"), + ) + example_inputs = ( + torch.zeros(1, args.input_len, dtype=torch.int32), + { + "masks": mgr.masks, + "freqs_cos_override": mgr.freqs_cos[: args.input_len], + "freqs_sin_override": mgr.freqs_sin[: args.input_len], + "in_cache_state": (mgr.k_caches, mgr.v_caches), + }, + ) + + # Test eager execution + print("\nTesting eager execution...") + with torch.no_grad(): + model(*example_inputs) + print("Eager execution successful!") + + # Export the model + print("\nExporting model...") + ep = torch.export.export(model, example_inputs) + print("Export successful!") + print(ep) + + # Generate metadata for C++ runner + print("\nGenerating metadata for C++ runner...") + constant_methods = _get_metadata( + model_args, example_inputs, args.input_len, cache_len, float_dtype + ) + + # Setup CoreML partitioner + print("\nSetting up CoreML partitioner...") + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18, + compute_precision={ + torch.float16: ct.precision.FLOAT16, + torch.float32: ct.precision.FLOAT32, + }[float_dtype], + compute_unit=ct.ComputeUnit.CPU_AND_NE, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, + ) + partitioner = CoreMLPartitioner( + compile_specs=compile_specs, + take_over_mutable_buffer=False, + skip_ops_for_coreml_delegation=[], + ) + + # Lower to edge with constant methods for C++ runner + print("\nLowering to edge...") + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + edge_manager = to_edge_transform_and_lower( + ep, + partitioner=[partitioner], + constant_methods=constant_methods, + compile_config=edge_compile_config, + ) + + print("\nDelegated program:") + print(format_delegated_graph(edge_manager.exported_program().graph_module)) + + # Convert to ExecuTorch + print("\nConverting to ExecuTorch...") + remove_graph_break_(edge_manager) + executorch_program = edge_manager.to_executorch( + ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, alloc_graph_output=False + ), + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), + ) + ) + + # Save the program + filename = save_pte_program(executorch_program, args.output) + print(f"\nSaved ExecuTorch program to {filename}") + + +if __name__ == "__main__": + main() diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index 14dff0c8580..46e9043a5fc 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -1,5 +1,41 @@ # ANE-friendly Llama models +To export a static, ANE-friendly model use: + +``` +python export_static_llm_coreml.py \ + --checkpoint /path/to/model.pth \ + --params /path/to/params.json \ + --output static_llm_coreml_model.pte +``` + +To test in python, use: + +``` +python run_static_llm.py \ + --model static_llm_coreml_model.pte \ + --params /path/to/params.json \ + --tokenizer /path/to/tokenizer.model \ + --prompt "Once upon a time" \ + --max_new_tokens 100 \ + --lookahead +``` + +(Enabling lookahead decoding is optional, but does improve performance.) + +The static model has several ANE optimizations, including: +* Splitting linear layers for improved performance (controlled by target_split_size and max_splits args) +* Splitting the pte into multiple Core ML pieces for improved performance (can be disabled with no_graph_breaks) +* Re-writing SDPA to avoid 5-D tensors to imporve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833) + + +We are working on adding a C++ runner as well. + + +# Deprecated (export.py, run.py, and run_lookahead.py) + +Below we describe export.py, run.py, and run_lookahead.py. But these are deprecated and will evenutally be removed because we are unifying around the static model formulation. + This directory contains ANE-friendly Llama models. Export model with: diff --git a/examples/apple/coreml/llama/run_static_llm.py b/examples/apple/coreml/llama/run_static_llm.py new file mode 100644 index 00000000000..2cd526aec42 --- /dev/null +++ b/examples/apple/coreml/llama/run_static_llm.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run script for static attention Llama models exported with coreml_static_llama.py. + +Usage: + python run_static_llm.py \ + --model llama1b_static.pte \ + --params $HOME/models/llama1b/params.json \ + --tokenizer $HOME/models/llama1b/tokenizer.model \ + --prompt "Once upon a time" \ + --max_new_tokens 100 +""" + +import argparse +import json +import time +from typing import Any, Dict, List, Tuple + +import sentencepiece as spm +import torch +import torch.utils._pytree as pytree + +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.runner.generation import next_token +from executorch.examples.models.llama.static_attention import StaticAttentionIOManager +from executorch.runtime import Runtime + + +class Tokenizer: + """Wrapper to support both SentencePiece and Tiktoken tokenizers.""" + + def __init__(self, model_path: str): + try: + print("Trying to load sentencepiece") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + self.tokenizer = sp + self._is_sentencepiece = True + except Exception: + print("Trying to load tiktoken") + from executorch.examples.models.llama.tokenizer import tiktoken + + self.tokenizer = tiktoken.Tokenizer(model_path) + self._is_sentencepiece = False + + def encode(self, text: str, bos: bool = True, eos: bool = False) -> List[int]: + if self._is_sentencepiece: + bos_string = "" if bos else "" + eos_string = "" if eos else "" + return self.tokenizer.encode(f"{bos_string}{text}{eos_string}") + return self.tokenizer.encode(text, bos=bos, eos=eos) + + def decode(self, tokens: List[int]) -> str: + if self._is_sentencepiece: + return self.tokenizer.decode(tokens) + return self.tokenizer.decode(tokens) + + def decode_token(self, token: int) -> str: + if self._is_sentencepiece: + return self.tokenizer.decode([token]) + try: + return self.tokenizer.decode_token(token) + except UnicodeDecodeError: + return f"<{token}>" + + @property + def stop_tokens(self) -> List[int]: + if self._is_sentencepiece: + return [self.tokenizer.eos_id()] + return self.tokenizer.stop_tokens + + +def create_pte_wrapper( + method, + k_cache_keys: List[str], + v_cache_keys: List[str], +): + """ + Create a wrapper function that adapts PTE execution to the interface + expected by StaticAttentionIOManager. + + The wrapper: + - Takes (tokens, options_dict) like the eager model + - Flattens inputs using pytree + - Executes the PTE method + - Reconstructs outputs to match eager model format: (logits, {"out_cache_state": (k_dict, v_dict)}) + """ + + def wrapper( + tokens: torch.Tensor, options: Dict[str, Any] + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + # Build the same input structure as during export + inputs = (tokens, options) + + # Flatten using pytree (same order as torch.export) + flat_inputs, _ = pytree.tree_flatten(inputs) + + # Execute PTE model + outputs = method.execute(flat_inputs) + + # First output is logits + logits = outputs[0] + + # Remaining outputs are k_cache updates then v_cache updates + num_layers = len(k_cache_keys) + k_updates = outputs[1 : 1 + num_layers] + v_updates = outputs[1 + num_layers : 1 + 2 * num_layers] + + # Reconstruct the output cache state dicts + k_cache_dict = dict(zip(k_cache_keys, k_updates)) + v_cache_dict = dict(zip(v_cache_keys, v_updates)) + + attn_updates = {"out_cache_state": (k_cache_dict, v_cache_dict)} + + return logits, attn_updates + + return wrapper + + +def main(): + parser = argparse.ArgumentParser(description="Run static attention Llama model") + + parser.add_argument( + "-m", + "--model", + required=True, + help="Path to exported .pte model", + ) + parser.add_argument( + "-p", + "--params", + required=True, + help="Path to params.json", + ) + parser.add_argument( + "-t", + "--tokenizer", + required=True, + help="Path to tokenizer model", + ) + parser.add_argument( + "--prompt", + type=str, + default="Once upon a time,", + help="Input prompt", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.6, + help="Sampling temperature", + ) + parser.add_argument( + "--top_p", + type=float, + default=0.9, + help="Top-p (nucleus) sampling threshold", + ) + parser.add_argument( + "--input_len", + type=int, + default=32, + help="Input sequence length (must match export)", + ) + parser.add_argument( + "--cache_len", + type=int, + default=992, + help="Cache length (must match export: max_context_len - input_len)", + ) + parser.add_argument( + "--lookahead", + action="store_true", + help="Enable lookahead (speculative) decoding", + ) + parser.add_argument( + "--ngram_size", + type=int, + default=5, + help="N-gram size for lookahead decoding", + ) + parser.add_argument( + "--window_size", + type=int, + default=4, + help="Window size for lookahead decoding", + ) + parser.add_argument( + "--n_verifications", + type=int, + default=4, + help="Number of verification branches for lookahead decoding", + ) + + args = parser.parse_args() + + # Load tokenizer + tokenizer = Tokenizer(args.tokenizer) + + # Load model params + with open(args.params, "r") as f: + params = json.loads(f.read()) + + # Create model args + model_args = ModelArgs( + max_context_len=args.cache_len + args.input_len, + generate_full_logits=True, + **params, + ) + model_args.attention_type = "static_mha" + + print(f"Model config: {model_args.n_layers} layers, dim={model_args.dim}") + print(f"Input length: {args.input_len}, Cache length: {args.cache_len}") + + # Create StaticAttentionIOManager + mgr = StaticAttentionIOManager( + model_args, + input_len=args.input_len, + cache_lens=args.cache_len, + batch_size=1, + dtype=torch.float16, + style="smart_mask", # Use smart_mask to match C++ StaticTransformerRunner + mask_val=float("-inf"), + ) + + # Load PTE model + print(f"Loading model from {args.model}...") + runtime = Runtime.get() + program = runtime.load_program(args.model) + method = program.load_method("forward") + + metadata = method.metadata + print( + f"Method metadata: num_inputs={metadata.num_inputs()}, num_outputs={metadata.num_outputs()}" + ) + + # Get cache keys in insertion order (NOT sorted alphabetically!) + # Pytree preserves dict insertion order in Python 3.7+ + # The caches are created in layer order (0, 1, 2, ..., n_layers-1) + k_cache_keys = list(mgr.k_caches.keys()) + v_cache_keys = list(mgr.v_caches.keys()) + + # Create wrapper function that adapts PTE to eager interface + model_fn = create_pte_wrapper(method, k_cache_keys, v_cache_keys) + + # Encode prompt + prompt_tokens = tokenizer.encode(args.prompt, bos=True, eos=False) + print(f"\nPrompt: {args.prompt}") + print(f"Prompt tokens: {len(prompt_tokens)}") + print("-" * 50) + + # Reset manager + mgr.reset() + + # Prefill using StaticAttentionIOManager.prefill + print("Prefilling...", end=" ", flush=True) + start_time = time.time() + logits = mgr.prefill(model_fn, prompt_tokens) + prefill_time = time.time() - start_time + print(f"done in {prefill_time:.2f}s") + + # Get first token from prefill logits + first_token = next_token(logits[:, -1, :], args.temperature, args.top_p) + + # Decode using StaticAttentionIOManager.decode or lookahead_decode + print(f"\n{args.prompt}", end="", flush=True) + print(tokenizer.decode_token(first_token), end="", flush=True) + + decode_start = time.time() + + if args.lookahead: + # Use lookahead (speculative) decoding + print( + f"\n[Using lookahead decoding: ngram={args.ngram_size}, window={args.window_size}, verifications={args.n_verifications}]" + ) + generated_tokens = mgr.lookahead_decode( + model_fn, + first_token, + n=args.max_new_tokens - 1, # -1 because first_token counts + ngram_size=args.ngram_size, + window_size=args.window_size, + n_verifications=args.n_verifications, + stop_tokens=tokenizer.stop_tokens, + ) + else: + # Use standard autoregressive decoding + generated_tokens = mgr.decode( + model_fn, + first_token, + n=args.max_new_tokens - 1, # -1 because first_token counts + stop_tokens=tokenizer.stop_tokens, + ) + + # Print generated tokens (skip first as it's the init_token we already printed) + for token in generated_tokens[1:]: + if token in tokenizer.stop_tokens: + break + print(tokenizer.decode_token(token), end="", flush=True) + + decode_time = time.time() - decode_start + total_generated = len(generated_tokens) + tokens_per_sec = total_generated / decode_time if decode_time > 0 else 0 + + print("\n" + "-" * 50) + print(f"Prefill: {len(prompt_tokens)} tokens in {prefill_time:.2f}s") + print( + f"Decode: {total_generated} tokens in {decode_time:.2f}s ({tokens_per_sec:.2f} tok/s)" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/apple/coreml/scripts/extract_coreml_models.py b/examples/apple/coreml/scripts/extract_coreml_models.py index b3778a22625..593a270186b 100644 --- a/examples/apple/coreml/scripts/extract_coreml_models.py +++ b/examples/apple/coreml/scripts/extract_coreml_models.py @@ -21,7 +21,7 @@ def extract_coreml_models(pte_data: bytes): - program = deserialize_pte_binary(pte_data) + program = deserialize_pte_binary(pte_data).program delegates: List[BackendDelegate] = sum( [execution_plan.delegates for execution_plan in program.execution_plan], [] ) diff --git a/examples/arm/README.md b/examples/arm/README.md index 9cce33bdade..e57644b9f74 100644 --- a/examples/arm/README.md +++ b/examples/arm/README.md @@ -10,7 +10,7 @@ The main scripts are `setup.sh`, `run.sh` and `aot_arm_compiler.py`. `setup.sh` will install the needed tools and with --root-dir you can change the path to a scratch folder where it will download and generate build artifacts. If supplied, you must also supply the same folder to run.sh with ---scratch-dir= If not supplied both script will use examples/arm/ethos-u-scratch +--scratch-dir= If not supplied both script will use examples/arm/arm-scratch `run.sh` can be used to build, run and test a model in an easy way and it will call cmake for you and in cases you want to run a simulator it will start it also. The script will call `aot_arm_compiler.py` @@ -89,7 +89,7 @@ $ cd $ ./examples/arm/setup.sh --i-agree-to-the-contained-eula # Step [2] - Setup path to tools, The `setup.sh` script has generated a script that you need to source every time you restart you shell. -$ source examples/arm/ethos-u-scratch/setup_path.sh +$ source examples/arm/arm-scratch/setup_path.sh # Step [3] - build and run ExecuTorch and executor_runner baremetal example application # on a Corstone(TM)-320 FVP to run a simple PyTorch model from a file. diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index d7e1b64e3ca..4c4a1e8eac2 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -9,40 +9,37 @@ import argparse import copy -import json import logging import os +import sys from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch from examples.devtools.scripts.export_bundled_program import save_bundled_program -from executorch.backends.arm.arm_backend import ( - ArmCompileSpecBuilder, - is_ethosu, - is_tosa, - is_vgf, -) -from executorch.backends.arm.ethosu import EthosUPartitioner +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.ethosu import EthosUCompileSpec from executorch.backends.arm.quantizer import ( - EthosUQuantizer, + get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, - TOSAQuantizer, - VgfQuantizer, ) from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.partitioner import TOSAPartitioner -from executorch.backends.arm.tosa.specification import get_tosa_spec +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.util._factory import create_partitioner, create_quantizer from executorch.backends.arm.util.arm_model_evaluator import ( - GenericModelEvaluator, - MobileNetV2Evaluator, + evaluate_model, + evaluator_calibration_data, ) -from executorch.backends.arm.vgf import VgfPartitioner +from executorch.backends.arm.vgf import VgfCompileSpec # To use Cortex-M backend +from executorch.backends.cortex_m.passes.convert_to_cortex_m_pass import ( + ConvertToCortexMPass, +) + from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( QuantizedOpFusionPass, ) @@ -60,9 +57,11 @@ ExecutorchBackendConfig, to_edge_transform_and_lower, ) -from executorch.exir.backend.compile_spec_schema import CompileSpec + from executorch.extension.export_util.utils import save_pte_program from tabulate import tabulate +from torch.export import ExportedProgram +from torch.fx import GraphModule from torch.utils.data import DataLoader # Quantize model if required using the standard export quantizaion flow. @@ -76,98 +75,189 @@ logging.basicConfig(level=logging.WARNING, format=FORMAT) -def get_model_and_inputs_from_name( - model_name: str, model_input: str | None -) -> Tuple[torch.nn.Module, Any]: - """Given the name of an example pytorch model, return it and example inputs. +def _load_example_inputs(model_input: str | None) -> Any: # nosec B614 + """Load example inputs from a `.pt` file when a path is provided.""" + if model_input is None: + return None + + logging.info(f"Load model input from {model_input}") + + if model_input.endswith(".pt"): + return torch.load( + model_input, weights_only=False + ) # nosec B614 trusted artifacts + + raise RuntimeError( + f"Model input data '{model_input}' is not a valid name. Use --model_input " + ".pt e.g. saved with torch.save()" + ) + + +def _load_internal_model( + model_name: str, example_inputs: Any +) -> Optional[Tuple[torch.nn.Module, Any]]: + """Load a bundled example model from the internal `MODELS` mapping.""" + logging.info( + "Loading internal models is deprecated. Use --model_name .py/.pt " + "or a model from examples/models." + ) + + if model_name not in MODELS: + return None + + logging.info(f"Internal model {model_name}") + + model = MODELS[model_name]() + inputs = ( + example_inputs + if example_inputs is not None + else MODELS[model_name].example_input + ) + + return model, inputs + + +def _load_registered_model( + model_name: str, example_inputs: Any +) -> Optional[Tuple[torch.nn.Module, Any]]: + """Load a registered example model from `examples.models`.""" + if model_name not in MODEL_NAME_TO_MODEL: + return None + + logging.warning( + "Using a model from examples/models not all of these are currently supported" + ) + logging.info( + f"Load {model_name} -> {MODEL_NAME_TO_MODEL[model_name]} from examples/models" + ) + + model, tmp_example_inputs, _, _ = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL[model_name] + ) + inputs = example_inputs if example_inputs is not None else tmp_example_inputs + + return model, inputs + + +def _load_python_module_model( + model_name: str, example_inputs: Any +) -> Optional[Tuple[torch.nn.Module, Any]]: + """Load a model and inputs from a Python source file. + + The file must define `ModelUnderTest` and `ModelInputs` attributes. - Raises RuntimeError if there is no example model corresponding to the given name. """ - example_inputs = None - if model_input is not None: - logging.info(f"Load model input from {model_input}") - if model_input.endswith(".pt"): - example_inputs = torch.load(model_input, weights_only=False) - else: - raise RuntimeError( - f"Model input data '{model_input}' is not a valid name. Use --model_input .pt e.g. saved with torch.save()" - ) + if not model_name.endswith(".py"): + return None - # Case 1: Model is defined in this file - if model_name in models.keys(): - logging.info(f"Internal model {model_name}") - model = models[model_name]() - if example_inputs is None: - example_inputs = models[model_name].example_input - # Case 2: Model is defined in examples/models/ - elif model_name in MODEL_NAME_TO_MODEL.keys(): - logging.warning( - "Using a model from examples/models not all of these are currently supported" - ) - logging.info( - f"Load {model_name} -> {MODEL_NAME_TO_MODEL[model_name]} from examples/models" - ) + logging.info( + f"Load model file {model_name} " + "Variable ModelUnderTest= ModelInputs=" + ) - model, tmp_example_inputs, _, _ = EagerModelFactory.create_model( - *MODEL_NAME_TO_MODEL[model_name] - ) - if example_inputs is None: - example_inputs = tmp_example_inputs - # Case 3: Model is in an external python file loaded as a module. - # ModelUnderTest should be a torch.nn.module instance - # ModelInputs should be a tuple of inputs to the forward function - elif model_name.endswith(".py"): - logging.info( - f"Load model file {model_name} Variable ModelUnderTest= ModelInputs=" - ) - import importlib.util - - # load model's module and add it - spec = importlib.util.spec_from_file_location("tmp_model", model_name) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - model = module.ModelUnderTest - if example_inputs is None: - example_inputs = module.ModelInputs - # Case 4: Model is in an saved model file torch.save(model) - elif model_name.endswith(".pth") or model_name.endswith(".pt"): - logging.info(f"Load model file {model_name}") - model = torch.load(model_name, weights_only=False) - if example_inputs is None: - raise RuntimeError( - f"Model '{model_name}' requires input data specify --model_input .pt" - ) - else: + import importlib.util + + spec = importlib.util.spec_from_file_location("tmp_model", model_name) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load model file {model_name}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + sys.modules["tmp_model"] = module + model = module.ModelUnderTest + inputs = example_inputs if example_inputs is not None else module.ModelInputs + + return model, inputs + + +def _load_serialized_model( + model_name: str, example_inputs: Any +) -> Optional[Tuple[torch.nn.Module, Any]]: # nosec B614 + """Load a serialized Torch model saved via `torch.save`.""" + if not model_name.endswith((".pth", ".pt")): + return None + + logging.info(f"Load model file {model_name}") + + model = torch.load(model_name, weights_only=False) # nosec B614 trusted inputs + if example_inputs is None: raise RuntimeError( - f"Model '{model_name}' is not a valid name. Use --help for a list of available models." + f"Model '{model_name}' requires input data specify --model_input .pt" ) - logging.debug(f"Loaded model: {model}") - logging.debug(f"Loaded input: {example_inputs}") + return model, example_inputs +def get_model_and_inputs_from_name( + model_name: str, model_input: str | None +) -> Tuple[torch.nn.Module, Any]: + """Resolve a model name into a model instance and example inputs. + + Args: + model_name: Identifier for the model. It can be a key in + `MODEL_NAME_TO_MODEL`, a Python module path, or a serialized + model file path. + model_input: Optional path to a `.pt` file containing example inputs. + + Returns: + Tuple of `(model, example_inputs)` ready for compilation. + + Raises: + RuntimeError: If the model cannot be resolved or required inputs are + missing. + + """ + example_inputs = _load_example_inputs(model_input) + + loaders = ( + _load_internal_model, + _load_registered_model, + _load_python_module_model, + _load_serialized_model, + ) + + for loader in loaders: + result = loader(model_name, example_inputs) + if result is not None: + model, example_inputs = result + logging.debug(f"Loaded model: {model}") + logging.debug(f"Loaded input: {example_inputs}") + return model, example_inputs + + raise RuntimeError( + f"Model '{model_name}' is not a valid name. Use --help for a list of available models." + ) + + def quantize( - model: torch.nn.Module, + model: GraphModule, model_name: str, - compile_specs: list[CompileSpec], + compile_specs: EthosUCompileSpec | VgfCompileSpec | TosaCompileSpec, example_inputs: Tuple[torch.Tensor], evaluator_name: str | None, evaluator_config: Dict[str, Any] | None, -) -> torch.nn.Module: - """This is the official recommended flow for quantization in pytorch 2.0 export""" + is_int16x8: bool = False, +) -> GraphModule: + """This is the official recommended flow for quantization in pytorch 2.0 + export. + + """ logging.info("Quantizing Model...") logging.debug(f"Original model: {model}") - quantizer = None - if is_ethosu(compile_specs): - quantizer = EthosUQuantizer(compile_specs) - elif is_tosa(compile_specs): - quantizer = TOSAQuantizer(get_tosa_spec(compile_specs)) - elif is_vgf(compile_specs): - quantizer = VgfQuantizer(compile_specs) + + quantizer = create_quantizer(compile_specs) + + if is_int16x8: + if compile_specs.tosa_spec.support_extension("int16"): + operator_config = get_symmetric_a16w8_quantization_config( + is_per_channel=True + ) + else: + raise ValueError( + f"Context TOSA spec {compile_specs.tosa_spec} doesn't support int16" + ) else: - raise RuntimeError("Unsupported compilespecs for quantization!") + operator_config = get_symmetric_quantization_config(is_per_channel=True) - operator_config = get_symmetric_quantization_config() quantizer.set_global(operator_config) m = prepare_pt2e(model, quantizer) @@ -188,46 +278,6 @@ def quantize( return m -# Simple example models -class AddModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x + x - - example_input = (torch.ones(5, dtype=torch.int32),) - can_delegate = True - - -class AddModule2(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return x + y - - example_input = ( - torch.ones(5, dtype=torch.int32), - torch.ones(5, dtype=torch.int32), - ) - can_delegate = True - - -class AddModule3(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return (x + y, x + x) - - example_input = ( - torch.ones(5, dtype=torch.int32), - torch.ones(5, dtype=torch.int32), - ) - can_delegate = True - - class QuantAddTest(torch.nn.Module): def __init__(self): super().__init__() @@ -276,48 +326,29 @@ def forward(self, w, x, y, z): can_delegate = True # when quantized -class SoftmaxModule(torch.nn.Module): +class QuantLinearTest(torch.nn.Module): def __init__(self): super().__init__() - self.softmax = torch.nn.Softmax(dim=0) + # Define a simple linear layer + self.linear = torch.nn.Linear(61, 37) def forward(self, x): - z = self.softmax(x) - return z + return self.linear(x) - example_input = (torch.ones(2, 2),) + example_input = (torch.randn([8, 61], dtype=torch.float32),) can_delegate = True -class MultipleOutputsModule(torch.nn.Module): - def forward(self, x: torch.Tensor, y: torch.Tensor): - return (x * y, x.sum(dim=-1, keepdim=True)) - - example_input = (torch.randn(10, 4, 5), torch.randn(10, 4, 5)) - can_delegate = True - - -models = { - "add": AddModule, - "add2": AddModule2, - "add3": AddModule3, +MODELS = { "qadd": QuantAddTest, "qadd2": QuantAddTest2, "qops": QuantOpTest, - "softmax": SoftmaxModule, - "MultipleOutputsModule": MultipleOutputsModule, + # TODO: Remove this from here, once we have dedicated MCU test pipeline ready. This is an interim solution. + # See https://github.com/pytorch/executorch/discussions/13944 + "qlinear": QuantLinearTest, } -calibration_data = { - "add": (torch.randn(1, 5),), - "add2": ( - torch.randn(1, 5), - torch.randn(1, 5), - ), - "add3": ( - torch.randn(32, 5), - torch.randn(32, 5), - ), +CALIBRATION_DATA = { "qadd": (torch.randn(32, 2, 1),), "qadd2": ( torch.randn(32, 2, 1), @@ -329,15 +360,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): torch.randn(32, 2, 1) * -0.000001, torch.randn(32, 2, 1) * 1000, ), - "softmax": (torch.randn(32, 2, 2),), -} - -evaluators = { - "generic": GenericModelEvaluator, - "mv2": MobileNetV2Evaluator, } -targets = [ +TARGETS = [ "ethos-u55-32", "ethos-u55-64", "ethos-u55-128", @@ -350,6 +375,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): "vgf", "TOSA-1.0+INT", "TOSA-1.0+FP", + "TOSA-1.0+INT+int16", ] @@ -361,26 +387,14 @@ def get_calibration_data( ): # Firstly, if the model is being evaluated, take the evaluators calibration function if it has one if evaluator_name is not None: - evaluator = evaluators[evaluator_name] - - if hasattr(evaluator, "get_calibrator"): - assert evaluator_config is not None + evaluator_data = evaluator_calibration_data(evaluator_name, evaluator_config) + if evaluator_data is not None: + return evaluator_data - config_path = Path(evaluator_config) - with config_path.open() as f: - config = json.load(f) - - if evaluator_name == "mv2": - return evaluator.get_calibrator( - training_dataset_path=config["training_dataset_path"] - ) - else: - raise RuntimeError(f"Unknown evaluator: {evaluator_name}") - - # If the model is in the calibration_data dictionary, get the data from there + # If the model is in the CALIBRATION_DATA dictionary, get the data from there # This is used for the simple model examples provided - if model_name in calibration_data: - return calibration_data[model_name] + if model_name in CALIBRATION_DATA: + return CALIBRATION_DATA[model_name] # As a last resort, fallback to the scripts previous behavior and return the example inputs return example_inputs @@ -393,20 +407,24 @@ def get_compile_spec( memory_mode: Optional[str] = None, quantize: bool = False, config: Optional[str] = None, -) -> list[CompileSpec]: - spec_builder = None + debug_mode: Optional[str] = None, +) -> TosaCompileSpec | EthosUCompileSpec | VgfCompileSpec: + compile_spec = None if target.startswith("TOSA"): try: tosa_spec = TosaSpecification.create_from_string(target) - except: + except Exception: tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") - spec_builder = ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec) + compile_spec = TosaCompileSpec(tosa_spec) elif "ethos-u" in target: - spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec( + extra_flags = ["--verbose-operators", "--verbose-cycle-estimate"] + if debug_mode is not None: + extra_flags.append("--enable-debug-db") + compile_spec = EthosUCompileSpec( target, system_config=system_config, memory_mode=memory_mode, - extra_flags="--verbose-operators --verbose-cycle-estimate", + extra_flags=extra_flags, config_ini=config, ) elif "vgf" in target: @@ -414,58 +432,18 @@ def get_compile_spec( tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") else: tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") - spec_builder = ArmCompileSpecBuilder().vgf_compile_spec(tosa_spec) + compile_spec = VgfCompileSpec(tosa_spec) + else: + raise RuntimeError(f"Unkown target {target}") if intermediates is not None: - spec_builder.dump_intermediate_artifacts_to(intermediates) + compile_spec.dump_intermediate_artifacts_to(intermediates) - return spec_builder.build() + if debug_mode is not None: + mode = ArmCompileSpec.DebugMode[debug_mode.upper()] + compile_spec.dump_debug_info(mode) - -def evaluate_model( - model_name: str, - intermediates: str, - model_fp32: torch.nn.Module, - model_int8: torch.nn.Module, - example_inputs: Tuple[torch.Tensor], - evaluator_name: str, - evaluator_config: str | None, -) -> None: - evaluator = evaluators[evaluator_name] - - # Get the path of the TOSA flatbuffer that is dumped - intermediates_path = Path(intermediates) - tosa_paths = list(intermediates_path.glob("*.tosa")) - - if evaluator.REQUIRES_CONFIG: - assert evaluator_config is not None - - config_path = Path(evaluator_config) - with config_path.open() as f: - config = json.load(f) - - if evaluator_name == "mv2": - init_evaluator = evaluator( - model_name, - model_fp32, - model_int8, - example_inputs, - str(tosa_paths[0]), - config["batch_size"], - config["validation_dataset_path"], - ) - else: - raise RuntimeError(f"Unknown evaluator {evaluator_name}") - else: - init_evaluator = evaluator( - model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0]) - ) - - quant_metrics = init_evaluator.evaluate() - output_json_path = intermediates_path / "quant_metrics.json" - - with output_json_path.open("w") as json_file: - json.dump(quant_metrics, json_file) + return compile_spec def dump_delegation_info(edge, intermediate_files_folder: Optional[str] = None): @@ -489,7 +467,7 @@ def get_args(): "-m", "--model_name", required=True, - help=f"Model file .py/.pth/.pt, builtin model or a model from examples/models. Valid names: {set(list(models.keys())+list(MODEL_NAME_TO_MODEL.keys()))}", + help=f"Model file .py/.pth/.pt or a model from examples/models. Valid names: {set(MODEL_NAME_TO_MODEL.keys())}", ) parser.add_argument( "--model_input", @@ -525,8 +503,8 @@ def get_args(): action="store", required=False, default="ethos-u55-128", - choices=targets, - help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {targets}", + choices=TARGETS, + help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {TARGETS}", ) parser.add_argument( "-e", @@ -534,7 +512,7 @@ def get_args(): required=False, nargs="?", const="generic", - choices=["generic", "mv2"], + choices=["generic", "mv2", "deit_tiny", "resnet18"], help="Flag for running evaluation of the model.", ) parser.add_argument( @@ -592,7 +570,7 @@ def get_args(): "--config", required=False, default="Arm/vela.ini", - help="Specify custom vela configuration file (vela.ini)", + help="Specify custom vela configuration file (vela.ini) for Ethos-U targets.", ) parser.add_argument( "--non_strict_export", @@ -604,7 +582,13 @@ def get_args(): parser.add_argument( "--enable_qdq_fusion_pass", action="store_true", - help="Enable the QuantizedOpFusionPass fusion step", + help="Enable the Quantized qdq fusion Op passes", + ) + parser.add_argument( + "--enable_debug_mode", + required=False, + choices=["json", "tosa"], + help="Flag to enable ATen-to-TOSA debug mode and dumping of Vela's debug database.", ) args = parser.parse_args() @@ -624,9 +608,9 @@ def get_args(): torch.ops.load_library(args.so_library) if ( - args.model_name in models.keys() + args.model_name in MODELS.keys() and args.delegate is True - and models[args.model_name].can_delegate is False + and MODELS[args.model_name].can_delegate is False ): raise RuntimeError(f"Model {args.model_name} cannot be delegated.") @@ -711,25 +695,36 @@ def save_bpte_program(exec_prog, original_model: torch.nn.Module, output_name: s save_bundled_program(exec_prog, method_test_suites, output_name) -def quantize_model(args, model: torch.nn.Module, example_inputs, compile_spec): - model_int8 = quantize( +def quantize_model( + args, + model: GraphModule, + example_inputs: Tuple[torch.Tensor], + compile_spec, +) -> Tuple[GraphModule, ExportedProgram]: + + is_int16x8 = True if args.target == "TOSA-1.0+INT+int16" else False + model_quant = quantize( model, args.model_name, compile_spec, example_inputs, args.evaluate, args.evaluate_config, + is_int16x8, ) # Wrap quantized model back into an exported_program exported_program = torch.export.export( - model_int8, example_inputs, strict=args.strict_export + model_quant, example_inputs, strict=args.strict_export ) - return model_int8, exported_program + return model_quant, exported_program def to_edge_TOSA_delegate( - exported_program, args, model: torch.nn.Module, example_inputs + exported_program: ExportedProgram, + args, + model: GraphModule, + example_inputs: Tuple[torch.Tensor], ): # As we can target multiple output encodings, one must # be specified. @@ -740,23 +735,16 @@ def to_edge_TOSA_delegate( args.memory_mode, args.quantize, args.config, + args.enable_debug_mode, ) - model_int8 = None + model_quant = None if args.quantize: - model_int8, exported_program = quantize_model( + model_quant, exported_program = quantize_model( args, model, example_inputs, compile_spec ) - model = model_int8 - - if is_ethosu(compile_spec): - partitioner = EthosUPartitioner(compile_spec) - elif is_tosa(compile_spec): - partitioner = TOSAPartitioner(compile_spec) - elif is_vgf(compile_spec): - partitioner = VgfPartitioner(compile_spec) - else: - raise RuntimeError(f"Unhandled compile spec: {compile_spec}") + + partitioner = create_partitioner(compile_spec) edge = to_edge_transform_and_lower( exported_program, @@ -766,11 +754,16 @@ def to_edge_TOSA_delegate( ), ) - return model_int8, edge + return model_quant, edge -def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_inputs): - model_int8 = None +def to_edge_no_delegate( + exported_program: ExportedProgram, + args, + model: GraphModule, + example_inputs: Tuple[torch.Tensor], +): + model_quant = None if args.quantize: # As we can target multiple output encodings, one must # be specified. @@ -781,11 +774,12 @@ def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_ args.memory_mode, args.quantize, args.config, + args.enable_debug_mode, ) model, exported_program = quantize_model( args, model, example_inputs, compile_spec ) - model_int8 = model + model_quant = model edge = to_edge_transform_and_lower( exported_program, @@ -794,25 +788,27 @@ def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_ ), ) - return model_int8, edge + return model_quant, edge -def transform_for_cortex_m_backend(edge, args): +def transform_for_cortex_m_backend(edge_program_manager, args): # Let's make sure we are using optimized Cortex M backend # NB: If we can't find and replace ops those are expected to be replaced, # bad things will happen at runtime, like "missing operator" errors! # Instantiate the mandatory ReplaceQuantNodesPass - passes = [ReplaceQuantNodesPass()] - - # Conditionally add the QuantizedOpFusionPass + passes = [ReplaceQuantNodesPass] if args.enable_qdq_fusion_pass: - passes.append(QuantizedOpFusionPass()) - - # Apply the passes - edge = edge.transform(passes) - - return edge + passes += [ConvertToCortexMPass, QuantizedOpFusionPass] + current_edge = edge_program_manager + for pass_cls in passes: + transform_pass = ( + pass_cls(current_edge.exported_program()) + if pass_cls.__name__ == "QuantizedLinearFusionPass" + else pass_cls() + ) + current_edge = current_edge.transform([transform_pass]) + return current_edge if __name__ == "__main__": # noqa: C901 @@ -829,20 +825,29 @@ def transform_for_cortex_m_backend(edge, args): exported_program = torch.export.export( model, example_inputs, strict=args.strict_export ) + model = exported_program.module() model_fp32 = model + model_name = os.path.basename(os.path.splitext(args.model_name)[0]) if args.intermediates: os.makedirs(args.intermediates, exist_ok=True) + # We only support Python3.10 and above, so use a later pickle protocol + torch.export.save( + exported_program, + f"{args.intermediates}/{model_name}_exported_program.pt2", + pickle_protocol=5, + ) + # Quantize if required - model_int8 = None + model_quant = None if args.delegate: - model_int8, edge = to_edge_TOSA_delegate( + model_quant, edge = to_edge_TOSA_delegate( exported_program, args, model, example_inputs ) else: - model_int8, edge = to_edge_no_delegate( + model_quant, edge = to_edge_no_delegate( exported_program, args, model, example_inputs ) @@ -867,7 +872,6 @@ def transform_for_cortex_m_backend(edge, args): else: raise e - model_name = os.path.basename(os.path.splitext(args.model_name)[0]) output_name = f"{model_name}" + ( f"_arm_delegate_{args.target}" if args.delegate is True @@ -903,7 +907,7 @@ def transform_for_cortex_m_backend(edge, args): if args.bundleio: # Realize the quantization impact on numerics when generating reference output - reference_model = original_model if not model_int8 else model_int8 + reference_model = original_model if not model_quant else model_quant save_bpte_program(exec_prog, reference_model, output_file_name) print(f"Bundle PTE file saved as {output_file_name}") else: @@ -914,8 +918,9 @@ def transform_for_cortex_m_backend(edge, args): evaluate_model( args.model_name, args.intermediates, + args.target, model_fp32, - model_int8, + model_quant, example_inputs, args.evaluate, args.evaluate_config, diff --git a/examples/arm/asan/CMakeLists.txt b/examples/arm/asan/CMakeLists.txt new file mode 100644 index 00000000000..9d7960fe6ac --- /dev/null +++ b/examples/arm/asan/CMakeLists.txt @@ -0,0 +1,21 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +add_library(executorch_asan STATIC asan_runtime.c) + +target_compile_features(executorch_asan PRIVATE c_std_11) + +target_compile_options( + executorch_asan PRIVATE -fno-sanitize=address -fno-sanitize=kernel-address + -fno-sanitize=undefined +) + +set_target_properties(executorch_asan PROPERTIES OUTPUT_NAME "asan") + +install( + TARGETS executorch_asan + EXPORT ExecuTorchTargets + ARCHIVE DESTINATION lib +) diff --git a/examples/arm/asan/asan_runtime.c b/examples/arm/asan/asan_runtime.c new file mode 100644 index 00000000000..6d7441c4d4e --- /dev/null +++ b/examples/arm/asan/asan_runtime.c @@ -0,0 +1,462 @@ +/* Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * A lightweight AddressSanitizer runtime tailored for the ExecuTorch bare + * metal examples. The goal is to provide basic memory safety diagnostics while + * keeping the runtime self-contained. + * + * This implementation shares the following characteristics: + * * Shadow memory resolution is 16 bytes per shadow byte. + * * Only coarse grained poisoning is implemented. Consumers should rely on + * __asan_poison_memory_region / __asan_unpoison_memory_region to describe + * invalid regions (for example heap red-zones). + * * Stack poisoning is not implemented: the stack malloc/free stubs fall back + * to the compiler inserted slow path. This keeps the runtime small while + * still enabling heap / global diagnostics. + * * The runtime prints diagnostics and traps on the first detected error. + * + * Note that this does not aim to be a drop-in replacement for compiler-rt's + * runtime. It is intentionally minimal to suit resource constrained bare-metal + * targets and to mirror the structure of the existing ubsan runtime. + */ + +#include +#include +#include +#include +#include + +#ifndef ASAN_RUNTIME_PREFIX +#define ASAN_RUNTIME_PREFIX "[ASAN] " +#endif + +/* Stringification needs two layers so macro arguments expand before '#'. */ +#define ASAN_STRINGIZE_IMPL(x) #x +#define ASAN_STRINGIZE(x) ASAN_STRINGIZE_IMPL(x) + +/* Memory map extracted from the Corstone linker scripts. Update with care. */ +#define ASAN_ITCM_START 0x10000000u +#define ASAN_ITCM_SIZE 0x00080000u + +#define ASAN_BROM_START 0x11000000u +#define ASAN_BROM_SIZE 0x00020000u + +#define ASAN_BRAM_START 0x12000000u +#define ASAN_BRAM_SIZE 0x00200000u + +#define ASAN_DTCM_START 0x30000000u +#define ASAN_DTCM_SIZE 0x00080000u + +#define ASAN_SRAM_START 0x31000000u +#define ASAN_SRAM_SIZE 0x00200000u + +#define ASAN_DDR_START 0x70000000u +#define ASAN_DDR_SIZE 0x10000000u + +/* Shadow setup: 16 bytes of application memory are represented by 1 byte. */ +#define ASAN_SHADOW_SCALE 4u +#define ASAN_SHADOW_GRANULARITY (1u << ASAN_SHADOW_SCALE) +#define ASAN_SHADOW_MASK (ASAN_SHADOW_GRANULARITY - 1u) + +#define ASAN_SHADOW_SIZE(region_size) (((region_size) + ASAN_SHADOW_MASK) >> ASAN_SHADOW_SCALE) + +#define ASAN_SHADOW_SIZE_ITCM ASAN_SHADOW_SIZE(ASAN_ITCM_SIZE) +#define ASAN_SHADOW_SIZE_BROM ASAN_SHADOW_SIZE(ASAN_BROM_SIZE) +#define ASAN_SHADOW_SIZE_BRAM ASAN_SHADOW_SIZE(ASAN_BRAM_SIZE) +#define ASAN_SHADOW_SIZE_DTCM ASAN_SHADOW_SIZE(ASAN_DTCM_SIZE) +#define ASAN_SHADOW_SIZE_SRAM ASAN_SHADOW_SIZE(ASAN_SRAM_SIZE) +#define ASAN_SHADOW_SIZE_DDR ASAN_SHADOW_SIZE(ASAN_DDR_SIZE) + +#define ASAN_SHADOW_TOTAL_SIZE \ + (ASAN_SHADOW_SIZE_ITCM + ASAN_SHADOW_SIZE_BROM + ASAN_SHADOW_SIZE_BRAM + \ + ASAN_SHADOW_SIZE_DTCM + ASAN_SHADOW_SIZE_SRAM + ASAN_SHADOW_SIZE_DDR) + +/* Shadow memory lives in .asan_shadow so the linker can park it in DDR. */ +__attribute__((section(".asan_shadow"), aligned(16))) +static uint8_t g_asan_shadow[ASAN_SHADOW_TOTAL_SIZE]; + +typedef struct { + uintptr_t start; + uintptr_t end; + uint8_t* shadow; + size_t shadow_size; + const char* name; +} asan_region_t; + +static asan_region_t g_regions[] = { + {ASAN_ITCM_START, + ASAN_ITCM_START + ASAN_ITCM_SIZE, + NULL, + ASAN_SHADOW_SIZE_ITCM, + "ITCM"}, + {ASAN_BROM_START, + ASAN_BROM_START + ASAN_BROM_SIZE, + NULL, + ASAN_SHADOW_SIZE_BROM, + "BROM"}, + {ASAN_BRAM_START, + ASAN_BRAM_START + ASAN_BRAM_SIZE, + NULL, + ASAN_SHADOW_SIZE_BRAM, + "BRAM"}, + {ASAN_DTCM_START, + ASAN_DTCM_START + ASAN_DTCM_SIZE, + NULL, + ASAN_SHADOW_SIZE_DTCM, + "DTCM"}, + {ASAN_SRAM_START, + ASAN_SRAM_START + ASAN_SRAM_SIZE, + NULL, + ASAN_SHADOW_SIZE_SRAM, + "SRAM"}, + {ASAN_DDR_START, + ASAN_DDR_START + ASAN_DDR_SIZE, + NULL, + ASAN_SHADOW_SIZE_DDR, + "DDR"}, +}; + +static bool g_asan_initialized = false; + +typedef struct { + uintptr_t start; + uintptr_t end; + const asan_region_t* region; +} asan_check_result; + +typedef struct { + uintptr_t start; + uintptr_t end; + const char* name; +} asan_peripheral_range_t; + +/* Allow-list for memory-mapped peripherals that must bypass checking. */ +static const asan_peripheral_range_t g_peripheral_ranges[] = { + {0xE0000000u, 0xE0100000u, "SCS"}, + {0x40000000u, 0x60000000u, "Peripheral"}, +}; + +static bool asan_region_contains(const asan_region_t* region, + uintptr_t begin, + uintptr_t end) { + return region && begin >= region->start && end <= region->end; +} + +static const char* asan_region_name(const asan_region_t* region) { + return region ? region->name : ""; +} + +static void asan_shadow_set(asan_region_t* region, + uintptr_t begin, + size_t size, + uint8_t value) { + if (!region || size == 0 || region->shadow == NULL) { + return; + } + uintptr_t offset = begin - region->start; + size_t shadow_begin = offset >> ASAN_SHADOW_SCALE; + size_t shadow_end = + (offset + size + ASAN_SHADOW_MASK) >> ASAN_SHADOW_SCALE; + if (shadow_end > region->shadow_size) { + shadow_end = region->shadow_size; + } + if (shadow_begin >= shadow_end) { + return; + } + memset(region->shadow + shadow_begin, value, shadow_end - shadow_begin); +} + +static bool asan_shadow_is_poisoned(const asan_region_t* region, + uintptr_t begin, + size_t size) { + if (!region || size == 0 || region->shadow == NULL) { + return true; + } + uintptr_t offset = begin - region->start; + size_t shadow_begin = offset >> ASAN_SHADOW_SCALE; + size_t shadow_end = + (offset + size + ASAN_SHADOW_MASK) >> ASAN_SHADOW_SCALE; + if (shadow_end > region->shadow_size) { + shadow_end = region->shadow_size; + } + for (size_t idx = shadow_begin; idx < shadow_end; ++idx) { + if (region->shadow[idx] != 0) { + return true; + } + } + return false; +} + +static asan_region_t* asan_find_region(uintptr_t begin, uintptr_t end) { + for (size_t i = 0; i < sizeof(g_regions) / sizeof(g_regions[0]); ++i) { + asan_region_t* region = &g_regions[i]; + if (asan_region_contains(region, begin, end)) { + return region; + } + } + return NULL; +} + +static void asan_report_error(const char* kind, + void* addr, + size_t size, + const char* reason, + const asan_region_t* region) { + printf(ASAN_RUNTIME_PREFIX "%s of size %zu at %p failed: %s (region=%s)\n", + kind, + size, + addr, + reason, + asan_region_name(region)); + fflush(stdout); +#if defined(__GNUC__) + __builtin_trap(); +#else + while (1) { + } +#endif +} + +static bool asan_check_address(const char* kind, void* addr, size_t size) { + if (!g_asan_initialized) { + return true; + } + uintptr_t begin = (uintptr_t)addr; + uintptr_t end = begin + size; + for (size_t i = 0; i < sizeof(g_peripheral_ranges) / sizeof(g_peripheral_ranges[0]); ++i) { + const asan_peripheral_range_t* range = &g_peripheral_ranges[i]; + if (begin >= range->start && end <= range->end) { + return true; + } + } + if (end < begin) { + asan_report_error(kind, addr, size, "overflow in address range", NULL); + return false; + } + asan_region_t* region = asan_find_region(begin, end); + if (!region) { + asan_report_error(kind, addr, size, "address outside tracked regions", NULL); + return false; + } + if (asan_shadow_is_poisoned(region, begin, size)) { + asan_report_error(kind, addr, size, "poisoned shadow", region); + return false; + } + return true; +} + +/* ----------- Sanitizer runtime entry points ----------- */ + +int __asan_option_detect_stack_use_after_return = 0; + +void __asan_init(void) { + if (g_asan_initialized) { + return; + } + uint8_t* shadow_cursor = g_asan_shadow; + for (size_t i = 0; i < sizeof(g_regions) / sizeof(g_regions[0]); ++i) { + g_regions[i].shadow = shadow_cursor; + shadow_cursor += g_regions[i].shadow_size; + /* Mark entire region as accessible by default. */ + asan_shadow_set(&g_regions[i], g_regions[i].start, g_regions[i].end - g_regions[i].start, 0); + } + g_asan_initialized = true; +} + +void __asan_version_mismatch_check_v8(void) {} + +void __asan_handle_no_return(void) {} + +void __asan_poison_memory_region(void* addr, size_t size) { + if (!g_asan_initialized) { + return; + } + uintptr_t begin = (uintptr_t)addr; + uintptr_t end = begin + size; + asan_region_t* region = asan_find_region(begin, end); + if (!region) { + return; + } + asan_shadow_set(region, begin, size, 0xFF); +} + +void __asan_unpoison_memory_region(void* addr, size_t size) { + if (!g_asan_initialized) { + return; + } + uintptr_t begin = (uintptr_t)addr; + uintptr_t end = begin + size; + asan_region_t* region = asan_find_region(begin, end); + if (!region) { + return; + } + asan_shadow_set(region, begin, size, 0x00); +} + +void __asan_alloca_poison(void* addr, size_t size) { + if (!g_asan_initialized) { + return; + } + __asan_poison_memory_region(addr, size); +} + +void __asan_allocas_unpoison(void* top, void* bottom) { + if (!g_asan_initialized) { + return; + } + uintptr_t begin = (uintptr_t)bottom; + uintptr_t end = (uintptr_t)top; + if (end <= begin) { + return; + } + __asan_unpoison_memory_region(bottom, end - begin); +} + +#define ASAN_DEFINE_LOAD_NOABORT(N) \ + void __asan_load##N##_noabort(void* addr) { \ + (void)asan_check_address("load" ASAN_STRINGIZE(N), addr, N); \ + } + +ASAN_DEFINE_LOAD_NOABORT(1) +ASAN_DEFINE_LOAD_NOABORT(2) +ASAN_DEFINE_LOAD_NOABORT(4) +ASAN_DEFINE_LOAD_NOABORT(8) +ASAN_DEFINE_LOAD_NOABORT(16) + +#undef ASAN_DEFINE_LOAD_NOABORT + +void __asan_loadN_noabort(void* addr, size_t size) { + (void)asan_check_address("loadN", addr, size); +} + +#define ASAN_DEFINE_STORE_NOABORT(N) \ + void __asan_store##N##_noabort(void* addr) { \ + (void)asan_check_address("store" ASAN_STRINGIZE(N), addr, N); \ + } + +ASAN_DEFINE_STORE_NOABORT(1) +ASAN_DEFINE_STORE_NOABORT(2) +ASAN_DEFINE_STORE_NOABORT(4) +ASAN_DEFINE_STORE_NOABORT(8) +ASAN_DEFINE_STORE_NOABORT(16) + +#undef ASAN_DEFINE_STORE_NOABORT + +void __asan_storeN_noabort(void* addr, size_t size) { + (void)asan_check_address("storeN", addr, size); +} + +/* The compiler still emits the reporting entry points. Delegate to the same helper. */ +#define ASAN_DEFINE_REPORT_LOAD(N) \ + void __asan_report_load##N(void* addr) { \ + asan_report_error("load" ASAN_STRINGIZE(N), addr, N, "reported hook", NULL); \ + } + +ASAN_DEFINE_REPORT_LOAD(1) +ASAN_DEFINE_REPORT_LOAD(2) +ASAN_DEFINE_REPORT_LOAD(4) +ASAN_DEFINE_REPORT_LOAD(8) +ASAN_DEFINE_REPORT_LOAD(16) + +#undef ASAN_DEFINE_REPORT_LOAD + +void __asan_report_load_n(void* addr, size_t size) { + asan_report_error("loadN", addr, size, "reported hook", NULL); +} + +#define ASAN_DEFINE_REPORT_STORE(N) \ + void __asan_report_store##N(void* addr) { \ + asan_report_error("store" ASAN_STRINGIZE(N), addr, N, "reported hook", NULL); \ + } + +ASAN_DEFINE_REPORT_STORE(1) +ASAN_DEFINE_REPORT_STORE(2) +ASAN_DEFINE_REPORT_STORE(4) +ASAN_DEFINE_REPORT_STORE(8) +ASAN_DEFINE_REPORT_STORE(16) + +#undef ASAN_DEFINE_REPORT_STORE + +void __asan_report_store_n(void* addr, size_t size) { + asan_report_error("storeN", addr, size, "reported hook", NULL); +} + +/* Stubbed APIs required by the instrumentation. + * Intentional no-ops: we rely on the compiler slow path to handle stack + * poisoning so the runtime stays minimal. */ +#define ASAN_STACK_MALLOC_FREE(N) \ + void* __asan_stack_malloc_##N(size_t size) { \ + (void)size; \ + return NULL; /* fall back to compiler slow path */ \ + } \ + \ + void __asan_stack_free_##N(void* ptr, size_t size) { \ + (void)ptr; \ + (void)size; \ + } + +ASAN_STACK_MALLOC_FREE(0) +ASAN_STACK_MALLOC_FREE(1) +ASAN_STACK_MALLOC_FREE(2) +ASAN_STACK_MALLOC_FREE(3) +ASAN_STACK_MALLOC_FREE(4) +ASAN_STACK_MALLOC_FREE(5) +ASAN_STACK_MALLOC_FREE(6) +ASAN_STACK_MALLOC_FREE(7) +ASAN_STACK_MALLOC_FREE(8) +ASAN_STACK_MALLOC_FREE(9) +ASAN_STACK_MALLOC_FREE(10) + +#undef ASAN_STACK_MALLOC_FREE + +struct __asan_global { + uintptr_t beg; + size_t size; + size_t size_with_redzone; + const char* name; + const char* module_name; + uintptr_t has_dynamic_init; + uintptr_t location; + uintptr_t odr_indicator; +}; + +void __asan_register_globals(struct __asan_global* globals, size_t n) { + if (!g_asan_initialized) { + return; + } + for (size_t i = 0; i < n; ++i) { + __asan_unpoison_memory_region((void*)globals[i].beg, globals[i].size); + } +} + +void __asan_unregister_globals(struct __asan_global* globals, size_t n) { + if (!g_asan_initialized) { + return; + } + for (size_t i = 0; i < n; ++i) { + __asan_poison_memory_region((void*)globals[i].beg, globals[i].size); + } +} + +void __asan_register_image_globals(struct __asan_global* globals, + size_t n) { + __asan_register_globals(globals, n); +} + +void __asan_unregister_image_globals(struct __asan_global* globals, + size_t n) { + __asan_unregister_globals(globals, n); +} + +/* Weak aliases so that missing hooks do not cause link failures. */ +void __asan_before_dynamic_init(const char* module_name) { + (void)module_name; +} + +void __asan_after_dynamic_init(void) {} + +__attribute__((constructor)) static void asan_runtime_constructor(void) { + __asan_init(); +} diff --git a/examples/arm/ethos-u-setup/core_platform/0001-Remove-hello_world-from-applications.patch b/examples/arm/ethos-u-setup/core_platform/0001-Remove-hello_world-from-applications.patch new file mode 100644 index 00000000000..11590a8578f --- /dev/null +++ b/examples/arm/ethos-u-setup/core_platform/0001-Remove-hello_world-from-applications.patch @@ -0,0 +1,25 @@ +From f6a7d867212336b3e344c21240a2a03671bffd65 Mon Sep 17 00:00:00 2001 +From: Per Held +Date: Wed, 17 Sep 2025 13:46:05 +0200 +Subject: Remove hello_world from applications + +--- + applications/CMakeLists.txt | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/applications/CMakeLists.txt b/applications/CMakeLists.txt +index a017575..130f0f7 100644 +--- a/applications/CMakeLists.txt ++++ b/applications/CMakeLists.txt +@@ -21,7 +21,7 @@ add_subdirectory(driver_unit_tests) + + add_subdirectory(freertos) + +-add_subdirectory(hello_world) ++#add_subdirectory(hello_world) + + add_subdirectory(threadx_demo) + +-- +2.43.0 + diff --git a/examples/arm/ethos_u_minimal_example.ipynb b/examples/arm/ethos_u_minimal_example.ipynb index e63a7d37e58..ac6e53564eb 100644 --- a/examples/arm/ethos_u_minimal_example.ipynb +++ b/examples/arm/ethos_u_minimal_example.ipynb @@ -58,7 +58,7 @@ "model = Add()\n", "model = model.eval()\n", "exported_program = torch.export.export(model, example_inputs)\n", - "graph_module = exported_program.module()\n", + "graph_module = exported_program.graph_module\n", "\n", "_ = graph_module.print_readable()" ] @@ -80,7 +80,7 @@ "metadata": {}, "outputs": [], "source": [ - "from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder\n", + "from executorch.backends.arm.ethosu import EthosUCompileSpec\n", "from executorch.backends.arm.quantizer import (\n", " EthosUQuantizer,\n", " get_symmetric_quantization_config,\n", @@ -90,13 +90,12 @@ "# Create a compilation spec describing the target for configuring the quantizer\n", "# Some args are used by the Arm Vela graph compiler later in the example. Refer to Arm Vela documentation for an\n", "# explanation of its flags: https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/OPTIONS.md\n", - "spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec(\n", + "compile_spec = EthosUCompileSpec(\n", " target=\"ethos-u55-128\",\n", " system_config=\"Ethos_U55_High_End_Embedded\",\n", " memory_mode=\"Shared_Sram\",\n", - " extra_flags=\"--output-format=raw --debug-force-regor\"\n", + " extra_flags=[\"--output-format=raw\", \"--debug-force-regor\"]\n", " )\n", - "compile_spec = spec_builder.build()\n", "\n", "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", "quantizer = EthosUQuantizer(compile_spec)\n", @@ -161,7 +160,7 @@ " config=ExecutorchBackendConfig(extract_delegate_segments=False)\n", " )\n", "\n", - "_ = executorch_program_manager.exported_program().module().print_readable()\n", + "_ = executorch_program_manager.exported_program().graph_module.print_readable()\n", "\n", "# Save pte file\n", "save_pte_program(executorch_program_manager, \"ethos_u_minimal_example.pte\")" @@ -186,7 +185,7 @@ "source": [ "%%bash\n", "# Ensure the arm-none-eabi-gcc toolchain and FVP:s are available on $PATH\n", - "source ethos-u-scratch/setup_path.sh\n", + "source arm-scratch/setup_path.sh\n", "\n", "# Build executorch libraries cross-compiled for arm baremetal to executorch/cmake-out-arm\n", "cmake --preset arm-baremetal \\\n", @@ -202,7 +201,7 @@ "outputs": [], "source": [ "%%bash \n", - "source ethos-u-scratch/setup_path.sh\n", + "source arm-scratch/setup_path.sh\n", "\n", "# Build example executor runner application to examples/arm/ethos_u_minimal_example\n", "cmake -DCMAKE_TOOLCHAIN_FILE=$(pwd)/ethos-u-setup/arm-none-eabi-gcc.cmake \\\n", @@ -233,7 +232,7 @@ "outputs": [], "source": [ "%%bash \n", - "source ethos-u-scratch/setup_path.sh\n", + "source arm-scratch/setup_path.sh\n", "\n", "# Run the example\n", "../../backends/arm/scripts/run_fvp.sh --elf=ethos_u_minimal_example/arm_executor_runner --target=ethos-u55-128" @@ -242,7 +241,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv (3.10.15)", + "display_name": "et_env", "language": "python", "name": "python3" }, @@ -256,7 +255,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/arm/example_modules/add.py b/examples/arm/example_modules/add.py index d29206083f8..a3063ea1b25 100644 --- a/examples/arm/example_modules/add.py +++ b/examples/arm/example_modules/add.py @@ -15,13 +15,15 @@ import torch +b = 2 + class myModelAdd(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): - return x + x + return x + x + b ModelUnderTest = myModelAdd() diff --git a/examples/arm/executor_runner/CMakeLists.txt b/examples/arm/executor_runner/CMakeLists.txt index ff6f73398c3..43c42068017 100644 --- a/examples/arm/executor_runner/CMakeLists.txt +++ b/examples/arm/executor_runner/CMakeLists.txt @@ -6,30 +6,59 @@ cmake_minimum_required(VERSION 3.20) project(arm_executor_runner) -option(SEMIHOSTING "Enable semihosting" OFF) -option( - ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE - "Set ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE to specify memory alloction pool size" - OFF -) option( ET_MODEL_PTE_ADDR "Place in memory that the PTE file is located/flashed, if set to OFF the PTE is built into the code as a big data area." OFF ) -option(ET_BUNDLE_IO "Set to compile in BundleIO support" OFF) -option(ET_ATOL "Set atol to use for BundleIO testing" OFF) -option(ET_RTOL "Set rtol to use for BundleIO testing" OFF) -option(ET_DUMP_INPUT "Dump input in log" OFF) -option(ET_DUMP_OUTPUT "Dump output in log" ON) -option(FETCH_ETHOS_U_CONTENT - "Fetch ethos_u dependencies instead of relying on pre-downloads" ON -) + set(ET_NUM_INFERENCES "1" CACHE STRING "Number of inferences to run" ) +option(ET_LOG_DUMP_INPUT "Dump input in log" OFF) +option(ET_LOG_DUMP_OUTPUT "Dump output in log" ON) + +option(ET_BUNDLE_IO "Set to compile in BundleIO support" OFF) +set(ET_ATOL + "0.01" + CACHE STRING "Set atol to use for BundleIO testing (Requires ET_BUNDLE_IO)" +) +set(ET_RTOL + "0.01" + CACHE STRING "Set atol to use for BundleIO testing (Requires ET_BUNDLE_IO)" +) + +option( + ET_DUMP_OUTPUTS + "Collect and print outputs as a base64 buffer in the log (Requires EXECUTORCH_ENABLE_EVENT_TRACER)" + OFF +) +option( + ET_DUMP_INTERMEDIATE_OUTPUTS + "Collect and print intermediate outputs as a base64 buffer in the log (Requires EXECUTORCH_ENABLE_EVENT_TRACER)" + OFF +) +set(ET_DEBUG_BUFFER_SIZE + "2097152" + CACHE + STRING + "Size of buffer to collect intermediate outputs/outputs buffers (Requires EXECUTORCH_ENABLE_EVENT_TRACER and ET_DUMP_OUTPUTS or ET_DUMP_INTERMEDIATE_OUTPUTS)" +) + +option(SEMIHOSTING "Enable semihosting" OFF) + +option( + ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE + "Set ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE to specify memory alloction pool size" + OFF +) + +option(FETCH_ETHOS_U_CONTENT + "Fetch ethos_u dependencies instead of relying on pre-downloads" ON +) + if(NOT DEFINED ET_MODEL_PTE_ADDR AND NOT DEFINED ET_PTE_FILE_PATH AND NOT DEFINED SEMIHOSTING @@ -61,7 +90,7 @@ set(ET_PTE_FILE_PATH CACHE PATH "Path to ExecuTorch model pte" ) set(ETHOS_SDK_PATH - "${ET_DIR_PATH}/examples/arm/ethos-u-scratch/ethos-u" + "${ET_DIR_PATH}/examples/arm/arm-scratch/ethos-u" CACHE PATH "Path to Ethos-U bare metal driver/env" ) set(PYTHON_EXECUTABLE @@ -206,10 +235,10 @@ list( -Map=arm_executor_runner.map ) -# Prefer to generate kernel bindings from model file if possible, which is when -# 1. Not building for semihosting 2. Not building with bundleio If that is not -# the case, fallback to select_ops_list If the model file does not contain any -# aten ops, a workaround is currently needed to avoid crashing. +# Figure out which ops to include: For semihosting build, use +# (user-set)SELECT_OPS_MODEL variable. For normal build, use +# EXECUTORCH_SELECT_OPS_MODEL to include ops automatically. If the pte contains +# no undelegated ops, use neither. execute_process( COMMAND python "${ET_DIR_PATH}/codegen/tools/gen_oplist.py" @@ -235,11 +264,6 @@ elseif(${FOUND_OPS_IN_FILE}) message( "gen_oplist: EXECUTORCH_SELECT_OPS_MODEL=${ET_PTE_FILE_PATH} is used to auto generate ops from" ) -elseif(NOT ${FOUND_OPS_IN_FILE} AND ${ET_BUNDLE_IO}) - set(EXECUTORCH_SELECT_OPS_MODEL "") - message( - "gen_oplist: Building with ET_BUNDLE_IO and .bpte is not supported to auto generate ops from will use EXECUTORCH_SELECT_OPS_LIST=${EXECUTORCH_SELECT_OPS_LIST}" - ) else() set(EXECUTORCH_SELECT_OPS_LIST "") set(EXECUTORCH_SELECT_OPS_MODEL "") @@ -302,12 +326,44 @@ endif() # Need whole-archive to ensure C++ ctor's are called - this may be wasteful for # bin size as we link in a number of other symbols -target_link_libraries(arm_executor_runner ${arm_executor_runner_link}) +target_link_libraries(arm_executor_runner PUBLIC ${arm_executor_runner_link}) target_link_options( arm_executor_runner PUBLIC LINKER:-Map=arm_executor_runner.map ) +# Sanitizers +if(CMAKE_BUILD_TYPE MATCHES "UndefinedSanitizer") + set(_et_runner_ubsan_flag -fsanitize=undefined) + target_compile_options(arm_executor_runner PRIVATE ${_et_runner_ubsan_flag}) + target_link_options(arm_executor_runner PRIVATE ${_et_runner_ubsan_flag}) + if(NOT TARGET executorch_ubsan) + add_subdirectory( + ${ET_DIR_PATH}/examples/arm/ubsan + ${CMAKE_CURRENT_BINARY_DIR}/ubsan_runtime + ) + endif() + target_link_directories( + arm_executor_runner PRIVATE $ + ) + target_link_libraries(arm_executor_runner PRIVATE executorch_ubsan) +endif() + +if(CMAKE_BUILD_TYPE MATCHES "AddressSanitizer") + set(_et_runner_asan_flags -fsanitize=kernel-address -fasan-shadow-offset=0x0) + target_compile_options(arm_executor_runner PRIVATE ${_et_runner_asan_flags}) + target_link_options(arm_executor_runner PRIVATE ${_et_runner_asan_flags}) + if(NOT TARGET executorch_asan) + add_subdirectory( + ${ET_DIR_PATH}/examples/arm/asan ${CMAKE_CURRENT_BINARY_DIR}/asan_runtime + ) + endif() + target_link_libraries(arm_executor_runner PRIVATE executorch_asan) + target_compile_definitions( + arm_executor_runner PRIVATE EXECUTORCH_ENABLE_ADDRESS_SANITIZER + ) +endif() + # ET headers and generated headers includes target_include_directories( arm_executor_runner @@ -322,37 +378,29 @@ if(NOT ${ET_MODEL_PTE_ADDR} AND NOT SEMIHOSTING) add_dependencies(arm_executor_runner gen_model_header) endif() -if(SEMIHOSTING) - target_compile_definitions(arm_executor_runner PUBLIC SEMIHOSTING) -endif() - -if(ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE) +if(ET_MODEL_PTE_ADDR) target_compile_definitions( - arm_executor_runner - PUBLIC - ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE} + arm_executor_runner PUBLIC -DET_MODEL_PTE_ADDR=${ET_MODEL_PTE_ADDR} ) endif() -target_compile_definitions( - arm_executor_runner - PUBLIC - ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE} -) -if(DEFINED ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE) +if(ET_NUM_INFERENCES) target_compile_definitions( - arm_executor_runner - PUBLIC - ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE} + arm_executor_runner PUBLIC ET_NUM_INFERENCES=${ET_NUM_INFERENCES} ) endif() -if(ET_MODEL_PTE_ADDR) - target_compile_definitions( - arm_executor_runner PUBLIC -DET_MODEL_PTE_ADDR=${ET_MODEL_PTE_ADDR} - ) +if(ET_LOG_DUMP_INPUT) + target_compile_definitions(arm_executor_runner PUBLIC -DET_LOG_DUMP_INPUT) endif() +if(ET_LOG_DUMP_OUTPUT) + target_compile_definitions(arm_executor_runner PUBLIC -DET_LOG_DUMP_OUTPUT) +endif() + +# Devtool BundleIO: Use Bundle PTE with input and reference output included to +# check if it matches. + if(ET_BUNDLE_IO) target_compile_definitions(arm_executor_runner PUBLIC -DET_BUNDLE_IO) endif() @@ -365,17 +413,50 @@ if(ET_RTOL) target_compile_definitions(arm_executor_runner PUBLIC ET_RTOL=${ET_RTOL}) endif() -if(ET_DUMP_INPUT) - target_compile_definitions(arm_executor_runner PUBLIC -DET_DUMP_INPUT) +# Devtools ETDump: Speed and dumping output + +if(ET_DUMP_OUTPUTS) + target_compile_definitions(arm_executor_runner PUBLIC -DET_DUMP_OUTPUTS) endif() -if(ET_DUMP_OUTPUT) - target_compile_definitions(arm_executor_runner PUBLIC -DET_DUMP_OUTPUT) +if(ET_DUMP_INTERMEDIATE_OUTPUTS) + target_compile_definitions( + arm_executor_runner PUBLIC -DET_DUMP_INTERMEDIATE_OUTPUTS + ) endif() -if(ET_NUM_INFERENCES) +if(ET_DEBUG_BUFFER_SIZE) target_compile_definitions( - arm_executor_runner PUBLIC ET_NUM_INFERENCES=${ET_NUM_INFERENCES} + arm_executor_runner PUBLIC ET_DEBUG_BUFFER_SIZE=${ET_DEBUG_BUFFER_SIZE} + ) +endif() + +# Semihosting FVP (FVP Simulator can access host filesystem) + +if(SEMIHOSTING) + target_compile_definitions(arm_executor_runner PUBLIC SEMIHOSTING) +endif() + +# Memory buffer sizes for Executorch flow + +if(ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE) + target_compile_definitions( + arm_executor_runner + PUBLIC + ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE} + ) +endif() + +target_compile_definitions( + arm_executor_runner + PUBLIC + ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE} +) +if(DEFINED ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE) + target_compile_definitions( + arm_executor_runner + PUBLIC + ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE} ) endif() diff --git a/examples/arm/executor_runner/Corstone-300.ld b/examples/arm/executor_runner/Corstone-300.ld index f5b063a35c6..e5f5b2a1410 100644 --- a/examples/arm/executor_runner/Corstone-300.ld +++ b/examples/arm/executor_runner/Corstone-300.ld @@ -237,6 +237,13 @@ SECTIONS * (ethosu_core_out_queue) . = ALIGN(4); } > DDR :rom_dram + .asan_shadow (NOLOAD) : + { + . = ALIGN(16); + __asan_shadow_start = .; + KEEP(*(.asan_shadow)) + __asan_shadow_end = .; + } > DDR :null .ddr_noload (NOLOAD) : { . = ALIGN(16); diff --git a/examples/arm/executor_runner/Corstone-320.ld b/examples/arm/executor_runner/Corstone-320.ld index 62bb6240913..8f9b1e826a1 100644 --- a/examples/arm/executor_runner/Corstone-320.ld +++ b/examples/arm/executor_runner/Corstone-320.ld @@ -252,6 +252,13 @@ SECTIONS /* Place data for scatter loading here */ __etext = .; } > DDR :rom_dram + .asan_shadow (NOLOAD) : + { + . = ALIGN(16); + __asan_shadow_start = .; + KEEP(*(.asan_shadow)) + __asan_shadow_end = .; + } > DDR :null .ddr_noload (NOLOAD) : { . = ALIGN(16); diff --git a/examples/arm/executor_runner/arm_executor_runner.cpp b/examples/arm/executor_runner/arm_executor_runner.cpp index d56710e27ad..89ebcd292f7 100644 --- a/examples/arm/executor_runner/arm_executor_runner.cpp +++ b/examples/arm/executor_runner/arm_executor_runner.cpp @@ -6,10 +6,10 @@ * LICENSE file in the root directory of this source tree. */ -/* This is an example executorch runner running on Arm Cortex-m and Ethos-U +/* This is an example ExecuTorch runner running on Arm Cortex-M and Ethos-U * based hardware. This example tries to illustrate a few ways to use ExecuTorch * and you can use it as is or remove the unneeded parts. Please use this code - * as inpiration. + * as inspiration. * * Some defines used to configure the code: * @@ -20,24 +20,43 @@ * that is controlled by your memory mode via the * ETHOSU_MODEL cmake parameter. * If SEMIHOSTING is define this is not used - * ET_DUMP_INPUT - Control if you want input to be dumped to the log. - * ET_DUMP_OUTPUT - Control if you want output to be dumped to the log. - * ET_BUNDLE_IO - Build in devtools BundelIO, this makes it possible to + * ET_NUM_INFERENCES - Numbers of times to run the inference + * ET_LOG_DUMP_INPUT - Control if you want input to be dumped to the log. + * ET_LOG_DUMP_OUTPUT - Control if you want output to be dumped to the log. + * + * Devtool BundleIO: Use Bundle PTE with input and reference output included to + * check if it matches. + * + * ET_BUNDLE_IO - Build in Devtools BundleIO, this makes it possible to * use bpte with bundled input and output refdata to * compare output. * See also ET_ATOL and ET_RTOL - * ET_ATOL - The atol used to compare the output and ref data when - * using ET_BUNDLE_IO - * ET_RTOL - The rtol used to compare the output and ref data when - * using ET_BUNDLE_IO - * ET_EVENT_TRACER_ENABLED - Build in devtools event trace code to generate - * ETDump and print it base64 coded of it in the logs - * so you can get it out of your embedded target. - * This can be used to benchmark where time is spent. - * If you run on Ethos-U the delegate/commandstream - * is run in one go, this means that per op - * measurements is not possible. - * Warning: CPU time meassurements is NOT possible in the FVP simulator and a + * ET_ATOL - The atol used to compare the output and ref data + * when using ET_BUNDLE_IO ET_RTOL - The rtol used to compare the + * output and ref data when using ET_BUNDLE_IO + * + * Devtools ETDump: Speed and dumping output + * + * ET_EVENT_TRACER_ENABLED - Build in Devtools ETDump event trace code + * to generate cycle data and print it base64 + * coded in the log so you can get it out of + * your embedded target. This can be used to + * benchmark where time is spent. If you run + * on Ethos-U the delegate/commandstream is + * run in one go, this means that per op + * measurements is not possible. + * ET_DUMP_OUTPUTS - Collect and print outputs as a base64 buffer + * in the log, see ExecuTorch Devtools for more + * info. (Requires ET_EVENT_TRACER_ENABLED) + * ET_DUMP_INTERMEDIATE_OUTPUTS - Collect and print intermediate outputs as a + * base64 buffer in the log, see ExecuTorch + * Devtools for more info. + * (Requires ET_EVENT_TRACER_ENABLED) + * ET_DEBUG_BUFFER_SIZE - Override the size of memory area used by + * ET_DUMP_OUTPUTS or + * ET_DUMP_INTERMEDIATE_OUTPUTS + * + * Warning: CPU time measurements is NOT possible in the FVP simulator and a * real target or FPGA must be used. NPU number are roughly OK, and can be used * as guidance if timeing adaptor values are set correctly. * @@ -54,11 +73,12 @@ * left over memory after code is linked. This needs to be big enough to fit * and run your model. In our example using the FVP simulator we have much * memory and set this quite high to be able to test larger models. - * Regarding heap/mallocs type of allocation from executorch, + * Regarding heap/mallocs type of allocation from ExecuTorch, * et_pal_allocate() is not implemented or needed. * - * ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE - Size of memory area - * used when setting up the model + * ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE - Size of memory area + * used when setting up + * the model * ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE - Size of memory area * used when running * inferences @@ -67,6 +87,7 @@ #include #include #include +#include #include #include #include @@ -75,6 +96,7 @@ #include #include #include +#include #include #include "arm_memory_allocator.h" @@ -86,10 +108,21 @@ #if defined(ET_EVENT_TRACER_ENABLED) #include + +#if defined(ET_DUMP_INTERMEDIATE_OUTPUTS) || defined(ET_DUMP_OUTPUTS) +#include + +#if !defined(ET_DEBUG_BUFFER_SIZE) +#define ET_DEBUG_BUFFER_SIZE (2 * 1024 * 1024) +#endif + +#endif + #if !defined(SEMIHOSTING) #include #endif -#endif + +#endif // defined(ET_EVENT_TRACER_ENABLED) #if defined(SEMIHOSTING) @@ -152,14 +185,17 @@ using executorch::runtime::Result; using executorch::runtime::Span; using executorch::runtime::Tag; using executorch::runtime::TensorInfo; +using executorch::runtime::toString; #if defined(ET_BUNDLE_IO) using executorch::bundled_program::compute_method_output_error_stats; using executorch::bundled_program::ErrorStats; using executorch::bundled_program::verify_method_outputs; #endif #if defined(ET_EVENT_TRACER_ENABLED) +using executorch::etdump::BufferDataSink; using executorch::etdump::ETDumpGen; using executorch::etdump::ETDumpResult; +using executorch::runtime::EventTracerDebugLogLevel; using torch::executor::etdump_result; #endif /** @@ -362,6 +398,19 @@ class Box { } }; +template +void fill_tensor_with_default_value(Tensor& tensor) { + ValueType fill_value{}; + if constexpr (std::is_same_v) { + fill_value = true; + } else { + fill_value = ValueType(1); + } + + ValueType* data_ptr = tensor.mutable_data_ptr(); + std::fill(data_ptr, data_ptr + tensor.numel(), fill_value); +} + Error prepare_input_tensors( Method& method, MemoryAllocator& allocator, @@ -377,8 +426,7 @@ Error prepare_input_tensors( "Wrong number of inputs allocated compared to method"); #endif - EValue* input_evalues = - static_cast(allocator.allocate(num_inputs * sizeof(EValue*))); + EValue* input_evalues = allocator.allocateList(num_inputs); ET_CHECK_OR_RETURN_ERROR( input_evalues != nullptr, MemoryAllocationFailed, @@ -420,23 +468,18 @@ Error prepare_input_tensors( if (input_evalues[i].isTensor()) { Tensor& tensor = input_evalues[i].toTensor(); switch (tensor.scalar_type()) { - case ScalarType::Int: - std::fill( - tensor.mutable_data_ptr(), - tensor.mutable_data_ptr() + tensor.numel(), - 1); - break; - case ScalarType::Float: - std::fill( - tensor.mutable_data_ptr(), - tensor.mutable_data_ptr() + tensor.numel(), - 1.0); - break; - case ScalarType::Char: - std::fill( - tensor.mutable_data_ptr(), - tensor.mutable_data_ptr() + tensor.numel(), - 1); +#define HANDLE_SCALAR_TYPE(cpp_type, scalar_name) \ + case ScalarType::scalar_name: \ + fill_tensor_with_default_value(tensor); \ + break; + ET_FORALL_SCALAR_TYPES(HANDLE_SCALAR_TYPE) +#undef HANDLE_SCALAR_TYPE + default: + ET_LOG( + Error, + "Unhandled ScalarType %s", + toString(tensor.scalar_type())); + err = Error::InvalidArgument; break; } } else { @@ -505,6 +548,9 @@ struct RunnerContext { Box> method; #if defined(ET_EVENT_TRACER_ENABLED) Box etdump_gen; +#if defined(ET_DUMP_INTERMEDIATE_OUTPUTS) || defined(ET_DUMP_OUTPUTS) + void* debug_buffer; +#endif #endif #if defined(SEMIHOSTING) Box input_file_allocator; @@ -539,7 +585,7 @@ void runner_init( } #endif auto loader = BufferDataLoader(program_data, ctx.program_data_len); - ET_LOG(Info, "PTE Model data loaded. Size: %lu bytes.", ctx.program_data_len); + ET_LOG(Info, "PTE Model data loaded. Size: %zu bytes.", ctx.program_data_len); // Parse the program file. This is immutable, and can also be reused // between multiple execution invocations across multiple threads. @@ -552,7 +598,7 @@ void runner_init( program.error()); } - ET_LOG(Info, "Model buffer loaded, has %lu methods", program->num_methods()); + ET_LOG(Info, "Model buffer loaded, has %zu methods", program->num_methods()); { const auto method_name_result = program->get_method_name(0); @@ -572,7 +618,7 @@ void runner_init( ET_LOG( Info, - "Setup Method allocator pool. Size: %lu bytes.", + "Setup Method allocator pool. Size: %zu bytes.", method_allocation_pool_size); ctx.method_allocator.reset( @@ -622,7 +668,60 @@ void runner_init( ET_LOG(Info, "Setting up ETDump"); ctx.etdump_gen.reset(); event_tracer_ptr = &ctx.etdump_gen.value(); -#endif + +#if defined(ET_DUMP_INTERMEDIATE_OUTPUTS) || defined(ET_DUMP_OUTPUTS) + // Alloc debug buffer and create if and only if we need to log intermediate + // tensor outputs + ctx.debug_buffer = ctx.method_allocator->allocate(ET_DEBUG_BUFFER_SIZE, 16); + if (ctx.debug_buffer != nullptr) { + Span debug_buffer_span( + (uint8_t*)ctx.debug_buffer, ET_DEBUG_BUFFER_SIZE); + + Result result = + ctx.etdump_gen.value().set_debug_buffer(debug_buffer_span); + + if (result.ok()) { + // Everything worked, we got the buffer setup, lets enable output logging + // depending on the compile flag ET_DUMP_INTERMEDIATE_OUTPUTS e.g. + // kIntermediateOutputs or kProgramOutputs +#if defined(ET_DUMP_INTERMEDIATE_OUTPUTS) + ET_LOG( + Info, + "ETDump: Allocated intermediate output buffer size: %d at 0x%p", + ET_DEBUG_BUFFER_SIZE, + ctx.debug_buffer); + ctx.etdump_gen.value().set_event_tracer_debug_level( + EventTracerDebugLogLevel::kIntermediateOutputs); +#else // defined(ET_DUMP_INTERMEDIATE_OUTPUTS) + ET_LOG( + Info, + "ETDump: Allocated output buffer size: %d at 0x%p", + ET_DEBUG_BUFFER_SIZE, + ctx.debug_buffer); + ctx.etdump_gen.value().set_event_tracer_debug_level( + EventTracerDebugLogLevel::kProgramOutputs); +#endif // defined(ET_DUMP_INTERMEDIATE_OUTPUTS) + + } else { + // set_debug_buffer() failed + // Here we would free ctx.debug_buffer if it was possible, but we can't as + // the allocator don't support it. + ctx.debug_buffer = nullptr; + ET_LOG( + Error, + "ETDump: Could not set_debug_buffer() for output buffer size %zu error:0x%" PRIx32, + ET_DEBUG_BUFFER_SIZE, + result.error()); + } + } else { + // debug buffer allocation failed + ET_LOG( + Error, + "ETDump: Could not allocate memory for output buffer size %zu", + ET_DEBUG_BUFFER_SIZE); + } +#endif // defined(ET_DUMP_INTERMEDIATE_OUTPUTS) || defined(ET_DUMP_OUTPUTS) +#endif // defined(ET_EVENT_TRACER_ENABLED) ctx.method.reset( program->load_method(ctx.method_name, &memory_manager, event_tracer_ptr)); @@ -660,7 +759,7 @@ void runner_init( ET_CHECK_MSG( status == Error::Ok, "Failed to prepare inputs 0x%" PRIx32, status); } -#if defined(ET_DUMP_INPUT) +#if defined(ET_LOG_DUMP_INPUT) { std::vector inputs((*ctx.method.value())->inputs_size()); ET_LOG(Info, "%zu inputs: ", inputs.size()); @@ -712,22 +811,22 @@ void runner_init( ET_LOG(Info, "Input prepared."); } -void log_mem_status(const RunnerContext& ctx) { +void log_mem_status(RunnerContext& ctx) { size_t executor_memsize = ctx.method_allocator->used_size() - ctx.executor_membase; #if defined(ET_MODEL_PTE_ADDR) ET_LOG( Info, - "model_pte_program_size: %lu bytes. (pte size unknown when not baked into elf)", + "model_pte_program_size: %zu bytes. (pte size unknown when not baked into elf)", ctx.program_data_len); ET_LOG( Info, - "model_pte_loaded_size: %lu bytes. (pte size unknown when not baked into elf)", + "model_pte_loaded_size: %zu bytes. (pte size unknown when not baked into elf)", ctx.pte_size); #else - ET_LOG(Info, "model_pte_program_size: %lu bytes.", ctx.program_data_len); - ET_LOG(Info, "model_pte_loaded_size: %lu bytes.", ctx.pte_size); + ET_LOG(Info, "model_pte_program_size: %zu bytes.", ctx.program_data_len); + ET_LOG(Info, "model_pte_loaded_size: %zu bytes.", ctx.pte_size); #endif #if defined(SEMIHOSTING) @@ -765,6 +864,20 @@ void log_mem_status(const RunnerContext& ctx) { if (ctx.temp_allocator->size() > 0) { ET_LOG(Info, "temp_allocator: %zu", ctx.temp_allocator->size()); } +#if defined(ET_EVENT_TRACER_ENABLED) +#if defined(ET_DUMP_INTERMEDIATE_OUTPUTS) || defined(ET_DUMP_OUTPUTS) + if (ctx.debug_buffer != nullptr) { + size_t outputdump_len = ctx.etdump_gen->get_data_sink()->get_used_bytes(); + ET_LOG( + Info, + "ETDump_outputs_buffer: %zu / %zu free: %zu ( used: %zu %% ) ", + outputdump_len, + ET_DEBUG_BUFFER_SIZE, + ET_DEBUG_BUFFER_SIZE - outputdump_len, + 100 * outputdump_len / ET_DEBUG_BUFFER_SIZE); + } +#endif +#endif } void print_outputs(RunnerContext& ctx) { @@ -779,7 +892,7 @@ void print_outputs(RunnerContext& ctx) { if (outputs[i].isTensor()) { Tensor tensor = outputs[i].toTensor(); #if !defined(SEMIHOSTING) -#if defined(ET_DUMP_OUTPUT) +#if defined(ET_LOG_DUMP_OUTPUT) // The output might be collected and parsed so printf() is used instead // of ET_LOG() here for (int j = 0; j < tensor.numel(); ++j) { @@ -811,7 +924,7 @@ void print_outputs(RunnerContext& ctx) { } } #endif -#else +#else //! defined(SEMIHOSTING) char out_filename[255]; snprintf(out_filename, 255, "%s-%d.bin", ctx.output_basename, i); ET_LOG(Info, "Writing output to file: %s", out_filename); @@ -819,7 +932,7 @@ void print_outputs(RunnerContext& ctx) { auto written_size = fwrite(tensor.const_data_ptr(), 1, tensor.nbytes(), out_file); fclose(out_file); -#endif +#endif //! defined(SEMIHOSTING) } else { printf("Output[%d]: Not Tensor\n", i); } @@ -835,29 +948,96 @@ void write_etdump(RunnerContext& ctx) { if (result.buf != nullptr && result.size > 0) { // On a device with no file system we can't just write it out // to the file-system so we base64 encode it and dump it on the log. + bool dump_outputs = false; int mode = base64_enc_modifier_padding | base64_dec_modifier_skipspace; - size_t len = result.size; - size_t encoded_len = base64_encoded_size(result.size, mode); + size_t etdump_len = result.size; + size_t encoded_etdump_len = base64_encoded_size(etdump_len, mode); + size_t base64buffer_len = encoded_etdump_len; +#if defined(ET_DUMP_INTERMEDIATE_OUTPUTS) || defined(ET_DUMP_OUTPUTS) + // Make base64 buffer fit both so it can be reused istead of allocating two + // buffers. + size_t outputdump_len = 0; + size_t encoded_outputdump_len = 0; + if (ctx.debug_buffer != nullptr) { + outputdump_len = ctx.etdump_gen->get_data_sink()->get_used_bytes(); + if (outputdump_len > 0) { + encoded_outputdump_len = base64_encoded_size(outputdump_len, mode); + if (encoded_outputdump_len > 0) { + base64buffer_len = + std::max(encoded_etdump_len, encoded_outputdump_len); + dump_outputs = true; + } else { + ET_LOG( + Error, + "Problem getting the size of the base64 ETDump output buffers"); + } + } else { + ET_LOG(Error, "No ETDump output buffers saved in the data area"); + } + } +#endif + ET_LOG(Info, "[base64] buffer size: %d", base64buffer_len); + uint8_t* encoded_buf = reinterpret_cast( - ctx.method_allocator->allocate(encoded_len + 1)); + ctx.method_allocator->allocate(base64buffer_len + 1)); if (encoded_buf != nullptr) { - int ret = base64_encode( - encoded_buf, (uint8_t*)result.buf, &encoded_len, &len, mode); - encoded_buf[encoded_len] = 0x00; // Ensure null termination - ET_LOG(Info, "Writing etdump.bin [base64]"); + int ret; + const char* debug_buffer_flag = ""; + printf("#[RUN THIS]\n"); +#if defined(ET_DUMP_INTERMEDIATE_OUTPUTS) || defined(ET_DUMP_OUTPUTS) + if (dump_outputs) { + ret = base64_encode( + encoded_buf, + (uint8_t*)ctx.debug_buffer, + &encoded_outputdump_len, + &outputdump_len, + mode); + encoded_buf[encoded_outputdump_len] = 0x00; // Ensure null termination + printf("# Writing debug_buffer.bin [base64]\n"); + printf("echo \"%s\" | base64 -d >debug_buffer.bin\n", encoded_buf); + debug_buffer_flag = "--debug_buffer_path debug_buffer.bin"; + } +#endif + ret = base64_encode( + encoded_buf, + (uint8_t*)result.buf, + &encoded_etdump_len, + &etdump_len, + mode); + encoded_buf[encoded_etdump_len] = 0x00; // Ensure null termination + printf("# Writing etdump.bin [base64]\n"); + printf("echo \"%s\" | base64 -d >etdump.bin\n", encoded_buf); + + printf("# Generate cpu cycle table with:\n"); printf( - "#[RUN THIS]\necho \"%s\" | base64 -d >etdump.bin\npython3 -m devtools.inspector.inspector_cli --etdump_path etdump.bin --source_time_scale cycles --target_time_scale cycles\n#[END]\n", - encoded_buf); + "python3 -m devtools.inspector.inspector_cli --etdump_path etdump.bin %s --source_time_scale cycles --target_time_scale cycles\n", + debug_buffer_flag); + printf("#[END]\n"); + } else { ET_LOG( Error, "Could not allocate memory etdump base64 encoding size %zu", - encoded_len + 1); + encoded_etdump_len + 1); } } -#else - // Dump the etdump data containing profiling/debugging data to the specified - // file. +#else // !defined(SEMIHOSTING) +#if defined(ET_DUMP_INTERMEDIATE_OUTPUTS) || defined(ET_DUMP_OUTPUTS) + if (ctx.debug_buffer != nullptr) { + // Dump the etdump outputs data to a file. + size_t outputdump_len = ctx.etdump_gen->get_data_sink()->get_used_bytes(); + const char* etdump_output_filename = "debug_buffer.bin"; + ET_LOG( + Info, + "Writing etdump debug_buffer to file: %s", + etdump_output_filename); + FILE* f = fopen(etdump_output_filename, "w+"); + fwrite((uint8_t*)ctx.debug_buffer, 1, outputdump_len, f); + fclose(f); + } +#endif + + // Dump the etdump data containing profiling/debugging data to a file. etdump_result result = ctx.etdump_gen->get_etdump_data(); if (result.buf != nullptr && result.size > 0) { // On a device with a file system we can just write it out @@ -869,11 +1049,12 @@ void write_etdump(RunnerContext& ctx) { fclose(f); free(result.buf); } -#endif -#endif +#endif // !defined(SEMIHOSTING) +#endif // defined(ET_EVENT_TRACER_ENABLED) } -void verify_result(RunnerContext& ctx, const void* model_pte) { +bool verify_result(RunnerContext& ctx, const void* model_pte) { + bool model_ok = false; #if defined(ET_BUNDLE_IO) if (ctx.bundle_io) { // Check result @@ -899,6 +1080,7 @@ void verify_result(RunnerContext& ctx, const void* model_pte) { if (status == Error::Ok) { ET_LOG(Info, "Model output match expected BundleIO bpte ref data."); ET_LOG(Info, "TEST: BundleIO index[%d] Test_result: PASS", testset_idx); + model_ok = true; } else { ET_LOG( Error, @@ -906,19 +1088,24 @@ void verify_result(RunnerContext& ctx, const void* model_pte) { et_rtol, et_atol); ET_LOG(Error, "TEST: BundleIO index[%d] Test_result: FAIL", testset_idx); + ET_LOG( + Error, "Bundle verification failed with status 0x%" PRIx32, status); + model_ok = false; } - ET_CHECK_MSG( - status == Error::Ok, - "Bundle verification failed with status 0x%" PRIx32, - status); + } else { + // No checking done, assume true + model_ok = true; } -#else +#else // defined(ET_BUNDLE_IO) (void)ctx; (void)model_pte; -#endif + // No checking done, assume true + model_ok = true; +#endif // defined(ET_BUNDLE_IO) + return model_ok; } -void run_model(RunnerContext& ctx, const void* model_pte) { +bool run_model(RunnerContext& ctx, const void* model_pte) { Error status; ET_LOG(Info, "Starting running %d inferences...", num_inferences); int n = 0; @@ -946,7 +1133,10 @@ void run_model(RunnerContext& ctx, const void* model_pte) { ET_LOG(Info, "%d inferences finished", num_inferences); print_outputs(ctx); - verify_result(ctx, model_pte); + bool model_ok = verify_result(ctx, model_pte); + ET_LOG(Info, "Model run: %d", model_ok); + + return model_ok; } } // namespace @@ -1047,10 +1237,14 @@ int main(int argc, const char* argv[]) { model_pte[7]); runner_init(ctx, input_buffers, pte_size); - run_model(ctx, model_pte); + bool model_ok = run_model(ctx, model_pte); + ET_LOG(Info, "Model run: %d", model_ok); + log_mem_status(ctx); write_etdump(ctx); + ET_CHECK_MSG(model_ok == true, "Problem running model"); + ET_LOG(Info, "Program complete, exiting."); #if defined(SEMIHOSTING) _exit(0); diff --git a/examples/arm/executor_runner/arm_memory_allocator.cpp b/examples/arm/executor_runner/arm_memory_allocator.cpp index 6b627625ae1..de670df29ae 100644 --- a/examples/arm/executor_runner/arm_memory_allocator.cpp +++ b/examples/arm/executor_runner/arm_memory_allocator.cpp @@ -6,12 +6,38 @@ #include "arm_memory_allocator.h" +#if defined(EXECUTORCH_ENABLE_ADDRESS_SANITIZER) +extern "C" { +void __asan_poison_memory_region(void* addr, size_t size); +void __asan_unpoison_memory_region(void* addr, size_t size); +} + +static void asan_poison_buffer(uint8_t* base, size_t size) { + if (base != nullptr && size > 0) { + __asan_poison_memory_region(base, size); + } +} + +static void asan_unpoison_buffer(void* base, size_t size) { + if (base != nullptr && size > 0) { + __asan_unpoison_memory_region(base, size); + } +} +#endif + ArmMemoryAllocator::ArmMemoryAllocator(uint32_t size, uint8_t* base_address) - : MemoryAllocator(size, base_address), used_(0) {} + : MemoryAllocator(size, base_address), used_(0) { +#if defined(EXECUTORCH_ENABLE_ADDRESS_SANITIZER) + asan_poison_buffer(base_address, size); +#endif +} void* ArmMemoryAllocator::allocate(size_t size, size_t alignment) { void* ret = executorch::runtime::MemoryAllocator::allocate(size, alignment); if (ret != nullptr) { +#if defined(EXECUTORCH_ENABLE_ADDRESS_SANITIZER) + asan_unpoison_buffer(ret, size); +#endif // Align with the same code as in MemoryAllocator::allocate() to keep // used_ "in sync" As alignment is expected to be power of 2 (checked by // MemoryAllocator::allocate()) we can check it the lower bits @@ -37,4 +63,7 @@ size_t ArmMemoryAllocator::free_size() const { void ArmMemoryAllocator::reset() { executorch::runtime::MemoryAllocator::reset(); used_ = 0; +#if defined(EXECUTORCH_ENABLE_ADDRESS_SANITIZER) + asan_poison_buffer(base_address(), size()); +#endif } diff --git a/examples/arm/executor_runner/arm_perf_monitor.cpp b/examples/arm/executor_runner/arm_perf_monitor.cpp index 58a47105743..35fd114f777 100644 --- a/examples/arm/executor_runner/arm_perf_monitor.cpp +++ b/examples/arm/executor_runner/arm_perf_monitor.cpp @@ -19,7 +19,7 @@ namespace { #if defined(ETHOSU55) || defined(ETHOSU65) const uint32_t ethosu_pmuCountersUsed = 4; #elif defined(ETHOSU85) -const uint32_t ethosu_pmuCountersUsed = 5; +const uint32_t ethosu_pmuCountersUsed = 7; #else #error No NPU target defined #endif @@ -65,11 +65,14 @@ void ethosu_inference_begin(struct ethosu_driver* drv, void*) { ETHOSU_PMU_Set_EVTYPER(drv, 2, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED); ETHOSU_PMU_Set_EVTYPER(drv, 3, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN); ETHOSU_PMU_Set_EVTYPER(drv, 4, ETHOSU_PMU_NPU_IDLE); - // Enable the 5 counters + ETHOSU_PMU_Set_EVTYPER(drv, 5, ETHOSU_PMU_MAC_ACTIVE); + ETHOSU_PMU_Set_EVTYPER(drv, 6, ETHOSU_PMU_WD_ACTIVE); + // Enable the 7 counters ETHOSU_PMU_CNTR_Enable( drv, ETHOSU_PMU_CNT1_Msk | ETHOSU_PMU_CNT2_Msk | ETHOSU_PMU_CNT3_Msk | - ETHOSU_PMU_CNT4_Msk | ETHOSU_PMU_CNT5_Msk); + ETHOSU_PMU_CNT4_Msk | ETHOSU_PMU_CNT5_Msk | ETHOSU_PMU_CNT6_Msk | + ETHOSU_PMU_CNT7_Msk); #else #error No NPU target defined #endif @@ -214,7 +217,7 @@ void StopMeasurements(int num_inferences) { #elif defined(ETHOSU85) ET_LOG( Info, - "Ethos-U PMU Events:[ETHOSU_PMU_SRAM_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_SRAM_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_NPU_IDLE]"); + "Ethos-U PMU Events:[ETHOSU_PMU_SRAM_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_SRAM_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_NPU_IDLE, ETHOSU_PMU_MAC_ACTIVE, ETHOSU_PMU_WD_ACTIVE]"); #else #error No NPU target defined #endif diff --git a/examples/arm/executor_runner/pte_to_header.py b/examples/arm/executor_runner/pte_to_header.py index 1b5fad05a12..65213bc729e 100644 --- a/examples/arm/executor_runner/pte_to_header.py +++ b/examples/arm/executor_runner/pte_to_header.py @@ -59,7 +59,7 @@ def input_file_path(path): if __name__ == "__main__": args = parser.parse_args() outfile = os.path.join(args.outdir, args.outfile) - attr = f'__attribute__((section("{args.section}"), aligned(16))) char ' + attr = f'__attribute__((section("{args.section}"), aligned(16))) unsigned char ' with open(args.pte, "rb") as fr, open(outfile, "w") as fw: data = fr.read() diff --git a/examples/arm/pruning_minimal_example.ipynb b/examples/arm/pruning_minimal_example.ipynb new file mode 100644 index 00000000000..37bffffd763 --- /dev/null +++ b/examples/arm/pruning_minimal_example.ipynb @@ -0,0 +1,566 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c0156802", + "metadata": {}, + "source": [ + "# Copyright 2025 Arm Limited and/or its affiliates.\n", + "#\n", + "# This source code is licensed under the BSD-style license found in the\n", + "# LICENSE file in the root directory of this source tree." + ] + }, + { + "cell_type": "markdown", + "id": "26b849fd", + "metadata": {}, + "source": [ + "# Introduction\n", + "Model conditioning techniques like pruning modify the weights of a Machine Learning model and in some cases allow significant speed-up of the inference execution, reduction of the memory footprint and reduction in the overall power consumption of the system. Assuming you can optimise your workload without loss in accuracy and you target an Arm® Ethos™ NPU or a GPU with a Neural Engine, you should consider pruning the neural network before compiling it in the to_edge_transform_and_lower stage." + ] + }, + { + "cell_type": "markdown", + "id": "9a7d6d97", + "metadata": {}, + "source": [ + "# Why apply model conditioning?\n", + "The Ethos-U hardware has a dedicated weight decoder to process the model weights. At the same time, the compiler arranges the weights into blocks and the blocks are then fed to the hardware weight decoder. As part of the block arrangement process, the compiler compresses sequences of zero weights and clusters of weights. To avoid any doubt, the compression by the compiler is lossless - to the same input tensor, irrespective of whether compression was applied or not, the output tensor from execution on the NPU will be the same. If the model you provide in the to_edge_transform_and_lower stage is optimised to have sequences of zero weights and/or clusters of the same weights, the compiler will be able to compress these weights very efficiently. The good compression would result in lower number of memory accesses by the NPU at runtime, which would mean that the MAC engines are not waiting on memory accesses resulting in better overall performance. In other words, if you have a memory bound model, you should consider pruning and clustering your neural network before lowering it in the to_edge_transform_and_lower stage.\n", + "\n", + "The Ethos-U85 hardware also has hardware support for 2:4 sparse weights - if you have 2:4 sparse weights, the MAC array will skip multiplications where the result will be 0. The 2:4 sparsity allow power savings for all configurations and provides a speed-up on compute-bound neural networks.\n", + "\n", + "Before we begin, make sure you are running the Jupyter notebook from the correct python virtual environment variable." + ] + }, + { + "cell_type": "markdown", + "id": "d6532247", + "metadata": {}, + "source": [ + "# Prerequisites\n", + "Let's import python the packages you will need to run through the jupyter notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8a191d7", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torchvision import datasets, transforms\n", + "from torch import nn\n", + "import torch.nn.utils.prune as prune\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader, Subset\n", + "import random\n", + "\n", + "from executorch.backends.arm.ethosu import EthosUPartitioner\n", + "from executorch.exir import (\n", + " EdgeCompileConfig,\n", + " ExecutorchBackendConfig,\n", + " to_edge_transform_and_lower,\n", + ")\n", + "from executorch.backends.arm.ethosu import EthosUCompileSpec\n", + "from executorch.backends.arm.quantizer import (\n", + " EthosUQuantizer,\n", + " get_symmetric_quantization_config,\n", + ")\n", + "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", + "from executorch.extension.export_util.utils import save_pte_program" + ] + }, + { + "cell_type": "markdown", + "id": "6af794bc", + "metadata": {}, + "source": [ + "# Model conditioning with PyTorch and deployment with ExecuTorch \n", + "We'll define a simple model with 3 back-to-back Linear layers. We will execute the model on the Ethos-U85 NPU, then we will prune the model and execute the pruned variant on the Ethos-U85 and compare the performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e37c2ce", + "metadata": {}, + "outputs": [], + "source": [ + "LR = 1e-3\n", + "NUM_EPOCHS = 1\n", + "BATCH_SIZE = 128\n", + "\n", + "# Data\n", + "transform = transforms.Compose([transforms.ToTensor()])\n", + "train_ds = datasets.MNIST(\"./data\", train=True, download=True, transform=transform)\n", + "test_ds = datasets.MNIST(\"./data\", train=False, transform=transform)\n", + "\n", + "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)\n", + "test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)\n", + "\n", + "class Simple_NN(nn.Module): \n", + " def __init__(self):\n", + " super().__init__()\n", + " self.flatten = nn.Flatten()\n", + " self.fc1 = nn.Linear(28 * 28, 512)\n", + " self.fc2 = nn.Linear(512, 256)\n", + " self.fc3 = nn.Linear(256, 10)\n", + "\n", + " def forward(self, x):\n", + " x = self.flatten(x)\n", + " x = F.relu(self.fc1(x))\n", + " x = F.relu(self.fc2(x))\n", + " x = self.fc3(x)\n", + " return x\n", + " \n", + " def prunable_parameters(self):\n", + " return (\n", + " (self.fc1, \"weight\"),\n", + " (self.fc2, \"weight\"),\n", + " (self.fc3, \"weight\"),\n", + " )\n", + "\n", + " def prune(self, pruning_method: prune.BasePruningMethod, amount: float = 0.1):\n", + " # reference https://pytorch.org/tutorials/intermediate/pruning_tutorial.html\n", + "\n", + " # produces a mask that is multiplied with the parameter\n", + " prune.global_unstructured(\n", + " self.prunable_parameters(),\n", + " pruning_method=pruning_method,\n", + " amount=amount,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "6db1e58d", + "metadata": {}, + "source": [ + "We define a simple model with 3 back-to-back linear layers. Linear is highly memory bound operation because every weight is read once only from the external memory. It is impossible to buffer the weights in memory(you usually have more weights in the external memory than space in the SARM) and reuse them for the computation. In comparison, in a convolution you usually have small filter sizes(e.g. 3x3 filter) which means you can buffer all the convolution weights in memory and reuse them for the computation. If your model or module within the model is composed entirely of Linear layers, the workload will be memory bound and pruning is likely to provide good speed-up.\n", + "\n", + "Next, let's define a simple function to train the network and a function to evaluate the accuracy of the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "477312ae", + "metadata": {}, + "outputs": [], + "source": [ + "# Training loop\n", + "def train(model):\n", + " # The model is simple enough that we can train it on CPU\n", + " device = \"cpu\"\n", + " for epoch in range(NUM_EPOCHS):\n", + " # ---- Training ----\n", + " model.train()\n", + " opt = torch.optim.Adam(model.parameters(), lr=LR)\n", + " criterion = torch.nn.CrossEntropyLoss()\n", + " for step, (inp, out_real) in enumerate(train_loader):\n", + " inp, out_real = inp.to(device), out_real.to(device)\n", + " opt.zero_grad()\n", + " out_pred = model(inp)\n", + " loss = criterion(out_pred, out_real)\n", + " #print(f\"Loss: {loss.item():.4f}\")\n", + " loss.backward()\n", + " opt.step()\n", + "\n", + "def evaluate(model):\n", + " # ---- Evaluation ----\n", + " correct, total = 0, 0\n", + " with torch.no_grad():\n", + " for inp, out_real in test_loader:\n", + " out_pred = model(inp)\n", + " preds = out_pred.argmax(1)\n", + " correct += (preds == out_real).sum().item()\n", + " total += out_real.size(0)\n", + "\n", + " acc = 100 * correct / total\n", + " print(f\"Top 1 accuracy = {acc:.2f}%\")" + ] + }, + { + "cell_type": "markdown", + "id": "a4750eaf", + "metadata": {}, + "source": [ + "Let's instantiate the model and train it. In order to get reproducible results, we will fix the seed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc68a7d9", + "metadata": {}, + "outputs": [], + "source": [ + "SEED = 123\n", + "torch.manual_seed(SEED)\n", + "model = Simple_NN()\n", + "train(model)\n", + "print(\"Evaluate FP32 model accuracy\")\n", + "evaluate(model)" + ] + }, + { + "cell_type": "markdown", + "id": "9837d9ba", + "metadata": {}, + "source": [ + "We obtain 96% top1 accuracy for the FP32 model.\n", + "\n", + "Next, we would like to apply post-training quantization with ExecuTorch and evaluate the accuracy of the quantized model. It is important to calibrate the quantized model on a few real samples from the MNIST dataset to get good quantization parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "855c542f", + "metadata": {}, + "outputs": [], + "source": [ + "# MNIST images are 28x28 in greyscale, hence the shape is 1x1x28x28\n", + "example_inputs = (torch.randn(1,1,28,28),)\n", + "exported_program = torch.export.export(model, example_inputs)\n", + "graph_module = exported_program.module(check_guards=False)\n", + "\n", + "# Create a compilation spec describing the target for configuring the quantizer\n", + "compile_spec = EthosUCompileSpec(\n", + " target=\"ethos-u85-128\",\n", + " system_config=\"Ethos_U85_SYS_Flash_High\",\n", + " memory_mode=\"Shared_Sram\",\n", + " extra_flags=[\"--output-format=raw\", \"--debug-force-regor --verbose-weights\"]\n", + " )\n", + "\n", + "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", + "quantizer = EthosUQuantizer(compile_spec)\n", + "operator_config = get_symmetric_quantization_config()\n", + "quantizer.set_global(operator_config)\n", + "\n", + "# Post training quantization, need a few example images to obtain good quantization parameters\n", + "subset_indices = random.sample(range(len(train_ds)), 50)\n", + "calibration_set = Subset(train_ds, subset_indices)\n", + "calibration_loader = DataLoader(calibration_set, shuffle=False)\n", + "\n", + "quantized_graph_module = prepare_pt2e(graph_module, quantizer)\n", + "for batch_images,label in calibration_loader:\n", + " quantized_graph_module(*batch_images) # Calibrate the graph module with the example input\n", + "quantized_graph_module = convert_pt2e(quantized_graph_module)" + ] + }, + { + "cell_type": "markdown", + "id": "996faefd", + "metadata": {}, + "source": [ + "Next, let us evaluate the accuracy of the quantized model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63da2b30", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Accuracy of the quantized model\")\n", + "evaluate(quantized_graph_module)" + ] + }, + { + "cell_type": "markdown", + "id": "2ff3462c", + "metadata": {}, + "source": [ + "We maintain the 96% top1 accuracy for the quantized model. Next, let's compile the model for the Ethos-U backend. We will define a function `generate_pte` that calls `to_edge_transform_and_lower` and saves the pte file on device." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa8259f4", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_pte(quantized_exported_program,compile_spec,name):\n", + " # Create partitioner from compile spec\n", + " partitioner = EthosUPartitioner(compile_spec)\n", + "\n", + " # Lower the exported program to the Ethos-U backend\n", + " edge_program_manager = to_edge_transform_and_lower(\n", + " quantized_exported_program,\n", + " partitioner=[partitioner],\n", + " compile_config=EdgeCompileConfig(\n", + " _check_ir_validity=False,\n", + " ),\n", + " )\n", + "\n", + " # Convert edge program to executorch\n", + " executorch_program_manager = edge_program_manager.to_executorch(\n", + " config=ExecutorchBackendConfig(extract_delegate_segments=False)\n", + " )\n", + "\n", + " # Save pte file\n", + " save_pte_program(executorch_program_manager, f\"{name}.pte\")\n", + "\n", + "# Create a new exported program using the quantized_graph_module\n", + "quantized_exported_program = torch.export.export(quantized_graph_module, example_inputs)\n", + "generate_pte(quantized_exported_program,compile_spec,\"original_model\")" + ] + }, + { + "cell_type": "markdown", + "id": "2b6cae04", + "metadata": {}, + "source": [ + "Note that as part of the compilation process in `to_edge_transform_and_lower`, we get Weight Compression information:\n", + "```\n", + "Original Weights Size 522.50 KiB\n", + "NPU Encoded Weights Size 507.44 KiB\n", + "```\n", + "In other words, the original Weights are 522KB and after compilation and encoding by the compiler, we get 507KB of weights that will be read by the NPU at runtime. Remember this is for the case when we've not applied pruning or clustering. This will generate original_model.pte file that we will deploy on device later on. \n", + "\n", + "Next, let's move on to prune the model and evaluate its accuracy. We have a lot of weights in the original network, so we will apply 95% pruning rate." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "493eed60", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Prune the model\")\n", + "model.prune(pruning_method=prune.L1Unstructured, amount=0.95)\n", + "print(\"Evaluate pruned model accuracy\")\n", + "evaluate(model)" + ] + }, + { + "cell_type": "markdown", + "id": "82460ba6", + "metadata": {}, + "source": [ + "We obtain 37% top1 accuracy for the pruned model. That can seem surprising at first sight, but remember that when we prune, we randomly set 95% of the weights to 0. It is normal to lose accuracy when applying pruning. We need to retrain the model in order to recover the accuracy we've lost from the pruning. We can do that easily by calling the train function one more time. Once we are done with the retraining, it is important to remove the parameters we've pruned." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c816ad25", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Train the pruned model to recover the lost information\")\n", + "train(model)\n", + "# Remove the pruned parameters when we've retrained the model and recovered the lost accuracy\n", + "for a,b in model.prunable_parameters():\n", + " prune.remove(a, b)\n", + "\n", + "print(\"Evaluate pruned model accuracy\")\n", + "evaluate(model)" + ] + }, + { + "cell_type": "markdown", + "id": "fbb70d47", + "metadata": {}, + "source": [ + "We obtain 96% top1 accuracy for the pruned workload so we have recovered the accuracy we've lost with the pruning. Let's quantize the pruned model, evaluate the accuracy of the int8 network and obtain a pte file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3cdb0f59", + "metadata": {}, + "outputs": [], + "source": [ + "pruned_exported_program = torch.export.export(model, example_inputs)\n", + "pruned_graph_module = pruned_exported_program.module(check_guards=False)\n", + "quantized_pruned_graph_module = prepare_pt2e(pruned_graph_module, quantizer)\n", + "for batch_images,label in calibration_loader:\n", + " quantized_pruned_graph_module(*batch_images) # Calibrate the graph module with the example input\n", + "quantized_pruned_graph_module = convert_pt2e(quantized_pruned_graph_module)\n", + "print(\"Accuracy of the pruned quantized model\")\n", + "evaluate(quantized_pruned_graph_module)\n", + "\n", + "quantized_ep_pruned = torch.export.export(quantized_pruned_graph_module, example_inputs)\n", + "generate_pte(quantized_ep_pruned,compile_spec,\"pruned_model\")" + ] + }, + { + "cell_type": "markdown", + "id": "4263714e", + "metadata": {}, + "source": [ + "We obtain 96% top1 accuracy of the quantized pruned model. What is interesting is that this time, the NPU encoded weights size shrank considerably:\n", + "```\n", + "Original Weights Size 522.50 KiB\n", + "NPU Encoded Weights Size 46.12 KiB\n", + "```\n", + "In other words, we are now solving the MNIST classification problem with just 46KB of encoded weights. This is a significant reduction from the 507KB we had in the original model.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "562fdb16", + "metadata": {}, + "source": [ + "# NPU performance\n", + "In the sections above, we generated two pte files - one pte for the original model and another pte for the pruned model. These models perform very similarly in terms of accuracy. Let's run both of these models on the NPU and analyse the performance at runtime.\n", + "\n", + "# Performance of the original model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4bdd91dc", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "# Ensure the arm-none-eabi-gcc toolchain and FVP:s are available on $PATH\n", + "source arm-scratch/setup_path.sh\n", + "\n", + "# Build executorch libraries cross-compiled for arm baremetal to executorch/cmake-out-arm\n", + "cmake --preset arm-baremetal \\\n", + "-DCMAKE_BUILD_TYPE=Release \\\n", + "-B../../cmake-out-arm ../..\n", + "cmake --build ../../cmake-out-arm --target install -j$(nproc) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "756ab779", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source arm-scratch/setup_path.sh\n", + "# Build example executor runner application to examples/arm/ethos_u_minimal_example\n", + "cmake -DCMAKE_TOOLCHAIN_FILE=$(pwd)/ethos-u-setup/arm-none-eabi-gcc.cmake \\\n", + " -DCMAKE_BUILD_TYPE=Release \\\n", + " -DET_PTE_FILE_PATH=original_model.pte \\\n", + " -DTARGET_CPU=cortex-m55 \\\n", + " -DETHOSU_TARGET_NPU_CONFIG=ethos-u85-128 \\\n", + " -DMEMORY_MODE=Shared_Sram \\\n", + " -DSYSTEM_CONFIG=Ethos_U85_SYS_DRAM_Mid \\\n", + " -Bethos_u_original_model \\\n", + " executor_runner\n", + "cmake --build ethos_u_original_model -j$(nproc) -- arm_executor_runner" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a525a09", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source arm-scratch/setup_path.sh\n", + "# Run the pruned model\n", + "../../backends/arm/scripts/run_fvp.sh --elf=ethos_u_original_model/arm_executor_runner --target=ethos-u85-128" + ] + }, + { + "cell_type": "markdown", + "id": "23ebdc46", + "metadata": {}, + "source": [ + "We obtain a total of 99k NPU Active cycles. The MAC engines of the NPU are active during 8k cycles and the Weight Decoder is active during 74k NPU cycles. It's worth noting that the data flow in the Ethos-U is pipelined. In other words, the MAC array and the Weight Decoder are working at the same time. Having a total of 99k NPU cycles and only 8k Active MAC cycles and 74k of Weight Decoder active cycles means that the NPU is spending most of the time decoding weights and the MAC array is underutilized. Pruning is designed to alleviate that bottleneck. Let's analyse the performance of the pruned workload.\n", + "\n", + "# Performance of the pruned model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7c09926", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source arm-scratch/setup_path.sh\n", + "\n", + "# Build example executor runner application to examples/arm/ethos_u_minimal_example\n", + "cmake -DCMAKE_TOOLCHAIN_FILE=$(pwd)/ethos-u-setup/arm-none-eabi-gcc.cmake \\\n", + " -DCMAKE_BUILD_TYPE=Release \\\n", + " -DET_PTE_FILE_PATH=pruned_model.pte \\\n", + " -DTARGET_CPU=cortex-m55 \\\n", + " -DETHOSU_TARGET_NPU_CONFIG=ethos-u85-128 \\\n", + " -DMEMORY_MODE=Shared_Sram \\\n", + " -DSYSTEM_CONFIG=Ethos_U85_SYS_DRAM_Mid \\\n", + " -Bethos_u_pruned_model \\\n", + " executor_runner\n", + "cmake --build ethos_u_pruned_model -j$(nproc) -- arm_executor_runner" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "891947f7", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source arm-scratch/setup_path.sh\n", + "# Run the pruned model\n", + "../../backends/arm/scripts/run_fvp.sh --elf=ethos_u_pruned_model/arm_executor_runner --target=ethos-u85-128" + ] + }, + { + "cell_type": "markdown", + "id": "e55ae929", + "metadata": {}, + "source": [ + "On the pruned model, the inference completes in 22k NPU cycles. The NPU still performs 8k MACs, but this time the number of cycles when the weight decoder is active has dropped to to 17k cycles. \n", + "It's also worth noting that the size of the pte file has been reduced significantly - from 518 KB of the original model to 57KB of the pruned workload. " + ] + }, + { + "cell_type": "markdown", + "id": "d934fe41", + "metadata": {}, + "source": [ + "# Conclusion\n", + "We defined a simple model to solve the MNIST dataset. The model is using Linear layers and is heavily memory-bound on the external memory. We pruned the model and obtain similar int8 accuracy between the original workload and the pruned counterpart. Let us put the results from the runtime in a table and draw a few conclusions: \n", + "\n", + "| Model |NPU_ACTIVE cycles | NPU Encoded Weight Size | Weight Decoder Active Cycles | External memory beats read | Size of the pte file |\n", + "| ----------------------------------------|----------------- | ------------------------- | -----------------------------|---------------------------------|-----------------------|\n", + "| Original model | 97k | 506 KB | 74k | 32k | 517 KB |\n", + "| Pruned model | 22k | 46 KB | 8k | 3k | 57 KB |\n", + "\n", + "For the pruned network, we obtain a significant uplift - over 3x improvement in the inference speed and a drastic reduction in the number of cycles when the Weight Decoder is active. The NPU will consume lower power and the size of the pruned model that we save on-device is significantly smaller compared to the original network." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv_py3.10", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/arm/run.sh b/examples/arm/run.sh index 77dddfe6451..3e743905655 100755 --- a/examples/arm/run.sh +++ b/examples/arm/run.sh @@ -36,11 +36,13 @@ config="" memory_mode="" pte_placement="elf" et_build_root="${et_root_dir}/arm_test" -ethos_u_scratch_dir=${script_dir}/ethos-u-scratch +arm_scratch_dir=${script_dir}/arm-scratch scratch_dir_set=false toolchain=arm-none-eabi-gcc select_ops_list="aten::_softmax.out" qdq_fusion_op=false +model_explorer=false +perf_overlay=false function help() { echo "Usage: $(basename $0) [options]" @@ -53,13 +55,13 @@ function help() { echo " --no_quantize Do not quantize the model (can't override builtin models)" echo " --portable_kernels= TO BE DEPRECATED: Alias to select_ops_list." echo " --select_ops_list= Comma separated list of portable (non delagated) kernels to include Default: ${select_ops_list}" - echo " NOTE: This is used when select_ops_model is not possible to use, e.g. for semihosting or bundleio." + echo " NOTE: This is only used when building for semihosting." echo " See https://docs.pytorch.org/executorch/stable/kernel-library-selective-build.html for more information." echo " --target= Target to build and run for Default: ${target}" echo " --output= Target build output folder Default: ${output_folder}" echo " --bundleio Create Bundled pte using Devtools BundelIO with Input/RefOutput included" echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" - echo " --build_type= Build with Release, Debug or RelWithDebInfo, default is ${build_type}" + echo " --build_type= Build with Release, Debug, RelWithDebInfo, UndefinedSanitizer or AddressSanitizer, default is ${build_type}" echo " --extra_build_flags= Extra flags to pass to cmake like -DET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE=60000 Default: none " echo " --build_only Only build, don't run" echo " --toolchain= Ethos-U: Toolchain can be specified (e.g. bare metal as arm-none-eabi-gcc or zephyr as arm-zephyr-eabi-gcc Default: ${toolchain}" @@ -69,8 +71,10 @@ function help() { echo " --memory_mode= Ethos-U: Memory mode to select from the Vela configuration file (see vela.ini), e.g. Shared_Sram/Sram_Only. Default: 'Shared_Sram' for Ethos-U55 targets, 'Sram_Only' for Ethos-U85 targets" echo " --pte_placement= Ethos-U: Control if runtime has PTE baked into the elf or if its placed in memory outside of the elf, defaults to ${pte_placement}" echo " --et_build_root= Executorch build output root folder to use, defaults to ${et_build_root}" - echo " --scratch-dir= Path to your Ethos-U scrach dir if you not using default ${ethos_u_scratch_dir}" + echo " --scratch-dir= Path to your Arm scrach dir if you not using default ${arm_scratch_dir}" echo " --qdq_fusion_op Enable QDQ fusion op" + echo " --model_explorer Enable model explorer to visualize TOSA graph." + echo " --perf_overlay With --model_explorer, include performance data from FVP PMU trace." exit 0 } @@ -97,13 +101,20 @@ for arg in "$@"; do --memory_mode=*) memory_mode="${arg#*=}";; --pte_placement=*) pte_placement="${arg#*=}";; --et_build_root=*) et_build_root="${arg#*=}";; - --scratch-dir=*) ethos_u_scratch_dir="${arg#*=}" ; scratch_dir_set=true ;; + --scratch-dir=*) arm_scratch_dir="${arg#*=}" ; scratch_dir_set=true ;; --qdq_fusion_op) qdq_fusion_op=true;; + --model_explorer) model_explorer=true ;; + --perf_overlay) perf_overlay=true ;; *) ;; esac done +if [ "$perf_overlay" = true ] && [ "$model_explorer" != true ]; then + echo "Error: --perf_overlay requires --model_explorer" >&2 + exit 1 +fi + if ! [[ ${pte_placement} == "elf" ]]; then if ! [[ "$pte_placement" =~ ^0x[0-9a-fA-F]{1,16}$ ]]; then echo "ERROR: Placing the PTE in memory failed, address is larger then 64bit $pte_placement" @@ -113,8 +124,8 @@ if ! [[ ${pte_placement} == "elf" ]]; then fi # Default Ethos-u tool folder override with --scratch-dir= -ethos_u_scratch_dir=$(realpath ${ethos_u_scratch_dir}) -setup_path_script=${ethos_u_scratch_dir}/setup_path.sh +arm_scratch_dir=$(realpath ${arm_scratch_dir}) +setup_path_script=${arm_scratch_dir}/setup_path.sh if [[ ${toolchain} == "arm-none-eabi-gcc" ]]; then toolchain_cmake=${et_root_dir}/examples/arm/ethos-u-setup/${toolchain}.cmake elif [[ ${toolchain} == "arm-zephyr-eabi-gcc" ]]; then @@ -201,6 +212,7 @@ bundleio_flag="" etrecord_flag="" et_dump_flag="" qdq_fusion_op_flag="" +fvp_pmu_flag="" if [ "$build_with_etdump" = true ] ; then et_dump_flag="--etdump" etrecord_flag="--etrecord" @@ -222,7 +234,6 @@ if [[ -z "$model_name" ]]; then test_model=( "softmax" # 0 "add" # 1 - "add3" # 2 "qadd" # 3 "qadd2" # 4 "qops" # 5 @@ -231,7 +242,6 @@ if [[ -z "$model_name" ]]; then model_compiler_flags=( "" # 0 softmax "--delegate" # 1 add - "--delegate" # 2 add3 "--delegate --quantize" # 3 qadd "--delegate --quantize" # 4 qadd2 "--delegate --quantize" # 5 qops @@ -272,6 +282,11 @@ for i in "${!test_model[@]}"; do output_folder=${et_build_root}/${model_short_name} fi + if [ "$perf_overlay" = true ] ; then + model_compiler_flags+="--enable_debug_mode tosa" + fvp_pmu_flag="--trace_file=${output_folder}/pmu_trace.gz" + fi + mkdir -p ${output_folder} output_folder=$(realpath ${output_folder}) pte_file="${output_folder}/${model_filename_ext}" @@ -289,6 +304,12 @@ for i in "${!test_model[@]}"; do pte_file=$(realpath ${pte_file}) + if [ "${etrecord_flag}" != "" ] ; then + etrecord_filename="${output_folder}/${model_filename}_etrecord.bin" + etrecord_filename=$(realpath ${etrecord_filename}) + etrecord_flag="--etrecord=${etrecord_filename}" + fi + [[ -f ${pte_file} ]] || { >&2 echo "Failed to generate a pte file - ${pte_file}"; exit 1; } echo "pte_data_size: $(wc -c ${pte_file})" echo "pte_file: ${pte_file}" @@ -300,7 +321,8 @@ for i in "${!test_model[@]}"; do set -x backends/arm/scripts/build_executor_runner_vkml.sh --build_type=${build_type} \ --extra_build_flags="${extra_build_flags}" \ - --output="${output_folder}" + --output="${output_folder}" \ + ${bundleio_flag} if [ "$build_only" = false ] ; then backends/arm/scripts/run_vkml.sh --model=${pte_file} --build_path=${output_folder} fi @@ -319,13 +341,23 @@ for i in "${!test_model[@]}"; do fi set -x - backends/arm/scripts/build_executor_runner.sh --et_build_root="${et_build_root}" --pte="${pte_file_or_mem}" --build_type=${build_type} --target=${target} --system_config=${system_config} --memory_mode=${memory_mode} ${bundleio_flag} ${et_dump_flag} --extra_build_flags="${extra_build_flags}" --ethosu_tools_dir="${ethos_u_scratch_dir}" --toolchain="${toolchain}" --select_ops_list="${select_ops_list}" + backends/arm/scripts/build_executor_runner.sh --et_build_root="${et_build_root}" --pte="${pte_file_or_mem}" --build_type=${build_type} --target=${target} --system_config=${system_config} --memory_mode=${memory_mode} ${bundleio_flag} ${et_dump_flag} --extra_build_flags="${extra_build_flags}" --ethosu_tools_dir="${arm_scratch_dir}" --toolchain="${toolchain}" --select_ops_list="${select_ops_list}" if [ "$build_only" = false ] ; then # Execute the executor_runner on FVP Simulator - backends/arm/scripts/run_fvp.sh --elf=${elf_file} ${model_data} --target=$target + + backends/arm/scripts/run_fvp.sh --elf=${elf_file} ${model_data} --target=$target ${etrecord_flag} ${fvp_pmu_flag} fi set +x fi + + if [ "$model_explorer" = true ]; then + tosa_flatbuffer_path=$(find ${output_folder} -name "*TOSA*.tosa" | head -n 1) + perf_flags="" + if [ "$perf_overlay" = true ]; then + perf_flags+="--trace ${output_folder}/pmu_trace.gz --tables ${output_folder}/output/out_debug.xml" + fi + python3 ${script_dir}/visualize.py --model_path ${tosa_flatbuffer_path} ${perf_flags} + fi done exit 0 diff --git a/examples/arm/run_mcu_models_fvp.sh b/examples/arm/run_mcu_models_fvp.sh deleted file mode 100755 index 68d5ec03003..00000000000 --- a/examples/arm/run_mcu_models_fvp.sh +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin/env bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# Copyright 2023-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Prerequisite steps: (run the following commands before running this script) -# 1. Setup your environment for Arm FVP -# a. Setup Conda environment / venv -# b. ./install_executorch.sh --clean ; ./install_executorch.sh --editable; -# c. examples/arm/setup.sh --i-agree-to-the-contained-eula; -# d. source examples/arm/ethos-u-scratch/setup_path.sh -# 2. bash examples/selective_build/test_selective_build.sh cmake - -set -u - -# Valid targets for MCU model validation -VALID_TARGETS=( - "cortex-m55" - "cortex-m85" -) - -# Default models for MCU validation with portable kernels -DEFAULT_MODELS=(mv2 mv3 lstm) -# Available models (on FVP) -AVAILABLE_MODELS=(mv2 mv3 lstm) -# Add the following models if you want to enable them later (atm they are not working on FVP) -# edsr w2l ic3 ic4 resnet18 resnet50 - -# Variables -TARGET="" -MODELS=() -PASSED_MODELS=() -FAILED_MODELS=() - -# Function to validate target -validate_target() { - local target=$1 - for valid_target in "${VALID_TARGETS[@]}"; do - if [[ "$target" == "$valid_target" ]]; then - return 0 - fi - done - return 1 -} - -# Function to validate models -validate_models() { - local invalid_models=() - for model in "${MODELS[@]}"; do - if [[ ! " ${AVAILABLE_MODELS[*]} " =~ " $model " ]]; then - invalid_models+=("$model") - fi - done - - if [[ ${#invalid_models[@]} -gt 0 ]]; then - echo "❌ Error: Invalid model(s): ${invalid_models[*]}" - echo "Available models: ${AVAILABLE_MODELS[*]}" - return 1 - fi - return 0 -} - -cpu_to_ethos_target() { - local cpu=$1 - case $cpu in - cortex-m55) - echo "ethos-u55-128" - ;; - cortex-m85) - echo "ethos-u85-128" - ;; - *) - echo "Unknown CPU: $cpu" >&2 - return 1 - ;; - esac -} - -# Function to show usage -show_usage() { - echo "Usage: $0 --target= [--models=]" - echo "" - echo "MCU Model Validation without delegation" - echo "" - echo "Required arguments:" - echo " --target= Target platform for validation" - echo "" - echo "Optional arguments:" - echo " --models= Comma-separated list of models to test" - echo " (overrides default model list)" - echo "" - echo "Valid targets:" - printf ' %s\n' "${VALID_TARGETS[@]}" - echo "" - echo "Available models:" - printf ' %s\n' "${AVAILABLE_MODELS[@]}" - echo "" - echo "Examples:" - echo " $0 --target=ethos-u85-128" - echo " $0 --target=ethos-u55-128 --models=mv2,mv3,resnet18" - echo "" - echo "Default behavior:" - echo " - Uses all available models: ${DEFAULT_MODELS[*]}" - echo " - Runs with portable kernels (no delegation)" -} - -# Function to display summary -show_summary() { - local total_models=${#MODELS[@]} - - echo "" - echo "════════════════════════════════════════════════════════════════" - echo "🏁 MCU MODEL VALIDATION SUMMARY - TARGET: $TARGET" - echo "════════════════════════════════════════════════════════════════" - echo "" - - # Show individual results - for model in "${MODELS[@]}"; do - if [[ " ${PASSED_MODELS[*]} " =~ " $model " ]]; then - printf "%-12s : ✅ Passed\n" "$model" - elif [[ " ${FAILED_MODELS[*]} " =~ " $model " ]]; then - printf "%-12s : ❌ Failed\n" "$model" - else - printf "%-12s : ⏭️ Skipped\n" "$model" - fi - done - - echo "" - echo "────────────────────────────────────────────────────────────────" - - # Show statistics - local passed_count=${#PASSED_MODELS[@]} - local failed_count=${#FAILED_MODELS[@]} - local success_rate=$((passed_count * 100 / total_models)) - - echo "📊 STATISTICS:" - echo " Total Models : $total_models" - echo " ✅ Passed : $passed_count" - echo " ❌ Failed : $failed_count" - echo " 📈 Success Rate : $success_rate%" - echo "" - - # Show model selection info - if [[ ${#MODELS[@]} -eq ${#DEFAULT_MODELS[@]} ]] && [[ "${MODELS[*]}" == "${DEFAULT_MODELS[*]}" ]]; then - echo "📋 Model Selection: Default (all available models)" - else - echo "📋 Model Selection: Custom (${MODELS[*]})" - fi - echo "" - - # Overall result - if [[ $failed_count -eq 0 ]]; then - echo "🎉 OVERALL RESULT: ALL TESTS PASSED!" - echo "🔧 Mode: Portable Kernels (No Delegation)" - else - echo "⚠️ OVERALL RESULT: $failed_count/$total_models TESTS FAILED" - echo "🔧 Mode: Portable Kernels (No Delegation)" - echo "" - echo "🔍 Failed models: ${FAILED_MODELS[*]}" - fi - - echo "════════════════════════════════════════════════════════════════" - echo "" -} - -# Parse command line arguments -while [[ $# -gt 0 ]]; do - case $1 in - --target=*) - TARGET="${1#*=}" - shift - ;; - --models=*) - IFS=',' read -ra MODELS <<< "${1#*=}" - shift - ;; - -h|--help) - show_usage - exit 0 - ;; - *) - echo "❌ Error: Unknown argument '$1'" - echo "" - show_usage - exit 1 - ;; - esac -done - -# Check if target is provided -if [[ -z "$TARGET" ]]; then - echo "❌ Error: --target argument is required" - echo "" - show_usage - exit 1 -fi - -# Validate target -if ! validate_target "$TARGET"; then - echo "❌ Error: Invalid target '$TARGET'" - echo "" - show_usage - exit 1 -fi - -# Use default models if none specified -if [[ ${#MODELS[@]} -eq 0 ]]; then - MODELS=("${DEFAULT_MODELS[@]}") -fi - -# Validate models -if ! validate_models; then - exit 1 -fi - -# Remove duplicates from models array -IFS=" " read -r -a MODELS <<< "$(printf '%s\n' "${MODELS[@]}" | sort -u | tr '\n' ' ')" - -echo "🎯 MCU Model Validation - Target: $TARGET" -echo "📋 Processing models: ${MODELS[*]}" -echo "🔧 Mode: Portable Kernels (No Delegation)" -echo "" - -echo "🔨 Building ExecuteTorch libraries (one-time setup)..." -if ! backends/arm/scripts/build_executorch.sh; then - echo "❌ Failed to build ExecuteTorch libraries" - exit 1 -fi -echo "✅ ExecuteTorch libraries built successfully" -echo "" - -ETHOS_TARGET=$(cpu_to_ethos_target "$TARGET") -if [[ $? -ne 0 ]]; then - echo "Invalid CPU target: $TARGET" - exit 1 -fi -echo "Using ETHOS target: $ETHOS_TARGET" - -# Process each model -for model in "${MODELS[@]}"; do - echo "=== 🚀 Processing $model for $TARGET ===" - - # Track if this model succeeds - MODEL_SUCCESS=true - - # Step 1: Create directory - echo "📁 Creating directory arm_test/$model" - mkdir -p "arm_test/$model" - - # Step 2: AOT compilation (quantized, no delegation = portable kernels) - echo "⚙️ AOT compilation for $model" - if ! python3 -m examples.arm.aot_arm_compiler \ - -m "$model" \ - --target="$ETHOS_TARGET" \ - --quantize \ - --output="arm_test/$model"; then - echo "❌ AOT compilation failed for $model" - MODEL_SUCCESS=false - fi - - # Step 3: Build executor runner (only if AOT succeeded) - if [[ "$MODEL_SUCCESS" == true ]]; then - echo "🔨 Building executor runner for $model" - if ! backends/arm/scripts/build_executor_runner.sh \ - --pte="arm_test/$model/${model}_arm_${ETHOS_TARGET}.pte" \ - --target="$ETHOS_TARGET" \ - --output="arm_test/$model"; then - echo "❌ Executor runner build failed for $model" - MODEL_SUCCESS=false - fi - fi - - # Step 4: Run on FVP (only if build succeeded) - if [[ "$MODEL_SUCCESS" == true ]]; then - echo "🏃 Running $model on FVP with portable kernels" - if ! backends/arm/scripts/run_fvp.sh \ - --elf="arm_test/$model/arm_executor_runner" \ - --target="$ETHOS_TARGET"; then - echo "❌ FVP execution failed for $model" - MODEL_SUCCESS=false - fi - fi - - # Record result - if [[ "$MODEL_SUCCESS" == true ]]; then - echo "✅ $model completed successfully" - PASSED_MODELS+=("$model") - else - echo "❌ $model failed" - FAILED_MODELS+=("$model") - fi - - echo "" -done - -# Show comprehensive summary -show_summary - -# Exit with appropriate code for CI -if [[ ${#FAILED_MODELS[@]} -eq 0 ]]; then - exit 0 # Success -else - exit 1 # Failure -fi diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 2aa6590c64d..b36dd0f5c04 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -16,7 +16,7 @@ script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) et_dir=$(realpath $script_dir/../..) ARCH="$(uname -m)" OS="$(uname -s)" -root_dir="${script_dir}/ethos-u-scratch" # TODO: rename +root_dir="${script_dir}/arm-scratch" eula_acceptance=0 enable_baremetal_toolchain=1 target_toolchain="" @@ -26,7 +26,7 @@ enable_model_converter=0 # model-converter tool for VGF output enable_vgf_lib=0 # vgf reader - runtime backend dependency enable_emulation_layer=0 # Vulkan layer driver - emulates Vulkan ML extensions enable_vulkan_sdk=0 # Download and export Vulkan SDK required by emulation layer -mlsdk_manifest_url="https://github.com/arm/ai-ml-sdk-manifest.git" +enable_mlsdk_pip_install=0 # This is a temporary option that will soon be the default # Figure out if setup.sh was called or sourced and save it into "is_script_sourced" (return 0 2>/dev/null) && is_script_sourced=1 || is_script_sourced=0 @@ -36,6 +36,9 @@ toolchain_url="" toolchain_dir="" toolchain_md5_checksum="" +# Load logging helpers early so option parsing can emit status messages. +source "$et_dir/backends/arm/scripts/utils.sh" + # List of supported options and their descriptions OPTION_LIST=( @@ -49,6 +52,7 @@ OPTION_LIST=( "--enable-emulation-layer Enable MLSDK Vulkan emulation layer" "--disable-ethos-u-deps Do not setup what is needed for Ethos-U" "--enable-mlsdk-deps Setup what is needed for MLSDK" + "--install-mlsdk-deps-with-pip Use MLSDK PyPi package instead of building from source" "--mlsdk-manifest-url URL to the MLSDK manifest for vulkan." "--help Display help" ) @@ -138,6 +142,10 @@ function check_options() { enable_vela=0 shift ;; + --install-mlsdk-deps-with-pip) + enable_mlsdk_pip_install=1 + shift + ;; --enable-mlsdk-deps) enable_model_converter=1 enable_vgf_lib=1 @@ -145,19 +153,8 @@ function check_options() { enable_vulkan_sdk=1 shift ;; - --mlsdk-manifest-url) - # Ensure that there is a url provided. - if [[ -n "$2" && "${2:0:1}" != "-" ]]; then - mlsdk_manifest_url="$2" - shift 2 - else - echo "Error: --mlsdk-manifest-url requires a URL argument." - print_usage "$@" - exit 1 - fi - ;; --setup-test-dependency) - echo "Installing test dependency..." + log_step "deps" "Installing test dependency..." source $et_dir/backends/arm/scripts/install_models_for_test.sh exit 0 ;; @@ -174,19 +171,32 @@ function check_options() { } function setup_root_dir() { - mkdir -p ${root_dir} - root_dir=$(realpath ${root_dir}) + mkdir -p "${root_dir}" + root_dir=$(realpath "${root_dir}") + log_step "main" "Prepared root dir at ${root_dir}" setup_path_script="${root_dir}/setup_path" } function setup_ethos_u_tools() { + log_step "ethos-u-tools" "Installing Ethos-U Python tooling" CMAKE_POLICY_VERSION_MINIMUM=3.5 BUILD_PYBIND=1 pip install --no-dependencies -r $et_dir/backends/arm/requirements-arm-ethos-u.txt } +function setup_mlsdk_dependencies() { + log_step "mlsdk" "Installing MLSDK dependencies from pip" + pip install -r $et_dir/backends/arm/requirements-arm-vgf.txt +} + function create_setup_path(){ cd "${root_dir}" clear_setup_path + log_step "path" "Generating setup path scripts at ${setup_path_script}" + + local use_mlsdk_pip=0 + if use_mlsdk_pip_package; then + use_mlsdk_pip=1 + fi if [[ "${enable_fvps}" -eq 1 ]]; then setup_path_fvp @@ -200,20 +210,48 @@ function create_setup_path(){ setup_path_vulkan fi - if [[ "${enable_model_converter}" -eq 1 ]]; then + if [[ "${enable_model_converter}" -eq 1 && "${use_mlsdk_pip}" -eq 0 ]]; then setup_path_model_converter fi - if [[ "${enable_vgf_lib}" -eq 1 ]]; then + if [[ "${enable_vgf_lib}" -eq 1 && "${use_mlsdk_pip}" -eq 0 ]]; then setup_path_vgf_lib fi if [[ "${enable_emulation_layer}" -eq 1 ]]; then - setup_path_emulation_layer + if [[ "${use_mlsdk_pip}" -eq 0 ]]; then + setup_path_emulation_layer + else + setup_path_emulation_layer_from_pip + fi fi - echo "[main] Update path by running 'source ${setup_path_script}.sh'" - echo "[main] Or for fish shell use 'source ${setup_path_script}.fish'" + log_step "path" "Update PATH by sourcing ${setup_path_script}.{sh|fish}" +} + +function use_mlsdk_pip_package() { + os=$(uname -s) + arch=$(uname -m) + + if [[ "${enable_mlsdk_pip_install}" -eq 0 ]]; then + return 1 + fi + + if [[ "$os" == "Darwin" ]]; then + if [[ "${enable_mlsdk_pip_install}" -eq 1 ]]; then + log_step "mlsdk" "[error] MLSDK pip install not yet supported on MacOS" + exit 1 + fi + fi + + if [[ "$arch" == "arm64" || "$arch" == "aarch64" ]]; then + if [[ "${enable_mlsdk_pip_install}" -eq 1 ]]; then + log_step "mlsdk" "[error] MLSDK pip install not yet supported on aarch64" + exit 1 + fi + fi + + return 0 } @@ -228,12 +266,12 @@ if [[ $is_script_sourced -eq 0 ]]; then check_options "$@" # Import utils - source $et_dir/backends/arm/scripts/utils.sh source $et_dir/backends/arm/scripts/fvp_utils.sh source $et_dir/backends/arm/scripts/toolchain_utils.sh source $et_dir/backends/arm/scripts/vulkan_utils.sh + source $et_dir/backends/arm/scripts/mlsdk_utils.sh - echo "[main]: Checking platform and os" + log_step "main" "Checking platform and OS" check_platform_support check_os_support @@ -242,20 +280,21 @@ if [[ $is_script_sourced -eq 0 ]]; then # Setup the root dir setup_root_dir cd "${root_dir}" - echo "[main] Using root dir ${root_dir} and options:" - echo "enable-fvps=${enable_fvps}" - echo "target-toolchain=${target_toolchain}" - echo "enable-baremetal-toolchain=${enable_baremetal_toolchain}" - echo "enable-model-converter=${enable_model_converter}" - echo "enable-vgf-lib=${enable_vgf_lib}" - echo "enable-emulation-layer=${enable_emulation_layer}" - echo "enable-vulkan-sdk=${enable_vulkan_sdk}" - echo "enable-vela=${enable_vela}" - echo "mlsdk-manifest-url=${mlsdk_manifest_url}" + if [[ "${mlsdk_manifest_dir}" != /* ]]; then + mlsdk_manifest_dir="${root_dir}/${mlsdk_manifest_dir}" + fi + + log_step "options" \ + "root=${root_dir}, target-toolchain=${target_toolchain:-}, mlsdk-dir=${mlsdk_manifest_dir}" + log_step "options" \ + "ethos-u: fvps=${enable_fvps}, toolchain=${enable_baremetal_toolchain}, vela=${enable_vela} | " \ + "mlsdk: model-converter=${enable_model_converter}, vgf-lib=${enable_vgf_lib}, " \ + "emu-layer=${enable_emulation_layer}, vulkan-sdk=${enable_vulkan_sdk}" # Setup toolchain if [[ "${enable_baremetal_toolchain}" -eq 1 ]]; then + log_step "toolchain" "Configuring baremetal toolchain (${target_toolchain:-gnu})" # Select appropriate toolchain select_toolchain setup_toolchain @@ -263,6 +302,7 @@ if [[ $is_script_sourced -eq 0 ]]; then # Setup FVP if [[ "${enable_fvps}" -eq 1 ]]; then + log_step "fvp" "Setting up Arm Fixed Virtual Platforms" check_fvp_eula setup_fvp install_fvp @@ -270,26 +310,71 @@ if [[ $is_script_sourced -eq 0 ]]; then # Setup Vulkan SDK if [[ "${enable_vulkan_sdk}" -eq 1 ]]; then + log_step "vulkan" "Setting up Vulkan SDK" setup_vulkan_sdk fi if [[ "${enable_model_converter}" -eq 1 || \ "${enable_vgf_lib}" -eq 1 || \ "${enable_emulation_layer}" -eq 1 ]]; then - source $et_dir/backends/arm/scripts/mlsdk_utils.sh -u "${mlsdk_manifest_url}" - setup_model_converter ${root_dir} ${mlsdk_manifest_dir} ${enable_model_converter} ${enable_vgf_lib} ${enable_emulation_layer} + log_step "mlsdk" "Configuring MLSDK components (model-converter=${enable_model_converter}, " \ + "vgf-lib=${enable_vgf_lib}, emu-layer=${enable_emulation_layer})" + if use_mlsdk_pip_package; then + setup_mlsdk_dependencies + else + log_step "mlsdk" "Installing MLSDK dependencies from source" + setup_mlsdk ${root_dir} \ + ${mlsdk_manifest_dir} \ + ${enable_model_converter} \ + ${enable_vgf_lib} \ + ${enable_emulation_layer} + fi fi # Create the setup_path.sh used to create the PATH variable for shell create_setup_path - # Setup the tosa_reference_model and dependencies - CMAKE_POLICY_VERSION_MINIMUM=3.5 BUILD_PYBIND=1 pip install --no-dependencies -r $et_dir/backends/arm/requirements-arm-tosa.txt + # Setup the TOSA reference model and serialization dependencies + log_step "deps" "Installing TOSA reference model dependencies" + CMAKE_POLICY_VERSION_MINIMUM=3.5 \ + pip install --no-dependencies -r "$et_dir/backends/arm/requirements-arm-tosa.txt" + + pushd "$root_dir" + if [[ ! -d "tosa-tools" ]]; then + git clone https://git.gitlab.arm.com/tosa/tosa-tools.git + fi + + pushd tosa-tools + git fetch origin main + git checkout 8468d041c50c6d806f3c1c18c66d7ef641e46580 # serialization lib pybindings + git cherry-pick 368f0cd745b2a1569bf36f077daeba95775de192 # perf fix for >2gb models + if [[ ! -d "reference_model" ]]; then + log_step "main" "[error] Missing reference_model directory in tosa-tools repo." + exit 1 + fi + if [[ ! -d "serialization" ]]; then + log_step "main" "[error] Missing serialization directory in tosa-tools repo." + exit 1 + fi + + + export CMAKE_BUILD_PARALLEL_LEVEL="$(get_parallel_jobs)" + + CMAKE_POLICY_VERSION_MINIMUM=3.5 \ + BUILD_PYBIND=1 \ + pip install --no-dependencies ./reference_model + + CMAKE_POLICY_VERSION_MINIMUM=3.5 \ + BUILD_PYBIND=1 \ + pip install --no-dependencies ./serialization + popd + popd if [[ "${enable_vela}" -eq 1 ]]; then + log_step "deps" "Installing Ethos-U Vela compiler" setup_ethos_u_tools fi - echo "[main] success!" + log_step "main" "Setup complete" exit 0 fi diff --git a/examples/arm/ubsan/CMakeLists.txt b/examples/arm/ubsan/CMakeLists.txt new file mode 100644 index 00000000000..8d5d23211b1 --- /dev/null +++ b/examples/arm/ubsan/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +add_library(executorch_ubsan STATIC ubsan_runtime.c) + +target_compile_features(executorch_ubsan PRIVATE c_std_11) + +target_compile_options(executorch_ubsan PRIVATE -fno-sanitize=undefined) + +set_target_properties(executorch_ubsan PROPERTIES OUTPUT_NAME "ubsan") + +install( + TARGETS executorch_ubsan + EXPORT ExecuTorchTargets + ARCHIVE DESTINATION lib +) diff --git a/examples/arm/ubsan/ubsan_runtime.c b/examples/arm/ubsan/ubsan_runtime.c new file mode 100644 index 00000000000..62f411073ba --- /dev/null +++ b/examples/arm/ubsan/ubsan_runtime.c @@ -0,0 +1,488 @@ +/* Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#ifndef UBSAN_RUNTIME_PREFIX +#define UBSAN_RUNTIME_PREFIX "[UBSAN] " +#endif + +typedef struct { + const char* filename; + uint32_t line; + uint32_t column; +} __ubsan_source_location; + +typedef struct { + uint16_t type_kind; + uint16_t type_info; + char type_name[]; +} __ubsan_type_descriptor; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* type; +} __ubsan_overflow_data; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* lhs_type; + const __ubsan_type_descriptor* rhs_type; +} __ubsan_shift_out_of_bounds_data; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* array_type; + const __ubsan_type_descriptor* index_type; +} __ubsan_out_of_bounds_data; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* type; + uint8_t log_alignment; + uint8_t type_check_kind; +} __ubsan_type_mismatch_data_v1; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* type; +} __ubsan_vla_bound_data; + +typedef struct { + __ubsan_source_location location; + __ubsan_source_location attr_location; +} __ubsan_nonnull_return_data_v1; + +typedef struct { + __ubsan_source_location location; + __ubsan_source_location attr_location; + uint8_t arg_index; +} __ubsan_nullability_arg_data; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* from_type; + const __ubsan_type_descriptor* to_type; +} __ubsan_float_cast_overflow_data; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* type; +} __ubsan_invalid_value_data; + +typedef struct { + __ubsan_source_location location; + __ubsan_source_location attr_location; + uint32_t arg_index; +} __ubsan_nonnull_arg_data; + +typedef struct { + __ubsan_source_location location; +} __ubsan_pointer_overflow_data; + +typedef struct { + __ubsan_source_location location; + __ubsan_source_location assumption_location; + uint64_t alignment; + uint8_t type_check_kind; +} __ubsan_alignment_assumption_data; + +static const char* ubsan_get_type_name(const __ubsan_type_descriptor* type) { + if (!type) { + return ""; + } + return type->type_name; +} + +static const char* ubsan_type_check_kind_string(uint8_t kind) { + switch (kind) { + case 0: + return "load of"; + case 1: + return "store to"; + case 2: + return "reference binding to"; + case 3: + return "member access within"; + case 4: + return "member call on"; + case 5: + return "constructor call for"; + case 6: + return "downcast of"; + case 7: + return "downcast of"; + case 8: + return "upcast of"; + case 9: + return "cast to virtual base of"; + default: + return "use of"; + } +} + +static uintptr_t ubsan_ptr_value(const void* ptr) { + return (uintptr_t)ptr; +} + +static void ubsan_abort(void) { +#if defined(__GNUC__) + __builtin_trap(); +#else + abort(); +#endif + while (1) { + } +} + +static void ubsan_print_location(const __ubsan_source_location* loc) { + if (!loc || !loc->filename) { + printf(UBSAN_RUNTIME_PREFIX "unknown location: "); + return; + } + printf(UBSAN_RUNTIME_PREFIX "%s:%u:%u: ", loc->filename, loc->line, + loc->column); +} + +static void ubsan_report_with_message(const __ubsan_source_location* loc, + const char* message) { + ubsan_print_location(loc); + printf("%s\n", message); + fflush(stdout); + ubsan_abort(); +} + +static void ubsan_report_overflow(const __ubsan_overflow_data* data, + const char* op, + uintptr_t lhs, + uintptr_t rhs) { + const char* type_name = ubsan_get_type_name(data->type); + char message[256]; + snprintf( + message, + sizeof(message), + "%s on type '%s' (lhs=0x%08" PRIxPTR ", rhs=0x%08" PRIxPTR ")", + op, + type_name, + lhs, + rhs); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_add_overflow(__ubsan_overflow_data* data, void* lhs, + void* rhs) { + ubsan_report_overflow( + data, + "addition overflow", + ubsan_ptr_value(lhs), + ubsan_ptr_value(rhs)); +} + +void __ubsan_handle_sub_overflow(__ubsan_overflow_data* data, void* lhs, + void* rhs) { + ubsan_report_overflow( + data, + "subtraction overflow", + ubsan_ptr_value(lhs), + ubsan_ptr_value(rhs)); +} + +void __ubsan_handle_mul_overflow(__ubsan_overflow_data* data, void* lhs, + void* rhs) { + ubsan_report_overflow( + data, + "multiplication overflow", + ubsan_ptr_value(lhs), + ubsan_ptr_value(rhs)); +} + +void __ubsan_handle_negate_overflow(__ubsan_overflow_data* data, void* value) { + ubsan_report_overflow( + data, + "negation overflow", + ubsan_ptr_value(value), + 0); +} + +void __ubsan_handle_divrem_overflow(__ubsan_overflow_data* data, void* lhs, + void* rhs) { + ubsan_report_overflow( + data, + "division remainder overflow", + ubsan_ptr_value(lhs), + ubsan_ptr_value(rhs)); +} + +void __ubsan_handle_shift_out_of_bounds(__ubsan_shift_out_of_bounds_data* data, + void* lhs, void* rhs) { + const char* lhs_type = ubsan_get_type_name(data->lhs_type); + const char* rhs_type = ubsan_get_type_name(data->rhs_type); + uintptr_t lhs_val = ubsan_ptr_value(lhs); + uintptr_t rhs_val = ubsan_ptr_value(rhs); + char message[256]; + snprintf( + message, + sizeof(message), + "shift out of bounds (lhs=0x%08" PRIxPTR " of type '%s', rhs=0x%08" PRIxPTR + " of type '%s')", + lhs_val, + lhs_type, + rhs_val, + rhs_type); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_out_of_bounds(__ubsan_out_of_bounds_data* data, + void* index) { + uintptr_t idx_val = ubsan_ptr_value(index); + const char* idx_type = ubsan_get_type_name(data->index_type); + const char* array_type = ubsan_get_type_name(data->array_type); + char message[256]; + snprintf( + message, + sizeof(message), + "index out of bounds (index=0x%08" PRIxPTR " of type '%s' on array '%s')", + idx_val, + idx_type, + array_type); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_type_mismatch_v1(__ubsan_type_mismatch_data_v1* data, + void* ptr) { + uintptr_t address = (uintptr_t)ptr; + size_t alignment = + (data->log_alignment < (sizeof(size_t) * 8)) + ? ((size_t)1 << data->log_alignment) + : 0; + const char* type_name = ubsan_get_type_name(data->type); + const char* check_desc = ubsan_type_check_kind_string(data->type_check_kind); + + char message[256]; + if (address == 0) { + snprintf( + message, + sizeof(message), + "%s null pointer of type '%s'", + check_desc, + type_name); + } else if (alignment && (address & (alignment - 1))) { + snprintf( + message, + sizeof(message), + "%s misaligned address 0x%08" PRIxPTR " for type '%s' (alignment %zu)", + check_desc, + address, + type_name, + alignment); + } else { + snprintf( + message, + sizeof(message), + "%s address 0x%08" PRIxPTR " with insufficient alignment for type '%s'", + check_desc, + address, + type_name); + } + + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_vla_bound_not_positive(__ubsan_vla_bound_data* data, + void* bound) { + uintptr_t bound_val = ubsan_ptr_value(bound); + char message[256]; + snprintf( + message, + sizeof(message), + "variable length array bound (%" PRIuPTR ") is not positive", + (uintptr_t)bound_val); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_load_invalid_value(__ubsan_invalid_value_data* data, + void* pointer) { + uintptr_t addr = ubsan_ptr_value(pointer); + const char* type_name = ubsan_get_type_name(data->type); + char message[256]; + snprintf( + message, + sizeof(message), + "load of invalid value at 0x%08" PRIxPTR " for type '%s'", + addr, + type_name); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_nonnull_return_v1(__ubsan_nonnull_return_data_v1* data, + __ubsan_source_location* where) { + (void)where; // Some toolchains leave this null; attr_location is reliable. + char message[256]; + if (data->attr_location.filename) { + snprintf( + message, + sizeof(message), + "null pointer returned from function marked 'returns_nonnull' " + "(attribute at %s:%u:%u)", + data->attr_location.filename, + data->attr_location.line, + data->attr_location.column); + } else { + snprintf( + message, + sizeof(message), + "null pointer returned from function marked 'returns_nonnull'"); + } + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_nullability_return_v1( + __ubsan_nonnull_return_data_v1* data, __ubsan_source_location* where) { + (void)where; // Some toolchains leave this null; attr_location is reliable. + char message[256]; + snprintf( + message, + sizeof(message), + "null returned from non-null return (attribute at %s:%u:%u)", + data->attr_location.filename ? data->attr_location.filename : "", + data->attr_location.line, + data->attr_location.column); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_nullability_arg_v1(__ubsan_nullability_arg_data* data, + __ubsan_source_location* where) { + (void)where; // Some toolchains leave this null; attr_location is reliable. + char message[256]; + snprintf( + message, + sizeof(message), + "null passed to non-null argument #%u (attribute at %s:%u:%u)", + data->arg_index, + data->attr_location.filename ? data->attr_location.filename : "", + data->attr_location.line, + data->attr_location.column); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_nonnull_arg(__ubsan_nonnull_arg_data* data) { + char message[256]; + snprintf( + message, + sizeof(message), + "null pointer passed to argument marked 'nonnull' (argument #%u, attribute at %s:%u:%u)", + data->arg_index, + data->attr_location.filename ? data->attr_location.filename : "", + data->attr_location.line, + data->attr_location.column); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_float_cast_overflow( + __ubsan_float_cast_overflow_data* data, void* from) { + uintptr_t raw = ubsan_ptr_value(from); + const char* from_type = ubsan_get_type_name(data->from_type); + const char* to_type = ubsan_get_type_name(data->to_type); + char message[256]; + snprintf( + message, + sizeof(message), + "floating point cast overflow (value bits=0x%08" PRIxPTR + ", from '%s' to '%s')", + raw, + from_type, + to_type); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_pointer_overflow(__ubsan_pointer_overflow_data* data, + void* base, void* result) { + uintptr_t base_val = ubsan_ptr_value(base); + uintptr_t result_val = ubsan_ptr_value(result); + char message[256]; + snprintf( + message, + sizeof(message), + "pointer overflow (base=0x%08" PRIxPTR ", result=0x%08" PRIxPTR ")", + base_val, + result_val); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_alignment_assumption( + __ubsan_alignment_assumption_data* data, void* pointer, + void* alignment, void* offset) { + uintptr_t ptr_val = ubsan_ptr_value(pointer); + uintptr_t align_val = ubsan_ptr_value(alignment); + uintptr_t offset_val = ubsan_ptr_value(offset); + char message[256]; + snprintf( + message, + sizeof(message), + "alignment assumption violated (ptr=0x%08" PRIxPTR ", alignment=%" PRIuPTR + ", offset=%" PRIuPTR ", required alignment=%" PRIu64 ")", + ptr_val, + align_val, + offset_val, + (unsigned long long)data->alignment); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_builtin_unreachable(__ubsan_source_location* location) { + ubsan_report_with_message(location, "execution reached an unreachable point"); +} + +void __ubsan_handle_missing_return(__ubsan_source_location* location) { + ubsan_report_with_message(location, + "control reached end of void function without " + "returning"); +} + +void __ubsan_handle_invalid_builtin(__ubsan_source_location* location) { + ubsan_report_with_message(location, "invalid builtin usage"); +} + +void __ubsan_handle_cfi_check_fail(__ubsan_source_location* location, + void* data, void* vtable) { + uintptr_t type_hash = ubsan_ptr_value(data); + uintptr_t vtable_ptr = ubsan_ptr_value(vtable); + char message[256]; + snprintf( + message, + sizeof(message), + "control-flow integrity check failed (type hash=0x%08" PRIxPTR + ", vtable=0x%08" PRIxPTR ")", + type_hash, + vtable_ptr); + ubsan_report_with_message(location, message); +} + +void __ubsan_handle_cfi_check_fail_abort(__ubsan_source_location* location, + void* data, void* vtable) { + __ubsan_handle_cfi_check_fail(location, data, vtable); +} + +void __ubsan_handle_dynamic_type_cache_miss(void* data, void* ptr) { + uintptr_t type_hash = ubsan_ptr_value(data); + uintptr_t object_ptr = ubsan_ptr_value(ptr); + printf( + UBSAN_RUNTIME_PREFIX + "dynamic type cache miss (type hash=0x%08" PRIxPTR ", object=0x%08" PRIxPTR + ")\n", + type_hash, + object_ptr); + fflush(stdout); + ubsan_abort(); +} + +void __ubsan_on_error(void) { + printf(UBSAN_RUNTIME_PREFIX "runtime error detected\n"); + fflush(stdout); + ubsan_abort(); +} diff --git a/examples/arm/vgf_minimal_example.ipynb b/examples/arm/vgf_minimal_example.ipynb index 35378817a7d..f01dfd8d977 100644 --- a/examples/arm/vgf_minimal_example.ipynb +++ b/examples/arm/vgf_minimal_example.ipynb @@ -24,7 +24,7 @@ "Before you begin:\n", "1. (In a clean virtual environment with a compatible Python version) Install executorch using `./install_executorch.sh`\n", "2. Install MLSDK and Tosa using `examples/arm/setup.sh --disable-ethos-u-deps --enable-mlsdk-deps` (For further guidance, refer to https://docs.pytorch.org/executorch/main/tutorial-arm.html)\n", - "3. Export vulkan environment variables and add MLSDK components to PATH and LD_LIBRARY_PATH using `examples/arm/ethos-u-scratch/setup_path.sh`\n", + "3. Export vulkan environment variables and add MLSDK components to PATH and LD_LIBRARY_PATH using `examples/arm/arm-scratch/setup_path.sh`\n", "\n", "With all commands executed from the base `executorch` folder.\n", "\n", @@ -56,8 +56,8 @@ "\n", "model = Add()\n", "model = model.eval()\n", - "exported_program = torch.export.export_for_training(model, example_inputs)\n", - "graph_module = exported_program.module()\n", + "exported_program = torch.export.export(model, example_inputs)\n", + "graph_module = exported_program.graph_module\n", "\n", "_ = graph_module.print_readable()" ] @@ -82,21 +82,15 @@ "metadata": {}, "outputs": [], "source": [ - "from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder\n", - "from executorch.backends.arm.tosa import ( \n", - " TosaSpecification,\n", - ")\n", + "from executorch.backends.arm.vgf import VgfCompileSpec\n", "\n", "# Create a compilation spec describing the floating point target.\n", - "tosa_spec = TosaSpecification.create_from_string(\"TOSA-1.0+FP\")\n", - "\n", - "spec_builder = ArmCompileSpecBuilder().vgf_compile_spec(tosa_spec)\n", - "compile_spec = spec_builder.build()\n", + "compile_spec = VgfCompileSpec(\"TOSA-1.0+FP\")\n", "\n", "_ = graph_module.print_readable()\n", "\n", "# Create a new exported program using the graph_module\n", - "exported_program = torch.export.export_for_training(graph_module, example_inputs)" + "exported_program = torch.export.export(graph_module, example_inputs)" ] }, { @@ -122,13 +116,11 @@ " VgfQuantizer,\n", " get_symmetric_quantization_config,\n", ")\n", + "from executorch.backends.arm.vgf import VgfCompileSpec\n", "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "\n", "# Create a compilation spec describing the target for configuring the quantizer\n", - "tosa_spec = TosaSpecification.create_from_string(\"TOSA-1.0+INT\")\n", - "\n", - "spec_builder = ArmCompileSpecBuilder().vgf_compile_spec(tosa_spec)\n", - "compile_spec = spec_builder.build()\n", + "compile_spec = VgfCompileSpec(\"TOSA-1.0+INT\")\n", "\n", "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", "quantizer = VgfQuantizer(compile_spec)\n", @@ -143,7 +135,7 @@ "_ = quantized_graph_module.print_readable()\n", "\n", "# Create a new exported program using the quantized_graph_module\n", - "quantized_exported_program = torch.export.export_for_training(quantized_graph_module, example_inputs)" + "quantized_exported_program = torch.export.export(quantized_graph_module, example_inputs)" ] }, { @@ -171,7 +163,7 @@ "source": [ "%%bash\n", "# Ensure the vulkan environment variables and MLSDK components are available on $PATH\n", - "source ethos-u-scratch/setup_path.sh" + "source arm-scratch/setup_path.sh" ] }, { @@ -206,7 +198,7 @@ " config=ExecutorchBackendConfig(extract_delegate_segments=False)\n", ")\n", "\n", - "executorch_program_manager.exported_program().module().print_readable()\n", + "executorch_program_manager.exported_program().graph_module.print_readable()\n", "\n", "# Save pte file\n", "cwd_dir = os.getcwd()\n", @@ -240,7 +232,7 @@ "source": [ "%%bash\n", "# Ensure the vulkan environment variables and MLSDK components are available on $PATH\n", - "source ethos-u-scratch/setup_path.sh\n", + "source arm-scratch/setup_path.sh\n", "\n", "# Compiled programs will appear in the executorch/cmake-out directory we create here.\n", "# Build example executor runner application to examples/arm/vgf_minimal_example\n", @@ -249,6 +241,7 @@ " -DCMAKE_BUILD_TYPE=Debug \\\n", " -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \\\n", " -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \\\n", + " -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \\\n", " -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \\\n", " -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \\\n", " -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \\\n", diff --git a/examples/arm/visualize.py b/examples/arm/visualize.py new file mode 100644 index 00000000000..a176931e44f --- /dev/null +++ b/examples/arm/visualize.py @@ -0,0 +1,294 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import gzip +import io +import json +import xml.etree.ElementTree as ET # nosec B405 +from pathlib import Path + +from typing import Any, Callable, Dict, Iterable, NamedTuple, Union + +import pandas as pd + +from executorch.devtools.visualization.visualization_utils import ( + visualize_model_explorer, +) +from model_explorer import config as model_explorer_config, node_data_builder as ndb + +COMPILER_OP_ID = "scheduled_id" + + +class Tables(NamedTuple): + queue: pd.DataFrame + group: pd.DataFrame + perf: pd.DataFrame + source: pd.DataFrame + + +def parse_tables(tables_path: Path) -> Tables: + """ + Parse the XML debug tables file and extract required tables as pandas DataFrames. + """ + required_tables = {"queue", "group", "perf", "source"} + try: + tree = ET.parse(tables_path) # nosec B314 + except ET.ParseError as e: + raise ValueError(f"Failed to parse XML tables file {tables_path}: {e}") + + tables: Dict[str, pd.DataFrame] = {} + for table in tree.getroot().findall("table"): + name = table.attrib.get("name") + if name in required_tables: + text = table.text or "" + tables[name] = pd.read_csv(io.StringIO(text)) + + missing = required_tables - tables.keys() + if missing: + raise ValueError(f"Missing required tables in XML: {missing}") + + return Tables(**tables) + + +def get_trace_file_objects(trace_file_path: Path) -> list[Dict[str, Any]]: + """ + Load and return the 'traceEvents' list from a gzip-compressed JSON trace file. + """ + try: + with gzip.open(trace_file_path, "rt", encoding="utf-8") as file: + data = json.load(file) + except (OSError, json.JSONDecodeError) as e: + raise ValueError(f"Failed to read or parse trace file {trace_file_path}: {e}") + + if "traceEvents" not in data: + raise KeyError(f"'traceEvents' key not found in {trace_file_path}") + + return data["traceEvents"] + + +def get_subops(df_group: pd.DataFrame) -> set: + return set(df_group[df_group["id"] != df_group["group_id"]]["id"]) + + +def transform_events( + objects: Iterable[Dict[str, Any]], queue_df: pd.DataFrame, sub_ops: set +) -> None: + """ + Annotate the 'queue' table in-place with duration based on trace events. + """ + queue_df_len = len(queue_df) + offsets = queue_df["offset"].astype(int) + + start_ts, cmd_index, chain_len = 0, 0, 1 + + def is_end_of_command(qread_offset: int, end_idx: int) -> bool: + if end_idx >= queue_df_len: + return qread_offset > offsets[cmd_index] + return qread_offset == offsets[end_idx] + + for event in (e for e in objects if e.get("tid") == "qread"): + if cmd_index >= queue_df_len: + break + + qread_offset = 4 * int(event["args"]["qread"]) + + while (cmd_index + chain_len <= queue_df_len - 1) and queue_df.iloc[ + cmd_index + chain_len + ]["scheduled_id"] in sub_ops: + chain_len += 1 + + end_idx = cmd_index + chain_len + if is_end_of_command(qread_offset, end_idx): + end_ts = int(event["ts"]) - 1 + queue_df.loc[cmd_index, ["duration"]] = [ + end_ts - start_ts, + ] + start_ts = end_ts + cmd_index = end_idx + chain_len = 1 + + +Agg = Union[str, Callable[[pd.Series], Any]] + + +def list_unique(s: pd.Series) -> list[Any]: + return sorted(set(s.dropna())) + + +def build_perf_df(tables: Tables) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Build a performance DataFrame summarizing queue metrics grouped by source_id. + Returns a tuple of (perf_df, cmd_to_op_df) where cmd_to_op_df is needed for unmapped op tracking. + """ + tables.queue["cmd_id"] = tables.queue.index + + excluded = {"optimised_id", "scheduled_id", "offset"} + col_funcs: Dict[str, Agg] = { + c: "sum" for c in tables.queue.columns if c not in excluded + } + + col_funcs.update({"cmdstream_id": list_unique, "cmd_id": list_unique}) + + cmd_to_op_df = tables.queue.groupby(COMPILER_OP_ID).agg(col_funcs).reset_index() + + opt_df = ( + pd.merge(tables.perf[["id", "source_id"]], tables.group, on="id", how="left") + .rename(columns={"id": COMPILER_OP_ID}) + .merge(cmd_to_op_df, on=COMPILER_OP_ID, how="inner") + ) + + exclude_columns = ["source_id"] + src_col_funcs: Dict[str, Agg] = { + col: "sum" for col in opt_df.columns if col not in exclude_columns + } + src_col_funcs[COMPILER_OP_ID] = list_unique + + perf_df = opt_df.groupby("source_id").agg(src_col_funcs).reset_index() + + return perf_df, cmd_to_op_df + + +def check_unmapped_ops( + tables: Tables, src_df: pd.DataFrame, cmd_to_op_df: pd.DataFrame +) -> None: + """ + Identify operators in the performance data that are not mapped to any source operation. + """ + opt_ids_in_src_table = set() + for opt_ids in src_df[COMPILER_OP_ID].dropna(): + if type(opt_ids) is list: + opt_ids_in_src_table.update(opt_ids) + + opt_df = pd.merge( + tables.perf[["id", "source_id"]], tables.group, on="id", how="left" + ) + opt_df = opt_df.rename(columns={"id": COMPILER_OP_ID}) + opt_df = pd.merge(opt_df, cmd_to_op_df, on=COMPILER_OP_ID, how="inner") + + unmapped_operators = opt_df[ + ~opt_df[COMPILER_OP_ID].isin(list(opt_ids_in_src_table)) + ] + + if not unmapped_operators.empty: + print("Warning: There are unmapped operators in the performance data.") + print(unmapped_operators) + return None + + +def build_src_df(tables: Tables, perf_df: pd.DataFrame) -> pd.DataFrame: + """ + Merge source table with performance metrics and total NPU cycles. + Returns a tuple of (src_df, cmd_to_op_df) where df_cmd_to_op is needed for unmapped op tracking. + """ + return pd.merge( + tables.source.rename(columns={"id": "source_id"})[["ext_key", "source_id"]], + perf_df, + on="source_id", + how="left", + ).merge( + tables.perf[["source_id", "npu_cycles"]] + .groupby("source_id") + .sum(numeric_only=True) + .reset_index(), + on="source_id", + how="left", + ) + + +def get_model_node_data(df: pd.DataFrame) -> ndb.ModelNodeData: + """ + Convert source-level metrics into ModelExplorer node data for duration. + """ + durations = df["duration"].fillna(0).astype(int) + + duration_results: Dict[str, ndb.NodeDataResult] = {} + + for src, dur in zip(df["ext_key"], durations): + node_id = f"main/op{int(src)}" + duration_results[node_id] = ndb.NodeDataResult(value=int(dur)) + + gradient = [ + ndb.GradientItem(stop=0.0, bgColor="#ffffff"), + ndb.GradientItem(stop=0.1, bgColor="#33FF00"), + ndb.GradientItem(stop=0.2, bgColor="#66FF00"), + ndb.GradientItem(stop=0.5, bgColor="#FFFF00"), + ndb.GradientItem(stop=0.7, bgColor="#FF6600"), + ndb.GradientItem(stop=1.0, bgColor="#FF0000"), + ] + + return ndb.ModelNodeData( + graphsData={ + "main": ndb.GraphNodeData(results=duration_results, gradient=gradient) + } + ) + + +def build_overlay_data(trace_path: Path, tables_path: Path) -> ndb.ModelNodeData: + """ + Build ModelExplorer node data from trace and tables files. + """ + tables = parse_tables(tables_path) + events = get_trace_file_objects(trace_path) + transform_events(events, tables.queue, get_subops(tables.group)) + perf_df, cmd_to_op_df = build_perf_df(tables) + src_df = build_src_df(tables, perf_df) + check_unmapped_ops(tables, src_df, cmd_to_op_df) + + return get_model_node_data(src_df) + + +def validate_file_exists(file_path: Path) -> None: + if not file_path.exists(): + raise FileNotFoundError(f"{file_path} not found") + + +def validate_perf_mode_args(trace: str, tables: str) -> None: + if not (trace and tables): + raise ValueError( + "Both --trace and --tables must be provided for perf mode, or neither for default mode" + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Visualize a model using model explorer." + ) + parser.add_argument( + "--model_path", required=True, type=str, help="Path to the model file" + ) + parser.add_argument( + "--trace", + required=False, + help="(perf mode) PMU trace JSON.gz file with performance data", + ) + parser.add_argument( + "--tables", + required=False, + help="(perf mode) Vela debug database tables XML file", + ) + + args = parser.parse_args() + model_file = Path(args.model_path).resolve() + validate_file_exists(model_file) + + config = model_explorer_config().add_model_from_path(str(model_file)) + + if args.trace or args.tables: + validate_perf_mode_args(args.trace, args.tables) + trace_file = Path(args.trace).resolve() + tables_file = Path(args.tables).resolve() + validate_file_exists(trace_file) + validate_file_exists(tables_file) + + config.add_node_data( + "Duration (Cycles)", build_overlay_data(trace_file, tables_file) + ) + + visualize_model_explorer(config=config, extensions=["tosa_adapter_model_explorer"]) + + +if __name__ == "__main__": + main() diff --git a/examples/cadence/models/babyllama.py b/examples/cadence/models/babyllama.py index 1b576a1a3eb..f393cd30037 100644 --- a/examples/cadence/models/babyllama.py +++ b/examples/cadence/models/babyllama.py @@ -14,8 +14,10 @@ from executorch.backends.cadence.aot.export_example import export_and_run_model -from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer - +from executorch.examples.models.llama.llama_transformer import ( + construct_transformer, + ModelArgs, +) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -32,7 +34,7 @@ def main() -> None: ) seq = 64 b = 1 - model = Transformer(args) + model = construct_transformer(args) example_inputs = (torch.randint(0, 10, [b, seq], dtype=torch.int64),) export_and_run_model(model, example_inputs) diff --git a/backends/nxp/backend/ir/edge_passes/__init__.py b/examples/cuda/scripts/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from backends/nxp/backend/ir/edge_passes/__init__.py rename to examples/cuda/scripts/__init__.py diff --git a/examples/cuda/scripts/export.py b/examples/cuda/scripts/export.py new file mode 100644 index 00000000000..c103d7ee50a --- /dev/null +++ b/examples/cuda/scripts/export.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer with CUDA delegate. + +import argparse +import pathlib + +import torch + +from executorch.backends.cuda.cuda_backend import CudaBackend + +from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + +from executorch.examples.models import MODEL_NAME_TO_MODEL +from executorch.examples.models.model_factory import EagerModelFactory + +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower + +from executorch.extension.export_util.utils import save_pte_program +from torch._inductor.decomposition import conv1d_to_conv2d +from torch.nn.attention import SDPBackend + +# Script to export a model with CUDA delegation. + +_EDGE_COMPILE_CONFIG = EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, # TODO(T182928844): enable dim_order in backend +) + + +def is_fbcode(): + return not hasattr(torch.version, "git_version") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--model_name", + required=True, + help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", + ) + parser.add_argument( + "--output_dir", + type=pathlib.Path, + default=pathlib.Path("./"), + help="Output directory for the exported model", + ) + parser.add_argument("--generate_etrecord", action=argparse.BooleanOptionalAction) + parser.add_argument("--save_processed_bytes", action=argparse.BooleanOptionalAction) + + args = parser.parse_args() + return args + + +def save_processed_bytes(processed_bytes, base_name: str): + filename = f"{base_name}.bin" + print(f"Saving processed bytes to {filename}") + with open(filename, "wb") as file: + file.write(processed_bytes) + return + + +def main(): + args = parse_args() + + if args.model_name not in MODEL_NAME_TO_MODEL: + raise RuntimeError( + f"Model {args.model_name} is not a valid name. " + f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." + ) + + ( + model, + example_args, + example_kwargs, + dynamic_shapes, + ) = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL[args.model_name]) + model = model.eval() + exported_programs = torch.export.export( + model, + args=example_args, + kwargs=example_kwargs, + dynamic_shapes=dynamic_shapes, + ) + print(exported_programs) + + partitioner = CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec(args.model_name)] + ) + # Add decompositions for triton to generate kernels. + exported_programs = exported_programs.run_decompositions( + { + torch.ops.aten.conv1d.default: conv1d_to_conv2d, + } + ) + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]): + et_prog = to_edge_transform_and_lower( + exported_programs, + partitioner=[partitioner], + compile_config=_EDGE_COMPILE_CONFIG, + generate_etrecord=args.generate_etrecord, + ) + exec_program = et_prog.to_executorch() + save_pte_program(exec_program, args.model_name, args.output_dir) + if args.generate_etrecord: + exec_program.get_etrecord().save(f"{args.model_name}_etrecord.bin") + + +if __name__ == "__main__": + main() diff --git a/examples/demo-apps/android/LlamaDemo/.gitignore b/examples/demo-apps/android/LlamaDemo/.gitignore deleted file mode 100644 index 41853c0472c..00000000000 --- a/examples/demo-apps/android/LlamaDemo/.gitignore +++ /dev/null @@ -1,12 +0,0 @@ -*.iml -.gradle -/local.properties -.idea -.DS_Store -/build -/captures -.externalNativeBuild -.cxx -local.properties -*.so -*.aar diff --git a/examples/demo-apps/android/LlamaDemo/README.md b/examples/demo-apps/android/LlamaDemo/README.md deleted file mode 100644 index 9a6b3b020e7..00000000000 --- a/examples/demo-apps/android/LlamaDemo/README.md +++ /dev/null @@ -1,174 +0,0 @@ -# ExecuTorch Llama Android Demo App - -**[UPDATE - 2025-05-15]** We have added support for running Qwen3 0.6B and 4B model. Please see [this tutorial](https://github.com/pytorch/executorch/tree/main/examples/models/qwen3#summary) for export. Loading and running Qwen3 with this app is the same as Llama, as in this doc. - -We’re excited to share that the newly revamped Android demo app is live and includes many new updates to provide a more intuitive and smoother user experience with a chat use case! The primary goal of this app is to showcase how easily ExecuTorch can be integrated into an Android demo app and how to exercise the many features ExecuTorch and Llama models have to offer. - -This app serves as a valuable resource to inspire your creativity and provide foundational code that you can customize and adapt for your particular use case. - -Please dive in and start exploring our demo app today! We look forward to any feedback and are excited to see your innovative ideas. - - -## Key Concepts -From this demo app, you will learn many key concepts such as: -* How to prepare Llama models, build the ExecuTorch library, and model inferencing across delegates -* Expose the ExecuTorch library via JNI layer -* Familiarity with current ExecuTorch app-facing capabilities - -The goal is for you to see the type of support ExecuTorch provides and feel comfortable with leveraging it for your use cases. - -## Supporting Models -As a whole, the models that this app supports are (varies by delegate): -* Llama 3.2 Quantized 1B/3B -* Llama 3.2 1B/3B in BF16 -* Llama Guard 3 1B -* Llama 3.1 8B -* Llama 3 8B -* Llama 2 7B -* LLaVA-1.5 vision model (only XNNPACK) -* Qwen 3 0.6B, 1.7B, and 4B - - -## Building the APK -First it’s important to note that currently ExecuTorch provides support across 3 delegates. Once you identify the delegate of your choice, select the README link to get a complete end-to-end instructions for environment set-up to exporting the models to build ExecuTorch libraries and apps to run on device: - -| Delegate | Resource | -| ------------- | ------------- | -| XNNPACK (CPU-based library) | [link](https://github.com/pytorch/executorch/blob/main/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md) | -| QNN (Qualcomm AI Accelerators) | [link](https://github.com/pytorch/executorch/blob/main/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md) | -| MediaTek (MediaTek AI Accelerators) | [link](https://github.com/pytorch/executorch/blob/main/examples/demo-apps/android/LlamaDemo/docs/delegates/mediatek_README.md) | - - -## How to Use the App - -This section will provide the main steps to use the app, along with a code snippet of the ExecuTorch API. - -For loading the app, development, and running on device we recommend Android Studio: -1. Open Android Studio and select "Open an existing Android Studio project" to open examples/demo-apps/android/LlamaDemo. -2. Run the app (^R). This builds and launches the app on the phone. - -### Opening the App - -Below are the UI features for the app. - -Select the settings widget to get started with picking a model, its parameters and any prompts. -

- -

- - - -### Select Models and Parameters - -Once you've selected the model, tokenizer, and model type you are ready to click on "Load Model" to have the app load the model and go back to the main Chat activity. -

- -

- - - -Optional Parameters: -* Temperature: Defaulted to 0, you can adjust the temperature for the model as well. The model will reload upon any adjustments. -* System Prompt: Without any formatting, you can enter in a system prompt. For example, "you are a travel assistant" or "give me a response in a few sentences". -* User Prompt: More for the advanced user, if you would like to manually input a prompt then you can do so by modifying the `{{user prompt}}`. You can also modify the special tokens as well. Once changed then go back to the main Chat activity to send. - -#### ExecuTorch App API - -```java -// Upon returning to the Main Chat Activity -mModule = new LlmModule( - ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()), - modelPath, - tokenizerPath, - temperature); -int loadResult = mModule.load(); -``` - -* `modelCategory`: Indicate whether it’s a text-only or vision model -* `modePath`: path to the .pte file -* `tokenizerPath`: path to the tokenizer file -* `temperature`: model parameter to adjust the randomness of the model’s output - - -### User Prompt -Once model is successfully loaded then enter any prompt and click the send (i.e. generate) button to send it to the model. -

- -

- -You can provide it more follow-up questions as well. -

- -

- -#### ExecuTorch App API - -```java -mModule.generate(prompt,sequence_length, MainActivity.this); -``` -* `prompt`: User formatted prompt -* `sequence_length`: Number of tokens to generate in response to a prompt -* `MainActivity.this`: Indicate that the callback functions (OnResult(), OnStats()) are present in this class. - -[*LLaVA-1.5: Only for XNNPACK delegate*] - -For LLaVA-1.5 implementation, select the exported LLaVA .pte and tokenizer file in the Settings menu and load the model. After this you can send an image from your gallery or take a live picture along with a text prompt to the model. - -

- -

- - -### Output Generated -To show completion of the follow-up question, here is the complete detailed response from the model. -

- -

- -#### ExecuTorch App API - -Ensure you have the following functions in your callback class that you provided in the `mModule.generate()`. For this example, it is `MainActivity.this`. -```java - @Override - public void onResult(String result) { - //...result contains token from response - //.. onResult will continue to be invoked until response is complete - } - - @Override - public void onStats(String stats) { - //... will be a json. See extension/llm/stats.h for the field definitions - } - -``` - -## Instrumentation Test -You can run the instrumentation test for sanity check. The test loads a model pte file and tokenizer.bin file -under `/data/local/tmp/llama`. - -### Model preparation -Go to ExecuTorch root, -```sh -curl -C - -Ls "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt" --output stories110M.pt -curl -C - -Ls "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model" --output tokenizer.model -# Create params.json file -touch params.json -echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json -python -m extension.llm.export.export_llm base.checkpoint=stories110M.pt base.params=params.json model.dtype_override="fp16" export.output_name=stories110m_h.pte model.use_kv_cache=True -python -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin -``` -### Push model -```sh -adb mkdir -p /data/local/tmp/llama -adb push stories110m_h.pte /data/local/tmp/llama -adb push tokenizer.bin /data/local/tmp/llama -``` - -### Run test -Go to `examples/demo-apps/android/LlamaDemo`, -```sh -./gradlew connectedAndroidTest -``` - -## Reporting Issues -If you encountered any bugs or issues following this tutorial please file a bug/issue here on [Github](https://github.com/pytorch/executorch/issues/new), or join our discord [here](https://lnkd.in/gWCM4ViK). diff --git a/examples/demo-apps/android/LlamaDemo/SDK-quick-setup-guide.md b/examples/demo-apps/android/LlamaDemo/SDK-quick-setup-guide.md deleted file mode 100644 index 9ae79e96763..00000000000 --- a/examples/demo-apps/android/LlamaDemo/SDK-quick-setup-guide.md +++ /dev/null @@ -1,94 +0,0 @@ -# Guide to set up Java/SDK/NDK for Android - -Follow this doc if you haven't set up Java/SDK/NDK for Android development -already. -This doc provides a CLI tutorial to set them up. Otherwise, you can do the same -thing with Android Studio GUI. - -## Set up Java 17 -1. Download the archive from Oracle website. -Make sure you have read and agree with the terms and conditions from the website before downloading. -```bash -export DEV_HOME= -cd $DEV_HOME -``` -Linux: -```bash -curl https://download.oracle.com/java/17/archive/jdk-17.0.10_linux-x64_bin.tar.gz -o jdk-17.0.10.tar.gz -``` -macOS: -```bash -curl https://download.oracle.com/java/17/archive/jdk-17.0.10_macos-aarch64_bin.tar.gz -o jdk-17.0.10.tar.gz -``` -2. Unzip the archive. The directory named `jdk-17.0.10` is the Java root directory. -```bash -tar xf jdk-17.0.10.tar.gz -``` -3. Set `JAVA_HOME` and update `PATH`. - -Linux: -```bash -export JAVA_HOME="$DEV_HOME"/jdk-17.0.10 -export PATH="$JAVA_HOME/bin:$PATH" -``` -macOS: -```bash -export JAVA_HOME="$DEV_HOME"/jdk-17.0.10.jdk/Contents/Home -export PATH="$JAVA_HOME/bin:$PATH" -``` - -Note: Oracle has tutorials for installing Java on -[Linux](https://docs.oracle.com/en/java/javase/17/install/installation-jdk-linux-platforms.html#GUID-4A6BD592-1840-4BB4-A758-4CD49E9EE88B) -and [macOS](https://docs.oracle.com/en/java/javase/17/install/installation-jdk-macos.html#GUID-E8A251B6-D9A9-4276-ABC8-CC0DAD62EA33). -Some Linux distributions has JDK package in package manager. For example, Debian users can install -openjdk-17-jdk package. - -## Set up Android SDK/NDK -Android has a command line tool [sdkmanager](https://developer.android.com/tools/sdkmanager) which -helps users managing SDK and other tools related to Android development. - -1. Go to https://developer.android.com/studio and download the archive from "Command line tools -only" section. Make sure you have read and agree with the terms and conditions from the website. - -Linux: -```bash -curl https://dl.google.com/android/repository/commandlinetools-linux-11076708_latest.zip -o commandlinetools.zip -``` -macOS: -```bash -curl https://dl.google.com/android/repository/commandlinetools-mac-11076708_latest.zip -o commandlinetools.zip -``` -2. Unzip. -```bash -unzip commandlinetools.zip -``` -3. Specify a root for Android SDK. For example, we can put it under `$DEV_HOME/sdk`. - -``` -mkdir -p $DEV_HOME/sdk -export ANDROID_HOME="$(realpath $DEV_HOME/sdk)" -# Install SDK 34 -./cmdline-tools/bin/sdkmanager --sdk_root="${ANDROID_HOME}" --install "platforms;android-34" -# Install NDK -./cmdline-tools/bin/sdkmanager --sdk_root="${ANDROID_HOME}" --install "ndk;26.3.11579264" -# The NDK root is then under `ndk/`. -export ANDROID_NDK="$ANDROID_HOME/ndk/26.3.11579264" -``` - -### (Optional) Android Studio Setup -If you want to use Android Studio and never set up Java/SDK/NDK before, or if -you use the newly installed ones, follow these steps to set Android Studio to use -them. - -Copy these output paths to be used by Android Studio -```bash -echo $ANDROID_HOME -echo $ANDROID_NDK -echo $JAVA_HOME -``` - -Open a project in Android Studio. In Project Structure (File -> Project -Structure, or `⌘;`) -> SDK Location, -* Set Android SDK Location to the path of $ANDROID_HOME -* Set Android NDK Location to the path of $ANDROID_NDK -* Set JDK location (Click Gradle Settings link) -> Gradle JDK -> Add JDK... to the path of $JAVA_HOME diff --git a/examples/demo-apps/android/LlamaDemo/app/.gitignore b/examples/demo-apps/android/LlamaDemo/app/.gitignore deleted file mode 100644 index 796b96d1c40..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/build diff --git a/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts b/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts deleted file mode 100644 index 19cfda847db..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -plugins { - id("com.android.application") - id("org.jetbrains.kotlin.android") -} - -val qnnVersion: String? = project.findProperty("qnnVersion") as? String - -android { - namespace = "com.example.executorchllamademo" - compileSdk = 34 - - defaultConfig { - applicationId = "com.example.executorchllamademo" - minSdk = 28 - targetSdk = 33 - versionCode = 1 - versionName = "1.0" - - testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" - vectorDrawables { useSupportLibrary = true } - externalNativeBuild { cmake { cppFlags += "" } } - } - - buildTypes { - release { - isMinifyEnabled = false - proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro") - } - } - compileOptions { - sourceCompatibility = JavaVersion.VERSION_1_8 - targetCompatibility = JavaVersion.VERSION_1_8 - } - kotlinOptions { jvmTarget = "1.8" } - buildFeatures { compose = true } - composeOptions { kotlinCompilerExtensionVersion = "1.4.3" } - packaging { resources { excludes += "/META-INF/{AL2.0,LGPL2.1}" } } -} - -dependencies { - implementation("androidx.core:core-ktx:1.9.0") - implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.1") - implementation("androidx.activity:activity-compose:1.7.0") - implementation(platform("androidx.compose:compose-bom:2023.03.00")) - implementation("androidx.compose.ui:ui") - implementation("androidx.compose.ui:ui-graphics") - implementation("androidx.compose.ui:ui-tooling-preview") - implementation("androidx.compose.material3:material3") - implementation("androidx.appcompat:appcompat:1.6.1") - implementation("androidx.camera:camera-core:1.3.0-rc02") - implementation("androidx.constraintlayout:constraintlayout:2.2.0-alpha12") - implementation("com.facebook.fbjni:fbjni:0.5.1") - implementation("com.google.code.gson:gson:2.8.6") - implementation(files("libs/executorch.aar")) - implementation("com.google.android.material:material:1.12.0") - implementation("androidx.activity:activity:1.9.0") - implementation("org.json:json:20250107") - if (!qnnVersion.isNullOrEmpty()) { - implementation("com.qualcomm.qti:qnn-runtime:$qnnVersion") - } - testImplementation("junit:junit:4.13.2") - androidTestImplementation("androidx.test.ext:junit:1.1.5") - androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") - androidTestImplementation(platform("androidx.compose:compose-bom:2023.03.00")) - androidTestImplementation("androidx.compose.ui:ui-test-junit4") - debugImplementation("androidx.compose.ui:ui-tooling") - debugImplementation("androidx.compose.ui:ui-test-manifest") -} - -tasks.register("setup") { - doFirst { - exec { - commandLine("sh", "examples/demo-apps/android/LlamaDemo/setup.sh") - workingDir("../../../../../") - } - } -} - -tasks.register("setupQnn") { - doFirst { - exec { - commandLine("sh", "examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh") - workingDir("../../../../../") - } - } -} - -tasks.register("download_prebuilt_lib") { - doFirst { - exec { - commandLine("sh", "examples/demo-apps/android/LlamaDemo/download_prebuilt_lib.sh") - workingDir("../../../../../") - } - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/proguard-rules.pro b/examples/demo-apps/android/LlamaDemo/app/proguard-rules.pro deleted file mode 100644 index 481bb434814..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/proguard-rules.pro +++ /dev/null @@ -1,21 +0,0 @@ -# Add project specific ProGuard rules here. -# You can control the set of applied configuration files using the -# proguardFiles setting in build.gradle. -# -# For more details, see -# http://developer.android.com/guide/developing/tools/proguard.html - -# If your project uses WebView with JS, uncomment the following -# and specify the fully qualified class name to the JavaScript interface -# class: -#-keepclassmembers class fqcn.of.javascript.interface.for.webview { -# public *; -#} - -# Uncomment this to preserve the line number information for -# debugging stack traces. -#-keepattributes SourceFile,LineNumberTable - -# If you keep the line number information, uncomment this to -# hide the original source file name. -#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java b/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java deleted file mode 100644 index 32ec24a0df9..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; - -import android.os.Bundle; -import androidx.test.ext.junit.runners.AndroidJUnit4; -import androidx.test.platform.app.InstrumentationRegistry; -import java.io.File; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import org.json.JSONException; -import org.json.JSONObject; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.pytorch.executorch.extension.llm.LlmCallback; -import org.pytorch.executorch.extension.llm.LlmModule; - -@RunWith(AndroidJUnit4.class) -public class PerfTest implements LlmCallback { - - private static final String RESOURCE_PATH = "/data/local/tmp/llama/"; - private static final String TOKENIZER_BIN = "tokenizer.bin"; - - private final List results = new ArrayList<>(); - private final List tokensPerSecond = new ArrayList<>(); - - @Test - public void testTokensPerSecond() { - String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN; - // Find out the model name - File directory = new File(RESOURCE_PATH); - Arrays.stream(directory.listFiles()) - .filter(file -> file.getName().endsWith(".pte")) - .forEach( - model -> { - LlmModule mModule = new LlmModule(model.getPath(), tokenizerPath, 0.8f); - // Print the model name because there might be more than one of them - report("ModelName", model.getName()); - - int loadResult = mModule.load(); - // Check that the model can be load successfully - assertEquals(0, loadResult); - - // Run a testing prompt - mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this); - assertFalse(tokensPerSecond.isEmpty()); - - final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); - report("TPS", tps); - }); - } - - @Override - public void onResult(String result) { - results.add(result); - } - - @Override - public void onStats(String result) { - try { - JSONObject jsonObject = new JSONObject(result); - int numGeneratedTokens = jsonObject.getInt("generated_tokens"); - int inferenceEndMs = jsonObject.getInt("inference_end_ms"); - int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); - float tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; - tokensPerSecond.add(tps); - } catch (JSONException e) { - } - } - - private void report(final String metric, final Float value) { - Bundle bundle = new Bundle(); - bundle.putFloat(metric, value); - InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); - } - - private void report(final String key, final String value) { - Bundle bundle = new Bundle(); - bundle.putString(key, value); - InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml deleted file mode 100644 index 7096a7d4e76..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml +++ /dev/null @@ -1,85 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/BUCK b/examples/demo-apps/android/LlamaDemo/app/src/main/BUCK deleted file mode 100644 index a64e11d1306..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/BUCK +++ /dev/null @@ -1,67 +0,0 @@ -load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") -load("@fbsource//tools/build_defs/android:fb_android_binary.bzl", "fb_android_binary") -load("@fbsource//tools/build_defs/android:fb_android_library.bzl", "fb_android_library") -load("@fbsource//tools/build_defs/android:fb_android_resource.bzl", "fb_android_resource") - -oncall("executorch") - -non_fbcode_target(_kind = fb_android_resource, - name = "app_res", - package = "com.example.executorchllamademo", - res = "res", -) - -non_fbcode_target(_kind = fb_android_library, - name = "app_lib", - srcs = [ - "java/com/example/executorchllamademo/AppLog.java", - "java/com/example/executorchllamademo/BackendType.java", - "java/com/example/executorchllamademo/DemoSharedPreferences.java", - "java/com/example/executorchllamademo/ETImage.java", - "java/com/example/executorchllamademo/ETLogging.java", - "java/com/example/executorchllamademo/LlmBenchmarkRunner.java", - "java/com/example/executorchllamademo/LogsActivity.java", - "java/com/example/executorchllamademo/LogsAdapter.java", - "java/com/example/executorchllamademo/MainActivity.java", - "java/com/example/executorchllamademo/Message.java", - "java/com/example/executorchllamademo/MessageAdapter.java", - "java/com/example/executorchllamademo/MessageType.java", - "java/com/example/executorchllamademo/ModelRunner.java", - "java/com/example/executorchllamademo/ModelRunnerCallback.java", - "java/com/example/executorchllamademo/ModelType.java", - "java/com/example/executorchllamademo/ModelUtils.java", - "java/com/example/executorchllamademo/PromptFormat.java", - "java/com/example/executorchllamademo/SettingsActivity.java", - "java/com/example/executorchllamademo/SettingsFields.java", - ], - autoglob = False, - language = "JAVA", - deps = [ - ":app_res", - "//third-party/java/androidx/constraintlayout/constraintlayout:constraintlayout", - "//third-party/java/com/google/code/gson/gson:gson", - "//xplat/executorch/extension/android:executorch_llama", - ], -) - -non_fbcode_target(_kind = fb_android_binary, - name = "ExecuTorchLlamaDemo", - keystore = "//fbandroid/keystores:debug", - manifest = "AndroidManifest.xml", - manifest_entries = { - "min_sdk_version": 21, - "target_sdk_version": 34, - "version_code": "1", - "version_name": "1.0", - }, - package_type = "release", - skip_proguard = True, - deps = [ - ":app_lib", - ":app_res", - "//third-party/java/androidx/appcompat/appcompat:appcompat", - "//third-party/java/com/google/code/gson/gson:gson", - "//xplat/executorch/extension/android:executorch_llama", - "//xplat/executorch/extension/android/jni:executorch_llama_jni", - ], -) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/AppLog.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/AppLog.java deleted file mode 100644 index 36d07419381..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/AppLog.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import java.text.SimpleDateFormat; -import java.util.Date; -import java.util.Locale; - -public class AppLog { - private final Long timestamp; - private final String message; - - public AppLog(String message) { - this.timestamp = getCurrentTimeStamp(); - this.message = message; - } - - public Long getTimestamp() { - return timestamp; - } - - public String getMessage() { - return message; - } - - public String getFormattedLog() { - return "[" + getFormattedTimeStamp() + "] " + message; - } - - private Long getCurrentTimeStamp() { - return System.currentTimeMillis(); - } - - private String getFormattedTimeStamp() { - return formatDate(timestamp); - } - - private String formatDate(long milliseconds) { - SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.getDefault()); - Date date = new Date(milliseconds); - return formatter.format(date); - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/BackendType.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/BackendType.java deleted file mode 100644 index 7c84799795f..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/BackendType.java +++ /dev/null @@ -1,7 +0,0 @@ -package com.example.executorchllamademo; - -public enum BackendType { - XNNPACK, - QUALCOMM, - MEDIATEK -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/DemoSharedPreferences.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/DemoSharedPreferences.java deleted file mode 100644 index 99a94c00ebb..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/DemoSharedPreferences.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import android.content.Context; -import android.content.SharedPreferences; -import com.google.gson.Gson; -import com.google.gson.reflect.TypeToken; -import java.lang.reflect.Type; -import java.util.ArrayList; - -public class DemoSharedPreferences { - Context context; - SharedPreferences sharedPreferences; - - public DemoSharedPreferences(Context context) { - this.context = context; - this.sharedPreferences = getSharedPrefs(); - } - - private SharedPreferences getSharedPrefs() { - return context.getSharedPreferences( - context.getString(R.string.demo_pref_file_key), Context.MODE_PRIVATE); - } - - public String getSavedMessages() { - return sharedPreferences.getString(context.getString(R.string.saved_messages_json_key), ""); - } - - public void addMessages(MessageAdapter messageAdapter) { - SharedPreferences.Editor editor = sharedPreferences.edit(); - Gson gson = new Gson(); - String msgJSON = gson.toJson(messageAdapter.getSavedMessages()); - editor.putString(context.getString(R.string.saved_messages_json_key), msgJSON); - editor.apply(); - } - - public void removeExistingMessages() { - SharedPreferences.Editor editor = sharedPreferences.edit(); - editor.remove(context.getString(R.string.saved_messages_json_key)); - editor.apply(); - } - - public void addSettings(SettingsFields settingsFields) { - SharedPreferences.Editor editor = sharedPreferences.edit(); - Gson gson = new Gson(); - String settingsJSON = gson.toJson(settingsFields); - editor.putString(context.getString(R.string.settings_json_key), settingsJSON); - editor.apply(); - } - - public String getSettings() { - return sharedPreferences.getString(context.getString(R.string.settings_json_key), ""); - } - - public void saveLogs() { - SharedPreferences.Editor editor = sharedPreferences.edit(); - Gson gson = new Gson(); - String msgJSON = gson.toJson(ETLogging.getInstance().getLogs()); - editor.putString(context.getString(R.string.logs_json_key), msgJSON); - editor.apply(); - } - - public void removeExistingLogs() { - SharedPreferences.Editor editor = sharedPreferences.edit(); - editor.remove(context.getString(R.string.logs_json_key)); - editor.apply(); - } - - public ArrayList getSavedLogs() { - String logsJSONString = - sharedPreferences.getString(context.getString(R.string.logs_json_key), null); - if (logsJSONString == null || logsJSONString.isEmpty()) { - return new ArrayList<>(); - } - Gson gson = new Gson(); - Type type = new TypeToken>() {}.getType(); - ArrayList appLogs = gson.fromJson(logsJSONString, type); - if (appLogs == null) { - return new ArrayList<>(); - } - return appLogs; - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java deleted file mode 100644 index e68c8472626..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import android.content.ContentResolver; -import android.graphics.Bitmap; -import android.graphics.BitmapFactory; -import android.graphics.Color; -import android.net.Uri; -import androidx.annotation.Nullable; -import java.io.FileNotFoundException; -import java.io.InputStream; - -public class ETImage { - private int width; - private int height; - private final byte[] bytes; - private final Uri uri; - private final ContentResolver contentResolver; - - ETImage(ContentResolver contentResolver, Uri uri) { - this.contentResolver = contentResolver; - this.uri = uri; - bytes = getBytesFromImageURI(uri); - } - - public int getWidth() { - return width; - } - - public int getHeight() { - return height; - } - - public Uri getUri() { - return uri; - } - - public byte[] getBytes() { - return bytes; - } - - public int[] getInts() { - // We need to convert the byte array to an int array because - // the runner expects an int array as input. - int[] intArray = new int[bytes.length]; - for (int i = 0; i < bytes.length; i++) { - intArray[i] = (bytes[i++] & 0xFF); - } - return intArray; - } - - private byte[] getBytesFromImageURI(Uri uri) { - try { - int RESIZED_IMAGE_WIDTH = 336; - Bitmap bitmap = resizeImage(uri, RESIZED_IMAGE_WIDTH); - - if (bitmap == null) { - ETLogging.getInstance().log("Unable to get bytes from Image URI. Bitmap is null"); - return new byte[0]; - } - - width = bitmap.getWidth(); - height = bitmap.getHeight(); - - byte[] rgbValues = new byte[width * height * 3]; - - for (int y = 0; y < height; y++) { - for (int x = 0; x < width; x++) { - // Get the color of the current pixel - int color = bitmap.getPixel(x, y); - - // Extract the RGB values from the color - int red = Color.red(color); - int green = Color.green(color); - int blue = Color.blue(color); - - // Store the RGB values in the byte array - rgbValues[y * width + x] = (byte) red; - rgbValues[(y * width + x) + height * width] = (byte) green; - rgbValues[(y * width + x) + 2 * height * width] = (byte) blue; - } - } - return rgbValues; - } catch (FileNotFoundException e) { - throw new RuntimeException(e); - } - } - - @Nullable - private Bitmap resizeImage(Uri uri, int maxLength) throws FileNotFoundException { - InputStream inputStream = contentResolver.openInputStream(uri); - if (inputStream == null) { - ETLogging.getInstance().log("Unable to resize image, input streams is null"); - return null; - } - Bitmap bitmap = BitmapFactory.decodeStream(inputStream); - if (bitmap == null) { - ETLogging.getInstance().log("Unable to resize image, bitmap during decode stream is null"); - return null; - } - - float aspectRatio; - int finalWidth, finalHeight; - - if (bitmap.getWidth() > bitmap.getHeight()) { - // width > height --> width = maxLength, height scale with aspect ratio - aspectRatio = bitmap.getWidth() / (float) bitmap.getHeight(); - finalWidth = maxLength; - finalHeight = Math.round(maxLength / aspectRatio); - } else { - // height >= width --> height = maxLength, width scale with aspect ratio - aspectRatio = bitmap.getHeight() / (float) bitmap.getWidth(); - finalHeight = maxLength; - finalWidth = Math.round(maxLength / aspectRatio); - } - - return Bitmap.createScaledBitmap(bitmap, finalWidth, finalHeight, false); - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETLogging.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETLogging.java deleted file mode 100644 index e595348945f..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETLogging.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import android.app.Application; -import android.util.Log; -import java.util.ArrayList; - -public class ETLogging extends Application { - private static ETLogging singleton; - - private ArrayList logs; - private DemoSharedPreferences mDemoSharedPreferences; - - @Override - public void onCreate() { - super.onCreate(); - singleton = this; - mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext()); - logs = mDemoSharedPreferences.getSavedLogs(); - if (logs == null) { // We don't have existing sharedPreference stored - logs = new ArrayList<>(); - } - } - - public static ETLogging getInstance() { - return singleton; - } - - public void log(String message) { - AppLog appLog = new AppLog(message); - logs.add(appLog); - Log.d("ETLogging", appLog.getMessage()); - } - - public ArrayList getLogs() { - return logs; - } - - public void clearLogs() { - logs.clear(); - mDemoSharedPreferences.removeExistingLogs(); - } - - public void saveLogs() { - mDemoSharedPreferences.saveLogs(); - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java deleted file mode 100644 index 8c2d60252a0..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import android.app.Activity; -import android.app.ActivityManager; -import android.content.Intent; -import android.os.Build; -import android.os.Bundle; -import android.util.Log; -import android.widget.TextView; -import androidx.annotation.NonNull; -import com.google.gson.Gson; -import java.io.File; -import java.io.FileWriter; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback { - ModelRunner mModelRunner; - - String mPrompt; - TextView mTextView; - StatsDump mStatsDump; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - setContentView(R.layout.activity_benchmarking); - mTextView = findViewById(R.id.log_view); - - Intent intent = getIntent(); - - File modelDir = new File(intent.getStringExtra("model_dir")); - File model = - Arrays.stream(modelDir.listFiles()) - .filter(file -> file.getName().endsWith(".pte")) - .findFirst() - .get(); - String tokenizerPath = intent.getStringExtra("tokenizer_path"); - - float temperature = intent.getFloatExtra("temperature", 0.8f); - mPrompt = intent.getStringExtra("prompt"); - if (mPrompt == null) { - mPrompt = "The ultimate answer"; - } - - mStatsDump = new StatsDump(); - mStatsDump.modelName = model.getName().replace(".pte", ""); - mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); - mStatsDump.loadStart = System.nanoTime(); - } - - @Override - public void onModelLoaded(int status) { - mStatsDump.loadEnd = System.nanoTime(); - mStatsDump.loadStatus = status; - if (status != 0) { - Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); - onGenerationStopped(); - return; - } - mStatsDump.generateStart = System.nanoTime(); - mModelRunner.generate(mPrompt); - } - - @Override - public void onTokenGenerated(String token) { - runOnUiThread( - () -> { - mTextView.append(token); - }); - } - - @Override - public void onStats(String stats) { - mStatsDump.tokens = stats; - } - - @Override - public void onGenerationStopped() { - mStatsDump.generateEnd = System.nanoTime(); - runOnUiThread( - () -> { - mTextView.append(mStatsDump.toString()); - }); - - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(mStatsDump.modelName); - final List results = new ArrayList<>(); - // The list of metrics we have atm includes: - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsDump.loadStatus, 0)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, - "model_load_time(ms)", - (mStatsDump.loadEnd - mStatsDump.loadStart) * 1e-6, - 0.0f)); - // LLM generate time - results.add( - new BenchmarkMetric( - benchmarkModel, - "generate_time(ms)", - (mStatsDump.generateEnd - mStatsDump.generateStart) * 1e-6, - 0.0f)); - // Token per second - results.add( - new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsDump.tokens), 0.0f)); - - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(results)); - } catch (IOException e) { - e.printStackTrace(); - } - } - - private double extractTPS(final String tokens) { - final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens); - if (m.find()) { - return Double.parseDouble(m.group()); - } else { - return 0.0f; - } - } -} - -class BenchmarkMetric { - public static class BenchmarkModel { - // The model name, i.e. stories110M - String name; - String backend; - String quantization; - - public BenchmarkModel(final String name, final String backend, final String quantization) { - this.name = name; - this.backend = backend; - this.quantization = quantization; - } - } - - BenchmarkModel benchmarkModel; - - // The metric name, i.e. TPS - String metric; - - // The actual value and the option target value - double actualValue; - double targetValue; - - public static class DeviceInfo { - // Let's see which information we want to include here - final String device = Build.BRAND; - // The phone model and Android release version - final String arch = Build.MODEL; - final String os = "Android " + Build.VERSION.RELEASE; - final long totalMem = new ActivityManager.MemoryInfo().totalMem; - final long availMem = new ActivityManager.MemoryInfo().availMem; - } - - DeviceInfo deviceInfo = new DeviceInfo(); - - public BenchmarkMetric( - final BenchmarkModel benchmarkModel, - final String metric, - final double actualValue, - final double targetValue) { - this.benchmarkModel = benchmarkModel; - this.metric = metric; - this.actualValue = actualValue; - this.targetValue = targetValue; - } - - // TODO (huydhn): Figure out a way to extract the backend and quantization information from - // the .pte model itself instead of parsing its name - public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { - final Matcher m = - Pattern.compile("(?\\w+)_(?[\\w\\+]+)_(?\\w+)").matcher(model); - if (m.matches()) { - return new BenchmarkMetric.BenchmarkModel( - m.group("name"), m.group("backend"), m.group("quantization")); - } else { - return new BenchmarkMetric.BenchmarkModel(model, "", ""); - } - } -} - -class StatsDump { - int loadStatus; - long loadStart; - long loadEnd; - long generateStart; - long generateEnd; - String tokens; - String modelName; - - @NonNull - @Override - public String toString() { - return "loadStart: " - + loadStart - + "\nloadEnd: " - + loadEnd - + "\ngenerateStart: " - + generateStart - + "\ngenerateEnd: " - + generateEnd - + "\n" - + tokens; - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsActivity.java deleted file mode 100644 index 7777b275e6e..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsActivity.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import android.app.AlertDialog; -import android.content.DialogInterface; -import android.os.Build; -import android.os.Bundle; -import android.widget.ImageButton; -import android.widget.ListView; -import androidx.appcompat.app.AppCompatActivity; -import androidx.core.content.ContextCompat; -import androidx.core.graphics.Insets; -import androidx.core.view.ViewCompat; -import androidx.core.view.WindowInsetsCompat; - -public class LogsActivity extends AppCompatActivity { - - private LogsAdapter mLogsAdapter; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - setContentView(R.layout.activity_logs); - if (Build.VERSION.SDK_INT >= 21) { - getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar)); - getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar)); - } - ViewCompat.setOnApplyWindowInsetsListener( - requireViewById(R.id.main), - (v, insets) -> { - Insets systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars()); - v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom); - return insets; - }); - - setupLogs(); - setupClearLogsButton(); - } - - @Override - public void onResume() { - super.onResume(); - mLogsAdapter.clear(); - mLogsAdapter.addAll(ETLogging.getInstance().getLogs()); - mLogsAdapter.notifyDataSetChanged(); - } - - private void setupLogs() { - ListView mLogsListView = requireViewById(R.id.logsListView); - mLogsAdapter = new LogsAdapter(this, R.layout.logs_message); - - mLogsListView.setAdapter(mLogsAdapter); - mLogsAdapter.addAll(ETLogging.getInstance().getLogs()); - mLogsAdapter.notifyDataSetChanged(); - } - - private void setupClearLogsButton() { - ImageButton clearLogsButton = requireViewById(R.id.clearLogsButton); - clearLogsButton.setOnClickListener( - view -> { - new AlertDialog.Builder(this) - .setTitle("Delete Logs History") - .setMessage("Do you really want to delete logs history?") - .setIcon(android.R.drawable.ic_dialog_alert) - .setPositiveButton( - android.R.string.yes, - new DialogInterface.OnClickListener() { - public void onClick(DialogInterface dialog, int whichButton) { - // Clear the messageAdapter and sharedPreference - ETLogging.getInstance().clearLogs(); - mLogsAdapter.clear(); - mLogsAdapter.notifyDataSetChanged(); - } - }) - .setNegativeButton(android.R.string.no, null) - .show(); - }); - } - - @Override - protected void onDestroy() { - super.onDestroy(); - ETLogging.getInstance().saveLogs(); - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsAdapter.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsAdapter.java deleted file mode 100644 index 76c6a1aa1b4..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsAdapter.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import android.view.LayoutInflater; -import android.view.View; -import android.view.ViewGroup; -import android.widget.ArrayAdapter; -import android.widget.TextView; -import androidx.annotation.NonNull; -import java.util.Objects; - -public class LogsAdapter extends ArrayAdapter { - public LogsAdapter(android.content.Context context, int resource) { - super(context, resource); - } - - static class ViewHolder { - private TextView logTextView; - } - - @NonNull - @Override - public View getView(int position, View convertView, @NonNull ViewGroup parent) { - ViewHolder mViewHolder = null; - - String logMessage = Objects.requireNonNull(getItem(position)).getFormattedLog(); - - if (convertView == null || convertView.getTag() == null) { - mViewHolder = new ViewHolder(); - convertView = LayoutInflater.from(getContext()).inflate(R.layout.logs_message, parent, false); - mViewHolder.logTextView = convertView.requireViewById(R.id.logsTextView); - } else { - mViewHolder = (ViewHolder) convertView.getTag(); - } - mViewHolder.logTextView.setText(logMessage); - return convertView; - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java deleted file mode 100644 index f995c5bc65a..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ /dev/null @@ -1,847 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import android.Manifest; -import android.app.ActivityManager; -import android.app.AlertDialog; -import android.content.ContentResolver; -import android.content.ContentValues; -import android.content.Intent; -import android.content.pm.PackageManager; -import android.net.Uri; -import android.os.Build; -import android.os.Bundle; -import android.os.Handler; -import android.os.Looper; -import android.os.Process; -import android.provider.MediaStore; -import android.system.ErrnoException; -import android.system.Os; -import android.util.Log; -import android.view.View; -import android.view.inputmethod.InputMethodManager; -import android.widget.EditText; -import android.widget.ImageButton; -import android.widget.ImageView; -import android.widget.LinearLayout; -import android.widget.ListView; -import android.widget.TextView; -import android.widget.Toast; -import androidx.activity.result.ActivityResultLauncher; -import androidx.activity.result.PickVisualMediaRequest; -import androidx.activity.result.contract.ActivityResultContracts; -import androidx.annotation.NonNull; -import androidx.appcompat.app.AppCompatActivity; -import androidx.constraintlayout.widget.ConstraintLayout; -import androidx.core.app.ActivityCompat; -import androidx.core.content.ContextCompat; -import androidx.core.content.res.ResourcesCompat; -import com.google.gson.Gson; -import com.google.gson.reflect.TypeToken; -import java.lang.reflect.Type; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; -import org.json.JSONException; -import org.json.JSONObject; -import org.pytorch.executorch.extension.llm.LlmCallback; -import org.pytorch.executorch.extension.llm.LlmModule; - -public class MainActivity extends AppCompatActivity implements Runnable, LlmCallback { - private EditText mEditTextMessage; - private ImageButton mThinkModeButton; - private ImageButton mSendButton; - private ImageButton mGalleryButton; - private ImageButton mCameraButton; - private ListView mMessagesView; - private MessageAdapter mMessageAdapter; - private LlmModule mModule = null; - private Message mResultMessage = null; - private ImageButton mSettingsButton; - private TextView mMemoryView; - private ActivityResultLauncher mPickGallery; - private ActivityResultLauncher mCameraRoll; - private List mSelectedImageUri; - private ConstraintLayout mMediaPreviewConstraintLayout; - private LinearLayout mAddMediaLayout; - private static final int MAX_NUM_OF_IMAGES = 5; - private static final int REQUEST_IMAGE_CAPTURE = 1; - private Uri cameraImageUri; - private DemoSharedPreferences mDemoSharedPreferences; - private SettingsFields mCurrentSettingsFields; - private Handler mMemoryUpdateHandler; - private Runnable memoryUpdater; - private boolean mThinkMode = false; - private int promptID = 0; - private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2; - private Executor executor; - - @Override - public void onResult(String result) { - if (result.equals(PromptFormat.getStopToken(mCurrentSettingsFields.getModelType()))) { - return; - } - result = PromptFormat.replaceSpecialToken(mCurrentSettingsFields.getModelType(), result); - if (result.equals("\n\n") || result.equals("\n")) { - if (!mResultMessage.getText().isEmpty()) { - mResultMessage.appendText(result); - run(); - } - } else { - mResultMessage.appendText(result); - run(); - } - } - - @Override - public void onStats(String stats) { - runOnUiThread( - () -> { - if (mResultMessage != null) { - float tps = 0; - try { - JSONObject jsonObject = new JSONObject(stats); - int numGeneratedTokens = jsonObject.getInt("generated_tokens"); - int inferenceEndMs = jsonObject.getInt("inference_end_ms"); - int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); - tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; - } catch (JSONException e) { - Log.e("LLM", "Error parsing JSON: " + e.getMessage()); - } - mResultMessage.setTokensPerSecond(tps); - mMessageAdapter.notifyDataSetChanged(); - } - }); - } - - private void setLocalModel(String modelPath, String tokenizerPath, float temperature) { - Message modelLoadingMessage = new Message("Loading model...", false, MessageType.SYSTEM, 0); - ETLogging.getInstance().log("Loading model " + modelPath + " with tokenizer " + tokenizerPath); - runOnUiThread( - () -> { - mSendButton.setEnabled(false); - mMessageAdapter.add(modelLoadingMessage); - mMessageAdapter.notifyDataSetChanged(); - }); - if (mModule != null) { - ETLogging.getInstance().log("Start deallocating existing module instance"); - mModule.resetNative(); - mModule = null; - ETLogging.getInstance().log("Completed deallocating existing module instance"); - } - long runStartTime = System.currentTimeMillis(); - mModule = - new LlmModule( - ModelUtils.getModelCategory( - mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()), - modelPath, - tokenizerPath, - temperature); - int loadResult = mModule.load(); - long loadDuration = System.currentTimeMillis() - runStartTime; - String modelLoadError = ""; - String modelInfo = ""; - if (loadResult != 0) { - // TODO: Map the error code to a reason to let the user know why model loading failed - modelInfo = "*Model could not load (Error Code: " + loadResult + ")*" + "\n"; - loadDuration = 0; - AlertDialog.Builder builder = new AlertDialog.Builder(this); - builder.setTitle("Load failed: " + loadResult); - runOnUiThread( - () -> { - AlertDialog alert = builder.create(); - alert.show(); - }); - } else { - String[] segments = modelPath.split("/"); - String pteName = segments[segments.length - 1]; - segments = tokenizerPath.split("/"); - String tokenizerName = segments[segments.length - 1]; - modelInfo = - "Successfully loaded model. " - + pteName - + " and tokenizer " - + tokenizerName - + " in " - + (float) loadDuration / 1000 - + " sec." - + " You can send text or image for inference"; - - if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { - ETLogging.getInstance().log("Llava start prefill prompt"); - mModule.resetContext(); - mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt()); - ETLogging.getInstance().log("Llava completes prefill prompt"); - } - } - - Message modelLoadedMessage = new Message(modelInfo, false, MessageType.SYSTEM, 0); - - String modelLoggingInfo = - modelLoadError - + "Model path: " - + modelPath - + "\nTokenizer path: " - + tokenizerPath - + "\nBackend: " - + mCurrentSettingsFields.getBackendType().toString() - + "\nModelType: " - + ModelUtils.getModelCategory( - mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()) - + "\nTemperature: " - + temperature - + "\nModel loaded time: " - + loadDuration - + " ms"; - ETLogging.getInstance().log("Load complete. " + modelLoggingInfo); - - runOnUiThread( - () -> { - mSendButton.setEnabled(true); - mMessageAdapter.remove(modelLoadingMessage); - mMessageAdapter.add(modelLoadedMessage); - mMessageAdapter.notifyDataSetChanged(); - }); - } - - private void loadLocalModelAndParameters( - String modelFilePath, String tokenizerFilePath, float temperature) { - Runnable runnable = - new Runnable() { - @Override - public void run() { - setLocalModel(modelFilePath, tokenizerFilePath, temperature); - } - }; - new Thread(runnable).start(); - } - - private void populateExistingMessages(String existingMsgJSON) { - Gson gson = new Gson(); - Type type = new TypeToken>() {}.getType(); - ArrayList savedMessages = gson.fromJson(existingMsgJSON, type); - for (Message msg : savedMessages) { - mMessageAdapter.add(msg); - } - mMessageAdapter.notifyDataSetChanged(); - } - - private int setPromptID() { - - return mMessageAdapter.getMaxPromptID() + 1; - } - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - setContentView(R.layout.activity_main); - - if (Build.VERSION.SDK_INT >= 21) { - getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar)); - getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar)); - } - - try { - Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); - Os.setenv("LD_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); - } catch (ErrnoException e) { - finish(); - } - - mThinkModeButton = requireViewById(R.id.thinkModeButton); - mEditTextMessage = requireViewById(R.id.editTextMessage); - mSendButton = requireViewById(R.id.sendButton); - mSendButton.setEnabled(false); - mMessagesView = requireViewById(R.id.messages_view); - mMessageAdapter = new MessageAdapter(this, R.layout.sent_message, new ArrayList()); - mMessagesView.setAdapter(mMessageAdapter); - mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext()); - String existingMsgJSON = mDemoSharedPreferences.getSavedMessages(); - if (!existingMsgJSON.isEmpty()) { - populateExistingMessages(existingMsgJSON); - promptID = setPromptID(); - } - mSettingsButton = requireViewById(R.id.settings); - mSettingsButton.setOnClickListener( - view -> { - Intent myIntent = new Intent(MainActivity.this, SettingsActivity.class); - MainActivity.this.startActivity(myIntent); - }); - - mThinkModeButton.setOnClickListener( - view -> { - if (mThinkMode) { - mThinkMode = false; - mThinkModeButton.setImageDrawable( - ResourcesCompat.getDrawable( - getResources(), R.drawable.baseline_lightbulb_24, null)); - } else { - mThinkMode = true; - mThinkModeButton.setImageDrawable( - ResourcesCompat.getDrawable(getResources(), R.drawable.blue_lightbulb_24, null)); - } - runOnUiThread( - () -> { - String thinkingModeText = mThinkMode ? "on" : "off"; - mMessageAdapter.add( - new Message( - "Thinking mode is " + thinkingModeText, false, MessageType.SYSTEM, 0)); - mMessageAdapter.notifyDataSetChanged(); - }); - }); - - mCurrentSettingsFields = new SettingsFields(); - mMemoryUpdateHandler = new Handler(Looper.getMainLooper()); - onModelRunStopped(); - setupMediaButton(); - setupGalleryPicker(); - setupCameraRoll(); - startMemoryUpdate(); - setupShowLogsButton(); - executor = Executors.newSingleThreadExecutor(); - } - - @Override - protected void onPause() { - super.onPause(); - mDemoSharedPreferences.addMessages(mMessageAdapter); - } - - @Override - protected void onResume() { - super.onResume(); - // Check for if settings parameters have changed - Gson gson = new Gson(); - String settingsFieldsJSON = mDemoSharedPreferences.getSettings(); - if (!settingsFieldsJSON.isEmpty()) { - SettingsFields updatedSettingsFields = - gson.fromJson(settingsFieldsJSON, SettingsFields.class); - if (updatedSettingsFields == null) { - // Added this check, because gson.fromJson can return null - askUserToSelectModel(); - return; - } - boolean isUpdated = !mCurrentSettingsFields.equals(updatedSettingsFields); - boolean isLoadModel = updatedSettingsFields.getIsLoadModel(); - setBackendMode(updatedSettingsFields.getBackendType()); - if (isUpdated) { - if (isLoadModel) { - // If users change the model file, but not pressing loadModelButton, we won't load the new - // model - checkForUpdateAndReloadModel(updatedSettingsFields); - } else { - askUserToSelectModel(); - } - - checkForClearChatHistory(updatedSettingsFields); - // Update current to point to the latest - mCurrentSettingsFields = new SettingsFields(updatedSettingsFields); - } - } else { - askUserToSelectModel(); - } - } - - private void setBackendMode(BackendType backendType) { - if (backendType.equals(BackendType.XNNPACK) || backendType.equals(BackendType.QUALCOMM)) { - setXNNPACKMode(); - } else if (backendType.equals(BackendType.MEDIATEK)) { - setMediaTekMode(); - } - } - - private void setXNNPACKMode() { - requireViewById(R.id.addMediaButton).setVisibility(View.VISIBLE); - } - - private void setMediaTekMode() { - requireViewById(R.id.addMediaButton).setVisibility(View.GONE); - } - - private void checkForClearChatHistory(SettingsFields updatedSettingsFields) { - if (updatedSettingsFields.getIsClearChatHistory()) { - mMessageAdapter.clear(); - mMessageAdapter.notifyDataSetChanged(); - mDemoSharedPreferences.removeExistingMessages(); - // changing to false since chat history has been cleared. - updatedSettingsFields.saveIsClearChatHistory(false); - mDemoSharedPreferences.addSettings(updatedSettingsFields); - } - } - - private void checkForUpdateAndReloadModel(SettingsFields updatedSettingsFields) { - // TODO need to add 'load model' in settings and queue loading based on that - String modelPath = updatedSettingsFields.getModelFilePath(); - String tokenizerPath = updatedSettingsFields.getTokenizerFilePath(); - double temperature = updatedSettingsFields.getTemperature(); - if (!modelPath.isEmpty() && !tokenizerPath.isEmpty()) { - if (updatedSettingsFields.getIsLoadModel() - || !modelPath.equals(mCurrentSettingsFields.getModelFilePath()) - || !tokenizerPath.equals(mCurrentSettingsFields.getTokenizerFilePath()) - || temperature != mCurrentSettingsFields.getTemperature()) { - loadLocalModelAndParameters( - updatedSettingsFields.getModelFilePath(), - updatedSettingsFields.getTokenizerFilePath(), - (float) updatedSettingsFields.getTemperature()); - updatedSettingsFields.saveLoadModelAction(false); - mDemoSharedPreferences.addSettings(updatedSettingsFields); - } - } else { - askUserToSelectModel(); - } - } - - private void askUserToSelectModel() { - String askLoadModel = - "To get started, select your desired model and tokenizer " + "from the top right corner"; - Message askLoadModelMessage = new Message(askLoadModel, false, MessageType.SYSTEM, 0); - ETLogging.getInstance().log(askLoadModel); - runOnUiThread( - () -> { - mMessageAdapter.add(askLoadModelMessage); - mMessageAdapter.notifyDataSetChanged(); - }); - } - - private void setupShowLogsButton() { - ImageButton showLogsButton = requireViewById(R.id.showLogsButton); - showLogsButton.setOnClickListener( - view -> { - Intent myIntent = new Intent(MainActivity.this, LogsActivity.class); - MainActivity.this.startActivity(myIntent); - }); - } - - private void setupMediaButton() { - mAddMediaLayout = requireViewById(R.id.addMediaLayout); - mAddMediaLayout.setVisibility(View.GONE); // We hide this initially - - ImageButton addMediaButton = requireViewById(R.id.addMediaButton); - addMediaButton.setOnClickListener( - view -> { - mAddMediaLayout.setVisibility(View.VISIBLE); - }); - - mGalleryButton = requireViewById(R.id.galleryButton); - mGalleryButton.setOnClickListener( - view -> { - // Launch the photo picker and let the user choose only images. - mPickGallery.launch( - new PickVisualMediaRequest.Builder() - .setMediaType(ActivityResultContracts.PickVisualMedia.ImageOnly.INSTANCE) - .build()); - }); - mCameraButton = requireViewById(R.id.cameraButton); - mCameraButton.setOnClickListener( - view -> { - Log.d("CameraRoll", "Check permission"); - if (ContextCompat.checkSelfPermission(MainActivity.this, Manifest.permission.CAMERA) - != PackageManager.PERMISSION_GRANTED) { - ActivityCompat.requestPermissions( - MainActivity.this, - new String[] {Manifest.permission.CAMERA}, - REQUEST_IMAGE_CAPTURE); - } else { - launchCamera(); - } - }); - } - - private void setupCameraRoll() { - // Registers a camera roll activity launcher. - mCameraRoll = - registerForActivityResult( - new ActivityResultContracts.TakePicture(), - result -> { - if (result && cameraImageUri != null) { - Log.d("CameraRoll", "Photo saved to uri: " + cameraImageUri); - mAddMediaLayout.setVisibility(View.GONE); - List uris = new ArrayList<>(); - uris.add(cameraImageUri); - showMediaPreview(uris); - } else { - // Delete the temp image file based on the url since the photo is not successfully - // taken - if (cameraImageUri != null) { - ContentResolver contentResolver = MainActivity.this.getContentResolver(); - contentResolver.delete(cameraImageUri, null, null); - Log.d("CameraRoll", "No photo taken. Delete temp uri"); - } - } - }); - mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout); - ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton); - mediaPreviewCloseButton.setOnClickListener( - view -> { - mMediaPreviewConstraintLayout.setVisibility(View.GONE); - mSelectedImageUri = null; - }); - - ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton); - addMoreImageButton.setOnClickListener( - view -> { - Log.d("addMore", "clicked"); - mMediaPreviewConstraintLayout.setVisibility(View.GONE); - // Direct user to select type of input - mCameraButton.callOnClick(); - }); - } - - private String updateMemoryUsage() { - ActivityManager.MemoryInfo memoryInfo = new ActivityManager.MemoryInfo(); - ActivityManager activityManager = (ActivityManager) getSystemService(ACTIVITY_SERVICE); - if (activityManager == null) { - return "---"; - } - activityManager.getMemoryInfo(memoryInfo); - long totalMem = memoryInfo.totalMem / (1024 * 1024); - long availableMem = memoryInfo.availMem / (1024 * 1024); - long usedMem = totalMem - availableMem; - return usedMem + "MB"; - } - - private void startMemoryUpdate() { - mMemoryView = requireViewById(R.id.ram_usage_live); - memoryUpdater = - new Runnable() { - @Override - public void run() { - mMemoryView.setText(updateMemoryUsage()); - mMemoryUpdateHandler.postDelayed(this, 1000); - } - }; - mMemoryUpdateHandler.post(memoryUpdater); - } - - @Override - public void onRequestPermissionsResult( - int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { - super.onRequestPermissionsResult(requestCode, permissions, grantResults); - if (requestCode == REQUEST_IMAGE_CAPTURE && grantResults.length != 0) { - if (grantResults[0] == PackageManager.PERMISSION_GRANTED) { - launchCamera(); - } else if (grantResults[0] == PackageManager.PERMISSION_DENIED) { - Log.d("CameraRoll", "Permission denied"); - } - } - } - - private void launchCamera() { - ContentValues values = new ContentValues(); - values.put(MediaStore.Images.Media.TITLE, "New Picture"); - values.put(MediaStore.Images.Media.DESCRIPTION, "From Camera"); - values.put(MediaStore.Images.Media.RELATIVE_PATH, "DCIM/Camera/"); - cameraImageUri = - MainActivity.this - .getContentResolver() - .insert(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, values); - mCameraRoll.launch(cameraImageUri); - } - - private void setupGalleryPicker() { - // Registers a photo picker activity launcher in single-select mode. - mPickGallery = - registerForActivityResult( - new ActivityResultContracts.PickMultipleVisualMedia(MAX_NUM_OF_IMAGES), - uris -> { - if (!uris.isEmpty()) { - Log.d("PhotoPicker", "Selected URIs: " + uris); - mAddMediaLayout.setVisibility(View.GONE); - for (Uri uri : uris) { - MainActivity.this - .getContentResolver() - .takePersistableUriPermission(uri, Intent.FLAG_GRANT_READ_URI_PERMISSION); - } - showMediaPreview(uris); - } else { - Log.d("PhotoPicker", "No media selected"); - } - }); - - mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout); - ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton); - mediaPreviewCloseButton.setOnClickListener( - view -> { - mMediaPreviewConstraintLayout.setVisibility(View.GONE); - mSelectedImageUri = null; - }); - - ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton); - addMoreImageButton.setOnClickListener( - view -> { - Log.d("addMore", "clicked"); - mMediaPreviewConstraintLayout.setVisibility(View.GONE); - mGalleryButton.callOnClick(); - }); - } - - private List getProcessedImagesForModel(List uris) { - List imageList = new ArrayList<>(); - if (uris != null) { - uris.forEach( - (uri) -> { - imageList.add(new ETImage(this.getContentResolver(), uri)); - }); - } - return imageList; - } - - private void showMediaPreview(List uris) { - if (mSelectedImageUri == null) { - mSelectedImageUri = uris; - } else { - mSelectedImageUri.addAll(uris); - } - - if (mSelectedImageUri.size() > MAX_NUM_OF_IMAGES) { - mSelectedImageUri = mSelectedImageUri.subList(0, MAX_NUM_OF_IMAGES); - Toast.makeText( - this, "Only max " + MAX_NUM_OF_IMAGES + " images are allowed", Toast.LENGTH_SHORT) - .show(); - } - Log.d("mSelectedImageUri", mSelectedImageUri.size() + " " + mSelectedImageUri); - - mMediaPreviewConstraintLayout.setVisibility(View.VISIBLE); - - List imageViews = new ArrayList(); - - // Pre-populate all the image views that are available from the layout (currently max 5) - imageViews.add(requireViewById(R.id.mediaPreviewImageView1)); - imageViews.add(requireViewById(R.id.mediaPreviewImageView2)); - imageViews.add(requireViewById(R.id.mediaPreviewImageView3)); - imageViews.add(requireViewById(R.id.mediaPreviewImageView4)); - imageViews.add(requireViewById(R.id.mediaPreviewImageView5)); - - // Hide all the image views (reset state) - for (int i = 0; i < imageViews.size(); i++) { - imageViews.get(i).setVisibility(View.GONE); - } - - // Only show/render those that have proper Image URIs - for (int i = 0; i < mSelectedImageUri.size(); i++) { - imageViews.get(i).setVisibility(View.VISIBLE); - imageViews.get(i).setImageURI(mSelectedImageUri.get(i)); - } - - // For LLava, we want to call prefill_image as soon as an image is selected - // Llava only support 1 image for now - if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { - List processedImageList = getProcessedImagesForModel(mSelectedImageUri); - if (!processedImageList.isEmpty()) { - mMessageAdapter.add( - new Message("Llava - Starting image Prefill.", false, MessageType.SYSTEM, 0)); - mMessageAdapter.notifyDataSetChanged(); - Runnable runnable = - () -> { - Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE); - ETLogging.getInstance().log("Starting runnable prefill image"); - ETImage img = processedImageList.get(0); - ETLogging.getInstance().log("Llava start prefill image"); - mModule.prefillImages( - img.getInts(), - img.getWidth(), - img.getHeight(), - ModelUtils.VISION_MODEL_IMAGE_CHANNELS); - }; - executor.execute(runnable); - } - } - } - - private void addSelectedImagesToChatThread(List selectedImageUri) { - if (selectedImageUri == null) { - return; - } - mMediaPreviewConstraintLayout.setVisibility(View.GONE); - for (int i = 0; i < selectedImageUri.size(); i++) { - Uri imageURI = selectedImageUri.get(i); - Log.d("image uri ", "test " + imageURI.getPath()); - mMessageAdapter.add(new Message(imageURI.toString(), true, MessageType.IMAGE, 0)); - } - mMessageAdapter.notifyDataSetChanged(); - } - - private String getConversationHistory() { - String conversationHistory = ""; - - ArrayList conversations = - mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK); - if (conversations.isEmpty()) { - return conversationHistory; - } - - int prevPromptID = conversations.get(0).getPromptID(); - String conversationFormat = - PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType()); - String format = conversationFormat; - for (int i = 0; i < conversations.size(); i++) { - Message conversation = conversations.get(i); - int currentPromptID = conversation.getPromptID(); - if (currentPromptID != prevPromptID) { - conversationHistory = conversationHistory + format; - format = conversationFormat; - prevPromptID = currentPromptID; - } - if (conversation.getIsSent()) { - format = - format - .replace(PromptFormat.USER_PLACEHOLDER, conversation.getText()) - .replace(PromptFormat.THINKING_MODE_PLACEHOLDER, ""); - } else { - format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText()); - } - } - conversationHistory = conversationHistory + format; - - return conversationHistory; - } - - private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) { - if (conversationHistory.isEmpty()) { - return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt, mThinkMode); - } - - return mCurrentSettingsFields.getFormattedSystemPrompt() - + conversationHistory - + mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt, mThinkMode); - } - - private void onModelRunStarted() { - mSendButton.setClickable(false); - mSendButton.setImageResource(R.drawable.baseline_stop_24); - mSendButton.setOnClickListener( - view -> { - mModule.stop(); - }); - } - - private void onModelRunStopped() { - mSendButton.setClickable(true); - mSendButton.setImageResource(R.drawable.baseline_send_24); - mSendButton.setOnClickListener( - view -> { - try { - InputMethodManager imm = (InputMethodManager) getSystemService(INPUT_METHOD_SERVICE); - imm.hideSoftInputFromWindow(getCurrentFocus().getWindowToken(), 0); - } catch (Exception e) { - ETLogging.getInstance().log("Keyboard dismissal error: " + e.getMessage()); - } - addSelectedImagesToChatThread(mSelectedImageUri); - String finalPrompt; - String rawPrompt = mEditTextMessage.getText().toString(); - if (ModelUtils.getModelCategory( - mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()) - == ModelUtils.VISION_MODEL) { - finalPrompt = - mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt, mThinkMode); - } else { - finalPrompt = getTotalFormattedPrompt(getConversationHistory(), rawPrompt); - } - // We store raw prompt into message adapter, because we don't want to show the extra - // tokens from system prompt - mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID)); - mMessageAdapter.notifyDataSetChanged(); - mEditTextMessage.setText(""); - mResultMessage = new Message("", false, MessageType.TEXT, promptID); - mMessageAdapter.add(mResultMessage); - // Scroll to bottom of the list - mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1); - // After images are added to prompt and chat thread, we clear the imageURI list - // Note: This has to be done after imageURIs are no longer needed by LlmModule - mSelectedImageUri = null; - promptID++; - Runnable runnable = - new Runnable() { - @Override - public void run() { - Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE); - ETLogging.getInstance().log("starting runnable generate()"); - runOnUiThread( - new Runnable() { - @Override - public void run() { - onModelRunStarted(); - } - }); - long generateStartTime = System.currentTimeMillis(); - if (ModelUtils.getModelCategory( - mCurrentSettingsFields.getModelType(), - mCurrentSettingsFields.getBackendType()) - == ModelUtils.VISION_MODEL) { - mModule.generate( - finalPrompt, ModelUtils.VISION_MODEL_SEQ_LEN, MainActivity.this, false); - } else if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_GUARD_3) { - String llamaGuardPromptForClassification = - PromptFormat.getFormattedLlamaGuardPrompt(rawPrompt); - ETLogging.getInstance() - .log("Running inference.. prompt=" + llamaGuardPromptForClassification); - mModule.generate( - llamaGuardPromptForClassification, - llamaGuardPromptForClassification.length() + 64, - MainActivity.this, - false); - } else { - ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt); - mModule.generate( - finalPrompt, - (int) (finalPrompt.length() * 0.75) + 64, - MainActivity.this, - false); - } - - long generateDuration = System.currentTimeMillis() - generateStartTime; - mResultMessage.setTotalGenerationTime(generateDuration); - runOnUiThread( - new Runnable() { - @Override - public void run() { - onModelRunStopped(); - } - }); - ETLogging.getInstance().log("Inference completed"); - } - }; - executor.execute(runnable); - }); - mMessageAdapter.notifyDataSetChanged(); - } - - @Override - public void run() { - runOnUiThread( - new Runnable() { - @Override - public void run() { - mMessageAdapter.notifyDataSetChanged(); - } - }); - } - - @Override - public void onBackPressed() { - super.onBackPressed(); - if (mAddMediaLayout != null && mAddMediaLayout.getVisibility() == View.VISIBLE) { - mAddMediaLayout.setVisibility(View.GONE); - } else { - // Default behavior of back button - finish(); - } - } - - @Override - protected void onDestroy() { - super.onDestroy(); - mMemoryUpdateHandler.removeCallbacks(memoryUpdater); - // This is to cover the case where the app is shutdown when user is on MainActivity but - // never clicked on the logsActivity - ETLogging.getInstance().saveLogs(); - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java deleted file mode 100644 index b2e5380e2a5..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import java.text.SimpleDateFormat; -import java.util.Date; -import java.util.Locale; - -public class Message { - private String text; - private final boolean isSent; - private float tokensPerSecond; - private long totalGenerationTime; - private final long timestamp; - private final MessageType messageType; - private String imagePath; - private final int promptID; - - private static final String TIMESTAMP_FORMAT = "hh:mm a"; // example: 2:23 PM - - public Message(String text, boolean isSent, MessageType messageType, int promptID) { - this.isSent = isSent; - this.messageType = messageType; - this.promptID = promptID; - - if (messageType == MessageType.IMAGE) { - this.imagePath = text; - } else { - this.text = text; - } - - if (messageType != MessageType.SYSTEM) { - this.timestamp = System.currentTimeMillis(); - } else { - this.timestamp = (long) 0; - } - } - - public int getPromptID() { - return promptID; - } - - public MessageType getMessageType() { - return messageType; - } - - public String getImagePath() { - return imagePath; - } - - public String getText() { - return text; - } - - public void appendText(String text) { - this.text += text; - } - - public boolean getIsSent() { - return isSent; - } - - public void setTokensPerSecond(float tokensPerSecond) { - this.tokensPerSecond = tokensPerSecond; - } - - public void setTotalGenerationTime(long totalGenerationTime) { - this.totalGenerationTime = totalGenerationTime; - } - - public float getTokensPerSecond() { - return tokensPerSecond; - } - - public long getTotalGenerationTime() { - return totalGenerationTime; - } - - public long getTimestamp() { - return timestamp; - } - - public String getFormattedTimestamp() { - SimpleDateFormat formatter = new SimpleDateFormat(TIMESTAMP_FORMAT, Locale.getDefault()); - Date date = new Date(timestamp); - return formatter.format(date); - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java deleted file mode 100644 index 31aaa9a1d5f..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import android.net.Uri; -import android.view.LayoutInflater; -import android.view.View; -import android.view.ViewGroup; -import android.widget.ArrayAdapter; -import android.widget.ImageView; -import android.widget.TextView; -import java.util.ArrayList; -import java.util.Collections; - -public class MessageAdapter extends ArrayAdapter { - - private final ArrayList savedMessages; - - public MessageAdapter( - android.content.Context context, int resource, ArrayList savedMessages) { - super(context, resource); - this.savedMessages = savedMessages; - } - - @Override - public View getView(int position, View convertView, ViewGroup parent) { - Message currentMessage = getItem(position); - int layoutIdForListItem; - - if (currentMessage.getMessageType() == MessageType.SYSTEM) { - layoutIdForListItem = R.layout.system_message; - } else { - layoutIdForListItem = - currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message; - } - View listItemView = - LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false); - if (currentMessage.getMessageType() == MessageType.IMAGE) { - ImageView messageImageView = listItemView.requireViewById(R.id.message_image); - messageImageView.setImageURI(Uri.parse(currentMessage.getImagePath())); - TextView messageTextView = listItemView.requireViewById(R.id.message_text); - messageTextView.setVisibility(View.GONE); - } else { - TextView messageTextView = listItemView.requireViewById(R.id.message_text); - messageTextView.setText(currentMessage.getText()); - } - - String metrics = ""; - TextView tokensView; - if (currentMessage.getTokensPerSecond() > 0) { - metrics = String.format("%.2f", currentMessage.getTokensPerSecond()) + "t/s "; - } - - if (currentMessage.getTotalGenerationTime() > 0) { - metrics = metrics + (float) currentMessage.getTotalGenerationTime() / 1000 + "s "; - } - - if (currentMessage.getTokensPerSecond() > 0 || currentMessage.getTotalGenerationTime() > 0) { - tokensView = listItemView.requireViewById(R.id.generation_metrics); - tokensView.setText(metrics); - TextView separatorView = listItemView.requireViewById(R.id.bar); - separatorView.setVisibility(View.VISIBLE); - } - - if (currentMessage.getTimestamp() > 0) { - TextView timestampView = listItemView.requireViewById(R.id.timestamp); - timestampView.setText(currentMessage.getFormattedTimestamp()); - } - - return listItemView; - } - - @Override - public void add(Message msg) { - super.add(msg); - savedMessages.add(msg); - } - - @Override - public void clear() { - super.clear(); - savedMessages.clear(); - } - - public ArrayList getSavedMessages() { - return savedMessages; - } - - public ArrayList getRecentSavedTextMessages(int numOfLatestPromptMessages) { - ArrayList recentMessages = new ArrayList(); - int lastIndex = savedMessages.size() - 1; - // In most cases lastIndex >=0 . - // A situation where the user clears chat history and enters prompt. Causes lastIndex=-1 . - if (lastIndex >= 0) { - Message messageToAdd = savedMessages.get(lastIndex); - int oldPromptID = messageToAdd.getPromptID(); - - for (int i = 0; i < savedMessages.size(); i++) { - messageToAdd = savedMessages.get(lastIndex - i); - if (messageToAdd.getMessageType() != MessageType.SYSTEM) { - if (messageToAdd.getPromptID() != oldPromptID) { - numOfLatestPromptMessages--; - oldPromptID = messageToAdd.getPromptID(); - } - if (numOfLatestPromptMessages > 0) { - if (messageToAdd.getMessageType() == MessageType.TEXT) { - recentMessages.add(messageToAdd); - } - } else { - break; - } - } - } - // To place the order in [input1, output1, input2, output2...] - Collections.reverse(recentMessages); - } - - return recentMessages; - } - - public int getMaxPromptID() { - int maxPromptID = -1; - for (Message msg : savedMessages) { - - maxPromptID = Math.max(msg.getPromptID(), maxPromptID); - } - return maxPromptID; - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageType.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageType.java deleted file mode 100644 index 6042acb5726..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageType.java +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -public enum MessageType { - TEXT, - IMAGE, - SYSTEM -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java deleted file mode 100644 index a1bc205c4ac..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Looper; -import android.os.Message; -import androidx.annotation.NonNull; -import org.json.JSONException; -import org.json.JSONObject; -import org.pytorch.executorch.extension.llm.LlmCallback; -import org.pytorch.executorch.extension.llm.LlmModule; - -/** A helper class to handle all model running logic within this class. */ -public class ModelRunner implements LlmCallback { - LlmModule mModule = null; - - String mModelFilePath = ""; - String mTokenizerFilePath = ""; - - ModelRunnerCallback mCallback = null; - - HandlerThread mHandlerThread = null; - Handler mHandler = null; - - /** - * ] Helper class to separate between UI logic and model runner logic. Automatically handle - * generate() request on worker thread. - * - * @param modelFilePath - * @param tokenizerFilePath - * @param callback - */ - ModelRunner( - String modelFilePath, - String tokenizerFilePath, - float temperature, - ModelRunnerCallback callback) { - mModelFilePath = modelFilePath; - mTokenizerFilePath = tokenizerFilePath; - mCallback = callback; - - mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); - mHandlerThread = new HandlerThread("ModelRunner"); - mHandlerThread.start(); - mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); - } - - int generate(String prompt) { - Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); - msg.sendToTarget(); - return 0; - } - - void stop() { - mModule.stop(); - } - - @Override - public void onResult(String result) { - mCallback.onTokenGenerated(result); - } - - @Override - public void onStats(String stats) { - float tps = 0; - try { - JSONObject jsonObject = new JSONObject(stats); - int numGeneratedTokens = jsonObject.getInt("generated_tokens"); - int inferenceEndMs = jsonObject.getInt("inference_end_ms"); - int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); - tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; - } catch (JSONException e) { - } - mCallback.onStats("tokens/second: " + tps); - } -} - -class ModelRunnerHandler extends Handler { - public static int MESSAGE_LOAD_MODEL = 1; - public static int MESSAGE_GENERATE = 2; - - private final ModelRunner mModelRunner; - - public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { - super(looper); - mModelRunner = modelRunner; - } - - @Override - public void handleMessage(@NonNull android.os.Message msg) { - if (msg.what == MESSAGE_LOAD_MODEL) { - int status = mModelRunner.mModule.load(); - mModelRunner.mCallback.onModelLoaded(status); - } else if (msg.what == MESSAGE_GENERATE) { - mModelRunner.mModule.generate((String) msg.obj, mModelRunner); - mModelRunner.mCallback.onGenerationStopped(); - } - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java deleted file mode 100644 index 5e8b6f00e3d..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -/** - * A helper interface within the app for MainActivity and Benchmarking to handle callback from - * ModelRunner. - */ -public interface ModelRunnerCallback { - - void onModelLoaded(int status); - - void onTokenGenerated(String token); - - void onStats(String stats); - - void onGenerationStopped(); -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java deleted file mode 100644 index 9f8132504ea..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -public enum ModelType { - LLAMA_3, - LLAMA_3_1, - LLAMA_3_2, - LLAVA_1_5, - LLAMA_GUARD_3, - QWEN_3, -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java deleted file mode 100644 index cf7ab1756ce..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -public class ModelUtils { - // XNNPACK or QNN - static final int TEXT_MODEL = 1; - - // XNNPACK - static final int VISION_MODEL = 2; - static final int VISION_MODEL_IMAGE_CHANNELS = 3; - static final int VISION_MODEL_SEQ_LEN = 768; - static final int TEXT_MODEL_SEQ_LEN = 256; - - // MediaTek - static final int MEDIATEK_TEXT_MODEL = 3; - - // QNN static llama - static final int QNN_TEXT_MODEL = 4; - - public static int getModelCategory(ModelType modelType, BackendType backendType) { - if (backendType.equals(BackendType.XNNPACK)) { - switch (modelType) { - case LLAVA_1_5: - return VISION_MODEL; - case LLAMA_3: - case LLAMA_3_1: - case LLAMA_3_2: - case QWEN_3: - default: - return TEXT_MODEL; - } - } else if (backendType.equals(BackendType.MEDIATEK)) { - return MEDIATEK_TEXT_MODEL; - } else if (backendType.equals(BackendType.QUALCOMM)) { - return QNN_TEXT_MODEL; - } - - return TEXT_MODEL; // default - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java deleted file mode 100644 index 524ad7cbf6d..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -public class PromptFormat { - - public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}"; - public static final String USER_PLACEHOLDER = "{{ user_prompt }}"; - public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}"; - public static final String THINKING_MODE_PLACEHOLDER = "{{ thinking_mode }}"; - public static final String DEFAULT_SYSTEM_PROMPT = "Answer the questions in a few sentences"; - - public static String getSystemPromptTemplate(ModelType modelType) { - switch (modelType) { - case LLAMA_3: - case LLAMA_3_1: - case LLAMA_3_2: - return "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" - + SYSTEM_PLACEHOLDER - + "<|eot_id|>"; - case LLAVA_1_5: - return "USER: "; - case QWEN_3: - return "<|im_start|>system\n" + "You are a helpful assistant.\n" + "<|im_end|>\n"; - default: - return SYSTEM_PLACEHOLDER; - } - } - - public static String getUserPromptTemplate(ModelType modelType, boolean thinkingMode) { - switch (modelType) { - case LLAMA_3: - case LLAMA_3_1: - case LLAMA_3_2: - case LLAMA_GUARD_3: - return "<|start_header_id|>user<|end_header_id|>\n" - + USER_PLACEHOLDER - + "<|eot_id|>" - + "<|start_header_id|>assistant<|end_header_id|>"; - - case QWEN_3: - return "<|im_start|>user\n" - + USER_PLACEHOLDER - + "\n<|im_end|>\n" - + "<|im_start|>assistant\n" - + THINKING_MODE_PLACEHOLDER; - case LLAVA_1_5: - default: - return USER_PLACEHOLDER; - } - } - - public static String getConversationFormat(ModelType modelType) { - switch (modelType) { - case LLAMA_3: - case LLAMA_3_1: - case LLAMA_3_2: - return getUserPromptTemplate(modelType, false) - + "\n" - + ASSISTANT_PLACEHOLDER - + "<|eot_id|>"; - case LLAVA_1_5: - return USER_PLACEHOLDER + " ASSISTANT:"; - case QWEN_3: - return getUserPromptTemplate(modelType, false) + "<|im_end|>\n"; - default: - return USER_PLACEHOLDER; - } - } - - public static String getStopToken(ModelType modelType) { - switch (modelType) { - case LLAMA_3: - case LLAMA_3_1: - case LLAMA_3_2: - case LLAMA_GUARD_3: - return "<|eot_id|>"; - case LLAVA_1_5: - return ""; - case QWEN_3: - return "<|endoftext|>"; - default: - return ""; - } - } - - public static String getThinkingModeToken(ModelType modelType, boolean thinkingMode) { - switch (modelType) { - case QWEN_3: - return thinkingMode ? "" : "\n\n\n\n\n"; - default: - return ""; - } - } - - public static String replaceSpecialToken(ModelType modelType, String token) { - switch (modelType) { - case QWEN_3: - switch (token) { - case "<|im_end|>": - return ""; - case "": - return "Thinking...\n"; - case "": - return "\nDone thinking"; - } - default: - return token; - } - } - - public static String getLlavaPresetPrompt() { - return "A chat between a curious human and an artificial intelligence assistant. The assistant" - + " gives helpful, detailed, and polite answers to the human's questions. USER: "; - } - - public static String getFormattedLlamaGuardPrompt(String userPrompt) { - return getUserPromptTemplate(ModelType.LLAMA_GUARD_3, false) - .replace( - USER_PLACEHOLDER, getLlamaGuardPresetPrompt().replace(USER_PLACEHOLDER, userPrompt)); - } - - public static String getLlamaGuardPresetPrompt() { - return "Task: Check if there is unsafe content in 'User' messages in conversations" - + " according our safety policy with the below categories.\n" - + "\n" - + "\n" - + "S1: Violent Crimes. \n" - + "S2: Non-Violent Crimes. \n" - + "S3: Sex Crimes. \n" - + "S4: Child Exploitation. \n" - + "S5: Defamation. \n" - + "S6: Specialized Advice. \n" - + "S7: Privacy. \n" - + "S8: Intellectual Property. \n" - + "S9: Indiscriminate Weapons. \n" - + "S10: Hate. \n" - + "S11: Self-Harm. \n" - + "S12: Sexual Content. \n" - + "S13: Elections. \n" - + "S14: Code Interpreter Abuse.\n" - + "\n" - + "\n" - + "\n" - + "\n" - + "User: " - + USER_PLACEHOLDER - + "\n" - + "\n" - + "Provide your safety assessment for ONLY THE LAST User message in the above" - + " conversation:\n" - + " - First line must read 'safe' or 'unsafe'.\n" - + " - If unsafe, a second line must include a comma-separated list of violated" - + " categories."; - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java deleted file mode 100644 index 0e388a5b0a4..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java +++ /dev/null @@ -1,463 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -import android.app.AlertDialog; -import android.content.DialogInterface; -import android.os.Build; -import android.os.Bundle; -import android.text.Editable; -import android.text.TextWatcher; -import android.view.View; -import android.widget.Button; -import android.widget.EditText; -import android.widget.ImageButton; -import android.widget.TextView; -import androidx.appcompat.app.AppCompatActivity; -import androidx.core.content.ContextCompat; -import androidx.core.graphics.Insets; -import androidx.core.view.ViewCompat; -import androidx.core.view.WindowInsetsCompat; -import com.google.gson.Gson; -import java.io.File; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -public class SettingsActivity extends AppCompatActivity { - - private String mModelFilePath = ""; - private String mTokenizerFilePath = ""; - private TextView mBackendTextView; - private TextView mModelTextView; - private TextView mTokenizerTextView; - private TextView mModelTypeTextView; - private EditText mSystemPromptEditText; - private EditText mUserPromptEditText; - private Button mLoadModelButton; - private double mSetTemperature; - private String mSystemPrompt; - private String mUserPrompt; - private BackendType mBackendType; - private ModelType mModelType; - public SettingsFields mSettingsFields; - - private DemoSharedPreferences mDemoSharedPreferences; - public static double TEMPERATURE_MIN_VALUE = 0.0; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - setContentView(R.layout.activity_settings); - if (Build.VERSION.SDK_INT >= 21) { - getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar)); - getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar)); - } - ViewCompat.setOnApplyWindowInsetsListener( - requireViewById(R.id.main), - (v, insets) -> { - Insets systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars()); - v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom); - return insets; - }); - mDemoSharedPreferences = new DemoSharedPreferences(getBaseContext()); - mSettingsFields = new SettingsFields(); - setupSettings(); - } - - private void setupSettings() { - mBackendTextView = requireViewById(R.id.backendTextView); - mModelTextView = requireViewById(R.id.modelTextView); - mTokenizerTextView = requireViewById(R.id.tokenizerTextView); - mModelTypeTextView = requireViewById(R.id.modelTypeTextView); - ImageButton backendImageButton = requireViewById(R.id.backendImageButton); - ImageButton modelImageButton = requireViewById(R.id.modelImageButton); - ImageButton tokenizerImageButton = requireViewById(R.id.tokenizerImageButton); - ImageButton modelTypeImageButton = requireViewById(R.id.modelTypeImageButton); - mSystemPromptEditText = requireViewById(R.id.systemPromptText); - mUserPromptEditText = requireViewById(R.id.userPromptText); - loadSettings(); - - // TODO: The two setOnClickListeners will be removed after file path issue is resolved - backendImageButton.setOnClickListener( - view -> { - setupBackendSelectorDialog(); - }); - modelImageButton.setOnClickListener( - view -> { - setupModelSelectorDialog(); - }); - tokenizerImageButton.setOnClickListener( - view -> { - setupTokenizerSelectorDialog(); - }); - modelTypeImageButton.setOnClickListener( - view -> { - setupModelTypeSelectorDialog(); - }); - mModelFilePath = mSettingsFields.getModelFilePath(); - if (!mModelFilePath.isEmpty()) { - mModelTextView.setText(getFilenameFromPath(mModelFilePath)); - } - mTokenizerFilePath = mSettingsFields.getTokenizerFilePath(); - if (!mTokenizerFilePath.isEmpty()) { - mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath)); - } - mModelType = mSettingsFields.getModelType(); - ETLogging.getInstance().log("mModelType from settings " + mModelType); - if (mModelType != null) { - mModelTypeTextView.setText(mModelType.toString()); - } - mBackendType = mSettingsFields.getBackendType(); - ETLogging.getInstance().log("mBackendType from settings " + mBackendType); - if (mBackendType != null) { - mBackendTextView.setText(mBackendType.toString()); - setBackendSettingMode(); - } - - setupParameterSettings(); - setupPromptSettings(); - setupClearChatHistoryButton(); - setupLoadModelButton(); - } - - private void setupLoadModelButton() { - mLoadModelButton = requireViewById(R.id.loadModelButton); - mLoadModelButton.setEnabled(true); - mLoadModelButton.setOnClickListener( - view -> { - new AlertDialog.Builder(this) - .setTitle("Load Model") - .setMessage("Do you really want to load the new model?") - .setIcon(android.R.drawable.ic_dialog_alert) - .setPositiveButton( - android.R.string.yes, - new DialogInterface.OnClickListener() { - public void onClick(DialogInterface dialog, int whichButton) { - mSettingsFields.saveLoadModelAction(true); - mLoadModelButton.setEnabled(false); - onBackPressed(); - } - }) - .setNegativeButton(android.R.string.no, null) - .show(); - }); - } - - private void setupClearChatHistoryButton() { - Button clearChatButton = requireViewById(R.id.clearChatButton); - clearChatButton.setOnClickListener( - view -> { - new AlertDialog.Builder(this) - .setTitle("Delete Chat History") - .setMessage("Do you really want to delete chat history?") - .setIcon(android.R.drawable.ic_dialog_alert) - .setPositiveButton( - android.R.string.yes, - new DialogInterface.OnClickListener() { - public void onClick(DialogInterface dialog, int whichButton) { - mSettingsFields.saveIsClearChatHistory(true); - } - }) - .setNegativeButton(android.R.string.no, null) - .show(); - }); - } - - private void setupParameterSettings() { - setupTemperatureSettings(); - } - - private void setupTemperatureSettings() { - mSetTemperature = mSettingsFields.getTemperature(); - EditText temperatureEditText = requireViewById(R.id.temperatureEditText); - temperatureEditText.setText(String.valueOf(mSetTemperature)); - temperatureEditText.addTextChangedListener( - new TextWatcher() { - @Override - public void beforeTextChanged(CharSequence s, int start, int count, int after) {} - - @Override - public void onTextChanged(CharSequence s, int start, int before, int count) {} - - @Override - public void afterTextChanged(Editable s) { - mSetTemperature = Double.parseDouble(s.toString()); - // This is needed because temperature is changed together with model loading - // Once temperature is no longer in LlmModule constructor, we can remove this - mSettingsFields.saveLoadModelAction(true); - saveSettings(); - } - }); - } - - private void setupPromptSettings() { - setupSystemPromptSettings(); - setupUserPromptSettings(); - } - - private void setupSystemPromptSettings() { - mSystemPrompt = mSettingsFields.getSystemPrompt(); - mSystemPromptEditText.setText(mSystemPrompt); - mSystemPromptEditText.addTextChangedListener( - new TextWatcher() { - @Override - public void beforeTextChanged(CharSequence s, int start, int count, int after) {} - - @Override - public void onTextChanged(CharSequence s, int start, int before, int count) {} - - @Override - public void afterTextChanged(Editable s) { - mSystemPrompt = s.toString(); - } - }); - - ImageButton resetSystemPrompt = requireViewById(R.id.resetSystemPrompt); - resetSystemPrompt.setOnClickListener( - view -> { - new AlertDialog.Builder(this) - .setTitle("Reset System Prompt") - .setMessage("Do you really want to reset system prompt?") - .setIcon(android.R.drawable.ic_dialog_alert) - .setPositiveButton( - android.R.string.yes, - new DialogInterface.OnClickListener() { - public void onClick(DialogInterface dialog, int whichButton) { - // Clear the messageAdapter and sharedPreference - mSystemPromptEditText.setText(PromptFormat.DEFAULT_SYSTEM_PROMPT); - } - }) - .setNegativeButton(android.R.string.no, null) - .show(); - }); - } - - private void setupUserPromptSettings() { - mUserPrompt = mSettingsFields.getUserPrompt(); - mUserPromptEditText.setText(mUserPrompt); - mUserPromptEditText.addTextChangedListener( - new TextWatcher() { - @Override - public void beforeTextChanged(CharSequence s, int start, int count, int after) {} - - @Override - public void onTextChanged(CharSequence s, int start, int before, int count) {} - - @Override - public void afterTextChanged(Editable s) { - if (isValidUserPrompt(s.toString())) { - mUserPrompt = s.toString(); - } else { - showInvalidPromptDialog(); - } - } - }); - - ImageButton resetUserPrompt = requireViewById(R.id.resetUserPrompt); - resetUserPrompt.setOnClickListener( - view -> { - new AlertDialog.Builder(this) - .setTitle("Reset Prompt Template") - .setMessage("Do you really want to reset the prompt template?") - .setIcon(android.R.drawable.ic_dialog_alert) - .setPositiveButton( - android.R.string.yes, - new DialogInterface.OnClickListener() { - public void onClick(DialogInterface dialog, int whichButton) { - // Clear the messageAdapter and sharedPreference - mUserPromptEditText.setText( - PromptFormat.getUserPromptTemplate(mModelType, false)); - } - }) - .setNegativeButton(android.R.string.no, null) - .show(); - }); - } - - private boolean isValidUserPrompt(String userPrompt) { - return userPrompt.contains(PromptFormat.USER_PLACEHOLDER); - } - - private void showInvalidPromptDialog() { - new AlertDialog.Builder(this) - .setTitle("Invalid Prompt Format") - .setMessage( - "Prompt format must contain " - + PromptFormat.USER_PLACEHOLDER - + ". Do you want to reset prompt format?") - .setIcon(android.R.drawable.ic_dialog_alert) - .setPositiveButton( - android.R.string.yes, - (dialog, whichButton) -> { - mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType, false)); - }) - .setNegativeButton(android.R.string.no, null) - .show(); - } - - private void setupBackendSelectorDialog() { - // Convert enum to list - List backendTypesList = new ArrayList<>(); - for (BackendType backendType : BackendType.values()) { - backendTypesList.add(backendType.toString()); - } - // Alert dialog builder takes in arr of string instead of list - String[] backendTypes = backendTypesList.toArray(new String[0]); - AlertDialog.Builder backendTypeBuilder = new AlertDialog.Builder(this); - backendTypeBuilder.setTitle("Select backend type"); - backendTypeBuilder.setSingleChoiceItems( - backendTypes, - -1, - (dialog, item) -> { - mBackendTextView.setText(backendTypes[item]); - mBackendType = BackendType.valueOf(backendTypes[item]); - setBackendSettingMode(); - dialog.dismiss(); - }); - - backendTypeBuilder.create().show(); - } - - private void setupModelSelectorDialog() { - String[] pteFiles = listLocalFile("/data/local/tmp/llama/", new String[] {".pte"}); - AlertDialog.Builder modelPathBuilder = new AlertDialog.Builder(this); - modelPathBuilder.setTitle("Select model path"); - - modelPathBuilder.setSingleChoiceItems( - pteFiles, - -1, - (dialog, item) -> { - mModelFilePath = pteFiles[item]; - mModelTextView.setText(getFilenameFromPath(mModelFilePath)); - mLoadModelButton.setEnabled(true); - dialog.dismiss(); - }); - - modelPathBuilder.create().show(); - } - - private static boolean fileHasExtension(String file, String[] suffix) { - return Arrays.stream(suffix).anyMatch(entry -> file.endsWith(entry)); - } - - private static String[] listLocalFile(String path, String[] suffix) { - File directory = new File(path); - if (directory.exists() && directory.isDirectory()) { - File[] files = directory.listFiles((dir, name) -> (fileHasExtension(name, suffix))); - String[] result = new String[files.length]; - for (int i = 0; i < files.length; i++) { - if (files[i].isFile() && fileHasExtension(files[i].getName(), suffix)) { - result[i] = files[i].getAbsolutePath(); - } - } - return result; - } - return new String[] {}; - } - - private void setupModelTypeSelectorDialog() { - // Convert enum to list - List modelTypesList = new ArrayList<>(); - for (ModelType modelType : ModelType.values()) { - modelTypesList.add(modelType.toString()); - } - // Alert dialog builder takes in arr of string instead of list - String[] modelTypes = modelTypesList.toArray(new String[0]); - AlertDialog.Builder modelTypeBuilder = new AlertDialog.Builder(this); - modelTypeBuilder.setTitle("Select model type"); - modelTypeBuilder.setSingleChoiceItems( - modelTypes, - -1, - (dialog, item) -> { - mModelTypeTextView.setText(modelTypes[item]); - mModelType = ModelType.valueOf(modelTypes[item]); - mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType, false)); - dialog.dismiss(); - }); - - modelTypeBuilder.create().show(); - } - - private void setupTokenizerSelectorDialog() { - String[] tokenizerFiles = - listLocalFile("/data/local/tmp/llama/", new String[] {".bin", ".json", ".model"}); - AlertDialog.Builder tokenizerPathBuilder = new AlertDialog.Builder(this); - tokenizerPathBuilder.setTitle("Select tokenizer path"); - tokenizerPathBuilder.setSingleChoiceItems( - tokenizerFiles, - -1, - (dialog, item) -> { - mTokenizerFilePath = tokenizerFiles[item]; - mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath)); - mLoadModelButton.setEnabled(true); - dialog.dismiss(); - }); - - tokenizerPathBuilder.create().show(); - } - - private String getFilenameFromPath(String uriFilePath) { - String[] segments = uriFilePath.split("/"); - if (segments.length > 0) { - return segments[segments.length - 1]; // get last element (aka filename) - } - return ""; - } - - private void setBackendSettingMode() { - if (mBackendType.equals(BackendType.XNNPACK) || mBackendType.equals(BackendType.QUALCOMM)) { - setXNNPACKSettingMode(); - } else if (mBackendType.equals(BackendType.MEDIATEK)) { - setMediaTekSettingMode(); - } - } - - private void setXNNPACKSettingMode() { - requireViewById(R.id.modelLayout).setVisibility(View.VISIBLE); - requireViewById(R.id.tokenizerLayout).setVisibility(View.VISIBLE); - requireViewById(R.id.parametersView).setVisibility(View.VISIBLE); - requireViewById(R.id.temperatureLayout).setVisibility(View.VISIBLE); - mModelFilePath = ""; - mTokenizerFilePath = ""; - } - - private void setMediaTekSettingMode() { - requireViewById(R.id.modelLayout).setVisibility(View.GONE); - requireViewById(R.id.tokenizerLayout).setVisibility(View.GONE); - requireViewById(R.id.parametersView).setVisibility(View.GONE); - requireViewById(R.id.temperatureLayout).setVisibility(View.GONE); - mModelFilePath = "/in/mtk/llama/runner"; - mTokenizerFilePath = "/in/mtk/llama/runner"; - } - - private void loadSettings() { - Gson gson = new Gson(); - String settingsFieldsJSON = mDemoSharedPreferences.getSettings(); - if (!settingsFieldsJSON.isEmpty()) { - mSettingsFields = gson.fromJson(settingsFieldsJSON, SettingsFields.class); - } - } - - private void saveSettings() { - mSettingsFields.saveModelPath(mModelFilePath); - mSettingsFields.saveTokenizerPath(mTokenizerFilePath); - mSettingsFields.saveParameters(mSetTemperature); - mSettingsFields.savePrompts(mSystemPrompt, mUserPrompt); - mSettingsFields.saveModelType(mModelType); - mSettingsFields.saveBackendType(mBackendType); - mDemoSharedPreferences.addSettings(mSettingsFields); - } - - @Override - public void onBackPressed() { - super.onBackPressed(); - saveSettings(); - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java deleted file mode 100644 index 94036f43947..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.example.executorchllamademo; - -public class SettingsFields { - - public String getModelFilePath() { - return modelFilePath; - } - - public String getTokenizerFilePath() { - return tokenizerFilePath; - } - - public double getTemperature() { - return temperature; - } - - public String getSystemPrompt() { - return systemPrompt; - } - - public ModelType getModelType() { - return modelType; - } - - public BackendType getBackendType() { - return backendType; - } - - public String getUserPrompt() { - return userPrompt; - } - - public String getFormattedSystemAndUserPrompt(String prompt, boolean thinkingMode) { - return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt, thinkingMode); - } - - public String getFormattedSystemPrompt() { - return PromptFormat.getSystemPromptTemplate(modelType) - .replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt); - } - - public String getFormattedUserPrompt(String prompt, boolean thinkingMode) { - return userPrompt - .replace(PromptFormat.USER_PLACEHOLDER, prompt) - .replace( - PromptFormat.THINKING_MODE_PLACEHOLDER, - PromptFormat.getThinkingModeToken(modelType, thinkingMode)); - } - - public boolean getIsClearChatHistory() { - return isClearChatHistory; - } - - public boolean getIsLoadModel() { - return isLoadModel; - } - - private String modelFilePath; - private String tokenizerFilePath; - private double temperature; - private String systemPrompt; - private String userPrompt; - private boolean isClearChatHistory; - private boolean isLoadModel; - private ModelType modelType; - private BackendType backendType; - - public SettingsFields() { - ModelType DEFAULT_MODEL = ModelType.LLAMA_3; - BackendType DEFAULT_BACKEND = BackendType.XNNPACK; - - modelFilePath = ""; - tokenizerFilePath = ""; - temperature = SettingsActivity.TEMPERATURE_MIN_VALUE; - systemPrompt = ""; - userPrompt = PromptFormat.getUserPromptTemplate(DEFAULT_MODEL, false); - isClearChatHistory = false; - isLoadModel = false; - modelType = DEFAULT_MODEL; - backendType = DEFAULT_BACKEND; - } - - public SettingsFields(SettingsFields settingsFields) { - this.modelFilePath = settingsFields.modelFilePath; - this.tokenizerFilePath = settingsFields.tokenizerFilePath; - this.temperature = settingsFields.temperature; - this.systemPrompt = settingsFields.getSystemPrompt(); - this.userPrompt = settingsFields.getUserPrompt(); - this.isClearChatHistory = settingsFields.getIsClearChatHistory(); - this.isLoadModel = settingsFields.getIsLoadModel(); - this.modelType = settingsFields.modelType; - this.backendType = settingsFields.backendType; - } - - public void saveModelPath(String modelFilePath) { - this.modelFilePath = modelFilePath; - } - - public void saveTokenizerPath(String tokenizerFilePath) { - this.tokenizerFilePath = tokenizerFilePath; - } - - public void saveModelType(ModelType modelType) { - this.modelType = modelType; - } - - public void saveBackendType(BackendType backendType) { - this.backendType = backendType; - } - - public void saveParameters(Double temperature) { - this.temperature = temperature; - } - - public void savePrompts(String systemPrompt, String userPrompt) { - this.systemPrompt = systemPrompt; - this.userPrompt = userPrompt; - } - - public void saveIsClearChatHistory(boolean needToClear) { - this.isClearChatHistory = needToClear; - } - - public void saveLoadModelAction(boolean shouldLoadModel) { - this.isLoadModel = shouldLoadModel; - } - - public boolean equals(SettingsFields anotherSettingsFields) { - if (this == anotherSettingsFields) return true; - return modelFilePath.equals(anotherSettingsFields.modelFilePath) - && tokenizerFilePath.equals(anotherSettingsFields.tokenizerFilePath) - && temperature == anotherSettingsFields.temperature - && systemPrompt.equals(anotherSettingsFields.systemPrompt) - && userPrompt.equals(anotherSettingsFields.userPrompt) - && isClearChatHistory == anotherSettingsFields.isClearChatHistory - && isLoadModel == anotherSettingsFields.isLoadModel - && modelType == anotherSettingsFields.modelType - && backendType == anotherSettingsFields.backendType; - } -} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/banner_shape.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/banner_shape.xml deleted file mode 100644 index 0868ffffa6f..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/banner_shape.xml +++ /dev/null @@ -1,5 +0,0 @@ - - - - \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_add_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_add_24.xml deleted file mode 100644 index 2ae27b8409e..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_add_24.xml +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml deleted file mode 100644 index 7077fedd483..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_article_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_article_24.xml deleted file mode 100644 index a6837b9c69f..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_article_24.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_close_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_close_24.xml deleted file mode 100644 index fb902d4331b..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_close_24.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_delete_forever_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_delete_forever_24.xml deleted file mode 100644 index 4680bc6629e..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_delete_forever_24.xml +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_lightbulb_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_lightbulb_24.xml deleted file mode 100644 index aa045396d28..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_lightbulb_24.xml +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_restart_alt_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_restart_alt_24.xml deleted file mode 100644 index 860470ab109..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_restart_alt_24.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_send_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_send_24.xml deleted file mode 100644 index 2de1f642089..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_send_24.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_settings_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_settings_24.xml deleted file mode 100644 index c51d84b9f4f..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_settings_24.xml +++ /dev/null @@ -1,11 +0,0 @@ - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_stop_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_stop_24.xml deleted file mode 100644 index 832e2585954..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_stop_24.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/blue_lightbulb_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/blue_lightbulb_24.xml deleted file mode 100644 index 585cd3b1892..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/blue_lightbulb_24.xml +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/btn.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/btn.xml deleted file mode 100644 index ceb3ac56c9e..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/btn.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/chat_background.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/chat_background.xml deleted file mode 100644 index eb8b9d1f1a9..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/chat_background.xml +++ /dev/null @@ -1,21 +0,0 @@ - - - - - - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/custom_button_round.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/custom_button_round.xml deleted file mode 100644 index 87c82d2a38d..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/custom_button_round.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/expand_circle_down.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/expand_circle_down.xml deleted file mode 100644 index 0a7a71f0700..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/expand_circle_down.xml +++ /dev/null @@ -1,9 +0,0 @@ - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/ic_launcher_background.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/ic_launcher_background.xml deleted file mode 100644 index 07d5da9cbf1..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/ic_launcher_background.xml +++ /dev/null @@ -1,170 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/ic_launcher_foreground.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/ic_launcher_foreground.xml deleted file mode 100644 index 7706ab9e6d4..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/ic_launcher_foreground.xml +++ /dev/null @@ -1,30 +0,0 @@ - - - - - - - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/input_text_shape.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/input_text_shape.xml deleted file mode 100644 index 35c778a437d..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/input_text_shape.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/logo.png b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/logo.png deleted file mode 100644 index 60e3e5174e9..00000000000 Binary files a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/logo.png and /dev/null differ diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_add_box_48.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_add_box_48.xml deleted file mode 100644 index bb45d63d85b..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_add_box_48.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_camera_alt_48.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_camera_alt_48.xml deleted file mode 100644 index c7b4b2e4a1d..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_camera_alt_48.xml +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_image_48.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_image_48.xml deleted file mode 100644 index a8bb4b2f646..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_image_48.xml +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/prompt_shape.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/prompt_shape.xml deleted file mode 100644 index 5f81396e382..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/prompt_shape.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/received_message.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/received_message.xml deleted file mode 100644 index c2288b5bfce..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/received_message.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/sent_message.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/sent_message.xml deleted file mode 100644 index e8d13ca4e12..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/sent_message.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/three_dots.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/three_dots.xml deleted file mode 100644 index afbe22da808..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/three_dots.xml +++ /dev/null @@ -1,5 +0,0 @@ - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_benchmarking.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_benchmarking.xml deleted file mode 100644 index 6e48b5de8be..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_benchmarking.xml +++ /dev/null @@ -1,16 +0,0 @@ - - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_logs.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_logs.xml deleted file mode 100644 index b327a544f25..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_logs.xml +++ /dev/null @@ -1,55 +0,0 @@ - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml deleted file mode 100644 index 52bf533521a..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml +++ /dev/null @@ -1,241 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_settings.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_settings.xml deleted file mode 100644 index 0ec551ae364..00000000000 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_settings.xml +++ /dev/null @@ -1,338 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -