José David Baena

On This Page

On this page

Fine-tuning for Chat (SFT)

Banner.jpeg
Published on
/23 mins read

Track 2: Practical Guides - Post 2.2 of 6

This post builds on Training Your First Model. View all posts in this track →

SFT transforms your base model into something you can talk to

The difference between a base model and a chat model isn't magic—it's supervised fine-tuning on the right data with the right masking strategy. I've walked through nanochat's SFT pipeline end-to-end to understand exactly where that transformation happens.

Your base model predicts tokens. SFT teaches it to have a conversation—and 2K high-quality examples beat 10K noisy ones.

TL;DR: Special tokens define conversation boundaries. Masking trains only on assistant responses. Dataset quality matters more than quantity—2K curated conversations outperform 10K scraped ones. One epoch is often enough. OpenAI's InstructGPT paper from 2022 showed that quality data beats quantity. Three years later, that insight still drives the best fine-tuning results.

The dataset that broke everything: Consider a common failure mode: collecting 50K conversations from Reddit for SFT. The model learns to mimic Reddit style—sarcasm, incomplete sentences, "this" as a valid reply. Users hate it. Replacing the dataset with 3K manually curated support conversations—clear questions, helpful answers, complete sentences—transforms user satisfaction on the same base model with the same training recipe. The lesson isn't "more data good." It's "curated data beats scraped data, every time." Two weeks of human curation beats two months of web scraping.

You've trained a base language model that's excellent at predicting the next token in documents. But to build a useful chatbot, you need more than raw language modeling—you need a model that understands conversation structure, follows instructions, answers questions accurately, and maintains helpful, coherent dialogue.

This transformation happens through Supervised Fine-Tuning (SFT), where you train your base model on carefully curated conversation datasets. Here's what makes nanochat's SFT implementation successful:

  • The conversation format and tokenization strategy for chat
  • How to prepare and mix training datasets effectively
  • The SFT training loop with specialized optimizers and schedulers
  • Evaluation strategies for chat models
  • Practical considerations for dataset selection and hyperparameters

Special tokens turn documents into structured conversations

The Fundamental Shift

Base model training teaches a model to predict p(token | previous_tokens) from raw text. Chat fine-tuning teaches it to predict p(assistant_response | conversation_history) with structure:

<|bos|>
<|user_start|>What is the capital of France?<|user_end|>
<|assistant_start|>The capital of France is Paris.<|assistant_end|>
<|user_start|>What's the population?<|user_end|>
<|assistant_start|>Paris has approximately 2.1 million residents...<|assistant_end|>

This structure is enforced through special tokens that delimit different parts of the conversation.

Special Tokens Design

From nanochat/tokenizer.py:

SPECIAL_TOKENS = [
    "<|bos|>",              # Beginning of sequence (document delimiter)
    "<|user_start|>",       # User message start
    "<|user_end|>",         # User message end
    "<|assistant_start|>",  # Assistant message start
    "<|assistant_end|>",    # Assistant message end
    "<|python_start|>",     # Python tool invocation
    "<|python_end|>",       # Python tool end
    "<|output_start|>",     # Tool output start
    "<|output_end|>",       # Tool output end
]

Key design decisions:

  1. Explicit delimiters: Each role and boundary is marked, making it unambiguous to the model what's happening
  2. Tool support: Built-in support for tool calling (Python REPL) for agentic behavior
  3. Output masking: Tool outputs aren't supervised (the model doesn't generate them at inference)

Masking trains only on what the model should generate

The Tokenization Strategy

The most critical function in SFT is render_conversation(), which converts a conversation dict into token IDs with a supervision mask:

def render_conversation(self, conversation, max_tokens=2048):
    """
    Returns:
    - ids: list[int] - token IDs of the rendered conversation
    - mask: list[int] - same length, 1 for tokens to supervise, 0 for others
    """
    ids, mask = [], []
    
    def add_tokens(token_ids, mask_val):
        if isinstance(token_ids, int):
            token_ids = [token_ids]
        ids.extend(token_ids)
        mask.extend([mask_val] * len(token_ids))

