Query-Adaptive Latent Working Memory for Long-Context Language Models
AdaMem compresses arbitrarily long context into a fixed budget of K soft memory tokens, conditioned on the current query. The decoder LLM operates on these K tokens instead of raw long context, reducing decoder-side cost, generation KV-cache, and repeated-query serving cost.
Long Context ──chunk──> ChunkEncoder ──> Coarse Memory H (cacheable, query-independent)
│
Query q ───────────────────────────> LatentResampler(H, q) ──> K memory tokens Z
│
[Z ; q] ──> Frozen LLM ──> answer
- Query-adaptive: unlike ICAE or AutoCompressor (query-agnostic), the resampler selects and fuses information relevant to the current query
- Two-stage caching: Stage A (ChunkEncoder) is offline and cacheable; Stage B (Resampler) is lightweight and per-query. Coarse memory is computed once, reused across queries
- Fixed decoder budget: the LLM always sees K tokens (default 32) regardless of context length — 100x compression at 4K, 3000x at 100K
| Stage | Module | Input | Output | Trainable |
|---|---|---|---|---|
| A | ChunkEncoder | Context token embeddings | Coarse memory [B, M, D] (M = N * S) |
Yes |
| B | LatentResampler | Coarse memory + query | Memory tokens [B, K, D] |
Yes |
| — | Projection | Memory tokens | LLM-dim memory tokens | Yes |
| C | SoftPrefixInterface | Memory tokens + decoder prompt | LLM logits + loss | No (frozen LLM) |
The design follows the BLIP-2 paradigm: ChunkEncoder is analogous to a frozen image encoder producing cacheable representations; LatentResampler is analogous to Q-Former with learnable latent queries cross-attending to encoder output.
pip install -e .
# With evaluation dependencies
pip install -e ".[eval]"
# With development dependencies
pip install -e ".[dev]"Requires Python >= 3.10 and PyTorch >= 2.1.
from adamem.config import AdaMemConfig
from adamem.models.adamem import AdaMem
# Load config and build model
config = AdaMemConfig.from_yaml("configs/default.yaml")
model = AdaMem(config.model)
model.setup_llm("Qwen/Qwen2.5-0.5B")
# Forward pass (training)
output = model(
context_ids=context_ids, # [B, L] long context
context_mask=context_mask,
query_ids=query_ids, # [B, Lq] question
decoder_input_ids=decoder_ids, # [B, Ld] = [query; answer] for teacher forcing
labels=labels, # [B, Ld] with -100 on query positions
)
loss = output.loss
# Generation (inference)
tokens = model.generate(
context_ids=context_ids,
query_ids=query_ids,
max_new_tokens=256,
)
# Cache coarse memory for multi-query serving
coarse = model.encode_context(context_ids, context_mask)
for q in queries:
tokens = model.generate(query_ids=q, coarse_memory=coarse)Phase 1 pretraining uses FineWeb (sample-10BT, ~30GB). By default the training script streams from HuggingFace Hub, but downloading locally is recommended for faster I/O:
bash scripts/download_fineweb.shThis downloads parquet files to /home/greenland-user/data/fineweb-10BT/. The path is configured via data.pretrain_local_path in experiments/configs/pretrain.yaml. When the local path is set and data exists, the training script reads from disk instead of streaming.
Three-phase training recipe (LLM stays frozen in Phases 1-2):
| Phase | Objective | Data |
|---|---|---|
| 1. Memory Pretraining | Continuation + KL distillation vs full-context LLM | FineWeb (1-2B tokens) |
| 2. Task Fine-tuning | Answer NLL + contrastive + diversity regularization | LongBench, HotpotQA, NarrativeQA |
| 3. Decoder Unfreezing (optional) | Same as Phase 2, unfreeze top 2-4 LLM layers | Same as Phase 2 |
# Phase 1
python experiments/step1_pretrain.py --config configs/default.yaml --output_dir outputs/pretrain
# Phase 2
python experiments/step2_finetune.py \
--config experiments/configs/finetune_longbench.yaml \
--checkpoint outputs/pretrain/final \
--output_dir outputs/phase2
# Evaluation
python experiments/step3_evaluate.py \
--checkpoint outputs/phase2/final \
--benchmarks hotpotqa narrativeqa longbench_multidoc \
--output_dir outputs/results
# Full pipeline
bash experiments/run_all.shWrapper scripts are also available:
adamem-train --phase 1 --config configs/default.yaml
adamem-eval --checkpoint outputs/phase2/finalBenchmarks grouped by information structure:
| Group | Tasks | Expected behavior |
|---|---|---|
| A: Distributed fusion | Multi-doc QA, multi-hop reasoning | AdaMem's strength |
| B: Exact local recall | Needle-in-haystack, number extraction | AdaMem's weakness |
| C: Redundancy-heavy | Summarization, theme extraction | High compression OK |
| D: Sparse lookup | Single-chunk factoid QA | RAG competitive |
adamem/
config.py # AdaMemConfig with YAML serialization
data/
chunking.py # Fixed-size chunk splitting
pretraining.py # Phase 1 data pipeline
task_data.py # Phase 2 TaskDataset with format parsers
models/
adamem.py # Full pipeline: ChunkEncoder -> Resampler -> Decoder
chunk_encoder.py # Stage A: per-chunk transformer with summary states
resampler.py # Stage B: FiLM-conditioned latent resampler
projection.py # Linear / MLP projection to LLM dim
interfaces/
soft_prefix.py # Primary: prepend memory tokens to decoder input
prefix_kv.py # Alternative: inject as KV pairs (experimental)
training/
losses.py # Contrastive, diversity, distillation losses
trainer.py # Training loop utilities
evaluation/
benchmarks.py # Benchmark loading and parsing
metrics.py # F1, EM, ROUGE scoring
utils/
checkpoints.py # Save/load with unified protocol
misc.py # Logging, seeding, device utilities
experiments/
step0_sanity_check.py # Smoke tests on real model
step1_pretrain.py # Phase 1 training script
step2_finetune.py # Phase 2 training script
step3_evaluate.py # Evaluation across benchmark groups
step4_ablation.py # Ablation study manager
step5_baselines.py # Baseline comparisons
step6_analysis.py # Visualization and analysis
run_all.sh # End-to-end pipeline
configs/ # Experiment-specific YAML configs
configs/
default.yaml # Default configuration
tests/
test_models.py # 24 tests: smoke, integration, train-inference consistency
All settings are in a single YAML file. Key hyperparameters:
| Parameter | Default | Description |
|---|---|---|
K (resampler.n_latents) |
32 | Memory token budget |
S (chunk_encoder.summary_per_chunk) |
2 | Summary states per chunk |
P (chunk_encoder.chunk_size) |
256 | Tokens per chunk |
| Base LLM | Qwen2.5-0.5B | Frozen decoder |
See configs/default.yaml for the full configuration.
- H1: Query-adaptive latent memory outperforms query-agnostic compression (ICAE, AutoCompressor) on distributed evidence tasks
- H2: Under the same decoder token budget, adaptive memory outperforms top-k retrieval on tasks requiring fusion of many weak evidence pieces
- H3: Two-stage caching (offline coarse memory + online resampling) matches full-context quality while reducing repeated-query cost
See TODO.md for the current status and planned work, including:
- Near-term: Phase 1/2 pilots, evaluation, ablations
- Mid-term: Adaptive budget (dynamic K/S), causal latent slots, budget controller
- Long-term: Memory instruction tuning, native memory pre-training
Design rationale and prior art analysis: PROPOSAL.md. Experiment execution guide: EXPERIMENTS.md.
MIT