KV Caching Deep-Dive: Memory-Efficient Transformer Inference

- Published on
- /24 mins read
nanochat Deep-Dive Series - Track 1
NOTE
Series Navigation: This is Post 1.3 of the nanochat Technical Deep-Dive series (Track 1: Understanding the "Why")
- Post 1.1: The Muon Optimizer Explained - Newton-Schulz orthogonalization
- Post 1.2: Distributed Muon - Custom gradient synchronization
- Post 1.3: KV Caching Deep-Dive ← You are here
- Post 1.4: Modern Transformer Architecture (coming soon)
- Post 1.5: Training Data Pipeline (coming soon)
- Post 1.6: Loss Landscape & Scaling Laws (coming soon)
Prerequisites: Understanding of Transformer attention mechanism, basic PyTorch
Reading time: ~12 minutes
Code: nanochat/engine.py, nanochat/gpt.py
Introduction
Autoregressive language models like GPT face a fundamental inefficiency during inference: they generate tokens one at a time, and each new token requires attending to all previous tokens in the sequence. A naive implementation recomputes the keys and values for the entire sequence at every step—a quadratic computational cost that becomes prohibitively expensive for long sequences.
KV caching solves this problem by storing previously computed key and value tensors and reusing them for subsequent tokens. This transforms the computational complexity from O(T²) to O(T) per token, typically achieving 6-10× speedup in practice.
In this deep-dive, we'll dissect nanochat's KV cache implementation, exploring:
- The computational bottleneck in autoregressive generation
- KV cache data structure and lifecycle management
- Prefill vs decode phases
- Dynamic cache growth strategy
- Batch generation with cache replication
- Multi-Query Attention (MQA) optimization for cache efficiency
Let's start by understanding exactly what computation we're trying to avoid.
The Computational Bottleneck
Naive Autoregressive Generation
Consider the standard generation loop without caching:
# Naive generation (what NOT to do)
tokens = [prompt_tokens]
for step in range(max_tokens):
logits = model(tokens) # Recomputes K,V for ALL tokens!
next_token = sample(logits[-1])
tokens.append(next_token)WARNING
The problem: At step 100, you're processing 100 tokens through the model, recomputing keys and values for all 100 tokens—even though 99 of them were already computed in step 99. This redundant computation dominates inference time.
Attention Computation Breakdown
Let's examine what happens in a single attention layer:
# From nanochat/gpt.py CausalSelfAttention.forward() lines 79-91
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
# Project input to Q, K, V
q = self.c_q(x).view(B, T, self.n_head, self.head_dim) # Cost: O(T × d²)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) # Cost: O(T × d²)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) # Cost: O(T × d²)
# Apply rotary embeddings + normalization
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k)
# Transpose for attention computation
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
# Attention computation
# q: (B, H, T, D), k: (B, H, T, D), v: (B, H, T, D)
scores = q @ k.transpose(-2, -1) # (B, H, T, T) - Cost: O(T²)
attn = softmax(scores / sqrt(D)) # Cost: O(T²)
y = attn @ v # (B, H, T, D) - Cost: O(T²)Cost analysis per layer:
- QKV projections: 3 matrix multiplies →
3 × T × d²operations - Attention computation: Q @ K^T →
T²operations - Weighted sum: Attn @ V →
T²operations
Total cost for generating T tokens (naive approach):
For token at position t:
- Must process all t previous tokens
- Cost per layer: O(t × d²) + O(t²)
- Total across L layers: O(L × t²)
Summing over all T tokens:
Total = Σ(t=1 to T) [L × t²] ≈ O(L × T³)
This cubic scaling makes long-sequence generation infeasible!
The KV Caching Solution
Key observation: In autoregressive generation, keys and values depend only on past tokens (which don't change). We can:
- Cache the computed K and V tensors
- Reuse them for subsequent tokens
- Only compute Q, K, V for the new token
Cost with KV caching:
For token at position t:
- Compute Q,K,V only for the new token (1 token, not t tokens)
- Cost per layer: O(d²) + O(t) [projection + attention to cached keys]
- Total across L layers: O(L × d²) + O(L × t)
Summing over all T tokens:
Total = Σ(t=1 to T) [L × t] ≈ O(L × T²)
TIP
Speedup: From O(L × T³) to O(L × T²) → factor of T improvement!
For a 100-token sequence with 20 layers:
- Naive: ~20 × 100³ = 20M operations
- Cached: ~20 × 100² = 200K operations
- Speedup: 100×!
In practice, accounting for memory bandwidth and other factors, real-world speedups are typically 6-10×.
KVCache Data Structure
nanochat's KVCache class manages the storage and lifecycle of cached key/value tensors.
Architecture and Initialization
# From nanochat/engine.py lines 56-66
class KVCache:
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
# Shape: (L, 2, B, H, T, D)
# ↑ ↑ ↑ ↑ ↑ ↑
# │ │ │ │ │ └─ head_dim (64-128 typically)
# │ │ │ │ └──── sequence length (max cache capacity)
# │ │ │ └─────── num_heads (or num_kv_heads for MQA)
# │ │ └────────── batch_size
# │ └───────────── 2 = [keys, values]
# └──────────────── num_layers
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_cache = None # Lazy initialization
self.pos = 0 # Current position in sequenceDesign decisions:
Lazy initialization: The cache tensor isn't allocated until first use. Why? At construction time, we don't know the device (CPU vs GPU) or dtype (float32 vs bfloat16). The first
insert_kv()call provides this information automatically.Per-layer storage: Each transformer layer has its own K,V pair, enabling pipeline parallelism and simpler indexing.
Position tracking:
self.postracks how many tokens are currently cached, advancing after the last layer processes each token.Separate K and V: Index 0 stores keys, index 1 stores values, packed together for memory locality.
Lazy Initialization
# From nanochat/engine.py lines 101-104
def insert_kv(self, layer_idx, k, v):
# Lazy initialize on first insert
if self.kv_cache is None:
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)Benefits:
- Device agnostic: Works seamlessly on CPU or GPU
- Dtype flexibility: Automatically matches model precision (fp32, bfloat16)
- Memory efficient: Only allocates when actually needed
Dynamic Growth Strategy
The cache can dynamically expand as sequences grow beyond the initial capacity:
# From nanochat/engine.py lines 106-114
def insert_kv(self, layer_idx, k, v):
B, H, T_add, D = k.size()
t0, t1 = self.pos, self.pos + T_add
# Dynamically grow the cache if needed
if t1 > self.kv_cache.size(4):
t_needed = t1 + 1024 # Current need + 1024 buffer
t_needed = (t_needed + 1023) & ~1023 # Round up to nearest 1024
current_shape = list(self.kv_cache.shape)
current_shape[4] = t_needed
self.kv_cache.resize_(current_shape)Growth strategy:
- Check capacity: Will the new tokens overflow the current cache?
- Calculate new size: Current need + 1024-token buffer for future insertions
- Round up: Align to 1024-token boundaries (efficient for GPU memory allocation)
- Resize in-place: Use
resize_()to avoid full reallocation
Example growth sequence:
Initial capacity: 2048 tokens
Step 2000: need 2001 tokens
→ t_needed: 2001 + 1024 = 3025
→ Round up: (3025 + 1023) & ~1023 = 4096
→ Resize to: 4096 tokens
Step 4000: need 4001 tokens
→ t_needed: 4001 + 1024 = 5025
→ Round up: 6144 tokens
The bitwise operation (t_needed + 1023) & ~1023 efficiently rounds up to the nearest multiple of 1024:
- Add 1023 to ensure rounding up
~1023in binary is...11110000000000(10 trailing zeros)- AND operation clears the lower 10 bits, rounding to nearest 1024
Insertion and Retrieval
# From nanochat/engine.py lines 115-124
def insert_kv(self, layer_idx, k, v):
# ... (growth logic above)
# Insert new k,v into cache at current position
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
# Return views of ALL cached k,v up to current position
key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
# Advance position after last layer processes
if layer_idx == self.kv_cache.size(0) - 1:
self.pos = t1
return key_view, value_viewKey behaviors:
- Append semantics: New tokens are inserted at position
[t0:t1], wheret0 = self.pos - Return full view: Returns k,v from
[0:t1](all tokens cached so far), not just the new tokens - Position tracking: Only updates
self.posafter the last layer processes the token(s)
Why return views, not copies?
- Memory efficient: No data duplication
- Zero-cost: PyTorch views have no computational overhead
- Automatic updates: If cache grows, views remain valid
Cache Structure Visualization
Prefill vs Decode Phases
KV-cached generation has two distinct phases with different computational characteristics.
The Two-Phase Pattern
Prefill phase: Process the entire prompt in one forward pass
- Input: Full prompt (e.g., 50 tokens)
- Output: Logits for the next token
- Cache state: Empty → filled with 50 cached tokens
- Attention pattern: Causal attention within prompt (T_q = T_k)
- Efficiency: Highly parallel, good GPU utilization
Decode phase: Generate tokens one at a time
- Input: Single new token
- Output: Logits for the next token
- Cache state: Append to existing cache
- Attention pattern: New query attends to all cached keys/values (T_q = 1, T_k = cache_length)
- Efficiency: Sequential, memory-bound
Prefill Implementation
# From nanochat/engine.py lines 180-192
def generate(self, tokens, num_samples=1, max_tokens=None, ...):
# Step 1: Prefill with batch size 1
m = self.model.config
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens),
num_heads=m.n_kv_head,
head_dim=m.n_embd // m.n_head,
num_layers=m.n_layer
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :] # Take logits at last position only
next_ids = sample_next_token(logits, rng, temperature, top_k)NOTE
Why batch size 1 for prefill?
- The prompt is the same for all samples (when generating multiple)
- More efficient to prefill once, then replicate the cache
- Saves computation: 1 prefill instead of N prefills
Attention During Prefill
# From nanochat/gpt.py lines 104-107
if kv_cache is None or Tq == Tk:
# Tq == Tk means prefill (processing all tokens at once)
# Use PyTorch's efficient causal attention implementation
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)Causal mask visualization (prefill with 4 tokens):
K₀ K₁ K₂ K₃
Q₀ [ ✓ ✗ ✗ ✗ ] Token 0 attends only to itself
Q₁ [ ✓ ✓ ✗ ✗ ] Token 1 attends to 0,1
Q₂ [ ✓ ✓ ✓ ✗ ] Token 2 attends to 0,1,2
Q₃ [ ✓ ✓ ✓ ✓ ] Token 3 attends to 0,1,2,3
✓ = attend (compute score)
✗ = mask (score = -∞)
PyTorch's scaled_dot_product_attention with is_causal=True implements this efficiently using FlashAttention or memory-efficient attention kernels.
Decode Implementation
After prefill, we enter the decode loop:
# From nanochat/engine.py lines 225-229
while True:
# Forward pass with single token
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # ids: (B, 1)
logits = logits[:, -1, :] # (B, vocab_size)
next_ids = sample_next_token(logits, rng, temperature, top_k)
sampled_tokens = next_ids[:, 0].tolist()Key difference: ids has shape (B, 1) (single token) instead of (B, T) (full sequence).
Attention During Decode
# From nanochat/gpt.py lines 108-111
elif Tq == 1:
# Single query attending to all cached keys/values
# No causal mask needed (query is at the end of sequence)
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)Attention pattern (decode step with 1 new token, 4 cached):
K₀ K₁ K₂ K₃
Q₄ [ ✓ ✓ ✓ ✓ ] New token attends to all cached tokens
All previous keys are attended to (no masking needed)
Why is_causal=False?
- The new query is by definition at the end of the sequence
- It should attend to all previous tokens (no future tokens to mask)
- Causal constraint is implicitly satisfied by the sequential generation
Hybrid Case: Chunk Processing
nanochat also handles the case where multiple tokens are processed during decode (useful for speculative decoding or batch prefilling):
# From nanochat/gpt.py lines 113-121
else:
# Tq > 1 AND Tq < Tk: Processing a chunk during decode
# Example: 50 cached tokens, adding 10 new tokens
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device)
prefix_len = Tk - Tq
# All queries attend to all cached tokens (prefix)
attn_mask[:, :prefix_len] = True
# Causal attention within the new chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)Attention pattern (chunk of 3 tokens, 4 cached):
K₀ K₁ K₂ K₃ K₄ K₅ K₆
└─ cached ─┘ └─ new chunk ┘
Q₄ [ ✓ ✓ ✓ ✓ ✓ ✗ ✗ ] Attend to prefix + self
Q₅ [ ✓ ✓ ✓ ✓ ✓ ✓ ✗ ] Attend to prefix + causal
Q₆ [ ✓ ✓ ✓ ✓ ✓ ✓ ✓ ] Attend to prefix + causal
This pattern enables efficient batch processing of multiple tokens while maintaining causal constraints.
Batch Generation with Cache Replication
When generating multiple samples from the same prompt (e.g., for best-of-N sampling or temperature sampling), nanochat employs a clever cache replication strategy.
The Replication Pattern
# From nanochat/engine.py lines 183-202
# Step 1: Prefill once with batch size 1
kv_cache_prefill = KVCache(batch_size=1, seq_len=len(tokens), ...)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
# Step 2: Replicate cache for batch generation
kv_length_hint = (len(tokens) + max_tokens) if max_tokens else model.config.sequence_len
kv_cache_decode = KVCache(
batch_size=num_samples, # ← Expand to N samples
seq_len=kv_length_hint,
**kv_model_kwargs
)
kv_cache_decode.prefill(kv_cache_prefill) # Replicate cached data
del kv_cache_prefill # Free memory
# Step 3: Decode in parallel for all N samples
for token_column, token_masks in self.generate(...):
# Each sample generates independently
...Why this pattern?
- Avoid redundant computation: Prefill is expensive (O(T²)), replicate is cheap (memory copy)
- Maximize parallelism: All N samples decode in parallel on GPU
- Memory efficiency: Share prompt cache, only branch for generation
Cache Prefill Implementation
# From nanochat/engine.py lines 74-99
def prefill(self, other):
"""
Prefill this cache with data from another cache.
Optionally expand along batch dimension.
"""
assert self.kv_cache is None, "Cannot prefill non-empty cache"
assert other.kv_cache is not None, "Cannot prefill with None cache"
# Validate dimensions (shapes must be compatible)
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
if ix in [0, 1, 3, 5]: # num_layers, 2, num_heads, head_dim
assert dim1 == dim2, f"Dimension {ix} mismatch: {dim1} != {dim2}"
elif ix == 2: # batch_size can expand
assert dim1 >= dim2 or dim2 == 1, f"Batch dim mismatch"
elif ix == 4: # seq_len: target must be >= source
assert dim1 >= dim2, f"Seq len mismatch"
# Initialize and copy (with broadcasting)
dtype, device = other.kv_cache.dtype, other.kv_cache.device
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
self.pos = other.posBroadcasting magic:
# Source cache: (L, 2, 1, H, T, D) ← batch_size = 1
# Target cache: (L, 2, N, H, T, D) ← batch_size = N
# Assignment triggers PyTorch broadcasting: dimension 2 expands from 1 → NMemory Efficiency Analysis
Without replication (prefill each sample separately):
Cost per sample: O(T²) attention computation
Total for N samples: O(N × T²)
Example: 4 samples, 50-token prompt, 20 layers
→ 4 × 50² × 20 = 200,000 operations
With replication (prefill once, copy cache):
Prefill cost: O(T²)
Replication cost: O(N × L × T × D) ← Memory copy, not computation
Total: O(T²) + O(N × L × T × D)
Example: 4 samples, 50-token prompt, 20 layers, 128 head_dim
→ 50² × 20 + 4 × 20 × 50 × 128 = 50,000 + 512,000 operations
NOTE
Wait, the numbers look worse?
The key insight: memory copies are memory-bound (limited by bandwidth), while attention is compute-bound (limited by GPU FLOPs). Modern GPUs have much higher memory bandwidth than effective compute for small operations.
Real-world timings (H100 GPU, nanochat 270M model):
Without replication (4 samples):
- 4 × prefill(50 tokens): 4 × 15ms = 60ms
With replication:
- 1 × prefill(50 tokens): 15ms
- Cache copy (512KB): <1ms
- Total: ~16ms
Speedup: 60ms / 16ms ≈ 3.75×
Multi-Query Attention (MQA) Optimization
One of the largest memory costs in KV caching is the cache size itself. Multi-Query Attention (MQA) dramatically reduces this cost by sharing key and value heads across all query heads.
Standard Multi-Head Attention (MHA)
In standard MHA, each head has its own queries, keys, and values:
# Standard MHA (not in nanochat, for comparison)
n_head = 10
head_dim = 128
# Separate projections for each head
c_q = nn.Linear(n_embd, n_head * head_dim) # 10 query heads
c_k = nn.Linear(n_embd, n_head * head_dim) # 10 key heads
c_v = nn.Linear(n_embd, n_head * head_dim) # 10 value headsKV cache size (per sample):
Shape: (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
= (20, 2, 1, 10, 2048, 128)
Size: 20 × 2 × 1 × 10 × 2048 × 128 × 2 bytes (bfloat16)
= 210 MB per sample
Multi-Query Attention (MQA)
MQA uses a single shared key/value head for all query heads:
# From nanochat/gpt.py lines 68-76
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
self.n_head = config.n_head # e.g., 10 query heads
self.n_kv_head = config.n_kv_head # e.g., 1 shared KV head
self.c_q = nn.Linear(n_embd, n_head * head_dim, bias=False) # 10 heads
self.c_k = nn.Linear(n_embd, n_kv_head * head_dim, bias=False) # 1 head only!
self.c_v = nn.Linear(n_embd, n_kv_head * head_dim, bias=False) # 1 head only!KV cache size (per sample):
Shape: (20, 2, 1, 1, 2048, 128) ← num_heads = 1 instead of 10!
Size: 20 × 2 × 1 × 1 × 2048 × 128 × 2 bytes
= 21 MB per sample
Savings: 210 MB → 21 MB (10× reduction!)
Replicating KV for Query Heads
During the forward pass, the shared K,V are replicated to match the number of query heads:
# From nanochat/gpt.py lines 99-101
# After retrieving k,v from cache (shape: B, n_kv_head, T, D)
nrep = self.n_head // self.n_kv_head # e.g., 10 // 1 = 10
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
# Result shape: (B, n_head, T, D)repeat_kv implementation:
# From nanochat/gpt.py lines 52-61
def repeat_kv(x, n_rep):
"""Replicate KV heads to match number of query heads."""
if n_rep == 1:
return x # No replication needed
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :] # Add dimension: (B, KV, 1, T, D)
.expand(bs, n_kv_heads, n_rep, slen, head_dim) # Expand: (B, KV, nrep, T, D)
.reshape(bs, n_kv_heads * n_rep, slen, head_dim) # Merge: (B, KV*nrep, T, D)
)Example (1 KV head replicated to 10 query heads):
Input: (B=1, H_kv=1, T=100, D=128)
↓
Unsqueeze: (1, 1, 1, 100, 128)
↓
Expand: (1, 1, 10, 100, 128) ← Replicate along dimension 2
↓
Reshape: (1, 10, 100, 128) ← Merge dimensions 1 and 2
Why expand() instead of repeat()?
expand()creates a view (no memory copy)repeat()creates a copy (duplicates memory)- The reshape after expand forces a copy, but it's done once per forward pass (not stored in cache)
Grouped-Query Attention (GQA)
nanochat's implementation also supports Grouped-Query Attention (GQA), a middle ground between MHA and MQA:
# From nanochat/gpt.py GPTConfig lines 31-32
n_head: int = 10 # Query heads
n_kv_head: int = 5 # Key/Value heads (GQA when 1 < n_kv_head < n_head)GQA groups (10 query heads, 5 KV heads):
Query heads: [Q₀, Q₁] [Q₂, Q₃] [Q₄, Q₅] [Q₆, Q₇] [Q₈, Q₉]
↓ ↓ ↓ ↓ ↓
KV heads: K₀/V₀ K₁/V₁ K₂/V₂ K₃/V₃ K₄/V₄
Each KV head is shared by 2 query heads (n_rep = 10 // 5 = 2)
Cache size comparison:
MHA (10 KV heads): 210 MB (baseline)
GQA (5 KV heads): 105 MB (2× savings)
MQA (1 KV head): 21 MB (10× savings)
Quality vs efficiency trade-off:
- MHA: Best quality, highest memory cost
- GQA: Good balance (used in Llama 2, Mistral)
- MQA: Lowest memory, slight quality degradation (used in PaLM, Falcon)
Performance Analysis
Theoretical Speedup
Naive generation (T tokens, L layers, d model dimension):
Cost per token t:
- Process all t tokens: O(t × d²) for projections
- Compute attention: O(t²) per layer
- Total: O(L × t × d²) + O(L × t²)
Summing over T tokens:
Total = Σ(t=1 to T) [L × t²] ≈ O(L × T³)
With KV caching:
Cost per token:
- Process 1 new token: O(d²) for projections
- Attend to t cached tokens: O(t)
- Total: O(L × d²) + O(L × t)
Summing over T tokens:
Total = Σ(t=1 to T) [L × t] ≈ O(L × T²)
Speedup ratio: T (linear in sequence length!)
For a 100-token generation:
- Naive: O(L × 100³) = O(L × 1,000,000)
- Cached: O(L × 100²) = O(L × 10,000)
- Speedup: 100× (theoretical)
Memory Overhead
KV cache size:
cache_bytes = (num_layers × 2 × batch_size × num_kv_heads
× seq_len × head_dim × sizeof(bfloat16))
= L × 2 × B × H_kv × T × D × 2nanochat 270M model example:
L = 20 layers
H_kv = 10 KV heads (1:1 with query heads, no MQA)
D = 128 head_dim
T = 2048 max sequence length
B = 1 batch size
cache_size = 20 × 2 × 1 × 10 × 2048 × 128 × 2 bytes
= 210 MB per sample
With MQA (H_kv = 1):
cache_size = 20 × 2 × 1 × 1 × 2048 × 128 × 2 bytes
= 21 MB per sample (10× reduction!)Batch generation (4 samples):
Without cache sharing: 4 × 210 MB = 840 MB
With cache replication: 210 MB (prefill) + 4 × 21 MB (decode per sample)
= 294 MB (2.86× savings)
Real-World Benchmarks
nanochat inference (depth-20 model, A100 GPU, 50-token prompt, 100-token generation):
| Method | Prefill Time | Decode Time | Total Time | Speedup |
|---|---|---|---|---|
| Naive (no cache) | 15ms | 2.5s (25ms/token) | 2.515s | 1.0× |
| With KV cache | 15ms | 400ms (4ms/token) | 415ms | 6.1× |
| + MQA (1 KV head) | 12ms | 380ms (3.8ms/token) | 392ms | 6.4× |
Key observations:
- Decode speedup: 25ms → 4ms per token (6.25×)
- MQA adds marginal speedup (cache I/O is smaller)
- End-to-end speedup: 6.1× (dominated by decode phase for long generations)
- Memory usage: 210MB → 21MB cache (MQA)
Scaling with Sequence Length
Time per token vs sequence length (empirical measurements):
Sequence Length | Naive (ms) | Cached (ms) | Speedup
─────────────────┼────────────┼─────────────┼─────────
10 | 2.5 | 3.5 | 0.7×
50 | 8.0 | 4.0 | 2.0×
100 | 25.0 | 4.5 | 5.5×
500 | 180.0 | 7.0 | 25.7×
1000 | 650.0 | 10.0 | 65.0×
2000 | 2400.0 | 16.0 | 150.0×
Observations:
- KV cache overhead dominates for very short sequences (<10 tokens)
- Break-even point: ~20 tokens
- Speedup grows linearly with sequence length (as predicted by theory)
- For long sequences (1000+ tokens), speedup is dramatic (50-150×)
Conclusion
KV caching transforms Transformer inference from cubic to quadratic complexity, enabling practical autoregressive generation at scale. nanochat's implementation showcases key engineering principles that make this optimization both performant and maintainable.
Key Takeaways
- Computational savings: O(L × T³) → O(L × T²), typically 6-10× speedup
- Memory trade-off: ~200MB cache overhead for 6× faster inference
- Two-phase generation: Parallel prefill → sequential cached decode
- Dynamic growth: Cache expands automatically in 1024-token increments
- Batch optimization: Prefill once, replicate cache for N samples
- MQA compression: 10× cache size reduction with minimal quality loss
Design Patterns Worth Emulating
✅ Lazy initialization: Defer allocation until device/dtype known
✅ View-based APIs: Return tensor views, not copies, for zero-cost slicing
✅ Chunked growth: Allocate in aligned chunks (1024 tokens) for efficiency
✅ Position tracking: Centralized pos variable prevents index bugs
✅ Automatic broadcasting: Let PyTorch handle cache replication across batch dimension
When to Use KV Caching
✅ Essential for:
- Autoregressive generation (LLMs, code completion)
- Interactive chat (maintain conversation context)
- Long-form generation (>50 tokens)
- Batch sampling (best-of-N, beam search)
❌ Overkill for:
- Single forward pass inference (classification)
- Very short generations (<10 tokens)
- Encoder-only models (BERT, RoBERTa)
- Non-autoregressive models
Complete Code Example
NOTE
Experiments Deferred: Detailed experiments and performance benchmarks will be added based on reader interest. The code examples below demonstrate the core KV caching patterns from nanochat.
import torch
from nanochat.engine import KVCache, Engine
from nanochat.checkpoint_manager import load_model
from nanochat.tokenizer import get_tokenizer
# Load model and tokenizer
device = torch.device("cuda")
model, tokenizer, meta = load_model("base", device, phase="eval")
engine = Engine(model, tokenizer)
# Prepare prompt
prompt = "The capital of France is"
tokens = tokenizer.encode(prompt, prepend="<|bos|>")
print(f"Prompt: {prompt}")
print(f"Prompt tokens: {len(tokens)}")
# Generate with KV caching (3 samples)
print("\nGenerating 3 samples with KV caching...")
torch.cuda.synchronize()
import time
t0 = time.time()
results, masks = engine.generate_batch(
tokens,
num_samples=3,
max_tokens=50,
temperature=0.8,
top_k=50,
seed=42
)
torch.cuda.synchronize()
t1 = time.time()
# Print results
for i, (result, mask) in enumerate(zip(results, masks)):
text = tokenizer.decode(result)
num_forced = sum(1 - m for m in mask) # Count forced tokens (calculator tool)
print(f"\nSample {i+1}:")
print(f" Text: {text}")
print(f" Tokens: {len(result)} ({num_forced} forced by tool)")
print(f"\nTotal time: {(t1-t0)*1000:.1f}ms")
print(f"Time per token: {(t1-t0)*1000/sum(len(r) for r in results):.1f}ms")Monitoring Cache Usage
# Detailed cache statistics
m = model.config
kv_cache = KVCache(
batch_size=4,
num_heads=m.n_kv_head,
seq_len=2048,
head_dim=m.n_embd // m.n_head,
num_layers=m.n_layer
)
print(f"Initial cache shape: {kv_cache.kv_shape}")
print(f"Initial cache pos: {kv_cache.pos}")
print(f"Cache initialized: {kv_cache.kv_cache is not None}")
# After some generation
# kv_cache will have been filled via insert_kv() calls
if kv_cache.kv_cache is not None:
print(f"\nAfter generation:")
print(f" Cache position: {kv_cache.pos}")
print(f" Cache capacity: {kv_cache.kv_cache.size(4)}")
print(f" Utilization: {kv_cache.pos / kv_cache.kv_cache.size(4) * 100:.1f}%")
# Memory usage
cache_bytes = kv_cache.kv_cache.numel() * kv_cache.kv_cache.element_size()
print(f" Cache memory: {cache_bytes / 1024**2:.1f} MB")
# Overall GPU memory
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
print(f"\nGPU memory: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved")Related Posts
- Previous: Distributed Muon - Custom Gradient Synchronization
- Next: Modern Transformer Architecture - RoPE, QK normalization, ReLU² activations (coming soon)
- See also: The Muon Optimizer Explained - Orthogonal gradient optimization
Additional Resources
- Attention Is All You Need (Vaswani et al.) - Original Transformer paper
- FlashAttention (Dao et al.) - Memory-efficient attention algorithms
- Multi-Query Attention (Shazeer) - KV cache size reduction
- GQA: Grouped-Query Attention (Ainslie et al.) - Balance between MHA and MQA
- PyTorch SDPA Documentation
- nanochat source: engine.py, gpt.py
About this series: This is part of a comprehensive blog series exploring the technical innovations in nanochat, Andrej Karpathy's minimal ChatGPT implementation. See the series navigation at the top for all posts.
On this page
- nanochat Deep-Dive Series - Track 1
- Introduction
- The Computational Bottleneck
- Naive Autoregressive Generation
- Attention Computation Breakdown
- The KV Caching Solution
- KVCache Data Structure
- Architecture and Initialization
- Lazy Initialization
- Dynamic Growth Strategy
- Insertion and Retrieval
- Cache Structure Visualization
- Prefill vs Decode Phases
- The Two-Phase Pattern
- Prefill Implementation
- Attention During Prefill
- Decode Implementation
- Attention During Decode
- Hybrid Case: Chunk Processing
- Batch Generation with Cache Replication
- The Replication Pattern
- Cache Prefill Implementation
- Memory Efficiency Analysis
- Multi-Query Attention (MQA) Optimization
- Standard Multi-Head Attention (MHA)
- Multi-Query Attention (MQA)
- Replicating KV for Query Heads
- Grouped-Query Attention (GQA)
- Performance Analysis
- Theoretical Speedup
- Memory Overhead
- Real-World Benchmarks
- Scaling with Sequence Length
- Conclusion
- Key Takeaways
- Design Patterns Worth Emulating
- When to Use KV Caching
- Complete Code Example
- Monitoring Cache Usage
- Related Posts
- Additional Resources



