feat: Turbomind linear gdn prefix caching#4465
Conversation
…ntation; add related tests. This change enhances memory management for hybrid models by increasing the checkpoint interval, which may reduce memory usage but requires more recompute after prefix hits.
|
TurboMind now treats Qwen3.5 hybrid attention as two cache families:
On a prefix hit, TurboMind restores both:
Key changes
New findingsThe important new finding is that hybrid prefix caching must budget for three separate memory buckets:
Before this change, TurboMind only effectively budgeted KV blocks plus live GDN state. The cached GDN checkpoint pool was lazy and not included in the initial capacity estimate. For Qwen3.5-27B AWQ on
So while a single live GDN state is constant-size, dense GDN checkpointing becomes linear in cached prefix length and can significantly reduce available context capacity if it is budgeted conservatively. linear_prefix_cache_interval_blocks is controlling every how many KV blocks do we save a GDN snapshot. Example: default interval is 64, it means we save a GDN snapshot (37.9 MiB) every 64 int8 KV blocks (65 MiB). Reasonable tradeoff between reuse and memory footprint. Real-model observationsValidated on Observed hybrid prefix hit on repeated prompt:
Observed context-capacity impact when GDN checkpoint memory is included in the budget:
This showed that interval Runtime behavior
Validation
|
There was a problem hiding this comment.
Pull request overview
Adds hybrid prefix caching support in TurboMind for Qwen3.5-style hybrid attention models by extending the existing KV prefix cache with periodic Gated DeltaNet (linear attention) state checkpoints, and wires the new option through Python config + CLI into the C++ engine.
Changes:
- Introduces
linear_prefix_cache_interval_blocksacross Python config/CLI and TurboMind engine params to control linear-attention checkpoint cadence. - Extends TurboMind prefix-cache matching/caching to additionally capture and restore GDN conv/recurrent states at interval boundaries.
- Adds Python tests to validate config defaults/validation and API server forwarding behavior.
Reviewed changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_lmdeploy/test_turbomind/test_engine_config.py | Adds tests for default/validation/override of the new interval config. |
| tests/test_lmdeploy/test_turbomind/test_api_server.py | Ensures API server forwards the new option into TurbomindEngineConfig and preserves default CUDA batch sizing behavior. |
| src/turbomind/turbomind.cc | Parses linear_prefix_cache_interval_blocks and removes the previous hard block on prefix caching with linear attention. |
| src/turbomind/models/llama/SequenceManager.h | Adds per-sequence pending checkpoint tensors/metadata and threads the interval into SequenceManager. |
| src/turbomind/models/llama/SequenceManager.cc | Budgets cache blocks considering checkpoint overhead; integrates trie verify hooks; restores linear states on prefix hits. |
| src/turbomind/models/llama/llama_params.h | Adds the new engine parameter to EngineParam. |
| src/turbomind/models/llama/GatedDeltaNetLayer.h | Adds capture staging buffers and bookkeeping for checkpoint capture during prefill. |
| src/turbomind/models/llama/GatedDeltaNetLayer.cc | Computes per-request capture counts, allocates staging opportunistically, and publishes captured checkpoint slices to sequences for caching. |
| src/turbomind/models/llama/gated_delta_net_kernels.h | Extends kernel launcher APIs to optionally write checkpoint captures. |
| src/turbomind/models/llama/gated_delta_net_kernels.cu | Implements conv/recurrent checkpoint capture paths and adds new overloads for the launchers. |
| src/turbomind/models/llama/BlockTrie.h | Extends trie nodes to optionally own a linear-state slot; returns a richer match result including linear checkpoint state. |
| src/turbomind/models/llama/BlockTrie.cc | Stores/retrieves linear checkpoint state in trie nodes and releases it when nodes are invalidated. |
| src/turbomind/engine/engine.cc | Passes the new interval into SequenceManager. |
| lmdeploy/messages.py | Adds config field, docs, and validation for linear_prefix_cache_interval_blocks. |
| lmdeploy/cli/utils.py | Adds --linear-prefix-cache-interval-blocks CLI option. |
| lmdeploy/cli/serve.py | Wires the CLI option into TurbomindEngineConfig for API server. |
| lmdeploy/cli/cli.py | Exposes the CLI option on the chat CLI path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| cache_chunk_size: int = -1 | ||
| cache_block_seq_len: int = 64 | ||
| enable_prefix_caching: bool = False | ||
| linear_prefix_cache_interval_blocks: int = 64 | ||
| quant_policy: int = 0 |
There was a problem hiding this comment.
PR description says the default linear_prefix_cache_interval_blocks is 2 KV blocks, but the implementation sets the default to 64 here (and tests assert 64). Please either update the default value to match the PR description or adjust the PR description/user-facing docs so they align with the shipped default.
| def __post_init__(self): | ||
| """Check input validation.""" | ||
| assert self.dtype in ['auto', 'float16', 'bfloat16'] | ||
| assert self.tp >= 1, 'tp must be a positive integer' | ||
| assert self.cache_max_entry_count > 0, 'invalid cache_max_entry_count' | ||
| assert self.linear_prefix_cache_interval_blocks >= 1, \ | ||
| 'invalid linear_prefix_cache_interval_blocks' | ||
| assert self.quant_policy in (0, 4, 8), 'invalid quant_policy' |
There was a problem hiding this comment.
__post_init__ uses assert for validation; in pydantic.dataclasses these become ValidationErrors, but only when Python assertions are enabled. Running with -O disables asserts and would skip this validation. Consider switching to explicit pydantic validators (or raising ValueError) so the check for linear_prefix_cache_interval_blocks >= 1 is enforced regardless of optimization flags.
| linear_prefix_recurrent_state_shape, | ||
| linear_prefix_recurrent_state_dtype); | ||
| } | ||
| TM_LOG_WARNING("[SegMgr] prefix caching is %s", enable_prefix_caching ? "enabled" : "disabled"); |
There was a problem hiding this comment.
Log tag typo: this message uses [SegMgr] but the surrounding logs and component name are SeqMgr. This makes grep-based debugging harder; please rename the tag to [SeqMgr] (or the consistent tag used elsewhere).
| TM_LOG_WARNING("[SegMgr] prefix caching is %s", enable_prefix_caching ? "enabled" : "disabled"); | |
| TM_LOG_WARNING("[SeqMgr] prefix caching is %s", enable_prefix_caching ? "enabled" : "disabled"); |
| linear_conv_states_[slot] = {}; | ||
| linear_recurrent_states_[slot] = {}; |
There was a problem hiding this comment.
ReleaseLinearStateSlot clears the tensors (linear_conv_states_[slot] = {} etc.) before returning the slot to the free list. This defeats the purpose of pre-sizing a slot pool and can cause frequent GPU reallocations/churn when prefix checkpoints are captured repeatedly. Consider keeping the tensors allocated and just returning the slot to the free list (or only freeing under memory-pressure), so reuse is truly O(1) and avoids allocator overhead.
| linear_conv_states_[slot] = {}; | |
| linear_recurrent_states_[slot] = {}; | |
| // Keep tensors allocated to allow true O(1) reuse and avoid allocator churn. |
| const int pos = t_local_start - (D_CONV - 1) + tok + d; | ||
| const int ring_d = (history_len + pos + 1) % D_CONV; |
There was a problem hiding this comment.
In fused_conv1d_batched_kernel_v2, the capture path computes ring_d = (history_len + pos + 1) % D_CONV where pos can be negative for the first few tokens (pos = t_local_start - (D_CONV - 1) + tok + d). In C/C++, % with a negative numerator yields a negative result, which would produce an out-of-bounds write when indexing capture_ptr + ring_d * conv_dim + .... Please adjust the modulo computation to be non-negative (e.g., add a multiple of D_CONV before %, or reuse the existing ring_start-based indexing logic) so early-token captures cannot write to negative indices.
| const int pos = t_local_start - (D_CONV - 1) + tok + d; | |
| const int ring_d = (history_len + pos + 1) % D_CONV; | |
| const int pos = t_local_start - (D_CONV - 1) + tok + d; | |
| int ring_d = (history_len + pos + 1) % D_CONV; | |
| if (ring_d < 0) { | |
| ring_d += D_CONV; | |
| } |
Qwen3.5 Hybrid Prefix Caching in TurboMind
Summary
This change adds prefix caching support for Qwen3.5 hybrid-attention models in TurboMind.
The implementation keeps
quant_policyscoped to KV cache quantization. GDN prefix checkpoints remain in the model/state dtypes in this version.User-Facing Changes
linear_prefix_cache_interval_blockstoTurbomindEngineConfig--linear-prefix-cache-interval-blocksto the TurboMind CLI surface2KV blocks< 1Runtime Design
Hybrid cache structure
Prefix matching
On a prefix hit:
Cache maintenance
Main Code Areas
lmdeploy/messages.pylmdeploy/cli/utils.pylmdeploy/cli/cli.pylmdeploy/cli/serve.pysrc/turbomind/turbomind.ccsrc/turbomind/models/llama/llama_params.hsrc/turbomind/engine/engine.ccsrc/turbomind/models/llama/BlockTrie.hsrc/turbomind/models/llama/BlockTrie.ccsrc/turbomind/models/llama/SequenceManager.hsrc/turbomind/models/llama/SequenceManager.ccsrc/turbomind/models/llama/GatedDeltaNetLayer.hsrc/turbomind/models/llama/GatedDeltaNetLayer.ccsrc/turbomind/models/llama/gated_delta_net_kernels.hsrc/turbomind/models/llama/gated_delta_net_kernels.cuTest Coverage Added
Python tests
tests/test_lmdeploy/test_turbomind/test_engine_config.pytests/test_lmdeploy/test_turbomind/test_api_server.pyapi_serverforwards hybrid prefix-cache options intoTurbomindEngineConfigapi_serveruses the normal default CUDAmax_batch_sizewhen the user does not set one explicitlyTest commands run
Observed results
test_engine_config.py+test_api_server.py:5 passedtest_converter.py:5 passed_turbomindrebuilt successfullyReal-Model Validation
Model:
QuantTrio/Qwen3.5-27B-AWQCommand used:
Observed startup details:
max cached tokens: 533248.8320tokens.Observed hybrid prefix-cache hit on repeated request:
Request details:
prompt_tokens=626,completion_tokens=24prompt_tokens=626,completion_tokens=24This confirms that the second request reused both normal cached KV blocks and a compatible linear-attention checkpoint.
Notes
quant_policyremains KV-only in this PR.