José David Baena

Distributed Muon: Custom Gradient Synchronization for Memory-Efficient Training

Distributed muon custom gradient synchronization banner.jpg
Published on
/17 mins read

📚 nanochat Blog Series - Track 1: Technical Deep-Dives

Part 2 of 6 - Understanding the "Why" behind nanochat's technical innovations

  1. 1.1 The Muon Optimizer Explained
  2. 1.2 Distributed Muon (You are here)
  3. 1.3 KV Caching Deep-Dive (Coming soon)
  4. 1.4 Modern Transformer Architecture (Coming soon)
  5. 1.5 Training Data Pipeline (Coming soon)
  6. 1.6 Loss Landscape & Scaling Laws (Coming soon)

Introduction

Training large language models requires distributing computation across multiple GPUs. While PyTorch's Distributed Data Parallel (DDP) makes this conceptually simple, it comes with significant memory overhead—every GPU stores a complete copy of the model parameters, gradients, and optimizer states.

For a 1 billion parameter model on 8 GPUs, this means 8× redundant copies of everything.

Enter ZeRO (Zero Redundancy Optimizer), which eliminates this redundancy by sharding optimizer states and gradients across GPUs. Microsoft's DeepSpeed popularized ZeRO, but nanochat implements a custom variant specifically tailored for the Muon optimizer.

Why custom? Because Muon's Newton-Schulz orthogonalization requires preserving the 2D matrix structure of parameters—you can't just slice weight matrices arbitrarily.

NOTE

Prerequisites: Understanding of the Muon optimizer and basic distributed training concepts. Reading time: ~12 minutes.

In this post, we'll dissect DistMuon, nanochat's distributed Muon implementation that achieves:

  • ~2-3× memory savings compared to standard DDP
  • Seamless integration with existing training loops
  • Custom reduce_scatter → compute → all_gather pattern optimized for Muon

Let's explore how it works under the hood.


The DDP Baseline: Understanding the Problem

Standard DDP's Synchronization Model

PyTorch DDP follows a simple but memory-inefficient pattern:

Standard DDP Training Pattern
# Pseudo-code for standard DDP
for step in training_loop:
    loss = model(x, y)
    loss.backward()  # Compute gradients locally
    # DDP hooks: all_reduce gradients (implicit, happens in backward)
    optimizer.step()  # Each rank updates full model independently

During backward(), DDP's hooks automatically trigger an all_reduce operation that averages gradients across all ranks. This ensures every GPU has identical gradients before the optimizer step.

Memory Overhead Analysis

For each parameter, every rank stores:

  1. Parameters (model weights): P bytes
  2. Gradients: P bytes
  3. Optimizer states: Depends on optimizer
    • Adam/AdamW: 2 states (exp_avg, exp_avg_sq) = 2P bytes
    • Muon: 1 state (momentum_buffer) = P bytes

Total per rank (Muon + DDP): P + P + P = 3P bytes
Total across N ranks: 3P × N bytes

For a 1B parameter model (2GB in bfloat16) on 8 GPUs:

Memory per rank = 2GB (params) + 2GB (grads) + 2GB (momentum) = 6GB
Total memory = 6GB × 8 = 48GB

This redundancy is wasteful. Can we do better?

What ZeRO-2 Offers

ZeRO has three stages of optimization:

  • Stage 1: Shard optimizer states across ranks
  • Stage 2: Shard gradients + optimizer states
  • Stage 3: Shard parameters + gradients + optimizer states

DistMuon implements ZeRO-2, keeping parameters replicated but sharding gradients and optimizer states. This strikes a balance between memory efficiency and implementation complexity.

Memory per rank (ZeRO-2): P + P/N + P/N = P(1 + 2/N) bytes

For our 1B parameter example with 8 GPUs:

Memory per rank = 2GB (params) + 0.25GB (grads/8) + 0.25GB (momentum/8) = 2.5GB
Total memory = 2.5GB × 8 = 20GB
Savings: 48GB - 20GB = 28GB (58% reduction!)

DistMuon Architecture: Three Key Design Decisions

Parameter Grouping by Shape

DistMuon groups all parameters by their shape before assigning ownership.

From the nanochat codebase (view on GitHub):

