diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 169da9c150..0bf8d8f9e9 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -381,11 +381,20 @@ def get_moe_group_name(group): AscendKVQuantMeta.set_value(step_context.block_offsets.device, step_context.model_config.dtype, record_file, total_layers) + cu_seqlens = None + has_initial_state = None + + if step_context.state_offsets is not None: + q_start_loc = step_context.q_start_loc + cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int() + if not step_context.is_decoding: + has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens) + attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( step_context.is_decoding, step_context.block_offsets, - q_start_loc=None, + q_start_loc=cu_seqlens, q_seqlens=q_seqlens_cpu, # kv_seqlens_expanded is only expanded in paged prefill, # otherwise it equals kv_seqlens_cpu @@ -398,6 +407,7 @@ def get_moe_group_name(group): max_kv_seq_len=max_kv_seq_len, quant_policy=step_context.kv_quant_policy, quant_meta=AscendKVQuantMeta.quant_meta, + has_initial_state=has_initial_state, ) step_context.attn_metadata = attn_metadata @@ -440,7 +450,9 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_ def init(): """Initialize Ascend backend.""" try: + from dlinfer.vendor.ascend.triton_ops.triton_utils import init_device_properties_triton from torch_npu.contrib import transfer_to_npu # noqa: F401 + init_device_properties_triton() except ImportError: logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. ' 'Ascend initialization skipped.') diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 78afe49040..e0bead2889 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -18,6 +18,7 @@ class DlinferAttentionMetadata(AttentionMetadata): max_kv_seq_len: int = 1 quant_meta: dict = None cu_seq_lens_kv: Tensor | None = None + has_initial_state: Tensor | None = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index e78b87e811..e3c2800e47 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -443,7 +443,8 @@ def from_hf_config( model_path, tp=tp, is_draft_model=is_draft_model, - spec_method=spec_method) + spec_method=spec_method, + device_type=device_type) if model_config.k_head_dim is None: assert model_config.head_dim is not None diff --git a/lmdeploy/pytorch/configurations/qwen3_5.py b/lmdeploy/pytorch/configurations/qwen3_5.py index 4c78705a25..0c12b7d93c 100644 --- a/lmdeploy/pytorch/configurations/qwen3_5.py +++ b/lmdeploy/pytorch/configurations/qwen3_5.py @@ -44,7 +44,8 @@ def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs): conv_state_shape = (num_delta_layers, conv_dim, conv_kernel_size) recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim) - if is_bf16_supported(): + device_type = kwargs.get('device_type', 'cuda') + if is_bf16_supported(device_type): dtype = torch.bfloat16 else: dtype = torch.float16