Skip to content

Building an LLM from Scratch in PyTorch: The Full Lifecycle Cheatsheet

Most LLM tutorials give you one of two things: a high-level diagram with boxes and arrows, or a 10,000-line codebase with no explanation of why each piece exists.

This post is neither. It's a step-by-step lifecycle — 8 phases, each with working PyTorch code, the reasoning behind every decision, and an explicit Do / Don't list that captures the mistakes that cost most beginners weeks of wasted compute.

By the end you'll have built, trained, modernised, scaled, and aligned a language model — the exact same lifecycle that produced every major LLM you've used.

Phase 1: Core Transformer    → the engine
Phase 2: Train a Tiny LLM    → prove the pipeline works
Phase 3: Modernise           → match 2026 architecture
Phase 4: Scale Efficiently   → push past toy datasets
Phase 5: Mixture of Experts  → conditional computation
Phase 6: SFT                 → turn autocomplete into an assistant
Phase 7: Reward Modelling    → teach the model what "good" looks like
Phase 8: RLHF                → optimise for human preference

Phase 1: Build the Core Transformer

What to Build

A minimal GPT-style decoder-only transformer. This is the skeleton every subsequent phase builds on.

Input tokens
Token Embedding + Positional Embedding
N × Transformer Block:
    ├── LayerNorm
    ├── Multi-Head Self-Attention (with causal mask)
    ├── Residual connection
    ├── LayerNorm
    ├── Feed-Forward Network
    └── Residual connection
Final LayerNorm
Linear projection → vocabulary logits

Why It Matters

Every modern LLM — GPT-4, Claude, Llama, Gemini — is a scaled, refined version of this exact structure. Understanding each component at the code level means you can read any LLM paper and translate it directly to an implementation.

The Code

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.q = nn.Linear(d_model, d_model, bias=False)
        self.k = nn.Linear(d_model, d_model, bias=False)
        self.v = nn.Linear(d_model, d_model, bias=False)
        self.out = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape

        q = self.q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        # Causal mask: each position can only attend to itself and earlier positions
        scale = math.sqrt(self.head_dim)
        scores = (q @ k.transpose(-2, -1)) / scale          # (B, H, T, T)
        causal_mask = torch.triu(
            torch.ones(T, T, device=x.device), diagonal=1
        ).bool()
        scores = scores.masked_fill(causal_mask, float("-inf"))

        attn = F.softmax(scores, dim=-1)
        out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.out(out)


class FeedForward(nn.Module):
    def __init__(self, d_model: int, expansion: int = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_model * expansion),
            nn.GELU(),
            nn.Linear(d_model * expansion, d_model),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn  = MultiHeadAttention(d_model, n_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff    = FeedForward(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.norm1(x))   # pre-norm + residual
        x = x + self.ff(self.norm2(x))
        return x


class TinyGPT(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_heads: int,
                 n_layers: int, max_seq_len: int):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Embedding(max_seq_len, d_model)
        self.blocks    = nn.ModuleList([TransformerBlock(d_model, n_heads)
                                        for _ in range(n_layers)])
        self.norm      = nn.LayerNorm(d_model)
        self.head      = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        B, T = idx.shape
        positions = torch.arange(T, device=idx.device)
        x = self.token_emb(idx) + self.pos_emb(positions)
        for block in self.blocks:
            x = block(x)
        return self.head(self.norm(x))          # logits: (B, T, vocab_size)

Do / Don't

Do Follow decoder-only architecture. Encoder-decoder (like T5) adds complexity you don't need for a generative LLM.
Do Use pre-norm (norm → attention → residual) not post-norm. Pre-norm trains more stably at all scales.
Don't Let the model peek at future tokens. The causal mask (torch.triu) is not optional — without it the model learns a trivial task (copy next token) and the training loss is meaningless.
Don't Skip the residual connections. Without them, gradients vanish by layer 6 and the model stops learning.

Phase 2: Train a Tiny LLM

What to Build

A complete training pipeline: character-level tokeniser, dataset batching, cross-entropy loss for next-token prediction, and a generation loop.

Why It Matters

A model that compiles is not a model that trains. This phase validates the entire forward/backward pass end-to-end on a dataset you can inspect by eye — before spending any real compute.

The Code

import torch
import torch.nn.functional as F

# --- 1. Tokeniser (character-level — no libraries needed) ---
text = open("tiny_shakespeare.txt").read()
chars = sorted(set(text))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda ids: "".join(itos[i] for i in ids)

data = torch.tensor(encode(text), dtype=torch.long)

# --- 2. Batching ---
def get_batch(data: torch.Tensor, batch_size: int, block_size: int,
              device: str) -> tuple[torch.Tensor, torch.Tensor]:
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
    return x.to(device), y.to(device)

# --- 3. Training loop ---
device = "cuda" if torch.cuda.is_available() else "cpu"

model = TinyGPT(
    vocab_size=vocab_size,
    d_model=128,
    n_heads=4,
    n_layers=4,
    max_seq_len=256,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for step in range(5000):
    x, y = get_batch(data, batch_size=32, block_size=256, device=device)
    logits = model(x)                                  # (B, T, vocab_size)
    loss = F.cross_entropy(
        logits.view(-1, vocab_size),                   # (B*T, vocab_size)
        y.view(-1),                                    # (B*T,)
    )
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 500 == 0:
        print(f"step {step:>5} | loss {loss.item():.4f}")

# --- 4. Generation ---
@torch.no_grad()
def generate(model, prompt: str, max_new_tokens: int = 200) -> str:
    model.eval()
    idx = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -256:]                       # crop to max_seq_len
        logits = model(idx_cond)
        next_token_logits = logits[:, -1, :]           # last position only
        probs = F.softmax(next_token_logits / 0.8, dim=-1)  # temperature=0.8
        next_token = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_token], dim=1)
    return decode(idx[0].tolist())

print(generate(model, "HAMLET:"))

Expected Training Signal

step     0 | loss 4.17   ← random (log(65) ≈ 4.17 for 65-char vocab — correct)
step   500 | loss 2.51
step  1000 | loss 2.19
step  2500 | loss 1.94
step  5000 | loss 1.72   ← coherent-ish Shakespeare. Pipeline is working.

If loss doesn't decrease from step 0, the causal mask or residuals are broken. Fix before continuing.

Do / Don't

Do Test with a small, inspectable dataset (Shakespeare, a book). You need to read the outputs and see them improve.
Do Verify your starting loss matches theory: log(vocab_size). If it's lower, you have a data leak.
Don't Expect quality generations at this stage. The goal is to prove the loss curve goes down. Output quality comes later.
Don't Move to Phase 3 until the pipeline runs cleanly on CPU. GPU bugs are harder to debug.

Phase 3: Modernise the Architecture

What to Build

Replace the "classic" components with the three upgrades used in every major LLM since Llama 1: RMSNorm, SwiGLU, and RoPE embeddings. Add a KV cache for inference speed.

Why It Matters

These aren't cosmetic changes. Each one addresses a real problem: - RMSNorm trains more stably and is 15–20% faster than LayerNorm. - SwiGLU consistently outperforms GELU on downstream benchmarks. - RoPE encodes relative position, generalising better to sequences longer than the training context. - KV cache reduces inference compute from O(n²) to O(n) for each new token.

The Code

import torch
import torch.nn as nn
import torch.nn.functional as F


# --- RMSNorm: simpler, faster than LayerNorm ---
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
        return x / rms * self.weight


# --- SwiGLU: gated FFN used in Llama, PaLM, Mistral ---
class SwiGLU(nn.Module):
    def __init__(self, d_model: int, expansion: int = 4):
        super().__init__()
        hidden = int(d_model * expansion * 2 / 3)  # standard Llama sizing
        self.gate = nn.Linear(d_model, hidden, bias=False)
        self.up   = nn.Linear(d_model, hidden, bias=False)
        self.down = nn.Linear(hidden,  d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down(F.silu(self.gate(x)) * self.up(x))


# --- RoPE: Rotary Position Embeddings ---
def precompute_rope(head_dim: int, max_seq_len: int,
                    base: int = 10000) -> tuple[torch.Tensor, torch.Tensor]:
    theta = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, theta)               # (T, head_dim/2)
    cos = freqs.cos().repeat_interleave(2, dim=-1)
    sin = freqs.sin().repeat_interleave(2, dim=-1)
    return cos, sin                             # each: (T, head_dim)

def apply_rope(x: torch.Tensor, cos: torch.Tensor,
               sin: torch.Tensor) -> torch.Tensor:
    # Rotate pairs of dimensions: [x0, x1] → [x0*cos - x1*sin, x0*sin + x1*cos]
    x_rotated = torch.stack([-x[..., 1::2], x[..., ::2]], dim=-1)
    x_rotated = x_rotated.flatten(-2)
    return x * cos + x_rotated * sin


# --- KV Cache: store past keys/values, avoid recomputing ---
class KVCache:
    def __init__(self):
        self.k_cache: torch.Tensor | None = None
        self.v_cache: torch.Tensor | None = None

    def update(self, k: torch.Tensor, v: torch.Tensor
               ) -> tuple[torch.Tensor, torch.Tensor]:
        if self.k_cache is None:
            self.k_cache, self.v_cache = k, v
        else:
            self.k_cache = torch.cat([self.k_cache, k], dim=2)
            self.v_cache = torch.cat([self.v_cache, v], dim=2)
        return self.k_cache, self.v_cache

    def reset(self):
        self.k_cache = self.v_cache = None


# --- Modern Transformer Block combining all three upgrades ---
class ModernTransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, max_seq_len: int):
        super().__init__()
        self.n_heads  = n_heads
        self.head_dim = d_model // n_heads
        self.norm1 = RMSNorm(d_model)
        self.norm2 = RMSNorm(d_model)
        self.q  = nn.Linear(d_model, d_model, bias=False)
        self.k  = nn.Linear(d_model, d_model, bias=False)
        self.v  = nn.Linear(d_model, d_model, bias=False)
        self.o  = nn.Linear(d_model, d_model, bias=False)
        self.ff = SwiGLU(d_model)
        cos, sin = precompute_rope(self.head_dim, max_seq_len)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin)

    def forward(self, x: torch.Tensor,
                kv_cache: KVCache | None = None) -> torch.Tensor:
        B, T, C = x.shape
        h = self.n_heads
        d = self.head_dim

        q = self.q(self.norm1(x)).view(B, T, h, d).transpose(1, 2)
        k = self.k(self.norm1(x)).view(B, T, h, d).transpose(1, 2)
        v = self.v(self.norm1(x)).view(B, T, h, d).transpose(1, 2)

        # Apply RoPE to queries and keys
        q = apply_rope(q, self.cos[:T], self.sin[:T])
        k = apply_rope(k, self.cos[:T], self.sin[:T])

        # Use KV cache at inference time
        if kv_cache is not None:
            k, v = kv_cache.update(k, v)

        attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        out = attn.transpose(1, 2).contiguous().view(B, T, C)
        x = x + self.o(out)
        x = x + self.ff(self.norm2(x))
        return x

Do / Don't

Do Upgrade incrementally: swap LayerNorm → RMSNorm first, verify loss is equal or better, then add SwiGLU, then RoPE. One change at a time.
Do Use PyTorch's F.scaled_dot_product_attention — it automatically uses Flash Attention when available, giving 2–4× memory savings.
Don't Mix old and new norm/embedding schemes. LayerNorm + RoPE requires different initialisation than RMSNorm + sinusoidal. Test the full combination.
Don't Enable KV cache during training — it breaks gradient flow. It's an inference-only optimisation.

Phase 4: Scale Efficiently

What to Build

Three complementary scaling techniques: BPE tokenisation (subword vocabulary), mixed precision training (FP16/BF16), and gradient accumulation (simulate large batches without large GPU memory).

Why It Matters

Character-level models plateau quickly. Subword tokenisation (used by every production LLM) achieves 3–4× better compression, meaning the model processes more meaning per token. Mixed precision cuts VRAM in half. Gradient accumulation lets a single A100 simulate a 256-sample batch.

The Code

import torch
from torch.cuda.amp import GradScaler, autocast
from tokenizers import Tokenizer, models, trainers, pre_tokenizers

# --- 1. Train a BPE tokeniser ---
def train_bpe_tokeniser(files: list[str], vocab_size: int = 8192) -> Tokenizer:
    tokenizer = Tokenizer(models.BPE())
    tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
    trainer = trainers.BpeTrainer(
        vocab_size=vocab_size,
        special_tokens=["<|pad|>", "<|bos|>", "<|eos|>"],
    )
    tokenizer.train(files, trainer)
    return tokenizer


# --- 2. Mixed Precision + Gradient Accumulation training loop ---
def train(model, dataloader, optimizer, n_epochs: int,
          accum_steps: int = 8, device: str = "cuda"):

    scaler = GradScaler()                           # for FP16 loss scaling
    model.train()
    optimizer.zero_grad()

    for epoch in range(n_epochs):
        for step, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)

            # autocast: runs forward in BF16/FP16, keeping master weights in FP32
            with autocast(dtype=torch.bfloat16):
                logits = model(x)
                loss = torch.nn.functional.cross_entropy(
                    logits.view(-1, logits.size(-1)), y.view(-1)
                )
                loss = loss / accum_steps           # normalise before accumulating

            scaler.scale(loss).backward()

            # Only update weights every accum_steps steps
            if (step + 1) % accum_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

                print(f"epoch {epoch} step {step} | loss {loss.item() * accum_steps:.4f}")

        # Checkpoint after every epoch
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
        }, f"checkpoint_epoch_{epoch}.pt")

Do / Don't

Do Checkpoint often when pushing GPU memory limits. The cost of saving is seconds; the cost of a crashed run is hours.
Do Use torch.bfloat16 over float16 when your GPU supports it (A100, H100, RTX 30xx+). BF16 has the same dynamic range as FP32 and doesn't need a loss scaler.
Don't Assume bigger batch sizes always improve results. Larger batches can hurt generalisation. Use gradient accumulation to hit your target effective batch size, then tune the learning rate separately.
Don't Forget clip_grad_norm_. Without it, a single bad batch can produce a massive gradient update that corrupts the model. max_norm=1.0 is the standard.

Phase 5: Mixture of Experts (MoE)

What to Build

Replace the dense feed-forward layer with a sparse MoE layer: multiple expert FFNs with a learned router that activates only the top-K experts per token.

Why It Matters

MoE is how Mixtral 8x7B achieves GPT-4-class quality with GPT-3.5-class inference cost. The model has 47B parameters total, but only 13B activate per forward pass. You get a big model's capacity with a small model's speed.

The Code

import torch
import torch.nn as nn
import torch.nn.functional as F


