Skip to content
Merged
179 changes: 137 additions & 42 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import gc
import inspect
import logging
import os
import shutil
import subprocess
import warnings
Expand Down Expand Up @@ -65,6 +64,7 @@ class QEFFBaseModel(ABC):
_start = 0
_end = 0
_total_layers = None
_layerwise_active = False
_pytorch_transforms: List[PytorchTransform]
_onnx_transforms = [BaseOnnxTransform]

Expand Down Expand Up @@ -332,6 +332,25 @@ def _resolve_pkv_layers(pkv_obj):
return tuple(layers)
return None

def _resolve_pkv_names(layer_idx, layer_state):
if hasattr(self.model, "get_onnx_past_key_value_names"):
names = self.model.get_onnx_past_key_value_names(layer_idx, layer_state)
if names is not None:
return list(names)
state_len = len(layer_state)
if state_len == 2:
return [f"past_key.{layer_idx}", f"past_value.{layer_idx}"]
if state_len == 4:
return [
f"past_key_self.{layer_idx}",
f"past_value_self.{layer_idx}",
f"past_key_cross.{layer_idx}",
f"past_value_cross.{layer_idx}",
]
raise ValueError(
f"Unknown shape of past_key_values! Expected length of past_key_values for each layer to be either 2 or 4 but got {state_len}"
)

