Skip to content

Ivis4ml/AdaMem

Repository files navigation

AdaMem

Query-Adaptive Latent Working Memory for Long-Context Language Models

AdaMem compresses arbitrarily long context into a fixed budget of K soft memory tokens, conditioned on the current query. The decoder LLM operates on these K tokens instead of raw long context, reducing decoder-side cost, generation KV-cache, and repeated-query serving cost.

Long Context ──chunk──> ChunkEncoder ──> Coarse Memory H (cacheable, query-independent)
                                              │
Query q ───────────────────────────> LatentResampler(H, q) ──> K memory tokens Z
                                                                      │
                                                          [Z ; q] ──> Frozen LLM ──> answer

Key Ideas

  • Query-adaptive: unlike ICAE or AutoCompressor (query-agnostic), the resampler selects and fuses information relevant to the current query
  • Two-stage caching: Stage A (ChunkEncoder) is offline and cacheable; Stage B (Resampler) is lightweight and per-query. Coarse memory is computed once, reused across queries
  • Fixed decoder budget: the LLM always sees K tokens (default 32) regardless of context length — 100x compression at 4K, 3000x at 100K

Architecture

Stage Module Input Output Trainable
A ChunkEncoder Context token embeddings Coarse memory [B, M, D] (M = N * S) Yes
B LatentResampler Coarse memory + query Memory tokens [B, K, D] Yes
Projection Memory tokens LLM-dim memory tokens Yes
C SoftPrefixInterface Memory tokens + decoder prompt LLM logits + loss No (frozen LLM)

The design follows the BLIP-2 paradigm: ChunkEncoder is analogous to a frozen image encoder producing cacheable representations; LatentResampler is analogous to Q-Former with learnable latent queries cross-attending to encoder output.

Installation

pip install -e .

# With evaluation dependencies
pip install -e ".[eval]"

# With development dependencies
pip install -e ".[dev]"

Requires Python >= 3.10 and PyTorch >= 2.1.

Quick Start

from adamem.config import AdaMemConfig
from adamem.models.adamem import AdaMem

# Load config and build model
config = AdaMemConfig.from_yaml("configs/default.yaml")
model = AdaMem(config.model)
model.setup_llm("Qwen/Qwen2.5-0.5B")

# Forward pass (training)
output = model(
    context_ids=context_ids,      # [B, L] long context
    context_mask=context_mask,
    query_ids=query_ids,          # [B, Lq] question
    decoder_input_ids=decoder_ids,  # [B, Ld] = [query; answer] for teacher forcing
    labels=labels,                # [B, Ld] with -100 on query positions
)
loss = output.loss

# Generation (inference)
tokens = model.generate(
    context_ids=context_ids,
    query_ids=query_ids,
    max_new_tokens=256,
)

# Cache coarse memory for multi-query serving
coarse = model.encode_context(context_ids, context_mask)
for q in queries:
    tokens = model.generate(query_ids=q, coarse_memory=coarse)

Data Preparation

Phase 1 pretraining uses FineWeb (sample-10BT, ~30GB). By default the training script streams from HuggingFace Hub, but downloading locally is recommended for faster I/O:

bash scripts/download_fineweb.sh

This downloads parquet files to /home/greenland-user/data/fineweb-10BT/. The path is configured via data.pretrain_local_path in experiments/configs/pretrain.yaml. When the local path is set and data exists, the training script reads from disk instead of streaming.

Training

Three-phase training recipe (LLM stays frozen in Phases 1-2):

Phase Objective Data
1. Memory Pretraining Continuation + KL distillation vs full-context LLM FineWeb (1-2B tokens)
2. Task Fine-tuning Answer NLL + contrastive + diversity regularization LongBench, HotpotQA, NarrativeQA
3. Decoder Unfreezing (optional) Same as Phase 2, unfreeze top 2-4 LLM layers Same as Phase 2
# Phase 1
python experiments/step1_pretrain.py --config configs/default.yaml --output_dir outputs/pretrain

