Reinforcement Learning from Human Feedback (RLHF)

- Published on
- /23 mins read
Track 2: Practical Guides - Post 2.3 of 6
This post builds on Fine-tuning for Chat (SFT). View all posts in this track →
Prerequisites and Installation
Before starting RL training, ensure you have a properly configured nanochat environment and a trained SFT model.
System Requirements:
- CUDA: 11.8+ or 12.x (required for GPU training)
- Python: 3.10-3.11 (nanochat compatibility)
- RAM: 32GB+ (for data loading and rollout generation)
- GPU: 24GB+ VRAM recommended (e.g., RTX 3090, A100)
- Single GPU: Works but slower (~32 hours for full run)
- 8× GPUs: Optimal for production training (~4 hours)
- Disk: 10GB+ for nanochat repository and checkpoints
Installation:
# Clone nanochat repository
git clone https://github.com/karpathy/nanochat.git
cd nanochat
# Install uv package manager (fast, reliable dependency management)
curl -LsSf https://astral.sh/uv/install.sh | sh
# Create virtual environment and install dependencies
uv venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
uv pip install -e .
# Verify installation
python -c "import torch; print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}')"Required Checkpoint: RL training requires a pre-trained SFT model. If you don't have one:
# Verify SFT checkpoint exists
ls -la chatsft_checkpoints/
# Should see directories like: d12/, d20/, or d24/ with model_step_*.pt files
# If missing, see: /blog/nanochat-deep-dive/fine-tuning-for-chat-sftCommon Installation Issues:
| Error | Cause | Solution |
|---|---|---|
ImportError: No module named 'nanochat' | Not installed in editable mode | Run uv pip install -e . from nanochat root |
CUDA out of memory during rollouts | Insufficient VRAM for device_batch_size=8 | Reduce to --device_batch_size=4 or =2 |
FileNotFoundError: chatsft_checkpoints/ | Missing SFT checkpoint | Train SFT model first (see prerequisites above) |
RuntimeError: CUDA error | CUDA version mismatch | Ensure PyTorch CUDA version matches system CUDA |
| Slow rollout generation | CPU bottleneck | Increase num_workers in data loading or reduce num_samples |
Introduction
Supervised Fine-Tuning (SFT) teaches a model to follow instructions by imitating human-written responses. But imitation has limits—how do you train a model to be more helpful, more accurate, or better at reasoning when you can't easily demonstrate the perfect answer?
This is where Reinforcement Learning from Human Feedback (RLHF) shines. Instead of imitating demonstrations, the model learns by trying different responses and receiving feedback on which ones are better. This allows it to discover solutions that might be better than any single human demonstration.
This post covers nanochat's RL implementation, which applies a simplified form of GRPO (Group Relative Policy Optimization) to improve mathematical reasoning on GSM8K:
- The fundamental difference between SFT and RL for chat models
- How to design reward functions for subjective qualities
- The rollout generation process and advantage estimation
- Policy gradient optimization without trust regions or PPO
- Practical considerations for RL training stability
- Tool use integration in RL (calculator for math problems)
The RL Paradigm for Chat Models
From Imitation to Optimization
SFT optimizes:
max E_{(x,y)~D} [log p(y|x)]
Where D is a dataset of (question, answer) pairs.
RL optimizes:
max E_{x~D, y~π} [R(x, y)]
Where π is the policy (our model), and R(x,y) is a reward function scoring how good response y is for question x.
Key difference: SFT learns from fixed demonstrations. RL learns from its own generations and improves based on feedback.
Why RL After SFT?
Consider this GSM8K problem:
Question: Weng earns $12/hour babysitting. Yesterday she did 50 minutes. How much did she earn?
SFT might produce:
"She worked 50 minutes which is 50/60 hours, so 12 * 50/60 = $10"
(Correct reasoning, correct answer)
But RL can explore:
"First convert to hours: 50 minutes = 50/60 = 0.833 hours
Then multiply: 12 * 0.833 = $10"
(Alternative valid approach)
Or discover:
"12/60 = $0.2 per minute
50 * 0.2 = $10"
(More direct solution)
RL explores the solution space and learns which strategies work best, potentially finding better approaches than the training demonstrations.
nanochat's RL Implementation
Simplified GRPO → REINFORCE
From scripts/chat_rl.py:
"""
I put GRPO in quotes because we actually end up with something a lot
simpler and more similar to just REINFORCE:
1) Delete trust region, so there is no KL regularization to a reference model
2) We are on policy, so there's no need for PPO ratio+clip.
3) We use GAPO style normalization that is token-level, not sequence-level.
4) Instead of z-score normalization (r - mu)/sigma, only use (r - mu) as the advantage.
"""Translation: nanochat uses vanilla policy gradients (REINFORCE) with mean-baseline advantages. No PPO complexity, no reference model, just simple and effective RL.
Core RL Hyperparameters
source = "sft" # Start from SFT model (not base)
device_batch_size = 8 # Max samples per forward pass
examples_per_step = 16 # Training examples per gradient step
num_samples = 16 # Samples per example (for advantage estimation)
max_new_tokens = 256 # Max response length
temperature = 1.0 # Sampling temperature
top_k = 50 # Top-k sampling
# Learning rates (same structure as SFT)
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
weight_decay = 0.0
init_lr_frac = 0.05 # Start at 5% of base LR
num_epochs = 1 # One pass through GSM8K train set
save_every = 60 # Checkpoint frequency
eval_every = 60 # Evaluation frequency
eval_examples = 400 # Examples for pass@k evaluationKey insight: We start from the SFT model, not the base model. SFT provides a good initialization—the model already knows how to format responses. RL fine-tunes the reasoning quality.
The Training Dataset: GSM8K
GSM8K (Grade School Math 8K) is perfect for RL because:
- Clear success criterion: Either the final answer is correct or not (binary reward)
- Tool use: Problems require calculator operations, testing agent capabilities
- Diverse strategies: Multiple valid solution paths exist
Example from tasks/gsm8k.py:
Question:
Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?
Answer:
Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.
#### 10
Notice the <<...>> tags—these are calculator tool calls. The model learns to invoke tools and use their outputs.
Conversation Format for GSM8K
The dataset is converted to conversations:
messages = [
{"role": "user", "content": "Weng earns $12 an hour..."},
{"role": "assistant", "content": [
{"type": "text", "text": "Weng earns 12/60 = $"},
{"type": "python", "text": "12/60"},
{"type": "python_output", "text": "0.2"},
{"type": "text", "text": "0.2 per minute.\nWorking 50 minutes, she earned 0.2 x 50 = $"},
{"type": "python", "text": "0.2*50"},
{"type": "python_output", "text": "10"},
{"type": "text", "text": "10.\n#### 10"},
]}
]The tokenizer renders this with special tokens:
<|user_start|>Weng earns $12 an hour...<|user_end|>
<|assistant_start|>Weng earns 12/60 = $<|python_start|>12/60<|python_end|><|output_start|>0.2<|output_end|>0.2 per minute...
Rollout Generation
The Rollout Loop
The core of RL training is generating rollouts—sampling completions from the current policy:
@torch.no_grad()
def get_batch():
assistant_end = tokenizer.encode_special("<|assistant_end|>")
rank_indices = range(ddp_rank, len(train_task), ddp_world_size)
for example_idx in itertools.cycle(rank_indices):
# Get the conversation
conversation = train_task[example_idx]
# Tokenize: remove assistant message, prime for completion
tokens = tokenizer.render_for_completion(conversation)
prefix_length = len(tokens)
# Generate num_samples completions
model.eval()
generated_token_sequences = []
masks = []
num_sampling_steps = num_samples // device_batch_size
for sampling_step in range(num_sampling_steps):
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF
try:
with autocast_ctx:
sequences_batch, masks_batch = engine.generate_batch(
tokens,
num_samples=device_batch_size,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
seed=seed,
)
generated_token_sequences.extend(sequences_batch)
masks.extend(masks_batch)
except RuntimeError as e:
if "out of memory" in str(e):
logging.error(f"OOM during rollout at step {step}, example {example_idx}. Clearing cache.")
torch.cuda.empty_cache()
# Reduce batch size for this rollout
device_batch_size = max(1, device_batch_size // 2)
continue
else:
raise eKey points:
- Multiple samples per question: We generate 16 completions for each question to estimate advantage
- Batched generation: Use
device_batch_size=8to avoid OOM, run 2 sampling steps (8×2=16) - Deterministic seeds: Reproducibility via
hash((step, example_idx, sampling_step)) - Masks track generation:
mask=1for sampled tokens,mask=0for forced tokens (prompts, tool outputs)
Reward Calculation
After generating completions, compute rewards:
rewards = []
for sample_tokens in generated_token_sequences:
# Extract generated response (after prompt)
generated_tokens = sample_tokens[prefix_length:]
generated_text = tokenizer.decode(generated_tokens)
# Calculate reward
reward = train_task.reward(conversation, generated_text)
rewards.append(reward)The reward() function extracts the final answer:
def reward(self, conversation, assistant_response):
"""Binary reward: 1.0 if answer is correct, 0.0 otherwise"""
is_correct = self.evaluate(conversation, assistant_response)
return float(is_correct)Extraction logic:
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
"""Extract numerical answer after #### marker"""
match = GSM_RE.search(completion)
if match:
match_str = match.group(1).strip().replace(",", "")
return match_str
return NoneExample:
Response: "She earned 0.2 * 50 = $10. #### 10"
Extracted: "10"
Ground truth: "10"
Reward: 1.0 ✓
Advantage Estimation
Simple mean-baseline advantage:
rewards = torch.tensor(rewards, dtype=torch.float, device=device) # (B,)
mu = rewards.mean()
advantages = rewards - mu # Simple baseline, no z-score normalizationWhy mean baseline?
- Reduces variance in gradient estimates
- Centers advantages around zero (positive for above-average, negative for below-average)
- Simpler than z-score
(r - mu)/sigma, which can be unstable with binary rewards
Example with 16 samples:
Rewards: [0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1]
Mean: 0.5625
Advantages: [-0.56, 0.44, -0.56, -0.56, 0.44, 0.44, -0.56, -0.56, 0.44, -0.56, 0.44, 0.44, -0.56, 0.44, -0.56, 0.44]
Samples with correct answers get positive advantages (upweighted), incorrect get negative (downweighted).
Collation and Padding
Pad sequences to uniform length for batching:
max_length = max(len(seq) for seq in generated_token_sequences)
padded_sequences = [seq + [assistant_end] * (max_length - len(seq))
for seq in generated_token_sequences]
padded_masks = [mask + [0] * (max_length - len(mask))
for mask in masks]
ids = torch.tensor(padded_sequences, dtype=torch.long, device=device)
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
# Autoregressive setup
inputs = ids[:, :-1]
targets = ids[:, 1:].clone()
targets[mask_ids[:, 1:] == 0] = -1 # Mask prompt and tool outputsMasking strategy:
- Prompt tokens:
mask=0→target=-1(not trained on) - Tool outputs:
mask=0→target=-1(not trained on) - Sampled tokens:
mask=1→target=token_id(trained on)
This ensures we only optimize the model's own generated text, not forced context.
Policy Gradient Optimization
The PG Objective
for example_step in range(examples_per_rank):
# Get batch for one training example
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
model.train()
# Process in device_batch_size chunks
num_passes = inputs_all.size(0) // device_batch_size
for pass_idx in range(num_passes):
b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
inputs = inputs_all[b0:b1]
targets = targets_all[b0:b1]
advantages = advantages_all[b0:b1]
try:
# Calculate log probabilities
with autocast_ctx:
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
# Check for NaN in loss (indicates instability)
if torch.isnan(logp).any():
logging.warning(f"NaN detected in policy gradient at step {step}, pass {pass_idx}. Skipping.")
optimizer.zero_grad()
continue
# Policy gradient objective
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
# Normalize by valid tokens and batch structure
num_valid = (targets >= 0).sum().clamp(min=1)
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
# Minimize negative objective (maximize objective)
loss = -pg_obj
loss.backward()
except RuntimeError as e:
if "out of memory" in str(e):
logging.error(f"OOM during backward pass at step {step}. Clearing cache.")
torch.cuda.empty_cache()
optimizer.zero_grad()
continue
else:
raise eMathematical breakdown:
- Log probability:
model(inputs, targets)returns negative log-likelihood (NLL), so we negate to get log-prob - Weighted by advantage:
logp * advantagesupweights correct samples, downweights incorrect - Summed over tokens: Each token in the sequence contributes to the objective
- Normalized: Divide by number of valid tokens to make loss scale-invariant
- Gradient ascent: Maximize objective → minimize
-pg_obj
Why No PPO?
Traditional PPO uses:
ratio = torch.exp(logp - old_logp)
clipped_ratio = torch.clamp(ratio, 1-epsilon, 1+epsilon)
loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()nanochat skips this because:
- On-policy: We sample from the current policy, so
ratio ≈ 1anyway - Simplicity: Clipping adds hyperparameters (
epsilon) and complexity - Stability: The mean baseline and small learning rates provide sufficient stability
When would you need PPO? Off-policy learning (reusing old samples) or very large policy updates.
No Reference Model KL
Some RLHF methods add KL divergence to a reference model:
loss = -pg_obj + beta * KL(π, π_ref)nanochat skips this:
- Trust in SFT initialization: Starting from a good SFT model reduces need for regularization
- Conservative learning rates:
init_lr_frac=0.05means small updates - Short training: 1 epoch doesn't allow much divergence
When would you need KL regularization? Long training runs or when the reward function is exploitable (e.g., a learned reward model that can be gamed).
Tool Use in RL: The Calculator
Tool State Machine in the Engine
From nanochat/engine.py:
if next_token == python_start:
state.in_python_block = True
state.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:
state.in_python_block = False
if state.python_expr_tokens:
expr = self.tokenizer.decode(state.python_expr_tokens)
result = use_calculator(expr)
if result is not None:
result_tokens = self.tokenizer.encode(str(result))
state.forced_tokens.append(output_start)
state.forced_tokens.extend(result_tokens)
state.forced_tokens.append(output_end)
state.python_expr_tokens = []
elif state.in_python_block:
state.python_expr_tokens.append(next_token)State machine:
- Model generates
<|python_start|> - Engine enters "tool mode", collects tokens
- Model generates
<|python_end|> - Engine evaluates the expression with
use_calculator() - Engine forces the result tokens into the sequence
- Model continues generating, incorporating the tool output
Safe Calculator Evaluation
def use_calculator(expr):
"""Evaluate a math expression safely"""
expr = expr.replace(",", "")
# Only allow numeric chars and basic operators
if any([x not in "0123456789*+-/.() " for x in expr]):
return None
# Disallow power operator (expensive)
if "**" in expr:
return None
return eval_with_timeout(expr, max_time=3)Safety measures:
- Whitelist characters (no variables, no functions)
- No power operator (prevents
9**9**9**9DoS) - 3-second timeout (prevents infinite loops)
Why Tool Use Matters for RL
The calculator provides groundedness:
- Correct calculations:
12/60 = 0.2is always correct - Reduced hallucination: Model doesn't need to memorize arithmetic
- Credit assignment: If the tool returns the right intermediate value, the model learns that invoking it was good
During training:
- Tool invocation tokens: supervised (model learns when to call tools)
- Tool output tokens: not supervised (forced by environment)
This creates a natural division: the model controls when to use tools, the environment provides what the tools return.
Evaluation: Pass@k Metric
What is Pass@k?
Instead of accuracy (pass@1), we measure:
Pass@k: Probability that at least one of k samples is correct
This is more forgiving and reflects real usage (users try multiple times).
Implementation
if step % eval_every == 0:
model.eval()
passk = torch.zeros(device_batch_size, device=device)
with autocast_ctx:
records_iter = run_gsm8k_eval(val_task, tokenizer, engine,
num_samples=device_batch_size,
max_examples=eval_examples,
temperature=1.0)
records = list(records_iter)
# Calculate pass@k for k=1..device_batch_size
for k in range(1, device_batch_size + 1):
passk[k-1] = sum(any(o["is_correct"] for o in r["outcomes"][:k])
for r in records)
# Aggregate across ranks
if ddp:
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
passk = passk / num_recordsInterpretation:
Step 0 | Pass@1: 0.2500, Pass@2: 0.3750, Pass@4: 0.5000, Pass@8: 0.6250
- 25% of problems are solved on first try
- 62.5% are solved if you try 8 times
- Diversity in sampling helps
Why Temperature=1.0 for Evaluation?
During training: temperature=1.0 (explore diverse solutions) During evaluation: temperature=1.0 (measure diverse solution quality)
If we used temperature=0.0 (greedy), we'd only measure the single best solution, missing the model's ability to find correct answers through different paths.
Learning Rate Schedule
Linear Decay
def get_lr_multiplier(it):
lrm = 1.0 - it / num_steps
return lrm
# Apply each step
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrmSchedule visualization (1000 steps):
Step 0: lrm=1.000 → lr=100%
Step 250: lrm=0.750 → lr=75%
Step 500: lrm=0.500 → lr=50%
Step 750: lrm=0.250 → lr=25%
Step 1000: lrm=0.000 → lr=0%
Why linear for RL?
- RL training is inherently noisy (rewards are sparse/binary)
- Gradual reduction prevents large updates late in training
- Reaches zero at the end, ensuring convergence
Practical Training Guide
Step 1: Ensure You Have an SFT Model
# Check for SFT checkpoint
ls -la chatsft_checkpoints/
# Should see: d12/ or d24/ with model_step_*.ptIf you don't have one, run SFT first (see Fine-tuning for Chat).
Step 2: Configure RL Training
Create rl_config.txt:
run = "gsm8k_rl_run1"
source = "sft"
num_epochs = 1
examples_per_step = 16
device_batch_size = 8
num_samples = 16
temperature = 1.0
init_lr_frac = 0.05
eval_every = 30
save_every = 30Step 3: Launch Training
Single GPU:
python -m scripts.chat_rl --config rl_config.txtMulti-GPU (8 GPUs):
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --config rl_config.txtStep 4: Monitor Training
Watch for:
Step 0 | Pass@1: 0.2500, Pass@2: 0.3750, Pass@4: 0.5000, Pass@8: 0.6250
Step 0/500 | Example step 0 | Pass 0 | loss: 0.532145 | Average reward: 0.5625
Step 0/500 | Average reward: 0.5625 | Average sequence length: 147.23
Step 60 | Pass@1: 0.3125, Pass@2: 0.4375, Pass@4: 0.5625, Pass@8: 0.6875
Step 60/500 | Example step 0 | Pass 0 | loss: 0.421876 | Average reward: 0.6250
Step 60/500 | Average reward: 0.6250 | Average sequence length: 138.45
Step 120 | Pass@1: 0.3750, Pass@2: 0.5000, Pass@4: 0.6250, Pass@8: 0.7500
Good signs:
- Pass@k metrics increasing over time
- Average reward improving
- Loss decreasing (but can be noisy)
- Sequence length stabilizing (model not degenerating to very short/long outputs)
Bad signs:
- Pass@k plateauing or decreasing (model not learning)
- Average reward = 0.0 or 1.0 (reward function broken)
- Sequence length → 0 or → max_tokens (model collapsing)
Advanced Topics
Custom Reward Functions
Binary rewards are simple but coarse. You can design richer rewards:
def reward(self, conversation, assistant_response):
# Base reward: correctness
is_correct = self.evaluate(conversation, assistant_response)
reward = float(is_correct)
# Bonus for shorter solutions
length = len(assistant_response)
reward += 0.1 * max(0, 1 - length / 500)
# Bonus for showing work
if "<<" in assistant_response: # Used calculator
reward += 0.05
# Penalty for format violations
if "####" not in assistant_response:
reward -= 0.2
return rewardDesign principles:
- Dominant term: Correctness should be the main driver
- Small bonuses: Auxiliary rewards should be 10-20% of main reward
- Avoid exploitation: Don't reward superficial patterns (e.g., just typing "####" without solving)
Dense Rewards
Instead of binary outcome, reward intermediate progress:
def reward(self, conversation, assistant_response):
# Extract all calculator results
calc_results = re.findall(r'<<(.+?)=(.+?)>>', assistant_response)
# Check if intermediate steps are correct
correct_steps = 0
for expr, result in calc_results:
expected = eval(expr)
if abs(float(result) - expected) < 0.01:
correct_steps += 1
# Partial credit
if correct_steps > 0:
return 0.1 * correct_steps
# Full credit for correct final answer
if self.evaluate(conversation, assistant_response):
return 1.0
return 0.0This provides feedback even when the final answer is wrong, helping the model learn intermediate steps.
Multi-Objective Optimization
Combine multiple reward signals:
rewards_dict = {
"correctness": 1.0 if correct else 0.0,
"efficiency": -len(response) / 1000, # Shorter is better
"clarity": count_explanation_sentences(response) / 10,
}
# Weighted sum
reward = (
1.0 * rewards_dict["correctness"] +
0.1 * rewards_dict["efficiency"] +
0.05 * rewards_dict["clarity"]
)Track individual components in logging to diagnose what the model is optimizing for.
Reward Shaping Pitfalls
Pitfall 1: Overfitting to Proxies
# BAD: Reward word count
reward += 0.01 * word_count # Model learns to be verbosePitfall 2: Conflicting Signals
# BAD: Contradictory rewards
reward += 0.1 if len(response) < 100 else 0 # Reward brevity
reward += 0.1 if "detailed explanation" in response else 0 # Reward detailPitfall 3: Reward Hacking
# BAD: Exploitable pattern
reward += 0.5 if "Therefore, the answer is" in response else 0
# Model learns to always write this phrase regardless of correctnessSolution: Always tie rewards to outcome metrics (accuracy, user satisfaction, etc.) and validate on held-out sets.
Comparison: SFT vs RL
| Aspect | SFT | RL |
|---|---|---|
| Objective | Imitate demonstrations | Maximize reward |
| Data | Fixed (x, y) pairs | Dynamic (generate y, score it) |
| Training Signal | Cross-entropy loss | Policy gradient |
| Exploration | None (teacher forcing) | Sampling-based |
| Strengths | Stable, fast, learns format | Discovers novel solutions |
| Weaknesses | Limited by demos | Noisy, slower, can diverge |
| Use Case | Teach format & basics | Optimize for quality |
Best Practice: SFT first (provides good initialization), then RL (optimizes quality).
Debugging Tips
Reward Distribution Analysis
# During rollout generation
print(f"Reward distribution: {torch.bincount(rewards.long())}")
# Output: tensor([10, 6]) → 10 wrong, 6 correct
# Calculate success rate
success_rate = rewards.sum() / len(rewards)
print(f"Success rate: {success_rate:.2%}")What to look for:
- Early training: ~25-50% success (better than random, not perfect)
- Late training: ~60-80% success (strong performance)
- All zeros: Reward function broken or task too hard
- All ones: Reward function broken or task too easy
Gradient Norms
# After loss.backward()
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
print(f"Gradient norm: {total_norm:.4f}")Typical values:
- Healthy: 0.1 - 10.0
- Too small (<0.01): Learning rate too low or gradients vanishing
- Too large (>100): Learning rate too high or exploding gradients
Advantage Distribution
advantages = rewards - rewards.mean()
print(f"Advantages: min={advantages.min():.3f}, max={advantages.max():.3f}, std={advantages.std():.3f}")Healthy distribution:
Advantages: min=-0.563, max=0.438, std=0.501
Centered around zero, reasonable spread.
Degenerate:
Advantages: min=0.000, max=0.000, std=0.000
All rewards identical—model not exploring or task trivial.
Sequence Length Tracking
sequence_lengths = [len(seq) for seq in generated_token_sequences]
print(f"Seq lengths: min={min(sequence_lengths)}, max={max(sequence_lengths)}, mean={sum(sequence_lengths)/len(sequence_lengths):.1f}")Healthy:
Seq lengths: min=87, max=203, mean=145.3
Degenerate:
Seq lengths: min=3, max=5, mean=4.1 → Model collapsed (just says "####0")
Seq lengths: min=256, max=256, mean=256.0 → Model rambling (hitting max_tokens)
Performance Expectations
Training Times
On 8× A100 GPUs (80GB):
| Metric | Value |
|---|---|
| Examples per step | 16 |
| Samples per example | 16 |
| Total sequences per step | 256 |
| Steps per epoch (GSM8K) | ~525 |
| Time per step | ~30 seconds |
| Total training time | ~4.5 hours |
| Cost (AWS) | ~$110 |
Expected Improvements
Starting from SFT model:
| Metric | SFT | After RL | Delta |
|---|---|---|---|
| GSM8K Pass@1 | 15-25% | 30-40% | +15% |
| GSM8K Pass@4 | 25-35% | 50-60% | +25% |
| GSM8K Pass@8 | 30-40% | 60-70% | +30% |
NOTE
These are rough estimates for a 12-layer model. Larger models see bigger gains.
Memory Requirements
Per GPU:
| Component | Memory |
|---|---|
| Model (BF16) | ~450 MB |
| Rollout generation (batch=8) | ~8 GB |
| Forward pass (batch=8) | ~8 GB |
| Gradients | ~450 MB |
| Total | ~17 GB |
RL is more memory-intensive than SFT due to rollout generation.
Common Pitfalls
1. Starting from Base Model
Symptom: Model generates gibberish or doesn't follow format.
Solution: Always start from SFT. The base model doesn't know conversation structure.
2. Insufficient Samples per Example
Symptom: High variance in rewards, unstable training.
Solution: Increase num_samples (try 16-32). More samples = better advantage estimates.
3. Learning Rate Too High
Symptom: Pass@k oscillates wildly or collapses to zero.
Solution: Reduce init_lr_frac (try 0.02-0.05) or lower the base learning rates.
4. Reward Hacking
Symptom: Model achieves high reward but produces nonsense.
Example: Model learns to output "####" followed by random numbers, getting partial credit.
Solution: Make reward function more robust—check that the reasoning is present, not just the format.
5. Mode Collapse
Symptom: All generated responses become identical.
Solution:
- Increase temperature (try 1.0-1.2)
- Add entropy bonus to reward:
reward += 0.01 * entropy - Reduce training duration (model is overfitting)
Extending Beyond GSM8K
Other Tasks for RL
Coding (HumanEval):
def reward(self, conversation, assistant_response):
# Extract code, run test cases
code = extract_code_block(assistant_response)
test_results = run_tests(code, self.test_cases)
return float(all(test_results))Instruction Following:
def reward(self, conversation, assistant_response):
# Check if response follows constraints
instruction = parse_instruction(conversation)
follows_format = check_format(assistant_response, instruction.format)
includes_keywords = check_keywords(assistant_response, instruction.keywords)
return 0.5 * follows_format + 0.5 * includes_keywordsConversational Quality:
def reward(self, conversation, assistant_response):
# Use a reward model (small LM trained on human preferences)
return reward_model.score(conversation + [assistant_response])Multi-Turn RL
For dialogue, optimize over full conversations:
def get_batch_multiturn():
# Start with conversation up to turn N-1
conversation = sample_conversation(max_turns=3)
# Generate turn N
tokens = tokenizer.render_for_completion(conversation)
sequences = engine.generate_batch(tokens, num_samples=16)
# Reward based on full conversation
rewards = [reward_conversation(conversation + [response]) for response in sequences]
yield sequences, rewardsState of the Art: What's Next?
Constitutional AI
Train models to self-critique and revise:
1. Generate initial response
2. Critique: "What's wrong with this answer?"
3. Revise: Generate improved response
4. Reward revision
Outcome-Supervised RL
Instead of rewarding final answers, reward intermediate reasoning:
Reward:
- Each correct reasoning step: +0.1
- Correct final answer: +1.0
- Self-correction: +0.2
Learned Reward Models
Instead of hand-coded rewards, train a model on human preferences:
1. Collect human comparisons: "Response A is better than Response B"
2. Train reward model to predict human preferences
3. Use reward model in RL loop
This is the "HF" (Human Feedback) in RLHF!
Conclusion
Reinforcement Learning allows chat models to go beyond imitation, discovering solutions through exploration and optimization. The key insights:
- Build on SFT: Start from a strong instruction-following model
- Simple works: Vanilla policy gradients (REINFORCE) are effective
- Tool integration: Calculator provides grounding for math reasoning
- Pass@k evaluation: Measures solution diversity, not just single-path accuracy
- Reward design matters: Binary rewards are simple; consider shaping for complex tasks
nanochat's RL implementation is minimalistic yet powerful. By understanding these principles, you can adapt it to your own tasks—coding, instruction following, creative writing, or any domain where "better" is easier to judge than to demonstrate.
The next post covers building custom evaluation tasks: creating benchmarks that measure what truly matters for your use case.
Related Posts
Previous in series:
- Fine-tuning for Chat (SFT) - Transform base models into chat assistants
Next in series:
- Building Custom Evaluation Tasks - Create domain-specific benchmarks
Related posts:
- Training Your First Model - Foundation for SFT and RL
- Modern Transformer Architecture - Understanding the model you're optimizing
Part of the nanochat Deep-Dive Series • Track 2: Practical Guides
GitHub: nanochat repository
RL Script: scripts/chat_rl.py
TIP
Pass@k evaluation is a powerful metric for measuring solution diversity. Consider implementing it for your own tasks, even outside RL!
On this page
- Prerequisites and Installation
- Introduction
- The RL Paradigm for Chat Models
- From Imitation to Optimization
- Why RL After SFT?
- nanochat's RL Implementation
- Simplified GRPO → REINFORCE
- Core RL Hyperparameters
- The Training Dataset: GSM8K
- Conversation Format for GSM8K
- Rollout Generation
- The Rollout Loop
- Reward Calculation
- Advantage Estimation
- Collation and Padding
- Policy Gradient Optimization
- The PG Objective
- Why No PPO?
- No Reference Model KL
- Tool Use in RL: The Calculator
- Tool State Machine in the Engine
- Safe Calculator Evaluation
- Why Tool Use Matters for RL
- Evaluation: Pass@k Metric
- What is Pass@k?
- Implementation
- Why Temperature=1.0 for Evaluation?
- Learning Rate Schedule
- Linear Decay
- Practical Training Guide
- Step 1: Ensure You Have an SFT Model
- Step 2: Configure RL Training
- Step 3: Launch Training
- Step 4: Monitor Training
- Advanced Topics
- Custom Reward Functions
- Dense Rewards
- Multi-Objective Optimization
- Reward Shaping Pitfalls
- Comparison: SFT vs RL
- Debugging Tips
- Reward Distribution Analysis
- Gradient Norms
- Advantage Distribution
- Sequence Length Tracking
- Performance Expectations
- Training Times
- Expected Improvements
- Memory Requirements
- Common Pitfalls
- 1. Starting from Base Model
- 2. Insufficient Samples per Example
- 3. Learning Rate Too High
- 4. Reward Hacking
- 5. Mode Collapse
- Extending Beyond GSM8K
- Other Tasks for RL
- Multi-Turn RL
- State of the Art: What's Next?
- Constitutional AI
- Outcome-Supervised RL
- Learned Reward Models
- Conclusion
- Related Posts