class MoELayer(nn.Module):
    def __init__(self, d_model: int, n_experts: int = 8, top_k: int = 2):
        super().__init__()
        assert top_k <= n_experts
        self.n_experts = n_experts
        self.top_k = top_k

        # One FFN per expert
        self.experts = nn.ModuleList([SwiGLU(d_model) for _ in range(n_experts)])

        # Router: projects each token to a score over all experts
        self.router = nn.Linear(d_model, n_experts, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        x_flat = x.view(B * T, C)                  # (N, d_model)

        # --- Routing ---
        router_logits = self.router(x_flat)         # (N, n_experts)
        router_probs  = F.softmax(router_logits, dim=-1)

        # Select top-K experts per token
        top_k_probs, top_k_idx = router_probs.topk(self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)  # renormalise

        # --- Dispatch each token to its selected experts ---
        output = torch.zeros_like(x_flat)
        for expert_idx in range(self.n_experts):
            # Find which tokens selected this expert and at what weight
            mask = (top_k_idx == expert_idx).any(dim=-1)   # (N,)
            if not mask.any():
                continue

            tokens_for_expert = x_flat[mask]               # (n_selected, C)
            weight_idx = (top_k_idx[mask] == expert_idx).nonzero(as_tuple=True)[1]
            weights = top_k_probs[mask].gather(1, weight_idx.unsqueeze(1))

            expert_out = self.experts[expert_idx](tokens_for_expert)
            output[mask] += expert_out * weights           # weighted contribution

        # --- Load balancing auxiliary loss ---
        # Without this, the router collapses to always using the same 2 experts
        expert_usage = router_probs.mean(0)                # (n_experts,)
        self.aux_loss = self.n_experts * (expert_usage * router_probs.mean(0)).sum()

        return output.view(B, T, C)


# In your training loop, add the auxiliary loss:
# total_loss = main_loss + 0.01 * moe_layer.aux_loss

Do / Don't

Do Start with n_experts=4, top_k=2. Validate the routing is balanced (all experts used roughly equally) before scaling.
Do Add the load-balancing auxiliary loss. Without it, the router will specialise 1–2 experts and ignore the rest — you get a dense model with extra compute overhead.
Don't Activate all experts at once. That's a dense model. The entire point of MoE is conditional, sparse computation.
Don't Add MoE before your baseline transformer is training correctly. It adds routing complexity and the aux loss is easy to miscalibrate.

Phase 6: Supervised Fine-Tuning (SFT)

What to Build

Fine-tune your pre-trained model on instruction-response pairs so it transitions from "word predictor" to "helpful assistant."

Why It Matters

A raw pre-trained LLM is a document completer — given "The capital of France is", it predicts "Paris." That's not what users want. SFT teaches the model the format of being an assistant: receiving an instruction and generating a helpful, well-formatted response.

The Code

from torch.utils.data import Dataset, DataLoader


# --- Instruction dataset format ---
class InstructionDataset(Dataset):
    def __init__(self, examples: list[dict], tokenizer, max_len: int = 512):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.data = []

        for ex in examples:
            # Format: special tokens wrap the instruction/response
            text = (
                f"<|user|>\n{ex['instruction']}\n<|assistant|>\n{ex['response']}<|eos|>"
            )
            tokens = tokenizer.encode(text)
            self.data.append(tokens[:max_len])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens = self.data[idx]
        x = torch.tensor(tokens[:-1], dtype=torch.long)
        y = torch.tensor(tokens[1:],  dtype=torch.long)
        return x, y


# --- Only compute loss on the response tokens, not the instruction ---
def sft_loss(logits: torch.Tensor, targets: torch.Tensor,
             response_start_idx: int) -> torch.Tensor:
    # Mask out the instruction part — we only want the model to learn
    # to generate the response, not repeat the instruction
    mask = torch.zeros_like(targets, dtype=torch.bool)
    mask[:, response_start_idx:] = True

    logits_flat  = logits.view(-1, logits.size(-1))
    targets_flat = targets.view(-1)
    mask_flat    = mask.view(-1)

    return F.cross_entropy(logits_flat[mask_flat], targets_flat[mask_flat])


# --- LoRA for efficient fine-tuning (add to existing Linear layers) ---
class LoRALinear(nn.Module):
    def __init__(self, base: nn.Linear, r: int = 16, alpha: float = 32):
        super().__init__()
        self.base   = base
        self.lora_A = nn.Parameter(torch.randn(r, base.in_features) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(base.out_features, r))
        self.scale  = alpha / r

        # Freeze the base weights
        for p in self.base.parameters():
            p.requires_grad_(False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.base(x) + (x @ self.lora_A.T @ self.lora_B.T) * self.scale

Example training data format:

examples = [
    {
        "instruction": "Explain gradient descent in simple terms.",
        "response": "Gradient descent is an optimisation algorithm that adjusts model weights by moving in the direction that most reduces the loss function — like rolling a ball downhill to find the lowest point in a valley."
    },
    {
        "instruction": "Write a Python function to reverse a string.",
        "response": "```python\ndef reverse_string(s: str) -> str:\n    return s[::-1]\n```"
    },
]

Do / Don't

Do Align tone and format to your target use case. A customer support assistant needs different SFT data than a coding assistant. The format of your training data becomes the format of the model's outputs.
Do Use LoRA instead of full fine-tuning. It trains ~1% of parameters with minimal quality loss — essential for limited GPU memory.
Don't Skip SFT and go straight to RLHF. RLHF has no signal to work with if the base model doesn't produce roughly coherent responses yet.
Don't Compute loss on the instruction tokens. The model should learn to generate responses, not repeat prompts. Mask the instruction in the loss.

Phase 7: Reward Modelling

What to Build

A separate model that takes a (prompt, response) pair and outputs a scalar reward score representing how much a human would prefer that response.

Why It Matters

RLHF needs a reward signal. You can't run a human in the loop for every training step. The reward model is a learned proxy for human preference — trained on human-annotated pairwise rankings, then used to score millions of responses during RL training.

The Code

class RewardModel(nn.Module):
    """
    A transformer that maps (prompt + response) → scalar reward.
    Built on top of a pre-trained / SFT-fine-tuned backbone.
    """
    def __init__(self, backbone: TinyGPT):
        super().__init__()
        self.backbone = backbone
        d_model = backbone.token_emb.embedding_dim
        self.reward_head = nn.Linear(d_model, 1, bias=False)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        # Use the last token's hidden state as the sequence representation
        hidden = self.backbone(input_ids)               # (B, T, d_model)
        last_hidden = hidden[:, -1, :]                  # (B, d_model)
        return self.reward_head(last_hidden).squeeze(-1) # (B,) scalar rewards


class RewardDataset(Dataset):
    """Each example is a (prompt, chosen_response, rejected_response) triple."""
    def __init__(self, examples: list[dict], tokenizer, max_len: int = 512):
        self.pairs = []
        for ex in examples:
            chosen   = tokenizer.encode(ex["prompt"] + ex["chosen"])[:max_len]
            rejected = tokenizer.encode(ex["prompt"] + ex["rejected"])[:max_len]
            self.pairs.append((
                torch.tensor(chosen,   dtype=torch.long),
                torch.tensor(rejected, dtype=torch.long),
            ))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return self.pairs[idx]


def reward_model_loss(reward_model: RewardModel,
                      chosen_ids: torch.Tensor,
                      rejected_ids: torch.Tensor) -> torch.Tensor:
    """
    Bradley-Terry pairwise preference loss:
    maximise P(chosen > rejected) = σ(r_chosen - r_rejected)
    """
    r_chosen   = reward_model(chosen_ids)
    r_rejected = reward_model(rejected_ids)
    # Negative log-likelihood of chosen being preferred
    return -F.logsigmoid(r_chosen - r_rejected).mean()


# Example human preference data
examples = [
    {
        "prompt":    "<|user|>\nWhat is 2+2?\n<|assistant|>\n",
        "chosen":    "4.",
        "rejected":  "I'm not sure, mathematics is complex.",
    },
    {
        "prompt":    "<|user|>\nExplain photosynthesis.\n<|assistant|>\n",
        "chosen":    "Photosynthesis converts sunlight, CO₂, and water into glucose and oxygen.",
        "rejected":  "Plants do something with light I think.",
    },
]

Do / Don't

Do Collect pairwise rankings (A vs B), not absolute scores. Humans are far more consistent at "which is better?" than "rate this 1–10."
Do Validate the reward model separately before using it in RLHF. Run it on a held-out preference set and check it agrees with humans ≥65% of the time.
Don't Assume one reward model fits all domains. A reward model trained on creative writing preferences will give nonsense scores for code correctness. Train domain-specific models.
Don't Use the SFT model directly as the reward model without a separate reward head. They need different training objectives.

Phase 8: RLHF with PPO

What to Build

Use Proximal Policy Optimisation (PPO) to fine-tune the SFT model using reward signals from the reward model — with a KL-divergence penalty to prevent the policy from drifting too far from the original SFT behaviour.

Why It Matters

This is the final alignment step. SFT teaches format. RLHF teaches preference — which format, tone, and content humans actually reward. ChatGPT, Claude, and Llama-2-chat all use this pipeline (or a variant like DPO/GRPO).

The Code

import torch
import torch.nn.functional as F


def compute_ppo_loss(
    policy_logprobs:    torch.Tensor,   # log probs from the current policy
    old_logprobs:       torch.Tensor,   # log probs from the policy at sampling time
    rewards:            torch.Tensor,   # reward model scores
    ref_logprobs:       torch.Tensor,   # log probs from the frozen SFT reference model
    kl_coeff:           float = 0.1,    # KL penalty weight
    clip_eps:           float = 0.2,    # PPO clipping range
) -> torch.Tensor:

    # --- KL Penalty: penalise deviation from the SFT reference model ---
    kl_penalty = policy_logprobs - ref_logprobs   # per-token KL estimate
    shaped_reward = rewards - kl_coeff * kl_penalty.sum(dim=-1)

    # --- PPO Clipped Surrogate Objective ---
    log_ratio = policy_logprobs.sum(-1) - old_logprobs.sum(-1)
    ratio = log_ratio.exp()                        # probability ratio π_θ / π_old

    # Clip the ratio to [1-ε, 1+ε] to prevent too-large updates
    clipped_ratio = ratio.clamp(1 - clip_eps, 1 + clip_eps)
    ppo_loss = -torch.min(
        ratio * shaped_reward,
        clipped_ratio * shaped_reward,
    ).mean()

    return ppo_loss


def rlhf_training_step(
    policy_model,
    ref_model,          # frozen SFT model — provides the KL baseline
    reward_model,
    prompt_ids: torch.Tensor,
    optimizer,
):
    # 1. Sample a response from the current policy
    with torch.no_grad():
        response_ids = policy_model.generate(prompt_ids, max_new_tokens=128)
        full_ids = torch.cat([prompt_ids, response_ids], dim=1)

        # 2. Score with reward model
        rewards = reward_model(full_ids)

        # 3. Get reference (SFT) log probs for KL penalty
        ref_logits = ref_model(response_ids)
        ref_logprobs = F.log_softmax(ref_logits, dim=-1).gather(
            -1, response_ids.unsqueeze(-1)
        ).squeeze(-1)

        # 4. Get old policy log probs (from the sampling step)
        old_logits = policy_model(response_ids)
        old_logprobs = F.log_softmax(old_logits, dim=-1).gather(
            -1, response_ids.unsqueeze(-1)
        ).squeeze(-1)

    # 5. Forward pass with gradient enabled
    policy_logits = policy_model(response_ids)
    policy_logprobs = F.log_softmax(policy_logits, dim=-1).gather(
        -1, response_ids.unsqueeze(-1)
    ).squeeze(-1)

    # 6. Compute PPO loss
    loss = compute_ppo_loss(
        policy_logprobs=policy_logprobs,
        old_logprobs=old_logprobs.detach(),
        rewards=rewards.detach(),
        ref_logprobs=ref_logprobs.detach(),
        kl_coeff=0.1,
    )

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 1.0)
    optimizer.step()

    return loss.item(), rewards.mean().item()

Monitoring What Matters

# The three numbers to watch during RLHF training:
print(f"PPO loss:      {ppo_loss:.4f}")      # should decrease
print(f"Mean reward:   {mean_reward:.4f}")    # should increase
print(f"KL divergence: {kl_div:.4f}")        # must stay bounded (< 4–6 nats)

# If KL divergence explodes: increase kl_coeff (0.1 → 0.2)
# If reward barely moves: decrease kl_coeff (0.1 → 0.05)
# If reward increases but outputs sound robotic: the reward model is being gamed

Do / Don't

Do Add the KL-penalty. Without it, the policy will find adversarial prompts that score high on the reward model but produce garbled or unsafe text — called reward hacking.
Do Monitor real outputs alongside metrics. A reward score of +3.2 means nothing if the actual text is repetitive or evasive. Sample 20–30 responses every 100 steps and read them.
Don't Optimise only for reward score. The reward model is an imperfect proxy for human preference. PPO that maximises it completely will overfit to the reward model's blind spots.
Don't Start RLHF from the base pre-trained model. Always start from the SFT checkpoint. RLHF with no instruction-following baseline diverges immediately.

The Complete Lifecycle at a Glance

┌────────────────────────────────────────────────────────────────┐
│  PHASE         GOAL                    KEY CODE                │
├────────────────────────────────────────────────────────────────┤
│  1. Core       Working transformer     MHA + causal mask       │
│  2. Train      Prove pipeline works    Cross-entropy + gen     │
│  3. Modernise  2026 architecture       RMSNorm/SwiGLU/RoPE     │
│  4. Scale      Handle real data        BPE + AMP + grad accum  │
│  5. MoE        Sparse computation      Router + load balance   │
│  6. SFT        Instruction following   Response-masked loss    │
│  7. Reward     Human preference proxy  Bradley-Terry loss      │
│  8. RLHF       Align to preferences    PPO + KL penalty        │
└────────────────────────────────────────────────────────────────┘

Start here:  Phase 1 → 2 (tiny Shakespeare dataset, CPU)
Add modern:  Phase 3 (swap components, verify loss unchanged)
Add scale:   Phase 4 (BPE, AMP, gradient accumulation)
Optional:    Phase 5 (MoE — only if you need parameter efficiency)
Align:       Phase 6 → 7 → 8 (SFT first, always)

Summary

Building an LLM from scratch is not one task — it's eight sequential phases, each with its own failure modes and validation criteria.

Phase 1 gives you the engine: decoder-only transformer with causal masking, residual connections, and pre-norm. Phase 2 validates it actually trains: check that loss starts at log(vocab_size) and falls. Phase 3 modernises it to match production LLMs: RMSNorm, SwiGLU, RoPE, and KV cache for fast inference. Phase 4 makes it trainable at real scale: BPE tokenisation, mixed precision, gradient accumulation, and compulsive checkpointing. Phase 5 adds conditional computation via Mixture of Experts — only useful once the dense baseline is solid. Phase 6 is the transition from autocomplete to assistant: SFT with instruction-response pairs, response-masked loss, and LoRA for efficiency. Phase 7 trains a reward model that learns human preference from pairwise rankings. Phase 8 uses PPO with KL-penalty to align the policy to that reward model without drifting into reward-hacking.

The universal lesson cutting across all eight phases: don't skip validation gates. Each phase assumes the previous one is working correctly. A broken causal mask in Phase 1 corrupts every phase that follows. A reward model that doesn't outperform chance in Phase 7 makes RLHF in Phase 8 meaningless. Build the smallest thing that proves correctness at each step, then move forward.

This is the backbone of every modern LLM. The architectures differ at the margins. The lifecycle does not.


Implementing one of these phases and hitting a specific wall? Drop a question in the comments — happy to dig in.

Questions or discussion? Connect on LinkedIn, X or reach out via email.

Discussion

Have thoughts on this post? Share them below — questions, corrections, or your own experience are all welcome.