What Gets Supervised?

This is the most important design decision in SFT:

Token TypeMasked?Reason
`<bos>`
`<user_start>`
User message content✓ Yes (0)User wrote this, not assistant
`<user_end>`
`<assistant_start>`
Assistant message contentNo (1)This is what we train!
`<assistant_end>`
Tool outputs✓ Yes (0)Come from environment, not model

From nanochat/tokenizer.py:

if message["role"] == "user":
    value_ids = self.encode(content)
    add_tokens(user_start, 0)      # Not supervised
    add_tokens(value_ids, 0)        # Not supervised
    add_tokens(user_end, 0)         # Not supervised
elif message["role"] == "assistant":
    add_tokens(assistant_start, 0)  # Not supervised (already given at inference)
    if isinstance(content, str):
        value_ids = self.encode(content)
        add_tokens(value_ids, 1)    # SUPERVISED!
    add_tokens(assistant_end, 1)    # SUPERVISED (must learn to stop!)

Why mask user messages? The model should predict assistant responses given user inputs, not regenerate what the user said. This is called teacher forcing with selective supervision.

Why supervise <|assistant_end|>? The model must learn when to stop generating. This token becomes the natural stopping point.

Conversation Mask Visualizer

See exactly which tokens are supervised during SFT training

Tokenized with Supervision Mask

<|bos|><|user_start|>What·is·the·capital·of·France?<|user_end|><|assistant_start|>The·capital·of·France·is·Paris.<|assistant_end|>
Supervised (mask=1) — Model learns to predict theseMasked (mask=0) — Context only, not trained on
Total Tokens
27
Supervised
12
Masked
15
Supervision Ratio
44.4%

SFT Masking Rules

Token TypeMaskReason
<|bos|>0Document delimiter, not predicted
<|user_start|>0Structural token, not generated
User content0User wrote this, not assistant
<|assistant_start|>0Already given at inference
Assistant content1This is what we train!
<|assistant_end|>1Model must learn to stop

Why this matters: By only supervising assistant tokens, the model learns to generate helpful responses given user context. The special token <|assistant_end|> is crucial—it teaches the model when to stop generating. A typical dataset has 50-60% supervision ratio; if yours is much lower, your assistant responses might be too short.

Visualizing the Mask

The tokenizer includes a helpful debugging function:

def visualize_tokenization(self, ids, mask):
    RED = '\033[91m'    # Masked (not supervised)
    GREEN = '\033[92m'  # Supervised
    RESET = '\033[0m'
    tokens = []
    for token_id, mask_val in zip(ids, mask):
        token_str = self.decode([token_id])
        color = GREEN if mask_val == 1 else RED
        tokens.append(f"{color}{token_str}{RESET}")
    return '|'.join(tokens)

This makes it immediately obvious which tokens the model is being trained on.

Dataset mixing balances instruction diversity and depth

Dataset Selection

From scripts/chat_sft.py:

train_ds = TaskMixture([
    ARC(subset="ARC-Easy", split="train"),        # 2.3K rows
    ARC(subset="ARC-Challenge", split="train"),   # 1.1K rows
    GSM8K(subset="main", split="train"),          # 8K rows
    SmolTalk(split="train", stop=10_000),         # 10K rows
])  # Total: 21.4K training examples
 
val_ds = SmolTalk(split="test")  # 24K validation examples

Dataset composition strategy:

  1. Reasoning tasks (ARC, GSM8K): Teach structured problem-solving
  2. Conversational data (SmolTalk): Teach natural dialogue patterns
  3. Balance: ~50% general conversation, ~50% specific skills

The TaskMixture Pattern

The TaskMixture class elegantly handles multi-dataset training:

class TaskMixture(Task):
    def __init__(self, tasks, **kwargs):
        self.tasks = tasks
        self.lengths = [len(task) for task in self.tasks]
        self.num_conversations = sum(self.lengths)
        
        # Build index map of (task_idx, local_idx) pairs
        self.index_map = []
        for task_idx, task_length in enumerate(self.lengths):
            for local_idx in range(task_length):
                self.index_map.append((task_idx, local_idx))
        
        # Deterministically shuffle to mix tasks
        rng = random.Random(42)
        rng.shuffle(self.index_map)

Why shuffle? Without shuffling, the model would see all ARC examples, then all GSM8K, then all SmolTalk. This can lead to catastrophic forgetting—later tasks overwrite earlier learnings. Shuffling creates a mixed curriculum.

For your fine-tuning datasets, this means: always shuffle across task types. If you have 3 datasets of 10K examples each, don't train sequentially—interleave them. Catastrophic forgetting silently destroys your early training work.

For your production pipeline, this means: deterministic seeding (seed=42) isn't just good practice—it's your debugging lifeline. When users report regressions, you can reproduce the exact training order.

Dataset Mixture Planner

Design your SFT data mix to balance quality, diversity, and volume

Data Sources & Weights

ShareGPT
25.0%
12,500 samples used0.25 epochs
Alpaca
25.0%
12,500 samples used0.24 epochs
FLAN
40.0%
20,000 samples used0.20 epochs
Custom
10.0%
5,000 samples used1.00 epochs

Mixture Distribution

Avg Quality
3.6
Avg Diversity
4.0
Max Epochs
1.0x
Overfit Risk
✓ Low

Data Mixing Tips

  • Quality over quantity: 5K high-quality examples often beat 50K noisy ones
  • Watch epochs: Seeing the same example 3+ times leads to memorization
  • Task diversity: Mix coding, writing, reasoning, and conversational data
  • Temperature sampling: For larger sources, you can subsample with temperature

Why deterministic (seed=42)? Reproducibility. The same codebase produces identical training order every time.

Data Collation and Padding

Conversations have variable lengths. The data generator handles this with padding:

def sft_data_generator(dataset, batch_size):
    pad_token_id = tokenizer.encode_special("<|assistant_end|>")
    
    def collate_and_yield(batch):
        nrows = len(batch)
        ncols = max(len(ids) for ids, mask in batch) - 1
        
        inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
        targets = torch.full((nrows, ncols), -1, dtype=torch.long)  # -1 = ignore
        
        for i, (ids, mask) in enumerate(batch):
            n = len(ids)
            ids_tensor = torch.tensor(ids, dtype=torch.long)
            inputs[i, :n-1] = ids_tensor[:-1]
            
            row_targets = ids_tensor[1:]
            mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
            row_targets[mask_tensor == 0] = -1  # Apply supervision mask
            targets[i, :n-1] = row_targets
        
        return inputs.to(device), targets.to(device)

Critical details:

  1. Pad token choice: Using <|assistant_end|> as padding is safe because padded positions get -1 targets (ignored in loss)
  2. Ignore index: PyTorch's CrossEntropyLoss ignores -1 targets by default
  3. Shifted targets: targets[i] = inputs[i+1] (standard language modeling setup)
  4. Mask application: Zero-masked positions get -1 targets, so loss doesn't update on them

For your data pipeline, this means: the -1 ignore trick is your friend for variable-length sequences. You pay for padding compute, but not for padding gradients—the model only learns from real tokens.

Lower learning rates prevent catastrophic forgetting

Hyperparameters

# Precision
dtype = "bfloat16"              # Memory efficient, stable training
device_batch_size = 4           # Max per GPU without OOM
 
# Optimization
num_epochs = 1                  # Often sufficient for SFT!
target_examples_per_step = 32   # Effective batch size
unembedding_lr = 0.004          # Output layer learning rate
embedding_lr = 0.2              # Input embedding LR (higher!)
matrix_lr = 0.02                # Attention/MLP matrices
weight_decay = 0.0              # No weight decay for SFT
init_lr_frac = 0.02             # Start at 2% of base LR
 
# Evaluation
eval_every = 100                # Validation loss frequency
eval_steps = 100                # Steps to average for val loss
eval_metrics_every = 200        # Full benchmark suite frequency

Why different learning rates?

  • Embedding (0.2): Highest—embeddings learn token meanings from scratch for new special tokens
  • Matrices (0.02): Medium—attention/MLP parameters fine-tune existing knowledge
  • Unembedding (0.004): Lowest—output distribution already well-calibrated from base training

This is the same layer-wise learning rate strategy used in base training, but scaled down by ~50% for fine-tuning.

For your fine-tuning setup, this means: don't use a single learning rate. Embeddings need higher LR to learn new special tokens; output layers need lower LR to preserve calibration.

For your training budget, this means: the 50x embedding LR isn't arbitrary—it compensates for sparse one-hot gradients that only touch one embedding row per token. Get this wrong and new special tokens never learn their meanings.

Gradient Accumulation

examples_per_step = device_batch_size * ddp_world_size
assert target_examples_per_step % examples_per_step == 0
grad_accum_steps = target_examples_per_step // examples_per_step
 
# Training step with error handling
num_tokens = torch.tensor(0, device=device)
for micro_step in range(grad_accum_steps):
    train_inputs, train_targets = next(train_iter)
    
    try:
        with autocast_ctx:
            loss = model(train_inputs, train_targets)
        
        # Check for NaN loss (indicates training instability)
        if torch.isnan(loss):
            logging.warning(f"NaN loss detected at step {step}, micro-step {micro_step}. Skipping batch.")
            optimizer.zero_grad()
            continue
        
        loss = loss / grad_accum_steps  # Normalize for accumulation
        loss.backward()
        num_tokens += (train_targets >= 0).sum()
        
    except RuntimeError as e:
        if "out of memory" in str(e):
            logging.error(f"OOM at step {step}, micro-step {micro_step}. Clearing cache and skipping batch.")
            torch.cuda.empty_cache()
            optimizer.zero_grad()
            continue
        else:
            # Re-raise unexpected errors
            raise e

Key insight: Dividing loss by grad_accum_steps before .backward() ensures gradients have the same scale as a single large batch. This is mathematically equivalent to averaging gradients across micro-batches.

Error handling additions:

  • NaN detection: Training instability (exploding gradients, numerical overflow) can cause NaN losses. Skipping the batch prevents corrupting model weights.
  • OOM recovery: Out-of-memory errors during forward/backward passes are caught, cache is cleared, and training continues with the next batch.
  • Gradient reset: optimizer.zero_grad() ensures partial gradients from failed batches don't accumulate.

For your overnight training runs, this means: these guards are what keep training alive while you sleep. One NaN or OOM shouldn't kill a 12-hour job. Log the skip events and review them in the morning.

Warmup and cosine decay stabilize fine-tuning

Linear Decay

SFT uses simple linear decay from init_lr_frac down to zero:

def get_lr_multiplier(it):
    lrm = 1.0 - it / num_iterations
    return lrm
 
# Apply to all optimizers
for opt in optimizers:
    for group in opt.param_groups:
        group["lr"] = group["initial_lr"] * lrm

Why linear instead of cosine? SFT is typically short (1 epoch, ~600 steps for 21K examples). Linear decay is simpler and works well for short training runs. The model doesn't have time to get "stuck" in local minima that cosine annealing helps with.

Warmup Through init_lr_frac

Instead of explicit warmup, we start at 2% of the base learning rate:

for opt in optimizers:
    for group in opt.param_groups:
        group["lr"] = group["lr"] * init_lr_frac  # Start at 2%
        group["initial_lr"] = group["lr"]

This is effectively a "pre-warmed" start. The model begins with small updates and gradually increases (via the linear decay actually decreasing from 1.0).

