Skip to content

feat: Turbomind linear gdn prefix caching#4465

Open
lapy wants to merge 5 commits intoInternLM:mainfrom
lapy:turbomind-linear-gdn-prefix-caching
Open

feat: Turbomind linear gdn prefix caching#4465
lapy wants to merge 5 commits intoInternLM:mainfrom
lapy:turbomind-linear-gdn-prefix-caching

Conversation

@lapy
Copy link
Copy Markdown
Contributor

@lapy lapy commented Mar 25, 2026

Qwen3.5 Hybrid Prefix Caching in TurboMind

Summary

This change adds prefix caching support for Qwen3.5 hybrid-attention models in TurboMind.

  • Full-attention layers keep using the existing KV prefix cache.
  • Gated DeltaNet linear-attention layers now store checkpointed recurrent state at configurable KV-block boundaries.
  • Prefix matches for hybrid models restore both:
    • shared KV blocks for full-attention layers
    • the closest compatible GDN checkpoint state

The implementation keeps quant_policy scoped to KV cache quantization. GDN prefix checkpoints remain in the model/state dtypes in this version.

User-Facing Changes

  • Added linear_prefix_cache_interval_blocks to TurbomindEngineConfig
  • Added --linear-prefix-cache-interval-blocks to the TurboMind CLI surface
  • Default interval is 2 KV blocks
  • Validation rejects values < 1

Runtime Design

Hybrid cache structure

  • Existing KV prefix cache is unchanged for full-attention layers.
  • A second cache family stores GDN prefix checkpoints.
  • Each checkpoint stores:
    • convolution state
    • recurrent state
  • Checkpoints are attached to trie nodes at the configured interval.

Prefix matching

On a prefix hit:

  • TurboMind matches normal KV blocks as before.
  • For hybrid models it also finds the deepest trie node with a valid GDN checkpoint.
  • Reusable prefix length is clamped to the deepest compatible linear checkpoint.
  • The matched GDN state is restored into the live per-sequence GDN buffers before decode continues.

Cache maintenance

  • GDN checkpoint slots are released when trie nodes are invalidated.
  • When KV cached blocks are freed or evicted, the trie is verified immediately so the corresponding GDN checkpoints are pruned in the same path.
  • If the GDN checkpoint pool is exhausted, TurboMind skips storing deeper checkpoints instead of aborting the request.
  • Warm-up requests never allocate GDN prefix checkpoint staging.
  • Large real batches that cannot afford checkpoint staging continue to run; they simply skip storing new GDN checkpoints for that batch.

Main Code Areas

  • CLI and engine config
    • lmdeploy/messages.py
    • lmdeploy/cli/utils.py
    • lmdeploy/cli/cli.py
    • lmdeploy/cli/serve.py
    • src/turbomind/turbomind.cc
    • src/turbomind/models/llama/llama_params.h
    • src/turbomind/engine/engine.cc
  • Core hybrid prefix-cache logic
    • src/turbomind/models/llama/BlockTrie.h
    • src/turbomind/models/llama/BlockTrie.cc
    • src/turbomind/models/llama/SequenceManager.h
    • src/turbomind/models/llama/SequenceManager.cc
    • src/turbomind/models/llama/GatedDeltaNetLayer.h
    • src/turbomind/models/llama/GatedDeltaNetLayer.cc
    • src/turbomind/models/llama/gated_delta_net_kernels.h
    • src/turbomind/models/llama/gated_delta_net_kernels.cu

Test Coverage Added

Python tests

  • tests/test_lmdeploy/test_turbomind/test_engine_config.py
    • default interval value
    • validation for invalid interval values
    • explicit override handling
  • tests/test_lmdeploy/test_turbomind/test_api_server.py
    • TurboMind api_server forwards hybrid prefix-cache options into TurbomindEngineConfig
    • TurboMind api_server uses the normal default CUDA max_batch_size when the user does not set one explicitly

Test commands run