# Phase 2
python experiments/step2_finetune.py \
    --config experiments/configs/finetune_longbench.yaml \
    --checkpoint outputs/pretrain/final \
    --output_dir outputs/phase2

# Evaluation
python experiments/step3_evaluate.py \
    --checkpoint outputs/phase2/final \
    --benchmarks hotpotqa narrativeqa longbench_multidoc \
    --output_dir outputs/results

# Full pipeline
bash experiments/run_all.sh

Wrapper scripts are also available:

adamem-train --phase 1 --config configs/default.yaml
adamem-eval --checkpoint outputs/phase2/final

Evaluation

Benchmarks grouped by information structure:

Group Tasks Expected behavior
A: Distributed fusion Multi-doc QA, multi-hop reasoning AdaMem's strength
B: Exact local recall Needle-in-haystack, number extraction AdaMem's weakness
C: Redundancy-heavy Summarization, theme extraction High compression OK
D: Sparse lookup Single-chunk factoid QA RAG competitive

Project Structure

adamem/
    config.py              # AdaMemConfig with YAML serialization
    data/
        chunking.py        # Fixed-size chunk splitting
        pretraining.py     # Phase 1 data pipeline
        task_data.py       # Phase 2 TaskDataset with format parsers
    models/
        adamem.py          # Full pipeline: ChunkEncoder -> Resampler -> Decoder
        chunk_encoder.py   # Stage A: per-chunk transformer with summary states
        resampler.py       # Stage B: FiLM-conditioned latent resampler
        projection.py      # Linear / MLP projection to LLM dim
        interfaces/
            soft_prefix.py # Primary: prepend memory tokens to decoder input
            prefix_kv.py   # Alternative: inject as KV pairs (experimental)
    training/
        losses.py          # Contrastive, diversity, distillation losses
        trainer.py         # Training loop utilities
    evaluation/
        benchmarks.py      # Benchmark loading and parsing
        metrics.py         # F1, EM, ROUGE scoring
    utils/
        checkpoints.py     # Save/load with unified protocol
        misc.py            # Logging, seeding, device utilities
experiments/
    step0_sanity_check.py  # Smoke tests on real model
    step1_pretrain.py      # Phase 1 training script
    step2_finetune.py      # Phase 2 training script
    step3_evaluate.py      # Evaluation across benchmark groups
    step4_ablation.py      # Ablation study manager
    step5_baselines.py     # Baseline comparisons
    step6_analysis.py      # Visualization and analysis
    run_all.sh             # End-to-end pipeline
    configs/               # Experiment-specific YAML configs
configs/
    default.yaml           # Default configuration
tests/
    test_models.py         # 24 tests: smoke, integration, train-inference consistency

Configuration

All settings are in a single YAML file. Key hyperparameters:

Parameter Default Description
K (resampler.n_latents) 32 Memory token budget
S (chunk_encoder.summary_per_chunk) 2 Summary states per chunk
P (chunk_encoder.chunk_size) 256 Tokens per chunk
Base LLM Qwen2.5-0.5B Frozen decoder

See configs/default.yaml for the full configuration.

Hypotheses

  • H1: Query-adaptive latent memory outperforms query-agnostic compression (ICAE, AutoCompressor) on distributed evidence tasks
  • H2: Under the same decoder token budget, adaptive memory outperforms top-k retrieval on tasks requiring fusion of many weak evidence pieces
  • H3: Two-stage caching (offline coarse memory + online resampling) matches full-context quality while reducing repeated-query cost

Roadmap

See TODO.md for the current status and planned work, including:

  • Near-term: Phase 1/2 pilots, evaluation, ablations
  • Mid-term: Adaptive budget (dynamic K/S), causal latent slots, budget controller
  • Long-term: Memory instruction tuning, native memory pre-training

Design rationale and prior art analysis: PROPOSAL.md. Experiment execution guide: EXPERIMENTS.md.

License

MIT

About

AdaMem: Query-Adaptive Latent Working Memory for Long-Context Language Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors