Distributed Muon: Custom Gradient Synchronization for Memory-Efficient Training

- Published on
- /17 mins read
📚 nanochat Blog Series - Track 1: Technical Deep-Dives
Part 2 of 6 - Understanding the "Why" behind nanochat's technical innovations
- 1.1 The Muon Optimizer Explained
- 1.2 Distributed Muon (You are here)
- 1.3 KV Caching Deep-Dive (Coming soon)
- 1.4 Modern Transformer Architecture (Coming soon)
- 1.5 Training Data Pipeline (Coming soon)
- 1.6 Loss Landscape & Scaling Laws (Coming soon)
Introduction
Training large language models requires distributing computation across multiple GPUs. While PyTorch's Distributed Data Parallel (DDP) makes this conceptually simple, it comes with significant memory overhead—every GPU stores a complete copy of the model parameters, gradients, and optimizer states.
For a 1 billion parameter model on 8 GPUs, this means 8× redundant copies of everything.
Enter ZeRO (Zero Redundancy Optimizer), which eliminates this redundancy by sharding optimizer states and gradients across GPUs. Microsoft's DeepSpeed popularized ZeRO, but nanochat implements a custom variant specifically tailored for the Muon optimizer.
Why custom? Because Muon's Newton-Schulz orthogonalization requires preserving the 2D matrix structure of parameters—you can't just slice weight matrices arbitrarily.
NOTE
Prerequisites: Understanding of the Muon optimizer and basic distributed training concepts. Reading time: ~12 minutes.
In this post, we'll dissect DistMuon, nanochat's distributed Muon implementation that achieves:
- ~2-3× memory savings compared to standard DDP
- Seamless integration with existing training loops
- Custom reduce_scatter → compute → all_gather pattern optimized for Muon
Let's explore how it works under the hood.
The DDP Baseline: Understanding the Problem
Standard DDP's Synchronization Model
PyTorch DDP follows a simple but memory-inefficient pattern:
# Pseudo-code for standard DDP
for step in training_loop:
loss = model(x, y)
loss.backward() # Compute gradients locally
# DDP hooks: all_reduce gradients (implicit, happens in backward)
optimizer.step() # Each rank updates full model independentlyDuring backward(), DDP's hooks automatically trigger an all_reduce operation that averages gradients across all ranks. This ensures every GPU has identical gradients before the optimizer step.
Memory Overhead Analysis
For each parameter, every rank stores:
- Parameters (model weights):
Pbytes - Gradients:
Pbytes - Optimizer states: Depends on optimizer
- Adam/AdamW: 2 states (
exp_avg,exp_avg_sq) =2Pbytes - Muon: 1 state (
momentum_buffer) =Pbytes
- Adam/AdamW: 2 states (
Total per rank (Muon + DDP): P + P + P = 3P bytes
Total across N ranks: 3P × N bytes
For a 1B parameter model (2GB in bfloat16) on 8 GPUs:
Memory per rank = 2GB (params) + 2GB (grads) + 2GB (momentum) = 6GB
Total memory = 6GB × 8 = 48GB
This redundancy is wasteful. Can we do better?
What ZeRO-2 Offers
ZeRO has three stages of optimization:
- Stage 1: Shard optimizer states across ranks
- Stage 2: Shard gradients + optimizer states
- Stage 3: Shard parameters + gradients + optimizer states
DistMuon implements ZeRO-2, keeping parameters replicated but sharding gradients and optimizer states. This strikes a balance between memory efficiency and implementation complexity.
Memory per rank (ZeRO-2): P + P/N + P/N = P(1 + 2/N) bytes
For our 1B parameter example with 8 GPUs:
Memory per rank = 2GB (params) + 0.25GB (grads/8) + 0.25GB (momentum/8) = 2.5GB
Total memory = 2.5GB × 8 = 20GB
Savings: 48GB - 20GB = 28GB (58% reduction!)
DistMuon Architecture: Three Key Design Decisions
Parameter Grouping by Shape
DistMuon groups all parameters by their shape before assigning ownership.
From the nanochat codebase (view on GitHub):
rank = dist.get_rank()
shapes = sorted({p.shape for p in params}) # Unique shapes, sorted
param_groups = []
for shape in shapes:
group_params = [p for p in params if p.shape == shape]
device, dtype = group_params[0].device, group_params[0].dtype
assert all(p.device == device for p in group_params)
assert all(p.dtype == dtype for p in group_params)
if rank == 0:
print(f"Muon: Grouping {len(group_params)} params of shape {shape}")
param_groups.append(dict(
params=group_params,
zero_buffer=torch.zeros_like(group_params[0])
))Why group by shape?
- Efficient batched operations: Newton-Schulz can process multiple matrices of the same shape simultaneously
- Simplified communication:
reduce_scatterandall_gatherrequire uniform tensor shapes - Better GPU utilization: Batched matrix operations maximize throughput
TIP
Example: A transformer model might have 100 parameters of shape [768, 768] (attention matrices), 50 of shape [3072, 768] (FFN matrices), and 50 of shape [768, 3072]. DistMuon creates 3 parameter groups, enabling efficient batched Newton-Schulz within each group.
Block-Cyclic Parameter Assignment
Within each shape group, parameters are assigned to ranks in a block-cyclic pattern:
world_size = dist.get_world_size()
for base_i in range(0, len(params), world_size):
owner_idx = base_i + rank # Each rank owns param at (base + rank)Visual representation (4 GPUs, 10 parameters):
Param indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
└─────────┘ └─────────┘ └──┘
Block 0 Block 1 Block 2
Rank 0 owns: [0, 4, 8 ] ← indices 0, 4, 8
Rank 1 owns: [ 1, 5, 9 ] ← indices 1, 5, 9
Rank 2 owns: [ 2, 6 ] ← indices 2, 6
Rank 3 owns: [ 3, 7 ] ← indices 3, 7
Why block-cyclic?
- Load balancing: Distributes parameters roughly evenly across ranks
- Simplicity: Each rank's ownership is a simple calculation (
base_i + rank) - Fault tolerance: Uneven parameter counts are handled gracefully with padding
The Three-Phase Update Pattern
DistMuon orchestrates a custom communication pattern that shards computation while maintaining parameter replication:
Let's examine each phase in detail.
Phase 1: Reduce-Scatter (Gradient Averaging)
The reduce_scatter operation combines gradients from all ranks and distributes the averaged results.
From the nanochat codebase (view on GitHub):
all_reduce_futures = []
for group in self.param_groups:
params = group["params"]
zero_buffer = group["zero_buffer"]
for base_i in range(0, len(params), world_size):
owner_idx = base_i + rank
# Each rank collects gradients for world_size consecutive params
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
# Pad with zeros if we don't have enough params
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
# Output buffer: gradient for the param this rank owns
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
# Launch async reduce_scatter
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
all_reduce_futures.append(work)What happens here:
- Input collection: Each rank gathers gradients for a block of
world_sizeparameters - Padding: If the block is incomplete (e.g., last block with fewer params), pad with
zero_buffer - Reduce-scatter: All ranks participate in averaging gradients
- Output: Each rank receives the averaged gradient for its owned parameter
Example (4 GPUs, block 0 with params [0,1,2,3]):
Rank 0: rs_input = [grad₀[0], grad₀[1], grad₀[2], grad₀[3]] → rs_output = avg(grad[0])
Rank 1: rs_input = [grad₁[0], grad₁[1], grad₁[2], grad₁[3]] → rs_output = avg(grad[1])
Rank 2: rs_input = [grad₂[0], grad₂[1], grad₂[2], grad₂[3]] → rs_output = avg(grad[2])
Rank 3: rs_input = [grad₃[0], grad₃[1], grad₃[2], grad₃[3]] → rs_output = avg(grad[3])
After reduce-scatter, each rank has the averaged gradient for its owned parameter, ready for computation.
Phase 2: Compute Update (Owner Ranks Only)
Once gradients are averaged, each rank computes the Muon update for its owned parameters.
From the nanochat codebase (view on GitHub):
future_idx = 0
all_gather_futures = []
for group in self.param_groups:
params = group["params"]
for base_i in range(0, len(params), world_size):
owner_idx = base_i + rank
# Wait for reduce_scatter to complete
all_reduce_futures[future_idx].wait()
future_idx += 1
# Only owner computes the update
if owner_idx < len(params):
p = params[owner_idx]
g = p.grad # Already averaged across ranks
state = self.state[p]
# Initialize momentum buffer if needed
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf = state["momentum_buffer"]
# Momentum accumulation
buf.lerp_(g, 1.0 - group["momentum"])
# Nesterov momentum (if enabled)
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
# Newton-Schulz orthogonalization
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
# Aspect-ratio scaled step
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
p.add_(g, alpha=-group["lr"] * scale)Key points:
- Wait synchronization:
wait()ensures the gradient is ready before computation - Owner-only execution: Non-owner ranks skip computation (idle during this phase)
- Standard Muon update: Same as single-GPU Muon (see Post 1.1)
- Momentum accumulation with
lerp_ - Optional Nesterov momentum
- Newton-Schulz orthogonalization
- Aspect-ratio scaling:
sqrt(max(1, height/width))
- Momentum accumulation with
NOTE
Memory efficiency: Each rank stores momentum_buffer only for its owned parameters, achieving 1/N sharding.
Phase 3: All-Gather (Parameter Replication)
After computing updates, ranks replicate their updated parameters to all other ranks.
From the nanochat codebase (view on GitHub):
# Replicate updated parameters to all ranks
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
ag_output = params[base_i:base_i + world_size]
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))])
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
all_gather_futures.append(work)
# Wait for all gathers to complete (outside the loops)
torch.futures.collect_all(all_gather_futures).wait()What happens:
- Input: Each rank's owned parameter (or zero_buffer if padding)
- Output: List of tensors to populate with gathered parameters
- All-gather: Broadcast each rank's parameter to all other ranks
- Synchronization:
collect_all().wait()ensures all communications complete
After all-gather, every rank has identical copies of all parameters, ready for the next forward pass.
Comparing DistMuon and DistAdamW
Both optimizers implement ZeRO-2, but their strategies differ due to algorithmic requirements.
Key Differences
| Feature | DistAdamW | DistMuon | Reason |
|---|---|---|---|
| Parameter Requirements | Any shape | 2D only | Newton-Schulz needs matrices |
| Sharding Strategy | Slice along dim 0 | Block-cyclic whole params | Preserve aspect ratio |
| State Storage | Slice-local (exp_avg, exp_avg_sq) | Param-local (momentum_buffer) | Matrix operations |
| Compute Pattern | All ranks on slices | Owner ranks only | Simplify NS batching |
| Reduce-scatter Input | Full tensor | List of tensors | Shape uniformity |
| Memory Efficiency | ~1/N states | ~1/N states | Similar overall |
| Load Balance | Perfect (slicing) | Imperfect (padding) | Trade-off for simplicity |
DistAdamW's Sharding Approach
From the nanochat codebase (view on GitHub):
for base_i in range(len(params)):
grad = params[base_i].grad
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
reduce_scatter_futures.append(
dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True)
)DistAdamW slices each parameter along the first dimension, distributing rows across ranks. This works for any parameter shape and achieves perfect load balancing.
WARNING
Why doesn't DistMuon do this? Newton-Schulz requires the full 2D matrix structure with its original aspect ratio. Slicing a [768, 768] matrix into [192, 768] slices would change the aspect ratio from 1:1 to 1:4, breaking the orthogonalization geometry. DistMuon preserves matrices intact by assigning whole parameters to owner ranks.
Comparison: Sharding Granularity
Example: 4 GPUs, parameter shape [1024, 768]
DistAdamW:
┌────────────┐
│ Rank 0 (256 rows) │ Each rank stores:
├────────────┤ - param slice [256, 768]
│ Rank 1 (256 rows) │ - grad slice [256, 768]
├────────────┤ - exp_avg slice [256, 768]
│ Rank 2 (256 rows) │ - exp_avg_sq slice [256, 768]
├────────────┤
│ Rank 3 (256 rows) │
└────────────┘
DistMuon (within a block of 4 params):
Rank 0: param[0] [1024,768] ← Full matrix
Rank 1: param[1] [1024,768] ← Full matrix
Rank 2: param[2] [1024,768] ← Full matrix
Rank 3: param[3] [1024,768] ← Full matrix
Each rank stores:
- Full param (replicated)
- momentum_buffer for owned param only
Memory Analysis and Efficiency Gains
Memory Breakdown Per Rank
Standard DDP + Muon:
Parameters: P bytes (full model)
Gradients: P bytes (full model)
Momentum buffers: P bytes (full model)
─────────────────────────────
Total per rank: 3P bytes
Total across N: 3P × N bytes
DistMuon (ZeRO-2):
Parameters: P bytes (replicated)
Gradients: P/N bytes (sharded)
Momentum buffers: P/N bytes (sharded)
─────────────────────────────
Total per rank: P(1 + 2/N) bytes
Total across N: P(N + 2) bytes
Efficiency Calculations
| N ranks | DDP Total | DistMuon Total | Memory Savings | Savings % |
|---|---|---|---|---|
| 2 | 6P | 4P | 2P | 33% |
| 4 | 12P | 6P | 6P | 50% |
| 8 | 24P | 10P | 14P | 58% |
| 16 | 48P | 18P | 30P | 63% |
| 64 | 192P | 66P | 126P | 66% |
Asymptotic behavior: As N → ∞, savings approach 67% (2/3 reduction).
Practical Example: nanochat's 270M Model
nanochat's depth-20 model has ~270M parameters in bfloat16 (540 MB total).
8× H100 GPUs (80GB each):
| Metric | Standard DDP | DistMuon | Savings |
|---|---|---|---|
| Params | 540 MB | 540 MB | 0 MB |
| Grads | 540 MB | 67.5 MB | 472.5 MB |
| States | 540 MB | 67.5 MB | 472.5 MB |
| Total/rank | 1.62 GB | 675 MB | 945 MB (58%) |
| Total/cluster | 12.96 GB | 5.4 GB | 7.56 GB |
This 945 MB per-rank saving allows:
- Larger batch sizes (more memory for activations)
- Longer sequences (quadratic attention memory)
- Bigger models (fit 470M params with same memory as 270M DDP)
Implementation Deep-Dive: Async Communication
Why Asynchronous Operations?
DistMuon uses async_op=True throughout to overlap communication with computation:
# Launch all reduce-scatters without waiting
for group in self.param_groups:
for base_i in range(0, len(params), world_size):
work = dist.reduce_scatter(..., async_op=True).get_future()
all_reduce_futures.append(work)
# Compute and gather (wait only when needed)
for group in self.param_groups:
for base_i in range(0, len(params), world_size):
all_reduce_futures[future_idx].wait() # Wait for specific gradient
if owner_idx < len(params):
# Compute Muon update
...
work = dist.all_gather(..., async_op=True).get_future()
all_gather_futures.append(work)
# Final synchronization
torch.futures.collect_all(all_gather_futures).wait()Benefits:
- Communication-computation overlap: While GPU computes updates for earlier parameters, network transfers gradients for later parameters
- Pipelining: Reduce-scatter and all-gather operations can overlap across parameter groups
- Lower latency: Non-blocking calls prevent idle GPU time
Synchronization Pattern
Time ──────────────────────────────────────────────►
Rank 0:
[reduce_scatter₀] [wait] [compute₀] [all_gather₀]
[reduce_scatter₁] [wait] [compute₁] [all_gather₁]
[reduce_scatter₂] [wait] ...
Rank 1:
[reduce_scatter₀] [wait] [compute₀] [all_gather₀]
[reduce_scatter₁] [wait] [compute₁] [all_gather₁]
[reduce_scatter₂] [wait] ...
TIP
Key insight: By launching all reduce-scatters first, then processing them sequentially with compute + gather, we maximize overlap.
Integration with Training: Seamless Drop-In Replacement
One of DistMuon's best features is its zero-friction integration.
From the training script (view on GitHub):
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay
)
adamw_optimizer, muon_optimizer = optimizersThe setup_optimizers() method automatically selects DistMuon when running distributed:
def setup_optimizers(self, ...):
# Separate parameters by dimensionality
matrix_params = [p for p in self.parameters() if p.ndim == 2]
vector_params = [p for p in self.parameters() if p.ndim < 2]
# Use Dist* optimizers if distributed, else regular
if dist.is_initialized():
from nanochat.muon import DistMuon
from nanochat.adamw import DistAdamW
muon_opt = DistMuon(matrix_params, lr=matrix_lr, ...)
adamw_opt = DistAdamW([{"params": vector_params}], lr=embedding_lr, ...)
else:
muon_opt = Muon(matrix_params, lr=matrix_lr, ...)
adamw_opt = AdamW(vector_params, lr=embedding_lr, ...)
return adamw_opt, muon_optTraining loop (unchanged):
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
for opt in optimizers:
opt.step() # DistMuon.step() handles all communication!
model.zero_grad(set_to_none=True)No special handling needed—just call opt.step() and DistMuon orchestrates all distributed operations internally.
Performance Characteristics
Communication Cost Analysis
For P parameters and N ranks:
| Operation | Data Volume (per rank) | Time Complexity |
|---|---|---|
| Reduce-scatter | Send: P/N, Recv: P/N | O(P/N) |
| Compute (Muon) | Local only | O(P/N) |
| All-gather | Send: P/N, Recv: P | O(P) |
| Total per step | Send: 2P/N, Recv: P(1 + 1/N) | O(P) |
Key takeaway: Communication scales linearly with parameter count, independent of N for large N.
Scaling Behavior
Weak scaling (increase model size proportionally with GPUs):
- ✅ Near-linear: Memory per rank stays constant
- ✅ Communication stays constant: Each rank sends/receives same amount
Strong scaling (fixed model size, increase GPUs):
- ⚠️ Sub-linear: Communication overhead increases relative to computation
- ⚠️ Sweet spot: 8-64 GPUs for typical Transformers
- ❌ Poor at scale: Beyond 128 GPUs, communication dominates
Example: Training nanochat's 270M model
- 8 GPUs: ~90% scaling efficiency
- 64 GPUs: ~75% scaling efficiency
- 512 GPUs: ~40% scaling efficiency (not recommended)
Conclusion
DistMuon demonstrates that domain-specific optimizations can significantly improve distributed training efficiency. By tailoring the ZeRO-2 pattern to Muon's unique needs—preserving matrix structure, enabling batched Newton-Schulz, and implementing block-cyclic assignment—nanochat achieves:
- 58-67% memory savings vs standard DDP (8-64 GPUs)
- Seamless integration with existing codebases
- Efficient scaling for typical Transformer training workloads
Key Takeaways
✅ When to use DistMuon:
- Training Transformers with 2D weight matrices
- Memory-constrained multi-GPU setups (8-64 GPUs)
- Want ZeRO-2 benefits without DeepSpeed dependency
❌ When to avoid:
- Single GPU training (use regular Muon)
- Models with mostly 1D parameters (use DistAdamW)
- Extreme scale (>128 GPUs, consider ZeRO-3 or model parallelism)
Design Principles Worth Remembering
- Group by shape: Enable batched operations by processing uniform tensors together
- Block-cyclic assignment: Balance load while maintaining simplicity
- Async communication: Overlap network transfers with computation
- Preserve algorithmic invariants: Don't break Newton-Schulz by slicing matrices
What's Next in This Series
💾 Post 1.3: KV Caching Deep-Dive (Coming Soon)
Memory-efficient Transformer inference with prefill-and-clone patterns and dynamic cache growth.
🏗️ Post 1.4: Modern Architecture Choices (Coming Soon)
RoPE, QK normalization, Multi-Query Attention, and design trade-offs explained.
Further Reading
- ZeRO Paper (Rajbhandari et al., 2020) - Original ZeRO optimization stages
- PyTorch DDP Tutorial - Understanding standard distributed training
- PyTorch Distributed Collective Ops -
reduce_scatter,all_gatherdocumentation - Muon Optimizer - Original Muon paper
- Post 1.1: The Muon Optimizer Explained - Prerequisite reading
- nanochat source: muon.py on GitHub
About this series: This is part of a comprehensive blog series exploring the technical innovations in nanochat, Andrej Karpathy's minimal ChatGPT implementation.
On this page
- Introduction
- The DDP Baseline: Understanding the Problem
- Standard DDP's Synchronization Model
- Memory Overhead Analysis
- What ZeRO-2 Offers
- DistMuon Architecture: Three Key Design Decisions
- Parameter Grouping by Shape
- Block-Cyclic Parameter Assignment
- The Three-Phase Update Pattern
- Phase 1: Reduce-Scatter (Gradient Averaging)
- Phase 2: Compute Update (Owner Ranks Only)
- Phase 3: All-Gather (Parameter Replication)
- Comparing DistMuon and DistAdamW
- Key Differences
- DistAdamW's Sharding Approach
- Comparison: Sharding Granularity
- Memory Analysis and Efficiency Gains
- Memory Breakdown Per Rank
- Efficiency Calculations
- Practical Example: nanochat's 270M Model
- Implementation Deep-Dive: Async Communication
- Why Asynchronous Operations?
- Synchronization Pattern
- Integration with Training: Seamless Drop-In Replacement
- Performance Characteristics
- Communication Cost Analysis
- Scaling Behavior
- Conclusion
- Key Takeaways
- Design Principles Worth Remembering
- What's Next in This Series
- Further Reading



