Building an LLM from Scratch in PyTorch: The Full Lifecycle Cheatsheet¶
Most LLM tutorials give you one of two things: a high-level diagram with boxes and arrows, or a 10,000-line codebase with no explanation of why each piece exists.
This post is neither. It's a step-by-step lifecycle — 8 phases, each with working PyTorch code, the reasoning behind every decision, and an explicit Do / Don't list that captures the mistakes that cost most beginners weeks of wasted compute.
By the end you'll have built, trained, modernised, scaled, and aligned a language model — the exact same lifecycle that produced every major LLM you've used.
Phase 1: Core Transformer → the engine
Phase 2: Train a Tiny LLM → prove the pipeline works
Phase 3: Modernise → match 2026 architecture
Phase 4: Scale Efficiently → push past toy datasets
Phase 5: Mixture of Experts → conditional computation
Phase 6: SFT → turn autocomplete into an assistant
Phase 7: Reward Modelling → teach the model what "good" looks like
Phase 8: RLHF → optimise for human preference
Phase 1: Build the Core Transformer¶
What to Build¶
A minimal GPT-style decoder-only transformer. This is the skeleton every subsequent phase builds on.
Input tokens
↓
Token Embedding + Positional Embedding
↓
N × Transformer Block:
├── LayerNorm
├── Multi-Head Self-Attention (with causal mask)
├── Residual connection
├── LayerNorm
├── Feed-Forward Network
└── Residual connection
↓
Final LayerNorm
↓
Linear projection → vocabulary logits
Why It Matters¶
Every modern LLM — GPT-4, Claude, Llama, Gemini — is a scaled, refined version of this exact structure. Understanding each component at the code level means you can read any LLM paper and translate it directly to an implementation.
The Code¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q = nn.Linear(d_model, d_model, bias=False)
self.k = nn.Linear(d_model, d_model, bias=False)
self.v = nn.Linear(d_model, d_model, bias=False)
self.out = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
q = self.q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
# Causal mask: each position can only attend to itself and earlier positions
scale = math.sqrt(self.head_dim)
scores = (q @ k.transpose(-2, -1)) / scale # (B, H, T, T)
causal_mask = torch.triu(
torch.ones(T, T, device=x.device), diagonal=1
).bool()
scores = scores.masked_fill(causal_mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
return self.out(out)
class FeedForward(nn.Module):
def __init__(self, d_model: int, expansion: int = 4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_model * expansion),
nn.GELU(),
nn.Linear(d_model * expansion, d_model),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class TransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, n_heads)
self.norm2 = nn.LayerNorm(d_model)
self.ff = FeedForward(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x)) # pre-norm + residual
x = x + self.ff(self.norm2(x))
return x
class TinyGPT(nn.Module):
def __init__(self, vocab_size: int, d_model: int, n_heads: int,
n_layers: int, max_seq_len: int):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads)
for _ in range(n_layers)])
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, idx: torch.Tensor) -> torch.Tensor:
B, T = idx.shape
positions = torch.arange(T, device=idx.device)
x = self.token_emb(idx) + self.pos_emb(positions)
for block in self.blocks:
x = block(x)
return self.head(self.norm(x)) # logits: (B, T, vocab_size)
Do / Don't¶
| Do | Follow decoder-only architecture. Encoder-decoder (like T5) adds complexity you don't need for a generative LLM. |
| Do | Use pre-norm (norm → attention → residual) not post-norm. Pre-norm trains more stably at all scales. |
| Don't | Let the model peek at future tokens. The causal mask (torch.triu) is not optional — without it the model learns a trivial task (copy next token) and the training loss is meaningless. |
| Don't | Skip the residual connections. Without them, gradients vanish by layer 6 and the model stops learning. |
Phase 2: Train a Tiny LLM¶
What to Build¶
A complete training pipeline: character-level tokeniser, dataset batching, cross-entropy loss for next-token prediction, and a generation loop.
Why It Matters¶
A model that compiles is not a model that trains. This phase validates the entire forward/backward pass end-to-end on a dataset you can inspect by eye — before spending any real compute.
The Code¶
import torch
import torch.nn.functional as F
# --- 1. Tokeniser (character-level — no libraries needed) ---
text = open("tiny_shakespeare.txt").read()
chars = sorted(set(text))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda ids: "".join(itos[i] for i in ids)
data = torch.tensor(encode(text), dtype=torch.long)
# --- 2. Batching ---
def get_batch(data: torch.Tensor, batch_size: int, block_size: int,
device: str) -> tuple[torch.Tensor, torch.Tensor]:
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i + block_size] for i in ix])
y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
return x.to(device), y.to(device)
# --- 3. Training loop ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TinyGPT(
vocab_size=vocab_size,
d_model=128,
n_heads=4,
n_layers=4,
max_seq_len=256,
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for step in range(5000):
x, y = get_batch(data, batch_size=32, block_size=256, device=device)
logits = model(x) # (B, T, vocab_size)
loss = F.cross_entropy(
logits.view(-1, vocab_size), # (B*T, vocab_size)
y.view(-1), # (B*T,)
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 500 == 0:
print(f"step {step:>5} | loss {loss.item():.4f}")
# --- 4. Generation ---
@torch.no_grad()
def generate(model, prompt: str, max_new_tokens: int = 200) -> str:
model.eval()
idx = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
for _ in range(max_new_tokens):
idx_cond = idx[:, -256:] # crop to max_seq_len
logits = model(idx_cond)
next_token_logits = logits[:, -1, :] # last position only
probs = F.softmax(next_token_logits / 0.8, dim=-1) # temperature=0.8
next_token = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_token], dim=1)
return decode(idx[0].tolist())
print(generate(model, "HAMLET:"))
Expected Training Signal¶
step 0 | loss 4.17 ← random (log(65) ≈ 4.17 for 65-char vocab — correct)
step 500 | loss 2.51
step 1000 | loss 2.19
step 2500 | loss 1.94
step 5000 | loss 1.72 ← coherent-ish Shakespeare. Pipeline is working.
If loss doesn't decrease from step 0, the causal mask or residuals are broken. Fix before continuing.
Do / Don't¶
| Do | Test with a small, inspectable dataset (Shakespeare, a book). You need to read the outputs and see them improve. |
| Do | Verify your starting loss matches theory: log(vocab_size). If it's lower, you have a data leak. |
| Don't | Expect quality generations at this stage. The goal is to prove the loss curve goes down. Output quality comes later. |
| Don't | Move to Phase 3 until the pipeline runs cleanly on CPU. GPU bugs are harder to debug. |
Phase 3: Modernise the Architecture¶
What to Build¶
Replace the "classic" components with the three upgrades used in every major LLM since Llama 1: RMSNorm, SwiGLU, and RoPE embeddings. Add a KV cache for inference speed.
Why It Matters¶
These aren't cosmetic changes. Each one addresses a real problem: - RMSNorm trains more stably and is 15–20% faster than LayerNorm. - SwiGLU consistently outperforms GELU on downstream benchmarks. - RoPE encodes relative position, generalising better to sequences longer than the training context. - KV cache reduces inference compute from O(n²) to O(n) for each new token.
The Code¶
import torch
import torch.nn as nn
import torch.nn.functional as F
# --- RMSNorm: simpler, faster than LayerNorm ---
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
return x / rms * self.weight
# --- SwiGLU: gated FFN used in Llama, PaLM, Mistral ---
class SwiGLU(nn.Module):
def __init__(self, d_model: int, expansion: int = 4):
super().__init__()
hidden = int(d_model * expansion * 2 / 3) # standard Llama sizing
self.gate = nn.Linear(d_model, hidden, bias=False)
self.up = nn.Linear(d_model, hidden, bias=False)
self.down = nn.Linear(hidden, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down(F.silu(self.gate(x)) * self.up(x))
# --- RoPE: Rotary Position Embeddings ---
def precompute_rope(head_dim: int, max_seq_len: int,
base: int = 10000) -> tuple[torch.Tensor, torch.Tensor]:
theta = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len)
freqs = torch.outer(t, theta) # (T, head_dim/2)
cos = freqs.cos().repeat_interleave(2, dim=-1)
sin = freqs.sin().repeat_interleave(2, dim=-1)
return cos, sin # each: (T, head_dim)
def apply_rope(x: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor) -> torch.Tensor:
# Rotate pairs of dimensions: [x0, x1] → [x0*cos - x1*sin, x0*sin + x1*cos]
x_rotated = torch.stack([-x[..., 1::2], x[..., ::2]], dim=-1)
x_rotated = x_rotated.flatten(-2)
return x * cos + x_rotated * sin
# --- KV Cache: store past keys/values, avoid recomputing ---
class KVCache:
def __init__(self):
self.k_cache: torch.Tensor | None = None
self.v_cache: torch.Tensor | None = None
def update(self, k: torch.Tensor, v: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
if self.k_cache is None:
self.k_cache, self.v_cache = k, v
else:
self.k_cache = torch.cat([self.k_cache, k], dim=2)
self.v_cache = torch.cat([self.v_cache, v], dim=2)
return self.k_cache, self.v_cache
def reset(self):
self.k_cache = self.v_cache = None
# --- Modern Transformer Block combining all three upgrades ---
class ModernTransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, max_seq_len: int):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
self.q = nn.Linear(d_model, d_model, bias=False)
self.k = nn.Linear(d_model, d_model, bias=False)
self.v = nn.Linear(d_model, d_model, bias=False)
self.o = nn.Linear(d_model, d_model, bias=False)
self.ff = SwiGLU(d_model)
cos, sin = precompute_rope(self.head_dim, max_seq_len)
self.register_buffer("cos", cos)
self.register_buffer("sin", sin)
def forward(self, x: torch.Tensor,
kv_cache: KVCache | None = None) -> torch.Tensor:
B, T, C = x.shape
h = self.n_heads
d = self.head_dim
q = self.q(self.norm1(x)).view(B, T, h, d).transpose(1, 2)
k = self.k(self.norm1(x)).view(B, T, h, d).transpose(1, 2)
v = self.v(self.norm1(x)).view(B, T, h, d).transpose(1, 2)
# Apply RoPE to queries and keys
q = apply_rope(q, self.cos[:T], self.sin[:T])
k = apply_rope(k, self.cos[:T], self.sin[:T])
# Use KV cache at inference time
if kv_cache is not None:
k, v = kv_cache.update(k, v)
attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
out = attn.transpose(1, 2).contiguous().view(B, T, C)
x = x + self.o(out)
x = x + self.ff(self.norm2(x))
return x
Do / Don't¶
| Do | Upgrade incrementally: swap LayerNorm → RMSNorm first, verify loss is equal or better, then add SwiGLU, then RoPE. One change at a time. |
| Do | Use PyTorch's F.scaled_dot_product_attention — it automatically uses Flash Attention when available, giving 2–4× memory savings. |
| Don't | Mix old and new norm/embedding schemes. LayerNorm + RoPE requires different initialisation than RMSNorm + sinusoidal. Test the full combination. |
| Don't | Enable KV cache during training — it breaks gradient flow. It's an inference-only optimisation. |
Phase 4: Scale Efficiently¶
What to Build¶
Three complementary scaling techniques: BPE tokenisation (subword vocabulary), mixed precision training (FP16/BF16), and gradient accumulation (simulate large batches without large GPU memory).
Why It Matters¶
Character-level models plateau quickly. Subword tokenisation (used by every production LLM) achieves 3–4× better compression, meaning the model processes more meaning per token. Mixed precision cuts VRAM in half. Gradient accumulation lets a single A100 simulate a 256-sample batch.
The Code¶
import torch
from torch.cuda.amp import GradScaler, autocast
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
# --- 1. Train a BPE tokeniser ---
def train_bpe_tokeniser(files: list[str], vocab_size: int = 8192) -> Tokenizer:
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
trainer = trainers.BpeTrainer(
vocab_size=vocab_size,
special_tokens=["<|pad|>", "<|bos|>", "<|eos|>"],
)
tokenizer.train(files, trainer)
return tokenizer
# --- 2. Mixed Precision + Gradient Accumulation training loop ---
def train(model, dataloader, optimizer, n_epochs: int,
accum_steps: int = 8, device: str = "cuda"):
scaler = GradScaler() # for FP16 loss scaling
model.train()
optimizer.zero_grad()
for epoch in range(n_epochs):
for step, (x, y) in enumerate(dataloader):
x, y = x.to(device), y.to(device)
# autocast: runs forward in BF16/FP16, keeping master weights in FP32
with autocast(dtype=torch.bfloat16):
logits = model(x)
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), y.view(-1)
)
loss = loss / accum_steps # normalise before accumulating
scaler.scale(loss).backward()
# Only update weights every accum_steps steps
if (step + 1) % accum_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
print(f"epoch {epoch} step {step} | loss {loss.item() * accum_steps:.4f}")
# Checkpoint after every epoch
torch.save({
"epoch": epoch,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, f"checkpoint_epoch_{epoch}.pt")
Do / Don't¶
| Do | Checkpoint often when pushing GPU memory limits. The cost of saving is seconds; the cost of a crashed run is hours. |
| Do | Use torch.bfloat16 over float16 when your GPU supports it (A100, H100, RTX 30xx+). BF16 has the same dynamic range as FP32 and doesn't need a loss scaler. |
| Don't | Assume bigger batch sizes always improve results. Larger batches can hurt generalisation. Use gradient accumulation to hit your target effective batch size, then tune the learning rate separately. |
| Don't | Forget clip_grad_norm_. Without it, a single bad batch can produce a massive gradient update that corrupts the model. max_norm=1.0 is the standard. |
Phase 5: Mixture of Experts (MoE)¶
What to Build¶
Replace the dense feed-forward layer with a sparse MoE layer: multiple expert FFNs with a learned router that activates only the top-K experts per token.
Why It Matters¶
MoE is how Mixtral 8x7B achieves GPT-4-class quality with GPT-3.5-class inference cost. The model has 47B parameters total, but only 13B activate per forward pass. You get a big model's capacity with a small model's speed.
The Code¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
def __init__(self, d_model: int, n_experts: int = 8, top_k: int = 2):
super().__init__()
assert top_k <= n_experts
self.n_experts = n_experts
self.top_k = top_k
# One FFN per expert
self.experts = nn.ModuleList([SwiGLU(d_model) for _ in range(n_experts)])
# Router: projects each token to a score over all experts
self.router = nn.Linear(d_model, n_experts, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
x_flat = x.view(B * T, C) # (N, d_model)
# --- Routing ---
router_logits = self.router(x_flat) # (N, n_experts)
router_probs = F.softmax(router_logits, dim=-1)
# Select top-K experts per token
top_k_probs, top_k_idx = router_probs.topk(self.top_k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True) # renormalise
# --- Dispatch each token to its selected experts ---
output = torch.zeros_like(x_flat)
for expert_idx in range(self.n_experts):
# Find which tokens selected this expert and at what weight
mask = (top_k_idx == expert_idx).any(dim=-1) # (N,)
if not mask.any():
continue
tokens_for_expert = x_flat[mask] # (n_selected, C)
weight_idx = (top_k_idx[mask] == expert_idx).nonzero(as_tuple=True)[1]
weights = top_k_probs[mask].gather(1, weight_idx.unsqueeze(1))
expert_out = self.experts[expert_idx](tokens_for_expert)
output[mask] += expert_out * weights # weighted contribution
# --- Load balancing auxiliary loss ---
# Without this, the router collapses to always using the same 2 experts
expert_usage = router_probs.mean(0) # (n_experts,)
self.aux_loss = self.n_experts * (expert_usage * router_probs.mean(0)).sum()
return output.view(B, T, C)
# In your training loop, add the auxiliary loss:
# total_loss = main_loss + 0.01 * moe_layer.aux_loss
Do / Don't¶
| Do | Start with n_experts=4, top_k=2. Validate the routing is balanced (all experts used roughly equally) before scaling. |
| Do | Add the load-balancing auxiliary loss. Without it, the router will specialise 1–2 experts and ignore the rest — you get a dense model with extra compute overhead. |
| Don't | Activate all experts at once. That's a dense model. The entire point of MoE is conditional, sparse computation. |
| Don't | Add MoE before your baseline transformer is training correctly. It adds routing complexity and the aux loss is easy to miscalibrate. |
Phase 6: Supervised Fine-Tuning (SFT)¶
What to Build¶
Fine-tune your pre-trained model on instruction-response pairs so it transitions from "word predictor" to "helpful assistant."
Why It Matters¶
A raw pre-trained LLM is a document completer — given "The capital of France is", it predicts "Paris." That's not what users want. SFT teaches the model the format of being an assistant: receiving an instruction and generating a helpful, well-formatted response.
The Code¶
from torch.utils.data import Dataset, DataLoader
# --- Instruction dataset format ---
class InstructionDataset(Dataset):
def __init__(self, examples: list[dict], tokenizer, max_len: int = 512):
self.tokenizer = tokenizer
self.max_len = max_len
self.data = []
for ex in examples:
# Format: special tokens wrap the instruction/response
text = (
f"<|user|>\n{ex['instruction']}\n<|assistant|>\n{ex['response']}<|eos|>"
)
tokens = tokenizer.encode(text)
self.data.append(tokens[:max_len])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
tokens = self.data[idx]
x = torch.tensor(tokens[:-1], dtype=torch.long)
y = torch.tensor(tokens[1:], dtype=torch.long)
return x, y
# --- Only compute loss on the response tokens, not the instruction ---
def sft_loss(logits: torch.Tensor, targets: torch.Tensor,
response_start_idx: int) -> torch.Tensor:
# Mask out the instruction part — we only want the model to learn
# to generate the response, not repeat the instruction
mask = torch.zeros_like(targets, dtype=torch.bool)
mask[:, response_start_idx:] = True
logits_flat = logits.view(-1, logits.size(-1))
targets_flat = targets.view(-1)
mask_flat = mask.view(-1)
return F.cross_entropy(logits_flat[mask_flat], targets_flat[mask_flat])
# --- LoRA for efficient fine-tuning (add to existing Linear layers) ---
class LoRALinear(nn.Module):
def __init__(self, base: nn.Linear, r: int = 16, alpha: float = 32):
super().__init__()
self.base = base
self.lora_A = nn.Parameter(torch.randn(r, base.in_features) * 0.01)
self.lora_B = nn.Parameter(torch.zeros(base.out_features, r))
self.scale = alpha / r
# Freeze the base weights
for p in self.base.parameters():
p.requires_grad_(False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.base(x) + (x @ self.lora_A.T @ self.lora_B.T) * self.scale
Example training data format:
examples = [
{
"instruction": "Explain gradient descent in simple terms.",
"response": "Gradient descent is an optimisation algorithm that adjusts model weights by moving in the direction that most reduces the loss function — like rolling a ball downhill to find the lowest point in a valley."
},
{
"instruction": "Write a Python function to reverse a string.",
"response": "```python\ndef reverse_string(s: str) -> str:\n return s[::-1]\n```"
},
]
Do / Don't¶
| Do | Align tone and format to your target use case. A customer support assistant needs different SFT data than a coding assistant. The format of your training data becomes the format of the model's outputs. |
| Do | Use LoRA instead of full fine-tuning. It trains ~1% of parameters with minimal quality loss — essential for limited GPU memory. |
| Don't | Skip SFT and go straight to RLHF. RLHF has no signal to work with if the base model doesn't produce roughly coherent responses yet. |
| Don't | Compute loss on the instruction tokens. The model should learn to generate responses, not repeat prompts. Mask the instruction in the loss. |
Phase 7: Reward Modelling¶
What to Build¶
A separate model that takes a (prompt, response) pair and outputs a scalar reward score representing how much a human would prefer that response.
Why It Matters¶
RLHF needs a reward signal. You can't run a human in the loop for every training step. The reward model is a learned proxy for human preference — trained on human-annotated pairwise rankings, then used to score millions of responses during RL training.
The Code¶
class RewardModel(nn.Module):
"""
A transformer that maps (prompt + response) → scalar reward.
Built on top of a pre-trained / SFT-fine-tuned backbone.
"""
def __init__(self, backbone: TinyGPT):
super().__init__()
self.backbone = backbone
d_model = backbone.token_emb.embedding_dim
self.reward_head = nn.Linear(d_model, 1, bias=False)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
# Use the last token's hidden state as the sequence representation
hidden = self.backbone(input_ids) # (B, T, d_model)
last_hidden = hidden[:, -1, :] # (B, d_model)
return self.reward_head(last_hidden).squeeze(-1) # (B,) scalar rewards
class RewardDataset(Dataset):
"""Each example is a (prompt, chosen_response, rejected_response) triple."""
def __init__(self, examples: list[dict], tokenizer, max_len: int = 512):
self.pairs = []
for ex in examples:
chosen = tokenizer.encode(ex["prompt"] + ex["chosen"])[:max_len]
rejected = tokenizer.encode(ex["prompt"] + ex["rejected"])[:max_len]
self.pairs.append((
torch.tensor(chosen, dtype=torch.long),
torch.tensor(rejected, dtype=torch.long),
))
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
return self.pairs[idx]
def reward_model_loss(reward_model: RewardModel,
chosen_ids: torch.Tensor,
rejected_ids: torch.Tensor) -> torch.Tensor:
"""
Bradley-Terry pairwise preference loss:
maximise P(chosen > rejected) = σ(r_chosen - r_rejected)
"""
r_chosen = reward_model(chosen_ids)
r_rejected = reward_model(rejected_ids)
# Negative log-likelihood of chosen being preferred
return -F.logsigmoid(r_chosen - r_rejected).mean()
# Example human preference data
examples = [
{
"prompt": "<|user|>\nWhat is 2+2?\n<|assistant|>\n",
"chosen": "4.",
"rejected": "I'm not sure, mathematics is complex.",
},
{
"prompt": "<|user|>\nExplain photosynthesis.\n<|assistant|>\n",
"chosen": "Photosynthesis converts sunlight, CO₂, and water into glucose and oxygen.",
"rejected": "Plants do something with light I think.",
},
]
Do / Don't¶
| Do | Collect pairwise rankings (A vs B), not absolute scores. Humans are far more consistent at "which is better?" than "rate this 1–10." |
| Do | Validate the reward model separately before using it in RLHF. Run it on a held-out preference set and check it agrees with humans ≥65% of the time. |
| Don't | Assume one reward model fits all domains. A reward model trained on creative writing preferences will give nonsense scores for code correctness. Train domain-specific models. |
| Don't | Use the SFT model directly as the reward model without a separate reward head. They need different training objectives. |
Phase 8: RLHF with PPO¶
What to Build¶
Use Proximal Policy Optimisation (PPO) to fine-tune the SFT model using reward signals from the reward model — with a KL-divergence penalty to prevent the policy from drifting too far from the original SFT behaviour.
Why It Matters¶
This is the final alignment step. SFT teaches format. RLHF teaches preference — which format, tone, and content humans actually reward. ChatGPT, Claude, and Llama-2-chat all use this pipeline (or a variant like DPO/GRPO).
The Code¶
import torch
import torch.nn.functional as F
def compute_ppo_loss(
policy_logprobs: torch.Tensor, # log probs from the current policy
old_logprobs: torch.Tensor, # log probs from the policy at sampling time
rewards: torch.Tensor, # reward model scores
ref_logprobs: torch.Tensor, # log probs from the frozen SFT reference model
kl_coeff: float = 0.1, # KL penalty weight
clip_eps: float = 0.2, # PPO clipping range
) -> torch.Tensor:
# --- KL Penalty: penalise deviation from the SFT reference model ---
kl_penalty = policy_logprobs - ref_logprobs # per-token KL estimate
shaped_reward = rewards - kl_coeff * kl_penalty.sum(dim=-1)
# --- PPO Clipped Surrogate Objective ---
log_ratio = policy_logprobs.sum(-1) - old_logprobs.sum(-1)
ratio = log_ratio.exp() # probability ratio π_θ / π_old
# Clip the ratio to [1-ε, 1+ε] to prevent too-large updates
clipped_ratio = ratio.clamp(1 - clip_eps, 1 + clip_eps)
ppo_loss = -torch.min(
ratio * shaped_reward,
clipped_ratio * shaped_reward,
).mean()
return ppo_loss
def rlhf_training_step(
policy_model,
ref_model, # frozen SFT model — provides the KL baseline
reward_model,
prompt_ids: torch.Tensor,
optimizer,
):
# 1. Sample a response from the current policy
with torch.no_grad():
response_ids = policy_model.generate(prompt_ids, max_new_tokens=128)
full_ids = torch.cat([prompt_ids, response_ids], dim=1)
# 2. Score with reward model
rewards = reward_model(full_ids)
# 3. Get reference (SFT) log probs for KL penalty
ref_logits = ref_model(response_ids)
ref_logprobs = F.log_softmax(ref_logits, dim=-1).gather(
-1, response_ids.unsqueeze(-1)
).squeeze(-1)
# 4. Get old policy log probs (from the sampling step)
old_logits = policy_model(response_ids)
old_logprobs = F.log_softmax(old_logits, dim=-1).gather(
-1, response_ids.unsqueeze(-1)
).squeeze(-1)
# 5. Forward pass with gradient enabled
policy_logits = policy_model(response_ids)
policy_logprobs = F.log_softmax(policy_logits, dim=-1).gather(
-1, response_ids.unsqueeze(-1)
).squeeze(-1)
# 6. Compute PPO loss
loss = compute_ppo_loss(
policy_logprobs=policy_logprobs,
old_logprobs=old_logprobs.detach(),
rewards=rewards.detach(),
ref_logprobs=ref_logprobs.detach(),
kl_coeff=0.1,
)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 1.0)
optimizer.step()
return loss.item(), rewards.mean().item()
Monitoring What Matters¶
# The three numbers to watch during RLHF training:
print(f"PPO loss: {ppo_loss:.4f}") # should decrease
print(f"Mean reward: {mean_reward:.4f}") # should increase
print(f"KL divergence: {kl_div:.4f}") # must stay bounded (< 4–6 nats)
# If KL divergence explodes: increase kl_coeff (0.1 → 0.2)
# If reward barely moves: decrease kl_coeff (0.1 → 0.05)
# If reward increases but outputs sound robotic: the reward model is being gamed
Do / Don't¶
| Do | Add the KL-penalty. Without it, the policy will find adversarial prompts that score high on the reward model but produce garbled or unsafe text — called reward hacking. |
| Do | Monitor real outputs alongside metrics. A reward score of +3.2 means nothing if the actual text is repetitive or evasive. Sample 20–30 responses every 100 steps and read them. |
| Don't | Optimise only for reward score. The reward model is an imperfect proxy for human preference. PPO that maximises it completely will overfit to the reward model's blind spots. |
| Don't | Start RLHF from the base pre-trained model. Always start from the SFT checkpoint. RLHF with no instruction-following baseline diverges immediately. |
The Complete Lifecycle at a Glance¶
┌────────────────────────────────────────────────────────────────┐
│ PHASE GOAL KEY CODE │
├────────────────────────────────────────────────────────────────┤
│ 1. Core Working transformer MHA + causal mask │
│ 2. Train Prove pipeline works Cross-entropy + gen │
│ 3. Modernise 2026 architecture RMSNorm/SwiGLU/RoPE │
│ 4. Scale Handle real data BPE + AMP + grad accum │
│ 5. MoE Sparse computation Router + load balance │
│ 6. SFT Instruction following Response-masked loss │
│ 7. Reward Human preference proxy Bradley-Terry loss │
│ 8. RLHF Align to preferences PPO + KL penalty │
└────────────────────────────────────────────────────────────────┘
Start here: Phase 1 → 2 (tiny Shakespeare dataset, CPU)
Add modern: Phase 3 (swap components, verify loss unchanged)
Add scale: Phase 4 (BPE, AMP, gradient accumulation)
Optional: Phase 5 (MoE — only if you need parameter efficiency)
Align: Phase 6 → 7 → 8 (SFT first, always)
Summary¶
Building an LLM from scratch is not one task — it's eight sequential phases, each with its own failure modes and validation criteria.
Phase 1 gives you the engine: decoder-only transformer with causal masking, residual connections, and pre-norm. Phase 2 validates it actually trains: check that loss starts at log(vocab_size) and falls. Phase 3 modernises it to match production LLMs: RMSNorm, SwiGLU, RoPE, and KV cache for fast inference. Phase 4 makes it trainable at real scale: BPE tokenisation, mixed precision, gradient accumulation, and compulsive checkpointing. Phase 5 adds conditional computation via Mixture of Experts — only useful once the dense baseline is solid. Phase 6 is the transition from autocomplete to assistant: SFT with instruction-response pairs, response-masked loss, and LoRA for efficiency. Phase 7 trains a reward model that learns human preference from pairwise rankings. Phase 8 uses PPO with KL-penalty to align the policy to that reward model without drifting into reward-hacking.
The universal lesson cutting across all eight phases: don't skip validation gates. Each phase assumes the previous one is working correctly. A broken causal mask in Phase 1 corrupts every phase that follows. A reward model that doesn't outperform chance in Phase 7 makes RLHF in Phase 8 meaningless. Build the smallest thing that proves correctness at each step, then move forward.
This is the backbone of every modern LLM. The architectures differ at the margins. The lifecycle does not.
Implementing one of these phases and hitting a specific wall? Drop a question in the comments — happy to dig in.
Questions or discussion? Connect on LinkedIn, X or reach out via email.
Discussion
Have thoughts on this post? Share them below — questions, corrections, or your own experience are all welcome.