nanochat/muon.py - Parameter Grouping
rank = dist.get_rank()
shapes = sorted({p.shape for p in params})  # Unique shapes, sorted
param_groups = []
for shape in shapes:
    group_params = [p for p in params if p.shape == shape]
    device, dtype = group_params[0].device, group_params[0].dtype
    assert all(p.device == device for p in group_params)
    assert all(p.dtype == dtype for p in group_params)
    if rank == 0:
        print(f"Muon: Grouping {len(group_params)} params of shape {shape}")
    param_groups.append(dict(
        params=group_params,
        zero_buffer=torch.zeros_like(group_params[0])
    ))

Why group by shape?

  1. Efficient batched operations: Newton-Schulz can process multiple matrices of the same shape simultaneously
  2. Simplified communication: reduce_scatter and all_gather require uniform tensor shapes
  3. Better GPU utilization: Batched matrix operations maximize throughput

TIP

Example: A transformer model might have 100 parameters of shape [768, 768] (attention matrices), 50 of shape [3072, 768] (FFN matrices), and 50 of shape [768, 3072]. DistMuon creates 3 parameter groups, enabling efficient batched Newton-Schulz within each group.

Block-Cyclic Parameter Assignment

Within each shape group, parameters are assigned to ranks in a block-cyclic pattern:

Block-Cyclic Assignment Pattern
world_size = dist.get_world_size()
for base_i in range(0, len(params), world_size):
    owner_idx = base_i + rank  # Each rank owns param at (base + rank)

Visual representation (4 GPUs, 10 parameters):

Param indices:  [0, 1, 2, 3,  4, 5, 6, 7,  8, 9]
                 └─────────┘  └─────────┘  └──┘
                  Block 0      Block 1    Block 2

Rank 0 owns:    [0,          4,          8    ]  ← indices 0, 4, 8
Rank 1 owns:    [   1,          5,          9 ]  ← indices 1, 5, 9
Rank 2 owns:    [      2,          6          ]  ← indices 2, 6
Rank 3 owns:    [         3,          7       ]  ← indices 3, 7

Why block-cyclic?

  • Load balancing: Distributes parameters roughly evenly across ranks
  • Simplicity: Each rank's ownership is a simple calculation (base_i + rank)
  • Fault tolerance: Uneven parameter counts are handled gracefully with padding

The Three-Phase Update Pattern

DistMuon orchestrates a custom communication pattern that shards computation while maintaining parameter replication:

Let's examine each phase in detail.


Phase 1: Reduce-Scatter (Gradient Averaging)

The reduce_scatter operation combines gradients from all ranks and distributes the averaged results.

From the nanochat codebase (view on GitHub):

nanochat/muon.py - Reduce-Scatter Phase
all_reduce_futures = []
for group in self.param_groups:
    params = group["params"]
    zero_buffer = group["zero_buffer"]
    for base_i in range(0, len(params), world_size):
        owner_idx = base_i + rank
        # Each rank collects gradients for world_size consecutive params
        rs_input = [p.grad for p in params[base_i:base_i + world_size]]
        # Pad with zeros if we don't have enough params
        rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
        # Output buffer: gradient for the param this rank owns
        rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
        # Launch async reduce_scatter
        work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
        all_reduce_futures.append(work)

What happens here:

  1. Input collection: Each rank gathers gradients for a block of world_size parameters
  2. Padding: If the block is incomplete (e.g., last block with fewer params), pad with zero_buffer
  3. Reduce-scatter: All ranks participate in averaging gradients
  4. Output: Each rank receives the averaged gradient for its owned parameter

Example (4 GPUs, block 0 with params [0,1,2,3]):

Rank 0: rs_input = [grad₀[0], grad₀[1], grad₀[2], grad₀[3]]  → rs_output = avg(grad[0])
Rank 1: rs_input = [grad₁[0], grad₁[1], grad₁[2], grad₁[3]]  → rs_output = avg(grad[1])
Rank 2: rs_input = [grad₂[0], grad₂[1], grad₂[2], grad₂[3]]  → rs_output = avg(grad[2])
Rank 3: rs_input = [grad₃[0], grad₃[1], grad₃[2], grad₃[3]]  → rs_output = avg(grad[3])

After reduce-scatter, each rank has the averaged gradient for its owned parameter, ready for computation.


Phase 2: Compute Update (Owner Ranks Only)

Once gradients are averaged, each rank computes the Muon update for its owned parameters.