# Create input_names from example_inputs
input_names = []
for param in inspect.signature(self.model.forward).parameters:
Expand All @@ -342,21 +361,7 @@ def _resolve_pkv_layers(pkv_obj):
input_names.append(param)
continue
for i in range(len(pkv_layers)):
if len(pkv_layers[0]) == 2:
input_names.extend([f"past_key.{i}", f"past_value.{i}"])
elif len(pkv_layers[0]) == 4:
input_names.extend(
[
f"past_key_self.{i}",
f"past_value_self.{i}",
f"past_key_cross.{i}",
f"past_value_cross.{i}",
]
)
else:
raise ValueError(
f"Unknown shape of past_key_values! Expected length of past_key_values for each layer to be either 2 or 4 but got {len(pkv_layers[0])}"
)
input_names.extend(_resolve_pkv_names(i, pkv_layers[i]))
elif param == "compressed_kvs":
for i in range(len(example_inputs["compressed_kvs"])):
input_names.extend(
Expand Down Expand Up @@ -488,6 +493,7 @@ def _export_layerwise(
prefill_only: Optional[bool] = False,
**export_kwargs,
) -> str:
cache_probe = export_kwargs.pop("_layerwise_cache_probe", False)
idx = int(QEFFBaseModel._start)
end_idx = int(getattr(QEFFBaseModel, "_end", idx + 1))
if end_idx <= idx:
Expand All @@ -502,6 +508,20 @@ def _export_layerwise(
self.onnx_path = onnx_path
return onnx_path

# Layer-wise reuse: if the merged final ONNX from a prior run exists
# under final_data/, skip per-window export entirely. The driver's
# stitch step picks up the same merged file, so re-running the same
# example without changes goes straight to the QPC compile.
final_data_dir = export_dir / "final_data"
if final_data_dir.is_dir():
total_layers = int(getattr(QEFFBaseModel, "_total_layers", 0) or 0)
cached_merged = final_data_dir / f"merged_0-{total_layers}.onnx"
if total_layers > 0 and cached_merged.is_file():
self.onnx_path = cached_merged
return self.onnx_path
if cache_probe:
return None

# check if the model is in meta state or weights are offloaded
self._model_offloaded_check()

Expand All @@ -525,6 +545,25 @@ def _resolve_pkv_layers(pkv_obj):
return tuple(layers)
return None

def _resolve_pkv_names(layer_idx, layer_state):
if hasattr(self.model, "get_onnx_past_key_value_names"):
names = self.model.get_onnx_past_key_value_names(layer_idx, layer_state)
if names is not None:
return list(names)
state_len = len(layer_state)
if state_len == 2:
return [f"past_key.{layer_idx}", f"past_value.{layer_idx}"]
if state_len == 4:
return [
f"past_key_self.{layer_idx}",
f"past_value_self.{layer_idx}",
f"past_key_cross.{layer_idx}",
f"past_value_cross.{layer_idx}",
]
raise ValueError(
f"Unknown shape of past_key_values! Expected length of past_key_values for each layer to be either 2 or 4 but got {state_len}"
)

is_vision = hasattr(self.model, "language_model")
output_name = []
output_name.append("logits")
Expand All @@ -534,19 +573,37 @@ def _resolve_pkv_layers(pkv_obj):
if "deepstack_features_RetainedState" in output_names:
output_name.append("deepstack_features_RetainedState")
output_name.append("image_idx_output")
retained_state_suffix = (
"_InternalRetainedState" if export_kwargs.get("use_onnx_subfunctions", False) else "_RetainedState"
)
for layer_idx in range(idx, end_idx):
output_name.append(f"past_key.{layer_idx}_InternalRetainedState")
output_name.append(f"past_value.{layer_idx}_InternalRetainedState")
layer_states = _resolve_pkv_layers(example_inputs.get("past_key_values"))
if layer_states is None:
output_name.append(f"past_key.{layer_idx}{retained_state_suffix}")
output_name.append(f"past_value.{layer_idx}{retained_state_suffix}")
else:
output_name.extend(
[
f"{name}{retained_state_suffix}"
for name in _resolve_pkv_names(layer_idx, layer_states[layer_idx])
]
)

# For some decoder wrappers (e.g. VLM language wrappers), forward does not accept
# `inputs_embeds`; keep `input_ids` in those cases.
if idx >= 1:
z = example_inputs.pop("input_ids")
if is_vision:
hidden_size = self.model.language_model.config.hidden_size
embed_dtype = getattr(self.model.language_model.config, "torch_dtype", None)
else:
hidden_size = self.model.model.config.hidden_size
inputs_embeds = torch.rand(z.shape[0], z.shape[1], hidden_size, device=z.device)
embed_dtype = getattr(self.model.model.config, "torch_dtype", None)
# Match the model's dtype so per-window export does not introduce a
# float32/float16 mismatch when running through fp16 decoder layers.
if embed_dtype is None:
embed_dtype = next(self.model.parameters()).dtype
inputs_embeds = torch.rand(z.shape[0], z.shape[1], hidden_size, device=z.device, dtype=embed_dtype)
example_inputs["inputs_embeds"] = inputs_embeds
dynamic_axes["inputs_embeds"] = dynamic_axes.pop("input_ids")

Expand Down Expand Up @@ -579,23 +636,12 @@ def _resolve_pkv_layers(pkv_obj):
continue
example_inputs["past_key_values"] = [val for i, val in enumerate(pkv_layers) if i < window_size]
for i in range(len(example_inputs["past_key_values"])):
if len(example_inputs["past_key_values"][0]) == 2:
for layer_offset in range(len(example_inputs["past_key_values"])):
layer_idx = idx + layer_offset
input_names.extend([f"past_key.{layer_idx}", f"past_value.{layer_idx}"])
elif len(example_inputs["past_key_values"][0]) == 4:
for layer_offset in range(len(example_inputs["past_key_values"])):
layer_idx = idx + layer_offset
input_names.extend(
[
f"past_key_self.{i}",
f"past_value_self.{i}",
f"past_key_cross.{i}",
f"past_value_cross.{i}",
]
)
else:
raise ValueError(
f"Unknown shape of past_key_values! Expected length of past_key_values for each layer to be either 2 or 4 but got {len(example_inputs['past_key_values'][0])}"
_resolve_pkv_names(layer_idx, example_inputs["past_key_values"][layer_offset])
)
break
elif param == "compressed_kvs":
for layer_offset in range(len(example_inputs["compressed_kvs"])):
layer_idx = idx + layer_offset
Expand Down Expand Up @@ -641,6 +687,26 @@ def _resolve_pkv_layers(pkv_obj):
_onnx_transforms = [SplitTensorsTransform, CustomOpTransform, RenameFunctionOutputsTransform]
onnx_transforms = OnnxTransformPipeline(transforms=_onnx_transforms)
model, transformed = onnx_transforms.apply(model, **transform_kwargs)

def _rename_graph_value(graph: onnx.GraphProto, old_name: str, new_name: str) -> None:
if old_name == new_name:
return
for node in graph.node:
node.input[:] = [new_name if value == old_name else value for value in node.input]
node.output[:] = [new_name if value == old_name else value for value in node.output]
for initializer in graph.initializer:
if initializer.name == old_name:
initializer.name = new_name
for value_info in list(graph.input) + list(graph.output) + list(graph.value_info):
if value_info.name == old_name:
value_info.name = new_name

for output_idx, expected_name in enumerate(output_names):
if output_idx >= len(model.graph.output):
break
current_name = model.graph.output[output_idx].name
_rename_graph_value(model.graph, current_name, expected_name)

onnx.save(model, layer_onnx_path_tmp)
self.onnx_path = layer_onnx_path_tmp
return layer_onnx_path_tmp
Expand Down Expand Up @@ -735,6 +801,7 @@ def _compile(
For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""

layerwise_cache_probe = compiler_options.pop("_layerwise_cache_probe", False)
moe_prefill_packed_chunk_size = compiler_options.pop("moe_prefill_packed_chunk_size", None)
if onnx_path is None:
# If weights were offloaded after export, compiling must use the existing
Expand All @@ -754,11 +821,15 @@ def _compile(
num_devices=mdp_ts_num_devices,
qaic_config=qaic_config,
moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size,
_layerwise_cache_probe=layerwise_cache_probe,
**compiler_options,
)
onnx_path = Path(onnx_path)
if os.environ.get("LAYERWISE_EXPORT", "False") == "True":
if QEFFBaseModel._layerwise_active:
if onnx_path is None:
return None
onnx_path = Path(onnx_path)
return onnx_path
onnx_path = Path(onnx_path)

compile_dir = Path(compile_dir or onnx_path.parent)
qpc_path = compile_dir / "qpc"
Expand Down Expand Up @@ -824,6 +895,31 @@ def _compile(
continue
command.append(f"{option}={value}")

# Final custom-IO normalization against ONNX I/O names.
# This only rewrites retained-state aliases:
# *_InternalRetainedState <-> *_RetainedState.
# Any other custom-IO key is preserved as-is for backward compatibility.
if custom_io is not None and onnx_path is not None:
try:
model = onnx.load(onnx_path, load_external_data=False)
io_names = {value.name for value in list(model.graph.input) + list(model.graph.output)}
normalized_custom_io = {}
for io_name, dtype in custom_io.items():
resolved_name = io_name
if io_name not in io_names:
if io_name.endswith("_InternalRetainedState"):
candidate = io_name[: -len("_InternalRetainedState")] + "_RetainedState"
if candidate in io_names:
resolved_name = candidate
elif io_name.endswith("_RetainedState"):
candidate = io_name[: -len("_RetainedState")] + "_InternalRetainedState"
if candidate in io_names:
resolved_name = candidate
normalized_custom_io[resolved_name] = dtype
custom_io = normalized_custom_io
except Exception:
pass

if use_onnx_subfunctions:
logger.info("Using ONNX subfunctions for compilation.")
command.append("-sub-functions")
Expand All @@ -841,14 +937,13 @@ def _compile(

compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash)
qpc_path = compile_dir / "qpc"
qpc_path.mkdir(parents=True, exist_ok=True)

if (qpc_path / "programqpc.bin").is_file():
self.qpc_path = qpc_path
return qpc_path
if qpc_path.is_dir():
if (qpc_path / "programqpc.bin").is_file():
self.qpc_path = qpc_path
return qpc_path
# Probably compilation failure last time, delete directory to start over
# Probably compilation failure last time, delete directory to start over.
shutil.rmtree(qpc_path)
compile_dir.mkdir(parents=True, exist_ok=True)

# Write the generated MDP partition config file (not if user provided it)

Expand Down
30 changes: 29 additions & 1 deletion QEfficient/generation/cloud_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,26 @@

import numpy as np


def _public_retained_state_name(output_name: str) -> Optional[str]:
"""Map internal subfunction retained-state outputs to public runtime names."""
suffix = "_InternalRetainedState"
if output_name.endswith(suffix):
return output_name[: -len(suffix)] + "_RetainedState"
return None


def is_retained_state_name(name: str) -> bool:
"""Return True when an I/O binding participates in retained-state cache flow."""
return name.startswith(("past_", "conv_state.", "recurrent_state.", "compressed_", "k_pe"))


def _add_basename_binding_aliases(binding_index_map: Dict[str, int], bindings) -> None:
"""Allow callers to use unprefixed I/O names for prefixed ONNX graphs."""
for binding in bindings:
binding_index_map.setdefault(binding.name.rsplit("/", 1)[-1], binding.index)


try:
import qaicrt

Expand Down Expand Up @@ -101,6 +121,7 @@ def __init__(
]
self.bindings = iodesc.selected_set.bindings
self.binding_index_map = {binding.name: binding.index for binding in self.bindings}
_add_basename_binding_aliases(self.binding_index_map, self.bindings)
# Create and load Program
prog_properties = qaicrt.QAicProgramProperties()
prog_properties.dataPathTimeoutMs = 60_000
Expand Down Expand Up @@ -226,8 +247,15 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
buffer_index = self.binding_index_map[output_name]
if self.qbuffers[buffer_index].size == 0:
continue
outputs[output_name] = np.frombuffer(
output = np.frombuffer(
bytes(output_qbuffers[buffer_index]),
self.aic_to_np_dtype_mapping[self.bindings[buffer_index].type],
).reshape(self.buf_dims[buffer_index][1])
outputs[output_name] = output
output_basename = output_name.rsplit("/", 1)[-1]
outputs.setdefault(output_basename, output)
public_name = _public_retained_state_name(output_name)
if public_name is not None:
outputs[public_name] = output
outputs.setdefault(public_name.rsplit("/", 1)[-1], output)
return outputs
10 changes: 2 additions & 8 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import transformers
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.generation.cloud_infer import QAICInferenceSession, is_retained_state_name
from QEfficient.utils import padding_check_and_fix
from QEfficient.utils.constants import Constants
from QEfficient.utils.logging_utils import logger
Expand Down Expand Up @@ -500,13 +500,7 @@ def __init__(
self._set_tokenizer_params() # set tokenizer params
# Skip inputs/outputs
self._session.skip_buffers(
[x for x in self._session.input_names + self._session.output_names if x.startswith("past_")]
)
self._session.skip_buffers(
[x for x in self._session.input_names + self._session.output_names if x.startswith("compressed_")]
)
self._session.skip_buffers(
[x for x in self._session.input_names + self._session.output_names if x.startswith("k_pe")]
[x for x in self._session.input_names + self._session.output_names if is_retained_state_name(x)]
)

def _set_tokenizer_params(self):
Expand Down
Loading
Loading