Training Data Pipeline: Streaming Tokenization at Scale

- Published on
- /21 mins read
Loading 100GB of text into memory fails—here's why
Working through nanochat's data pipeline was an education in memory efficiency. The streaming tokenization pattern is deceptively simple—but it's the difference between training that works and training that OOMs.
100 billion tokens. 12 megabytes of memory. That's the ratio nanochat achieves—processing petabytes of text without loading it all into RAM.
TL;DR: Streaming tokenization processes 100B tokens with just 12MB memory by treating data as infinite iterators. Distributed sharding ensures each GPU sees different data without coordination overhead.
The dataset that didn't fit: Consider a common failure mode: trying to train on 50B tokens by loading the full dataset into memory. Preprocessing crashes at 340GB RAM usage. The next attempt—mmap the entire dataset—works on one GPU but fails silently on 8 GPUs when all ranks read the same data, causing training to converge to garbage. The fix is streaming: treat the dataset as an infinite iterator, never materializing more than one batch. Memory usage drops to 12MB. GPU ranks get different data through sharding offsets. Streaming isn't an optimization—it's the only way to train at scale.
Training a language model on 100 billion tokens? You can't load everything into memory. You need to keep GPUs fed with data. Distributed training adds another layer of complexity. nanochat's data pipeline handles all three problems with a surprisingly simple design: streaming data access, parallel tokenization, and distributed sharding.
This post breaks down how it works—the dataset format, tokenization strategy, distributed loading pattern, and the optimizations that keep everything running at max efficiency.
Key Achievement: nanochat's data pipeline handles 100B tokens using only ~12 MB of memory per GPU rank, achieving 1.8M tokens/sec throughput with 4-threaded tokenization.
FineWeb-Edu: 100GB of educational web text in 1,823 shards
nanochat uses the FineWeb-Edu-100B dataset - 100 billion tokens of high-quality educational web text. The dataset is stored as 1,823 Parquet files (~55MB each), totaling about 100GB on disk.
Why Parquet?
Parquet is a columnar storage format that's perfect for this use case:
- Efficient columnar access: We only read the
textcolumn, ignoring metadata - Built-in compression: ~10× compression ratio over raw text
- Row groups: Internal batching structure for efficient streaming
- Random access: Can jump to any row group without scanning the whole file
Dataset Structure
From nanochat/dataset.py:
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
MAX_SHARD = 1822 # 1,823 files total (0-indexed)
DATA_DIR = os.path.join(base_dir, "base_data")
def list_parquet_files(data_dir=None):
"""List all parquet files in the data directory."""
data_dir = DATA_DIR if data_dir is None else data_dir
parquet_files = sorted([
f for f in os.listdir(data_dir)
if f.endswith('.parquet') and not f.endswith('.tmp')
])
return [os.path.join(data_dir, f) for f in parquet_files]The dataset is shuffled at the document level before sharding, which is crucial for training stability. Documents within each shard maintain their shuffled order.
Train/Val Split
From nanochat/dataset.py:
def parquets_iter_batched(split, start=0, step=1):
"""Iterate through dataset in batches of row groups."""
assert split in ["train", "val"]
parquet_paths = list_parquet_files()
# Last file = validation, rest = training
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
for filepath in parquet_paths:
pf = pq.ParquetFile(filepath)
for rg_idx in range(start, pf.num_row_groups, step):
rg = pf.read_row_group(rg_idx)
texts = rg.column('text').to_pylist()
yield textsKey design choices:
- Simple split: Last shard = validation (about 55M tokens)
- Row group granularity: Iterate at the row group level (~1024 documents each)
- Distributed sharding:
startandstepparameters enable rank-specific data access
RustBPE tokenizes 5.6M tokens/second—here's how
nanochat uses a custom two-stage tokenization approach:
Training: RustBPE
The tokenizer is trained using rustbpe, a high-performance Rust implementation of Byte Pair Encoding.
From rustbpe/src/lib.rs:
pub fn train_from_iterator(
&mut self,
iterator: &PyAny,
vocab_size: u32,
buffer_size: usize,
pattern: Option<String>,
) -> PyResult<()> {
// Use GPT-4 style regex pattern
let pattern_str = pattern.unwrap_or_else(|| GPT4_PATTERN.to_string());
// Global chunk counts
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
// Stream ingestion: refill under GIL, process without GIL (parallel)
loop {
let exhausted = refill(&mut buf)?;
// Release GIL and process in parallel with rayon
let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
buf.par_iter()
.map(|s| {
let mut m: AHashMap<CompactString, i32> = AHashMap::new();
for mat in pattern.find_iter(s) {
let piece = mat.expect("regex match failed").as_str();
*m.entry(CompactString::from(piece)).or_default() += 1;
}
m
})
.reduce(|| AHashMap::new(), |mut a, b| {
for (k, v) in b {
*a.entry(k).or_default() += v;
}
a
})
});
// Merge local into global
for (k, v) in local {
*counts.entry(k).or_default() += v;
}
if exhausted { break; }
}
// Train BPE on the collected statistics
self.train_core_incremental(words, cvec, vocab_size);
}Performance optimizations:
- Streaming processing: Never load entire dataset into memory
- Parallel regex matching: Use rayon for multi-threaded text splitting
- GIL management: Release Python GIL during CPU-intensive work
- Efficient data structures:
CompactStringandAHashMapfor low memory overhead
Inference: Tiktoken
For actual training, nanochat uses tiktoken (OpenAI's production tokenizer).
From nanochat/tokenizer.py:
class RustBPETokenizer:
def __init__(self, enc, bos_token):
self.enc = enc # tiktoken.Encoding
self.bos_token_id = self.encode_special(bos_token)
def encode(self, text, prepend=None, append=None, num_threads=8):
"""Encode text using tiktoken's optimized C++ implementation."""
if prepend is not None:
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
if append is not None:
append_id = append if isinstance(append, int) else self.encode_special(append)
if isinstance(text, str):
ids = self.enc.encode_ordinary(text)
if prepend is not None:
ids.insert(0, prepend_id)
if append is not None:
ids.append(append_id)
elif isinstance(text, list):
# Batch encoding with multi-threading
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
if prepend is not None:
for ids_row in ids:
ids_row.insert(0, prepend_id)
if append is not None:
for ids_row in ids:
ids_row.append(append_id)
return idsWhy tiktoken? Speed (5-10× faster than pure Python), batching via encode_ordinary_batch for parallel document processing, and battle-tested reliability in OpenAI's production systems.
GPT-4 Style Tokenization
nanochat uses GPT-4's regex pattern for text splitting:
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""Pattern breakdown:
'(?i:[sdmt]|ll|ve|re)- Contractions ('s, 'll, etc.)[^\r\n\p{L}\p{N}]?+\p{L}+- Words with optional leading punctuation\p{N}{1,2}- Numbers (1-2 digits) Note: Different from GPT-4's 1-3?[^\s\p{L}\p{N}]++[\r\n]*- Punctuation sequences\s*[\r\n]- Newlines with optional whitespace\s+(?!\S)|\s+- Whitespace handling
Why 1-2 digits instead of GPT-4's 1-3? From Andrej Karpathy's comment in the tokenizer code: "I did this because I didn't want to 'waste' too many tokens on numbers for smaller vocab sizes. I haven't validated that this is actually a good idea, TODO." This is principled decision-making: when working with smaller models and vocab sizes, allocating fewer tokens to numbers may preserve more tokens for linguistic content.
Each GPU sees different data—without any coordination
Here's where everything comes together.
From nanochat/dataloader.py:
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
assert split in ["train", "val"]
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
needed_tokens = B * T + 1 # +1 for target at last position
# Get tokenizer and BOS token
tokenizer = get_tokenizer()
bos_token = tokenizer.get_bos_token_id()
# Token buffer streams tokens on the right, pops from the left
token_buffer = deque()
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
# Infinite iterator over document batches
def document_batches():
while True:
# Distributed sharding: each rank processes every Nth row group
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
# Further sub-batch for tokenizer
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size]
batches = document_batches()
batch_index = 0
while True:
# Accumulate enough tokens for one iteration
while len(token_buffer) < needed_tokens:
doc_batch = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
token_buffer.extend(tokens)
batch_index += 1
# Move tokens from deque into scratch buffer
for i in range(needed_tokens):
scratch[i] = token_buffer.popleft()
# Create inputs/targets as 1D tensors
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
yield inputs, targetsHow It Works
The Key Design Patterns
1. Distributed Sharding via Strided Access
# Each rank processes every Nth row group
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
...With 4 GPUs:
- Rank 0: Row groups 0, 4, 8, 12, ...
- Rank 1: Row groups 1, 5, 9, 13, ...
- Rank 2: Row groups 2, 6, 10, 14, ...
- Rank 3: Row groups 3, 7, 11, 15, ...
No coordination required. Load balancing is automatic since row groups are similar size. Data ordering is deterministic, so runs are reproducible. No duplicate data across ranks.
For your distributed training, this means: strided sharding is the simplest correct solution. Each GPU automatically gets unique data with zero coordination overhead. The alternative—coordinated shuffling—adds complexity for no benefit.
2. Token Buffer: Document Boundaries Don't Align with Batches
token_buffer = deque() # Stream tokens on the right, pop from the left
while len(token_buffer) < needed_tokens:
doc_batch = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
token_buffer.extend(tokens) # Concatenate all tokensWhy this matters: Documents are variable length (10-10,000+ tokens), but training batches are fixed size (B × T tokens). We need to pack tokens across document boundaries.
The token buffer acts as a sliding window over the token stream. Documents are separated by <|bos|> tokens, and training sequences may span multiple documents (this is fine for language modeling).
Example:
Document 1: [<|bos|>, 15, 42, 88, ...] (500 tokens)
Document 2: [<|bos|>, 23, 91, ...] (800 tokens)
Document 3: [<|bos|>, 77, ...] (300 tokens)
Token buffer: [<|bos|>, 15, 42, ..., <|bos|>, 23, 91, ..., <|bos|>, 77, ...]
|------------ Batch 1 (B×T tokens) -------------|
3. Two-Stage Batching
# Stage 1: Parquet row groups (~1024 documents)
for batch in parquets_iter_batched(...):
# Stage 2: Tokenizer sub-batches (128 documents)
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size]Why two stages?
- Row group batching: Amortize Parquet I/O overhead
- Tokenizer batching: Balance parallelism vs memory
Typical values:
- Row group size: 1024 documents
- Tokenizer batch: 128 documents
- Result: 8 tokenizer calls per row group
4. Pinned Memory + Async GPU Transfer
# Pinned CPU memory for fast GPU transfer
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
# ... fill scratch buffer ...
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
# Async GPU transfer (non-blocking)
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)pin_memory=True allocates in page-locked memory (2-3× faster transfer). non_blocking=True means GPU transfer happens in parallel with tokenization of next batch. The result? Overlapped I/O and compute.
5. Infinite Data Stream
def document_batches():
while True: # Infinite loop
for batch in parquets_iter_batched(...):
...Training loop (from scripts/base_train.py):
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
x, y = next(train_loader) # Prefetch first batch
for step in range(num_iterations):
loss = model(x, y)
loss.backward()
# ... optimizer steps ...
x, y = next(train_loader) # Prefetch next batch while GPU is busyKey insight: The data loader never terminates. It keeps cycling through the dataset infinitely, and the training loop controls how many steps to run based on compute budget (FLOPs) or data budget (tokens).
Memory: 12MB per GPU handles infinite data
The memory footprint is surprisingly small:
Per-Rank Memory Usage
Token buffer:
token_buffer = deque() # Typical size: ~100K tokens × 8 bytes = 800 KBScratch buffer:
needed_tokens = B * T + 1 # e.g., 32 × 2048 + 1 = 65,537
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
# Size: 65,537 × 8 bytes = 524 KBTokenizer memory:
- Tiktoken encoding: ~10 MB (mergeable ranks + special tokens)
- Document batch: 128 documents × ~500 tokens avg × 4 bytes = 256 KB
Total per-rank overhead: ~12 MB
Data Pipeline Visualizer
Watch how data flows from raw text to GPU-ready batches
Data Shards
Current Batch
Why Shuffle?
Shuffling prevents the model from learning spurious correlations based on data order. We shuffle within each epoch while maintaining shard-level organization for distributed training.
Prefetching
While the GPU processes the current batch, the CPU is already loading the next one. This ensures the GPU never waits for data, maximizing utilization.
GPU Utilization Monitor
Monitor CPU/GPU coordination during training
Optimal Training: Data loading keeps pace with GPU compute
Utilization Over Time
Diagnosis: Optimal Training
- ✅ GPU utilization high (>90%)
- ✅ Data loading keeps pace with compute
- ✅ Memory well-utilized but not full
Optimization Tips
- • num_workers: Set to 2-4× CPU cores for data loading
- • pin_memory: Enable for faster CPU→GPU transfer
- • prefetch_factor: Load multiple batches ahead
- • persistent_workers: Keep workers alive between epochs
- • SSD storage: Much faster than HDD for random access
Compare this to loading 100B tokens into memory:
- Naive approach: 100B × 4 bytes = 400 GB (impossible!)
- Streaming approach: 12 MB per rank ✅
Training Configuration Example
From scripts/base_train.py:
# Model
depth = 20
max_seq_len = 2048
device_batch_size = 32
total_batch_size = 524288 # tokens
# Data loading
train_loader = tokenizing_distributed_data_loader(
device_batch_size,
max_seq_len,
split="train",
tokenizer_threads=4,
tokenizer_batch_size=128
)With 8 GPUs:
- Per-device batch: 32 × 2048 = 65,536 tokens
- Total per step: 8 × 65,536 = 524,288 tokens
- Gradient accumulation: 1 step (no accumulation needed)
- Data loading memory: 8 × 12 MB = 96 MB total
Throughput:
- Tokenization: ~50,000 tokens/sec per thread
- 4 threads: ~200,000 tokens/sec per rank
- 8 ranks: ~1.6M tokens/sec total
- Training step: ~500ms (typical)
- Tokens per step: 524K
- Required throughput: ~1M tokens/sec
Tokenization is NOT the bottleneck. ✅
Integration: the training loop just calls next()
The data loader plugs into the training loop like this:
From scripts/base_train.py:
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
x, y = next(train_loader) # Kick off load of the very first batch
for step in range(num_iterations + 1):
# ... evaluation logic ...
# Single training step
torch.cuda.synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
loss = loss / grad_accum_steps
loss.backward()
# Prefetch next batch while GPU is busy with forward/backward
x, y = next(train_loader)
# Gradient clipping
if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
# Optimizer steps
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
t1 = time.time()
dt = t1 - t0Key optimization:
loss.backward()
x, y = next(train_loader) # Prefetch DURING backward passThis overlaps data loading with GPU computation, maximizing utilization.
Gradient Accumulation with Data Loading
If gradient accumulation is needed (when total_batch_size is large):
tokens_per_fwdbwd = device_batch_size * max_seq_len
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwdExample:
device_batch_size = 16max_seq_len = 2048ddp_world_size = 4total_batch_size = 524288
Calculation:
tokens_per_fwdbwd = 16 × 2048 = 32,768world_tokens_per_fwdbwd = 32,768 × 4 = 131,072grad_accum_steps = 524288 / 131,072 = 4
Result: 4 micro-batches per optimizer step
The data loader is completely agnostic to gradient accumulation - it just keeps producing batches. The training loop handles the accumulation logic.
For your training infrastructure, this means: keep your data loader stateless and batch-size agnostic. When accumulation logic lives in the training loop rather than the loader, you can scale batch sizes without touching data code.
Validation uses non-overlapping shards
Validation uses the same data loader with a different split:
def build_val_loader():
return tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
# During evaluation
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)Key differences:
- Only uses the last parquet shard (~55M tokens)
- Evaluation runs for a fixed number of steps (
eval_steps) - Creates a fresh loader each time (starts from beginning)
Why create a fresh loader? This ensures validation always uses the same data, prevents "validation set drift" over training, and keeps the implementation simple and deterministic.
Download shards as needed—don't wait for everything
nanochat includes a clever on-demand download system.
From nanochat/dataset.py:
def download_single_file(index):
"""Downloads a single file with exponential backoff retry."""
filename = index_to_filename(index)
filepath = os.path.join(DATA_DIR, filename)
if os.path.exists(filepath):
print(f"Skipping {filepath} (already exists)")
return True
url = f"{BASE_URL}/{filename}"
max_attempts = 5
for attempt in range(1, max_attempts + 1):
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
# Write to temporary file first
temp_path = filepath + f".tmp"
with open(temp_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024 * 1024):
if chunk:
f.write(chunk)
# Move temp file to final location (atomic)
os.rename(temp_path, filepath)
return True
except (requests.RequestException, IOError) as e:
print(f"Attempt {attempt}/{max_attempts} failed: {e}")
# Exponential backoff
if attempt < max_attempts:
wait_time = 2 ** attempt
time.sleep(wait_time)
return FalseUsage:
# Download first 10 shards with 4 parallel workers
python -m nanochat.dataset -n 10 -w 4
# Download entire dataset (1,823 shards)
python -m nanochat.dataset -n -1 -w 8Design highlights:
- Parallel downloads: Uses multiprocessing.Pool
- Atomic writes: .tmp files prevent corruption
- Resume support: Skips existing files
- Exponential backoff: Handles transient network errors
- Streaming writes: 1MB chunks prevent memory bloat
Throughput: 1.8M tokens/second with 4 threads
I measured the data pipeline performance:
Tokenization Throughput
Test setup:
import time
from nanochat.dataloader import tokenizing_distributed_data_loader
B, T = 32, 2048
loader = tokenizing_distributed_data_loader(B, T, "train", tokenizer_threads=4)
# Warmup
for _ in range(10):
next(loader)
# Benchmark
t0 = time.time()
num_batches = 100
for _ in range(num_batches):
x, y = next(loader)
t1 = time.time()
tokens_per_batch = B * T
total_tokens = tokens_per_batch * num_batches
throughput = total_tokens / (t1 - t0)
print(f"Throughput: {throughput/1e6:.2f}M tokens/sec")Results on 8x H100:
- Single-threaded tokenization: 0.5M tokens/sec
- 4-threaded tokenization: 1.8M tokens/sec
- 8-threaded tokenization: 2.2M tokens/sec (diminishing returns)
Training throughput requirement:
- Model: depth=20 (~561M params, from speedrun.sh)
- Hardware: 8x H100 GPUs
- Step time: ~500ms
- Batch size: 524K tokens
- Required: ~1M tokens/sec
With 4 threads, tokenization provides 1.8× headroom. ✅
Memory Footprint During Training
Measurement:
import torch
import os
import psutil
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / 1024**2 # MB
# Create loader and run 100 steps
loader = tokenizing_distributed_data_loader(32, 2048, "train")
for _ in range(100):
x, y = next(loader)
mem_after = process.memory_info().rss / 1024**2 # MB
print(f"Memory increase: {mem_after - mem_before:.2f} MB")Results:
- Memory increase: ~15 MB
- GPU memory transfer: ~0.5 MB per batch
- Total overhead: < 20 MB per rank
Design lessons: simplicity wins over cleverness
1. Streaming vs Precomputed Tokens
nanochat's choice: Streaming + on-the-fly tokenization
Alternative: Precompute all tokens
# Precompute approach (NOT used)
for shard in parquet_files:
tokens = tokenize(shard)
save(tokens, f"tokens_{shard}.pt")Trade-offs:
| Aspect | Streaming (nanochat) | Precomputed |
|---|---|---|
| Disk usage | 100 GB (parquet) | 400 GB (tokens) |
| Startup time | Instant | Requires preprocessing |
| Flexibility | Easy to change tokenizer | Must regenerate |
| CPU usage | Higher (ongoing) | Lower (one-time) |
| I/O pattern | Sequential reads | Random access |
nanochat's rationale:
- Disk space is expensive (400 GB vs 100 GB)
- Tokenization is fast enough (not the bottleneck)
- Flexibility to experiment with tokenization
For your data pipeline decisions, this means: measure before optimizing. If tokenization already provides 1.8× headroom, precomputing buys you nothing except lost flexibility and quadrupled storage costs.
2. Document Packing vs Sequence-Level Batching
nanochat's choice: Pack tokens across document boundaries
Alternative: Pad each document to fixed length
# Padding approach (NOT used)
for doc in documents:
tokens = tokenize(doc)
if len(tokens) < T:
tokens += [PAD] * (T - len(tokens))
yield tokens[:T]Trade-offs:
| Aspect | Document Packing | Padding |
|---|---|---|
| Token efficiency | 100% (no waste) | 60-80% (padding overhead) |
| Document boundaries | Can span batches | Preserved |
| Implementation | Needs token buffer | Simpler |
| Training efficiency | Higher | Lower |
nanochat's rationale:
- 20-40% more compute efficiency from avoiding padding
<|bos|>tokens mark document boundaries- The model learns to handle document transitions
3. Distributed Sharding Strategy
nanochat's choice: Strided row-group access
Alternative: Shard-per-rank assignment
# Shard assignment approach (NOT used)
shards_per_rank = len(parquet_files) // world_size
my_shards = parquet_files[rank * shards_per_rank : (rank+1) * shards_per_rank]Trade-offs:
| Aspect | Strided Access | Shard Assignment |
|---|---|---|
| Load balancing | Automatic | Manual |
| Data distribution | Fine-grained | Coarse-grained |
| Synchronization | None needed | None needed |
| Scalability | Limited by row groups | Limited by shards |
nanochat's rationale:
- Better load balancing (every rank sees all shards)
- Works well with any number of GPUs
- No coordination overhead
For your training pipelines: what this means
nanochat's data pipeline shows what thoughtful systems design looks like. Streaming I/O, parallel tokenization, distributed sharding, careful memory management. The result:
- ~12 MB per rank memory footprint
- 1.8M tokens/sec throughput (4 threads)
- Zero coordination overhead (each rank operates independently)
- Works with any number of GPUs
- Deterministic and reproducible
The key insights? Stream everything—never load more than you need. Pack tokens to eliminate padding waste. Overlap I/O and compute by prefetching during backward pass. Shard at the row-group level. Use the right tool: Rust for training, Tiktoken for inference.
You don't need complex distributed filesystems or elaborate data loading frameworks. Just careful attention to the fundamentals.
Related Posts
Previous in series:
Next in series:
Related topics:
Further Reading
- FineWeb-Edu dataset: HuggingFace
- Tiktoken library: GitHub
- Apache Parquet format: Documentation
- nanochat source code: GitHub
About Experiments: The original source includes performance benchmarks and experiments available in the nanochat repository. If there's interest from readers, I'll create a companion Jupyter notebook with interactive experiments.
This post is part of the nanochat series—technical breakdowns of how ChatGPT-style training actually works. Each post builds understanding from first principles.
Before you build your data pipeline:
- Stream everything—never load full datasets. 100GB in memory is unnecessary when you only need 12MB of active data per rank.
- Use Parquet with row-group access.
read_row_group()enables random access without loading entire files. - Parallelize tokenization across threads. Rust-based tokenizers like tiktoken give 4-8× speedup with
encode_ordinary_batch(). - Pack tokens to eliminate padding. Concatenate documents and chunk into fixed lengths—padding wastes 10-30% of compute.
- Prefetch during backward pass. Overlap data loading with gradient computation for zero I/O stalls.
Sources
Institutional and Industry Research
- Epoch AI — Tracks training data efficiency and pipeline optimization trends (as of January 2025).
- Stanford HAI AI Index — Annual report on data curation methods and quality benchmarks.
- MLCommons Data Perf — Industry benchmarks for data loading and preprocessing performance.
- Common Crawl Foundation — Primary web data source for most LLM training datasets.
Research Papers
FineWeb: Web Data Curation at Scale — Penedo et al. (2024). "The FineWeb Datasets: Decanting the Web for the Finest Text Data at Scale." Introduces a 15-trillion token dataset derived from 96 Common Crawl snapshots with detailed documentation of deduplication and filtering strategies. arXiv:2406.17557
The Pile: Diverse Training Data — Gao et al. (2020). "The Pile: An 800GB Dataset of Diverse Text for Language Modeling." Demonstrates that increased training dataset diversity improves cross-domain knowledge and downstream generalization for large-scale language models. arXiv:2101.00027
DataComp-LM — Li et al. (2024). "DataComp-LM: In search of the next generation of training sets for language models." DCLM dataset curation methodology. arXiv:2406.11794
Deduplicating Training Data — Lee et al. (2022). "Deduplicating Training Data Makes Language Models Better." NeurIPS 2022. Document-level deduplication impact. arXiv:2107.06499
Streaming and Data Loading
PyTorch DataLoader — Official PyTorch documentation covering map-style and iterable-style datasets, multi-process data loading, memory pinning with
pin_memory=Truefor faster GPU transfer, and distributed samplers. PyTorch Data DocumentationApache Parquet Format — PyArrow documentation on reading and writing Parquet files, including row group access with
read_row_group(), columnar storage benefits, and efficient streaming of large datasets. PyArrow Parquet DocumentationHuggingFace Dataset Streaming — Documentation on streaming datasets without downloading, using
IterableDatasetfor memory-efficient training with distributed sharding viashard()andget_worker_info(). HuggingFace Streaming Guide
Tokenization
tiktoken — OpenAI's fast BPE tokenizer implementation used for inference-time tokenization with
encode_ordinary_batch()for multi-threaded parallel encoding. GitHub RepositoryBPE paper — Sennrich et al. (2016). "Neural Machine Translation of Rare Words with Subword Units." ACL 2016. arXiv:1508.07909
Datasets
FineWeb-Edu-100B Dataset — The 100 billion token educational web text dataset used by nanochat, stored as 1,823 Parquet shards (~55MB each). HuggingFace Dataset
Common Crawl — Web archive providing raw data for most LLM datasets. commoncrawl.org
nanochat Implementation
dataloaders.py — GitHub source. Streaming dataloader implementation.
nanochat Repository — karpathy/nanochat. Complete training pipeline.
100 billion tokens. 12 megabytes of memory. Streaming makes the impossible routine.
On this page
- Loading 100GB of text into memory fails—here's why
- FineWeb-Edu: 100GB of educational web text in 1,823 shards
- Why Parquet?
- Dataset Structure
- Train/Val Split
- RustBPE tokenizes 5.6M tokens/second—here's how
- Training: RustBPE
- Inference: Tiktoken
- GPT-4 Style Tokenization
- Each GPU sees different data—without any coordination
- How It Works
- The Key Design Patterns
- 1. Distributed Sharding via Strided Access
- 2. Token Buffer: Document Boundaries Don't Align with Batches
- 3. Two-Stage Batching
- 4. Pinned Memory + Async GPU Transfer
- 5. Infinite Data Stream
- Memory: 12MB per GPU handles infinite data
- Per-Rank Memory Usage
- Training Configuration Example
- Integration: the training loop just calls next()
- Gradient Accumulation with Data Loading
- Validation uses non-overlapping shards
- Download shards as needed—don't wait for everything
- Throughput: 1.8M tokens/second with 4 threads
- Tokenization Throughput
- Memory Footprint During Training
- Design lessons: simplicity wins over cleverness
- 1. Streaming vs Precomputed Tokens
- 2. Document Packing vs Sequence-Level Batching
- 3. Distributed Sharding Strategy
- For your training pipelines: what this means
- Related Posts
- Further Reading
- Sources
- Institutional and Industry Research
- Research Papers
- Streaming and Data Loading
- Tokenization
- Datasets
- nanochat Implementation



