Memory Optimization Techniques: Gradient Accumulation & Mixed Precision

- Published on
- /25 mins read
Track 2: Practical Guides - Post 2.6 of 6
This final post in Track 2 covers practical memory optimization strategies: gradient accumulation, mixed precision training, sequence length management, optimizer state optimization, distributed memory, and profiling tools. View all posts in this track →
One H100 has 80GB—here's why you'll run out anyway
Memory management was my biggest surprise when first training LLMs. Weights are tiny compared to optimizer states, gradients, and activations. Understanding this breakdown is the first step to fitting larger models on smaller hardware.
80GB sounds like a lot. Then you add optimizer states, gradients, and activations. Suddenly your 150M model won't fit.
TL;DR: Gradient accumulation gives you 4× batch size without 4× memory. bfloat16 halves memory with no accuracy loss. Muon uses 2× less optimizer memory than Adam. These techniques let a single 24GB GPU train 150M parameter models.
The OOM that cost a week: Consider a common scenario: training a 150M model on an RTX 4090 (24GB) that crashes at step 50,000 with "CUDA out of memory"—after running fine for two days. The cause: Python's garbage collector hasn't run, and intermediate tensors accumulate. The fix is a single line:
torch.cuda.empty_cache()every 1000 steps. But if you haven't saved checkpoints recently, you lose everything and have to restart. Memory management isn't just about fitting your model—it's about keeping it stable through long training runs. The techniques in this post prevent that crash before it happens.
Memory is the primary bottleneck when training language models. A single H100 GPU has 80GB of memory—sounds like a lot until you realize:
- A 768-dim, 12-layer model (~60M parameters) in
bfloat16uses ~120MB for weights - Activations for a single forward pass (batch_size=32, seq_len=2048) use ~8GB
- Optimizer states (Adam momentum + variance) triple the memory: ~360MB
- Gradients double it again: ~240MB
- Total: ~9GB per GPU, and that's for a tiny model!
Scale to a 1280-dim, 20-layer model (~150M parameters) with batch_size=32, and you're looking at ~20GB per GPU. Run out of memory, and training stops.
Eight techniques get you back under budget:
- Gradient Accumulation: Simulate larger batches without OOM
- Mixed Precision Training:
bfloat16for 2x memory savings - Sequence Length Management: Dynamic batching strategies
- Optimizer State Optimization: Choosing the right optimizer
- Distributed Training: Splitting work across GPUs
- Inference Optimizations: KV caching, batch size tuning
- Advanced Techniques: Gradient checkpointing, activation compression
- Memory Profiling: Tools to diagnose bottlenecks
Table of Contents
- Memory Breakdown: Where Does It All Go?
- Gradient Accumulation
- Mixed Precision Training
- Sequence Length Management
- Optimizer Choice
- Distributed Training Memory
- Inference Memory Optimization
- Advanced: Gradient Checkpointing
- Memory Profiling
- Best Practices & Common Pitfalls
Optimizer states consume 4×8 bytes per parameter
Training Memory Components
For a model with P parameters, batch size B, sequence length T, and embedding dimension d:
| Component | Memory (per parameter) | Total | Notes |
|---|---|---|---|
| Model Weights | 2 bytes (bf16) | 2P | The model parameters |
| Gradients | 2 bytes (bf16) | 2P | Stored during backward pass |
| Optimizer State (Adam) | 8 bytes | 8P | Momentum (4B) + variance (4B) |
| Optimizer State (Muon) | 4 bytes | 4P | Momentum only (4B) |
| Activations | Varies | ~12 * B * T * d * n_layers | Intermediate layer outputs |
Example: 60M parameter model, Adam optimizer:
- Weights: 60M × 2B = 120MB
- Gradients: 60M × 2B = 120MB
- Adam states: 60M × 8B = 480MB
- Total parameter memory: 720MB
Activations dominate for large batches:
batch_size=32,seq_len=2048,d=768,n_layers=12- Activations: ~12 × 32 × 2048 × 768 × 12 ≈ ~8GB
GPU Memory Budget Calculator
Estimate memory requirements for training your model
Memory Distribution
Memory Breakdown
GPU Compatibility
Tips to reduce memory:
- Use gradient checkpointing to trade compute for memory
- Reduce batch size and use gradient accumulation
- Use mixed precision (BF16) training
- Consider ZeRO optimization for multi-GPU setups
- Muon uses ~33% less optimizer memory than AdamW
Memory Scaling
| Scale Factor | Impact |
|---|---|
| Double parameters | +2x model memory, +2x gradient memory, +2x optimizer memory |
| Double batch size | +2x activation memory (no change to model/optimizer) |
| Double sequence length | +2x activation memory (no change to model/optimizer) |
| Switch Adam → Muon | -50% optimizer memory (8 bytes → 4 bytes per param) |
| Add gradient accumulation | No extra memory (same effective batch, different compute) |
Key insight: Activations scale with batch size and sequence length, but optimizer state scales only with parameters.
For your GPU budget, this means: if you're memory-constrained, tackle activations first (smaller batch, gradient accumulation). Optimizer choice matters less for memory—but Muon still saves 50% on optimizer state.
For your training schedule, this means: don't start with max batch size. Start small, monitor GPU memory, and scale up until you hit 85-90% utilization. Leaving 10-15% headroom prevents OOM from activation spikes.
Gradient accumulation simulates large batches without OOM
The Problem
You want total_batch_size = 524,288 tokens, but:
- Your GPU can only fit
device_batch_size = 16sequences ofseq_len = 2048 - That's only
16 × 2048 = 32,768 tokensper step - Gap: You need 16x more tokens per update!
The Solution: Gradient Accumulation
Idea: Accumulate gradients across multiple forward/backward passes before stepping the optimizer.
# Desired: total_batch_size = 524,288 tokens
# Reality: device_batch_size = 16, seq_len = 2048 => 32,768 tokens/step
# Solution: grad_accum_steps = 524,288 / 32,768 = 16
grad_accum_steps = total_batch_size // (device_batch_size * seq_len * world_size)
for step in range(num_iterations):
# Accumulate gradients over multiple micro-batches
for micro_step in range(grad_accum_steps):
x, y = next(data_loader)
loss = model(x, y)
loss = loss / grad_accum_steps # Normalize: each .backward() sums gradients
loss.backward()
# Step optimizer once
optimizer.step()
optimizer.zero_grad()Why Scale Loss by grad_accum_steps?
Without scaling:
# Micro-batch 1: loss=2.5 → backward() adds gradients
# Micro-batch 2: loss=2.3 → backward() adds more gradients
# Result: gradients are 2x too large (sum of 2 losses)For your training loop, this means: always divide loss by grad_accum_steps before .backward(). Miss this, and your effective learning rate silently multiplies by your accumulation factor—training will diverge.
For your debugging, this means: if training diverges after adding gradient accumulation, check the loss scaling first. It's the #1 gradient accumulation bug.
With scaling:
# Micro-batch 1: loss=2.5/2=1.25 → backward()
# Micro-batch 2: loss=2.3/2=1.15 → backward()
# Result: gradients average across micro-batches (correct)Memory Impact
Zero extra memory cost! Gradients are reused across micro-batches:
# Iteration 1: forward → backward (gradients stored)
# Iteration 2: forward → backward (gradients ADDED to existing)
# Iteration 3: forward → backward (gradients ADDED again)
# ...
# Optimizer step (gradients cleared)Only one batch of activations in memory at a time.
Gradient Accumulation Simulator
Visualize how gradient accumulation enables larger effective batch sizes
Micro-batches Progress
Individual Micro-batch Gradients
Accumulated Gradient (÷8)
Memory Comparison
Key insight: Gradient accumulation achieves the same mathematical result as a large batch, but uses only the memory of a small batch. The gradients from each micro-batch are averaged together before the optimizer step. This is mathematically equivalent to processing all samples at once, but allows training on GPUs with limited memory.
nanochat Implementation
# From scripts/base_train.py
tokens_per_fwdbwd = device_batch_size * max_seq_len
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print(f"Gradient accumulation steps: {grad_accum_steps}")
# Training loop
for step in range(num_iterations):
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
train_loss = loss.detach()
loss = loss / grad_accum_steps # Scale loss
loss.backward()
x, y = next(train_loader) # Prefetch next batch
# Clip gradients (optional)
if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# Step optimizers
optimizer.step()
model.zero_grad(set_to_none=True)Key detail: set_to_none=True frees gradient memory immediately (faster than zeroing).
bfloat16 halves memory with no accuracy loss
What is Mixed Precision?
Use lower precision (16-bit) for most operations, full precision (32-bit) for sensitive ops:
| Precision | Size | Range | Precision | Use Case |
|---|---|---|---|---|
float32 | 4 bytes | ±3.4e38 | ~7 decimal digits | Default, stable |
float16 | 2 bytes | ±65,504 | ~3 decimal digits | Fast but unstable |
bfloat16 | 2 bytes | ±3.4e38 | ~2 decimal digits | Best of both worlds |
bfloat16 = same exponent range as float32, less mantissa precision.
Mixed Precision Explainer
Understand floating-point formats used in LLM training and inference
FP16 (Half Precision)
Precision Test
| Format | Represented As | Error | Error % |
|---|---|---|---|
| FP32 | 0.123457 | 0.00e+0 | 0.00% |
| FP16 | 0.123047 | 4.10e-4 | 0.33% |
| BF16 | 0.125000 | 1.54e-3 | 1.25% |
| FP8_E4M3 | 0.125000 | 1.54e-3 | 1.25% |
| FP8_E5M2 | 0.000000 | 1.23e-1 | 100.00% |
| INT8 | 0.125984 | 2.53e-3 | 2.05% |
| INT4 | 0.142857 | 1.94e-2 | 15.71% |
Training (Mixed Precision)
- • Master weights in FP32
- • Forward/backward in FP16/BF16
- • Loss scaling to prevent underflow
- • ~2x speedup, 50% memory saved
Inference (Quantization)
- • INT8/INT4 for weights
- • FP16 for activations
- • 4x-8x memory reduction
- • Requires calibration
Why BF16 for Training?
BF16 has the same exponent range as FP32, preventing overflow/underflow issues that can occur with FP16 during training. While it has lower precision, the dynamic range is more important for gradient updates. Modern GPUs (A100, H100) have native BF16 tensor cores.
Why bfloat16?
- 2x memory savings: Weights, activations, gradients all use half the memory
- Faster compute: Tensor cores on A100/H100 are optimized for bf16
- Stable training: Wider range than fp16 → no loss scaling needed
- Drop-in replacement: Works out-of-the-box for LLM training
nanochat Implementation
# Create autocast context
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
# Use it during forward/backward
with autocast_ctx:
loss = model(x, y)
loss.backward()Under the hood:
- Matrix multiplications →
bfloat16(fast, memory-efficient) - Reductions (sum, softmax) →
float32(numerical stability) - Optimizer step →
float32(precision for weight updates)
What Gets Cast?
# Forward pass
x = self.transformer.wte(idx) # Embedding: bfloat16
x = norm(x) # RMSNorm: bfloat16
q, k, v = self.attn(x) # Linear: bfloat16
attn = softmax(q @ k.T) # Softmax: float32 (automatic upcast)
out = attn @ v # MatMul: bfloat16
logits = self.lm_head(out) # Linear: bfloat16
loss = cross_entropy(logits, y) # Cross-entropy: float32 (automatic upcast)Result: Most memory and compute in bf16, stability-critical ops in fp32.
Explicit Casting in nanochat
# From nanochat/gpt.py
# Cast embeddings to bf16 (save memory in embedding table)
self.transformer.wte.to(dtype=torch.bfloat16)
# Cast rotary embeddings to bf16
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16()
# During forward: cast logits to fp32 for stable loss
logits = self.lm_head(x)
logits = logits.float() # fp32 for cross-entropy
loss = F.cross_entropy(logits, targets)Memory Savings Example
60M parameter model, batch_size=32:
| Component | fp32 | bf16 | Savings |
|---|---|---|---|
| Weights | 240MB | 120MB | -50% |
| Gradients | 240MB | 120MB | -50% |
| Activations (~8GB) | 16GB | 8GB | -50% |
| Total | ~16.5GB | ~8.2GB | ~50% |
Enables: 2x larger models or 2x larger batches on same hardware.
For your next training run, this means: enable torch.bfloat16 on day one. It's free memory savings with zero quality loss. There's no reason not to use it on modern GPUs (A100, H100, RTX 4090).
For your production deployment, this means: bfloat16 inference is twice as fast and uses half the memory. If your serving infrastructure doesn't support it, you're leaving money on the table.
Shorter sequences free quadratic attention memory
The Challenge
Activation memory scales quadratically with sequence length:
Attention: Q @ K^T creates (seq_len × seq_len) matrix
Memory: O(batch_size * n_heads * seq_len^2 * sizeof(dtype))
Example: seq_len=2048 vs seq_len=4096:
- Attention memory: 4x increase (2048² → 4096²)
- Total activation memory: ~2-3x increase
Strategy 1: Start Short, Grow Gradually
Train on shorter sequences early, increase length later:
# Hypothetical staged training
stage_1_iters = 5000
stage_2_iters = 5000
if step < stage_1_iters:
max_seq_len = 1024 # Start short
elif step < stage_1_iters + stage_2_iters:
max_seq_len = 2048 # Grow
else:
max_seq_len = 4096 # Full lengthTrade-off: Early training sees less long-range context, but uses memory efficiently.
Strategy 2: Variable-Length Batching
nanochat uses fixed-length batches for simplicity:
# Every batch has exactly B × T tokens
batch = torch.randint(0, vocab_size, (B, T))Alternative: Pack variable-length sequences into fixed token budget:
# Advanced (not in nanochat): pack sequences dynamically
# Batch 1: [seq_len=1024, seq_len=2048, seq_len=512] → 3584 tokens
# Batch 2: [seq_len=2048, seq_len=2048] → 4096 tokens
# Goal: maintain ~4096 tokens/batch, varying sequence countsThis is complex (requires padding/masking) but maximizes GPU utilization.
Strategy 3: Truncation
nanochat truncates long sequences during tokenization:
def render_conversation(self, conversation, max_tokens=2048):
ids, mask = [], []
# ... render conversation ...
# Truncate to max_tokens
ids = ids[:max_tokens]
mask = mask[:max_tokens]
return ids, maskTrade-off: Long conversations lose tail context, but avoids OOM.
nanochat's Choice
# Fixed sequence length throughout training
max_seq_len = 2048 # Constant
# Data loader yields batches of shape (batch_size, max_seq_len)
for x, y in data_loader:
assert x.shape == (batch_size, max_seq_len)
loss = model(x, y)Why? Simplicity + compiled model performance (dynamic shapes are slower).
Muon uses 2× less optimizer memory than Adam
Memory Footprint Comparison
| Optimizer | State per Parameter | Memory per Param | Example (60M params) |
|---|---|---|---|
| SGD | None | 0 bytes | 0 MB |
| SGD + Momentum | Momentum buffer | 4 bytes | 240 MB |
| AdamW | Momentum + Variance | 8 bytes | 480 MB |
| Muon | Momentum (2D params only) | 4 bytes | 240 MB |
nanochat's Hybrid Approach
# From nanochat/gpt.py
def setup_optimizers(self):
# Separate parameters by type
matrix_params = list(self.transformer.h.parameters()) # 2D (linear layers)
embedding_params = list(self.transformer.wte.parameters()) # 1D (embeddings)
lm_head_params = list(self.lm_head.parameters()) # 2D (classifier)
# AdamW for embeddings + lm_head (needs adaptive LR)
adamw_optimizer = AdamW([
{"params": lm_head_params, "lr": 0.004},
{"params": embedding_params, "lr": 0.2},
], betas=(0.8, 0.95))
# Muon for Transformer matrix params (memory-efficient)
muon_optimizer = Muon(matrix_params, lr=0.02, momentum=0.95)
return [adamw_optimizer, muon_optimizer]Memory breakdown (60M param model):
- Matrix params: ~50M params → Muon → 50M × 4B = 200MB
- Embedding + lm_head: ~10M params → AdamW → 10M × 8B = 80MB
- Total optimizer memory: 280MB (vs 480MB for full AdamW)
Savings: ~40% reduction in optimizer state memory.
When to Use What?
| Optimizer | Use Case | Memory | Convergence Speed |
|---|---|---|---|
| SGD | Very memory-constrained | Lowest | Slowest |
| AdamW | Standard choice, embeddings | High | Fast |
| Muon | Matrix params (Transformers) | Medium | Fast |
| 8-bit AdamW | Extreme memory constraints | Medium | Fast (slight quality loss) |
DDP replicates memory; FSDP shards it across GPUs
Data Parallelism (DDP)
Each GPU holds:
- Full copy of model weights
- Full copy of optimizer state
- 1/N of the batch (where N = number of GPUs)
# Example: 8 GPUs, batch_size=32, seq_len=2048
# Total batch: 32 × 2048 × 8 = 524,288 tokens
# Per-GPU batch: 32 × 2048 = 65,536 tokens
# Each GPU:
device_batch_size = 32 # Per-GPU
total_batch_size = 32 * 8 * 2048 # Across all GPUs
# Memory per GPU:
# - Model weights: Same on all GPUs
# - Activations: 1/8 of total (only local batch)
# - Gradients: Same on all GPUs (all-reduced after backward)Memory savings: Activation memory is distributed, but model/optimizer memory is not.
ZeRO-2 (DistMuon/DistAdamW)
nanochat uses ZeRO-2 for optimizer state:
# From nanochat/muon.py
class DistMuon:
"""Distributed Muon with ZeRO-2: shard optimizer state across ranks"""
def __init__(self, params, lr, momentum):
# Each rank owns a shard of parameters
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
# Shard parameters across ranks
self.owned_params = [p for i, p in enumerate(params) if i % world_size == rank]
# Only allocate momentum for owned parameters
self.momentum_buffers = [torch.zeros_like(p) for p in self.owned_params]Memory savings:
- Optimizer state: Divided by
world_size - Gradients: Still all-reduced (not sharded in ZeRO-2)
- Weights: Still replicated (not sharded in ZeRO-2)
Example (60M params, 8 GPUs):
- Without ZeRO: 280MB optimizer state per GPU
- With ZeRO-2: 280MB / 8 = 35MB optimizer state per GPU
Trade-off: Requires all-gather before optimizer step (communication overhead).
ZeRO-3 (Not in nanochat)
ZeRO-3 shards model weights too:
- Weights: Divided by
world_size - Gradients: Divided by
world_size - Optimizer state: Divided by
world_size
Memory per GPU = total_memory / world_size
Trade-off: Much more communication (all-gather on every forward pass).
KV caching trades memory for 10× inference speed
KV Caching
From Post 1.3, KV caching reuses past key/value projections:
class KVCache:
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
# Shape: (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_cache = None
self.pos = 0
def insert_kv(self, layer_idx, k, v):
# Dynamically grow cache if needed
if self.pos + k.size(2) > self.kv_cache.size(4):
# Grow by 1024 tokens, round up
new_size = ((self.pos + k.size(2) + 1024) + 1023) & ~1023
self.kv_cache.resize_([..., new_size, ...])
# Insert new k, v
self.kv_cache[layer_idx, 0, :, :, self.pos:self.pos+k.size(2)] = k
self.kv_cache[layer_idx, 1, :, :, self.pos:self.pos+k.size(2)] = v
self.pos += k.size(2)
# Return full cache
return self.kv_cache[layer_idx, 0, :, :, :self.pos], self.kv_cache[layer_idx, 1, :, :, :self.pos]Memory cost:
# For a 12-layer, 6-head, 128-dim model:
# batch_size=1, max_seq_len=2048
kv_memory = 2 * 12 * 6 * 2048 * 128 * 2 # (K+V) * layers * heads * seq * dim * bytes(bf16)
= ~70 MB per sequenceBatch size impact: KV cache scales linearly with batch size:
batch_size=1: 70 MBbatch_size=8: 560 MBbatch_size=64: 4.5 GB
Inference Batch Size Tuning
For generation, batch size trades throughput vs latency:
# Low batch size: low latency, low throughput
engine.generate_batch(tokens, num_samples=1) # 1 sequence at a time
# High batch size: high latency, high throughput
engine.generate_batch(tokens, num_samples=64) # 64 sequences in parallelMemory scaling:
memory_per_batch = weights + (batch_size × kv_cache_per_seq) + (batch_size × activations_per_seq)
For serving: Use largest batch size that fits in memory (maximizes throughput).
Prefill vs Decode
nanochat separates prefill (process prompt) and decode (generate tokens):
# Prefill: batch_size=1, full prompt at once
kv_cache_prefill = KVCache(batch_size=1, seq_len=len(prompt))
logits = model.forward(prompt_tokens, kv_cache=kv_cache_prefill)
# Decode: batch_size=num_samples, one token at a time
kv_cache_decode = KVCache(batch_size=num_samples, seq_len=max_gen_len)
kv_cache_decode.prefill(kv_cache_prefill) # Copy from prefill
# Generate multiple samples in parallel
for step in range(max_gen_len):
logits = model.forward(next_tokens, kv_cache=kv_cache_decode)
next_tokens = sample(logits)Memory optimization: Prefill uses batch_size=1 (save KV cache memory), then replicate for decode.
Gradient checkpointing trades compute for 3× less activation memory
The Problem
Activation memory dominates for large models:
# Forward pass stores activations for backward pass
x = input
for layer in model.layers:
x = layer(x) # Activation stored in memoryFor a 20-layer model, 20 sets of activations stored simultaneously.
Gradient Checkpointing
Idea: Don't store all activations. Recompute them during backward pass.
# Without checkpointing: store all activations
x = input
activations = []
for layer in layers:
x = layer(x)
activations.append(x) # Store for backward
loss = criterion(x, target)
loss.backward() # Uses stored activations
# With checkpointing: store only some activations
x = input
checkpoints = []
for i, layer in enumerate(layers):
if i % checkpoint_every == 0:
checkpoints.append(x) # Checkpoint
x = layer(x) # Don't store
loss = criterion(x, target)
loss.backward() # Recomputes missing activations on-the-flyMemory/compute trade-off:
- Memory saved: ~N/checkpoint_every (e.g., checkpoint every 4 layers → 4x savings)
- Compute cost: ~33% increase (recompute during backward)
PyTorch Implementation
import torch.utils.checkpoint as checkpoint
class TransformerWithCheckpointing(nn.Module):
def forward(self, x):
for layer in self.layers:
# Checkpoint this layer (no activations stored)
x = checkpoint.checkpoint(layer, x, use_reentrant=False)
return xnanochat Status
Not currently implemented in nanochat (focus on simplicity), but would add:
# Hypothetical addition to nanochat/gpt.py
class GPT(nn.Module):
def __init__(self, config, use_checkpointing=False):
self.use_checkpointing = use_checkpointing
# ...
def forward(self, x):
for block in self.transformer.h:
if self.use_checkpointing:
x = checkpoint.checkpoint(block, x, cos_sin, kv_cache, use_reentrant=False)
else:
x = block(x, cos_sin, kv_cache)
return xWhen to use:
- Very large models (billions of parameters)
- Long sequences (4K+ tokens)
- Memory-constrained hardware
When NOT to use:
- Small models (overhead dominates)
- Short sequences (activation memory is small)
- When training speed is critical
torch.cuda.memory_summary shows exactly where memory goes
PyTorch Memory Stats
import torch
# Track peak memory
torch.cuda.reset_peak_memory_stats()
# ... training code ...
peak_memory_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
print(f"Peak memory: {peak_memory_mb:.2f} MB")
# Detailed stats
print(torch.cuda.memory_summary())Output example:
|===========================================================================|
| PyTorch CUDA memory summary |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 |
| Allocation retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 8.2 GB | 12.5 GB | 50.3 GB | 42.1 GB |
| Active memory | 8.2 GB | 12.5 GB | | |
| ...
nanochat's Usage
# From scripts/base_train.py
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")Profiling Individual Components
def profile_memory(model, batch):
torch.cuda.reset_peak_memory_stats()
# Forward pass
loss = model(batch)
fwd_memory = torch.cuda.max_memory_allocated() / 1024**2
print(f"After forward: {fwd_memory:.2f} MB")
# Backward pass
torch.cuda.reset_peak_memory_stats()
loss.backward()
bwd_memory = torch.cuda.max_memory_allocated() / 1024**2
print(f"After backward: {bwd_memory:.2f} MB")
# Optimizer step
torch.cuda.reset_peak_memory_stats()
optimizer.step()
opt_memory = torch.cuda.max_memory_allocated() / 1024**2
print(f"After optimizer: {opt_memory:.2f} MB")PyTorch Profiler (Advanced)
from torch.profiler import profile, ProfilerActivity
with profile(activities=[ProfilerActivity.CUDA], profile_memory=True) as prof:
for step in range(10):
loss = model(x, y)
loss.backward()
optimizer.step()
# Export to Chrome trace
prof.export_chrome_trace("trace.json")
# Or print table
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))Output:
--------------------------------- ------------ ------------
Name Self CUDA Mem Peak CUDA Mem
--------------------------------- ------------ ------------
aten::matmul 2.50 GB 2.50 GB
aten::linear 1.20 GB 1.20 GB
aten::softmax 0.80 GB 0.80 GB
...
Debugging OOM Errors
try:
loss = model(x, y)
loss.backward()
except RuntimeError as e:
if "out of memory" in str(e):
print(f"OOM! Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
print(torch.cuda.memory_summary())
# Try to recover
torch.cuda.empty_cache()
# Reduce batch size and retry
batch_size = batch_size // 2
else:
raise eThese patterns prevent OOM errors
Best Practices
1. Start Conservative, Scale Up
# Good: Start with small batch, measure memory, then scale
device_batch_size = 8 # Start small
# ... measure peak memory ...
# If memory allows, increase to 16, 32, etc.
# Bad: Start with huge batch, OOM immediately
device_batch_size = 128 # Instant OOM2. Use Mixed Precision by Default
# Good: Always use bfloat16 for modern GPUs
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
# Bad: Stick to fp32 (uses 2x memory for no reason)
# ... no autocast, everything in fp32 ...3. Gradient Accumulation for Large Batches
# Good: Simulate large batch with gradient accumulation
device_batch_size = 16 # Fits in memory
grad_accum_steps = 32 # Effective batch = 16 × 32 = 512
# Bad: Try to fit huge batch directly
device_batch_size = 512 # OOM4. Clear Gradients Properly
# Good: Free gradient memory immediately
model.zero_grad(set_to_none=True)
# Okay: Zero gradients (slightly slower)
model.zero_grad()
# Bad: Don't clear gradients (memory leak)
# ... no zero_grad call ...5. Monitor Memory Throughout Training
# Good: Log memory usage periodically
if step % 100 == 0:
mem_mb = torch.cuda.memory_allocated() / 1024**2
wandb.log({"memory_mb": mem_mb})
# Bad: Never check memory (discover OOM at step 5000)Common Pitfalls
Pitfall 1: Forgetting to Scale Loss in Gradient Accumulation
# Bad: Gradients are grad_accum_steps × too large
for micro_step in range(grad_accum_steps):
loss = model(x, y)
loss.backward() # Oops, no scaling!
# Good: Scale loss by grad_accum_steps
for micro_step in range(grad_accum_steps):
loss = model(x, y)
loss = loss / grad_accum_steps
loss.backward()Pitfall 2: Storing Unnecessary Tensors
# Bad: Accumulating losses keeps activation graphs in memory
losses = []
for x, y in data_loader:
loss = model(x, y)
losses.append(loss) # Stores computation graph!
# Good: Detach scalars
losses = []
for x, y in data_loader:
loss = model(x, y)
losses.append(loss.detach().item()) # No graph, just floatPitfall 3: Inefficient KV Cache Initialization
# Bad: Preallocate huge cache upfront
kv_cache = torch.zeros((layers, 2, batch, heads, 100000, dim)) # 100K seq len!
# Good: Grow dynamically as needed
kv_cache = torch.zeros((layers, 2, batch, heads, prompt_len, dim))
# ... grow as generation proceeds ...Pitfall 4: Mixed Precision Without Autocast
# Bad: Manually cast everything (error-prone, verbose)
x = x.to(dtype=torch.bfloat16)
y = model(x)
y = y.to(dtype=torch.float32)
# Good: Use autocast context (automatic, correct)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
y = model(x)Pitfall 5: Not Clearing CUDA Cache
# Bad: Cache fragmentation causes OOM over time
for step in range(10000):
# ... training ...
# Memory gets fragmented, eventual OOM
# Good: Periodically clear cache
for step in range(10000):
# ... training ...
if step % 1000 == 0:
torch.cuda.empty_cache()Pitfall 6: Large Batch Size on Small Models
# Bad: batch_size=256 on a 10M param model
# Problem: Activation memory >> model memory (wasteful)
# Good: Use batch size proportional to model size
# Small model (10M): batch_size=16-32
# Medium model (100M): batch_size=32-64
# Large model (1B+): batch_size=64-128Memory optimization is the art of doing more with less
You can train larger models, longer sequences, and bigger batches on the hardware you already have.
Key techniques:
- Gradient Accumulation: Free larger effective batch sizes
- Mixed Precision (bfloat16): 2x memory savings with no quality loss
- Optimizer Choice: Muon saves 50% vs AdamW on matrix parameters
- Sequence Length Management: Start short, grow gradually
- ZeRO-2: Shard optimizer state across GPUs
- KV Caching: Reuse past keys/values during generation
- Gradient Checkpointing: Trade compute for memory (33% slower, 4x less memory)
- Memory Profiling: Measure before optimizing
Memory hierarchy:
Model Weights (fixed) < Optimizer State (fixed) << Activations (scales with batch)
Optimization priority:
- Enable mixed precision (bfloat16) → instant 2x savings
- Tune device_batch_size + grad_accum_steps → maximize GPU utilization
- Choose efficient optimizer (Muon for Transformers) → 40% optimizer memory savings
- Use distributed training (DDP + ZeRO-2) → linear scaling across GPUs
- Add gradient checkpointing (if desperate) → 4x activation memory savings
With these techniques, you can train 2-4x larger models on the same hardware—or train faster with larger batches and more aggressive settings.
Before you optimize your memory usage:
- Enable bfloat16 first. This is free 2× memory savings with zero quality loss—do this before anything else.
- Profile before optimizing. Use
torch.cuda.memory_allocated()to identify whether activations, optimizer state, or model weights are your bottleneck. - Start with small batch size, scale up. Begin at batch_size=8, measure peak memory, then double until you approach 90% utilization.
- Scale loss during gradient accumulation. Divide by
grad_accum_steps—forgetting this makes gradients N× too large. - Use
set_to_none=Truein zero_grad. Frees gradient memory immediately instead of just zeroing values.
The GPU you have is more capable than you think. Now you know how to unlock it.
Sources
Research Papers
ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (Rajbhandari et al., 2019) - The foundational paper introducing Zero Redundancy Optimizer, which partitions optimizer states, gradients, and parameters across data-parallel processes to dramatically reduce memory footprint. arXiv:1910.02054
Mixed Precision Training (Micikevicius et al., 2017) - Introduces techniques for training deep neural networks using half-precision floating point numbers while maintaining model accuracy, reducing memory consumption by nearly 2x. Published at ICLR 2018. arXiv:1710.03740
Training Deep Nets with Sublinear Memory Cost (Chen et al., 2016) - Proposes gradient checkpointing, a systematic approach to reduce memory consumption from O(n) to O(√n) for training n-layer networks, enabling deeper models on limited hardware. arXiv:1604.06174
ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning (Rajbhandari et al., 2021) - Extends ZeRO with CPU and NVMe offloading capabilities to train models with trillions of parameters. arXiv:2104.07857
Muon: An optimizer for hidden layers in neural networks (Kosson et al., 2025) - Describes the Muon optimizer that maintains significantly less state than Adam (only momentum vs. both first and second moments), reducing optimizer memory requirements. arXiv:2502.16982
Technical Documentation
PyTorch Automatic Mixed Precision (torch.amp) - Official PyTorch documentation for the AMP package, covering autocast context managers, GradScaler for gradient scaling, and op-specific behavior for mixed precision training. PyTorch AMP Docs
DeepSpeed ZeRO Tutorial - Comprehensive guide to implementing ZeRO stages 1-3, including configuration examples, memory savings calculations, and ZeRO-Infinity offloading to CPU and NVMe. DeepSpeed ZeRO Tutorial
DeepSpeed Configuration Reference - Complete documentation for DeepSpeed JSON configuration options including ZeRO optimizations, FP16/BFloat16 training, and optimizer settings. DeepSpeed Config JSON
Framework Resources
PyTorch Gradient Checkpointing - Documentation for
torch.utils.checkpointwhich implements activation checkpointing to trade compute for memory during training. PyTorch Checkpoint UtilsNVIDIA Apex - NVIDIA's PyTorch extension library providing optimized mixed precision and distributed training utilities. NVIDIA Apex GitHub
Microsoft DeepSpeed - Deep learning optimization library that implements ZeRO, mixed precision training, and various memory optimization techniques. DeepSpeed GitHub
GPU Hardware & Pricing (as of January 2025)
| GPU | VRAM | Memory Bandwidth | Typical Cost/hr | Source |
|---|---|---|---|---|
| RTX 3090 | 24GB | 936 GB/s | Consumer purchase | NVIDIA RTX 3090 |
| RTX 4090 | 24GB | 1 TB/s | Consumer purchase | NVIDIA RTX 4090 |
| A100 80GB | 80GB | 2 TB/s | ~$1.44/hr | Lambda Labs |
| H100 80GB | 80GB | 3.35 TB/s | ~$2.49/hr | Lambda Labs |
Industry Research (as of January 2025)
- Epoch AI Training Compute: Compute Trends in Machine Learning. Tracks memory efficiency improvements; shows 3× memory efficiency gains in 2024 vs 2022 baseline.
- MLCommons MLPerf Training: Training Benchmark Results. Industry-standard benchmarks showing memory efficiency across hardware configurations.
Related Posts
- Post 1.1: Muon Optimizer Explained - Why Muon saves memory vs AdamW
- Post 1.2: Distributed Muon - ZeRO-2 sharding implementation
- Post 1.3: KV Caching Deep-Dive - Inference memory optimization
- Post 2.1: Training Your First Model - Practical gradient accumulation setup
Exercises
Measure activation memory: Profile your model and measure what percentage of memory is activations vs weights vs optimizer state.
Gradient accumulation experiment: Train with
batch_size=32, grad_accum=1vsbatch_size=16, grad_accum=2. Verify identical convergence.Mixed precision ablation: Train with fp32 vs bf16. Compare memory usage, training speed, and final validation loss.
Sequence length scaling: Measure peak memory as you scale
seq_lenfrom 512 → 1024 → 2048 → 4096. Plot memory vs seq_len².Optimizer state comparison: Train identical model with AdamW vs Muon. Compare optimizer state memory (use
torch.cuda.memory_allocated()).KV cache growth: Implement dynamic KV cache growth and measure memory usage during generation from 0 to 2048 tokens.
Series Complete! You've completed Track 2 (Practical Guides). Congratulations on mastering the practical implementation of nanochat!
Part of the nanochat Deep-Dive Series • Code: nanochat on GitHub
On this page
- One H100 has 80GB—here's why you'll run out anyway
- Table of Contents
- Optimizer states consume 4×8 bytes per parameter
- Training Memory Components
- Memory Scaling
- Gradient accumulation simulates large batches without OOM
- The Problem
- The Solution: Gradient Accumulation
- Why Scale Loss by grad_accum_steps?
- Memory Impact
- nanochat Implementation
- bfloat16 halves memory with no accuracy loss
- What is Mixed Precision?
- Why bfloat16?
- nanochat Implementation
- What Gets Cast?
- Explicit Casting in nanochat
- Memory Savings Example
- Shorter sequences free quadratic attention memory
- The Challenge
- Strategy 1: Start Short, Grow Gradually
- Strategy 2: Variable-Length Batching
- Strategy 3: Truncation
- nanochat's Choice
- Muon uses 2× less optimizer memory than Adam
- Memory Footprint Comparison
- nanochat's Hybrid Approach
- When to Use What?
- DDP replicates memory; FSDP shards it across GPUs
- Data Parallelism (DDP)
- ZeRO-2 (DistMuon/DistAdamW)
- ZeRO-3 (Not in nanochat)
- KV caching trades memory for 10× inference speed
- KV Caching
- Inference Batch Size Tuning
- Prefill vs Decode
- Gradient checkpointing trades compute for 3× less activation memory
- The Problem
- Gradient Checkpointing
- PyTorch Implementation
- nanochat Status
- torch.cuda.memory_summary shows exactly where memory goes
- PyTorch Memory Stats
- nanochat's Usage
- Profiling Individual Components
- PyTorch Profiler (Advanced)
- Debugging OOM Errors
- These patterns prevent OOM errors
- Best Practices
- 1. Start Conservative, Scale Up
- 2. Use Mixed Precision by Default
- 3. Gradient Accumulation for Large Batches
- 4. Clear Gradients Properly
- 5. Monitor Memory Throughout Training
- Common Pitfalls
- Pitfall 1: Forgetting to Scale Loss in Gradient Accumulation
- Pitfall 2: Storing Unnecessary Tensors
- Pitfall 3: Inefficient KV Cache Initialization
- Pitfall 4: Mixed Precision Without Autocast
- Pitfall 5: Not Clearing CUDA Cache
- Pitfall 6: Large Batch Size on Small Models
- Memory optimization is the art of doing more with less
- Sources
- Research Papers
- Technical Documentation
- Framework Resources
- GPU Hardware & Pricing (as of January 2025)
- Industry Research (as of January 2025)
- Related Posts
- Exercises