From the nanochat codebase (view on GitHub):

nanochat/muon.py - Compute Phase
future_idx = 0
all_gather_futures = []
for group in self.param_groups:
    params = group["params"]
    for base_i in range(0, len(params), world_size):
        owner_idx = base_i + rank
        # Wait for reduce_scatter to complete
        all_reduce_futures[future_idx].wait()
        future_idx += 1
        
        # Only owner computes the update
        if owner_idx < len(params):
            p = params[owner_idx]
            g = p.grad  # Already averaged across ranks
            state = self.state[p]
            
            # Initialize momentum buffer if needed
            if "momentum_buffer" not in state:
                state["momentum_buffer"] = torch.zeros_like(g)
            
            buf = state["momentum_buffer"]
            # Momentum accumulation
            buf.lerp_(g, 1.0 - group["momentum"])
            # Nesterov momentum (if enabled)
            g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
            # Newton-Schulz orthogonalization
            g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
            # Aspect-ratio scaled step
            scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
            p.add_(g, alpha=-group["lr"] * scale)

Key points:

  1. Wait synchronization: wait() ensures the gradient is ready before computation
  2. Owner-only execution: Non-owner ranks skip computation (idle during this phase)
  3. Standard Muon update: Same as single-GPU Muon (see Post 1.1)
    • Momentum accumulation with lerp_
    • Optional Nesterov momentum
    • Newton-Schulz orthogonalization
    • Aspect-ratio scaling: sqrt(max(1, height/width))

NOTE

Memory efficiency: Each rank stores momentum_buffer only for its owned parameters, achieving 1/N sharding.


Phase 3: All-Gather (Parameter Replication)

After computing updates, ranks replicate their updated parameters to all other ranks.

From the nanochat codebase (view on GitHub):

nanochat/muon.py - All-Gather Phase
        # Replicate updated parameters to all ranks
        ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
        ag_output = params[base_i:base_i + world_size]
        ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))])
        work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
        all_gather_futures.append(work)
 
# Wait for all gathers to complete (outside the loops)
torch.futures.collect_all(all_gather_futures).wait()

What happens:

  1. Input: Each rank's owned parameter (or zero_buffer if padding)
  2. Output: List of tensors to populate with gathered parameters
  3. All-gather: Broadcast each rank's parameter to all other ranks
  4. Synchronization: collect_all().wait() ensures all communications complete

After all-gather, every rank has identical copies of all parameters, ready for the next forward pass.


Comparing DistMuon and DistAdamW

Both optimizers implement ZeRO-2, but their strategies differ due to algorithmic requirements.

Key Differences

FeatureDistAdamWDistMuonReason
Parameter RequirementsAny shape2D onlyNewton-Schulz needs matrices
Sharding StrategySlice along dim 0Block-cyclic whole paramsPreserve aspect ratio
State StorageSlice-local (exp_avg, exp_avg_sq)Param-local (momentum_buffer)Matrix operations
Compute PatternAll ranks on slicesOwner ranks onlySimplify NS batching
Reduce-scatter InputFull tensorList of tensorsShape uniformity
Memory Efficiency~1/N states~1/N statesSimilar overall
Load BalancePerfect (slicing)Imperfect (padding)Trade-off for simplicity

DistAdamW's Sharding Approach

From the nanochat codebase (view on GitHub):

nanochat/adamw.py - Tensor Slicing Strategy
for base_i in range(len(params)):
    grad = params[base_i].grad
    rank_size = grad.shape[0] // world_size
    grad_slice = torch.empty_like(grad[:rank_size])
    reduce_scatter_futures.append(
        dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True)
    )

DistAdamW slices each parameter along the first dimension, distributing rows across ranks. This works for any parameter shape and achieves perfect load balancing.

WARNING

Why doesn't DistMuon do this? Newton-Schulz requires the full 2D matrix structure with its original aspect ratio. Slicing a [768, 768] matrix into [192, 768] slices would change the aspect ratio from 1:1 to 1:4, breaking the orthogonalization geometry. DistMuon preserves matrices intact by assigning whole parameters to owner ranks.

Comparison: Sharding Granularity

Example: 4 GPUs, parameter shape [1024, 768]