Automated evaluation catches formatting and instruction-following issues

Two-Level Evaluation

SFT evaluation happens at two levels:

1. Validation Loss (Every 100 Steps)

if step % eval_every == 0:
    model.eval()
    val_iter = iter(build_val_loader())
    losses = []
    for _ in range(eval_steps):
        val_inputs, val_targets = next(val_iter)
        with torch.no_grad(), autocast_ctx:
            loss = model(val_inputs, val_targets)
        losses.append(loss)
    val_loss = torch.stack(losses).mean()

What it tells you: How well the model predicts assistant responses. Lower is better, but doesn't capture task performance.

2. Task Metrics (Every 200 Steps)

if step % eval_metrics_every == 0:
    metrics = {}
    with torch.no_grad(), autocast_ctx:
        metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, 
                                             batch_size=device_batch_size*2, max_problems=1024)
        metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine,
                                                 batch_size=device_batch_size*2, max_problems=1024)
        metrics["gsm8k_acc"] = run_chat_eval("GSM8K", model, tokenizer, engine, 
                                              max_problems=64)
        metrics["humaneval_acc"] = run_chat_eval("HumanEval", model, tokenizer, engine,
                                                  max_problems=64)

What it tells you: Actual task performance on multiple choice (MMLU, ARC) and generative (GSM8K, HumanEval) benchmarks.

Categorical vs Generative Evaluation

From scripts/chat_eval.py, there are two evaluation modes:

Categorical Evaluation (MMLU, ARC)

def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
    # Render conversation up to answer
    prompt_ids = [tokenizer.render_for_completion(conv) for conv in conversations]
    
    # Get logits for the batch
    logits = model(prompt_ids)  # (B, T, V)
    
    # Focus on answer position and available letters
    letters = conversation['letters']  # e.g., ["A", "B", "C", "D"]
    letter_ids = [tokenizer.encode(letter)[0] for letter in letters]
    focus_logits = logits[idx, answer_pos, letter_ids]
    
    # Argmax over constrained choices
    predicted_letter = letters[focus_logits.argmax().item()]

Why constrained evaluation? Multiple choice tasks are easier when you only evaluate the model's confidence across valid choices (A, B, C, D) rather than generating free text. This is standard practice in benchmarks like MMLU.

Generative Evaluation (GSM8K, HumanEval)

def run_generative_eval(task_object, tokenizer, model, engine, 
                        num_samples, max_new_tokens, temperature, top_k, max_problems=None):
    # Tokenize prompt
    encoded_prompt = tokenizer.render_for_completion(conversation)
    
    # Generate completions
    results, _ = engine.generate_batch(
        encoded_prompt,
        num_samples=num_samples,
        max_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k,
    )
    
    # Decode and evaluate
    completions = [tokenizer.decode(result[prefix_len:]) for result in results]
    outcomes = [task_object.evaluate(conversation, completion) for completion in completions]
    passed = any(outcomes)  # Pass-at-k evaluation

Pass-at-k: For code generation (HumanEval), we generate k samples and pass if any are correct. This is more forgiving and reflects real-world usage (developers try multiple times).

Step-by-step: from base model to chat assistant

Step 1: Prepare Your Base Model

Ensure you have a trained base model:

# Check available models
ls -la base_checkpoints/
 
# Should see: d12/ or d24/ directories with model_step_*.pt files

Step 2: Configure Your Training

Create a config file sft_config.txt:

run = "my_sft_run"
source = "mid"           # or "base" depending on your checkpoint
num_epochs = 1
target_examples_per_step = 32
device_batch_size = 4    # Adjust based on your GPU memory
eval_every = 50
eval_metrics_every = 100

Step 3: Launch SFT Training

Single GPU:

python -m scripts.chat_sft --config sft_config.txt

Multi-GPU (8 GPUs):

torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft --config sft_config.txt

Step 4: Monitor Training

Watch the output for key metrics:

Step 00000 | Validation loss: 2.451234
Step 00000 | mmlu_acc: 0.234000, arc_easy_acc: 0.421000, gsm8k_acc: 0.012000, humaneval_acc: 0.000000
Step 00100 | Training loss: 1.823456 | lrm: 0.980000 | num_tokens: 45,231
Step 00200 | mmlu_acc: 0.287000, arc_easy_acc: 0.498000, gsm8k_acc: 0.078000, humaneval_acc: 0.031000

What to look for:

  • Training loss should decrease steadily
  • Validation loss should track training loss (if it diverges upward, you're overfitting)
  • Task metrics should improve over time, especially on tasks in the training mixture
  • num_tokens shows how many supervised tokens per step (varies due to conversation length)

Step 5: Evaluate the Final Model

After training completes, evaluate on all benchmarks:

torchrun --nproc_per_node=8 -m scripts.chat_eval -- -i sft -a ARC-Easy|ARC-Challenge|MMLU|GSM8K|HumanEval

Advanced: DPO, RLHF, and iterative refinement

Custom Dataset Integration

To add your own conversation dataset, create a Task class:

from tasks.common import Task
 
class MyCustomTask(Task):
    def __init__(self, split, **kwargs):
        super().__init__(**kwargs)
        # Load your data
        self.data = self.load_custom_data(split)
    
    def num_examples(self):
        return len(self.data)
    
    def get_example(self, index):
        row = self.data[index]
        # Return conversation dict
        return {
            "messages": [
                {"role": "user", "content": row["question"]},
                {"role": "assistant", "content": row["answer"]},
            ]
        }

Then add it to the training mixture:

train_ds = TaskMixture([
    ARC(subset="ARC-Easy", split="train"),
    GSM8K(subset="main", split="train"),
    SmolTalk(split="train", stop=10_000),
    MyCustomTask(split="train"),  # Your data!
])

Handling System Messages

Some datasets include system messages (instructions for the assistant's behavior):

messages = [
    {"role": "system", "content": "You are a helpful math tutor."},
    {"role": "user", "content": "How do I solve x^2 = 4?"},
    {"role": "assistant", "content": "To solve x^2 = 4, take the square root..."}
]

The tokenizer automatically handles this by merging the system message with the first user message:

if conversation["messages"][0]["role"] == "system":
    conversation = copy.deepcopy(conversation)
    messages = conversation["messages"]
    assert messages[1]["role"] == "user"
    messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
    messages = messages[1:]  # Remove system message

Multi-Turn Conversations

The tokenizer handles arbitrary-length conversations:

messages = [
    {"role": "user", "content": "What is 2+2?"},
    {"role": "assistant", "content": "2+2 equals 4."},
    {"role": "user", "content": "What about 2+3?"},
    {"role": "assistant", "content": "2+3 equals 5."},
    {"role": "user", "content": "Can you explain why?"},
    {"role": "assistant", "content": "Addition combines quantities..."},
]

All assistant responses are supervised, allowing the model to learn context-dependent responses.

Tool-Augmented Training

For agentic behavior, conversations can include tool calls:

messages = [
    {"role": "user", "content": "What is 123 * 456?"},
    {"role": "assistant", "content": [
        {"type": "text", "text": "Let me calculate that for you."},
        {"type": "python", "text": "123 * 456"},
        {"type": "python_output", "text": "56088"},
        {"type": "text", "text": "The result is 56,088."},
    ]},
]

The tokenizer renders this as:

<|assistant_start|>
Let me calculate that for you.
<|python_start|>123 * 456<|python_end|>
<|output_start|>56088<|output_end|>
The result is 56,088.
<|assistant_end|>

Where:

  • Text and Python code are supervised (mask=1)
  • Python outputs are not supervised (mask=0) because they come from the environment

These debugging patterns reveal common SFT failures

Visualize Tokenization

Use the built-in visualizer to inspect what's being supervised:

from nanochat.tokenizer import get_tokenizer
 
tokenizer = get_tokenizer()
conversation = {
    "messages": [
        {"role": "user", "content": "Hello!"},
        {"role": "assistant", "content": "Hi there!"},
    ]
}
 
ids, mask = tokenizer.render_conversation(conversation)
print(tokenizer.visualize_tokenization(ids, mask))

Output shows green (supervised) and red (not supervised) tokens.

Check Masking Statistics

Monitor what percentage of tokens are supervised:

num_supervised = (mask_tensor == 1).sum().item()
total_tokens = len(mask_tensor)
supervision_ratio = num_supervised / total_tokens
print(f"Supervising {num_supervised}/{total_tokens} tokens ({100*supervision_ratio:.1f}%)")

Typical ratios:

  • 50-60% for conversational data (user messages are ~half the tokens)
  • 30-40% for datasets with long questions and short answers
  • 70-80% for datasets with short questions and long explanations

Monitor Training Dynamics

Track the number of supervised tokens per step:

num_tokens = (train_targets >= 0).sum()

If this varies wildly (e.g., 100 tokens to 10,000 tokens per step), you may want to:

  1. Truncate conversations to max_tokens (already done in render_conversation)
  2. Use batch packing (advanced: pack multiple conversations into one sequence)

Expect 70%+ instruction-following after 1 epoch

Typical Training Times

On 8× A100 GPUs (80GB):

Dataset SizeBatch SizeDurationCost (AWS)
21K examples32~1 hour~$25
50K examples32~2.5 hours~$60
100K examples32~5 hours~$120

Expected Accuracy Improvements

Starting from a base model (untrained on instructions):

MetricBase ModelAfter SFTDelta
MMLU25% (random)35-45%+10-20%
ARC-Easy40-50%60-70%+10-20%
GSM8K0-5%15-30%+15-25%
HumanEval0-2%10-20%+10-18%

NOTE

These are rough estimates for a 12-layer model. Larger models see bigger gains.

Memory Requirements

Per GPU:

Batch SizeBF16 Model SizeActivation MemoryTotal Memory
1~450 MB~2 GB~3 GB
2~450 MB~4 GB~5 GB
4~450 MB~8 GB~10 GB
8~450 MB~16 GB~18 GB

Activation memory scales linearly with batch size and sequence length.

These pitfalls kill SFT quality before training ends

1. Overfitting on Small Datasets

Symptom: Training loss decreases but validation loss increases.

Solution:

  • Reduce num_epochs (try 0.5 epochs)
  • Add more diverse data to the mixture
  • Increase validation evaluation frequency to catch overfitting early

2. Catastrophic Forgetting

Symptom: Model performs well on recent tasks but forgets earlier ones.

Solution:

  • Ensure TaskMixture is shuffled (it is by default)
  • Add continual learning: include base model data in the mixture
  • Use smaller learning rates

3. Poor Multi-Turn Performance

Symptom: Model handles single-turn questions but fails on follow-ups.

Solution:

  • Ensure training data includes multi-turn conversations
  • Increase max_tokens to avoid truncating context
  • Evaluate specifically on multi-turn benchmarks

4. Model Doesn't Stop

Symptom: Model generates excessively long responses or doesn't emit <|assistant_end|>.

Solution:

  • Verify <|assistant_end|> is supervised (mask=1)
  • Check dataset quality: do assistant messages end properly?
  • Add length penalty during generation

5. Low Supervised Token Ratio

Symptom: Only 20-30% of tokens are supervised (expected is 50-60%).

Solution:

  • Check conversation balance: too many long user messages?
  • Verify masking logic: are you accidentally masking assistant messages?
  • Consider datasets with richer assistant responses

SFT trains hours, base training takes weeks—here's why

AspectBase TrainingSFT Training
ObjectiveNext token predictionInstruction following
DataRaw documentsConversations
SupervisionAll tokensAssistant responses only
DurationWeeksHours
IterationsMillionsHundreds
Learning RateHigher (0.006-0.4)Lower (0.004-0.2)
Epochs1 (of massive data)1-3 (of curated data)
GoalLearn languageLearn behavior

Extending to Multi-Modal

The conversation format naturally extends to multi-modal inputs:

messages = [
    {
        "role": "user", 
        "content": [
            {"type": "image", "url": "path/to/image.jpg"},
            {"type": "text", "text": "What's in this image?"}
        ]
    },
    {
        "role": "assistant",
        "content": "I see a cat sitting on a red couch."
    }
]

To support this:

  1. Add <|image_start|> and <|image_end|> special tokens
  2. Encode images as token sequences (e.g., via a vision encoder)
  3. Update render_conversation() to handle image content types

Next Steps

You now understand:

  • ✅ How conversations are tokenized with special tokens and masking
  • ✅ The SFT training loop with optimizers and schedulers
  • ✅ Evaluation strategies for chat models
  • ✅ Practical considerations for dataset selection

What's next?

  1. Reinforcement Learning from Human Feedback - Take your SFT model further with RL to optimize for human preferences
  2. Building Custom Evaluation Tasks - Create your own benchmarks to measure what matters for your use case
  3. Modern Transformer Architecture - RoPE, QK normalization, and the design choices that make chat models efficient

SFT transforms language models into chatbots—but only if you get the details right

Careful conversation design, selective supervision, and multi-task training make the difference. Get masking wrong, and your model learns to parrot users instead of responding to them. Get dataset mixing wrong, and it forgets half its skills by the time training ends.

For your fine-tuning pipeline, this means: the difference between a useful assistant and an expensive autocomplete is in these details.


Before you run your first SFT job:

  1. Verify your masking ratios. Print supervision stats: if <40% of tokens are supervised, your data probably has too much user content or masked tool outputs. Aim for 45-60%.

  2. Confirm <|assistant_end|> is supervised. Your model must learn to stop. Check the mask visualization—green on that token or your model will ramble forever.

  3. Shuffle your dataset mixture. Sequential training (all ARC, then all SmolTalk) causes catastrophic forgetting. Interleave with TaskMixture or equivalent.

  4. Test on multi-turn before deploying. Single-turn accuracy doesn't predict multi-turn coherence. Run 3+ turn conversations through your eval pipeline.

  5. One epoch is usually enough. SFT overfits fast on small datasets. Watch validation loss—if it starts climbing, stop.


Your model's personality starts here. Make it a good one.

The next post covers Reinforcement Learning from Human Feedback (RLHF), optimizing for subjective human preferences that are hard to capture in supervised datasets.

Previous in series:

Next in series:

Related posts:


Part of the nanochat Deep-Dive Series • Track 2: Practical Guides

GitHub: nanochat repository
SFT Script: scripts/chat_sft.py


Sources

Research Papers

Alignment Techniques

Datasets

  • SmolTalk: HuggingFace. "SmolTalk". Conversational dataset used in nanochat SFT.
  • OpenAssistant: Köpf, A. et al. (2023). "OpenAssistant Conversations". arXiv:2304.07327. Open-source conversation dataset.
  • Dolly 2.0: Databricks (2023). dolly-v2-12b. 15K human-written instruction examples.

Tokenization

nanochat Implementation

Industry Standards & Research (as of January 2025)

  • Stanford HAI AI Index 2024: Instruction Tuning Trends. Documents 3× improvement in instruction-following capability since 2022 through SFT advances.
  • Anthropic Constitutional AI: Training Helpful Assistants. Industry-leading approach to alignment through supervised feedback.
  • MLCommons Benchmarks: Chat Model Evaluation. Emerging standards for evaluating instruction-following models.

TIP

Experiment notebooks: Due to reader interest, interactive Jupyter notebooks for hands-on experiments are planned. Let us know if you'd like to see them!


Your model's personality starts here. The data you choose and how you mask it determines everything that follows.