python -m pytest -q tests/test_lmdeploy/test_turbomind/test_engine_config.py tests/test_lmdeploy/test_turbomind/test_api_server.py
python -m pytest -q tests/test_lmdeploy/test_turbomind/test_converter.py
cmake --build /root/lmdeploy/build --target _turbomind -j4

Observed results

  • test_engine_config.py + test_api_server.py: 5 passed
  • test_converter.py: 5 passed
  • _turbomind rebuilt successfully

Real-Model Validation

Model:

  • QuantTrio/Qwen3.5-27B-AWQ

Command used:

TM_LOG_LEVEL=INFO CUDA_VISIBLE_DEVICES=1,2 lmdeploy serve api_server \
  QuantTrio/Qwen3.5-27B-AWQ \
  --tp 2 \
  --server-port 23335 \
  --reasoning-parser qwen-qwq \
  --tool-call-parser qwen3coder \
  --quant-policy 8 \
  --enable-prefix-caching

Observed startup details:

  • Server reached full Uvicorn startup successfully.
  • TurboMind reported max cached tokens: 533248.
  • Warm-up completed successfully through 8320 tokens.

Observed hybrid prefix-cache hit on repeated request:

[TM][INFO] [SeqMgr][match] ID 2, hit blocks 8, linear_cache_len 512, cache_len 0
[TM][INFO] [SeqMgr][match] ID 2, after matching, blocks 8, cache_len 512

Request details:

  • request 1: prompt_tokens=626, completion_tokens=24
  • request 2: prompt_tokens=626, completion_tokens=24

This confirms that the second request reused both normal cached KV blocks and a compatible linear-attention checkpoint.

Notes

  • quant_policy remains KV-only in this PR.

lapy added 5 commits March 25, 2026 08:26
@lapy
Copy link
Copy Markdown
Contributor Author

lapy commented Mar 25, 2026

TurboMind now treats Qwen3.5 hybrid attention as two cache families:

  • standard KV prefix cache for full-attention layers
  • checkpointed Gated DeltaNet (GDN) state for linear-attention layers

On a prefix hit, TurboMind restores both:

  • shared KV blocks for the matched full-attention prefix
  • the deepest compatible cached GDN checkpoint for the matched linear-attention prefix

Key changes

  • Changed the default linear_prefix_cache_interval_blocks from 2 to 64.
  • Updated the CLI/help text to describe the tradeoff more clearly: larger values reduce GDN checkpoint memory usage but increase recompute after a prefix hit.

New findings

The important new finding is that hybrid prefix caching must budget for three separate memory buckets:

  • KV cache blocks
  • live per-sequence GDN state
  • cached GDN prefix checkpoints

Before this change, TurboMind only effectively budgeted KV blocks plus live GDN state. The cached GDN checkpoint pool was lazy and not included in the initial capacity estimate.

For Qwen3.5-27B AWQ on tp=2 with quant_policy=8, a single cached GDN checkpoint is much larger than it first appears:

  • one int8 KV block of 64 items: about 1.016 MiB
  • one GDN checkpoint snapshot slot: about 37.9 MiB per rank

So while a single live GDN state is constant-size, dense GDN checkpointing becomes linear in cached prefix length and can significantly reduce available context capacity if it is budgeted conservatively.

linear_prefix_cache_interval_blocks is controlling every how many KV blocks do we save a GDN snapshot.

Example: default interval is 64, it means we save a GDN snapshot (37.9 MiB) every 64 int8 KV blocks (65 MiB). Reasonable tradeoff between reuse and memory footprint.

Real-model observations

Validated on QuantTrio/Qwen3.5-27B-AWQ with TurboMind and real repeated requests.

Observed hybrid prefix hit on repeated prompt:

  • hit blocks 8
  • linear_cache_len 512
  • after matching, blocks 8, cache_len 512