DistAdamW:
┌────────────┐
│ Rank 0 (256 rows) │  Each rank stores:
├────────────┤  - param slice [256, 768]
│ Rank 1 (256 rows) │  - grad slice [256, 768]
├────────────┤  - exp_avg slice [256, 768]
│ Rank 2 (256 rows) │  - exp_avg_sq slice [256, 768]
├────────────┤
│ Rank 3 (256 rows) │
└────────────┘

DistMuon (within a block of 4 params):
Rank 0: param[0] [1024,768]  ← Full matrix
Rank 1: param[1] [1024,768]  ← Full matrix
Rank 2: param[2] [1024,768]  ← Full matrix
Rank 3: param[3] [1024,768]  ← Full matrix

Each rank stores:
- Full param (replicated)
- momentum_buffer for owned param only

Memory Analysis and Efficiency Gains

Memory Breakdown Per Rank

Standard DDP + Muon:

Parameters:          P bytes (full model)
Gradients:           P bytes (full model)
Momentum buffers:    P bytes (full model)
─────────────────────────────
Total per rank:      3P bytes
Total across N:      3P × N bytes

DistMuon (ZeRO-2):

Parameters:          P bytes (replicated)
Gradients:           P/N bytes (sharded)
Momentum buffers:    P/N bytes (sharded)
─────────────────────────────
Total per rank:      P(1 + 2/N) bytes
Total across N:      P(N + 2) bytes

Efficiency Calculations

N ranksDDP TotalDistMuon TotalMemory SavingsSavings %
26P4P2P33%
412P6P6P50%
824P10P14P58%
1648P18P30P63%
64192P66P126P66%

Asymptotic behavior: As N → ∞, savings approach 67% (2/3 reduction).

Practical Example: nanochat's 270M Model

nanochat's depth-20 model has ~270M parameters in bfloat16 (540 MB total).

8× H100 GPUs (80GB each):

MetricStandard DDPDistMuonSavings
Params540 MB540 MB0 MB
Grads540 MB67.5 MB472.5 MB
States540 MB67.5 MB472.5 MB
Total/rank1.62 GB675 MB945 MB (58%)
Total/cluster12.96 GB5.4 GB7.56 GB

This 945 MB per-rank saving allows:

  • Larger batch sizes (more memory for activations)
  • Longer sequences (quadratic attention memory)
  • Bigger models (fit 470M params with same memory as 270M DDP)

Implementation Deep-Dive: Async Communication

Why Asynchronous Operations?

DistMuon uses async_op=True throughout to overlap communication with computation:

Asynchronous Communication Pattern
# Launch all reduce-scatters without waiting
for group in self.param_groups:
    for base_i in range(0, len(params), world_size):
        work = dist.reduce_scatter(..., async_op=True).get_future()
        all_reduce_futures.append(work)
 
# Compute and gather (wait only when needed)
for group in self.param_groups:
    for base_i in range(0, len(params), world_size):
        all_reduce_futures[future_idx].wait()  # Wait for specific gradient
        if owner_idx < len(params):
            # Compute Muon update
            ...
        work = dist.all_gather(..., async_op=True).get_future()
        all_gather_futures.append(work)
 
# Final synchronization
torch.futures.collect_all(all_gather_futures).wait()

Benefits:

  1. Communication-computation overlap: While GPU computes updates for earlier parameters, network transfers gradients for later parameters
  2. Pipelining: Reduce-scatter and all-gather operations can overlap across parameter groups
  3. Lower latency: Non-blocking calls prevent idle GPU time

Synchronization Pattern

Time ──────────────────────────────────────────────►

Rank 0:
  [reduce_scatter₀] [wait] [compute₀] [all_gather₀]
                    [reduce_scatter₁] [wait] [compute₁] [all_gather₁]
                                      [reduce_scatter₂] [wait] ...

Rank 1:
  [reduce_scatter₀] [wait] [compute₀] [all_gather₀]
                    [reduce_scatter₁] [wait] [compute₁] [all_gather₁]
                                      [reduce_scatter₂] [wait] ...

TIP

Key insight: By launching all reduce-scatters first, then processing them sequentially with compute + gather, we maximize overlap.


Integration with Training: Seamless Drop-In Replacement

One of DistMuon's best features is its zero-friction integration.

From the training script (view on GitHub):

scripts/base_train.py - Optimizer Setup
optimizers = model.setup_optimizers(
    unembedding_lr=unembedding_lr,
    embedding_lr=embedding_lr,
    matrix_lr=matrix_lr,
    weight_decay=weight_decay
)
adamw_optimizer, muon_optimizer = optimizers