Observed context-capacity impact when GDN checkpoint memory is included in the budget:

  • interval 2: max cached tokens = 27200
  • interval 64: max cached tokens = 337600
  • interval 128: max cached tokens = 413952

This showed that interval 2 is far too dense for this model on 32GB V100s, while 64 and 128 are both practical. Based on these results, the default was changed to 64.

Runtime behavior

  • Huge prompts can still run even if new GDN checkpoints cannot be stored.
  • If GDN checkpoint staging would exceed the per-batch budget, TurboMind skips storing new GDN checkpoints for that batch instead of aborting the request.
  • If the GDN checkpoint slot pool is exhausted, deeper checkpoints are skipped until cached entries are evicted.
  • Prefix caching remains opportunistic acceleration data rather than a hard requirement for forward progress.

Validation

  • QuantTrio/Qwen3.5-27B-AWQ
  • CUDA_VISIBLE_DEVICES=1,2
  • tp=2
  • quant_policy=8
  • prefix caching enabled

@lvhan028 lvhan028 requested review from Copilot and lzhangzz March 26, 2026 09:56
@lvhan028 lvhan028 added the enhancement New feature or request label Apr 2, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds hybrid prefix caching support in TurboMind for Qwen3.5-style hybrid attention models by extending the existing KV prefix cache with periodic Gated DeltaNet (linear attention) state checkpoints, and wires the new option through Python config + CLI into the C++ engine.

Changes:

  • Introduces linear_prefix_cache_interval_blocks across Python config/CLI and TurboMind engine params to control linear-attention checkpoint cadence.
  • Extends TurboMind prefix-cache matching/caching to additionally capture and restore GDN conv/recurrent states at interval boundaries.
  • Adds Python tests to validate config defaults/validation and API server forwarding behavior.

Reviewed changes

Copilot reviewed 17 out of 17 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/test_lmdeploy/test_turbomind/test_engine_config.py Adds tests for default/validation/override of the new interval config.
tests/test_lmdeploy/test_turbomind/test_api_server.py Ensures API server forwards the new option into TurbomindEngineConfig and preserves default CUDA batch sizing behavior.
src/turbomind/turbomind.cc Parses linear_prefix_cache_interval_blocks and removes the previous hard block on prefix caching with linear attention.
src/turbomind/models/llama/SequenceManager.h Adds per-sequence pending checkpoint tensors/metadata and threads the interval into SequenceManager.
src/turbomind/models/llama/SequenceManager.cc Budgets cache blocks considering checkpoint overhead; integrates trie verify hooks; restores linear states on prefix hits.
src/turbomind/models/llama/llama_params.h Adds the new engine parameter to EngineParam.
src/turbomind/models/llama/GatedDeltaNetLayer.h Adds capture staging buffers and bookkeeping for checkpoint capture during prefill.
src/turbomind/models/llama/GatedDeltaNetLayer.cc Computes per-request capture counts, allocates staging opportunistically, and publishes captured checkpoint slices to sequences for caching.
src/turbomind/models/llama/gated_delta_net_kernels.h Extends kernel launcher APIs to optionally write checkpoint captures.
src/turbomind/models/llama/gated_delta_net_kernels.cu Implements conv/recurrent checkpoint capture paths and adds new overloads for the launchers.
src/turbomind/models/llama/BlockTrie.h Extends trie nodes to optionally own a linear-state slot; returns a richer match result including linear checkpoint state.
src/turbomind/models/llama/BlockTrie.cc Stores/retrieves linear checkpoint state in trie nodes and releases it when nodes are invalidated.
src/turbomind/engine/engine.cc Passes the new interval into SequenceManager.
lmdeploy/messages.py Adds config field, docs, and validation for linear_prefix_cache_interval_blocks.
lmdeploy/cli/utils.py Adds --linear-prefix-cache-interval-blocks CLI option.
lmdeploy/cli/serve.py Wires the CLI option into TurbomindEngineConfig for API server.
lmdeploy/cli/cli.py Exposes the CLI option on the chat CLI path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 283 to 287
cache_chunk_size: int = -1
cache_block_seq_len: int = 64
enable_prefix_caching: bool = False
linear_prefix_cache_interval_blocks: int = 64
quant_policy: int = 0
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description says the default linear_prefix_cache_interval_blocks is 2 KV blocks, but the implementation sets the default to 64 here (and tests assert 64). Please either update the default value to match the PR description or adjust the PR description/user-facing docs so they align with the shipped default.