The setup_optimizers() method automatically selects DistMuon when running distributed:

Automatic DistMuon Selection Pattern
def setup_optimizers(self, ...):
    # Separate parameters by dimensionality
    matrix_params = [p for p in self.parameters() if p.ndim == 2]
    vector_params = [p for p in self.parameters() if p.ndim < 2]
    
    # Use Dist* optimizers if distributed, else regular
    if dist.is_initialized():
        from nanochat.muon import DistMuon
        from nanochat.adamw import DistAdamW
        muon_opt = DistMuon(matrix_params, lr=matrix_lr, ...)
        adamw_opt = DistAdamW([{"params": vector_params}], lr=embedding_lr, ...)
    else:
        muon_opt = Muon(matrix_params, lr=matrix_lr, ...)
        adamw_opt = AdamW(vector_params, lr=embedding_lr, ...)
    
    return adamw_opt, muon_opt

Training loop (unchanged):

Training Loop - No Special Handling Needed
lrm = get_lr_multiplier(step)
for opt in optimizers:
    for group in opt.param_groups:
        group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
    group["momentum"] = muon_momentum
for opt in optimizers:
    opt.step()  # DistMuon.step() handles all communication!
model.zero_grad(set_to_none=True)

No special handling needed—just call opt.step() and DistMuon orchestrates all distributed operations internally.


Performance Characteristics

Communication Cost Analysis

For P parameters and N ranks:

OperationData Volume (per rank)Time Complexity
Reduce-scatterSend: P/N, Recv: P/NO(P/N)
Compute (Muon)Local onlyO(P/N)
All-gatherSend: P/N, Recv: PO(P)
Total per stepSend: 2P/N, Recv: P(1 + 1/N)O(P)

Key takeaway: Communication scales linearly with parameter count, independent of N for large N.

Scaling Behavior

Weak scaling (increase model size proportionally with GPUs):

  • Near-linear: Memory per rank stays constant
  • Communication stays constant: Each rank sends/receives same amount

Strong scaling (fixed model size, increase GPUs):

  • ⚠️ Sub-linear: Communication overhead increases relative to computation
  • ⚠️ Sweet spot: 8-64 GPUs for typical Transformers
  • Poor at scale: Beyond 128 GPUs, communication dominates

Example: Training nanochat's 270M model

  • 8 GPUs: ~90% scaling efficiency
  • 64 GPUs: ~75% scaling efficiency
  • 512 GPUs: ~40% scaling efficiency (not recommended)

Conclusion

DistMuon demonstrates that domain-specific optimizations can significantly improve distributed training efficiency. By tailoring the ZeRO-2 pattern to Muon's unique needs—preserving matrix structure, enabling batched Newton-Schulz, and implementing block-cyclic assignment—nanochat achieves:

  1. 58-67% memory savings vs standard DDP (8-64 GPUs)
  2. Seamless integration with existing codebases
  3. Efficient scaling for typical Transformer training workloads

Key Takeaways

When to use DistMuon:

  • Training Transformers with 2D weight matrices
  • Memory-constrained multi-GPU setups (8-64 GPUs)
  • Want ZeRO-2 benefits without DeepSpeed dependency

When to avoid:

  • Single GPU training (use regular Muon)
  • Models with mostly 1D parameters (use DistAdamW)
  • Extreme scale (>128 GPUs, consider ZeRO-3 or model parallelism)

Design Principles Worth Remembering

  1. Group by shape: Enable batched operations by processing uniform tensors together
  2. Block-cyclic assignment: Balance load while maintaining simplicity
  3. Async communication: Overlap network transfers with computation
  4. Preserve algorithmic invariants: Don't break Newton-Schulz by slicing matrices

What's Next in This Series

💾 Post 1.3: KV Caching Deep-Dive (Coming Soon)

Memory-efficient Transformer inference with prefill-and-clone patterns and dynamic cache growth.

🏗️ Post 1.4: Modern Architecture Choices (Coming Soon)

RoPE, QK normalization, Multi-Query Attention, and design trade-offs explained.

Further Reading


About this series: This is part of a comprehensive blog series exploring the technical innovations in nanochat, Andrej Karpathy's minimal ChatGPT implementation.