Copilot uses AI. Check for mistakes.
Comment on lines 302 to 309
def __post_init__(self):
"""Check input validation."""
assert self.dtype in ['auto', 'float16', 'bfloat16']
assert self.tp >= 1, 'tp must be a positive integer'
assert self.cache_max_entry_count > 0, 'invalid cache_max_entry_count'
assert self.linear_prefix_cache_interval_blocks >= 1, \
'invalid linear_prefix_cache_interval_blocks'
assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__post_init__ uses assert for validation; in pydantic.dataclasses these become ValidationErrors, but only when Python assertions are enabled. Running with -O disables asserts and would skip this validation. Consider switching to explicit pydantic validators (or raising ValueError) so the check for linear_prefix_cache_interval_blocks >= 1 is enforced regardless of optimization flags.

Copilot uses AI. Check for mistakes.
linear_prefix_recurrent_state_shape,
linear_prefix_recurrent_state_dtype);
}
TM_LOG_WARNING("[SegMgr] prefix caching is %s", enable_prefix_caching ? "enabled" : "disabled");
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Log tag typo: this message uses [SegMgr] but the surrounding logs and component name are SeqMgr. This makes grep-based debugging harder; please rename the tag to [SeqMgr] (or the consistent tag used elsewhere).

Suggested change
TM_LOG_WARNING("[SegMgr] prefix caching is %s", enable_prefix_caching ? "enabled" : "disabled");
TM_LOG_WARNING("[SeqMgr] prefix caching is %s", enable_prefix_caching ? "enabled" : "disabled");

Copilot uses AI. Check for mistakes.
Comment on lines +242 to +243
linear_conv_states_[slot] = {};
linear_recurrent_states_[slot] = {};
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReleaseLinearStateSlot clears the tensors (linear_conv_states_[slot] = {} etc.) before returning the slot to the free list. This defeats the purpose of pre-sizing a slot pool and can cause frequent GPU reallocations/churn when prefix checkpoints are captured repeatedly. Consider keeping the tensors allocated and just returning the slot to the free list (or only freeing under memory-pressure), so reuse is truly O(1) and avoids allocator overhead.

Suggested change
linear_conv_states_[slot] = {};
linear_recurrent_states_[slot] = {};
// Keep tensors allocated to allow true O(1) reuse and avoid allocator churn.

Copilot uses AI. Check for mistakes.
Comment on lines +1301 to +1302
const int pos = t_local_start - (D_CONV - 1) + tok + d;
const int ring_d = (history_len + pos + 1) % D_CONV;
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fused_conv1d_batched_kernel_v2, the capture path computes ring_d = (history_len + pos + 1) % D_CONV where pos can be negative for the first few tokens (pos = t_local_start - (D_CONV - 1) + tok + d). In C/C++, % with a negative numerator yields a negative result, which would produce an out-of-bounds write when indexing capture_ptr + ring_d * conv_dim + .... Please adjust the modulo computation to be non-negative (e.g., add a multiple of D_CONV before %, or reuse the existing ring_start-based indexing logic) so early-token captures cannot write to negative indices.

Suggested change
const int pos = t_local_start - (D_CONV - 1) + tok + d;
const int ring_d = (history_len + pos + 1) % D_CONV;
const int pos = t_local_start - (D_CONV - 1) + tok + d;
int ring_d = (history_len + pos + 1) % D_CONV;
if (ring_d < 0) {
ring_d += D_CONV;
}

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants